feat: implement tool call progress handling and status updates (#7303)

* feat: implement tool call progress handling and status updates

- Update MCP tool response handling to include 'pending' and 'cancelled' statuses.
- Introduce new IPC channel for progress updates.
- Enhance UI components to reflect tool call statuses, including pending and cancelled states.
- Add localization for new status messages in multiple languages.
- Refactor message handling logic to accommodate new tool response types.

* fix: adjust alignment of action tool container in MessageTools component

- Change justify-content from flex-end to flex-start to improve layout consistency.

* feat: enhance tool confirmation handling and update related components

- Introduced a new tool confirmation mechanism in userConfirmation.ts, allowing for individual tool confirmations.
- Updated GeminiAPIClient and OpenAIResponseAPIClient to include tool configuration options.
- Refactored MessageTools component to utilize new confirmation functions and improved styling.
- Enhanced mcp-tools.ts to manage tool invocation and confirmation processes more effectively, ensuring real-time status updates.

* refactor(McpToolChunkMiddleware): enhance tool execution handling and confirmation tracking

- Updated createToolHandlingTransform to manage confirmed tool calls and results more effectively.
- Refactored executeToolCalls and executeToolUseResponses to return both tool results and confirmed tool calls.
- Adjusted buildParamsWithToolResults to utilize confirmed tool calls for building new request messages.
- Improved error handling in messageThunk for tool call status updates, ensuring accurate block ID mapping.

* feat(McpToolChunkMiddleware, ToolUseExtractionMiddleware, mcp-tools, userConfirmation): enhance tool execution and confirmation handling

- Updated McpToolChunkMiddleware to execute tool calls and responses asynchronously, improving performance and response handling.
- Enhanced ToolUseExtractionMiddleware to generate unique tool IDs for better tracking.
- Modified parseToolUse function to accept a starting index for tool extraction.
- Improved user confirmation handling with abort signal support to manage tool action confirmations more effectively.
- Updated SYSTEM_PROMPT to clarify the use of multiple tools per message.

* fix(tagExtraction): update test expectations for tag extraction results

- Adjusted expected length of results from 7 to 9 to reflect changes in tag extraction logic.
- Modified content assertions for specific tag contents to ensure accurate validation of extracted tags.

* refactor(GeminiAPIClient, OpenAIResponseAPIClient): remove unused function calling configurations

- Removed the unused FunctionCallingConfigMode from GeminiAPIClient to streamline the code.
- Eliminated the parallel_tool_calls property from OpenAIResponseAPIClient, simplifying the tool call configuration.

* feat(McpToolChunkMiddleware): enhance LLM response handling and tool call confirmation

- Added notification to UI for new LLM response processing before recursive calls in createToolHandlingTransform.
- Improved tool call confirmation logic in executeToolCalls to match tool IDs more accurately, enhancing response validation.

* refactor(McpToolChunkMiddleware, ToolUseExtractionMiddleware, messageThunk): remove unnecessary console logs

- Eliminated redundant console log statements in McpToolChunkMiddleware, ToolUseExtractionMiddleware, and messageThunk to clean up the code and improve performance.
- Focused on enhancing readability and maintainability by reducing clutter in the logging output.

* refactor(McpToolChunkMiddleware): remove redundant logging statements

- Eliminated unnecessary logging in createToolHandlingTransform to streamline the code and enhance readability.
- Focused on reducing clutter in the logging output while maintaining error handling functionality.

* feat: enhance action button functionality with cancel and confirm options

* refactor(AbortHandlerMiddleware, McpToolChunkMiddleware, ToolUseExtractionMiddleware, messageThunk): improve error handling and code clarity

- Updated AbortHandlerMiddleware to skip abort status checks if an error chunk is received, enhancing error handling logic.
- Replaced console.error with Logger.error in McpToolChunkMiddleware for consistent logging practices.
- Refined ToolUseExtractionMiddleware to improve tool use extraction logic and ensure proper handling of tool_use tags.
- Enhanced messageThunk to include initialPlaceholderBlockId in block ID checks, improving error state management.

* refactor(ToolUseExtractionMiddleware): enhance tool use parsing logic with counter

- Introduced a toolCounter to track the number of tool use responses processed.
- Updated parseToolUse function calls to include the toolCounter, improving the extraction logic and ensuring accurate response handling.

* feat(McpService, IpcChannel, MessageTools): implement tool abort functionality

- Added Mcp_AbortTool channel to handle tool abortion requests.
- Implemented abortTool method in McpService to manage active tool calls and provide logging.
- Updated MessageTools component to include an abort button for ongoing tool calls, enhancing user control.
- Modified API calls to support optional callId for better tracking of tool executions.
- Added localization strings for tool abort messages in multiple languages.

---------

Co-authored-by: Vaayne <liu.vaayne@gmail.com>
This commit is contained in:
SuYao 2025-07-08 17:17:58 +08:00 committed by GitHub
parent 4f7ca3ede8
commit fba6c1642d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 1419 additions and 301 deletions

View File

@ -107,7 +107,7 @@
"@langchain/community": "^0.3.36",
"@langchain/ollama": "^0.2.1",
"@mistralai/mistralai": "^1.6.0",
"@modelcontextprotocol/sdk": "^1.11.4",
"@modelcontextprotocol/sdk": "^1.12.3",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
"@playwright/test": "^1.52.0",

View File

@ -74,6 +74,8 @@ export enum IpcChannel {
Mcp_ServersChanged = 'mcp:servers-changed',
Mcp_ServersUpdated = 'mcp:servers-updated',
Mcp_CheckConnectivity = 'mcp:check-connectivity',
Mcp_SetProgress = 'mcp:set-progress',
Mcp_AbortTool = 'mcp:abort-tool',
// Python
Python_Execute = 'python:execute',

View File

@ -501,6 +501,10 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.Mcp_GetResource, mcpService.getResource)
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo)
ipcMain.handle(IpcChannel.Mcp_CheckConnectivity, mcpService.checkMcpConnectivity)
ipcMain.handle(IpcChannel.Mcp_AbortTool, mcpService.abortTool)
ipcMain.handle(IpcChannel.Mcp_SetProgress, (_, progress: number) => {
mainWindow.webContents.send('mcp-progress', progress)
})
// Register Python execution handler
ipcMain.handle(

View File

@ -28,6 +28,7 @@ import { app } from 'electron'
import Logger from 'electron-log'
import { EventEmitter } from 'events'
import { memoize } from 'lodash'
import { v4 as uuidv4 } from 'uuid'
import { CacheService } from './CacheService'
import { CallBackServer } from './mcp/oauth/callback'
@ -71,6 +72,7 @@ function withCache<T extends unknown[], R>(
class McpService {
private clients: Map<string, Client> = new Map()
private pendingClients: Map<string, Promise<Client>> = new Map()
private activeToolCalls: Map<string, AbortController> = new Map()
constructor() {
this.initClient = this.initClient.bind(this)
@ -84,6 +86,7 @@ class McpService {
this.removeServer = this.removeServer.bind(this)
this.restartServer = this.restartServer.bind(this)
this.stopServer = this.stopServer.bind(this)
this.abortTool = this.abortTool.bind(this)
this.cleanup = this.cleanup.bind(this)
}
@ -455,10 +458,14 @@ class McpService {
*/
public async callTool(
_: Electron.IpcMainInvokeEvent,
{ server, name, args }: { server: MCPServer; name: string; args: any }
{ server, name, args, callId }: { server: MCPServer; name: string; args: any; callId?: string }
): Promise<MCPCallToolResponse> {
const toolCallId = callId || uuidv4()
const abortController = new AbortController()
this.activeToolCalls.set(toolCallId, abortController)
try {
Logger.info('[MCP] Calling:', server.name, name, args)
Logger.info('[MCP] Calling:', server.name, name, args, 'callId:', toolCallId)
if (typeof args === 'string') {
try {
args = JSON.parse(args)
@ -468,12 +475,19 @@ class McpService {
}
const client = await this.initClient(server)
const result = await client.callTool({ name, arguments: args }, undefined, {
timeout: server.timeout ? server.timeout * 1000 : 60000 // Default timeout of 1 minute
onprogress: (process) => {
console.log('[MCP] Progress:', process.progress / (process.total || 1))
window.api.mcp.setProgress(process.progress / (process.total || 1))
},
timeout: server.timeout ? server.timeout * 1000 : 60000, // Default timeout of 1 minute
signal: this.activeToolCalls.get(toolCallId)?.signal
})
return result as MCPCallToolResponse
} catch (error) {
Logger.error(`[MCP] Error calling tool ${name} on ${server.name}:`, error)
throw error
} finally {
this.activeToolCalls.delete(toolCallId)
}
}
@ -664,6 +678,20 @@ class McpService {
delete env.http_proxy
delete env.https_proxy
}
// 实现 abortTool 方法
public async abortTool(_: Electron.IpcMainInvokeEvent, callId: string) {
const activeToolCall = this.activeToolCalls.get(callId)
if (activeToolCall) {
activeToolCall.abort()
this.activeToolCalls.delete(callId)
Logger.info(`[MCP] Aborted tool call: ${callId}`)
return true
} else {
Logger.warn(`[MCP] No active tool call found for callId: ${callId}`)
return false
}
}
}
export default new McpService()

View File

@ -228,8 +228,8 @@ const api = {
restartServer: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_RestartServer, server),
stopServer: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_StopServer, server),
listTools: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_ListTools, server),
callTool: ({ server, name, args }: { server: MCPServer; name: string; args: any }) =>
ipcRenderer.invoke(IpcChannel.Mcp_CallTool, { server, name, args }),
callTool: ({ server, name, args, callId }: { server: MCPServer; name: string; args: any; callId?: string }) =>
ipcRenderer.invoke(IpcChannel.Mcp_CallTool, { server, name, args, callId }),
listPrompts: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_ListPrompts, server),
getPrompt: ({ server, name, args }: { server: MCPServer; name: string; args?: Record<string, any> }) =>
ipcRenderer.invoke(IpcChannel.Mcp_GetPrompt, { server, name, args }),
@ -237,7 +237,9 @@ const api = {
getResource: ({ server, uri }: { server: MCPServer; uri: string }) =>
ipcRenderer.invoke(IpcChannel.Mcp_GetResource, { server, uri }),
getInstallInfo: () => ipcRenderer.invoke(IpcChannel.Mcp_GetInstallInfo),
checkMcpConnectivity: (server: any) => ipcRenderer.invoke(IpcChannel.Mcp_CheckConnectivity, server)
checkMcpConnectivity: (server: any) => ipcRenderer.invoke(IpcChannel.Mcp_CheckConnectivity, server),
abortTool: (callId: string) => ipcRenderer.invoke(IpcChannel.Mcp_AbortTool, callId),
setProgress: (progress: number) => ipcRenderer.invoke(IpcChannel.Mcp_SetProgress, progress)
},
python: {
execute: (script: string, context?: Record<string, any>, timeout?: number) =>

View File

@ -67,7 +67,12 @@ export const AbortHandlerMiddleware: CompletionsMiddleware =
const streamWithAbortHandler = (result.stream as ReadableStream<Chunk>).pipeThrough(
new TransformStream<Chunk, Chunk | ErrorChunk>({
transform(chunk, controller) {
// 检查 abort 状态
// 如果已经收到错误块,不再检查 abort 状态
if (chunk.type === ChunkType.ERROR) {
controller.enqueue(chunk)
return
}
if (abortSignal?.aborted) {
// 转换为 ErrorChunk
const errorChunk: ErrorChunk = {

View File

@ -136,7 +136,6 @@ function extractAndAccumulateUsageMetrics(ctx: CompletionsContext, chunk: Generi
Logger.debug(`[${MIDDLEWARE_NAME}] First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`)
}
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
Logger.debug(`[${MIDDLEWARE_NAME}] LLM_RESPONSE_COMPLETE chunk received:`, ctx._internal)
// 从LLM_RESPONSE_COMPLETE chunk中提取usage数据
if (chunk.response?.usage) {
accumulateUsage(ctx._internal.observer.usage, chunk.response.usage)

View File

@ -89,6 +89,11 @@ function createToolHandlingTransform(
let hasToolUseResponses = false
let streamEnded = false
// 存储已执行的工具结果
const executedToolResults: SdkMessageParam[] = []
const executedToolCalls: SdkToolCall[] = []
const executionPromises: Promise<void>[] = []
return new TransformStream({
async transform(chunk: GenericChunk, controller) {
try {
@ -98,22 +103,64 @@ function createToolHandlingTransform(
// 1. 处理Function Call方式的工具调用
if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) {
toolCalls.push(...createdChunk.tool_calls)
hasToolCalls = true
for (const toolCall of createdChunk.tool_calls) {
toolCalls.push(toolCall)
const executionPromise = (async () => {
try {
const result = await executeToolCalls(
ctx,
[toolCall],
mcpTools,
allToolResponses,
currentParams.onChunk,
currentParams.assistant.model!
)
// 缓存执行结果
executedToolResults.push(...result.toolResults)
executedToolCalls.push(...result.confirmedToolCalls)
} catch (error) {
console.error(`🔧 [${MIDDLEWARE_NAME}] Error executing tool call asynchronously:`, error)
}
})()
executionPromises.push(executionPromise)
}
}
// 2. 处理Tool Use方式的工具调用
if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) {
toolUseResponses.push(...createdChunk.tool_use_responses)
hasToolUseResponses = true
for (const toolUseResponse of createdChunk.tool_use_responses) {
toolUseResponses.push(toolUseResponse)
const executionPromise = (async () => {
try {
const result = await executeToolUseResponses(
ctx,
[toolUseResponse], // 单个执行
mcpTools,
allToolResponses,
currentParams.onChunk,
currentParams.assistant.model!
)
// 缓存执行结果
executedToolResults.push(...result.toolResults)
} catch (error) {
console.error(`🔧 [${MIDDLEWARE_NAME}] Error executing tool use response asynchronously:`, error)
// 错误时不影响其他工具的执行
}
})()
executionPromises.push(executionPromise)
}
}
// 不转发MCP工具进展chunks避免重复处理
return
} else {
controller.enqueue(chunk)
}
// 转发其他所有chunk
controller.enqueue(chunk)
} catch (error) {
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
controller.error(error)
@ -121,43 +168,33 @@ function createToolHandlingTransform(
},
async flush(controller) {
const shouldExecuteToolCalls = hasToolCalls && toolCalls.length > 0
const shouldExecuteToolUseResponses = hasToolUseResponses && toolUseResponses.length > 0
if (!streamEnded && (shouldExecuteToolCalls || shouldExecuteToolUseResponses)) {
// 在流结束时等待所有异步工具执行完成,然后进行递归调用
if (!streamEnded && (hasToolCalls || hasToolUseResponses)) {
streamEnded = true
try {
let toolResult: SdkMessageParam[] = []
if (shouldExecuteToolCalls) {
toolResult = await executeToolCalls(
ctx,
toolCalls,
mcpTools,
allToolResponses,
currentParams.onChunk,
currentParams.assistant.model!
)
} else if (shouldExecuteToolUseResponses) {
toolResult = await executeToolUseResponses(
ctx,
toolUseResponses,
mcpTools,
allToolResponses,
currentParams.onChunk,
currentParams.assistant.model!
)
}
if (toolResult.length > 0) {
await Promise.all(executionPromises)
if (executedToolResults.length > 0) {
const output = ctx._internal.toolProcessingState?.output
const newParams = buildParamsWithToolResults(
ctx,
currentParams,
output,
executedToolResults,
executedToolCalls
)
// 在递归调用前通知UI开始新的LLM响应处理
if (currentParams.onChunk) {
currentParams.onChunk({
type: ChunkType.LLM_RESPONSE_CREATED
})
}
const newParams = buildParamsWithToolResults(ctx, currentParams, output, toolResult, toolCalls)
await executeWithToolHandling(newParams, depth + 1)
}
} catch (error) {
console.error(`🔧 [${MIDDLEWARE_NAME}] Error in tool processing:`, error)
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Error in tool processing:`, error)
controller.error(error)
} finally {
hasToolCalls = false
@ -178,8 +215,7 @@ async function executeToolCalls(
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
model: Model
): Promise<SdkMessageParam[]> {
// 转换为MCPToolResponse格式
): Promise<{ toolResults: SdkMessageParam[]; confirmedToolCalls: SdkToolCall[] }> {
const mcpToolResponses: ToolCallResponse[] = toolCalls
.map((toolCall) => {
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
@ -192,11 +228,11 @@ async function executeToolCalls(
if (mcpToolResponses.length === 0) {
console.warn(`🔧 [${MIDDLEWARE_NAME}] No valid MCP tool responses to execute`)
return []
return { toolResults: [], confirmedToolCalls: [] }
}
// 使用现有的parseAndCallTools函数执行工具
const toolResults = await parseAndCallTools(
const { toolResults, confirmedToolResponses } = await parseAndCallTools(
mcpToolResponses,
allToolResponses,
onChunk,
@ -204,10 +240,24 @@ async function executeToolCalls(
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
},
model,
mcpTools
mcpTools,
ctx._internal?.flowControl?.abortSignal
)
return toolResults
// 找出已确认工具对应的原始toolCalls
const confirmedToolCalls = toolCalls.filter((toolCall) => {
return confirmedToolResponses.find((confirmed) => {
// 根据不同的ID字段匹配原始toolCall
return (
('name' in toolCall &&
(toolCall.name?.includes(confirmed.tool.name) || toolCall.name?.includes(confirmed.tool.id))) ||
confirmed.tool.name === toolCall.id ||
confirmed.tool.id === toolCall.id
)
})
})
return { toolResults, confirmedToolCalls }
}
/**
@ -221,9 +271,9 @@ async function executeToolUseResponses(
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
model: Model
): Promise<SdkMessageParam[]> {
): Promise<{ toolResults: SdkMessageParam[] }> {
// 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse
const toolResults = await parseAndCallTools(
const { toolResults } = await parseAndCallTools(
toolUseResponses,
allToolResponses,
onChunk,
@ -231,10 +281,11 @@ async function executeToolUseResponses(
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
},
model,
mcpTools
mcpTools,
ctx._internal?.flowControl?.abortSignal
)
return toolResults
return { toolResults }
}
/**
@ -245,7 +296,7 @@ function buildParamsWithToolResults(
currentParams: CompletionsParams,
output: SdkRawOutput | string | undefined,
toolResults: SdkMessageParam[],
toolCalls: SdkToolCall[]
confirmedToolCalls: SdkToolCall[]
): CompletionsParams {
// 获取当前已经转换好的reqMessages如果没有则使用原始messages
const currentReqMessages = getCurrentReqMessages(ctx)
@ -253,7 +304,7 @@ function buildParamsWithToolResults(
const apiClient = ctx.apiClientInstance
// 从回复中构建助手消息
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, confirmedToolCalls)
if (output && ctx._internal.toolProcessingState) {
ctx._internal.toolProcessingState.output = undefined

View File

@ -22,7 +22,8 @@ const TOOL_USE_TAG_CONFIG: TagConfig = {
* 1. <tool_use></tool_use>
* 2. ToolUseResponse
* 3. MCP_TOOL_CREATED chunk McpToolChunkMiddleware
* 4. 使
* 4. tool_use
* 5. 使
*
* McpToolChunkMiddleware
*/
@ -32,13 +33,10 @@ export const ToolUseExtractionMiddleware: CompletionsMiddleware =
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const mcpTools = params.mcpTools || []
// 如果没有工具,直接调用下一个中间件
if (!mcpTools || mcpTools.length === 0) return next(ctx, params)
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理:处理工具使用标签提取
if (result.stream) {
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
@ -60,7 +58,9 @@ function createToolUseExtractionTransform(
_ctx: CompletionsContext,
mcpTools: MCPTool[]
): TransformStream<GenericChunk, GenericChunk> {
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
const toolUseExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
let hasAnyToolUse = false
let toolCounter = 0
return new TransformStream({
async transform(chunk: GenericChunk, controller) {
@ -68,30 +68,37 @@ function createToolUseExtractionTransform(
// 处理文本内容,检测工具使用标签
if (chunk.type === ChunkType.TEXT_DELTA) {
const textChunk = chunk as TextDeltaChunk
const extractionResults = tagExtractor.processText(textChunk.text)
for (const result of extractionResults) {
// 处理 tool_use 标签
const toolUseResults = toolUseExtractor.processText(textChunk.text)
for (const result of toolUseResults) {
if (result.complete && result.tagContentExtracted) {
// 提取到完整的工具使用内容,解析并转换为 SDK ToolCall 格式
const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools)
const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools, toolCounter)
toolCounter += toolUseResponses.length
if (toolUseResponses.length > 0) {
// 生成 MCP_TOOL_CREATED chunk,复用现有的处理流程
// 生成 MCP_TOOL_CREATED chunk
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
type: ChunkType.MCP_TOOL_CREATED,
tool_use_responses: toolUseResponses
}
controller.enqueue(mcpToolCreatedChunk)
// 标记已有工具调用
hasAnyToolUse = true
}
} else if (!result.isTagContent && result.content) {
// 发送标签外的正常文本内容
const cleanTextChunk: TextDeltaChunk = {
...textChunk,
text: result.content
if (!hasAnyToolUse) {
const cleanTextChunk: TextDeltaChunk = {
...textChunk,
text: result.content
}
controller.enqueue(cleanTextChunk)
}
controller.enqueue(cleanTextChunk)
}
// 注意标签内的内容不会作为TEXT_DELTA转发,避免重复显示
// tool_use 标签内的内容不转发,避免重复显示
}
return
}
@ -105,16 +112,17 @@ function createToolUseExtractionTransform(
},
async flush(controller) {
// 检查是否有未完成的标签内容
const finalResult = tagExtractor.finalize()
if (finalResult && finalResult.tagContentExtracted) {
const toolUseResponses = parseToolUse(finalResult.tagContentExtracted, mcpTools)
// 检查是否有未完成的 tool_use 标签内容
const finalToolUseResult = toolUseExtractor.finalize()
if (finalToolUseResult && finalToolUseResult.tagContentExtracted) {
const toolUseResponses = parseToolUse(finalToolUseResult.tagContentExtracted, mcpTools, toolCounter)
if (toolUseResponses.length > 0) {
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
type: ChunkType.MCP_TOOL_CREATED,
tool_use_responses: toolUseResponses
}
controller.enqueue(mcpToolCreatedChunk)
hasAnyToolUse = true
}
}
}

View File

@ -706,8 +706,12 @@
"success.yuque.export": "Successfully exported to Yuque",
"switch.disabled": "Please wait for the current reply to complete",
"tools": {
"pending": "Pending",
"cancelled": "Cancelled",
"completed": "Completed",
"invoking": "Invoking",
"aborted": "Tool call aborted",
"abort_failed": "Tool call abort failed",
"error": "Error occurred",
"raw": "Raw",
"preview": "Preview"

View File

@ -706,9 +706,13 @@
"tools": {
"completed": "完了",
"invoking": "呼び出し中",
"aborted": "ツール呼び出し中断",
"abort_failed": "ツール呼び出し中断失敗",
"error": "エラーが発生しました",
"raw": "生データ",
"preview": "プレビュー"
"preview": "プレビュー",
"pending": "保留中",
"cancelled": "キャンセル"
},
"topic.added": "新しいトピックが追加されました",
"upgrade.success.button": "再起動",

View File

@ -705,11 +705,15 @@
"success.yuque.export": "Успешный экспорт в Yuque",
"switch.disabled": "Пожалуйста, дождитесь завершения текущего ответа",
"tools": {
"aborted": "Вызов инструмента прерван",
"abort_failed": "Вызов инструмента прерван",
"completed": "Завершено",
"invoking": "Вызов",
"error": "Произошла ошибка",
"raw": "Исходный",
"preview": "Предпросмотр"
"preview": "Предпросмотр",
"pending": "Ожидание",
"cancelled": "Отменено"
},
"topic.added": "Новый топик добавлен",
"upgrade.success.button": "Перезапустить",

View File

@ -706,8 +706,12 @@
"success.yuque.export": "成功导出到语雀",
"switch.disabled": "请等待当前回复完成后操作",
"tools": {
"pending": "等待中",
"cancelled": "已取消",
"completed": "已完成",
"invoking": "调用中",
"aborted": "工具调用已中断",
"abort_failed": "工具调用中断失败",
"error": "发生错误",
"raw": "原始",
"preview": "预览"

View File

@ -706,11 +706,15 @@
"success.yuque.export": "成功匯出到語雀",
"switch.disabled": "請等待當前回覆完成",
"tools": {
"aborted": "工具調用已中斷",
"abort_failed": "工具調用中斷失敗",
"completed": "已完成",
"invoking": "調用中",
"error": "發生錯誤",
"raw": "原始碼",
"preview": "預覽"
"preview": "預覽",
"pending": "等待中",
"cancelled": "已取消"
},
"topic.added": "新話題已新增",
"upgrade.success.button": "重新啟動",

View File

@ -1,8 +1,12 @@
import { CheckOutlined, ExpandOutlined, LoadingOutlined, WarningOutlined } from '@ant-design/icons'
import { CheckOutlined, CloseOutlined, LoadingOutlined, WarningOutlined } from '@ant-design/icons'
import { useCodeStyle } from '@renderer/context/CodeStyleProvider'
import { useSettings } from '@renderer/hooks/useSettings'
import type { ToolMessageBlock } from '@renderer/types/newMessage'
import { Collapse, message as antdMessage, Modal, Tabs, Tooltip } from 'antd'
import { cancelToolAction, confirmToolAction } from '@renderer/utils/userConfirmation'
import { Collapse, message as antdMessage, Tooltip } from 'antd'
import { message } from 'antd'
import Logger from 'electron-log/renderer'
import { PauseCircle } from 'lucide-react'
import { FC, memo, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@ -14,12 +18,24 @@ interface Props {
const MessageTools: FC<Props> = ({ block }) => {
const [activeKeys, setActiveKeys] = useState<string[]>([])
const [copiedMap, setCopiedMap] = useState<Record<string, boolean>>({})
const [expandedResponse, setExpandedResponse] = useState<{ content: string; title: string } | null>(null)
const { t } = useTranslation()
const { messageFont, fontSize } = useSettings()
const toolResponse = block.metadata?.rawMcpToolResponse
const { id, tool, status, response } = toolResponse!
const isPending = status === 'pending'
const isInvoking = status === 'invoking'
const isDone = status === 'done'
const argsString = useMemo(() => {
if (toolResponse?.arguments) {
return JSON.stringify(toolResponse.arguments, null, 2)
}
return 'No arguments'
}, [toolResponse])
const resultString = useMemo(() => {
try {
return JSON.stringify(
@ -50,13 +66,34 @@ const MessageTools: FC<Props> = ({ block }) => {
setActiveKeys(Array.isArray(keys) ? keys : [keys])
}
const handleConfirmTool = () => {
confirmToolAction(id)
}
const handleCancelTool = () => {
cancelToolAction(id)
}
const handleAbortTool = async () => {
if (toolResponse?.id) {
try {
const success = await window.api.mcp.abortTool(toolResponse.id)
if (success) {
message.success({ content: t('message.tools.aborted'), key: 'abort-tool' })
} else {
message.error({ content: t('message.tools.abort_failed'), key: 'abort-tool' })
}
} catch (error) {
Logger.error('Failed to abort tool:', error)
message.error({ content: t('message.tools.abort_failed'), key: 'abort-tool' })
}
}
}
// Format tool responses for collapse items
const getCollapseItems = () => {
const items: { key: string; label: React.ReactNode; children: React.ReactNode }[] = []
const { id, tool, status, response } = toolResponse
const isInvoking = status === 'invoking'
const isDone = status === 'done'
const hasError = isDone && response?.isError === true
const hasError = response?.isError === true
const result = {
params: toolResponse.arguments,
response: toolResponse.response
@ -68,34 +105,93 @@ const MessageTools: FC<Props> = ({ block }) => {
<MessageTitleLabel>
<TitleContent>
<ToolName>{tool.name}</ToolName>
<StatusIndicator $isInvoking={isInvoking} $hasError={hasError}>
{isInvoking
? t('message.tools.invoking')
: hasError
? t('message.tools.error')
: t('message.tools.completed')}
{isInvoking && <LoadingOutlined spin style={{ marginLeft: 6 }} />}
{isDone && !hasError && <CheckOutlined style={{ marginLeft: 6 }} />}
{hasError && <WarningOutlined style={{ marginLeft: 6 }} />}
<StatusIndicator status={status} hasError={hasError}>
{(() => {
switch (status) {
case 'pending':
return (
<>
{t('message.tools.pending')}
<LoadingOutlined spin style={{ marginLeft: 6 }} />
</>
)
case 'invoking':
return (
<>
{t('message.tools.invoking')}
<LoadingOutlined spin style={{ marginLeft: 6 }} />
</>
)
case 'cancelled':
return (
<>
{t('message.tools.cancelled')}
<CloseOutlined style={{ marginLeft: 6 }} />
</>
)
case 'done':
if (hasError) {
return (
<>
{t('message.tools.error')}
<WarningOutlined style={{ marginLeft: 6 }} />
</>
)
} else {
return (
<>
{t('message.tools.completed')}
<CheckOutlined style={{ marginLeft: 6 }} />
</>
)
}
default:
return ''
}
})()}
</StatusIndicator>
</TitleContent>
<ActionButtonsContainer>
{isDone && response && (
{isPending && (
<>
<Tooltip title={t('common.expand')} mouseEnterDelay={0.5}>
<Tooltip title={t('common.cancel')} mouseEnterDelay={0.3}>
<ActionButton
className="message-action-button"
onClick={(e) => {
e.stopPropagation()
setExpandedResponse({
content: JSON.stringify(response, null, 2),
title: tool.name
})
handleCancelTool()
}}
aria-label={t('common.expand')}>
<ExpandOutlined />
aria-label={t('common.cancel')}>
<CloseOutlined style={{ fontSize: '14px' }} />
</ActionButton>
</Tooltip>
<Tooltip title={t('common.confirm')} mouseEnterDelay={0.3}>
<ActionButton
className="confirm-button"
onClick={(e) => {
e.stopPropagation()
handleConfirmTool()
}}
aria-label={t('common.confirm')}>
<CheckOutlined style={{ fontSize: '14px' }} />
</ActionButton>
</Tooltip>
</>
)}
{isInvoking && toolResponse?.id && (
<Tooltip title={t('chat.input.pause')} mouseEnterDelay={0.3}>
<ActionButton
className="abort-button"
onClick={(e) => {
e.stopPropagation()
handleAbortTool()
}}
aria-label={t('chat.input.pause')}>
<PauseCircle color="var(--color-error)" size={14} />
</ActionButton>
</Tooltip>
)}
{isDone && response && (
<>
<Tooltip title={t('common.copy')} mouseEnterDelay={0.5}>
<ActionButton
className="message-action-button"
@ -113,98 +209,38 @@ const MessageTools: FC<Props> = ({ block }) => {
</ActionButtonsContainer>
</MessageTitleLabel>
),
children: isDone && result && (
<ToolResponseContainer
style={{
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
fontSize: '12px'
}}>
<CollapsedContent isExpanded={activeKeys.includes(id)} resultString={resultString} />
</ToolResponseContainer>
)
children:
isDone && result ? (
<ToolResponseContainer
style={{
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
fontSize
}}>
<CollapsedContent isExpanded={activeKeys.includes(id)} resultString={resultString} />
</ToolResponseContainer>
) : argsString ? (
<>
<ToolResponseContainer>
<CollapsedContent isExpanded={activeKeys.includes(id)} resultString={argsString} />
</ToolResponseContainer>
</>
) : null
})
return items
}
const renderPreview = (content: string) => {
if (!content) return null
try {
const parsedResult = JSON.parse(content)
switch (parsedResult.content[0]?.type) {
case 'text':
return <PreviewBlock>{parsedResult.content[0].text}</PreviewBlock>
default:
return <PreviewBlock>{content}</PreviewBlock>
}
} catch (e) {
console.error('failed to render the preview of mcp results:', e)
return <PreviewBlock>{content}</PreviewBlock>
}
}
return (
<>
<ToolContainer>
<CollapseContainer
activeKey={activeKeys}
size="small"
onChange={handleCollapseChange}
className="message-tools-container"
items={getCollapseItems()}
expandIcon={({ isActive }) => (
<CollapsibleIcon className={`iconfont ${isActive ? 'icon-chevron-down' : 'icon-chevron-right'}`} />
)}
expandIconPosition="end"
/>
<Modal
title={expandedResponse?.title}
open={!!expandedResponse}
onCancel={() => setExpandedResponse(null)}
footer={null}
width="80%"
centered
transitionName="animation-move-down"
styles={{ body: { maxHeight: '80vh', overflow: 'auto' } }}>
{expandedResponse && (
<ExpandedResponseContainer
style={{
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
fontSize
}}>
<Tabs
tabBarExtraContent={
<ActionButton
className="copy-expanded-button"
onClick={() => {
navigator.clipboard.writeText(
typeof expandedResponse.content === 'string'
? expandedResponse.content
: JSON.stringify(expandedResponse.content, null, 2)
)
antdMessage.success({ content: t('message.copied'), key: 'copy-expanded' })
}}
aria-label={t('common.copy')}>
<i className="iconfont icon-copy"></i>
</ActionButton>
}
items={[
{
key: 'preview',
label: t('message.tools.preview'),
children: <CollapsedContent isExpanded={true} resultString={resultString} />
},
{
key: 'raw',
label: t('message.tools.raw'),
children: renderPreview(expandedResponse.content)
}
]}
/>
</ExpandedResponseContainer>
)}
</Modal>
</>
</ToolContainer>
)
}
@ -230,15 +266,25 @@ const CollapsedContent: FC<{ isExpanded: boolean; resultString: string }> = ({ i
}
const CollapseContainer = styled(Collapse)`
margin-top: 10px;
margin-bottom: 12px;
border-radius: 8px;
border: none;
overflow: hidden;
.ant-collapse-header {
background-color: var(--color-bg-2);
transition: background-color 0.2s;
display: flex;
align-items: center;
.ant-collapse-expand-icon {
height: 100% !important;
}
.ant-collapse-arrow {
height: 28px !important;
svg {
width: 14px;
height: 14px;
}
}
&:hover {
background-color: var(--color-bg-3);
}
@ -249,6 +295,15 @@ const CollapseContainer = styled(Collapse)`
}
`
const ToolContainer = styled.div`
margin-top: 10px;
margin-bottom: 12px;
border: 1px solid var(--color-border);
background-color: var(--color-bg-2);
border-radius: 8px;
overflow: hidden;
`
const MarkdownContainer = styled.div`
& pre {
background: transparent !important;
@ -267,6 +322,7 @@ const MessageTitleLabel = styled.div`
min-height: 26px;
gap: 10px;
padding: 0;
margin-left: 4px;
`
const TitleContent = styled.div`
@ -282,18 +338,27 @@ const ToolName = styled.span`
font-size: 13px;
`
const StatusIndicator = styled.span<{ $isInvoking: boolean; $hasError?: boolean }>`
const StatusIndicator = styled.span<{ status: string; hasError?: boolean }>`
color: ${(props) => {
if (props.$hasError) return 'var(--color-error, #ff4d4f)'
if (props.$isInvoking) return 'var(--color-primary)'
return 'var(--color-success, #52c41a)'
switch (props.status) {
case 'pending':
return 'var(--color-text-2)'
case 'invoking':
return 'var(--color-primary)'
case 'cancelled':
return 'var(--color-error, #ff4d4f)' // Assuming cancelled should also be an error color
case 'done':
return props.hasError ? 'var(--color-error, #ff4d4f)' : 'var(--color-success, #52c41a)'
default:
return 'var(--color-text)'
}
}};
font-size: 11px;
display: flex;
align-items: center;
opacity: 0.85;
border-left: 1px solid var(--color-border);
padding-left: 8px;
padding-left: 12px;
`
const ActionButtonsContainer = styled.div`
@ -307,18 +372,30 @@ const ActionButton = styled.button`
border: none;
color: var(--color-text-2);
cursor: pointer;
padding: 4px 8px;
padding: 4px;
display: flex;
align-items: center;
justify-content: center;
opacity: 0.7;
transition: all 0.2s;
border-radius: 4px;
gap: 4px;
min-width: 28px;
height: 28px;
&:hover {
opacity: 1;
color: var(--color-text);
background-color: var(--color-bg-1);
background-color: var(--color-bg-3);
}
&.confirm-button {
color: var(--color-primary);
&:hover {
background-color: var(--color-primary-bg);
color: var(--color-primary);
}
}
&:focus-visible {
@ -332,12 +409,6 @@ const ActionButton = styled.button`
}
`
const CollapsibleIcon = styled.i`
color: var(--color-text-2);
font-size: 12px;
transition: transform 0.2s;
`
const ToolResponseContainer = styled.div`
border-radius: 0 0 4px 4px;
overflow: auto;
@ -346,35 +417,4 @@ const ToolResponseContainer = styled.div`
position: relative;
`
const PreviewBlock = styled.div`
margin: 0;
white-space: pre-wrap;
word-break: break-word;
color: var(--color-text);
user-select: text;
`
const ExpandedResponseContainer = styled.div`
background: var(--color-bg-1);
border-radius: 8px;
padding: 16px;
position: relative;
.copy-expanded-button {
position: absolute;
top: 10px;
right: 10px;
background-color: var(--color-bg-2);
border-radius: 4px;
z-index: 1;
}
pre {
margin: 0;
white-space: pre-wrap;
word-break: break-word;
color: var(--color-text);
}
`
export default memo(MessageTools)

View File

@ -16,6 +16,7 @@ export interface StreamProcessorCallbacks {
onThinkingChunk?: (text: string, thinking_millsec?: number) => void
onThinkingComplete?: (text: string, thinking_millsec?: number) => void
// A tool call response chunk (from MCP)
onToolCallPending?: (toolResponse: MCPToolResponse) => void
onToolCallInProgress?: (toolResponse: MCPToolResponse) => void
onToolCallComplete?: (toolResponse: MCPToolResponse) => void
// External tool call in progress
@ -69,6 +70,10 @@ export function createStreamProcessor(callbacks: StreamProcessorCallbacks = {})
if (callbacks.onThinkingComplete) callbacks.onThinkingComplete(data.text, data.thinking_millsec)
break
}
case ChunkType.MCP_TOOL_PENDING: {
if (callbacks.onToolCallPending) data.responses.forEach((toolResp) => callbacks.onToolCallPending!(toolResp))
break
}
case ChunkType.MCP_TOOL_IN_PROGRESS: {
if (callbacks.onToolCallInProgress)
data.responses.forEach((toolResp) => callbacks.onToolCallInProgress!(toolResp))

View File

@ -529,12 +529,13 @@ const fetchAndProcessAssistantResponseImpl = async (
}
thinkingBlockId = null
},
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
onToolCallPending: (toolResponse: MCPToolResponse) => {
if (initialPlaceholderBlockId) {
lastBlockType = MessageBlockType.TOOL
const changes = {
type: MessageBlockType.TOOL,
status: MessageBlockStatus.PROCESSING,
status: MessageBlockStatus.PENDING,
toolName: toolResponse.tool.name,
metadata: { rawMcpToolResponse: toolResponse }
}
toolBlockId = initialPlaceholderBlockId
@ -542,14 +543,37 @@ const fetchAndProcessAssistantResponseImpl = async (
dispatch(updateOneBlock({ id: toolBlockId, changes }))
saveUpdatedBlockToDB(toolBlockId, assistantMsgId, topicId, getState)
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlockId)
} else if (toolResponse.status === 'invoking') {
} else if (toolResponse.status === 'pending') {
const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, {
toolName: toolResponse.tool.name,
status: MessageBlockStatus.PROCESSING,
status: MessageBlockStatus.PENDING,
metadata: { rawMcpToolResponse: toolResponse }
})
toolBlockId = toolBlock.id
handleBlockTransition(toolBlock, MessageBlockType.TOOL)
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlock.id)
} else {
console.warn(
`[onToolCallPending] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
)
}
},
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
// 根据 toolResponse.id 查找对应的块ID
const targetBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
if (targetBlockId && toolResponse.status === 'invoking') {
const changes = {
status: MessageBlockStatus.PROCESSING,
metadata: { rawMcpToolResponse: toolResponse }
}
dispatch(updateOneBlock({ id: targetBlockId, changes }))
saveUpdatedBlockToDB(targetBlockId, assistantMsgId, topicId, getState)
} else if (!targetBlockId) {
console.warn(
`[onToolCallInProgress] No block ID found for tool ID: ${toolResponse.id}. Available mappings:`,
Array.from(toolCallIdToBlockIdMap.entries())
)
} else {
console.warn(
`[onToolCallInProgress] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
@ -559,14 +583,17 @@ const fetchAndProcessAssistantResponseImpl = async (
onToolCallComplete: (toolResponse: MCPToolResponse) => {
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
toolCallIdToBlockIdMap.delete(toolResponse.id)
if (toolResponse.status === 'done' || toolResponse.status === 'error') {
if (toolResponse.status === 'done' || toolResponse.status === 'error' || toolResponse.status === 'cancelled') {
if (!existingBlockId) {
console.error(
`[onToolCallComplete] No existing block found for completed/error tool call ID: ${toolResponse.id}. Cannot update.`
)
return
}
const finalStatus = toolResponse.status === 'done' ? MessageBlockStatus.SUCCESS : MessageBlockStatus.ERROR
const finalStatus =
toolResponse.status === 'done' || toolResponse.status === 'cancelled'
? MessageBlockStatus.SUCCESS
: MessageBlockStatus.ERROR
const changes: Partial<ToolMessageBlock> = {
content: toolResponse.response,
status: finalStatus,
@ -583,6 +610,7 @@ const fetchAndProcessAssistantResponseImpl = async (
`[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
)
}
toolBlockId = null
},
onExternalToolInProgress: async () => {
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
@ -762,7 +790,14 @@ const fetchAndProcessAssistantResponseImpl = async (
})
}
const possibleBlockId =
mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId
mainTextBlockId ||
thinkingBlockId ||
toolBlockId ||
imageBlockId ||
citationBlockId ||
initialPlaceholderBlockId ||
lastBlockId
if (possibleBlockId) {
// 更改上一个block的状态为ERROR
const changes: Partial<MessageBlock> = {
@ -801,7 +836,13 @@ const fetchAndProcessAssistantResponseImpl = async (
const finalContextWithAssistant = [...contextForUsage, finalAssistantMsg]
const possibleBlockId =
mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId
mainTextBlockId ||
thinkingBlockId ||
toolBlockId ||
imageBlockId ||
citationBlockId ||
initialPlaceholderBlockId ||
lastBlockId
if (possibleBlockId) {
const changes: Partial<MessageBlock> = {
status: MessageBlockStatus.SUCCESS
@ -1109,7 +1150,6 @@ export const resendMessageThunk =
// 没有相关的助手消息就创建一个或多个
if (userMessageToResend?.mentions?.length) {
console.log('userMessageToResend.mentions', userMessageToResend.mentions)
for (const mention of userMessageToResend.mentions) {
const assistantMessage = createAssistantMessage(assistant.id, topicId, {
askId: userMessageToResend.id,

View File

@ -13,6 +13,7 @@ export enum ChunkType {
KNOWLEDGE_SEARCH_IN_PROGRESS = 'knowledge_search_in_progress',
KNOWLEDGE_SEARCH_COMPLETE = 'knowledge_search_complete',
MCP_TOOL_CREATED = 'mcp_tool_created',
MCP_TOOL_PENDING = 'mcp_tool_pending',
MCP_TOOL_IN_PROGRESS = 'mcp_tool_in_progress',
MCP_TOOL_COMPLETE = 'mcp_tool_complete',
EXTERNEL_TOOL_COMPLETE = 'externel_tool_complete',
@ -260,6 +261,11 @@ export interface MCPToolCreatedChunk {
tool_use_responses?: ToolUseResponse[] // 工具使用响应
}
export interface MCPToolPendingChunk {
type: ChunkType.MCP_TOOL_PENDING
responses: MCPToolResponse[]
}
export interface MCPToolInProgressChunk {
/**
* The type of the chunk
@ -353,6 +359,7 @@ export type Chunk =
| KnowledgeSearchInProgressChunk // 知识库搜索进行中
| KnowledgeSearchCompleteChunk // 知识库搜索完成
| MCPToolCreatedChunk // MCP工具被大模型创建
| MCPToolPendingChunk // MCP工具调用等待中
| MCPToolInProgressChunk // MCP工具调用中
| MCPToolCompleteChunk // MCP工具调用完成
| ExternalToolCompleteChunk // 外部工具调用完成外部工具包含搜索互联网知识库MCP服务器

View File

@ -683,11 +683,13 @@ export interface MCPConfig {
isBunInstalled: boolean
}
export type MCPToolResponseStatus = 'pending' | 'cancelled' | 'invoking' | 'done' | 'error'
interface BaseToolResponse {
id: string // unique id
tool: MCPTool
arguments: Record<string, unknown> | undefined
status: string // 'invoking' | 'done'
status: MCPToolResponseStatus
response?: any
}

View File

@ -0,0 +1,693 @@
import { describe, expect, test } from 'vitest'
import { TagConfig, TagExtractor } from '../tagExtraction'
describe('TagExtractor', () => {
describe('基本标签提取', () => {
test('应该正确提取简单的标签内容', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
const results = extractor.processText('<think>Hello World</think>')
expect(results).toHaveLength(2)
expect(results[0]).toEqual({
content: 'Hello World',
isTagContent: true,
complete: false
})
expect(results[1]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: 'Hello World'
})
})
test('应该处理标签前后的普通文本', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
const results = extractor.processText('前文<think>思考内容</think>后文')
expect(results).toHaveLength(4)
expect(results[0]).toEqual({
content: '前文',
isTagContent: false,
complete: false
})
expect(results[1]).toEqual({
content: '思考内容',
isTagContent: true,
complete: false
})
expect(results[2]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '思考内容'
})
expect(results[3]).toEqual({
content: '后文',
isTagContent: false,
complete: false
})
})
test('应该处理空标签', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
const results = extractor.processText('<think></think>')
expect(results).toHaveLength(0)
})
})
describe('分块处理', () => {
test('应该正确处理分块的标签内容', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
let results = extractor.processText('<think>第一')
expect(results).toHaveLength(1)
expect(results[0]).toEqual({
content: '第一',
isTagContent: true,
complete: false
})
results = extractor.processText('部分内容')
expect(results).toHaveLength(1)
expect(results[0]).toEqual({
content: '部分内容',
isTagContent: true,
complete: false
})
results = extractor.processText('</think>')
expect(results).toHaveLength(1)
expect(results[0]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '第一部分内容'
})
})
test('应该处理分块的开始标签', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
let results = extractor.processText('<thi')
expect(results).toHaveLength(0)
results = extractor.processText('nk>内容</think>')
expect(results).toHaveLength(2)
expect(results[0]).toEqual({
content: '内容',
isTagContent: true,
complete: false
})
expect(results[1]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '内容'
})
})
test('应该处理模拟可读流的分块数据', async () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
// 模拟流式数据块
const streamChunks = [
'这是普通文本',
'<thi',
'nk>这是第一个',
'思考内容',
'</think>',
'中间的一些文本',
'<think>第二',
'个思考内容',
'</thi',
'nk>',
'结束文本'
]
const allResults: any[] = []
// 模拟异步流式处理
for (const chunk of streamChunks) {
await new Promise((resolve) => setTimeout(resolve, 10)) // 模拟异步延迟
const results = extractor.processText(chunk)
allResults.push(...results)
}
// 验证结果
expect(allResults).toHaveLength(9)
// 第一个普通文本
expect(allResults[0]).toEqual({
content: '这是普通文本',
isTagContent: false,
complete: false
})
// 第一个思考标签内容
expect(allResults[1]).toEqual({
content: '这是第一个',
isTagContent: true,
complete: false
})
expect(allResults[2]).toEqual({
content: '思考内容',
isTagContent: true,
complete: false
})
// 第一个完整的标签内容提取
expect(allResults[3]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '这是第一个思考内容'
})
// 中间文本
expect(allResults[4]).toEqual({
content: '中间的一些文本',
isTagContent: false,
complete: false
})
// 第二个思考标签内容
expect(allResults[5]).toEqual({
content: '第二',
isTagContent: true,
complete: false
})
// 第二个完整的标签内容提取和结束文本
expect(allResults[6]).toEqual({
content: '个思考内容',
isTagContent: true,
complete: false
})
expect(allResults[7]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '第二个思考内容'
})
expect(allResults[8]).toEqual({
content: '结束文本',
isTagContent: false,
complete: false
})
})
})
describe('多个标签处理', () => {
test('应该处理连续的多个标签', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
const results = extractor.processText('<think>第一个</think><think>第二个</think>')
expect(results).toHaveLength(4)
expect(results[0]).toEqual({
content: '第一个',
isTagContent: true,
complete: false
})
expect(results[1]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '第一个'
})
expect(results[2]).toEqual({
content: '第二个',
isTagContent: true,
complete: false
})
expect(results[3]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '第二个'
})
})
test('应该处理标签间的文本内容', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
const results = extractor.processText('<think>思考1</think>中间文本<think>思考2</think>')
expect(results).toHaveLength(5)
expect(results[0]).toEqual({
content: '思考1',
isTagContent: true,
complete: false
})
expect(results[1]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '思考1'
})
expect(results[2]).toEqual({
content: '中间文本',
isTagContent: false,
complete: false
})
expect(results[3]).toEqual({
content: '思考2',
isTagContent: true,
complete: false
})
expect(results[4]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '思考2'
})
})
test('应该处理三个连续标签的分次输出', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
// 第一次输入:包含两个完整标签和第三个标签的开始
let results = extractor.processText('<think>第一个</think><think>第二个</think><think>第三个开始')
expect(results).toHaveLength(5)
expect(results[0]).toEqual({
content: '第一个',
isTagContent: true,
complete: false
})
expect(results[1]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '第一个'
})
expect(results[2]).toEqual({
content: '第二个',
isTagContent: true,
complete: false
})
expect(results[3]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '第二个'
})
expect(results[4]).toEqual({
content: '第三个开始',
isTagContent: true,
complete: false
})
// 第二次输入:继续第三个标签的内容
results = extractor.processText('继续内容')
expect(results).toHaveLength(1)
expect(results[0]).toEqual({
content: '继续内容',
isTagContent: true,
complete: false
})
// 第三次输入:完成第三个标签
results = extractor.processText('结束</think>')
expect(results).toHaveLength(2)
expect(results[0]).toEqual({
content: '结束',
isTagContent: true,
complete: false
})
expect(results[1]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '第三个开始继续内容结束'
})
})
test('应该处理三个连续标签的另一种分次输出模式', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
// 第一次输入:第一个完整标签
let results = extractor.processText('<think>第一个思考</think>')
expect(results).toHaveLength(2)
expect(results[0]).toEqual({
content: '第一个思考',
isTagContent: true,
complete: false
})
expect(results[1]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '第一个思考'
})
// 第二次输入:第二个完整标签和第三个标签的部分内容
results = extractor.processText('<think>第二个思考</think><think>第三个开')
expect(results).toHaveLength(3)
expect(results[0]).toEqual({
content: '第二个思考',
isTagContent: true,
complete: false
})
expect(results[1]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '第二个思考'
})
expect(results[2]).toEqual({
content: '第三个开',
isTagContent: true,
complete: false
})
// 第三次输入:完成第三个标签
results = extractor.processText('始部分</think>')
expect(results).toHaveLength(2)
expect(results[0]).toEqual({
content: '始部分',
isTagContent: true,
complete: false
})
expect(results[1]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '第三个开始部分'
})
})
})
describe('不完整标签处理', () => {
test('应该处理只有开始标签的情况', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
const results = extractor.processText('<think>未完成的思考')
expect(results).toHaveLength(1)
expect(results[0]).toEqual({
content: '未完成的思考',
isTagContent: true,
complete: false
})
})
test('应该处理文本中间截断的标签', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
const results = extractor.processText('正常文本<thi')
expect(results).toHaveLength(1)
expect(results[0]).toEqual({
content: '正常文本',
isTagContent: false,
complete: false
})
})
})
describe('finalize 方法', () => {
test('应该返回未完成的标签内容', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
extractor.processText('<think>未完成的内容')
const result = extractor.finalize()
expect(result).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '未完成的内容'
})
})
test('当没有未完成内容时应该返回 null', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
extractor.processText('<think>完整内容</think>')
const result = extractor.finalize()
expect(result).toBeNull()
})
test('对于普通文本应该返回 null', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
extractor.processText('只是普通文本')
const result = extractor.finalize()
expect(result).toBeNull()
})
})
describe('reset 方法', () => {
test('应该重置所有内部状态', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
// 处理一些文本以改变内部状态
extractor.processText('<think>一些内容')
// 重置
extractor.reset()
// 重置后应该能正常处理新的文本
const results = extractor.processText('<think>新内容</think>')
expect(results).toHaveLength(2)
expect(results[0]).toEqual({
content: '新内容',
isTagContent: true,
complete: false
})
expect(results[1]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '新内容'
})
})
test('重置后 finalize 应该返回 null', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
extractor.processText('<think>未完成')
extractor.reset()
const result = extractor.finalize()
expect(result).toBeNull()
})
})
describe('不同标签配置', () => {
test('应该处理工具使用标签', () => {
const config: TagConfig = {
openingTag: '<tool_use>',
closingTag: '</tool_use>'
}
const extractor = new TagExtractor(config)
const results = extractor.processText('<tool_use>{"name": "search"}</tool_use>')
expect(results).toHaveLength(2)
expect(results[0]).toEqual({
content: '{"name": "search"}',
isTagContent: true,
complete: false
})
expect(results[1]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '{"name": "search"}'
})
})
test('应该处理自定义标签', () => {
const config: TagConfig = {
openingTag: '[START]',
closingTag: '[END]'
}
const extractor = new TagExtractor(config)
const results = extractor.processText('前文[START]中间内容[END]后文')
expect(results).toHaveLength(4)
expect(results[0]).toEqual({
content: '前文',
isTagContent: false,
complete: false
})
expect(results[1]).toEqual({
content: '中间内容',
isTagContent: true,
complete: false
})
expect(results[2]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '中间内容'
})
expect(results[3]).toEqual({
content: '后文',
isTagContent: false,
complete: false
})
})
})
describe('边界情况', () => {
test('应该处理空字符串输入', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
const results = extractor.processText('')
expect(results).toHaveLength(0)
})
test('应该处理只包含标签的输入', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
const results = extractor.processText('<think></think>')
expect(results).toHaveLength(0)
})
test('应该处理标签内容包含相似文本的情况', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
const results = extractor.processText('<think>我在<thinking>思考</think>')
expect(results).toHaveLength(2)
expect(results[0]).toEqual({
content: '我在<thinking>思考',
isTagContent: true,
complete: false
})
expect(results[1]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: '我在<thinking>思考'
})
})
test('应该处理换行符和特殊字符', () => {
const config: TagConfig = {
openingTag: '<think>',
closingTag: '</think>'
}
const extractor = new TagExtractor(config)
const content = '多行\n内容\t带制表符'
const results = extractor.processText(`<think>${content}</think>`)
expect(results).toHaveLength(2)
expect(results[0]).toEqual({
content: content,
isTagContent: true,
complete: false
})
expect(results[1]).toEqual({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: content
})
})
})
})

View File

@ -14,9 +14,8 @@ import {
Model,
ToolUseResponse
} from '@renderer/types'
import type { MCPToolCompleteChunk, MCPToolInProgressChunk } from '@renderer/types/chunk'
import type { MCPToolCompleteChunk, MCPToolInProgressChunk, MCPToolPendingChunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import { SdkMessageParam } from '@renderer/types/sdk'
import { isArray, isObject, pull, transform } from 'lodash'
import { nanoid } from 'nanoid'
import OpenAI from 'openai'
@ -28,6 +27,7 @@ import {
} from 'openai/resources'
import { CompletionsParams } from '../aiCore/middleware/schemas'
import { requestToolConfirmation } from './userConfirmation'
const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install'
const EXTRA_SCHEMA_KEYS = ['schema', 'headers']
@ -278,7 +278,8 @@ export async function callMCPTool(toolResponse: MCPToolResponse): Promise<MCPCal
const resp = await window.api.mcp.callTool({
server,
name: toolResponse.tool.name,
args: toolResponse.arguments
args: toolResponse.arguments,
callId: toolResponse.id
})
if (toolResponse.tool.serverName === MCP_AUTO_INSTALL_SERVER_NAME) {
if (resp.data) {
@ -400,7 +401,7 @@ export function geminiFunctionCallToMcpTool(
export function upsertMCPToolResponse(
results: MCPToolResponse[],
resp: MCPToolResponse,
onChunk: (chunk: MCPToolInProgressChunk | MCPToolCompleteChunk) => void
onChunk: (chunk: MCPToolPendingChunk | MCPToolInProgressChunk | MCPToolCompleteChunk) => void
) {
const index = results.findIndex((ret) => ret.id === resp.id)
let result = resp
@ -416,10 +417,29 @@ export function upsertMCPToolResponse(
} else {
results.push(resp)
}
onChunk({
type: resp.status === 'invoking' ? ChunkType.MCP_TOOL_IN_PROGRESS : ChunkType.MCP_TOOL_COMPLETE,
responses: [result]
})
switch (resp.status) {
case 'pending':
onChunk({
type: ChunkType.MCP_TOOL_PENDING,
responses: [result]
})
break
case 'invoking':
onChunk({
type: ChunkType.MCP_TOOL_IN_PROGRESS,
responses: [result]
})
break
case 'cancelled':
case 'done':
onChunk({
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [result]
})
break
default:
break
}
}
export function filterMCPTools(
@ -441,7 +461,7 @@ export function getMcpServerByTool(tool: MCPTool) {
return servers.find((s) => s.id === tool.serverId)
}
export function parseToolUse(content: string, mcpTools: MCPTool[]): ToolUseResponse[] {
export function parseToolUse(content: string, mcpTools: MCPTool[], startIdx: number = 0): ToolUseResponse[] {
if (!content || !mcpTools || mcpTools.length === 0) {
return []
}
@ -461,7 +481,7 @@ export function parseToolUse(content: string, mcpTools: MCPTool[]): ToolUseRespo
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
const tools: ToolUseResponse[] = []
let match
let idx = 0
let idx = startIdx
// Find all tool use blocks
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
// const fullMatch = match[0]
@ -505,8 +525,9 @@ export async function parseAndCallTools<R>(
onChunk: CompletionsParams['onChunk'],
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[]
): Promise<SdkMessageParam[]>
mcpTools?: MCPTool[],
abortSignal?: AbortSignal
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }>
export async function parseAndCallTools<R>(
content: string,
@ -514,8 +535,9 @@ export async function parseAndCallTools<R>(
onChunk: CompletionsParams['onChunk'],
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[]
): Promise<SdkMessageParam[]>
mcpTools?: MCPTool[],
abortSignal?: AbortSignal
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }>
export async function parseAndCallTools<R>(
content: string | MCPToolResponse[],
@ -523,68 +545,172 @@ export async function parseAndCallTools<R>(
onChunk: CompletionsParams['onChunk'],
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[]
): Promise<R[]> {
mcpTools?: MCPTool[],
abortSignal?: AbortSignal
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }> {
const toolResults: R[] = []
let curToolResponses: MCPToolResponse[] = []
if (Array.isArray(content)) {
curToolResponses = content
} else {
// process tool use
curToolResponses = parseToolUse(content, mcpTools || [])
curToolResponses = parseToolUse(content, mcpTools || [], 0)
}
if (!curToolResponses || curToolResponses.length === 0) {
return toolResults
return { toolResults, confirmedToolResponses: [] }
}
for (let i = 0; i < curToolResponses.length; i++) {
const toolResponse = curToolResponses[i]
for (const toolResponse of curToolResponses) {
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'invoking'
status: 'pending'
},
onChunk!
)
}
const toolPromises = curToolResponses.map(async (toolResponse) => {
const images: string[] = []
const toolCallResponse = await callMCPTool(toolResponse)
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'done',
response: toolCallResponse
},
onChunk!
)
// 创建工具确认Promise映射并立即处理每个确认
const confirmedTools: MCPToolResponse[] = []
const pendingPromises: Promise<void>[] = []
for (const content of toolCallResponse.content) {
if (content.type === 'image' && content.data) {
images.push(`data:${content.mimeType};base64,${content.data}`)
}
}
curToolResponses.forEach((toolResponse) => {
const confirmationPromise = requestToolConfirmation(toolResponse.id, abortSignal)
if (images.length) {
onChunk?.({
type: ChunkType.IMAGE_CREATED
})
onChunk?.({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: images
const processingPromise = confirmationPromise
.then(async (confirmed) => {
if (confirmed) {
// 立即更新为invoking状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'invoking'
},
onChunk!
)
// 执行工具调用
try {
const images: string[] = []
const toolCallResponse = await callMCPTool(toolResponse)
// 立即更新为done状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'done',
response: toolCallResponse
},
onChunk!
)
// 处理图片
for (const content of toolCallResponse.content) {
if (content.type === 'image' && content.data) {
images.push(`data:${content.mimeType};base64,${content.data}`)
}
}
if (images.length) {
onChunk?.({
type: ChunkType.IMAGE_CREATED
})
onChunk?.({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: images
}
})
}
// 转换消息并添加到结果
const convertedMessage = convertToMessage(toolResponse, toolCallResponse, model)
if (convertedMessage) {
confirmedTools.push(toolResponse)
toolResults.push(convertedMessage)
}
} catch (error) {
Logger.error(`🔧 [MCP] Error executing tool ${toolResponse.id}:`, error)
// 更新为错误状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'done',
response: {
isError: true,
content: [
{
type: 'text',
text: `Error executing tool: ${error instanceof Error ? error.message : 'Unknown error'}`
}
]
}
},
onChunk!
)
}
} else {
// 立即更新为cancelled状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'cancelled',
response: {
isError: false,
content: [
{
type: 'text',
text: 'Tool call cancelled by user.'
}
]
}
},
onChunk!
)
}
})
}
.catch((error) => {
Logger.error(`🔧 [MCP] Error waiting for tool confirmation ${toolResponse.id}:`, error)
// 立即更新为cancelled状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'cancelled',
response: {
isError: true,
content: [
{
type: 'text',
text: `Error in confirmation process: ${error instanceof Error ? error.message : 'Unknown error'}`
}
]
}
},
onChunk!
)
})
return convertToMessage(toolResponse, toolCallResponse, model)
pendingPromises.push(processingPromise)
})
toolResults.push(...(await Promise.all(toolPromises)).filter((t) => typeof t !== 'undefined'))
return toolResults
Logger.info(
`🔧 [MCP] Waiting for tool confirmations:`,
curToolResponses.map((t) => t.id)
)
// 等待所有工具处理完成(但每个工具的状态已经实时更新)
await Promise.all(pendingPromises)
Logger.info(`🔧 [MCP] All tools processed. Confirmed tools: ${confirmedTools.length}`)
return { toolResults, confirmedToolResponses: confirmedTools }
}
export function mcpToolCallResponseToOpenAICompatibleMessage(

View File

@ -2,7 +2,7 @@ import store from '@renderer/store'
import { Assistant, MCPTool } from '@renderer/types'
export const SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \
You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
You can use one or more tools per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
## Tool Use Formatting

View File

@ -0,0 +1,86 @@
import Logger from '@renderer/config/logger'
// 存储每个工具的确认Promise的resolve函数
const toolConfirmResolvers = new Map<string, (value: boolean) => void>()
// 存储每个工具的abort监听器清理函数
const abortListeners = new Map<string, () => void>()
export function requestUserConfirmation(): Promise<boolean> {
return new Promise((resolve) => {
const globalKey = '_global'
toolConfirmResolvers.set(globalKey, resolve)
})
}
export function requestToolConfirmation(toolId: string, abortSignal?: AbortSignal): Promise<boolean> {
return new Promise((resolve) => {
if (abortSignal?.aborted) {
resolve(false)
return
}
toolConfirmResolvers.set(toolId, resolve)
if (abortSignal) {
const abortListener = () => {
const resolver = toolConfirmResolvers.get(toolId)
if (resolver) {
resolver(false)
toolConfirmResolvers.delete(toolId)
abortListeners.delete(toolId)
}
}
abortSignal.addEventListener('abort', abortListener)
// 存储清理函数
const cleanup = () => {
abortSignal.removeEventListener('abort', abortListener)
abortListeners.delete(toolId)
}
abortListeners.set(toolId, cleanup)
}
})
}
export function confirmToolAction(toolId: string) {
const resolve = toolConfirmResolvers.get(toolId)
if (resolve) {
resolve(true)
toolConfirmResolvers.delete(toolId)
// 清理abort监听器
const cleanup = abortListeners.get(toolId)
if (cleanup) {
cleanup()
}
} else {
Logger.warn(`🔧 [userConfirmation] No resolver found for tool: ${toolId}`)
}
}
export function cancelToolAction(toolId: string) {
const resolve = toolConfirmResolvers.get(toolId)
if (resolve) {
resolve(false)
toolConfirmResolvers.delete(toolId)
// 清理abort监听器
const cleanup = abortListeners.get(toolId)
if (cleanup) {
cleanup()
}
} else {
Logger.warn(`🔧 [userConfirmation] No resolver found for tool: ${toolId}`)
}
}
// 获取所有待确认的工具ID
export function getPendingToolIds(): string[] {
return Array.from(toolConfirmResolvers.keys()).filter((id) => id !== '_global')
}
// 检查某个工具是否在等待确认
export function isToolPending(toolId: string): boolean {
return toolConfirmResolvers.has(toolId)
}

View File

@ -3756,11 +3756,11 @@ __metadata:
languageName: node
linkType: hard
"@modelcontextprotocol/sdk@npm:^1.11.4":
version: 1.11.4
resolution: "@modelcontextprotocol/sdk@npm:1.11.4"
"@modelcontextprotocol/sdk@npm:^1.12.3":
version: 1.12.3
resolution: "@modelcontextprotocol/sdk@npm:1.12.3"
dependencies:
ajv: "npm:^8.17.1"
ajv: "npm:^6.12.6"
content-type: "npm:^1.0.5"
cors: "npm:^2.8.5"
cross-spawn: "npm:^7.0.5"
@ -3771,7 +3771,7 @@ __metadata:
raw-body: "npm:^3.0.0"
zod: "npm:^3.23.8"
zod-to-json-schema: "npm:^3.24.1"
checksum: 10c0/797694937e65ccc02e8dc63db711d9d96fbc49b49e6d246e6fed95d8d2bfe98ef203207224e39c9fc3b54da182da865a5d311ea06ef939c5c57ce0cd27c0f546
checksum: 10c0/8bc0b91e596ec886efc64d68ae8474247647405f1a5ae407e02439c74c2a03528b3fbdce8f9352d9c2df54aa4548411e1aa1816ab3b09e045c2ff4202e2fd374
languageName: node
linkType: hard
@ -7092,7 +7092,7 @@ __metadata:
"@libsql/client": "npm:0.14.0"
"@libsql/win32-x64-msvc": "npm:^0.4.7"
"@mistralai/mistralai": "npm:^1.6.0"
"@modelcontextprotocol/sdk": "npm:^1.11.4"
"@modelcontextprotocol/sdk": "npm:^1.12.3"
"@mozilla/readability": "npm:^0.6.0"
"@notionhq/client": "npm:^2.2.15"
"@playwright/test": "npm:^1.52.0"
@ -7349,7 +7349,7 @@ __metadata:
languageName: node
linkType: hard
"ajv@npm:^6.10.0, ajv@npm:^6.12.0, ajv@npm:^6.12.4":
"ajv@npm:^6.10.0, ajv@npm:^6.12.0, ajv@npm:^6.12.4, ajv@npm:^6.12.6":
version: 6.12.6
resolution: "ajv@npm:6.12.6"
dependencies:
@ -7361,7 +7361,7 @@ __metadata:
languageName: node
linkType: hard
"ajv@npm:^8.0.0, ajv@npm:^8.17.1, ajv@npm:^8.6.3":
"ajv@npm:^8.0.0, ajv@npm:^8.6.3":
version: 8.17.1
resolution: "ajv@npm:8.17.1"
dependencies: