diff --git a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts index 5495bb0d36..54242d12e9 100644 --- a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts +++ b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts @@ -135,11 +135,11 @@ export class AiSdkToChunkAdapter { // === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) === - case 'tool-input-start': - case 'tool-input-delta': - case 'tool-input-end': - this.toolCallHandler.handleToolCallCreated(chunk) - break + // case 'tool-input-start': + // case 'tool-input-delta': + // case 'tool-input-end': + // this.toolCallHandler.handleToolCallCreated(chunk) + // break // case 'tool-input-delta': // this.toolCallHandler.handleToolCallCreated(chunk) diff --git a/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts b/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts index 4ddc9dd927..7467685923 100644 --- a/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts +++ b/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts @@ -208,14 +208,14 @@ export class ToolCallChunkHandler { id: toolCallId, tool: tool, arguments: args, - status: 'invoking', + status: 'pending', toolCallId: toolCallId } // 调用 onChunk if (this.onChunk) { this.onChunk({ - type: ChunkType.MCP_TOOL_IN_PROGRESS, + type: ChunkType.MCP_TOOL_PENDING, responses: [toolResponse] }) } diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 1fb283d317..ec0ef75f71 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -236,7 +236,7 @@ export default class ModernAiProvider { const streamResult = await executor.streamText( modelId, - params, + { ...params, experimental_context: { onChunk: config.onChunk } }, middlewares.length > 0 ? { middlewares } : undefined ) diff --git a/src/renderer/src/aiCore/utils/mcp.ts b/src/renderer/src/aiCore/utils/mcp.ts index 91099fe282..4ae56d722e 100644 --- a/src/renderer/src/aiCore/utils/mcp.ts +++ b/src/renderer/src/aiCore/utils/mcp.ts @@ -1,9 +1,10 @@ -import { Tool } from '@cherrystudio/ai-core' import { loggerService } from '@logger' // import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types' import { MCPTool, MCPToolResponse } from '@renderer/types' -import { callMCPTool } from '@renderer/utils/mcp-tools' -import { jsonSchema, tool } from 'ai' +import { Chunk, ChunkType } from '@renderer/types/chunk' +import { callMCPTool, getMcpServerByTool, isToolAutoApproved } from '@renderer/utils/mcp-tools' +import { requestToolConfirmation } from '@renderer/utils/userConfirmation' +import { jsonSchema, type Tool, tool } from 'ai' import { JSONSchema7 } from 'json-schema' const logger = loggerService.withContext('MCP-utils') @@ -31,18 +32,52 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record { + execute: async (params, { toolCallId, experimental_context }) => { + const { onChunk } = experimental_context as { onChunk: (chunk: Chunk) => void } // 创建适配的 MCPToolResponse 对象 const toolResponse: MCPToolResponse = { - id: `tool_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`, + id: toolCallId, tool: mcpTool, arguments: params, - status: 'invoking', - toolCallId: `call_${Date.now()}_${Math.random().toString(36).substr(2, 9)}` + status: 'pending', + toolCallId } try { - // 复用现有的 callMCPTool 函数 + // 检查是否启用自动批准 + const server = getMcpServerByTool(mcpTool) + const isAutoApproveEnabled = isToolAutoApproved(mcpTool, server) + + let confirmed = true + if (!isAutoApproveEnabled) { + // 请求用户确认 + logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`) + confirmed = await requestToolConfirmation(toolResponse.id) + } + + if (!confirmed) { + // 用户拒绝执行工具 + logger.debug(`User cancelled tool execution: ${mcpTool.name}`) + return { + content: [ + { + type: 'text', + text: `User declined to execute tool "${mcpTool.name}".` + } + ], + isError: false + } + } + + // 用户确认或自动批准,执行工具 + toolResponse.status = 'invoking' + logger.debug(`Executing tool: ${mcpTool.name}`) + + onChunk({ + type: ChunkType.MCP_TOOL_IN_PROGRESS, + responses: [toolResponse] + }) + const result = await callMCPTool(toolResponse) // 返回结果,AI SDK 会处理序列化