From 0930201e5ddc70c44c07dfe6d407e6825650fd4c Mon Sep 17 00:00:00 2001 From: SuYao Date: Wed, 16 Jul 2025 15:04:19 +0800 Subject: [PATCH] Fix/mcp bug (#8189) * feat(models): enhance function calling model detection and update migration logic - Added support for 'gemini-1' in FUNCTION_CALLING_EXCLUDED_MODELS. - Updated isFunctionCallingModel to handle optional model input. - Modified migration logic to change tool use mode for assistants using function calling models. * feat(models): add new models to vision and function calling lists - Added 'kimi-thinking-preview' to visionAllowedModels. - Added 'kimi-k2' to FUNCTION_CALLING_MODELS. - Updated migration logic to ensure compatibility with new model settings. * refactor(TextChunkMiddleware): streamline text accumulation logic and improve response handling - Simplified the logic for accumulating text content and updating the internal state. - Ensured that the final text is consistently used in response callbacks. - Removed redundant code for handling text completion in the ToolUseExtractionMiddleware. - Added mock state for MCP tools in tests to enhance coverage for tool use extraction. * refactor(BaseApiClient): remove unused content extraction utility - Replaced the usage of getContentWithTools with getMainTextContent in the getMessageContent method. - Cleaned up imports by removing the unused getContentWithTools function. --- .../src/aiCore/clients/BaseApiClient.ts | 4 +- .../middleware/core/TextChunkMiddleware.ts | 15 +- .../feat/ToolUseExtractionMiddleware.ts | 1 - src/renderer/src/config/models.ts | 15 +- .../src/services/__tests__/ApiService.test.ts | 277 ++++++++++++++++++ src/renderer/src/store/migrate.ts | 35 ++- 6 files changed, 316 insertions(+), 31 deletions(-) diff --git a/src/renderer/src/aiCore/clients/BaseApiClient.ts b/src/renderer/src/aiCore/clients/BaseApiClient.ts index c23c34351f..3b904685d1 100644 --- a/src/renderer/src/aiCore/clients/BaseApiClient.ts +++ b/src/renderer/src/aiCore/clients/BaseApiClient.ts @@ -38,7 +38,7 @@ import { } from '@renderer/types/sdk' import { isJSON, parseJSON } from '@renderer/utils' import { addAbortController, removeAbortController } from '@renderer/utils/abortController' -import { findFileBlocks, getContentWithTools, getMainTextContent } from '@renderer/utils/messageUtils/find' +import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { defaultTimeout } from '@shared/config/constant' import Logger from 'electron-log/renderer' import { isEmpty } from 'lodash' @@ -210,7 +210,7 @@ export abstract class BaseApiClient< } public async getMessageContent(message: Message): Promise { - const content = getContentWithTools(message) + const content = getMainTextContent(message) if (isEmpty(content)) { return '' diff --git a/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts b/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts index db9fdb253a..cfaf70299f 100644 --- a/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts +++ b/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts @@ -54,21 +54,20 @@ export const TextChunkMiddleware: CompletionsMiddleware = text: accumulatedTextContent // 增量更新 }) } else if (accumulatedTextContent && chunk.type !== ChunkType.TEXT_START) { - if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) { - const finalText = accumulatedTextContent - ctx._internal.customState!.accumulatedText = finalText - if (ctx._internal.toolProcessingState && !ctx._internal.toolProcessingState?.output) { - ctx._internal.toolProcessingState.output = finalText - } + ctx._internal.customState!.accumulatedText = accumulatedTextContent + if (ctx._internal.toolProcessingState && !ctx._internal.toolProcessingState?.output) { + ctx._internal.toolProcessingState.output = accumulatedTextContent + } + if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) { // 处理 onResponse 回调 - 发送最终完整文本 if (params.onResponse) { - params.onResponse(finalText, true) + params.onResponse(accumulatedTextContent, true) } controller.enqueue({ type: ChunkType.TEXT_COMPLETE, - text: finalText + text: accumulatedTextContent }) controller.enqueue(chunk) } else { diff --git a/src/renderer/src/aiCore/middleware/feat/ToolUseExtractionMiddleware.ts b/src/renderer/src/aiCore/middleware/feat/ToolUseExtractionMiddleware.ts index 3e606f6683..b53d7348f1 100644 --- a/src/renderer/src/aiCore/middleware/feat/ToolUseExtractionMiddleware.ts +++ b/src/renderer/src/aiCore/middleware/feat/ToolUseExtractionMiddleware.ts @@ -79,7 +79,6 @@ function createToolUseExtractionTransform( toolCounter += toolUseResponses.length if (toolUseResponses.length > 0) { - controller.enqueue({ type: ChunkType.TEXT_COMPLETE, text: '' }) // 生成 MCP_TOOL_CREATED chunk const mcpToolCreatedChunk: MCPToolCreatedChunk = { type: ChunkType.MCP_TOOL_CREATED, diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index c162dab6ea..64c38127fc 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -184,7 +184,8 @@ const visionAllowedModels = [ 'deepseek-vl(?:[\\w-]+)?', 'kimi-latest', 'gemma-3(?:-[\\w-]+)', - 'doubao-seed-1[.-]6(?:-[\\w-]+)?' + 'doubao-seed-1[.-]6(?:-[\\w-]+)?', + 'kimi-thinking-preview' ] const visionExcludedModels = [ @@ -239,7 +240,8 @@ export const FUNCTION_CALLING_MODELS = [ 'learnlm(?:-[\\w-]+)?', 'gemini(?:-[\\w-]+)?', // 提前排除了gemini的嵌入模型 'grok-3(?:-[\\w-]+)?', - 'doubao-seed-1[.-]6(?:-[\\w-]+)?' + 'doubao-seed-1[.-]6(?:-[\\w-]+)?', + 'kimi-k2(?:-[\\w-]+)?' ] const FUNCTION_CALLING_EXCLUDED_MODELS = [ @@ -247,7 +249,8 @@ const FUNCTION_CALLING_EXCLUDED_MODELS = [ 'imagen(?:-[\\w-]+)?', 'o1-mini', 'o1-preview', - 'AIDC-AI/Marco-o1' + 'AIDC-AI/Marco-o1', + 'gemini-1(?:\\.[\\w-]+)?' ] export const FUNCTION_CALLING_REGEX = new RegExp( @@ -260,7 +263,11 @@ export const CLAUDE_SUPPORTED_WEBSEARCH_REGEX = new RegExp( 'i' ) -export function isFunctionCallingModel(model: Model): boolean { +export function isFunctionCallingModel(model?: Model): boolean { + if (!model) { + return false + } + if (model.type?.includes('function_calling')) { return true } diff --git a/src/renderer/src/services/__tests__/ApiService.test.ts b/src/renderer/src/services/__tests__/ApiService.test.ts index b88d0d0775..239e73b2a6 100644 --- a/src/renderer/src/services/__tests__/ApiService.test.ts +++ b/src/renderer/src/services/__tests__/ApiService.test.ts @@ -197,6 +197,17 @@ vi.mock('@renderer/store/llm.ts', () => { } }) +vi.mock('@renderer/store/mcp.ts', () => { + const mockInitialState = { + servers: [{ id: 'mcp-server-1', name: 'mcp-server-1', isActive: true, disabledAutoApproveTools: [] }] + } + return { + default: (state = mockInitialState) => { + return state + } + } +}) + // 测试用例:将 Gemini API 响应数据转换为 geminiChunks 数组 const geminiChunks: GeminiSdkRawChunk[] = [ { @@ -543,6 +554,67 @@ const geminiThinkingChunks: GeminiSdkRawChunk[] = [ } as unknown as GeminiSdkRawChunk ] +const geminiToolUseChunks: GeminiSdkRawChunk[] = [ + { + candidates: [ + { + content: { + parts: [ + { + text: '**Initiating File Retrieval**\n\nI\'ve determined that the `tool_mcp-tool-1` tool is suitable for this task. It seems the user intends to read a file, and this tool aligns with that objective. Currently, I\'m focusing on the necessary parameters. The `tool_mcp-tool-1` tool requires a `name` and `age`, which the user has helpfully provided: `{"name": "xxx", "age": 20}`. I\'m verifying the input.\n\n\n', + thought: true + } + ], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: {} + } as GeminiSdkRawChunk, + { + candidates: [ + { + content: { + parts: [{ text: '好的,我将为您打印用户的' }], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: {} + } as GeminiSdkRawChunk, + { + candidates: [ + { + content: { + parts: [{ text: '信息。\n\u003ctool_use\u003e\n \u003cname\u003emcp-tool-1\u003c/name\u003e\n' }], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: {} + } as GeminiSdkRawChunk, + { + candidates: [ + { + content: { + parts: [ + { + text: ' \u003carguments\u003e{"name":"xxx","age":20}\u003c/arguments\u003e\n\u003c/tool_use\u003e' + } + ], + role: 'model' + }, + finishReason: FinishReason.STOP, + index: 0 + } + ], + usageMetadata: {} + } as GeminiSdkRawChunk +] + // 正确的 async generator 函数 async function* geminiChunkGenerator(): AsyncGenerator { for (const chunk of geminiChunks) { @@ -556,6 +628,12 @@ async function* geminiThinkingChunkGenerator(): AsyncGenerator { + for (const chunk of geminiToolUseChunks) { + yield chunk + } +} + // 创建 mock 的 GeminiAPIClient const mockGeminiApiClient = { createCompletions: vi.fn().mockImplementation(() => geminiChunkGenerator()), @@ -677,6 +755,9 @@ const mockGeminiApiClient = { const mockGeminiThinkingApiClient = cloneDeep(mockGeminiApiClient) mockGeminiThinkingApiClient.createCompletions = vi.fn().mockImplementation(() => geminiThinkingChunkGenerator()) +const mockGeminiToolUseApiClient = cloneDeep(mockGeminiApiClient) +mockGeminiToolUseApiClient.createCompletions = vi.fn().mockImplementation(() => geminiToolUseChunkGenerator()) + const mockProvider = { id: 'gemini', type: 'gemini', @@ -982,4 +1063,200 @@ describe('ApiService', () => { expect(filteredChunks).toEqual(expectedChunks) }) + + // it('should extract tool use responses correctly', async () => { + // const mockCreate = vi.mocked(ApiClientFactory.create) + // mockCreate.mockReturnValue(mockGeminiToolUseApiClient as unknown as BaseApiClient) + // const AI = new AiProvider(mockProvider) + // const spy = vi.spyOn(McpToolsModule, 'callMCPTool') + // spy.mockResolvedValue({ + // content: [{ type: 'text', text: 'test' }], + // isError: false + // }) + + // const result = await AI.completions({ + // callType: 'test', + // messages: [], + // assistant: { + // id: '1', + // name: 'test', + // prompt: 'test', + // model: { + // id: 'gemini-2.5-pro', + // name: 'Gemini 2.5 Pro' + // }, + // settings: { + // toolUseMode: 'prompt' + // } + // } as Assistant, + // mcpTools: [ + // { + // id: 'mcp-tool-1', + // name: 'mcp-tool-1', + // serverId: 'mcp-server-1', + // serverName: 'mcp-server-1', + // description: 'mcp-tool-1', + // inputSchema: { + // type: 'object', + // title: 'mcp-tool-1', + // properties: { + // name: { type: 'string' }, + // age: { type: 'number' } + // }, + // description: 'print the name and age', + // required: ['name', 'age'] + // } + // } + // ], + // onChunk: mockOnChunk, + // enableReasoning: true, + // streamOutput: true + // }) + + // expect(result).toBeDefined() + // expect(ApiClientFactory.create).toHaveBeenCalledWith(mockProvider) + // expect(result.stream).toBeDefined() + + // const stream = result.stream! as ReadableStream + // const reader = stream.getReader() + + // const chunks: GenericChunk[] = [] + + // while (true) { + // const { done, value } = await reader.read() + // if (done) break + // chunks.push(value) + // } + + // reader.releaseLock() + + // const filteredChunks = chunks.map((chunk) => { + // if (chunk.type === ChunkType.THINKING_DELTA || chunk.type === ChunkType.THINKING_COMPLETE) { + // delete (chunk as any).thinking_millsec + // return chunk + // } + // return chunk + // }) + + // const expectedChunks: GenericChunk[] = [ + // { + // type: ChunkType.THINKING_START + // }, + // { + // type: ChunkType.THINKING_DELTA, + // text: '**Initiating File Retrieval**\n\nI\'ve determined that the `tool_mcp-tool-1` tool is suitable for this task. It seems the user intends to read a file, and this tool aligns with that objective. Currently, I\'m focusing on the necessary parameters. The `tool_mcp-tool-1` tool requires a `name` and `age`, which the user has helpfully provided: `{"name": "xxx", "age": 20}`. I\'m verifying the input.\n\n\n' + // }, + // { + // type: ChunkType.THINKING_COMPLETE, + // text: '**Initiating File Retrieval**\n\nI\'ve determined that the `tool_mcp-tool-1` tool is suitable for this task. It seems the user intends to read a file, and this tool aligns with that objective. Currently, I\'m focusing on the necessary parameters. The `tool_mcp-tool-1` tool requires a `name` and `age`, which the user has helpfully provided: `{"name": "xxx", "age": 20}`. I\'m verifying the input.\n\n\n' + // }, + // { + // type: ChunkType.TEXT_START + // }, + // { + // type: ChunkType.TEXT_DELTA, + // text: '好的,我将为您打印用户的' + // }, + // { + // type: ChunkType.TEXT_DELTA, + // text: '好的,我将为您打印用户的信息。\n' + // }, + // { + // type: ChunkType.TEXT_COMPLETE, + // text: '好的,我将为您打印用户的信息。\n' + // }, + // { + // type: ChunkType.MCP_TOOL_CREATED + // }, + // { + // type: ChunkType.MCP_TOOL_PENDING, + // responses: [ + // { + // id: 'mcp-tool-1', + // tool: { + // id: 'mcp-tool-1', + // serverId: 'mcp-server-1', + // serverName: 'mcp-server-1', + // name: 'mcp-tool-1', + // inputSchema: { + // type: 'object', + // title: 'mcp-tool-1', + // properties: { + // name: { type: 'string' }, + // age: { type: 'number' } + // }, + // description: 'print the name and age', + // required: ['name', 'age'] + // } + // }, + // arguments: { + // name: 'xxx', + // age: 20 + // }, + // status: 'pending' + // } + // ] + // }, + // { + // type: ChunkType.MCP_TOOL_IN_PROGRESS, + // responses: [ + // { + // id: 'mcp-tool-1', + // tool: { + // id: 'mcp-tool-1', + // serverId: 'mcp-server-1', + // serverName: 'mcp-server-1', + // name: 'mcp-tool-1', + // inputSchema: { + // type: 'object', + // title: 'mcp-tool-1', + // properties: { + // name: { type: 'string' }, + // age: { type: 'number' } + // }, + // description: 'print the name and age', + // required: ['name', 'age'] + // } + // }, + // arguments: { + // name: 'xxx', + // age: 20 + // }, + // status: 'invoking' + // } + // ] + // }, + // { + // type: ChunkType.MCP_TOOL_COMPLETE, + // responses: [ + // { + // id: 'mcp-tool-1', + // tool: { + // id: 'mcp-tool-1', + // serverId: 'mcp-server-1', + // serverName: 'mcp-server-1', + // name: 'mcp-tool-1', + // inputSchema: { + // type: 'object', + // title: 'mcp-tool-1', + // properties: { + // name: { type: 'string' }, + // age: { type: 'number' } + // }, + // description: 'print the name and age', + // required: ['name', 'age'] + // } + // }, + // arguments: { + // name: 'xxx', + // age: 20 + // }, + // status: 'done' + // } + // ] + // } + // ] + + // expect(filteredChunks).toEqual(expectedChunks) + // }) }) diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index 0bb1d9f6fe..71990b05b5 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -1,7 +1,7 @@ import { nanoid } from '@reduxjs/toolkit' import { DEFAULT_CONTEXTCOUNT, DEFAULT_TEMPERATURE, isMac } from '@renderer/config/constant' import { DEFAULT_MIN_APPS } from '@renderer/config/minapps' -import { SYSTEM_MODELS } from '@renderer/config/models' +import { isFunctionCallingModel, SYSTEM_MODELS } from '@renderer/config/models' import { TRANSLATE_PROMPT } from '@renderer/config/prompts' import db from '@renderer/databases' import i18n from '@renderer/i18n' @@ -1786,25 +1786,28 @@ const migrateConfig = { try { const { toolOrder } = state.inputTools const urlContextKey = 'url_context' - const webSearchIndex = toolOrder.visible.indexOf('web_search') - const knowledgeBaseIndex = toolOrder.visible.indexOf('knowledge_base') - if (webSearchIndex !== -1) { - toolOrder.visible.splice(webSearchIndex, 0, urlContextKey) - } else if (knowledgeBaseIndex !== -1) { - toolOrder.visible.splice(knowledgeBaseIndex, 0, urlContextKey) - } else { - toolOrder.visible.push(urlContextKey) + if (!toolOrder.visible.includes(urlContextKey)) { + const webSearchIndex = toolOrder.visible.indexOf('web_search') + const knowledgeBaseIndex = toolOrder.visible.indexOf('knowledge_base') + if (webSearchIndex !== -1) { + toolOrder.visible.splice(webSearchIndex, 0, urlContextKey) + } else if (knowledgeBaseIndex !== -1) { + toolOrder.visible.splice(knowledgeBaseIndex, 0, urlContextKey) + } else { + toolOrder.visible.push(urlContextKey) + } } - return state - } catch (error) { - return state - } - }, - '122': (state: RootState) => { - try { + + for (const assistant of state.assistants.assistants) { + if (assistant.settings?.toolUseMode === 'prompt' && isFunctionCallingModel(assistant.model)) { + assistant.settings.toolUseMode = 'function' + } + } + if (state.settings && typeof state.settings.webdavDisableStream === 'undefined') { state.settings.webdavDisableStream = false } + return state } catch (error) { return state