feat: add test case for mcp response in apiservice (#8300)

* refactor: 将工具调用逻辑移动到中间件文件

* feat(日志): 在流处理中添加调试日志记录

添加对分块数据的调试日志记录,便于跟踪流处理过程中的数据流动

* test(api-service): 添加工具调用响应测试用例
This commit is contained in:
Phantom 2025-07-21 14:48:24 +08:00 committed by GitHub
parent 8967a82107
commit e7fd97deef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 480 additions and 400 deletions

View File

@ -561,6 +561,7 @@ export class GeminiAPIClient extends BaseApiClient<
let isFirstThinkingChunk = true let isFirstThinkingChunk = true
return () => ({ return () => ({
async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) { async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
logger.silly('chunk', chunk)
if (chunk.candidates && chunk.candidates.length > 0) { if (chunk.candidates && chunk.candidates.length > 0) {
for (const candidate of chunk.candidates) { for (const candidate of chunk.candidates) {
if (candidate.content) { if (candidate.content) {

View File

@ -64,6 +64,7 @@ const FinalChunkConsumerMiddleware: CompletionsMiddleware =
try { try {
while (true) { while (true) {
const { done, value: chunk } = await reader.read() const { done, value: chunk } = await reader.read()
logger.silly('chunk', chunk)
if (done) { if (done) {
logger.debug(`Input stream finished.`) logger.debug(`Input stream finished.`)
break break

View File

@ -1,8 +1,15 @@
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types' import { MCPCallToolResponse, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk' import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk'
import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk' import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk'
import { parseAndCallTools } from '@renderer/utils/mcp-tools' import {
callMCPTool,
getMcpServerByTool,
isToolAutoApproved,
parseToolUse,
upsertMCPToolResponse
} from '@renderer/utils/mcp-tools'
import { confirmSameNameTools, requestToolConfirmation, setToolIdToNameMapping } from '@renderer/utils/userConfirmation'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types' import { CompletionsContext, CompletionsMiddleware } from '../types'
@ -369,4 +376,207 @@ function getCurrentReqMessages(ctx: CompletionsContext): SdkMessageParam[] {
return ctx.apiClientInstance.extractMessagesFromSdkPayload(sdkPayload) return ctx.apiClientInstance.extractMessagesFromSdkPayload(sdkPayload)
} }
export default McpToolChunkMiddleware export async function parseAndCallTools<R>(
tools: MCPToolResponse[],
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[],
abortSignal?: AbortSignal,
topicId?: CompletionsParams['topicId']
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }>
export async function parseAndCallTools<R>(
content: string,
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[],
abortSignal?: AbortSignal,
topicId?: CompletionsParams['topicId']
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }>
export async function parseAndCallTools<R>(
content: string | MCPToolResponse[],
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[],
abortSignal?: AbortSignal,
topicId?: CompletionsParams['topicId']
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }> {
const toolResults: R[] = []
let curToolResponses: MCPToolResponse[] = []
if (Array.isArray(content)) {
curToolResponses = content
} else {
// process tool use
curToolResponses = parseToolUse(content, mcpTools || [], 0)
}
if (!curToolResponses || curToolResponses.length === 0) {
return { toolResults, confirmedToolResponses: [] }
}
for (const toolResponse of curToolResponses) {
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'pending'
},
onChunk!
)
}
// 创建工具确认Promise映射并立即处理每个确认
const confirmedTools: MCPToolResponse[] = []
const pendingPromises: Promise<void>[] = []
curToolResponses.forEach((toolResponse) => {
const server = getMcpServerByTool(toolResponse.tool)
const isAutoApproveEnabled = isToolAutoApproved(toolResponse.tool, server)
let confirmationPromise: Promise<boolean>
if (isAutoApproveEnabled) {
confirmationPromise = Promise.resolve(true)
} else {
setToolIdToNameMapping(toolResponse.id, toolResponse.tool.name)
confirmationPromise = requestToolConfirmation(toolResponse.id, abortSignal).then((confirmed) => {
if (confirmed && server) {
// 自动确认其他同名的待确认工具
confirmSameNameTools(toolResponse.tool.name)
}
return confirmed
})
}
const processingPromise = confirmationPromise
.then(async (confirmed) => {
if (confirmed) {
// 立即更新为invoking状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'invoking'
},
onChunk!
)
// 执行工具调用
try {
const images: string[] = []
const toolCallResponse = await callMCPTool(toolResponse, topicId, model.name)
// 立即更新为done状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'done',
response: toolCallResponse
},
onChunk!
)
// 处理图片
for (const content of toolCallResponse.content) {
if (content.type === 'image' && content.data) {
images.push(`data:${content.mimeType};base64,${content.data}`)
}
}
if (images.length) {
onChunk?.({
type: ChunkType.IMAGE_CREATED
})
onChunk?.({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: images
}
})
}
// 转换消息并添加到结果
const convertedMessage = convertToMessage(toolResponse, toolCallResponse, model)
if (convertedMessage) {
confirmedTools.push(toolResponse)
toolResults.push(convertedMessage)
}
} catch (error) {
logger.error(`Error executing tool ${toolResponse.id}:`, error)
// 更新为错误状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'done',
response: {
isError: true,
content: [
{
type: 'text',
text: `Error executing tool: ${error instanceof Error ? error.message : 'Unknown error'}`
}
]
}
},
onChunk!
)
}
} else {
// 立即更新为cancelled状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'cancelled',
response: {
isError: false,
content: [
{
type: 'text',
text: 'Tool call cancelled by user.'
}
]
}
},
onChunk!
)
}
})
.catch((error) => {
logger.error(`Error waiting for tool confirmation ${toolResponse.id}:`, error)
// 立即更新为cancelled状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'cancelled',
response: {
isError: true,
content: [
{
type: 'text',
text: `Error in confirmation process: ${error instanceof Error ? error.message : 'Unknown error'}`
}
]
}
},
onChunk!
)
})
pendingPromises.push(processingPromise)
})
// 等待所有工具处理完成(但每个工具的状态已经实时更新)
await Promise.all(pendingPromises)
return { toolResults, confirmedToolResponses: confirmedTools }
}

View File

@ -6,8 +6,10 @@ import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient' import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient' import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas' import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { Assistant, Provider, WebSearchSource } from '@renderer/types' import { isVisionModel } from '@renderer/config/models'
import { Assistant, MCPCallToolResponse, MCPToolResponse, Model, Provider, WebSearchSource } from '@renderer/types'
import { import {
Chunk,
ChunkType, ChunkType,
LLMResponseCompleteChunk, LLMResponseCompleteChunk,
LLMWebSearchCompleteChunk, LLMWebSearchCompleteChunk,
@ -15,7 +17,15 @@ import {
TextStartChunk, TextStartChunk,
ThinkingStartChunk ThinkingStartChunk
} from '@renderer/types/chunk' } from '@renderer/types/chunk'
import { GeminiSdkRawChunk, OpenAISdkRawChunk, OpenAISdkRawContentSource } from '@renderer/types/sdk' import {
GeminiSdkMessageParam,
GeminiSdkRawChunk,
GeminiSdkToolCall,
OpenAISdkRawChunk,
OpenAISdkRawContentSource
} from '@renderer/types/sdk'
import * as McpToolsModule from '@renderer/utils/mcp-tools'
import { mcpToolCallResponseToGeminiMessage } from '@renderer/utils/mcp-tools'
import { cloneDeep } from 'lodash' import { cloneDeep } from 'lodash'
import OpenAI from 'openai' import OpenAI from 'openai'
import { ChatCompletionChunk } from 'openai/resources' import { ChatCompletionChunk } from 'openai/resources'
@ -610,11 +620,32 @@ const geminiToolUseChunks: GeminiSdkRawChunk[] = [
], ],
role: 'model' role: 'model'
}, },
finishReason: FinishReason.STOP,
index: 0 index: 0
} }
], ],
usageMetadata: {} usageMetadata: {}
} as GeminiSdkRawChunk,
{
candidates: [
{
content: {
parts: [
{
functionCall: {
name: 'mcp-tool-1',
args: {
name: 'alice',
age: 13
}
} as GeminiSdkToolCall
}
],
role: 'model'
},
finishReason: FinishReason.STOP
}
],
usageMetadata: {}
} as GeminiSdkRawChunk } as GeminiSdkRawChunk
] ]
@ -895,6 +926,7 @@ const mockOpenaiApiClient = {
choice.delta && choice.delta &&
Object.keys(choice.delta).length > 0 && Object.keys(choice.delta).length > 0 &&
(!('content' in choice.delta) || (!('content' in choice.delta) ||
(choice.delta.tool_calls && choice.delta.tool_calls.length > 0) ||
(typeof choice.delta.content === 'string' && choice.delta.content !== '') || (typeof choice.delta.content === 'string' && choice.delta.content !== '') ||
(typeof (choice.delta as any).reasoning_content === 'string' && (typeof (choice.delta as any).reasoning_content === 'string' &&
(choice.delta as any).reasoning_content !== '') || (choice.delta as any).reasoning_content !== '') ||
@ -1141,6 +1173,14 @@ mockGeminiThinkingApiClient.createCompletions = vi.fn().mockImplementation(() =>
const mockGeminiToolUseApiClient = cloneDeep(mockGeminiApiClient) const mockGeminiToolUseApiClient = cloneDeep(mockGeminiApiClient)
mockGeminiToolUseApiClient.createCompletions = vi.fn().mockImplementation(() => geminiToolUseChunkGenerator()) mockGeminiToolUseApiClient.createCompletions = vi.fn().mockImplementation(() => geminiToolUseChunkGenerator())
mockGeminiToolUseApiClient.convertMcpToolResponseToSdkMessageParam = vi
.fn()
.mockImplementation(
(mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model): GeminiSdkMessageParam | undefined => {
// mcp使用tooluse
return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model))
}
)
const mockProvider = { const mockProvider = {
id: 'gemini', id: 'gemini',
@ -1537,199 +1577,235 @@ describe('ApiService', () => {
expect(filteredChunks).toEqual(expectedChunks) expect(filteredChunks).toEqual(expectedChunks)
}) })
// it('should extract tool use responses correctly', async () => { it('should extract tool use responses correctly', async () => {
// const mockCreate = vi.mocked(ApiClientFactory.create) const mockCreate = vi.mocked(ApiClientFactory.create)
// mockCreate.mockReturnValue(mockGeminiToolUseApiClient as unknown as BaseApiClient) mockCreate.mockReturnValue(mockGeminiToolUseApiClient as unknown as BaseApiClient)
// const AI = new AiProvider(mockProvider) const AI = new AiProvider(mockProvider)
// const spy = vi.spyOn(McpToolsModule, 'callMCPTool')
// spy.mockResolvedValue({
// content: [{ type: 'text', text: 'test' }],
// isError: false
// })
// const result = await AI.completions({ const mcpChunks: GenericChunk[] = []
// callType: 'test', const firstResponseChunks: GenericChunk[] = []
// messages: [],
// assistant: {
// id: '1',
// name: 'test',
// prompt: 'test',
// model: {
// id: 'gemini-2.5-pro',
// name: 'Gemini 2.5 Pro'
// },
// settings: {
// toolUseMode: 'prompt'
// }
// } as Assistant,
// mcpTools: [
// {
// id: 'mcp-tool-1',
// name: 'mcp-tool-1',
// serverId: 'mcp-server-1',
// serverName: 'mcp-server-1',
// description: 'mcp-tool-1',
// inputSchema: {
// type: 'object',
// title: 'mcp-tool-1',
// properties: {
// name: { type: 'string' },
// age: { type: 'number' }
// },
// description: 'print the name and age',
// required: ['name', 'age']
// }
// }
// ],
// onChunk: mockOnChunk,
// enableReasoning: true,
// streamOutput: true
// })
// expect(result).toBeDefined() const spy = vi.spyOn(McpToolsModule, 'callMCPTool')
// expect(ApiClientFactory.create).toHaveBeenCalledWith(mockProvider) spy.mockResolvedValue({
// expect(result.stream).toBeDefined() content: [{ type: 'text', text: 'test' }],
isError: false
})
// const stream = result.stream! as ReadableStream<GenericChunk> const onChunk = vi.fn((chunk: Chunk) => {
// const reader = stream.getReader() mcpChunks.push(chunk)
})
// const chunks: GenericChunk[] = [] const result = await AI.completions({
callType: 'test',
messages: [],
assistant: {
id: '1',
name: 'test',
prompt: 'test',
model: {
id: 'gemini-2.5-pro',
name: 'Gemini 2.5 Pro'
},
settings: {
toolUseMode: 'prompt'
}
} as Assistant,
mcpTools: [
{
id: 'mcp-tool-1',
name: 'mcp-tool-1',
serverId: 'mcp-server-1',
serverName: 'mcp-server-1',
description: 'mcp-tool-1',
inputSchema: {
type: 'object',
title: 'mcp-tool-1',
properties: {
name: { type: 'string' },
age: { type: 'number' }
},
description: 'print the name and age',
required: ['name', 'age']
}
}
],
onChunk,
enableReasoning: true,
streamOutput: true
})
// while (true) { expect(result).toBeDefined()
// const { done, value } = await reader.read() expect(ApiClientFactory.create).toHaveBeenCalledWith(mockProvider)
// if (done) break expect(result.stream).toBeDefined()
// chunks.push(value)
// }
// reader.releaseLock() const stream = result.stream! as ReadableStream<GenericChunk>
const reader = stream.getReader()
// const filteredChunks = chunks.map((chunk) => { while (true) {
// if (chunk.type === ChunkType.THINKING_DELTA || chunk.type === ChunkType.THINKING_COMPLETE) { const { done, value: chunk } = await reader.read()
// delete (chunk as any).thinking_millsec if (done) break
// return chunk firstResponseChunks.push(chunk)
// } }
// return chunk
// })
// const expectedChunks: GenericChunk[] = [ reader.releaseLock()
// {
// type: ChunkType.THINKING_START
// },
// {
// type: ChunkType.THINKING_DELTA,
// text: '**Initiating File Retrieval**\n\nI\'ve determined that the `tool_mcp-tool-1` tool is suitable for this task. It seems the user intends to read a file, and this tool aligns with that objective. Currently, I\'m focusing on the necessary parameters. The `tool_mcp-tool-1` tool requires a `name` and `age`, which the user has helpfully provided: `{"name": "xxx", "age": 20}`. I\'m verifying the input.\n\n\n'
// },
// {
// type: ChunkType.THINKING_COMPLETE,
// text: '**Initiating File Retrieval**\n\nI\'ve determined that the `tool_mcp-tool-1` tool is suitable for this task. It seems the user intends to read a file, and this tool aligns with that objective. Currently, I\'m focusing on the necessary parameters. The `tool_mcp-tool-1` tool requires a `name` and `age`, which the user has helpfully provided: `{"name": "xxx", "age": 20}`. I\'m verifying the input.\n\n\n'
// },
// {
// type: ChunkType.TEXT_START
// },
// {
// type: ChunkType.TEXT_DELTA,
// text: '好的,我将为您打印用户的'
// },
// {
// type: ChunkType.TEXT_DELTA,
// text: '好的,我将为您打印用户的信息。\n'
// },
// {
// type: ChunkType.TEXT_COMPLETE,
// text: '好的,我将为您打印用户的信息。\n'
// },
// {
// type: ChunkType.MCP_TOOL_CREATED
// },
// {
// type: ChunkType.MCP_TOOL_PENDING,
// responses: [
// {
// id: 'mcp-tool-1',
// tool: {
// id: 'mcp-tool-1',
// serverId: 'mcp-server-1',
// serverName: 'mcp-server-1',
// name: 'mcp-tool-1',
// inputSchema: {
// type: 'object',
// title: 'mcp-tool-1',
// properties: {
// name: { type: 'string' },
// age: { type: 'number' }
// },
// description: 'print the name and age',
// required: ['name', 'age']
// }
// },
// arguments: {
// name: 'xxx',
// age: 20
// },
// status: 'pending'
// }
// ]
// },
// {
// type: ChunkType.MCP_TOOL_IN_PROGRESS,
// responses: [
// {
// id: 'mcp-tool-1',
// tool: {
// id: 'mcp-tool-1',
// serverId: 'mcp-server-1',
// serverName: 'mcp-server-1',
// name: 'mcp-tool-1',
// inputSchema: {
// type: 'object',
// title: 'mcp-tool-1',
// properties: {
// name: { type: 'string' },
// age: { type: 'number' }
// },
// description: 'print the name and age',
// required: ['name', 'age']
// }
// },
// arguments: {
// name: 'xxx',
// age: 20
// },
// status: 'invoking'
// }
// ]
// },
// {
// type: ChunkType.MCP_TOOL_COMPLETE,
// responses: [
// {
// id: 'mcp-tool-1',
// tool: {
// id: 'mcp-tool-1',
// serverId: 'mcp-server-1',
// serverName: 'mcp-server-1',
// name: 'mcp-tool-1',
// inputSchema: {
// type: 'object',
// title: 'mcp-tool-1',
// properties: {
// name: { type: 'string' },
// age: { type: 'number' }
// },
// description: 'print the name and age',
// required: ['name', 'age']
// }
// },
// arguments: {
// name: 'xxx',
// age: 20
// },
// status: 'done'
// }
// ]
// }
// ]
// expect(filteredChunks).toEqual(expectedChunks) const filteredFirstResponseChunks = firstResponseChunks.map((chunk) => {
// }) if (chunk.type === ChunkType.THINKING_DELTA || chunk.type === ChunkType.THINKING_COMPLETE) {
delete (chunk as any).thinking_millsec
return chunk
}
return chunk
})
const expectedFirstResponseChunks: GenericChunk[] = [
{
type: ChunkType.THINKING_START
},
{
type: ChunkType.THINKING_DELTA,
text: '**Initiating File Retrieval**\n\nI\'ve determined that the `tool_mcp-tool-1` tool is suitable for this task. It seems the user intends to read a file, and this tool aligns with that objective. Currently, I\'m focusing on the necessary parameters. The `tool_mcp-tool-1` tool requires a `name` and `age`, which the user has helpfully provided: `{"name": "xxx", "age": 20}`. I\'m verifying the input.\n\n\n'
},
{
type: ChunkType.THINKING_COMPLETE,
text: '**Initiating File Retrieval**\n\nI\'ve determined that the `tool_mcp-tool-1` tool is suitable for this task. It seems the user intends to read a file, and this tool aligns with that objective. Currently, I\'m focusing on the necessary parameters. The `tool_mcp-tool-1` tool requires a `name` and `age`, which the user has helpfully provided: `{"name": "xxx", "age": 20}`. I\'m verifying the input.\n\n\n'
},
{
type: ChunkType.TEXT_START
},
{
type: ChunkType.TEXT_DELTA,
text: '好的,我将为您打印用户的'
},
{
type: ChunkType.TEXT_DELTA,
text: '好的,我将为您打印用户的信息。\n'
},
{
type: ChunkType.TEXT_COMPLETE,
text: '好的,我将为您打印用户的信息。\n'
},
{
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
completion_tokens: 0,
prompt_tokens: 0,
total_tokens: 0
}
}
}
]
const expectedMcpResponseChunks: GenericChunk[] = [
{
type: ChunkType.MCP_TOOL_PENDING,
responses: [
{
id: 'mcp-tool-1-0',
tool: {
description: 'mcp-tool-1',
id: 'mcp-tool-1',
serverId: 'mcp-server-1',
serverName: 'mcp-server-1',
name: 'mcp-tool-1',
inputSchema: {
type: 'object',
title: 'mcp-tool-1',
properties: {
name: { type: 'string' },
age: { type: 'number' }
},
description: 'print the name and age',
required: ['name', 'age']
}
},
toolUseId: 'mcp-tool-1',
arguments: {
name: 'xxx',
age: 20
},
status: 'pending'
}
]
},
{
type: ChunkType.MCP_TOOL_IN_PROGRESS,
responses: [
{
id: 'mcp-tool-1-0',
response: undefined,
tool: {
description: 'mcp-tool-1',
id: 'mcp-tool-1',
serverId: 'mcp-server-1',
serverName: 'mcp-server-1',
name: 'mcp-tool-1',
inputSchema: {
type: 'object',
title: 'mcp-tool-1',
properties: {
name: { type: 'string' },
age: { type: 'number' }
},
description: 'print the name and age',
required: ['name', 'age']
}
},
toolUseId: 'mcp-tool-1',
arguments: {
name: 'xxx',
age: 20
},
status: 'invoking'
}
]
},
{
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [
{
id: 'mcp-tool-1-0',
tool: {
description: 'mcp-tool-1',
id: 'mcp-tool-1',
serverId: 'mcp-server-1',
serverName: 'mcp-server-1',
name: 'mcp-tool-1',
inputSchema: {
type: 'object',
title: 'mcp-tool-1',
properties: {
name: { type: 'string' },
age: { type: 'number' }
},
description: 'print the name and age',
required: ['name', 'age']
}
},
response: {
content: [
{
text: 'test',
type: 'text'
}
],
isError: false
},
toolUseId: 'mcp-tool-1',
arguments: {
name: 'xxx',
age: 20
},
status: 'done'
}
]
},
{
type: ChunkType.LLM_RESPONSE_CREATED
}
]
expect(filteredFirstResponseChunks).toEqual(expectedFirstResponseChunks)
expect(mcpChunks).toEqual(expectedMcpResponseChunks)
})
}) })

View File

@ -27,9 +27,6 @@ import {
ChatCompletionTool ChatCompletionTool
} from 'openai/resources' } from 'openai/resources'
import { CompletionsParams } from '../aiCore/middleware/schemas'
import { confirmSameNameTools, requestToolConfirmation, setToolIdToNameMapping } from './userConfirmation'
const logger = loggerService.withContext('Utils:MCPTools') const logger = loggerService.withContext('Utils:MCPTools')
const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install' const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install'
@ -534,211 +531,6 @@ export function parseToolUse(content: string, mcpTools: MCPTool[], startIdx: num
return tools return tools
} }
export async function parseAndCallTools<R>(
tools: MCPToolResponse[],
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[],
abortSignal?: AbortSignal,
topicId?: CompletionsParams['topicId']
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }>
export async function parseAndCallTools<R>(
content: string,
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[],
abortSignal?: AbortSignal,
topicId?: CompletionsParams['topicId']
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }>
export async function parseAndCallTools<R>(
content: string | MCPToolResponse[],
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[],
abortSignal?: AbortSignal,
topicId?: CompletionsParams['topicId']
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }> {
const toolResults: R[] = []
let curToolResponses: MCPToolResponse[] = []
if (Array.isArray(content)) {
curToolResponses = content
} else {
// process tool use
curToolResponses = parseToolUse(content, mcpTools || [], 0)
}
if (!curToolResponses || curToolResponses.length === 0) {
return { toolResults, confirmedToolResponses: [] }
}
for (const toolResponse of curToolResponses) {
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'pending'
},
onChunk!
)
}
// 创建工具确认Promise映射并立即处理每个确认
const confirmedTools: MCPToolResponse[] = []
const pendingPromises: Promise<void>[] = []
curToolResponses.forEach((toolResponse) => {
const server = getMcpServerByTool(toolResponse.tool)
const isAutoApproveEnabled = isToolAutoApproved(toolResponse.tool, server)
let confirmationPromise: Promise<boolean>
if (isAutoApproveEnabled) {
confirmationPromise = Promise.resolve(true)
} else {
setToolIdToNameMapping(toolResponse.id, toolResponse.tool.name)
confirmationPromise = requestToolConfirmation(toolResponse.id, abortSignal).then((confirmed) => {
if (confirmed && server) {
// 自动确认其他同名的待确认工具
confirmSameNameTools(toolResponse.tool.name)
}
return confirmed
})
}
const processingPromise = confirmationPromise
.then(async (confirmed) => {
if (confirmed) {
// 立即更新为invoking状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'invoking'
},
onChunk!
)
// 执行工具调用
try {
const images: string[] = []
const toolCallResponse = await callMCPTool(toolResponse, topicId, model.name)
// 立即更新为done状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'done',
response: toolCallResponse
},
onChunk!
)
// 处理图片
for (const content of toolCallResponse.content) {
if (content.type === 'image' && content.data) {
images.push(`data:${content.mimeType};base64,${content.data}`)
}
}
if (images.length) {
onChunk?.({
type: ChunkType.IMAGE_CREATED
})
onChunk?.({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: images
}
})
}
// 转换消息并添加到结果
const convertedMessage = convertToMessage(toolResponse, toolCallResponse, model)
if (convertedMessage) {
confirmedTools.push(toolResponse)
toolResults.push(convertedMessage)
}
} catch (error) {
logger.error(`Error executing tool ${toolResponse.id}:`, error)
// 更新为错误状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'done',
response: {
isError: true,
content: [
{
type: 'text',
text: `Error executing tool: ${error instanceof Error ? error.message : 'Unknown error'}`
}
]
}
},
onChunk!
)
}
} else {
// 立即更新为cancelled状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'cancelled',
response: {
isError: false,
content: [
{
type: 'text',
text: 'Tool call cancelled by user.'
}
]
}
},
onChunk!
)
}
})
.catch((error) => {
logger.error(`Error waiting for tool confirmation ${toolResponse.id}:`, error)
// 立即更新为cancelled状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'cancelled',
response: {
isError: true,
content: [
{
type: 'text',
text: `Error in confirmation process: ${error instanceof Error ? error.message : 'Unknown error'}`
}
]
}
},
onChunk!
)
})
pendingPromises.push(processingPromise)
})
// 等待所有工具处理完成(但每个工具的状态已经实时更新)
await Promise.all(pendingPromises)
return { toolResults, confirmedToolResponses: confirmedTools }
}
export function mcpToolCallResponseToOpenAICompatibleMessage( export function mcpToolCallResponseToOpenAICompatibleMessage(
mcpToolResponse: MCPToolResponse, mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse, resp: MCPCallToolResponse,