diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index 730bd76d78..251082de75 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -2145,7 +2145,13 @@ export const TEXT_TO_IMAGES_MODELS_SUPPORT_IMAGE_ENHANCEMENT = [ 'stabilityai/stable-diffusion-xl-base-1.0' ] -export const GENERATE_IMAGE_MODELS = ['gemini-2.0-flash-exp-image-generation', 'gemini-2.0-flash-exp'] +export const GENERATE_IMAGE_MODELS = [ + 'gemini-2.0-flash-exp-image-generation', + 'gemini-2.0-flash-exp', + 'grok-2-image-1212', + 'gpt-4o-image', + 'gpt-image-1' +] export const GEMINI_SEARCH_MODELS = [ 'gemini-2.0-flash', diff --git a/src/renderer/src/providers/AiProvider/AnthropicProvider.ts b/src/renderer/src/providers/AiProvider/AnthropicProvider.ts index 0c936d6951..c7a1473a62 100644 --- a/src/renderer/src/providers/AiProvider/AnthropicProvider.ts +++ b/src/renderer/src/providers/AiProvider/AnthropicProvider.ts @@ -507,6 +507,10 @@ export default class AnthropicProvider extends BaseProvider { return [] } + public async generateImageByChat(): Promise { + throw new Error('Method not implemented.') + } + /** * Generate suggestions * @returns The suggestions diff --git a/src/renderer/src/providers/AiProvider/BaseProvider.ts b/src/renderer/src/providers/AiProvider/BaseProvider.ts index 142f5c0661..182dd4b432 100644 --- a/src/renderer/src/providers/AiProvider/BaseProvider.ts +++ b/src/renderer/src/providers/AiProvider/BaseProvider.ts @@ -39,6 +39,7 @@ export default abstract class BaseProvider { abstract check(model: Model): Promise<{ valid: boolean; error: Error | null }> abstract models(): Promise abstract generateImage(params: GenerateImageParams): Promise + abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise abstract getEmbeddingDimensions(model: Model): Promise public getBaseURL(): string { diff --git a/src/renderer/src/providers/AiProvider/GeminiProvider.ts b/src/renderer/src/providers/AiProvider/GeminiProvider.ts index 50498f956c..446f2911c0 100644 --- a/src/renderer/src/providers/AiProvider/GeminiProvider.ts +++ b/src/renderer/src/providers/AiProvider/GeminiProvider.ts @@ -718,4 +718,8 @@ export default class GeminiProvider extends BaseProvider { }) return data.embeddings?.[0]?.values?.length || 0 } + + public generateImageByChat(): Promise { + throw new Error('Method not implemented.') + } } diff --git a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts index 52ebe8bda7..b482d5d9d7 100644 --- a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts +++ b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts @@ -307,6 +307,10 @@ export default class OpenAIProvider extends BaseProvider { * @returns The completions */ async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams): Promise { + if (assistant.enableGenerateImage) { + await this.generateImageByChat({ messages, assistant, onChunk } as CompletionsParams) + return + } const defaultModel = getDefaultModel() const model = assistant.model || defaultModel const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) @@ -893,4 +897,33 @@ export default class OpenAIProvider extends BaseProvider { const { token } = await window.api.copilot.getToken(defaultHeaders) this.sdk.apiKey = token } + + public async generateImageByChat({ messages, assistant, onChunk }: CompletionsParams): Promise { + const defaultModel = getDefaultModel() + const model = assistant.model || defaultModel + const lastUserMessage = messages.findLast((m) => m.role === 'user') + const { abortController, signalPromise } = this.createAbortController(lastUserMessage?.id, true) + const { signal } = abortController + const response = await this.sdk.images.generate( + { + model: model.id, + prompt: lastUserMessage?.content || '' + }, + { + signal + } + ) + + await signalPromise?.promise?.catch((error) => { + throw error + }) + + return onChunk({ + text: '', + generateImage: { + type: 'url', + images: response.data.map((item) => item.url).filter((url): url is string => url !== undefined) + } + }) + } } diff --git a/src/renderer/src/providers/AiProvider/index.ts b/src/renderer/src/providers/AiProvider/index.ts index e518c218e5..32350693d4 100644 --- a/src/renderer/src/providers/AiProvider/index.ts +++ b/src/renderer/src/providers/AiProvider/index.ts @@ -109,6 +109,15 @@ export default class AiProvider { return this.sdk.generateImage(params) } + public async generateImageByChat({ + messages, + assistant, + onChunk, + onFilterMessages + }: CompletionsParams): Promise { + return this.sdk.generateImageByChat({ messages, assistant, onChunk, onFilterMessages }) + } + public async getEmbeddingDimensions(model: Model): Promise { return this.sdk.getEmbeddingDimensions(model) }