diff --git a/src/renderer/src/providers/AiProvider/GeminiProvider.ts b/src/renderer/src/providers/AiProvider/GeminiProvider.ts index 52323efc1d..0a49ebe573 100644 --- a/src/renderer/src/providers/AiProvider/GeminiProvider.ts +++ b/src/renderer/src/providers/AiProvider/GeminiProvider.ts @@ -127,7 +127,6 @@ export default class GeminiProvider extends BaseProvider { * @returns The message contents */ private async getMessageContents(message: Message): Promise { - console.log('getMessageContents', message) const role = message.role === 'user' ? 'user' : 'model' const parts: Part[] = [{ text: await this.getMessageContent(message) }] // Add any generated images from previous responses @@ -197,6 +196,50 @@ export default class GeminiProvider extends BaseProvider { } } + private async getImageFileContents(message: Message): Promise { + const role = message.role === 'user' ? 'user' : 'model' + const content = getMainTextContent(message) + const parts: Part[] = [{ text: content }] + const imageBlocks = findImageBlocks(message) + for (const imageBlock of imageBlocks) { + if ( + imageBlock.metadata?.generateImageResponse?.images && + imageBlock.metadata.generateImageResponse.images.length > 0 + ) { + for (const imageUrl of imageBlock.metadata.generateImageResponse.images) { + if (imageUrl && imageUrl.startsWith('data:')) { + // Extract base64 data and mime type from the data URL + const matches = imageUrl.match(/^data:(.+);base64,(.*)$/) + if (matches && matches.length === 3) { + const mimeType = matches[1] + const base64Data = matches[2] + parts.push({ + inlineData: { + data: base64Data, + mimeType: mimeType + } as Part['inlineData'] + }) + } + } + } + } + const file = imageBlock.file + if (file) { + const base64Data = await window.api.file.base64Image(file.id + file.ext) + parts.push({ + inlineData: { + data: base64Data.base64, + mimeType: base64Data.mime + } as Part['inlineData'] + }) + } + } + return { + role, + parts: parts + } + } + /** * Get the safety settings * @returns The safety settings @@ -284,6 +327,18 @@ export default class GeminiProvider extends BaseProvider { }: CompletionsParams): Promise { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel + let canGenerateImage = false + if (isGenerateImageModel(model)) { + if (model.id === 'gemini-2.0-flash-exp') { + canGenerateImage = assistant.enableGenerateImage! + } else { + canGenerateImage = true + } + } + if (canGenerateImage) { + await this.generateImageByChat({ messages, assistant, onChunk }) + return + } const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant) const userMessages = filterUserRoleStartMessages( @@ -320,21 +375,10 @@ export default class GeminiProvider extends BaseProvider { }) } - let canGenerateImage = false - if (isGenerateImageModel(model)) { - if (model.id === 'gemini-2.0-flash-exp') { - canGenerateImage = assistant.enableGenerateImage! - } else { - canGenerateImage = true - } - } - const generateContentConfig: GenerateContentConfig = { - responseModalities: canGenerateImage ? [Modality.TEXT, Modality.IMAGE] : undefined, - responseMimeType: canGenerateImage ? 'text/plain' : undefined, safetySettings: this.getSafetySettings(), // generate image don't need system instruction - systemInstruction: isGemmaModel(model) || canGenerateImage ? undefined : systemInstruction, + systemInstruction: isGemmaModel(model) ? undefined : systemInstruction, temperature: assistant?.settings?.temperature, topP: assistant?.settings?.topP, maxOutputTokens: maxTokens, @@ -528,11 +572,6 @@ export default class GeminiProvider extends BaseProvider { onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text }) } - const generateImage = this.processGeminiImageResponse(chunk, onChunk) - if (generateImage?.images?.length) { - onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage }) - } - if (chunk.candidates?.[0]?.finishReason) { if (chunk.text) { onChunk({ type: ChunkType.TEXT_COMPLETE, text: content }) @@ -950,8 +989,97 @@ export default class GeminiProvider extends BaseProvider { return data.embeddings?.[0]?.values?.length || 0 } - public generateImageByChat(): Promise { - throw new Error('Method not implemented.') + public async generateImageByChat({ messages, assistant, onChunk }): Promise { + const defaultModel = getDefaultModel() + const model = assistant.model || defaultModel + const { contextCount, maxTokens } = getAssistantSettings(assistant) + const userMessages = filterUserRoleStartMessages( + filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2))) + ) + + const userLastMessage = userMessages.pop() + const { abortController } = this.createAbortController(userLastMessage?.id, true) + const { signal } = abortController + const generateContentConfig: GenerateContentConfig = { + responseModalities: [Modality.TEXT, Modality.IMAGE], + responseMimeType: 'text/plain', + safetySettings: this.getSafetySettings(), + temperature: assistant?.settings?.temperature, + topP: assistant?.settings?.top_p, + maxOutputTokens: maxTokens, + abortSignal: signal, + ...this.getCustomParameters(assistant) + } + const history: Content[] = [] + try { + for (const message of userMessages) { + history.push(await this.getImageFileContents(message)) + } + + let time_first_token_millsec = 0 + const start_time_millsec = new Date().getTime() + onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) + const chat = this.sdk.chats.create({ + model: model.id, + config: generateContentConfig, + history: history + }) + let content = '' + const finalUsage: Usage = { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0 + } + const userMessage: Content = await this.getImageFileContents(userLastMessage!) + const response = await chat.sendMessageStream({ + message: userMessage.parts!, + config: { + ...generateContentConfig, + abortSignal: signal + } + }) + for await (const chunk of response as AsyncGenerator) { + if (time_first_token_millsec == 0) { + time_first_token_millsec = new Date().getTime() + } + + if (chunk.text !== undefined) { + content += chunk.text + onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text }) + } + const generateImage = this.processGeminiImageResponse(chunk, onChunk) + if (generateImage?.images?.length) { + onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage }) + } + if (chunk.candidates?.[0]?.finishReason) { + if (chunk.text) { + onChunk({ type: ChunkType.TEXT_COMPLETE, text: content }) + } + if (chunk.usageMetadata) { + finalUsage.prompt_tokens = chunk.usageMetadata.promptTokenCount || 0 + finalUsage.completion_tokens = chunk.usageMetadata.candidatesTokenCount || 0 + finalUsage.total_tokens = chunk.usageMetadata.totalTokenCount || 0 + } + } + } + onChunk({ + type: ChunkType.BLOCK_COMPLETE, + response: { + usage: finalUsage, + metrics: { + completion_tokens: finalUsage.completion_tokens, + time_completion_millsec: new Date().getTime() - start_time_millsec, + time_first_token_millsec: time_first_token_millsec - start_time_millsec + } + } + }) + } catch (error) { + console.error('[generateImageByChat] error', error) + onChunk({ + type: ChunkType.ERROR, + error + }) + } } public convertMcpTools(mcpTools: MCPTool[]): T[] { diff --git a/src/renderer/src/store/thunk/messageThunk.ts b/src/renderer/src/store/thunk/messageThunk.ts index 4e586a9a66..5d40439626 100644 --- a/src/renderer/src/store/thunk/messageThunk.ts +++ b/src/renderer/src/store/thunk/messageThunk.ts @@ -613,7 +613,6 @@ const fetchAndProcessAssistantResponseImpl = async ( const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant }) response.usage = usage } - console.log('response', response) } if (response && response.metrics) { if (response.metrics.completion_tokens === 0 && response.usage?.completion_tokens) {