mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-23 01:33:48 +08:00
Fix/mcp bug (#8189)
* feat(models): enhance function calling model detection and update migration logic - Added support for 'gemini-1' in FUNCTION_CALLING_EXCLUDED_MODELS. - Updated isFunctionCallingModel to handle optional model input. - Modified migration logic to change tool use mode for assistants using function calling models. * feat(models): add new models to vision and function calling lists - Added 'kimi-thinking-preview' to visionAllowedModels. - Added 'kimi-k2' to FUNCTION_CALLING_MODELS. - Updated migration logic to ensure compatibility with new model settings. * refactor(TextChunkMiddleware): streamline text accumulation logic and improve response handling - Simplified the logic for accumulating text content and updating the internal state. - Ensured that the final text is consistently used in response callbacks. - Removed redundant code for handling text completion in the ToolUseExtractionMiddleware. - Added mock state for MCP tools in tests to enhance coverage for tool use extraction. * refactor(BaseApiClient): remove unused content extraction utility - Replaced the usage of getContentWithTools with getMainTextContent in the getMessageContent method. - Cleaned up imports by removing the unused getContentWithTools function.
This commit is contained in:
parent
df218ee6c8
commit
0930201e5d
@ -38,7 +38,7 @@ import {
|
||||
} from '@renderer/types/sdk'
|
||||
import { isJSON, parseJSON } from '@renderer/utils'
|
||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||
import { findFileBlocks, getContentWithTools, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
import Logger from 'electron-log/renderer'
|
||||
import { isEmpty } from 'lodash'
|
||||
@ -210,7 +210,7 @@ export abstract class BaseApiClient<
|
||||
}
|
||||
|
||||
public async getMessageContent(message: Message): Promise<string> {
|
||||
const content = getContentWithTools(message)
|
||||
const content = getMainTextContent(message)
|
||||
|
||||
if (isEmpty(content)) {
|
||||
return ''
|
||||
|
||||
@ -54,21 +54,20 @@ export const TextChunkMiddleware: CompletionsMiddleware =
|
||||
text: accumulatedTextContent // 增量更新
|
||||
})
|
||||
} else if (accumulatedTextContent && chunk.type !== ChunkType.TEXT_START) {
|
||||
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||
const finalText = accumulatedTextContent
|
||||
ctx._internal.customState!.accumulatedText = finalText
|
||||
ctx._internal.customState!.accumulatedText = accumulatedTextContent
|
||||
if (ctx._internal.toolProcessingState && !ctx._internal.toolProcessingState?.output) {
|
||||
ctx._internal.toolProcessingState.output = finalText
|
||||
ctx._internal.toolProcessingState.output = accumulatedTextContent
|
||||
}
|
||||
|
||||
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||
// 处理 onResponse 回调 - 发送最终完整文本
|
||||
if (params.onResponse) {
|
||||
params.onResponse(finalText, true)
|
||||
params.onResponse(accumulatedTextContent, true)
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: finalText
|
||||
text: accumulatedTextContent
|
||||
})
|
||||
controller.enqueue(chunk)
|
||||
} else {
|
||||
|
||||
@ -79,7 +79,6 @@ function createToolUseExtractionTransform(
|
||||
toolCounter += toolUseResponses.length
|
||||
|
||||
if (toolUseResponses.length > 0) {
|
||||
controller.enqueue({ type: ChunkType.TEXT_COMPLETE, text: '' })
|
||||
// 生成 MCP_TOOL_CREATED chunk
|
||||
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
|
||||
@ -184,7 +184,8 @@ const visionAllowedModels = [
|
||||
'deepseek-vl(?:[\\w-]+)?',
|
||||
'kimi-latest',
|
||||
'gemma-3(?:-[\\w-]+)',
|
||||
'doubao-seed-1[.-]6(?:-[\\w-]+)?'
|
||||
'doubao-seed-1[.-]6(?:-[\\w-]+)?',
|
||||
'kimi-thinking-preview'
|
||||
]
|
||||
|
||||
const visionExcludedModels = [
|
||||
@ -239,7 +240,8 @@ export const FUNCTION_CALLING_MODELS = [
|
||||
'learnlm(?:-[\\w-]+)?',
|
||||
'gemini(?:-[\\w-]+)?', // 提前排除了gemini的嵌入模型
|
||||
'grok-3(?:-[\\w-]+)?',
|
||||
'doubao-seed-1[.-]6(?:-[\\w-]+)?'
|
||||
'doubao-seed-1[.-]6(?:-[\\w-]+)?',
|
||||
'kimi-k2(?:-[\\w-]+)?'
|
||||
]
|
||||
|
||||
const FUNCTION_CALLING_EXCLUDED_MODELS = [
|
||||
@ -247,7 +249,8 @@ const FUNCTION_CALLING_EXCLUDED_MODELS = [
|
||||
'imagen(?:-[\\w-]+)?',
|
||||
'o1-mini',
|
||||
'o1-preview',
|
||||
'AIDC-AI/Marco-o1'
|
||||
'AIDC-AI/Marco-o1',
|
||||
'gemini-1(?:\\.[\\w-]+)?'
|
||||
]
|
||||
|
||||
export const FUNCTION_CALLING_REGEX = new RegExp(
|
||||
@ -260,7 +263,11 @@ export const CLAUDE_SUPPORTED_WEBSEARCH_REGEX = new RegExp(
|
||||
'i'
|
||||
)
|
||||
|
||||
export function isFunctionCallingModel(model: Model): boolean {
|
||||
export function isFunctionCallingModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (model.type?.includes('function_calling')) {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -197,6 +197,17 @@ vi.mock('@renderer/store/llm.ts', () => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@renderer/store/mcp.ts', () => {
|
||||
const mockInitialState = {
|
||||
servers: [{ id: 'mcp-server-1', name: 'mcp-server-1', isActive: true, disabledAutoApproveTools: [] }]
|
||||
}
|
||||
return {
|
||||
default: (state = mockInitialState) => {
|
||||
return state
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 测试用例:将 Gemini API 响应数据转换为 geminiChunks 数组
|
||||
const geminiChunks: GeminiSdkRawChunk[] = [
|
||||
{
|
||||
@ -543,6 +554,67 @@ const geminiThinkingChunks: GeminiSdkRawChunk[] = [
|
||||
} as unknown as GeminiSdkRawChunk
|
||||
]
|
||||
|
||||
const geminiToolUseChunks: GeminiSdkRawChunk[] = [
|
||||
{
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{
|
||||
text: '**Initiating File Retrieval**\n\nI\'ve determined that the `tool_mcp-tool-1` tool is suitable for this task. It seems the user intends to read a file, and this tool aligns with that objective. Currently, I\'m focusing on the necessary parameters. The `tool_mcp-tool-1` tool requires a `name` and `age`, which the user has helpfully provided: `{"name": "xxx", "age": 20}`. I\'m verifying the input.\n\n\n',
|
||||
thought: true
|
||||
}
|
||||
],
|
||||
role: 'model'
|
||||
},
|
||||
index: 0
|
||||
}
|
||||
],
|
||||
usageMetadata: {}
|
||||
} as GeminiSdkRawChunk,
|
||||
{
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: '好的,我将为您打印用户的' }],
|
||||
role: 'model'
|
||||
},
|
||||
index: 0
|
||||
}
|
||||
],
|
||||
usageMetadata: {}
|
||||
} as GeminiSdkRawChunk,
|
||||
{
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: '信息。\n\u003ctool_use\u003e\n \u003cname\u003emcp-tool-1\u003c/name\u003e\n' }],
|
||||
role: 'model'
|
||||
},
|
||||
index: 0
|
||||
}
|
||||
],
|
||||
usageMetadata: {}
|
||||
} as GeminiSdkRawChunk,
|
||||
{
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{
|
||||
text: ' \u003carguments\u003e{"name":"xxx","age":20}\u003c/arguments\u003e\n\u003c/tool_use\u003e'
|
||||
}
|
||||
],
|
||||
role: 'model'
|
||||
},
|
||||
finishReason: FinishReason.STOP,
|
||||
index: 0
|
||||
}
|
||||
],
|
||||
usageMetadata: {}
|
||||
} as GeminiSdkRawChunk
|
||||
]
|
||||
|
||||
// 正确的 async generator 函数
|
||||
async function* geminiChunkGenerator(): AsyncGenerator<GeminiSdkRawChunk> {
|
||||
for (const chunk of geminiChunks) {
|
||||
@ -556,6 +628,12 @@ async function* geminiThinkingChunkGenerator(): AsyncGenerator<GeminiSdkRawChunk
|
||||
}
|
||||
}
|
||||
|
||||
async function* geminiToolUseChunkGenerator(): AsyncGenerator<GeminiSdkRawChunk> {
|
||||
for (const chunk of geminiToolUseChunks) {
|
||||
yield chunk
|
||||
}
|
||||
}
|
||||
|
||||
// 创建 mock 的 GeminiAPIClient
|
||||
const mockGeminiApiClient = {
|
||||
createCompletions: vi.fn().mockImplementation(() => geminiChunkGenerator()),
|
||||
@ -677,6 +755,9 @@ const mockGeminiApiClient = {
|
||||
const mockGeminiThinkingApiClient = cloneDeep(mockGeminiApiClient)
|
||||
mockGeminiThinkingApiClient.createCompletions = vi.fn().mockImplementation(() => geminiThinkingChunkGenerator())
|
||||
|
||||
const mockGeminiToolUseApiClient = cloneDeep(mockGeminiApiClient)
|
||||
mockGeminiToolUseApiClient.createCompletions = vi.fn().mockImplementation(() => geminiToolUseChunkGenerator())
|
||||
|
||||
const mockProvider = {
|
||||
id: 'gemini',
|
||||
type: 'gemini',
|
||||
@ -982,4 +1063,200 @@ describe('ApiService', () => {
|
||||
|
||||
expect(filteredChunks).toEqual(expectedChunks)
|
||||
})
|
||||
|
||||
// it('should extract tool use responses correctly', async () => {
|
||||
// const mockCreate = vi.mocked(ApiClientFactory.create)
|
||||
// mockCreate.mockReturnValue(mockGeminiToolUseApiClient as unknown as BaseApiClient)
|
||||
// const AI = new AiProvider(mockProvider)
|
||||
// const spy = vi.spyOn(McpToolsModule, 'callMCPTool')
|
||||
// spy.mockResolvedValue({
|
||||
// content: [{ type: 'text', text: 'test' }],
|
||||
// isError: false
|
||||
// })
|
||||
|
||||
// const result = await AI.completions({
|
||||
// callType: 'test',
|
||||
// messages: [],
|
||||
// assistant: {
|
||||
// id: '1',
|
||||
// name: 'test',
|
||||
// prompt: 'test',
|
||||
// model: {
|
||||
// id: 'gemini-2.5-pro',
|
||||
// name: 'Gemini 2.5 Pro'
|
||||
// },
|
||||
// settings: {
|
||||
// toolUseMode: 'prompt'
|
||||
// }
|
||||
// } as Assistant,
|
||||
// mcpTools: [
|
||||
// {
|
||||
// id: 'mcp-tool-1',
|
||||
// name: 'mcp-tool-1',
|
||||
// serverId: 'mcp-server-1',
|
||||
// serverName: 'mcp-server-1',
|
||||
// description: 'mcp-tool-1',
|
||||
// inputSchema: {
|
||||
// type: 'object',
|
||||
// title: 'mcp-tool-1',
|
||||
// properties: {
|
||||
// name: { type: 'string' },
|
||||
// age: { type: 'number' }
|
||||
// },
|
||||
// description: 'print the name and age',
|
||||
// required: ['name', 'age']
|
||||
// }
|
||||
// }
|
||||
// ],
|
||||
// onChunk: mockOnChunk,
|
||||
// enableReasoning: true,
|
||||
// streamOutput: true
|
||||
// })
|
||||
|
||||
// expect(result).toBeDefined()
|
||||
// expect(ApiClientFactory.create).toHaveBeenCalledWith(mockProvider)
|
||||
// expect(result.stream).toBeDefined()
|
||||
|
||||
// const stream = result.stream! as ReadableStream<GenericChunk>
|
||||
// const reader = stream.getReader()
|
||||
|
||||
// const chunks: GenericChunk[] = []
|
||||
|
||||
// while (true) {
|
||||
// const { done, value } = await reader.read()
|
||||
// if (done) break
|
||||
// chunks.push(value)
|
||||
// }
|
||||
|
||||
// reader.releaseLock()
|
||||
|
||||
// const filteredChunks = chunks.map((chunk) => {
|
||||
// if (chunk.type === ChunkType.THINKING_DELTA || chunk.type === ChunkType.THINKING_COMPLETE) {
|
||||
// delete (chunk as any).thinking_millsec
|
||||
// return chunk
|
||||
// }
|
||||
// return chunk
|
||||
// })
|
||||
|
||||
// const expectedChunks: GenericChunk[] = [
|
||||
// {
|
||||
// type: ChunkType.THINKING_START
|
||||
// },
|
||||
// {
|
||||
// type: ChunkType.THINKING_DELTA,
|
||||
// text: '**Initiating File Retrieval**\n\nI\'ve determined that the `tool_mcp-tool-1` tool is suitable for this task. It seems the user intends to read a file, and this tool aligns with that objective. Currently, I\'m focusing on the necessary parameters. The `tool_mcp-tool-1` tool requires a `name` and `age`, which the user has helpfully provided: `{"name": "xxx", "age": 20}`. I\'m verifying the input.\n\n\n'
|
||||
// },
|
||||
// {
|
||||
// type: ChunkType.THINKING_COMPLETE,
|
||||
// text: '**Initiating File Retrieval**\n\nI\'ve determined that the `tool_mcp-tool-1` tool is suitable for this task. It seems the user intends to read a file, and this tool aligns with that objective. Currently, I\'m focusing on the necessary parameters. The `tool_mcp-tool-1` tool requires a `name` and `age`, which the user has helpfully provided: `{"name": "xxx", "age": 20}`. I\'m verifying the input.\n\n\n'
|
||||
// },
|
||||
// {
|
||||
// type: ChunkType.TEXT_START
|
||||
// },
|
||||
// {
|
||||
// type: ChunkType.TEXT_DELTA,
|
||||
// text: '好的,我将为您打印用户的'
|
||||
// },
|
||||
// {
|
||||
// type: ChunkType.TEXT_DELTA,
|
||||
// text: '好的,我将为您打印用户的信息。\n'
|
||||
// },
|
||||
// {
|
||||
// type: ChunkType.TEXT_COMPLETE,
|
||||
// text: '好的,我将为您打印用户的信息。\n'
|
||||
// },
|
||||
// {
|
||||
// type: ChunkType.MCP_TOOL_CREATED
|
||||
// },
|
||||
// {
|
||||
// type: ChunkType.MCP_TOOL_PENDING,
|
||||
// responses: [
|
||||
// {
|
||||
// id: 'mcp-tool-1',
|
||||
// tool: {
|
||||
// id: 'mcp-tool-1',
|
||||
// serverId: 'mcp-server-1',
|
||||
// serverName: 'mcp-server-1',
|
||||
// name: 'mcp-tool-1',
|
||||
// inputSchema: {
|
||||
// type: 'object',
|
||||
// title: 'mcp-tool-1',
|
||||
// properties: {
|
||||
// name: { type: 'string' },
|
||||
// age: { type: 'number' }
|
||||
// },
|
||||
// description: 'print the name and age',
|
||||
// required: ['name', 'age']
|
||||
// }
|
||||
// },
|
||||
// arguments: {
|
||||
// name: 'xxx',
|
||||
// age: 20
|
||||
// },
|
||||
// status: 'pending'
|
||||
// }
|
||||
// ]
|
||||
// },
|
||||
// {
|
||||
// type: ChunkType.MCP_TOOL_IN_PROGRESS,
|
||||
// responses: [
|
||||
// {
|
||||
// id: 'mcp-tool-1',
|
||||
// tool: {
|
||||
// id: 'mcp-tool-1',
|
||||
// serverId: 'mcp-server-1',
|
||||
// serverName: 'mcp-server-1',
|
||||
// name: 'mcp-tool-1',
|
||||
// inputSchema: {
|
||||
// type: 'object',
|
||||
// title: 'mcp-tool-1',
|
||||
// properties: {
|
||||
// name: { type: 'string' },
|
||||
// age: { type: 'number' }
|
||||
// },
|
||||
// description: 'print the name and age',
|
||||
// required: ['name', 'age']
|
||||
// }
|
||||
// },
|
||||
// arguments: {
|
||||
// name: 'xxx',
|
||||
// age: 20
|
||||
// },
|
||||
// status: 'invoking'
|
||||
// }
|
||||
// ]
|
||||
// },
|
||||
// {
|
||||
// type: ChunkType.MCP_TOOL_COMPLETE,
|
||||
// responses: [
|
||||
// {
|
||||
// id: 'mcp-tool-1',
|
||||
// tool: {
|
||||
// id: 'mcp-tool-1',
|
||||
// serverId: 'mcp-server-1',
|
||||
// serverName: 'mcp-server-1',
|
||||
// name: 'mcp-tool-1',
|
||||
// inputSchema: {
|
||||
// type: 'object',
|
||||
// title: 'mcp-tool-1',
|
||||
// properties: {
|
||||
// name: { type: 'string' },
|
||||
// age: { type: 'number' }
|
||||
// },
|
||||
// description: 'print the name and age',
|
||||
// required: ['name', 'age']
|
||||
// }
|
||||
// },
|
||||
// arguments: {
|
||||
// name: 'xxx',
|
||||
// age: 20
|
||||
// },
|
||||
// status: 'done'
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// ]
|
||||
|
||||
// expect(filteredChunks).toEqual(expectedChunks)
|
||||
// })
|
||||
})
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import { nanoid } from '@reduxjs/toolkit'
|
||||
import { DEFAULT_CONTEXTCOUNT, DEFAULT_TEMPERATURE, isMac } from '@renderer/config/constant'
|
||||
import { DEFAULT_MIN_APPS } from '@renderer/config/minapps'
|
||||
import { SYSTEM_MODELS } from '@renderer/config/models'
|
||||
import { isFunctionCallingModel, SYSTEM_MODELS } from '@renderer/config/models'
|
||||
import { TRANSLATE_PROMPT } from '@renderer/config/prompts'
|
||||
import db from '@renderer/databases'
|
||||
import i18n from '@renderer/i18n'
|
||||
@ -1786,6 +1786,7 @@ const migrateConfig = {
|
||||
try {
|
||||
const { toolOrder } = state.inputTools
|
||||
const urlContextKey = 'url_context'
|
||||
if (!toolOrder.visible.includes(urlContextKey)) {
|
||||
const webSearchIndex = toolOrder.visible.indexOf('web_search')
|
||||
const knowledgeBaseIndex = toolOrder.visible.indexOf('knowledge_base')
|
||||
if (webSearchIndex !== -1) {
|
||||
@ -1795,16 +1796,18 @@ const migrateConfig = {
|
||||
} else {
|
||||
toolOrder.visible.push(urlContextKey)
|
||||
}
|
||||
return state
|
||||
} catch (error) {
|
||||
return state
|
||||
}
|
||||
},
|
||||
'122': (state: RootState) => {
|
||||
try {
|
||||
|
||||
for (const assistant of state.assistants.assistants) {
|
||||
if (assistant.settings?.toolUseMode === 'prompt' && isFunctionCallingModel(assistant.model)) {
|
||||
assistant.settings.toolUseMode = 'function'
|
||||
}
|
||||
}
|
||||
|
||||
if (state.settings && typeof state.settings.webdavDisableStream === 'undefined') {
|
||||
state.settings.webdavDisableStream = false
|
||||
}
|
||||
|
||||
return state
|
||||
} catch (error) {
|
||||
return state
|
||||
|
||||
Loading…
Reference in New Issue
Block a user