mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-05 20:41:30 +08:00
feat: implement tool call progress handling and status updates (#7303)
* feat: implement tool call progress handling and status updates - Update MCP tool response handling to include 'pending' and 'cancelled' statuses. - Introduce new IPC channel for progress updates. - Enhance UI components to reflect tool call statuses, including pending and cancelled states. - Add localization for new status messages in multiple languages. - Refactor message handling logic to accommodate new tool response types. * fix: adjust alignment of action tool container in MessageTools component - Change justify-content from flex-end to flex-start to improve layout consistency. * feat: enhance tool confirmation handling and update related components - Introduced a new tool confirmation mechanism in userConfirmation.ts, allowing for individual tool confirmations. - Updated GeminiAPIClient and OpenAIResponseAPIClient to include tool configuration options. - Refactored MessageTools component to utilize new confirmation functions and improved styling. - Enhanced mcp-tools.ts to manage tool invocation and confirmation processes more effectively, ensuring real-time status updates. * refactor(McpToolChunkMiddleware): enhance tool execution handling and confirmation tracking - Updated createToolHandlingTransform to manage confirmed tool calls and results more effectively. - Refactored executeToolCalls and executeToolUseResponses to return both tool results and confirmed tool calls. - Adjusted buildParamsWithToolResults to utilize confirmed tool calls for building new request messages. - Improved error handling in messageThunk for tool call status updates, ensuring accurate block ID mapping. * feat(McpToolChunkMiddleware, ToolUseExtractionMiddleware, mcp-tools, userConfirmation): enhance tool execution and confirmation handling - Updated McpToolChunkMiddleware to execute tool calls and responses asynchronously, improving performance and response handling. - Enhanced ToolUseExtractionMiddleware to generate unique tool IDs for better tracking. - Modified parseToolUse function to accept a starting index for tool extraction. - Improved user confirmation handling with abort signal support to manage tool action confirmations more effectively. - Updated SYSTEM_PROMPT to clarify the use of multiple tools per message. * fix(tagExtraction): update test expectations for tag extraction results - Adjusted expected length of results from 7 to 9 to reflect changes in tag extraction logic. - Modified content assertions for specific tag contents to ensure accurate validation of extracted tags. * refactor(GeminiAPIClient, OpenAIResponseAPIClient): remove unused function calling configurations - Removed the unused FunctionCallingConfigMode from GeminiAPIClient to streamline the code. - Eliminated the parallel_tool_calls property from OpenAIResponseAPIClient, simplifying the tool call configuration. * feat(McpToolChunkMiddleware): enhance LLM response handling and tool call confirmation - Added notification to UI for new LLM response processing before recursive calls in createToolHandlingTransform. - Improved tool call confirmation logic in executeToolCalls to match tool IDs more accurately, enhancing response validation. * refactor(McpToolChunkMiddleware, ToolUseExtractionMiddleware, messageThunk): remove unnecessary console logs - Eliminated redundant console log statements in McpToolChunkMiddleware, ToolUseExtractionMiddleware, and messageThunk to clean up the code and improve performance. - Focused on enhancing readability and maintainability by reducing clutter in the logging output. * refactor(McpToolChunkMiddleware): remove redundant logging statements - Eliminated unnecessary logging in createToolHandlingTransform to streamline the code and enhance readability. - Focused on reducing clutter in the logging output while maintaining error handling functionality. * feat: enhance action button functionality with cancel and confirm options * refactor(AbortHandlerMiddleware, McpToolChunkMiddleware, ToolUseExtractionMiddleware, messageThunk): improve error handling and code clarity - Updated AbortHandlerMiddleware to skip abort status checks if an error chunk is received, enhancing error handling logic. - Replaced console.error with Logger.error in McpToolChunkMiddleware for consistent logging practices. - Refined ToolUseExtractionMiddleware to improve tool use extraction logic and ensure proper handling of tool_use tags. - Enhanced messageThunk to include initialPlaceholderBlockId in block ID checks, improving error state management. * refactor(ToolUseExtractionMiddleware): enhance tool use parsing logic with counter - Introduced a toolCounter to track the number of tool use responses processed. - Updated parseToolUse function calls to include the toolCounter, improving the extraction logic and ensuring accurate response handling. * feat(McpService, IpcChannel, MessageTools): implement tool abort functionality - Added Mcp_AbortTool channel to handle tool abortion requests. - Implemented abortTool method in McpService to manage active tool calls and provide logging. - Updated MessageTools component to include an abort button for ongoing tool calls, enhancing user control. - Modified API calls to support optional callId for better tracking of tool executions. - Added localization strings for tool abort messages in multiple languages. --------- Co-authored-by: Vaayne <liu.vaayne@gmail.com>
This commit is contained in:
parent
4f7ca3ede8
commit
fba6c1642d
@ -107,7 +107,7 @@
|
||||
"@langchain/community": "^0.3.36",
|
||||
"@langchain/ollama": "^0.2.1",
|
||||
"@mistralai/mistralai": "^1.6.0",
|
||||
"@modelcontextprotocol/sdk": "^1.11.4",
|
||||
"@modelcontextprotocol/sdk": "^1.12.3",
|
||||
"@mozilla/readability": "^0.6.0",
|
||||
"@notionhq/client": "^2.2.15",
|
||||
"@playwright/test": "^1.52.0",
|
||||
|
||||
@ -74,6 +74,8 @@ export enum IpcChannel {
|
||||
Mcp_ServersChanged = 'mcp:servers-changed',
|
||||
Mcp_ServersUpdated = 'mcp:servers-updated',
|
||||
Mcp_CheckConnectivity = 'mcp:check-connectivity',
|
||||
Mcp_SetProgress = 'mcp:set-progress',
|
||||
Mcp_AbortTool = 'mcp:abort-tool',
|
||||
|
||||
// Python
|
||||
Python_Execute = 'python:execute',
|
||||
|
||||
@ -501,6 +501,10 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
ipcMain.handle(IpcChannel.Mcp_GetResource, mcpService.getResource)
|
||||
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo)
|
||||
ipcMain.handle(IpcChannel.Mcp_CheckConnectivity, mcpService.checkMcpConnectivity)
|
||||
ipcMain.handle(IpcChannel.Mcp_AbortTool, mcpService.abortTool)
|
||||
ipcMain.handle(IpcChannel.Mcp_SetProgress, (_, progress: number) => {
|
||||
mainWindow.webContents.send('mcp-progress', progress)
|
||||
})
|
||||
|
||||
// Register Python execution handler
|
||||
ipcMain.handle(
|
||||
|
||||
@ -28,6 +28,7 @@ import { app } from 'electron'
|
||||
import Logger from 'electron-log'
|
||||
import { EventEmitter } from 'events'
|
||||
import { memoize } from 'lodash'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
import { CacheService } from './CacheService'
|
||||
import { CallBackServer } from './mcp/oauth/callback'
|
||||
@ -71,6 +72,7 @@ function withCache<T extends unknown[], R>(
|
||||
class McpService {
|
||||
private clients: Map<string, Client> = new Map()
|
||||
private pendingClients: Map<string, Promise<Client>> = new Map()
|
||||
private activeToolCalls: Map<string, AbortController> = new Map()
|
||||
|
||||
constructor() {
|
||||
this.initClient = this.initClient.bind(this)
|
||||
@ -84,6 +86,7 @@ class McpService {
|
||||
this.removeServer = this.removeServer.bind(this)
|
||||
this.restartServer = this.restartServer.bind(this)
|
||||
this.stopServer = this.stopServer.bind(this)
|
||||
this.abortTool = this.abortTool.bind(this)
|
||||
this.cleanup = this.cleanup.bind(this)
|
||||
}
|
||||
|
||||
@ -455,10 +458,14 @@ class McpService {
|
||||
*/
|
||||
public async callTool(
|
||||
_: Electron.IpcMainInvokeEvent,
|
||||
{ server, name, args }: { server: MCPServer; name: string; args: any }
|
||||
{ server, name, args, callId }: { server: MCPServer; name: string; args: any; callId?: string }
|
||||
): Promise<MCPCallToolResponse> {
|
||||
const toolCallId = callId || uuidv4()
|
||||
const abortController = new AbortController()
|
||||
this.activeToolCalls.set(toolCallId, abortController)
|
||||
|
||||
try {
|
||||
Logger.info('[MCP] Calling:', server.name, name, args)
|
||||
Logger.info('[MCP] Calling:', server.name, name, args, 'callId:', toolCallId)
|
||||
if (typeof args === 'string') {
|
||||
try {
|
||||
args = JSON.parse(args)
|
||||
@ -468,12 +475,19 @@ class McpService {
|
||||
}
|
||||
const client = await this.initClient(server)
|
||||
const result = await client.callTool({ name, arguments: args }, undefined, {
|
||||
timeout: server.timeout ? server.timeout * 1000 : 60000 // Default timeout of 1 minute
|
||||
onprogress: (process) => {
|
||||
console.log('[MCP] Progress:', process.progress / (process.total || 1))
|
||||
window.api.mcp.setProgress(process.progress / (process.total || 1))
|
||||
},
|
||||
timeout: server.timeout ? server.timeout * 1000 : 60000, // Default timeout of 1 minute
|
||||
signal: this.activeToolCalls.get(toolCallId)?.signal
|
||||
})
|
||||
return result as MCPCallToolResponse
|
||||
} catch (error) {
|
||||
Logger.error(`[MCP] Error calling tool ${name} on ${server.name}:`, error)
|
||||
throw error
|
||||
} finally {
|
||||
this.activeToolCalls.delete(toolCallId)
|
||||
}
|
||||
}
|
||||
|
||||
@ -664,6 +678,20 @@ class McpService {
|
||||
delete env.http_proxy
|
||||
delete env.https_proxy
|
||||
}
|
||||
|
||||
// 实现 abortTool 方法
|
||||
public async abortTool(_: Electron.IpcMainInvokeEvent, callId: string) {
|
||||
const activeToolCall = this.activeToolCalls.get(callId)
|
||||
if (activeToolCall) {
|
||||
activeToolCall.abort()
|
||||
this.activeToolCalls.delete(callId)
|
||||
Logger.info(`[MCP] Aborted tool call: ${callId}`)
|
||||
return true
|
||||
} else {
|
||||
Logger.warn(`[MCP] No active tool call found for callId: ${callId}`)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default new McpService()
|
||||
|
||||
@ -228,8 +228,8 @@ const api = {
|
||||
restartServer: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_RestartServer, server),
|
||||
stopServer: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_StopServer, server),
|
||||
listTools: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_ListTools, server),
|
||||
callTool: ({ server, name, args }: { server: MCPServer; name: string; args: any }) =>
|
||||
ipcRenderer.invoke(IpcChannel.Mcp_CallTool, { server, name, args }),
|
||||
callTool: ({ server, name, args, callId }: { server: MCPServer; name: string; args: any; callId?: string }) =>
|
||||
ipcRenderer.invoke(IpcChannel.Mcp_CallTool, { server, name, args, callId }),
|
||||
listPrompts: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_ListPrompts, server),
|
||||
getPrompt: ({ server, name, args }: { server: MCPServer; name: string; args?: Record<string, any> }) =>
|
||||
ipcRenderer.invoke(IpcChannel.Mcp_GetPrompt, { server, name, args }),
|
||||
@ -237,7 +237,9 @@ const api = {
|
||||
getResource: ({ server, uri }: { server: MCPServer; uri: string }) =>
|
||||
ipcRenderer.invoke(IpcChannel.Mcp_GetResource, { server, uri }),
|
||||
getInstallInfo: () => ipcRenderer.invoke(IpcChannel.Mcp_GetInstallInfo),
|
||||
checkMcpConnectivity: (server: any) => ipcRenderer.invoke(IpcChannel.Mcp_CheckConnectivity, server)
|
||||
checkMcpConnectivity: (server: any) => ipcRenderer.invoke(IpcChannel.Mcp_CheckConnectivity, server),
|
||||
abortTool: (callId: string) => ipcRenderer.invoke(IpcChannel.Mcp_AbortTool, callId),
|
||||
setProgress: (progress: number) => ipcRenderer.invoke(IpcChannel.Mcp_SetProgress, progress)
|
||||
},
|
||||
python: {
|
||||
execute: (script: string, context?: Record<string, any>, timeout?: number) =>
|
||||
|
||||
@ -67,7 +67,12 @@ export const AbortHandlerMiddleware: CompletionsMiddleware =
|
||||
const streamWithAbortHandler = (result.stream as ReadableStream<Chunk>).pipeThrough(
|
||||
new TransformStream<Chunk, Chunk | ErrorChunk>({
|
||||
transform(chunk, controller) {
|
||||
// 检查 abort 状态
|
||||
// 如果已经收到错误块,不再检查 abort 状态
|
||||
if (chunk.type === ChunkType.ERROR) {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
if (abortSignal?.aborted) {
|
||||
// 转换为 ErrorChunk
|
||||
const errorChunk: ErrorChunk = {
|
||||
|
||||
@ -136,7 +136,6 @@ function extractAndAccumulateUsageMetrics(ctx: CompletionsContext, chunk: Generi
|
||||
Logger.debug(`[${MIDDLEWARE_NAME}] First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`)
|
||||
}
|
||||
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||
Logger.debug(`[${MIDDLEWARE_NAME}] LLM_RESPONSE_COMPLETE chunk received:`, ctx._internal)
|
||||
// 从LLM_RESPONSE_COMPLETE chunk中提取usage数据
|
||||
if (chunk.response?.usage) {
|
||||
accumulateUsage(ctx._internal.observer.usage, chunk.response.usage)
|
||||
|
||||
@ -89,6 +89,11 @@ function createToolHandlingTransform(
|
||||
let hasToolUseResponses = false
|
||||
let streamEnded = false
|
||||
|
||||
// 存储已执行的工具结果
|
||||
const executedToolResults: SdkMessageParam[] = []
|
||||
const executedToolCalls: SdkToolCall[] = []
|
||||
const executionPromises: Promise<void>[] = []
|
||||
|
||||
return new TransformStream({
|
||||
async transform(chunk: GenericChunk, controller) {
|
||||
try {
|
||||
@ -98,22 +103,64 @@ function createToolHandlingTransform(
|
||||
|
||||
// 1. 处理Function Call方式的工具调用
|
||||
if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) {
|
||||
toolCalls.push(...createdChunk.tool_calls)
|
||||
hasToolCalls = true
|
||||
|
||||
for (const toolCall of createdChunk.tool_calls) {
|
||||
toolCalls.push(toolCall)
|
||||
|
||||
const executionPromise = (async () => {
|
||||
try {
|
||||
const result = await executeToolCalls(
|
||||
ctx,
|
||||
[toolCall],
|
||||
mcpTools,
|
||||
allToolResponses,
|
||||
currentParams.onChunk,
|
||||
currentParams.assistant.model!
|
||||
)
|
||||
|
||||
// 缓存执行结果
|
||||
executedToolResults.push(...result.toolResults)
|
||||
executedToolCalls.push(...result.confirmedToolCalls)
|
||||
} catch (error) {
|
||||
console.error(`🔧 [${MIDDLEWARE_NAME}] Error executing tool call asynchronously:`, error)
|
||||
}
|
||||
})()
|
||||
|
||||
executionPromises.push(executionPromise)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 处理Tool Use方式的工具调用
|
||||
if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) {
|
||||
toolUseResponses.push(...createdChunk.tool_use_responses)
|
||||
hasToolUseResponses = true
|
||||
for (const toolUseResponse of createdChunk.tool_use_responses) {
|
||||
toolUseResponses.push(toolUseResponse)
|
||||
const executionPromise = (async () => {
|
||||
try {
|
||||
const result = await executeToolUseResponses(
|
||||
ctx,
|
||||
[toolUseResponse], // 单个执行
|
||||
mcpTools,
|
||||
allToolResponses,
|
||||
currentParams.onChunk,
|
||||
currentParams.assistant.model!
|
||||
)
|
||||
|
||||
// 缓存执行结果
|
||||
executedToolResults.push(...result.toolResults)
|
||||
} catch (error) {
|
||||
console.error(`🔧 [${MIDDLEWARE_NAME}] Error executing tool use response asynchronously:`, error)
|
||||
// 错误时不影响其他工具的执行
|
||||
}
|
||||
})()
|
||||
|
||||
executionPromises.push(executionPromise)
|
||||
}
|
||||
}
|
||||
|
||||
// 不转发MCP工具进展chunks,避免重复处理
|
||||
return
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
|
||||
// 转发其他所有chunk
|
||||
controller.enqueue(chunk)
|
||||
} catch (error) {
|
||||
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
|
||||
controller.error(error)
|
||||
@ -121,43 +168,33 @@ function createToolHandlingTransform(
|
||||
},
|
||||
|
||||
async flush(controller) {
|
||||
const shouldExecuteToolCalls = hasToolCalls && toolCalls.length > 0
|
||||
const shouldExecuteToolUseResponses = hasToolUseResponses && toolUseResponses.length > 0
|
||||
|
||||
if (!streamEnded && (shouldExecuteToolCalls || shouldExecuteToolUseResponses)) {
|
||||
// 在流结束时等待所有异步工具执行完成,然后进行递归调用
|
||||
if (!streamEnded && (hasToolCalls || hasToolUseResponses)) {
|
||||
streamEnded = true
|
||||
|
||||
try {
|
||||
let toolResult: SdkMessageParam[] = []
|
||||
|
||||
if (shouldExecuteToolCalls) {
|
||||
toolResult = await executeToolCalls(
|
||||
ctx,
|
||||
toolCalls,
|
||||
mcpTools,
|
||||
allToolResponses,
|
||||
currentParams.onChunk,
|
||||
currentParams.assistant.model!
|
||||
)
|
||||
} else if (shouldExecuteToolUseResponses) {
|
||||
toolResult = await executeToolUseResponses(
|
||||
ctx,
|
||||
toolUseResponses,
|
||||
mcpTools,
|
||||
allToolResponses,
|
||||
currentParams.onChunk,
|
||||
currentParams.assistant.model!
|
||||
)
|
||||
}
|
||||
|
||||
if (toolResult.length > 0) {
|
||||
await Promise.all(executionPromises)
|
||||
if (executedToolResults.length > 0) {
|
||||
const output = ctx._internal.toolProcessingState?.output
|
||||
const newParams = buildParamsWithToolResults(
|
||||
ctx,
|
||||
currentParams,
|
||||
output,
|
||||
executedToolResults,
|
||||
executedToolCalls
|
||||
)
|
||||
|
||||
// 在递归调用前通知UI开始新的LLM响应处理
|
||||
if (currentParams.onChunk) {
|
||||
currentParams.onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_CREATED
|
||||
})
|
||||
}
|
||||
|
||||
const newParams = buildParamsWithToolResults(ctx, currentParams, output, toolResult, toolCalls)
|
||||
await executeWithToolHandling(newParams, depth + 1)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`🔧 [${MIDDLEWARE_NAME}] Error in tool processing:`, error)
|
||||
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Error in tool processing:`, error)
|
||||
controller.error(error)
|
||||
} finally {
|
||||
hasToolCalls = false
|
||||
@ -178,8 +215,7 @@ async function executeToolCalls(
|
||||
allToolResponses: MCPToolResponse[],
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
model: Model
|
||||
): Promise<SdkMessageParam[]> {
|
||||
// 转换为MCPToolResponse格式
|
||||
): Promise<{ toolResults: SdkMessageParam[]; confirmedToolCalls: SdkToolCall[] }> {
|
||||
const mcpToolResponses: ToolCallResponse[] = toolCalls
|
||||
.map((toolCall) => {
|
||||
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||
@ -192,11 +228,11 @@ async function executeToolCalls(
|
||||
|
||||
if (mcpToolResponses.length === 0) {
|
||||
console.warn(`🔧 [${MIDDLEWARE_NAME}] No valid MCP tool responses to execute`)
|
||||
return []
|
||||
return { toolResults: [], confirmedToolCalls: [] }
|
||||
}
|
||||
|
||||
// 使用现有的parseAndCallTools函数执行工具
|
||||
const toolResults = await parseAndCallTools(
|
||||
const { toolResults, confirmedToolResponses } = await parseAndCallTools(
|
||||
mcpToolResponses,
|
||||
allToolResponses,
|
||||
onChunk,
|
||||
@ -204,10 +240,24 @@ async function executeToolCalls(
|
||||
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
},
|
||||
model,
|
||||
mcpTools
|
||||
mcpTools,
|
||||
ctx._internal?.flowControl?.abortSignal
|
||||
)
|
||||
|
||||
return toolResults
|
||||
// 找出已确认工具对应的原始toolCalls
|
||||
const confirmedToolCalls = toolCalls.filter((toolCall) => {
|
||||
return confirmedToolResponses.find((confirmed) => {
|
||||
// 根据不同的ID字段匹配原始toolCall
|
||||
return (
|
||||
('name' in toolCall &&
|
||||
(toolCall.name?.includes(confirmed.tool.name) || toolCall.name?.includes(confirmed.tool.id))) ||
|
||||
confirmed.tool.name === toolCall.id ||
|
||||
confirmed.tool.id === toolCall.id
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
return { toolResults, confirmedToolCalls }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -221,9 +271,9 @@ async function executeToolUseResponses(
|
||||
allToolResponses: MCPToolResponse[],
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
model: Model
|
||||
): Promise<SdkMessageParam[]> {
|
||||
): Promise<{ toolResults: SdkMessageParam[] }> {
|
||||
// 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse
|
||||
const toolResults = await parseAndCallTools(
|
||||
const { toolResults } = await parseAndCallTools(
|
||||
toolUseResponses,
|
||||
allToolResponses,
|
||||
onChunk,
|
||||
@ -231,10 +281,11 @@ async function executeToolUseResponses(
|
||||
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
},
|
||||
model,
|
||||
mcpTools
|
||||
mcpTools,
|
||||
ctx._internal?.flowControl?.abortSignal
|
||||
)
|
||||
|
||||
return toolResults
|
||||
return { toolResults }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -245,7 +296,7 @@ function buildParamsWithToolResults(
|
||||
currentParams: CompletionsParams,
|
||||
output: SdkRawOutput | string | undefined,
|
||||
toolResults: SdkMessageParam[],
|
||||
toolCalls: SdkToolCall[]
|
||||
confirmedToolCalls: SdkToolCall[]
|
||||
): CompletionsParams {
|
||||
// 获取当前已经转换好的reqMessages,如果没有则使用原始messages
|
||||
const currentReqMessages = getCurrentReqMessages(ctx)
|
||||
@ -253,7 +304,7 @@ function buildParamsWithToolResults(
|
||||
const apiClient = ctx.apiClientInstance
|
||||
|
||||
// 从回复中构建助手消息
|
||||
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
||||
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, confirmedToolCalls)
|
||||
|
||||
if (output && ctx._internal.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState.output = undefined
|
||||
|
||||
@ -22,7 +22,8 @@ const TOOL_USE_TAG_CONFIG: TagConfig = {
|
||||
* 1. 从文本流中检测并提取 <tool_use></tool_use> 标签
|
||||
* 2. 解析工具调用信息并转换为 ToolUseResponse 格式
|
||||
* 3. 生成 MCP_TOOL_CREATED chunk 供 McpToolChunkMiddleware 处理
|
||||
* 4. 清理文本流,移除工具使用标签但保留正常文本
|
||||
* 4. 丢弃 tool_use 之后的所有内容(助手幻觉)
|
||||
* 5. 清理文本流,移除工具使用标签但保留正常文本
|
||||
*
|
||||
* 注意:此中间件只负责提取和转换,实际工具调用由 McpToolChunkMiddleware 处理
|
||||
*/
|
||||
@ -32,13 +33,10 @@ export const ToolUseExtractionMiddleware: CompletionsMiddleware =
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const mcpTools = params.mcpTools || []
|
||||
|
||||
// 如果没有工具,直接调用下一个中间件
|
||||
if (!mcpTools || mcpTools.length === 0) return next(ctx, params)
|
||||
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:处理工具使用标签提取
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
|
||||
@ -60,7 +58,9 @@ function createToolUseExtractionTransform(
|
||||
_ctx: CompletionsContext,
|
||||
mcpTools: MCPTool[]
|
||||
): TransformStream<GenericChunk, GenericChunk> {
|
||||
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
|
||||
const toolUseExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
|
||||
let hasAnyToolUse = false
|
||||
let toolCounter = 0
|
||||
|
||||
return new TransformStream({
|
||||
async transform(chunk: GenericChunk, controller) {
|
||||
@ -68,30 +68,37 @@ function createToolUseExtractionTransform(
|
||||
// 处理文本内容,检测工具使用标签
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
const textChunk = chunk as TextDeltaChunk
|
||||
const extractionResults = tagExtractor.processText(textChunk.text)
|
||||
|
||||
for (const result of extractionResults) {
|
||||
// 处理 tool_use 标签
|
||||
const toolUseResults = toolUseExtractor.processText(textChunk.text)
|
||||
|
||||
for (const result of toolUseResults) {
|
||||
if (result.complete && result.tagContentExtracted) {
|
||||
// 提取到完整的工具使用内容,解析并转换为 SDK ToolCall 格式
|
||||
const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools)
|
||||
const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools, toolCounter)
|
||||
toolCounter += toolUseResponses.length
|
||||
|
||||
if (toolUseResponses.length > 0) {
|
||||
// 生成 MCP_TOOL_CREATED chunk,复用现有的处理流程
|
||||
// 生成 MCP_TOOL_CREATED chunk
|
||||
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_use_responses: toolUseResponses
|
||||
}
|
||||
controller.enqueue(mcpToolCreatedChunk)
|
||||
|
||||
// 标记已有工具调用
|
||||
hasAnyToolUse = true
|
||||
}
|
||||
} else if (!result.isTagContent && result.content) {
|
||||
// 发送标签外的正常文本内容
|
||||
const cleanTextChunk: TextDeltaChunk = {
|
||||
...textChunk,
|
||||
text: result.content
|
||||
if (!hasAnyToolUse) {
|
||||
const cleanTextChunk: TextDeltaChunk = {
|
||||
...textChunk,
|
||||
text: result.content
|
||||
}
|
||||
controller.enqueue(cleanTextChunk)
|
||||
}
|
||||
controller.enqueue(cleanTextChunk)
|
||||
}
|
||||
// 注意:标签内的内容不会作为TEXT_DELTA转发,避免重复显示
|
||||
// tool_use 标签内的内容不转发,避免重复显示
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -105,16 +112,17 @@ function createToolUseExtractionTransform(
|
||||
},
|
||||
|
||||
async flush(controller) {
|
||||
// 检查是否有未完成的标签内容
|
||||
const finalResult = tagExtractor.finalize()
|
||||
if (finalResult && finalResult.tagContentExtracted) {
|
||||
const toolUseResponses = parseToolUse(finalResult.tagContentExtracted, mcpTools)
|
||||
// 检查是否有未完成的 tool_use 标签内容
|
||||
const finalToolUseResult = toolUseExtractor.finalize()
|
||||
if (finalToolUseResult && finalToolUseResult.tagContentExtracted) {
|
||||
const toolUseResponses = parseToolUse(finalToolUseResult.tagContentExtracted, mcpTools, toolCounter)
|
||||
if (toolUseResponses.length > 0) {
|
||||
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_use_responses: toolUseResponses
|
||||
}
|
||||
controller.enqueue(mcpToolCreatedChunk)
|
||||
hasAnyToolUse = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -706,8 +706,12 @@
|
||||
"success.yuque.export": "Successfully exported to Yuque",
|
||||
"switch.disabled": "Please wait for the current reply to complete",
|
||||
"tools": {
|
||||
"pending": "Pending",
|
||||
"cancelled": "Cancelled",
|
||||
"completed": "Completed",
|
||||
"invoking": "Invoking",
|
||||
"aborted": "Tool call aborted",
|
||||
"abort_failed": "Tool call abort failed",
|
||||
"error": "Error occurred",
|
||||
"raw": "Raw",
|
||||
"preview": "Preview"
|
||||
|
||||
@ -706,9 +706,13 @@
|
||||
"tools": {
|
||||
"completed": "完了",
|
||||
"invoking": "呼び出し中",
|
||||
"aborted": "ツール呼び出し中断",
|
||||
"abort_failed": "ツール呼び出し中断失敗",
|
||||
"error": "エラーが発生しました",
|
||||
"raw": "生データ",
|
||||
"preview": "プレビュー"
|
||||
"preview": "プレビュー",
|
||||
"pending": "保留中",
|
||||
"cancelled": "キャンセル"
|
||||
},
|
||||
"topic.added": "新しいトピックが追加されました",
|
||||
"upgrade.success.button": "再起動",
|
||||
|
||||
@ -705,11 +705,15 @@
|
||||
"success.yuque.export": "Успешный экспорт в Yuque",
|
||||
"switch.disabled": "Пожалуйста, дождитесь завершения текущего ответа",
|
||||
"tools": {
|
||||
"aborted": "Вызов инструмента прерван",
|
||||
"abort_failed": "Вызов инструмента прерван",
|
||||
"completed": "Завершено",
|
||||
"invoking": "Вызов",
|
||||
"error": "Произошла ошибка",
|
||||
"raw": "Исходный",
|
||||
"preview": "Предпросмотр"
|
||||
"preview": "Предпросмотр",
|
||||
"pending": "Ожидание",
|
||||
"cancelled": "Отменено"
|
||||
},
|
||||
"topic.added": "Новый топик добавлен",
|
||||
"upgrade.success.button": "Перезапустить",
|
||||
|
||||
@ -706,8 +706,12 @@
|
||||
"success.yuque.export": "成功导出到语雀",
|
||||
"switch.disabled": "请等待当前回复完成后操作",
|
||||
"tools": {
|
||||
"pending": "等待中",
|
||||
"cancelled": "已取消",
|
||||
"completed": "已完成",
|
||||
"invoking": "调用中",
|
||||
"aborted": "工具调用已中断",
|
||||
"abort_failed": "工具调用中断失败",
|
||||
"error": "发生错误",
|
||||
"raw": "原始",
|
||||
"preview": "预览"
|
||||
|
||||
@ -706,11 +706,15 @@
|
||||
"success.yuque.export": "成功匯出到語雀",
|
||||
"switch.disabled": "請等待當前回覆完成",
|
||||
"tools": {
|
||||
"aborted": "工具調用已中斷",
|
||||
"abort_failed": "工具調用中斷失敗",
|
||||
"completed": "已完成",
|
||||
"invoking": "調用中",
|
||||
"error": "發生錯誤",
|
||||
"raw": "原始碼",
|
||||
"preview": "預覽"
|
||||
"preview": "預覽",
|
||||
"pending": "等待中",
|
||||
"cancelled": "已取消"
|
||||
},
|
||||
"topic.added": "新話題已新增",
|
||||
"upgrade.success.button": "重新啟動",
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
import { CheckOutlined, ExpandOutlined, LoadingOutlined, WarningOutlined } from '@ant-design/icons'
|
||||
import { CheckOutlined, CloseOutlined, LoadingOutlined, WarningOutlined } from '@ant-design/icons'
|
||||
import { useCodeStyle } from '@renderer/context/CodeStyleProvider'
|
||||
import { useSettings } from '@renderer/hooks/useSettings'
|
||||
import type { ToolMessageBlock } from '@renderer/types/newMessage'
|
||||
import { Collapse, message as antdMessage, Modal, Tabs, Tooltip } from 'antd'
|
||||
import { cancelToolAction, confirmToolAction } from '@renderer/utils/userConfirmation'
|
||||
import { Collapse, message as antdMessage, Tooltip } from 'antd'
|
||||
import { message } from 'antd'
|
||||
import Logger from 'electron-log/renderer'
|
||||
import { PauseCircle } from 'lucide-react'
|
||||
import { FC, memo, useEffect, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
@ -14,12 +18,24 @@ interface Props {
|
||||
const MessageTools: FC<Props> = ({ block }) => {
|
||||
const [activeKeys, setActiveKeys] = useState<string[]>([])
|
||||
const [copiedMap, setCopiedMap] = useState<Record<string, boolean>>({})
|
||||
const [expandedResponse, setExpandedResponse] = useState<{ content: string; title: string } | null>(null)
|
||||
const { t } = useTranslation()
|
||||
const { messageFont, fontSize } = useSettings()
|
||||
|
||||
const toolResponse = block.metadata?.rawMcpToolResponse
|
||||
|
||||
const { id, tool, status, response } = toolResponse!
|
||||
|
||||
const isPending = status === 'pending'
|
||||
const isInvoking = status === 'invoking'
|
||||
const isDone = status === 'done'
|
||||
|
||||
const argsString = useMemo(() => {
|
||||
if (toolResponse?.arguments) {
|
||||
return JSON.stringify(toolResponse.arguments, null, 2)
|
||||
}
|
||||
return 'No arguments'
|
||||
}, [toolResponse])
|
||||
|
||||
const resultString = useMemo(() => {
|
||||
try {
|
||||
return JSON.stringify(
|
||||
@ -50,13 +66,34 @@ const MessageTools: FC<Props> = ({ block }) => {
|
||||
setActiveKeys(Array.isArray(keys) ? keys : [keys])
|
||||
}
|
||||
|
||||
const handleConfirmTool = () => {
|
||||
confirmToolAction(id)
|
||||
}
|
||||
|
||||
const handleCancelTool = () => {
|
||||
cancelToolAction(id)
|
||||
}
|
||||
|
||||
const handleAbortTool = async () => {
|
||||
if (toolResponse?.id) {
|
||||
try {
|
||||
const success = await window.api.mcp.abortTool(toolResponse.id)
|
||||
if (success) {
|
||||
message.success({ content: t('message.tools.aborted'), key: 'abort-tool' })
|
||||
} else {
|
||||
message.error({ content: t('message.tools.abort_failed'), key: 'abort-tool' })
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error('Failed to abort tool:', error)
|
||||
message.error({ content: t('message.tools.abort_failed'), key: 'abort-tool' })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Format tool responses for collapse items
|
||||
const getCollapseItems = () => {
|
||||
const items: { key: string; label: React.ReactNode; children: React.ReactNode }[] = []
|
||||
const { id, tool, status, response } = toolResponse
|
||||
const isInvoking = status === 'invoking'
|
||||
const isDone = status === 'done'
|
||||
const hasError = isDone && response?.isError === true
|
||||
const hasError = response?.isError === true
|
||||
const result = {
|
||||
params: toolResponse.arguments,
|
||||
response: toolResponse.response
|
||||
@ -68,34 +105,93 @@ const MessageTools: FC<Props> = ({ block }) => {
|
||||
<MessageTitleLabel>
|
||||
<TitleContent>
|
||||
<ToolName>{tool.name}</ToolName>
|
||||
<StatusIndicator $isInvoking={isInvoking} $hasError={hasError}>
|
||||
{isInvoking
|
||||
? t('message.tools.invoking')
|
||||
: hasError
|
||||
? t('message.tools.error')
|
||||
: t('message.tools.completed')}
|
||||
{isInvoking && <LoadingOutlined spin style={{ marginLeft: 6 }} />}
|
||||
{isDone && !hasError && <CheckOutlined style={{ marginLeft: 6 }} />}
|
||||
{hasError && <WarningOutlined style={{ marginLeft: 6 }} />}
|
||||
<StatusIndicator status={status} hasError={hasError}>
|
||||
{(() => {
|
||||
switch (status) {
|
||||
case 'pending':
|
||||
return (
|
||||
<>
|
||||
{t('message.tools.pending')}
|
||||
<LoadingOutlined spin style={{ marginLeft: 6 }} />
|
||||
</>
|
||||
)
|
||||
case 'invoking':
|
||||
return (
|
||||
<>
|
||||
{t('message.tools.invoking')}
|
||||
<LoadingOutlined spin style={{ marginLeft: 6 }} />
|
||||
</>
|
||||
)
|
||||
case 'cancelled':
|
||||
return (
|
||||
<>
|
||||
{t('message.tools.cancelled')}
|
||||
<CloseOutlined style={{ marginLeft: 6 }} />
|
||||
</>
|
||||
)
|
||||
case 'done':
|
||||
if (hasError) {
|
||||
return (
|
||||
<>
|
||||
{t('message.tools.error')}
|
||||
<WarningOutlined style={{ marginLeft: 6 }} />
|
||||
</>
|
||||
)
|
||||
} else {
|
||||
return (
|
||||
<>
|
||||
{t('message.tools.completed')}
|
||||
<CheckOutlined style={{ marginLeft: 6 }} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
default:
|
||||
return ''
|
||||
}
|
||||
})()}
|
||||
</StatusIndicator>
|
||||
</TitleContent>
|
||||
<ActionButtonsContainer>
|
||||
{isDone && response && (
|
||||
{isPending && (
|
||||
<>
|
||||
<Tooltip title={t('common.expand')} mouseEnterDelay={0.5}>
|
||||
<Tooltip title={t('common.cancel')} mouseEnterDelay={0.3}>
|
||||
<ActionButton
|
||||
className="message-action-button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
setExpandedResponse({
|
||||
content: JSON.stringify(response, null, 2),
|
||||
title: tool.name
|
||||
})
|
||||
handleCancelTool()
|
||||
}}
|
||||
aria-label={t('common.expand')}>
|
||||
<ExpandOutlined />
|
||||
aria-label={t('common.cancel')}>
|
||||
<CloseOutlined style={{ fontSize: '14px' }} />
|
||||
</ActionButton>
|
||||
</Tooltip>
|
||||
<Tooltip title={t('common.confirm')} mouseEnterDelay={0.3}>
|
||||
<ActionButton
|
||||
className="confirm-button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
handleConfirmTool()
|
||||
}}
|
||||
aria-label={t('common.confirm')}>
|
||||
<CheckOutlined style={{ fontSize: '14px' }} />
|
||||
</ActionButton>
|
||||
</Tooltip>
|
||||
</>
|
||||
)}
|
||||
{isInvoking && toolResponse?.id && (
|
||||
<Tooltip title={t('chat.input.pause')} mouseEnterDelay={0.3}>
|
||||
<ActionButton
|
||||
className="abort-button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
handleAbortTool()
|
||||
}}
|
||||
aria-label={t('chat.input.pause')}>
|
||||
<PauseCircle color="var(--color-error)" size={14} />
|
||||
</ActionButton>
|
||||
</Tooltip>
|
||||
)}
|
||||
{isDone && response && (
|
||||
<>
|
||||
<Tooltip title={t('common.copy')} mouseEnterDelay={0.5}>
|
||||
<ActionButton
|
||||
className="message-action-button"
|
||||
@ -113,98 +209,38 @@ const MessageTools: FC<Props> = ({ block }) => {
|
||||
</ActionButtonsContainer>
|
||||
</MessageTitleLabel>
|
||||
),
|
||||
children: isDone && result && (
|
||||
<ToolResponseContainer
|
||||
style={{
|
||||
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
|
||||
fontSize: '12px'
|
||||
}}>
|
||||
<CollapsedContent isExpanded={activeKeys.includes(id)} resultString={resultString} />
|
||||
</ToolResponseContainer>
|
||||
)
|
||||
children:
|
||||
isDone && result ? (
|
||||
<ToolResponseContainer
|
||||
style={{
|
||||
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
|
||||
fontSize
|
||||
}}>
|
||||
<CollapsedContent isExpanded={activeKeys.includes(id)} resultString={resultString} />
|
||||
</ToolResponseContainer>
|
||||
) : argsString ? (
|
||||
<>
|
||||
<ToolResponseContainer>
|
||||
<CollapsedContent isExpanded={activeKeys.includes(id)} resultString={argsString} />
|
||||
</ToolResponseContainer>
|
||||
</>
|
||||
) : null
|
||||
})
|
||||
|
||||
return items
|
||||
}
|
||||
|
||||
const renderPreview = (content: string) => {
|
||||
if (!content) return null
|
||||
|
||||
try {
|
||||
const parsedResult = JSON.parse(content)
|
||||
switch (parsedResult.content[0]?.type) {
|
||||
case 'text':
|
||||
return <PreviewBlock>{parsedResult.content[0].text}</PreviewBlock>
|
||||
default:
|
||||
return <PreviewBlock>{content}</PreviewBlock>
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('failed to render the preview of mcp results:', e)
|
||||
return <PreviewBlock>{content}</PreviewBlock>
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<ToolContainer>
|
||||
<CollapseContainer
|
||||
activeKey={activeKeys}
|
||||
size="small"
|
||||
onChange={handleCollapseChange}
|
||||
className="message-tools-container"
|
||||
items={getCollapseItems()}
|
||||
expandIcon={({ isActive }) => (
|
||||
<CollapsibleIcon className={`iconfont ${isActive ? 'icon-chevron-down' : 'icon-chevron-right'}`} />
|
||||
)}
|
||||
expandIconPosition="end"
|
||||
/>
|
||||
|
||||
<Modal
|
||||
title={expandedResponse?.title}
|
||||
open={!!expandedResponse}
|
||||
onCancel={() => setExpandedResponse(null)}
|
||||
footer={null}
|
||||
width="80%"
|
||||
centered
|
||||
transitionName="animation-move-down"
|
||||
styles={{ body: { maxHeight: '80vh', overflow: 'auto' } }}>
|
||||
{expandedResponse && (
|
||||
<ExpandedResponseContainer
|
||||
style={{
|
||||
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
|
||||
fontSize
|
||||
}}>
|
||||
<Tabs
|
||||
tabBarExtraContent={
|
||||
<ActionButton
|
||||
className="copy-expanded-button"
|
||||
onClick={() => {
|
||||
navigator.clipboard.writeText(
|
||||
typeof expandedResponse.content === 'string'
|
||||
? expandedResponse.content
|
||||
: JSON.stringify(expandedResponse.content, null, 2)
|
||||
)
|
||||
antdMessage.success({ content: t('message.copied'), key: 'copy-expanded' })
|
||||
}}
|
||||
aria-label={t('common.copy')}>
|
||||
<i className="iconfont icon-copy"></i>
|
||||
</ActionButton>
|
||||
}
|
||||
items={[
|
||||
{
|
||||
key: 'preview',
|
||||
label: t('message.tools.preview'),
|
||||
children: <CollapsedContent isExpanded={true} resultString={resultString} />
|
||||
},
|
||||
{
|
||||
key: 'raw',
|
||||
label: t('message.tools.raw'),
|
||||
children: renderPreview(expandedResponse.content)
|
||||
}
|
||||
]}
|
||||
/>
|
||||
</ExpandedResponseContainer>
|
||||
)}
|
||||
</Modal>
|
||||
</>
|
||||
</ToolContainer>
|
||||
)
|
||||
}
|
||||
|
||||
@ -230,15 +266,25 @@ const CollapsedContent: FC<{ isExpanded: boolean; resultString: string }> = ({ i
|
||||
}
|
||||
|
||||
const CollapseContainer = styled(Collapse)`
|
||||
margin-top: 10px;
|
||||
margin-bottom: 12px;
|
||||
border-radius: 8px;
|
||||
border: none;
|
||||
overflow: hidden;
|
||||
|
||||
.ant-collapse-header {
|
||||
background-color: var(--color-bg-2);
|
||||
transition: background-color 0.2s;
|
||||
|
||||
display: flex;
|
||||
align-items: center;
|
||||
.ant-collapse-expand-icon {
|
||||
height: 100% !important;
|
||||
}
|
||||
.ant-collapse-arrow {
|
||||
height: 28px !important;
|
||||
svg {
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
}
|
||||
}
|
||||
&:hover {
|
||||
background-color: var(--color-bg-3);
|
||||
}
|
||||
@ -249,6 +295,15 @@ const CollapseContainer = styled(Collapse)`
|
||||
}
|
||||
`
|
||||
|
||||
const ToolContainer = styled.div`
|
||||
margin-top: 10px;
|
||||
margin-bottom: 12px;
|
||||
border: 1px solid var(--color-border);
|
||||
background-color: var(--color-bg-2);
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
`
|
||||
|
||||
const MarkdownContainer = styled.div`
|
||||
& pre {
|
||||
background: transparent !important;
|
||||
@ -267,6 +322,7 @@ const MessageTitleLabel = styled.div`
|
||||
min-height: 26px;
|
||||
gap: 10px;
|
||||
padding: 0;
|
||||
margin-left: 4px;
|
||||
`
|
||||
|
||||
const TitleContent = styled.div`
|
||||
@ -282,18 +338,27 @@ const ToolName = styled.span`
|
||||
font-size: 13px;
|
||||
`
|
||||
|
||||
const StatusIndicator = styled.span<{ $isInvoking: boolean; $hasError?: boolean }>`
|
||||
const StatusIndicator = styled.span<{ status: string; hasError?: boolean }>`
|
||||
color: ${(props) => {
|
||||
if (props.$hasError) return 'var(--color-error, #ff4d4f)'
|
||||
if (props.$isInvoking) return 'var(--color-primary)'
|
||||
return 'var(--color-success, #52c41a)'
|
||||
switch (props.status) {
|
||||
case 'pending':
|
||||
return 'var(--color-text-2)'
|
||||
case 'invoking':
|
||||
return 'var(--color-primary)'
|
||||
case 'cancelled':
|
||||
return 'var(--color-error, #ff4d4f)' // Assuming cancelled should also be an error color
|
||||
case 'done':
|
||||
return props.hasError ? 'var(--color-error, #ff4d4f)' : 'var(--color-success, #52c41a)'
|
||||
default:
|
||||
return 'var(--color-text)'
|
||||
}
|
||||
}};
|
||||
font-size: 11px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
opacity: 0.85;
|
||||
border-left: 1px solid var(--color-border);
|
||||
padding-left: 8px;
|
||||
padding-left: 12px;
|
||||
`
|
||||
|
||||
const ActionButtonsContainer = styled.div`
|
||||
@ -307,18 +372,30 @@ const ActionButton = styled.button`
|
||||
border: none;
|
||||
color: var(--color-text-2);
|
||||
cursor: pointer;
|
||||
padding: 4px 8px;
|
||||
padding: 4px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
opacity: 0.7;
|
||||
transition: all 0.2s;
|
||||
border-radius: 4px;
|
||||
gap: 4px;
|
||||
min-width: 28px;
|
||||
height: 28px;
|
||||
|
||||
&:hover {
|
||||
opacity: 1;
|
||||
color: var(--color-text);
|
||||
background-color: var(--color-bg-1);
|
||||
background-color: var(--color-bg-3);
|
||||
}
|
||||
|
||||
&.confirm-button {
|
||||
color: var(--color-primary);
|
||||
|
||||
&:hover {
|
||||
background-color: var(--color-primary-bg);
|
||||
color: var(--color-primary);
|
||||
}
|
||||
}
|
||||
|
||||
&:focus-visible {
|
||||
@ -332,12 +409,6 @@ const ActionButton = styled.button`
|
||||
}
|
||||
`
|
||||
|
||||
const CollapsibleIcon = styled.i`
|
||||
color: var(--color-text-2);
|
||||
font-size: 12px;
|
||||
transition: transform 0.2s;
|
||||
`
|
||||
|
||||
const ToolResponseContainer = styled.div`
|
||||
border-radius: 0 0 4px 4px;
|
||||
overflow: auto;
|
||||
@ -346,35 +417,4 @@ const ToolResponseContainer = styled.div`
|
||||
position: relative;
|
||||
`
|
||||
|
||||
const PreviewBlock = styled.div`
|
||||
margin: 0;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
color: var(--color-text);
|
||||
user-select: text;
|
||||
`
|
||||
|
||||
const ExpandedResponseContainer = styled.div`
|
||||
background: var(--color-bg-1);
|
||||
border-radius: 8px;
|
||||
padding: 16px;
|
||||
position: relative;
|
||||
|
||||
.copy-expanded-button {
|
||||
position: absolute;
|
||||
top: 10px;
|
||||
right: 10px;
|
||||
background-color: var(--color-bg-2);
|
||||
border-radius: 4px;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
pre {
|
||||
margin: 0;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
color: var(--color-text);
|
||||
}
|
||||
`
|
||||
|
||||
export default memo(MessageTools)
|
||||
|
||||
@ -16,6 +16,7 @@ export interface StreamProcessorCallbacks {
|
||||
onThinkingChunk?: (text: string, thinking_millsec?: number) => void
|
||||
onThinkingComplete?: (text: string, thinking_millsec?: number) => void
|
||||
// A tool call response chunk (from MCP)
|
||||
onToolCallPending?: (toolResponse: MCPToolResponse) => void
|
||||
onToolCallInProgress?: (toolResponse: MCPToolResponse) => void
|
||||
onToolCallComplete?: (toolResponse: MCPToolResponse) => void
|
||||
// External tool call in progress
|
||||
@ -69,6 +70,10 @@ export function createStreamProcessor(callbacks: StreamProcessorCallbacks = {})
|
||||
if (callbacks.onThinkingComplete) callbacks.onThinkingComplete(data.text, data.thinking_millsec)
|
||||
break
|
||||
}
|
||||
case ChunkType.MCP_TOOL_PENDING: {
|
||||
if (callbacks.onToolCallPending) data.responses.forEach((toolResp) => callbacks.onToolCallPending!(toolResp))
|
||||
break
|
||||
}
|
||||
case ChunkType.MCP_TOOL_IN_PROGRESS: {
|
||||
if (callbacks.onToolCallInProgress)
|
||||
data.responses.forEach((toolResp) => callbacks.onToolCallInProgress!(toolResp))
|
||||
|
||||
@ -529,12 +529,13 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
}
|
||||
thinkingBlockId = null
|
||||
},
|
||||
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
|
||||
onToolCallPending: (toolResponse: MCPToolResponse) => {
|
||||
if (initialPlaceholderBlockId) {
|
||||
lastBlockType = MessageBlockType.TOOL
|
||||
const changes = {
|
||||
type: MessageBlockType.TOOL,
|
||||
status: MessageBlockStatus.PROCESSING,
|
||||
status: MessageBlockStatus.PENDING,
|
||||
toolName: toolResponse.tool.name,
|
||||
metadata: { rawMcpToolResponse: toolResponse }
|
||||
}
|
||||
toolBlockId = initialPlaceholderBlockId
|
||||
@ -542,14 +543,37 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
dispatch(updateOneBlock({ id: toolBlockId, changes }))
|
||||
saveUpdatedBlockToDB(toolBlockId, assistantMsgId, topicId, getState)
|
||||
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlockId)
|
||||
} else if (toolResponse.status === 'invoking') {
|
||||
} else if (toolResponse.status === 'pending') {
|
||||
const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, {
|
||||
toolName: toolResponse.tool.name,
|
||||
status: MessageBlockStatus.PROCESSING,
|
||||
status: MessageBlockStatus.PENDING,
|
||||
metadata: { rawMcpToolResponse: toolResponse }
|
||||
})
|
||||
toolBlockId = toolBlock.id
|
||||
handleBlockTransition(toolBlock, MessageBlockType.TOOL)
|
||||
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlock.id)
|
||||
} else {
|
||||
console.warn(
|
||||
`[onToolCallPending] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
||||
)
|
||||
}
|
||||
},
|
||||
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
|
||||
// 根据 toolResponse.id 查找对应的块ID
|
||||
const targetBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
||||
|
||||
if (targetBlockId && toolResponse.status === 'invoking') {
|
||||
const changes = {
|
||||
status: MessageBlockStatus.PROCESSING,
|
||||
metadata: { rawMcpToolResponse: toolResponse }
|
||||
}
|
||||
dispatch(updateOneBlock({ id: targetBlockId, changes }))
|
||||
saveUpdatedBlockToDB(targetBlockId, assistantMsgId, topicId, getState)
|
||||
} else if (!targetBlockId) {
|
||||
console.warn(
|
||||
`[onToolCallInProgress] No block ID found for tool ID: ${toolResponse.id}. Available mappings:`,
|
||||
Array.from(toolCallIdToBlockIdMap.entries())
|
||||
)
|
||||
} else {
|
||||
console.warn(
|
||||
`[onToolCallInProgress] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
||||
@ -559,14 +583,17 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
onToolCallComplete: (toolResponse: MCPToolResponse) => {
|
||||
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
||||
toolCallIdToBlockIdMap.delete(toolResponse.id)
|
||||
if (toolResponse.status === 'done' || toolResponse.status === 'error') {
|
||||
if (toolResponse.status === 'done' || toolResponse.status === 'error' || toolResponse.status === 'cancelled') {
|
||||
if (!existingBlockId) {
|
||||
console.error(
|
||||
`[onToolCallComplete] No existing block found for completed/error tool call ID: ${toolResponse.id}. Cannot update.`
|
||||
)
|
||||
return
|
||||
}
|
||||
const finalStatus = toolResponse.status === 'done' ? MessageBlockStatus.SUCCESS : MessageBlockStatus.ERROR
|
||||
const finalStatus =
|
||||
toolResponse.status === 'done' || toolResponse.status === 'cancelled'
|
||||
? MessageBlockStatus.SUCCESS
|
||||
: MessageBlockStatus.ERROR
|
||||
const changes: Partial<ToolMessageBlock> = {
|
||||
content: toolResponse.response,
|
||||
status: finalStatus,
|
||||
@ -583,6 +610,7 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
`[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
||||
)
|
||||
}
|
||||
toolBlockId = null
|
||||
},
|
||||
onExternalToolInProgress: async () => {
|
||||
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
|
||||
@ -762,7 +790,14 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
})
|
||||
}
|
||||
const possibleBlockId =
|
||||
mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId
|
||||
mainTextBlockId ||
|
||||
thinkingBlockId ||
|
||||
toolBlockId ||
|
||||
imageBlockId ||
|
||||
citationBlockId ||
|
||||
initialPlaceholderBlockId ||
|
||||
lastBlockId
|
||||
|
||||
if (possibleBlockId) {
|
||||
// 更改上一个block的状态为ERROR
|
||||
const changes: Partial<MessageBlock> = {
|
||||
@ -801,7 +836,13 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
const finalContextWithAssistant = [...contextForUsage, finalAssistantMsg]
|
||||
|
||||
const possibleBlockId =
|
||||
mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId
|
||||
mainTextBlockId ||
|
||||
thinkingBlockId ||
|
||||
toolBlockId ||
|
||||
imageBlockId ||
|
||||
citationBlockId ||
|
||||
initialPlaceholderBlockId ||
|
||||
lastBlockId
|
||||
if (possibleBlockId) {
|
||||
const changes: Partial<MessageBlock> = {
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
@ -1109,7 +1150,6 @@ export const resendMessageThunk =
|
||||
// 没有相关的助手消息就创建一个或多个
|
||||
|
||||
if (userMessageToResend?.mentions?.length) {
|
||||
console.log('userMessageToResend.mentions', userMessageToResend.mentions)
|
||||
for (const mention of userMessageToResend.mentions) {
|
||||
const assistantMessage = createAssistantMessage(assistant.id, topicId, {
|
||||
askId: userMessageToResend.id,
|
||||
|
||||
@ -13,6 +13,7 @@ export enum ChunkType {
|
||||
KNOWLEDGE_SEARCH_IN_PROGRESS = 'knowledge_search_in_progress',
|
||||
KNOWLEDGE_SEARCH_COMPLETE = 'knowledge_search_complete',
|
||||
MCP_TOOL_CREATED = 'mcp_tool_created',
|
||||
MCP_TOOL_PENDING = 'mcp_tool_pending',
|
||||
MCP_TOOL_IN_PROGRESS = 'mcp_tool_in_progress',
|
||||
MCP_TOOL_COMPLETE = 'mcp_tool_complete',
|
||||
EXTERNEL_TOOL_COMPLETE = 'externel_tool_complete',
|
||||
@ -260,6 +261,11 @@ export interface MCPToolCreatedChunk {
|
||||
tool_use_responses?: ToolUseResponse[] // 工具使用响应
|
||||
}
|
||||
|
||||
export interface MCPToolPendingChunk {
|
||||
type: ChunkType.MCP_TOOL_PENDING
|
||||
responses: MCPToolResponse[]
|
||||
}
|
||||
|
||||
export interface MCPToolInProgressChunk {
|
||||
/**
|
||||
* The type of the chunk
|
||||
@ -353,6 +359,7 @@ export type Chunk =
|
||||
| KnowledgeSearchInProgressChunk // 知识库搜索进行中
|
||||
| KnowledgeSearchCompleteChunk // 知识库搜索完成
|
||||
| MCPToolCreatedChunk // MCP工具被大模型创建
|
||||
| MCPToolPendingChunk // MCP工具调用等待中
|
||||
| MCPToolInProgressChunk // MCP工具调用中
|
||||
| MCPToolCompleteChunk // MCP工具调用完成
|
||||
| ExternalToolCompleteChunk // 外部工具调用完成,外部工具包含搜索互联网,知识库,MCP服务器
|
||||
|
||||
@ -683,11 +683,13 @@ export interface MCPConfig {
|
||||
isBunInstalled: boolean
|
||||
}
|
||||
|
||||
export type MCPToolResponseStatus = 'pending' | 'cancelled' | 'invoking' | 'done' | 'error'
|
||||
|
||||
interface BaseToolResponse {
|
||||
id: string // unique id
|
||||
tool: MCPTool
|
||||
arguments: Record<string, unknown> | undefined
|
||||
status: string // 'invoking' | 'done'
|
||||
status: MCPToolResponseStatus
|
||||
response?: any
|
||||
}
|
||||
|
||||
|
||||
693
src/renderer/src/utils/__tests__/tagExtraction.test.ts
Normal file
693
src/renderer/src/utils/__tests__/tagExtraction.test.ts
Normal file
@ -0,0 +1,693 @@
|
||||
import { describe, expect, test } from 'vitest'
|
||||
|
||||
import { TagConfig, TagExtractor } from '../tagExtraction'
|
||||
|
||||
describe('TagExtractor', () => {
|
||||
describe('基本标签提取', () => {
|
||||
test('应该正确提取简单的标签内容', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
const results = extractor.processText('<think>Hello World</think>')
|
||||
|
||||
expect(results).toHaveLength(2)
|
||||
expect(results[0]).toEqual({
|
||||
content: 'Hello World',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: 'Hello World'
|
||||
})
|
||||
})
|
||||
|
||||
test('应该处理标签前后的普通文本', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
const results = extractor.processText('前文<think>思考内容</think>后文')
|
||||
|
||||
expect(results).toHaveLength(4)
|
||||
expect(results[0]).toEqual({
|
||||
content: '前文',
|
||||
isTagContent: false,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '思考内容',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[2]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '思考内容'
|
||||
})
|
||||
expect(results[3]).toEqual({
|
||||
content: '后文',
|
||||
isTagContent: false,
|
||||
complete: false
|
||||
})
|
||||
})
|
||||
|
||||
test('应该处理空标签', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
const results = extractor.processText('<think></think>')
|
||||
|
||||
expect(results).toHaveLength(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('分块处理', () => {
|
||||
test('应该正确处理分块的标签内容', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
let results = extractor.processText('<think>第一')
|
||||
expect(results).toHaveLength(1)
|
||||
expect(results[0]).toEqual({
|
||||
content: '第一',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
|
||||
results = extractor.processText('部分内容')
|
||||
expect(results).toHaveLength(1)
|
||||
expect(results[0]).toEqual({
|
||||
content: '部分内容',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
|
||||
results = extractor.processText('</think>')
|
||||
expect(results).toHaveLength(1)
|
||||
expect(results[0]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '第一部分内容'
|
||||
})
|
||||
})
|
||||
|
||||
test('应该处理分块的开始标签', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
let results = extractor.processText('<thi')
|
||||
expect(results).toHaveLength(0)
|
||||
|
||||
results = extractor.processText('nk>内容</think>')
|
||||
expect(results).toHaveLength(2)
|
||||
expect(results[0]).toEqual({
|
||||
content: '内容',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '内容'
|
||||
})
|
||||
})
|
||||
|
||||
test('应该处理模拟可读流的分块数据', async () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
// 模拟流式数据块
|
||||
const streamChunks = [
|
||||
'这是普通文本',
|
||||
'<thi',
|
||||
'nk>这是第一个',
|
||||
'思考内容',
|
||||
'</think>',
|
||||
'中间的一些文本',
|
||||
'<think>第二',
|
||||
'个思考内容',
|
||||
'</thi',
|
||||
'nk>',
|
||||
'结束文本'
|
||||
]
|
||||
|
||||
const allResults: any[] = []
|
||||
|
||||
// 模拟异步流式处理
|
||||
for (const chunk of streamChunks) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 10)) // 模拟异步延迟
|
||||
const results = extractor.processText(chunk)
|
||||
allResults.push(...results)
|
||||
}
|
||||
|
||||
// 验证结果
|
||||
expect(allResults).toHaveLength(9)
|
||||
|
||||
// 第一个普通文本
|
||||
expect(allResults[0]).toEqual({
|
||||
content: '这是普通文本',
|
||||
isTagContent: false,
|
||||
complete: false
|
||||
})
|
||||
|
||||
// 第一个思考标签内容
|
||||
expect(allResults[1]).toEqual({
|
||||
content: '这是第一个',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
|
||||
expect(allResults[2]).toEqual({
|
||||
content: '思考内容',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
|
||||
// 第一个完整的标签内容提取
|
||||
expect(allResults[3]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '这是第一个思考内容'
|
||||
})
|
||||
|
||||
// 中间文本
|
||||
expect(allResults[4]).toEqual({
|
||||
content: '中间的一些文本',
|
||||
isTagContent: false,
|
||||
complete: false
|
||||
})
|
||||
|
||||
// 第二个思考标签内容
|
||||
expect(allResults[5]).toEqual({
|
||||
content: '第二',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
|
||||
// 第二个完整的标签内容提取和结束文本
|
||||
expect(allResults[6]).toEqual({
|
||||
content: '个思考内容',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
|
||||
expect(allResults[7]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '第二个思考内容'
|
||||
})
|
||||
|
||||
expect(allResults[8]).toEqual({
|
||||
content: '结束文本',
|
||||
isTagContent: false,
|
||||
complete: false
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('多个标签处理', () => {
|
||||
test('应该处理连续的多个标签', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
const results = extractor.processText('<think>第一个</think><think>第二个</think>')
|
||||
|
||||
expect(results).toHaveLength(4)
|
||||
expect(results[0]).toEqual({
|
||||
content: '第一个',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '第一个'
|
||||
})
|
||||
expect(results[2]).toEqual({
|
||||
content: '第二个',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[3]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '第二个'
|
||||
})
|
||||
})
|
||||
|
||||
test('应该处理标签间的文本内容', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
const results = extractor.processText('<think>思考1</think>中间文本<think>思考2</think>')
|
||||
|
||||
expect(results).toHaveLength(5)
|
||||
expect(results[0]).toEqual({
|
||||
content: '思考1',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '思考1'
|
||||
})
|
||||
expect(results[2]).toEqual({
|
||||
content: '中间文本',
|
||||
isTagContent: false,
|
||||
complete: false
|
||||
})
|
||||
expect(results[3]).toEqual({
|
||||
content: '思考2',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[4]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '思考2'
|
||||
})
|
||||
})
|
||||
|
||||
test('应该处理三个连续标签的分次输出', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
// 第一次输入:包含两个完整标签和第三个标签的开始
|
||||
let results = extractor.processText('<think>第一个</think><think>第二个</think><think>第三个开始')
|
||||
|
||||
expect(results).toHaveLength(5)
|
||||
expect(results[0]).toEqual({
|
||||
content: '第一个',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '第一个'
|
||||
})
|
||||
expect(results[2]).toEqual({
|
||||
content: '第二个',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[3]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '第二个'
|
||||
})
|
||||
expect(results[4]).toEqual({
|
||||
content: '第三个开始',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
|
||||
// 第二次输入:继续第三个标签的内容
|
||||
results = extractor.processText('继续内容')
|
||||
|
||||
expect(results).toHaveLength(1)
|
||||
expect(results[0]).toEqual({
|
||||
content: '继续内容',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
|
||||
// 第三次输入:完成第三个标签
|
||||
results = extractor.processText('结束</think>')
|
||||
|
||||
expect(results).toHaveLength(2)
|
||||
expect(results[0]).toEqual({
|
||||
content: '结束',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '第三个开始继续内容结束'
|
||||
})
|
||||
})
|
||||
|
||||
test('应该处理三个连续标签的另一种分次输出模式', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
// 第一次输入:第一个完整标签
|
||||
let results = extractor.processText('<think>第一个思考</think>')
|
||||
|
||||
expect(results).toHaveLength(2)
|
||||
expect(results[0]).toEqual({
|
||||
content: '第一个思考',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '第一个思考'
|
||||
})
|
||||
|
||||
// 第二次输入:第二个完整标签和第三个标签的部分内容
|
||||
results = extractor.processText('<think>第二个思考</think><think>第三个开')
|
||||
|
||||
expect(results).toHaveLength(3)
|
||||
expect(results[0]).toEqual({
|
||||
content: '第二个思考',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '第二个思考'
|
||||
})
|
||||
expect(results[2]).toEqual({
|
||||
content: '第三个开',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
|
||||
// 第三次输入:完成第三个标签
|
||||
results = extractor.processText('始部分</think>')
|
||||
|
||||
expect(results).toHaveLength(2)
|
||||
expect(results[0]).toEqual({
|
||||
content: '始部分',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '第三个开始部分'
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('不完整标签处理', () => {
|
||||
test('应该处理只有开始标签的情况', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
const results = extractor.processText('<think>未完成的思考')
|
||||
|
||||
expect(results).toHaveLength(1)
|
||||
expect(results[0]).toEqual({
|
||||
content: '未完成的思考',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
})
|
||||
|
||||
test('应该处理文本中间截断的标签', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
const results = extractor.processText('正常文本<thi')
|
||||
|
||||
expect(results).toHaveLength(1)
|
||||
expect(results[0]).toEqual({
|
||||
content: '正常文本',
|
||||
isTagContent: false,
|
||||
complete: false
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('finalize 方法', () => {
|
||||
test('应该返回未完成的标签内容', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
extractor.processText('<think>未完成的内容')
|
||||
const result = extractor.finalize()
|
||||
|
||||
expect(result).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '未完成的内容'
|
||||
})
|
||||
})
|
||||
|
||||
test('当没有未完成内容时应该返回 null', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
extractor.processText('<think>完整内容</think>')
|
||||
const result = extractor.finalize()
|
||||
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
test('对于普通文本应该返回 null', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
extractor.processText('只是普通文本')
|
||||
const result = extractor.finalize()
|
||||
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('reset 方法', () => {
|
||||
test('应该重置所有内部状态', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
// 处理一些文本以改变内部状态
|
||||
extractor.processText('<think>一些内容')
|
||||
|
||||
// 重置
|
||||
extractor.reset()
|
||||
|
||||
// 重置后应该能正常处理新的文本
|
||||
const results = extractor.processText('<think>新内容</think>')
|
||||
|
||||
expect(results).toHaveLength(2)
|
||||
expect(results[0]).toEqual({
|
||||
content: '新内容',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '新内容'
|
||||
})
|
||||
})
|
||||
|
||||
test('重置后 finalize 应该返回 null', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
extractor.processText('<think>未完成')
|
||||
extractor.reset()
|
||||
|
||||
const result = extractor.finalize()
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('不同标签配置', () => {
|
||||
test('应该处理工具使用标签', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<tool_use>',
|
||||
closingTag: '</tool_use>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
const results = extractor.processText('<tool_use>{"name": "search"}</tool_use>')
|
||||
|
||||
expect(results).toHaveLength(2)
|
||||
expect(results[0]).toEqual({
|
||||
content: '{"name": "search"}',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '{"name": "search"}'
|
||||
})
|
||||
})
|
||||
|
||||
test('应该处理自定义标签', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '[START]',
|
||||
closingTag: '[END]'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
const results = extractor.processText('前文[START]中间内容[END]后文')
|
||||
|
||||
expect(results).toHaveLength(4)
|
||||
expect(results[0]).toEqual({
|
||||
content: '前文',
|
||||
isTagContent: false,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '中间内容',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[2]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '中间内容'
|
||||
})
|
||||
expect(results[3]).toEqual({
|
||||
content: '后文',
|
||||
isTagContent: false,
|
||||
complete: false
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('边界情况', () => {
|
||||
test('应该处理空字符串输入', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
const results = extractor.processText('')
|
||||
|
||||
expect(results).toHaveLength(0)
|
||||
})
|
||||
|
||||
test('应该处理只包含标签的输入', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
const results = extractor.processText('<think></think>')
|
||||
|
||||
expect(results).toHaveLength(0)
|
||||
})
|
||||
|
||||
test('应该处理标签内容包含相似文本的情况', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
const results = extractor.processText('<think>我在<thinking>思考</think>')
|
||||
|
||||
expect(results).toHaveLength(2)
|
||||
expect(results[0]).toEqual({
|
||||
content: '我在<thinking>思考',
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: '我在<thinking>思考'
|
||||
})
|
||||
})
|
||||
|
||||
test('应该处理换行符和特殊字符', () => {
|
||||
const config: TagConfig = {
|
||||
openingTag: '<think>',
|
||||
closingTag: '</think>'
|
||||
}
|
||||
const extractor = new TagExtractor(config)
|
||||
|
||||
const content = '多行\n内容\t带制表符'
|
||||
const results = extractor.processText(`<think>${content}</think>`)
|
||||
|
||||
expect(results).toHaveLength(2)
|
||||
expect(results[0]).toEqual({
|
||||
content: content,
|
||||
isTagContent: true,
|
||||
complete: false
|
||||
})
|
||||
expect(results[1]).toEqual({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: content
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -14,9 +14,8 @@ import {
|
||||
Model,
|
||||
ToolUseResponse
|
||||
} from '@renderer/types'
|
||||
import type { MCPToolCompleteChunk, MCPToolInProgressChunk } from '@renderer/types/chunk'
|
||||
import type { MCPToolCompleteChunk, MCPToolInProgressChunk, MCPToolPendingChunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { SdkMessageParam } from '@renderer/types/sdk'
|
||||
import { isArray, isObject, pull, transform } from 'lodash'
|
||||
import { nanoid } from 'nanoid'
|
||||
import OpenAI from 'openai'
|
||||
@ -28,6 +27,7 @@ import {
|
||||
} from 'openai/resources'
|
||||
|
||||
import { CompletionsParams } from '../aiCore/middleware/schemas'
|
||||
import { requestToolConfirmation } from './userConfirmation'
|
||||
|
||||
const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install'
|
||||
const EXTRA_SCHEMA_KEYS = ['schema', 'headers']
|
||||
@ -278,7 +278,8 @@ export async function callMCPTool(toolResponse: MCPToolResponse): Promise<MCPCal
|
||||
const resp = await window.api.mcp.callTool({
|
||||
server,
|
||||
name: toolResponse.tool.name,
|
||||
args: toolResponse.arguments
|
||||
args: toolResponse.arguments,
|
||||
callId: toolResponse.id
|
||||
})
|
||||
if (toolResponse.tool.serverName === MCP_AUTO_INSTALL_SERVER_NAME) {
|
||||
if (resp.data) {
|
||||
@ -400,7 +401,7 @@ export function geminiFunctionCallToMcpTool(
|
||||
export function upsertMCPToolResponse(
|
||||
results: MCPToolResponse[],
|
||||
resp: MCPToolResponse,
|
||||
onChunk: (chunk: MCPToolInProgressChunk | MCPToolCompleteChunk) => void
|
||||
onChunk: (chunk: MCPToolPendingChunk | MCPToolInProgressChunk | MCPToolCompleteChunk) => void
|
||||
) {
|
||||
const index = results.findIndex((ret) => ret.id === resp.id)
|
||||
let result = resp
|
||||
@ -416,10 +417,29 @@ export function upsertMCPToolResponse(
|
||||
} else {
|
||||
results.push(resp)
|
||||
}
|
||||
onChunk({
|
||||
type: resp.status === 'invoking' ? ChunkType.MCP_TOOL_IN_PROGRESS : ChunkType.MCP_TOOL_COMPLETE,
|
||||
responses: [result]
|
||||
})
|
||||
switch (resp.status) {
|
||||
case 'pending':
|
||||
onChunk({
|
||||
type: ChunkType.MCP_TOOL_PENDING,
|
||||
responses: [result]
|
||||
})
|
||||
break
|
||||
case 'invoking':
|
||||
onChunk({
|
||||
type: ChunkType.MCP_TOOL_IN_PROGRESS,
|
||||
responses: [result]
|
||||
})
|
||||
break
|
||||
case 'cancelled':
|
||||
case 'done':
|
||||
onChunk({
|
||||
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||
responses: [result]
|
||||
})
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
export function filterMCPTools(
|
||||
@ -441,7 +461,7 @@ export function getMcpServerByTool(tool: MCPTool) {
|
||||
return servers.find((s) => s.id === tool.serverId)
|
||||
}
|
||||
|
||||
export function parseToolUse(content: string, mcpTools: MCPTool[]): ToolUseResponse[] {
|
||||
export function parseToolUse(content: string, mcpTools: MCPTool[], startIdx: number = 0): ToolUseResponse[] {
|
||||
if (!content || !mcpTools || mcpTools.length === 0) {
|
||||
return []
|
||||
}
|
||||
@ -461,7 +481,7 @@ export function parseToolUse(content: string, mcpTools: MCPTool[]): ToolUseRespo
|
||||
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
|
||||
const tools: ToolUseResponse[] = []
|
||||
let match
|
||||
let idx = 0
|
||||
let idx = startIdx
|
||||
// Find all tool use blocks
|
||||
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
|
||||
// const fullMatch = match[0]
|
||||
@ -505,8 +525,9 @@ export async function parseAndCallTools<R>(
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||
model: Model,
|
||||
mcpTools?: MCPTool[]
|
||||
): Promise<SdkMessageParam[]>
|
||||
mcpTools?: MCPTool[],
|
||||
abortSignal?: AbortSignal
|
||||
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }>
|
||||
|
||||
export async function parseAndCallTools<R>(
|
||||
content: string,
|
||||
@ -514,8 +535,9 @@ export async function parseAndCallTools<R>(
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||
model: Model,
|
||||
mcpTools?: MCPTool[]
|
||||
): Promise<SdkMessageParam[]>
|
||||
mcpTools?: MCPTool[],
|
||||
abortSignal?: AbortSignal
|
||||
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }>
|
||||
|
||||
export async function parseAndCallTools<R>(
|
||||
content: string | MCPToolResponse[],
|
||||
@ -523,68 +545,172 @@ export async function parseAndCallTools<R>(
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||
model: Model,
|
||||
mcpTools?: MCPTool[]
|
||||
): Promise<R[]> {
|
||||
mcpTools?: MCPTool[],
|
||||
abortSignal?: AbortSignal
|
||||
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }> {
|
||||
const toolResults: R[] = []
|
||||
let curToolResponses: MCPToolResponse[] = []
|
||||
if (Array.isArray(content)) {
|
||||
curToolResponses = content
|
||||
} else {
|
||||
// process tool use
|
||||
curToolResponses = parseToolUse(content, mcpTools || [])
|
||||
curToolResponses = parseToolUse(content, mcpTools || [], 0)
|
||||
}
|
||||
if (!curToolResponses || curToolResponses.length === 0) {
|
||||
return toolResults
|
||||
return { toolResults, confirmedToolResponses: [] }
|
||||
}
|
||||
for (let i = 0; i < curToolResponses.length; i++) {
|
||||
const toolResponse = curToolResponses[i]
|
||||
|
||||
for (const toolResponse of curToolResponses) {
|
||||
upsertMCPToolResponse(
|
||||
allToolResponses,
|
||||
{
|
||||
...toolResponse,
|
||||
status: 'invoking'
|
||||
status: 'pending'
|
||||
},
|
||||
onChunk!
|
||||
)
|
||||
}
|
||||
|
||||
const toolPromises = curToolResponses.map(async (toolResponse) => {
|
||||
const images: string[] = []
|
||||
const toolCallResponse = await callMCPTool(toolResponse)
|
||||
upsertMCPToolResponse(
|
||||
allToolResponses,
|
||||
{
|
||||
...toolResponse,
|
||||
status: 'done',
|
||||
response: toolCallResponse
|
||||
},
|
||||
onChunk!
|
||||
)
|
||||
// 创建工具确认Promise映射,并立即处理每个确认
|
||||
const confirmedTools: MCPToolResponse[] = []
|
||||
const pendingPromises: Promise<void>[] = []
|
||||
|
||||
for (const content of toolCallResponse.content) {
|
||||
if (content.type === 'image' && content.data) {
|
||||
images.push(`data:${content.mimeType};base64,${content.data}`)
|
||||
}
|
||||
}
|
||||
curToolResponses.forEach((toolResponse) => {
|
||||
const confirmationPromise = requestToolConfirmation(toolResponse.id, abortSignal)
|
||||
|
||||
if (images.length) {
|
||||
onChunk?.({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
onChunk?.({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: images
|
||||
const processingPromise = confirmationPromise
|
||||
.then(async (confirmed) => {
|
||||
if (confirmed) {
|
||||
// 立即更新为invoking状态
|
||||
upsertMCPToolResponse(
|
||||
allToolResponses,
|
||||
{
|
||||
...toolResponse,
|
||||
status: 'invoking'
|
||||
},
|
||||
onChunk!
|
||||
)
|
||||
|
||||
// 执行工具调用
|
||||
try {
|
||||
const images: string[] = []
|
||||
const toolCallResponse = await callMCPTool(toolResponse)
|
||||
|
||||
// 立即更新为done状态
|
||||
upsertMCPToolResponse(
|
||||
allToolResponses,
|
||||
{
|
||||
...toolResponse,
|
||||
status: 'done',
|
||||
response: toolCallResponse
|
||||
},
|
||||
onChunk!
|
||||
)
|
||||
|
||||
// 处理图片
|
||||
for (const content of toolCallResponse.content) {
|
||||
if (content.type === 'image' && content.data) {
|
||||
images.push(`data:${content.mimeType};base64,${content.data}`)
|
||||
}
|
||||
}
|
||||
|
||||
if (images.length) {
|
||||
onChunk?.({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
onChunk?.({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: images
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 转换消息并添加到结果
|
||||
const convertedMessage = convertToMessage(toolResponse, toolCallResponse, model)
|
||||
if (convertedMessage) {
|
||||
confirmedTools.push(toolResponse)
|
||||
toolResults.push(convertedMessage)
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error(`🔧 [MCP] Error executing tool ${toolResponse.id}:`, error)
|
||||
// 更新为错误状态
|
||||
upsertMCPToolResponse(
|
||||
allToolResponses,
|
||||
{
|
||||
...toolResponse,
|
||||
status: 'done',
|
||||
response: {
|
||||
isError: true,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `Error executing tool: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
onChunk!
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// 立即更新为cancelled状态
|
||||
upsertMCPToolResponse(
|
||||
allToolResponses,
|
||||
{
|
||||
...toolResponse,
|
||||
status: 'cancelled',
|
||||
response: {
|
||||
isError: false,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Tool call cancelled by user.'
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
onChunk!
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
.catch((error) => {
|
||||
Logger.error(`🔧 [MCP] Error waiting for tool confirmation ${toolResponse.id}:`, error)
|
||||
// 立即更新为cancelled状态
|
||||
upsertMCPToolResponse(
|
||||
allToolResponses,
|
||||
{
|
||||
...toolResponse,
|
||||
status: 'cancelled',
|
||||
response: {
|
||||
isError: true,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `Error in confirmation process: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
onChunk!
|
||||
)
|
||||
})
|
||||
|
||||
return convertToMessage(toolResponse, toolCallResponse, model)
|
||||
pendingPromises.push(processingPromise)
|
||||
})
|
||||
|
||||
toolResults.push(...(await Promise.all(toolPromises)).filter((t) => typeof t !== 'undefined'))
|
||||
return toolResults
|
||||
Logger.info(
|
||||
`🔧 [MCP] Waiting for tool confirmations:`,
|
||||
curToolResponses.map((t) => t.id)
|
||||
)
|
||||
|
||||
// 等待所有工具处理完成(但每个工具的状态已经实时更新)
|
||||
await Promise.all(pendingPromises)
|
||||
|
||||
Logger.info(`🔧 [MCP] All tools processed. Confirmed tools: ${confirmedTools.length}`)
|
||||
|
||||
return { toolResults, confirmedToolResponses: confirmedTools }
|
||||
}
|
||||
|
||||
export function mcpToolCallResponseToOpenAICompatibleMessage(
|
||||
|
||||
@ -2,7 +2,7 @@ import store from '@renderer/store'
|
||||
import { Assistant, MCPTool } from '@renderer/types'
|
||||
|
||||
export const SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \
|
||||
You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
|
||||
You can use one or more tools per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
|
||||
|
||||
## Tool Use Formatting
|
||||
|
||||
|
||||
86
src/renderer/src/utils/userConfirmation.ts
Normal file
86
src/renderer/src/utils/userConfirmation.ts
Normal file
@ -0,0 +1,86 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
|
||||
// 存储每个工具的确认Promise的resolve函数
|
||||
const toolConfirmResolvers = new Map<string, (value: boolean) => void>()
|
||||
// 存储每个工具的abort监听器清理函数
|
||||
const abortListeners = new Map<string, () => void>()
|
||||
|
||||
export function requestUserConfirmation(): Promise<boolean> {
|
||||
return new Promise((resolve) => {
|
||||
const globalKey = '_global'
|
||||
toolConfirmResolvers.set(globalKey, resolve)
|
||||
})
|
||||
}
|
||||
|
||||
export function requestToolConfirmation(toolId: string, abortSignal?: AbortSignal): Promise<boolean> {
|
||||
return new Promise((resolve) => {
|
||||
if (abortSignal?.aborted) {
|
||||
resolve(false)
|
||||
return
|
||||
}
|
||||
|
||||
toolConfirmResolvers.set(toolId, resolve)
|
||||
|
||||
if (abortSignal) {
|
||||
const abortListener = () => {
|
||||
const resolver = toolConfirmResolvers.get(toolId)
|
||||
if (resolver) {
|
||||
resolver(false)
|
||||
toolConfirmResolvers.delete(toolId)
|
||||
abortListeners.delete(toolId)
|
||||
}
|
||||
}
|
||||
|
||||
abortSignal.addEventListener('abort', abortListener)
|
||||
|
||||
// 存储清理函数
|
||||
const cleanup = () => {
|
||||
abortSignal.removeEventListener('abort', abortListener)
|
||||
abortListeners.delete(toolId)
|
||||
}
|
||||
abortListeners.set(toolId, cleanup)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
export function confirmToolAction(toolId: string) {
|
||||
const resolve = toolConfirmResolvers.get(toolId)
|
||||
if (resolve) {
|
||||
resolve(true)
|
||||
toolConfirmResolvers.delete(toolId)
|
||||
|
||||
// 清理abort监听器
|
||||
const cleanup = abortListeners.get(toolId)
|
||||
if (cleanup) {
|
||||
cleanup()
|
||||
}
|
||||
} else {
|
||||
Logger.warn(`🔧 [userConfirmation] No resolver found for tool: ${toolId}`)
|
||||
}
|
||||
}
|
||||
|
||||
export function cancelToolAction(toolId: string) {
|
||||
const resolve = toolConfirmResolvers.get(toolId)
|
||||
if (resolve) {
|
||||
resolve(false)
|
||||
toolConfirmResolvers.delete(toolId)
|
||||
|
||||
// 清理abort监听器
|
||||
const cleanup = abortListeners.get(toolId)
|
||||
if (cleanup) {
|
||||
cleanup()
|
||||
}
|
||||
} else {
|
||||
Logger.warn(`🔧 [userConfirmation] No resolver found for tool: ${toolId}`)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取所有待确认的工具ID
|
||||
export function getPendingToolIds(): string[] {
|
||||
return Array.from(toolConfirmResolvers.keys()).filter((id) => id !== '_global')
|
||||
}
|
||||
|
||||
// 检查某个工具是否在等待确认
|
||||
export function isToolPending(toolId: string): boolean {
|
||||
return toolConfirmResolvers.has(toolId)
|
||||
}
|
||||
16
yarn.lock
16
yarn.lock
@ -3756,11 +3756,11 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@modelcontextprotocol/sdk@npm:^1.11.4":
|
||||
version: 1.11.4
|
||||
resolution: "@modelcontextprotocol/sdk@npm:1.11.4"
|
||||
"@modelcontextprotocol/sdk@npm:^1.12.3":
|
||||
version: 1.12.3
|
||||
resolution: "@modelcontextprotocol/sdk@npm:1.12.3"
|
||||
dependencies:
|
||||
ajv: "npm:^8.17.1"
|
||||
ajv: "npm:^6.12.6"
|
||||
content-type: "npm:^1.0.5"
|
||||
cors: "npm:^2.8.5"
|
||||
cross-spawn: "npm:^7.0.5"
|
||||
@ -3771,7 +3771,7 @@ __metadata:
|
||||
raw-body: "npm:^3.0.0"
|
||||
zod: "npm:^3.23.8"
|
||||
zod-to-json-schema: "npm:^3.24.1"
|
||||
checksum: 10c0/797694937e65ccc02e8dc63db711d9d96fbc49b49e6d246e6fed95d8d2bfe98ef203207224e39c9fc3b54da182da865a5d311ea06ef939c5c57ce0cd27c0f546
|
||||
checksum: 10c0/8bc0b91e596ec886efc64d68ae8474247647405f1a5ae407e02439c74c2a03528b3fbdce8f9352d9c2df54aa4548411e1aa1816ab3b09e045c2ff4202e2fd374
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
@ -7092,7 +7092,7 @@ __metadata:
|
||||
"@libsql/client": "npm:0.14.0"
|
||||
"@libsql/win32-x64-msvc": "npm:^0.4.7"
|
||||
"@mistralai/mistralai": "npm:^1.6.0"
|
||||
"@modelcontextprotocol/sdk": "npm:^1.11.4"
|
||||
"@modelcontextprotocol/sdk": "npm:^1.12.3"
|
||||
"@mozilla/readability": "npm:^0.6.0"
|
||||
"@notionhq/client": "npm:^2.2.15"
|
||||
"@playwright/test": "npm:^1.52.0"
|
||||
@ -7349,7 +7349,7 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"ajv@npm:^6.10.0, ajv@npm:^6.12.0, ajv@npm:^6.12.4":
|
||||
"ajv@npm:^6.10.0, ajv@npm:^6.12.0, ajv@npm:^6.12.4, ajv@npm:^6.12.6":
|
||||
version: 6.12.6
|
||||
resolution: "ajv@npm:6.12.6"
|
||||
dependencies:
|
||||
@ -7361,7 +7361,7 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"ajv@npm:^8.0.0, ajv@npm:^8.17.1, ajv@npm:^8.6.3":
|
||||
"ajv@npm:^8.0.0, ajv@npm:^8.6.3":
|
||||
version: 8.17.1
|
||||
resolution: "ajv@npm:8.17.1"
|
||||
dependencies:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user