From e197a6f3e43c33dc4366d24806fd421ca9aebe13 Mon Sep 17 00:00:00 2001 From: SuYao Date: Wed, 18 Jun 2025 17:40:46 +0800 Subject: [PATCH] fix: initialize messageContents and improve message handling in GeminiAPIClient; add new Gemini model to configuration (#7307) * fix: initialize messageContents and improve message handling in GeminiAPIClient; add new Gemini model to configuration * refactor: streamline message handling in GeminiAPIClient; enhance message extraction from SDK payload --- .../aiCore/clients/gemini/GeminiAPIClient.ts | 57 +++++++++++-------- src/renderer/src/config/models.ts | 4 +- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts b/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts index 60c55f866f..3255b8cf86 100644 --- a/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts +++ b/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts @@ -466,7 +466,7 @@ export class GeminiAPIClient extends BaseApiClient< systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools, assistant) } - let messageContents: Content + let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents const history: Content[] = [] // 3. 处理用户消息 if (typeof messages === 'string') { @@ -475,10 +475,12 @@ export class GeminiAPIClient extends BaseApiClient< parts: [{ text: messages }] } } else { - const userLastMessage = messages.pop()! - messageContents = await this.convertMessageToSdkParam(userLastMessage) - for (const message of messages) { - history.push(await this.convertMessageToSdkParam(message)) + const userLastMessage = messages.pop() + if (userLastMessage) { + messageContents = await this.convertMessageToSdkParam(userLastMessage) + for (const message of messages) { + history.push(await this.convertMessageToSdkParam(message)) + } } } @@ -491,6 +493,10 @@ export class GeminiAPIClient extends BaseApiClient< if (isGemmaModel(model) && assistant.prompt) { const isFirstMessage = history.length === 0 if (isFirstMessage && messageContents) { + const userMessageText = + messageContents.parts && messageContents.parts.length > 0 + ? (messageContents.parts[0] as Part).text || '' + : '' const systemMessage = [ { text: @@ -498,7 +504,7 @@ export class GeminiAPIClient extends BaseApiClient< systemInstruction + '\n' + 'user\n' + - (messageContents?.parts?.[0] as Part).text + + userMessageText + '' } ] as Part[] @@ -515,13 +521,7 @@ export class GeminiAPIClient extends BaseApiClient< const newMessageContents = isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 - ? { - ...messageContents, - parts: [ - ...(messageContents.parts || []), - ...(recursiveSdkMessages[recursiveSdkMessages.length - 1].parts || []) - ] - } + ? recursiveSdkMessages[recursiveSdkMessages.length - 1] : messageContents const generateContentConfig: GenerateContentConfig = { @@ -555,7 +555,7 @@ export class GeminiAPIClient extends BaseApiClient< getResponseChunkTransformer(): ResponseChunkTransformer { return () => ({ async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController) { - let toolCalls: FunctionCall[] = [] + const toolCalls: FunctionCall[] = [] if (chunk.candidates && chunk.candidates.length > 0) { for (const candidate of chunk.candidates) { if (candidate.content) { @@ -583,6 +583,8 @@ export class GeminiAPIClient extends BaseApiClient< ] } }) + } else if (part.functionCall) { + toolCalls.push(part.functionCall) } }) } @@ -597,9 +599,6 @@ export class GeminiAPIClient extends BaseApiClient< } } as LLMWebSearchCompleteChunk) } - if (chunk.functionCalls) { - toolCalls = toolCalls.concat(chunk.functionCalls) - } controller.enqueue({ type: ChunkType.LLM_RESPONSE_COMPLETE, response: { @@ -702,12 +701,11 @@ export class GeminiAPIClient extends BaseApiClient< .filter((p) => p !== undefined) ) - const userMessage: Content = { - role: 'user', - parts: parts + const lastMessage = currentReqMessages[currentReqMessages.length - 1] + if (lastMessage) { + lastMessage.parts?.push(...parts) } - - return [...currentReqMessages, userMessage] + return currentReqMessages } override estimateMessageTokens(message: GeminiSdkMessageParam): number { @@ -734,7 +732,20 @@ export class GeminiAPIClient extends BaseApiClient< } public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] { - return sdkPayload.history || [] + const messageParam: GeminiSdkMessageParam = { + role: 'user', + parts: [] + } + if (Array.isArray(sdkPayload.message)) { + sdkPayload.message.forEach((part) => { + if (typeof part === 'string') { + messageParam.parts?.push({ text: part }) + } else if (typeof part === 'object') { + messageParam.parts?.push(part) + } + }) + } + return [messageParam, ...(sdkPayload.history || [])] } private async uploadFile(file: FileType): Promise { diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index 2ffdac76c2..173753f8ed 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -2313,6 +2313,7 @@ export const GEMINI_SEARCH_MODELS = [ 'gemini-2.0-flash-lite', 'gemini-2.0-flash-exp', 'gemini-2.0-flash-001', + 'gemini-2.5-pro', 'gemini-2.0-pro-exp-02-05', 'gemini-2.0-pro-exp', 'gemini-2.5-pro-exp', @@ -2322,7 +2323,8 @@ export const GEMINI_SEARCH_MODELS = [ 'gemini-2.5-pro-preview-05-06', 'gemini-2.5-flash-preview', 'gemini-2.5-flash-preview-04-17', - 'gemini-2.5-flash-preview-05-20' + 'gemini-2.5-flash-preview-05-20', + 'gemini-2.5-flash-lite-preview-06-17' ] export const OPENAI_NO_SUPPORT_DEV_ROLE_MODELS = ['o1-preview', 'o1-mini']