mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-09 06:49:02 +08:00
feat: implement tool call progress handling and status updates (#7303)
* feat: implement tool call progress handling and status updates - Update MCP tool response handling to include 'pending' and 'cancelled' statuses. - Introduce new IPC channel for progress updates. - Enhance UI components to reflect tool call statuses, including pending and cancelled states. - Add localization for new status messages in multiple languages. - Refactor message handling logic to accommodate new tool response types. * fix: adjust alignment of action tool container in MessageTools component - Change justify-content from flex-end to flex-start to improve layout consistency. * feat: enhance tool confirmation handling and update related components - Introduced a new tool confirmation mechanism in userConfirmation.ts, allowing for individual tool confirmations. - Updated GeminiAPIClient and OpenAIResponseAPIClient to include tool configuration options. - Refactored MessageTools component to utilize new confirmation functions and improved styling. - Enhanced mcp-tools.ts to manage tool invocation and confirmation processes more effectively, ensuring real-time status updates. * refactor(McpToolChunkMiddleware): enhance tool execution handling and confirmation tracking - Updated createToolHandlingTransform to manage confirmed tool calls and results more effectively. - Refactored executeToolCalls and executeToolUseResponses to return both tool results and confirmed tool calls. - Adjusted buildParamsWithToolResults to utilize confirmed tool calls for building new request messages. - Improved error handling in messageThunk for tool call status updates, ensuring accurate block ID mapping. * feat(McpToolChunkMiddleware, ToolUseExtractionMiddleware, mcp-tools, userConfirmation): enhance tool execution and confirmation handling - Updated McpToolChunkMiddleware to execute tool calls and responses asynchronously, improving performance and response handling. - Enhanced ToolUseExtractionMiddleware to generate unique tool IDs for better tracking. - Modified parseToolUse function to accept a starting index for tool extraction. - Improved user confirmation handling with abort signal support to manage tool action confirmations more effectively. - Updated SYSTEM_PROMPT to clarify the use of multiple tools per message. * fix(tagExtraction): update test expectations for tag extraction results - Adjusted expected length of results from 7 to 9 to reflect changes in tag extraction logic. - Modified content assertions for specific tag contents to ensure accurate validation of extracted tags. * refactor(GeminiAPIClient, OpenAIResponseAPIClient): remove unused function calling configurations - Removed the unused FunctionCallingConfigMode from GeminiAPIClient to streamline the code. - Eliminated the parallel_tool_calls property from OpenAIResponseAPIClient, simplifying the tool call configuration. * feat(McpToolChunkMiddleware): enhance LLM response handling and tool call confirmation - Added notification to UI for new LLM response processing before recursive calls in createToolHandlingTransform. - Improved tool call confirmation logic in executeToolCalls to match tool IDs more accurately, enhancing response validation. * refactor(McpToolChunkMiddleware, ToolUseExtractionMiddleware, messageThunk): remove unnecessary console logs - Eliminated redundant console log statements in McpToolChunkMiddleware, ToolUseExtractionMiddleware, and messageThunk to clean up the code and improve performance. - Focused on enhancing readability and maintainability by reducing clutter in the logging output. * refactor(McpToolChunkMiddleware): remove redundant logging statements - Eliminated unnecessary logging in createToolHandlingTransform to streamline the code and enhance readability. - Focused on reducing clutter in the logging output while maintaining error handling functionality. * feat: enhance action button functionality with cancel and confirm options * refactor(AbortHandlerMiddleware, McpToolChunkMiddleware, ToolUseExtractionMiddleware, messageThunk): improve error handling and code clarity - Updated AbortHandlerMiddleware to skip abort status checks if an error chunk is received, enhancing error handling logic. - Replaced console.error with Logger.error in McpToolChunkMiddleware for consistent logging practices. - Refined ToolUseExtractionMiddleware to improve tool use extraction logic and ensure proper handling of tool_use tags. - Enhanced messageThunk to include initialPlaceholderBlockId in block ID checks, improving error state management. * refactor(ToolUseExtractionMiddleware): enhance tool use parsing logic with counter - Introduced a toolCounter to track the number of tool use responses processed. - Updated parseToolUse function calls to include the toolCounter, improving the extraction logic and ensuring accurate response handling. * feat(McpService, IpcChannel, MessageTools): implement tool abort functionality - Added Mcp_AbortTool channel to handle tool abortion requests. - Implemented abortTool method in McpService to manage active tool calls and provide logging. - Updated MessageTools component to include an abort button for ongoing tool calls, enhancing user control. - Modified API calls to support optional callId for better tracking of tool executions. - Added localization strings for tool abort messages in multiple languages. --------- Co-authored-by: Vaayne <liu.vaayne@gmail.com>
This commit is contained in:
parent
4f7ca3ede8
commit
fba6c1642d
@ -107,7 +107,7 @@
|
|||||||
"@langchain/community": "^0.3.36",
|
"@langchain/community": "^0.3.36",
|
||||||
"@langchain/ollama": "^0.2.1",
|
"@langchain/ollama": "^0.2.1",
|
||||||
"@mistralai/mistralai": "^1.6.0",
|
"@mistralai/mistralai": "^1.6.0",
|
||||||
"@modelcontextprotocol/sdk": "^1.11.4",
|
"@modelcontextprotocol/sdk": "^1.12.3",
|
||||||
"@mozilla/readability": "^0.6.0",
|
"@mozilla/readability": "^0.6.0",
|
||||||
"@notionhq/client": "^2.2.15",
|
"@notionhq/client": "^2.2.15",
|
||||||
"@playwright/test": "^1.52.0",
|
"@playwright/test": "^1.52.0",
|
||||||
|
|||||||
@ -74,6 +74,8 @@ export enum IpcChannel {
|
|||||||
Mcp_ServersChanged = 'mcp:servers-changed',
|
Mcp_ServersChanged = 'mcp:servers-changed',
|
||||||
Mcp_ServersUpdated = 'mcp:servers-updated',
|
Mcp_ServersUpdated = 'mcp:servers-updated',
|
||||||
Mcp_CheckConnectivity = 'mcp:check-connectivity',
|
Mcp_CheckConnectivity = 'mcp:check-connectivity',
|
||||||
|
Mcp_SetProgress = 'mcp:set-progress',
|
||||||
|
Mcp_AbortTool = 'mcp:abort-tool',
|
||||||
|
|
||||||
// Python
|
// Python
|
||||||
Python_Execute = 'python:execute',
|
Python_Execute = 'python:execute',
|
||||||
|
|||||||
@ -501,6 +501,10 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
ipcMain.handle(IpcChannel.Mcp_GetResource, mcpService.getResource)
|
ipcMain.handle(IpcChannel.Mcp_GetResource, mcpService.getResource)
|
||||||
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo)
|
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo)
|
||||||
ipcMain.handle(IpcChannel.Mcp_CheckConnectivity, mcpService.checkMcpConnectivity)
|
ipcMain.handle(IpcChannel.Mcp_CheckConnectivity, mcpService.checkMcpConnectivity)
|
||||||
|
ipcMain.handle(IpcChannel.Mcp_AbortTool, mcpService.abortTool)
|
||||||
|
ipcMain.handle(IpcChannel.Mcp_SetProgress, (_, progress: number) => {
|
||||||
|
mainWindow.webContents.send('mcp-progress', progress)
|
||||||
|
})
|
||||||
|
|
||||||
// Register Python execution handler
|
// Register Python execution handler
|
||||||
ipcMain.handle(
|
ipcMain.handle(
|
||||||
|
|||||||
@ -28,6 +28,7 @@ import { app } from 'electron'
|
|||||||
import Logger from 'electron-log'
|
import Logger from 'electron-log'
|
||||||
import { EventEmitter } from 'events'
|
import { EventEmitter } from 'events'
|
||||||
import { memoize } from 'lodash'
|
import { memoize } from 'lodash'
|
||||||
|
import { v4 as uuidv4 } from 'uuid'
|
||||||
|
|
||||||
import { CacheService } from './CacheService'
|
import { CacheService } from './CacheService'
|
||||||
import { CallBackServer } from './mcp/oauth/callback'
|
import { CallBackServer } from './mcp/oauth/callback'
|
||||||
@ -71,6 +72,7 @@ function withCache<T extends unknown[], R>(
|
|||||||
class McpService {
|
class McpService {
|
||||||
private clients: Map<string, Client> = new Map()
|
private clients: Map<string, Client> = new Map()
|
||||||
private pendingClients: Map<string, Promise<Client>> = new Map()
|
private pendingClients: Map<string, Promise<Client>> = new Map()
|
||||||
|
private activeToolCalls: Map<string, AbortController> = new Map()
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
this.initClient = this.initClient.bind(this)
|
this.initClient = this.initClient.bind(this)
|
||||||
@ -84,6 +86,7 @@ class McpService {
|
|||||||
this.removeServer = this.removeServer.bind(this)
|
this.removeServer = this.removeServer.bind(this)
|
||||||
this.restartServer = this.restartServer.bind(this)
|
this.restartServer = this.restartServer.bind(this)
|
||||||
this.stopServer = this.stopServer.bind(this)
|
this.stopServer = this.stopServer.bind(this)
|
||||||
|
this.abortTool = this.abortTool.bind(this)
|
||||||
this.cleanup = this.cleanup.bind(this)
|
this.cleanup = this.cleanup.bind(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -455,10 +458,14 @@ class McpService {
|
|||||||
*/
|
*/
|
||||||
public async callTool(
|
public async callTool(
|
||||||
_: Electron.IpcMainInvokeEvent,
|
_: Electron.IpcMainInvokeEvent,
|
||||||
{ server, name, args }: { server: MCPServer; name: string; args: any }
|
{ server, name, args, callId }: { server: MCPServer; name: string; args: any; callId?: string }
|
||||||
): Promise<MCPCallToolResponse> {
|
): Promise<MCPCallToolResponse> {
|
||||||
|
const toolCallId = callId || uuidv4()
|
||||||
|
const abortController = new AbortController()
|
||||||
|
this.activeToolCalls.set(toolCallId, abortController)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Logger.info('[MCP] Calling:', server.name, name, args)
|
Logger.info('[MCP] Calling:', server.name, name, args, 'callId:', toolCallId)
|
||||||
if (typeof args === 'string') {
|
if (typeof args === 'string') {
|
||||||
try {
|
try {
|
||||||
args = JSON.parse(args)
|
args = JSON.parse(args)
|
||||||
@ -468,12 +475,19 @@ class McpService {
|
|||||||
}
|
}
|
||||||
const client = await this.initClient(server)
|
const client = await this.initClient(server)
|
||||||
const result = await client.callTool({ name, arguments: args }, undefined, {
|
const result = await client.callTool({ name, arguments: args }, undefined, {
|
||||||
timeout: server.timeout ? server.timeout * 1000 : 60000 // Default timeout of 1 minute
|
onprogress: (process) => {
|
||||||
|
console.log('[MCP] Progress:', process.progress / (process.total || 1))
|
||||||
|
window.api.mcp.setProgress(process.progress / (process.total || 1))
|
||||||
|
},
|
||||||
|
timeout: server.timeout ? server.timeout * 1000 : 60000, // Default timeout of 1 minute
|
||||||
|
signal: this.activeToolCalls.get(toolCallId)?.signal
|
||||||
})
|
})
|
||||||
return result as MCPCallToolResponse
|
return result as MCPCallToolResponse
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
Logger.error(`[MCP] Error calling tool ${name} on ${server.name}:`, error)
|
Logger.error(`[MCP] Error calling tool ${name} on ${server.name}:`, error)
|
||||||
throw error
|
throw error
|
||||||
|
} finally {
|
||||||
|
this.activeToolCalls.delete(toolCallId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -664,6 +678,20 @@ class McpService {
|
|||||||
delete env.http_proxy
|
delete env.http_proxy
|
||||||
delete env.https_proxy
|
delete env.https_proxy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 实现 abortTool 方法
|
||||||
|
public async abortTool(_: Electron.IpcMainInvokeEvent, callId: string) {
|
||||||
|
const activeToolCall = this.activeToolCalls.get(callId)
|
||||||
|
if (activeToolCall) {
|
||||||
|
activeToolCall.abort()
|
||||||
|
this.activeToolCalls.delete(callId)
|
||||||
|
Logger.info(`[MCP] Aborted tool call: ${callId}`)
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
Logger.warn(`[MCP] No active tool call found for callId: ${callId}`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export default new McpService()
|
export default new McpService()
|
||||||
|
|||||||
@ -228,8 +228,8 @@ const api = {
|
|||||||
restartServer: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_RestartServer, server),
|
restartServer: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_RestartServer, server),
|
||||||
stopServer: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_StopServer, server),
|
stopServer: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_StopServer, server),
|
||||||
listTools: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_ListTools, server),
|
listTools: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_ListTools, server),
|
||||||
callTool: ({ server, name, args }: { server: MCPServer; name: string; args: any }) =>
|
callTool: ({ server, name, args, callId }: { server: MCPServer; name: string; args: any; callId?: string }) =>
|
||||||
ipcRenderer.invoke(IpcChannel.Mcp_CallTool, { server, name, args }),
|
ipcRenderer.invoke(IpcChannel.Mcp_CallTool, { server, name, args, callId }),
|
||||||
listPrompts: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_ListPrompts, server),
|
listPrompts: (server: MCPServer) => ipcRenderer.invoke(IpcChannel.Mcp_ListPrompts, server),
|
||||||
getPrompt: ({ server, name, args }: { server: MCPServer; name: string; args?: Record<string, any> }) =>
|
getPrompt: ({ server, name, args }: { server: MCPServer; name: string; args?: Record<string, any> }) =>
|
||||||
ipcRenderer.invoke(IpcChannel.Mcp_GetPrompt, { server, name, args }),
|
ipcRenderer.invoke(IpcChannel.Mcp_GetPrompt, { server, name, args }),
|
||||||
@ -237,7 +237,9 @@ const api = {
|
|||||||
getResource: ({ server, uri }: { server: MCPServer; uri: string }) =>
|
getResource: ({ server, uri }: { server: MCPServer; uri: string }) =>
|
||||||
ipcRenderer.invoke(IpcChannel.Mcp_GetResource, { server, uri }),
|
ipcRenderer.invoke(IpcChannel.Mcp_GetResource, { server, uri }),
|
||||||
getInstallInfo: () => ipcRenderer.invoke(IpcChannel.Mcp_GetInstallInfo),
|
getInstallInfo: () => ipcRenderer.invoke(IpcChannel.Mcp_GetInstallInfo),
|
||||||
checkMcpConnectivity: (server: any) => ipcRenderer.invoke(IpcChannel.Mcp_CheckConnectivity, server)
|
checkMcpConnectivity: (server: any) => ipcRenderer.invoke(IpcChannel.Mcp_CheckConnectivity, server),
|
||||||
|
abortTool: (callId: string) => ipcRenderer.invoke(IpcChannel.Mcp_AbortTool, callId),
|
||||||
|
setProgress: (progress: number) => ipcRenderer.invoke(IpcChannel.Mcp_SetProgress, progress)
|
||||||
},
|
},
|
||||||
python: {
|
python: {
|
||||||
execute: (script: string, context?: Record<string, any>, timeout?: number) =>
|
execute: (script: string, context?: Record<string, any>, timeout?: number) =>
|
||||||
|
|||||||
@ -67,7 +67,12 @@ export const AbortHandlerMiddleware: CompletionsMiddleware =
|
|||||||
const streamWithAbortHandler = (result.stream as ReadableStream<Chunk>).pipeThrough(
|
const streamWithAbortHandler = (result.stream as ReadableStream<Chunk>).pipeThrough(
|
||||||
new TransformStream<Chunk, Chunk | ErrorChunk>({
|
new TransformStream<Chunk, Chunk | ErrorChunk>({
|
||||||
transform(chunk, controller) {
|
transform(chunk, controller) {
|
||||||
// 检查 abort 状态
|
// 如果已经收到错误块,不再检查 abort 状态
|
||||||
|
if (chunk.type === ChunkType.ERROR) {
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if (abortSignal?.aborted) {
|
if (abortSignal?.aborted) {
|
||||||
// 转换为 ErrorChunk
|
// 转换为 ErrorChunk
|
||||||
const errorChunk: ErrorChunk = {
|
const errorChunk: ErrorChunk = {
|
||||||
|
|||||||
@ -136,7 +136,6 @@ function extractAndAccumulateUsageMetrics(ctx: CompletionsContext, chunk: Generi
|
|||||||
Logger.debug(`[${MIDDLEWARE_NAME}] First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`)
|
Logger.debug(`[${MIDDLEWARE_NAME}] First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`)
|
||||||
}
|
}
|
||||||
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
|
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||||
Logger.debug(`[${MIDDLEWARE_NAME}] LLM_RESPONSE_COMPLETE chunk received:`, ctx._internal)
|
|
||||||
// 从LLM_RESPONSE_COMPLETE chunk中提取usage数据
|
// 从LLM_RESPONSE_COMPLETE chunk中提取usage数据
|
||||||
if (chunk.response?.usage) {
|
if (chunk.response?.usage) {
|
||||||
accumulateUsage(ctx._internal.observer.usage, chunk.response.usage)
|
accumulateUsage(ctx._internal.observer.usage, chunk.response.usage)
|
||||||
|
|||||||
@ -89,6 +89,11 @@ function createToolHandlingTransform(
|
|||||||
let hasToolUseResponses = false
|
let hasToolUseResponses = false
|
||||||
let streamEnded = false
|
let streamEnded = false
|
||||||
|
|
||||||
|
// 存储已执行的工具结果
|
||||||
|
const executedToolResults: SdkMessageParam[] = []
|
||||||
|
const executedToolCalls: SdkToolCall[] = []
|
||||||
|
const executionPromises: Promise<void>[] = []
|
||||||
|
|
||||||
return new TransformStream({
|
return new TransformStream({
|
||||||
async transform(chunk: GenericChunk, controller) {
|
async transform(chunk: GenericChunk, controller) {
|
||||||
try {
|
try {
|
||||||
@ -98,22 +103,64 @@ function createToolHandlingTransform(
|
|||||||
|
|
||||||
// 1. 处理Function Call方式的工具调用
|
// 1. 处理Function Call方式的工具调用
|
||||||
if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) {
|
if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) {
|
||||||
toolCalls.push(...createdChunk.tool_calls)
|
|
||||||
hasToolCalls = true
|
hasToolCalls = true
|
||||||
|
|
||||||
|
for (const toolCall of createdChunk.tool_calls) {
|
||||||
|
toolCalls.push(toolCall)
|
||||||
|
|
||||||
|
const executionPromise = (async () => {
|
||||||
|
try {
|
||||||
|
const result = await executeToolCalls(
|
||||||
|
ctx,
|
||||||
|
[toolCall],
|
||||||
|
mcpTools,
|
||||||
|
allToolResponses,
|
||||||
|
currentParams.onChunk,
|
||||||
|
currentParams.assistant.model!
|
||||||
|
)
|
||||||
|
|
||||||
|
// 缓存执行结果
|
||||||
|
executedToolResults.push(...result.toolResults)
|
||||||
|
executedToolCalls.push(...result.confirmedToolCalls)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`🔧 [${MIDDLEWARE_NAME}] Error executing tool call asynchronously:`, error)
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
executionPromises.push(executionPromise)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 处理Tool Use方式的工具调用
|
// 2. 处理Tool Use方式的工具调用
|
||||||
if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) {
|
if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) {
|
||||||
toolUseResponses.push(...createdChunk.tool_use_responses)
|
|
||||||
hasToolUseResponses = true
|
hasToolUseResponses = true
|
||||||
|
for (const toolUseResponse of createdChunk.tool_use_responses) {
|
||||||
|
toolUseResponses.push(toolUseResponse)
|
||||||
|
const executionPromise = (async () => {
|
||||||
|
try {
|
||||||
|
const result = await executeToolUseResponses(
|
||||||
|
ctx,
|
||||||
|
[toolUseResponse], // 单个执行
|
||||||
|
mcpTools,
|
||||||
|
allToolResponses,
|
||||||
|
currentParams.onChunk,
|
||||||
|
currentParams.assistant.model!
|
||||||
|
)
|
||||||
|
|
||||||
|
// 缓存执行结果
|
||||||
|
executedToolResults.push(...result.toolResults)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`🔧 [${MIDDLEWARE_NAME}] Error executing tool use response asynchronously:`, error)
|
||||||
|
// 错误时不影响其他工具的执行
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
executionPromises.push(executionPromise)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
// 不转发MCP工具进展chunks,避免重复处理
|
controller.enqueue(chunk)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 转发其他所有chunk
|
|
||||||
controller.enqueue(chunk)
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
|
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
|
||||||
controller.error(error)
|
controller.error(error)
|
||||||
@ -121,43 +168,33 @@ function createToolHandlingTransform(
|
|||||||
},
|
},
|
||||||
|
|
||||||
async flush(controller) {
|
async flush(controller) {
|
||||||
const shouldExecuteToolCalls = hasToolCalls && toolCalls.length > 0
|
// 在流结束时等待所有异步工具执行完成,然后进行递归调用
|
||||||
const shouldExecuteToolUseResponses = hasToolUseResponses && toolUseResponses.length > 0
|
if (!streamEnded && (hasToolCalls || hasToolUseResponses)) {
|
||||||
|
|
||||||
if (!streamEnded && (shouldExecuteToolCalls || shouldExecuteToolUseResponses)) {
|
|
||||||
streamEnded = true
|
streamEnded = true
|
||||||
|
|
||||||
try {
|
try {
|
||||||
let toolResult: SdkMessageParam[] = []
|
await Promise.all(executionPromises)
|
||||||
|
if (executedToolResults.length > 0) {
|
||||||
if (shouldExecuteToolCalls) {
|
|
||||||
toolResult = await executeToolCalls(
|
|
||||||
ctx,
|
|
||||||
toolCalls,
|
|
||||||
mcpTools,
|
|
||||||
allToolResponses,
|
|
||||||
currentParams.onChunk,
|
|
||||||
currentParams.assistant.model!
|
|
||||||
)
|
|
||||||
} else if (shouldExecuteToolUseResponses) {
|
|
||||||
toolResult = await executeToolUseResponses(
|
|
||||||
ctx,
|
|
||||||
toolUseResponses,
|
|
||||||
mcpTools,
|
|
||||||
allToolResponses,
|
|
||||||
currentParams.onChunk,
|
|
||||||
currentParams.assistant.model!
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (toolResult.length > 0) {
|
|
||||||
const output = ctx._internal.toolProcessingState?.output
|
const output = ctx._internal.toolProcessingState?.output
|
||||||
|
const newParams = buildParamsWithToolResults(
|
||||||
|
ctx,
|
||||||
|
currentParams,
|
||||||
|
output,
|
||||||
|
executedToolResults,
|
||||||
|
executedToolCalls
|
||||||
|
)
|
||||||
|
|
||||||
|
// 在递归调用前通知UI开始新的LLM响应处理
|
||||||
|
if (currentParams.onChunk) {
|
||||||
|
currentParams.onChunk({
|
||||||
|
type: ChunkType.LLM_RESPONSE_CREATED
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
const newParams = buildParamsWithToolResults(ctx, currentParams, output, toolResult, toolCalls)
|
|
||||||
await executeWithToolHandling(newParams, depth + 1)
|
await executeWithToolHandling(newParams, depth + 1)
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`🔧 [${MIDDLEWARE_NAME}] Error in tool processing:`, error)
|
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Error in tool processing:`, error)
|
||||||
controller.error(error)
|
controller.error(error)
|
||||||
} finally {
|
} finally {
|
||||||
hasToolCalls = false
|
hasToolCalls = false
|
||||||
@ -178,8 +215,7 @@ async function executeToolCalls(
|
|||||||
allToolResponses: MCPToolResponse[],
|
allToolResponses: MCPToolResponse[],
|
||||||
onChunk: CompletionsParams['onChunk'],
|
onChunk: CompletionsParams['onChunk'],
|
||||||
model: Model
|
model: Model
|
||||||
): Promise<SdkMessageParam[]> {
|
): Promise<{ toolResults: SdkMessageParam[]; confirmedToolCalls: SdkToolCall[] }> {
|
||||||
// 转换为MCPToolResponse格式
|
|
||||||
const mcpToolResponses: ToolCallResponse[] = toolCalls
|
const mcpToolResponses: ToolCallResponse[] = toolCalls
|
||||||
.map((toolCall) => {
|
.map((toolCall) => {
|
||||||
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
|
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||||
@ -192,11 +228,11 @@ async function executeToolCalls(
|
|||||||
|
|
||||||
if (mcpToolResponses.length === 0) {
|
if (mcpToolResponses.length === 0) {
|
||||||
console.warn(`🔧 [${MIDDLEWARE_NAME}] No valid MCP tool responses to execute`)
|
console.warn(`🔧 [${MIDDLEWARE_NAME}] No valid MCP tool responses to execute`)
|
||||||
return []
|
return { toolResults: [], confirmedToolCalls: [] }
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用现有的parseAndCallTools函数执行工具
|
// 使用现有的parseAndCallTools函数执行工具
|
||||||
const toolResults = await parseAndCallTools(
|
const { toolResults, confirmedToolResponses } = await parseAndCallTools(
|
||||||
mcpToolResponses,
|
mcpToolResponses,
|
||||||
allToolResponses,
|
allToolResponses,
|
||||||
onChunk,
|
onChunk,
|
||||||
@ -204,10 +240,24 @@ async function executeToolCalls(
|
|||||||
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||||
},
|
},
|
||||||
model,
|
model,
|
||||||
mcpTools
|
mcpTools,
|
||||||
|
ctx._internal?.flowControl?.abortSignal
|
||||||
)
|
)
|
||||||
|
|
||||||
return toolResults
|
// 找出已确认工具对应的原始toolCalls
|
||||||
|
const confirmedToolCalls = toolCalls.filter((toolCall) => {
|
||||||
|
return confirmedToolResponses.find((confirmed) => {
|
||||||
|
// 根据不同的ID字段匹配原始toolCall
|
||||||
|
return (
|
||||||
|
('name' in toolCall &&
|
||||||
|
(toolCall.name?.includes(confirmed.tool.name) || toolCall.name?.includes(confirmed.tool.id))) ||
|
||||||
|
confirmed.tool.name === toolCall.id ||
|
||||||
|
confirmed.tool.id === toolCall.id
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return { toolResults, confirmedToolCalls }
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -221,9 +271,9 @@ async function executeToolUseResponses(
|
|||||||
allToolResponses: MCPToolResponse[],
|
allToolResponses: MCPToolResponse[],
|
||||||
onChunk: CompletionsParams['onChunk'],
|
onChunk: CompletionsParams['onChunk'],
|
||||||
model: Model
|
model: Model
|
||||||
): Promise<SdkMessageParam[]> {
|
): Promise<{ toolResults: SdkMessageParam[] }> {
|
||||||
// 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse
|
// 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse
|
||||||
const toolResults = await parseAndCallTools(
|
const { toolResults } = await parseAndCallTools(
|
||||||
toolUseResponses,
|
toolUseResponses,
|
||||||
allToolResponses,
|
allToolResponses,
|
||||||
onChunk,
|
onChunk,
|
||||||
@ -231,10 +281,11 @@ async function executeToolUseResponses(
|
|||||||
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||||
},
|
},
|
||||||
model,
|
model,
|
||||||
mcpTools
|
mcpTools,
|
||||||
|
ctx._internal?.flowControl?.abortSignal
|
||||||
)
|
)
|
||||||
|
|
||||||
return toolResults
|
return { toolResults }
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -245,7 +296,7 @@ function buildParamsWithToolResults(
|
|||||||
currentParams: CompletionsParams,
|
currentParams: CompletionsParams,
|
||||||
output: SdkRawOutput | string | undefined,
|
output: SdkRawOutput | string | undefined,
|
||||||
toolResults: SdkMessageParam[],
|
toolResults: SdkMessageParam[],
|
||||||
toolCalls: SdkToolCall[]
|
confirmedToolCalls: SdkToolCall[]
|
||||||
): CompletionsParams {
|
): CompletionsParams {
|
||||||
// 获取当前已经转换好的reqMessages,如果没有则使用原始messages
|
// 获取当前已经转换好的reqMessages,如果没有则使用原始messages
|
||||||
const currentReqMessages = getCurrentReqMessages(ctx)
|
const currentReqMessages = getCurrentReqMessages(ctx)
|
||||||
@ -253,7 +304,7 @@ function buildParamsWithToolResults(
|
|||||||
const apiClient = ctx.apiClientInstance
|
const apiClient = ctx.apiClientInstance
|
||||||
|
|
||||||
// 从回复中构建助手消息
|
// 从回复中构建助手消息
|
||||||
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, confirmedToolCalls)
|
||||||
|
|
||||||
if (output && ctx._internal.toolProcessingState) {
|
if (output && ctx._internal.toolProcessingState) {
|
||||||
ctx._internal.toolProcessingState.output = undefined
|
ctx._internal.toolProcessingState.output = undefined
|
||||||
|
|||||||
@ -22,7 +22,8 @@ const TOOL_USE_TAG_CONFIG: TagConfig = {
|
|||||||
* 1. 从文本流中检测并提取 <tool_use></tool_use> 标签
|
* 1. 从文本流中检测并提取 <tool_use></tool_use> 标签
|
||||||
* 2. 解析工具调用信息并转换为 ToolUseResponse 格式
|
* 2. 解析工具调用信息并转换为 ToolUseResponse 格式
|
||||||
* 3. 生成 MCP_TOOL_CREATED chunk 供 McpToolChunkMiddleware 处理
|
* 3. 生成 MCP_TOOL_CREATED chunk 供 McpToolChunkMiddleware 处理
|
||||||
* 4. 清理文本流,移除工具使用标签但保留正常文本
|
* 4. 丢弃 tool_use 之后的所有内容(助手幻觉)
|
||||||
|
* 5. 清理文本流,移除工具使用标签但保留正常文本
|
||||||
*
|
*
|
||||||
* 注意:此中间件只负责提取和转换,实际工具调用由 McpToolChunkMiddleware 处理
|
* 注意:此中间件只负责提取和转换,实际工具调用由 McpToolChunkMiddleware 处理
|
||||||
*/
|
*/
|
||||||
@ -32,13 +33,10 @@ export const ToolUseExtractionMiddleware: CompletionsMiddleware =
|
|||||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||||
const mcpTools = params.mcpTools || []
|
const mcpTools = params.mcpTools || []
|
||||||
|
|
||||||
// 如果没有工具,直接调用下一个中间件
|
|
||||||
if (!mcpTools || mcpTools.length === 0) return next(ctx, params)
|
if (!mcpTools || mcpTools.length === 0) return next(ctx, params)
|
||||||
|
|
||||||
// 调用下游中间件
|
|
||||||
const result = await next(ctx, params)
|
const result = await next(ctx, params)
|
||||||
|
|
||||||
// 响应后处理:处理工具使用标签提取
|
|
||||||
if (result.stream) {
|
if (result.stream) {
|
||||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||||
|
|
||||||
@ -60,7 +58,9 @@ function createToolUseExtractionTransform(
|
|||||||
_ctx: CompletionsContext,
|
_ctx: CompletionsContext,
|
||||||
mcpTools: MCPTool[]
|
mcpTools: MCPTool[]
|
||||||
): TransformStream<GenericChunk, GenericChunk> {
|
): TransformStream<GenericChunk, GenericChunk> {
|
||||||
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
|
const toolUseExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
|
||||||
|
let hasAnyToolUse = false
|
||||||
|
let toolCounter = 0
|
||||||
|
|
||||||
return new TransformStream({
|
return new TransformStream({
|
||||||
async transform(chunk: GenericChunk, controller) {
|
async transform(chunk: GenericChunk, controller) {
|
||||||
@ -68,30 +68,37 @@ function createToolUseExtractionTransform(
|
|||||||
// 处理文本内容,检测工具使用标签
|
// 处理文本内容,检测工具使用标签
|
||||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||||
const textChunk = chunk as TextDeltaChunk
|
const textChunk = chunk as TextDeltaChunk
|
||||||
const extractionResults = tagExtractor.processText(textChunk.text)
|
|
||||||
|
|
||||||
for (const result of extractionResults) {
|
// 处理 tool_use 标签
|
||||||
|
const toolUseResults = toolUseExtractor.processText(textChunk.text)
|
||||||
|
|
||||||
|
for (const result of toolUseResults) {
|
||||||
if (result.complete && result.tagContentExtracted) {
|
if (result.complete && result.tagContentExtracted) {
|
||||||
// 提取到完整的工具使用内容,解析并转换为 SDK ToolCall 格式
|
// 提取到完整的工具使用内容,解析并转换为 SDK ToolCall 格式
|
||||||
const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools)
|
const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools, toolCounter)
|
||||||
|
toolCounter += toolUseResponses.length
|
||||||
|
|
||||||
if (toolUseResponses.length > 0) {
|
if (toolUseResponses.length > 0) {
|
||||||
// 生成 MCP_TOOL_CREATED chunk,复用现有的处理流程
|
// 生成 MCP_TOOL_CREATED chunk
|
||||||
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
||||||
type: ChunkType.MCP_TOOL_CREATED,
|
type: ChunkType.MCP_TOOL_CREATED,
|
||||||
tool_use_responses: toolUseResponses
|
tool_use_responses: toolUseResponses
|
||||||
}
|
}
|
||||||
controller.enqueue(mcpToolCreatedChunk)
|
controller.enqueue(mcpToolCreatedChunk)
|
||||||
|
|
||||||
|
// 标记已有工具调用
|
||||||
|
hasAnyToolUse = true
|
||||||
}
|
}
|
||||||
} else if (!result.isTagContent && result.content) {
|
} else if (!result.isTagContent && result.content) {
|
||||||
// 发送标签外的正常文本内容
|
if (!hasAnyToolUse) {
|
||||||
const cleanTextChunk: TextDeltaChunk = {
|
const cleanTextChunk: TextDeltaChunk = {
|
||||||
...textChunk,
|
...textChunk,
|
||||||
text: result.content
|
text: result.content
|
||||||
|
}
|
||||||
|
controller.enqueue(cleanTextChunk)
|
||||||
}
|
}
|
||||||
controller.enqueue(cleanTextChunk)
|
|
||||||
}
|
}
|
||||||
// 注意:标签内的内容不会作为TEXT_DELTA转发,避免重复显示
|
// tool_use 标签内的内容不转发,避免重复显示
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -105,16 +112,17 @@ function createToolUseExtractionTransform(
|
|||||||
},
|
},
|
||||||
|
|
||||||
async flush(controller) {
|
async flush(controller) {
|
||||||
// 检查是否有未完成的标签内容
|
// 检查是否有未完成的 tool_use 标签内容
|
||||||
const finalResult = tagExtractor.finalize()
|
const finalToolUseResult = toolUseExtractor.finalize()
|
||||||
if (finalResult && finalResult.tagContentExtracted) {
|
if (finalToolUseResult && finalToolUseResult.tagContentExtracted) {
|
||||||
const toolUseResponses = parseToolUse(finalResult.tagContentExtracted, mcpTools)
|
const toolUseResponses = parseToolUse(finalToolUseResult.tagContentExtracted, mcpTools, toolCounter)
|
||||||
if (toolUseResponses.length > 0) {
|
if (toolUseResponses.length > 0) {
|
||||||
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
||||||
type: ChunkType.MCP_TOOL_CREATED,
|
type: ChunkType.MCP_TOOL_CREATED,
|
||||||
tool_use_responses: toolUseResponses
|
tool_use_responses: toolUseResponses
|
||||||
}
|
}
|
||||||
controller.enqueue(mcpToolCreatedChunk)
|
controller.enqueue(mcpToolCreatedChunk)
|
||||||
|
hasAnyToolUse = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -706,8 +706,12 @@
|
|||||||
"success.yuque.export": "Successfully exported to Yuque",
|
"success.yuque.export": "Successfully exported to Yuque",
|
||||||
"switch.disabled": "Please wait for the current reply to complete",
|
"switch.disabled": "Please wait for the current reply to complete",
|
||||||
"tools": {
|
"tools": {
|
||||||
|
"pending": "Pending",
|
||||||
|
"cancelled": "Cancelled",
|
||||||
"completed": "Completed",
|
"completed": "Completed",
|
||||||
"invoking": "Invoking",
|
"invoking": "Invoking",
|
||||||
|
"aborted": "Tool call aborted",
|
||||||
|
"abort_failed": "Tool call abort failed",
|
||||||
"error": "Error occurred",
|
"error": "Error occurred",
|
||||||
"raw": "Raw",
|
"raw": "Raw",
|
||||||
"preview": "Preview"
|
"preview": "Preview"
|
||||||
|
|||||||
@ -706,9 +706,13 @@
|
|||||||
"tools": {
|
"tools": {
|
||||||
"completed": "完了",
|
"completed": "完了",
|
||||||
"invoking": "呼び出し中",
|
"invoking": "呼び出し中",
|
||||||
|
"aborted": "ツール呼び出し中断",
|
||||||
|
"abort_failed": "ツール呼び出し中断失敗",
|
||||||
"error": "エラーが発生しました",
|
"error": "エラーが発生しました",
|
||||||
"raw": "生データ",
|
"raw": "生データ",
|
||||||
"preview": "プレビュー"
|
"preview": "プレビュー",
|
||||||
|
"pending": "保留中",
|
||||||
|
"cancelled": "キャンセル"
|
||||||
},
|
},
|
||||||
"topic.added": "新しいトピックが追加されました",
|
"topic.added": "新しいトピックが追加されました",
|
||||||
"upgrade.success.button": "再起動",
|
"upgrade.success.button": "再起動",
|
||||||
|
|||||||
@ -705,11 +705,15 @@
|
|||||||
"success.yuque.export": "Успешный экспорт в Yuque",
|
"success.yuque.export": "Успешный экспорт в Yuque",
|
||||||
"switch.disabled": "Пожалуйста, дождитесь завершения текущего ответа",
|
"switch.disabled": "Пожалуйста, дождитесь завершения текущего ответа",
|
||||||
"tools": {
|
"tools": {
|
||||||
|
"aborted": "Вызов инструмента прерван",
|
||||||
|
"abort_failed": "Вызов инструмента прерван",
|
||||||
"completed": "Завершено",
|
"completed": "Завершено",
|
||||||
"invoking": "Вызов",
|
"invoking": "Вызов",
|
||||||
"error": "Произошла ошибка",
|
"error": "Произошла ошибка",
|
||||||
"raw": "Исходный",
|
"raw": "Исходный",
|
||||||
"preview": "Предпросмотр"
|
"preview": "Предпросмотр",
|
||||||
|
"pending": "Ожидание",
|
||||||
|
"cancelled": "Отменено"
|
||||||
},
|
},
|
||||||
"topic.added": "Новый топик добавлен",
|
"topic.added": "Новый топик добавлен",
|
||||||
"upgrade.success.button": "Перезапустить",
|
"upgrade.success.button": "Перезапустить",
|
||||||
|
|||||||
@ -706,8 +706,12 @@
|
|||||||
"success.yuque.export": "成功导出到语雀",
|
"success.yuque.export": "成功导出到语雀",
|
||||||
"switch.disabled": "请等待当前回复完成后操作",
|
"switch.disabled": "请等待当前回复完成后操作",
|
||||||
"tools": {
|
"tools": {
|
||||||
|
"pending": "等待中",
|
||||||
|
"cancelled": "已取消",
|
||||||
"completed": "已完成",
|
"completed": "已完成",
|
||||||
"invoking": "调用中",
|
"invoking": "调用中",
|
||||||
|
"aborted": "工具调用已中断",
|
||||||
|
"abort_failed": "工具调用中断失败",
|
||||||
"error": "发生错误",
|
"error": "发生错误",
|
||||||
"raw": "原始",
|
"raw": "原始",
|
||||||
"preview": "预览"
|
"preview": "预览"
|
||||||
|
|||||||
@ -706,11 +706,15 @@
|
|||||||
"success.yuque.export": "成功匯出到語雀",
|
"success.yuque.export": "成功匯出到語雀",
|
||||||
"switch.disabled": "請等待當前回覆完成",
|
"switch.disabled": "請等待當前回覆完成",
|
||||||
"tools": {
|
"tools": {
|
||||||
|
"aborted": "工具調用已中斷",
|
||||||
|
"abort_failed": "工具調用中斷失敗",
|
||||||
"completed": "已完成",
|
"completed": "已完成",
|
||||||
"invoking": "調用中",
|
"invoking": "調用中",
|
||||||
"error": "發生錯誤",
|
"error": "發生錯誤",
|
||||||
"raw": "原始碼",
|
"raw": "原始碼",
|
||||||
"preview": "預覽"
|
"preview": "預覽",
|
||||||
|
"pending": "等待中",
|
||||||
|
"cancelled": "已取消"
|
||||||
},
|
},
|
||||||
"topic.added": "新話題已新增",
|
"topic.added": "新話題已新增",
|
||||||
"upgrade.success.button": "重新啟動",
|
"upgrade.success.button": "重新啟動",
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
import { CheckOutlined, ExpandOutlined, LoadingOutlined, WarningOutlined } from '@ant-design/icons'
|
import { CheckOutlined, CloseOutlined, LoadingOutlined, WarningOutlined } from '@ant-design/icons'
|
||||||
import { useCodeStyle } from '@renderer/context/CodeStyleProvider'
|
import { useCodeStyle } from '@renderer/context/CodeStyleProvider'
|
||||||
import { useSettings } from '@renderer/hooks/useSettings'
|
import { useSettings } from '@renderer/hooks/useSettings'
|
||||||
import type { ToolMessageBlock } from '@renderer/types/newMessage'
|
import type { ToolMessageBlock } from '@renderer/types/newMessage'
|
||||||
import { Collapse, message as antdMessage, Modal, Tabs, Tooltip } from 'antd'
|
import { cancelToolAction, confirmToolAction } from '@renderer/utils/userConfirmation'
|
||||||
|
import { Collapse, message as antdMessage, Tooltip } from 'antd'
|
||||||
|
import { message } from 'antd'
|
||||||
|
import Logger from 'electron-log/renderer'
|
||||||
|
import { PauseCircle } from 'lucide-react'
|
||||||
import { FC, memo, useEffect, useMemo, useState } from 'react'
|
import { FC, memo, useEffect, useMemo, useState } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import styled from 'styled-components'
|
import styled from 'styled-components'
|
||||||
@ -14,12 +18,24 @@ interface Props {
|
|||||||
const MessageTools: FC<Props> = ({ block }) => {
|
const MessageTools: FC<Props> = ({ block }) => {
|
||||||
const [activeKeys, setActiveKeys] = useState<string[]>([])
|
const [activeKeys, setActiveKeys] = useState<string[]>([])
|
||||||
const [copiedMap, setCopiedMap] = useState<Record<string, boolean>>({})
|
const [copiedMap, setCopiedMap] = useState<Record<string, boolean>>({})
|
||||||
const [expandedResponse, setExpandedResponse] = useState<{ content: string; title: string } | null>(null)
|
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const { messageFont, fontSize } = useSettings()
|
const { messageFont, fontSize } = useSettings()
|
||||||
|
|
||||||
const toolResponse = block.metadata?.rawMcpToolResponse
|
const toolResponse = block.metadata?.rawMcpToolResponse
|
||||||
|
|
||||||
|
const { id, tool, status, response } = toolResponse!
|
||||||
|
|
||||||
|
const isPending = status === 'pending'
|
||||||
|
const isInvoking = status === 'invoking'
|
||||||
|
const isDone = status === 'done'
|
||||||
|
|
||||||
|
const argsString = useMemo(() => {
|
||||||
|
if (toolResponse?.arguments) {
|
||||||
|
return JSON.stringify(toolResponse.arguments, null, 2)
|
||||||
|
}
|
||||||
|
return 'No arguments'
|
||||||
|
}, [toolResponse])
|
||||||
|
|
||||||
const resultString = useMemo(() => {
|
const resultString = useMemo(() => {
|
||||||
try {
|
try {
|
||||||
return JSON.stringify(
|
return JSON.stringify(
|
||||||
@ -50,13 +66,34 @@ const MessageTools: FC<Props> = ({ block }) => {
|
|||||||
setActiveKeys(Array.isArray(keys) ? keys : [keys])
|
setActiveKeys(Array.isArray(keys) ? keys : [keys])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleConfirmTool = () => {
|
||||||
|
confirmToolAction(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleCancelTool = () => {
|
||||||
|
cancelToolAction(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleAbortTool = async () => {
|
||||||
|
if (toolResponse?.id) {
|
||||||
|
try {
|
||||||
|
const success = await window.api.mcp.abortTool(toolResponse.id)
|
||||||
|
if (success) {
|
||||||
|
message.success({ content: t('message.tools.aborted'), key: 'abort-tool' })
|
||||||
|
} else {
|
||||||
|
message.error({ content: t('message.tools.abort_failed'), key: 'abort-tool' })
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
Logger.error('Failed to abort tool:', error)
|
||||||
|
message.error({ content: t('message.tools.abort_failed'), key: 'abort-tool' })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Format tool responses for collapse items
|
// Format tool responses for collapse items
|
||||||
const getCollapseItems = () => {
|
const getCollapseItems = () => {
|
||||||
const items: { key: string; label: React.ReactNode; children: React.ReactNode }[] = []
|
const items: { key: string; label: React.ReactNode; children: React.ReactNode }[] = []
|
||||||
const { id, tool, status, response } = toolResponse
|
const hasError = response?.isError === true
|
||||||
const isInvoking = status === 'invoking'
|
|
||||||
const isDone = status === 'done'
|
|
||||||
const hasError = isDone && response?.isError === true
|
|
||||||
const result = {
|
const result = {
|
||||||
params: toolResponse.arguments,
|
params: toolResponse.arguments,
|
||||||
response: toolResponse.response
|
response: toolResponse.response
|
||||||
@ -68,34 +105,93 @@ const MessageTools: FC<Props> = ({ block }) => {
|
|||||||
<MessageTitleLabel>
|
<MessageTitleLabel>
|
||||||
<TitleContent>
|
<TitleContent>
|
||||||
<ToolName>{tool.name}</ToolName>
|
<ToolName>{tool.name}</ToolName>
|
||||||
<StatusIndicator $isInvoking={isInvoking} $hasError={hasError}>
|
<StatusIndicator status={status} hasError={hasError}>
|
||||||
{isInvoking
|
{(() => {
|
||||||
? t('message.tools.invoking')
|
switch (status) {
|
||||||
: hasError
|
case 'pending':
|
||||||
? t('message.tools.error')
|
return (
|
||||||
: t('message.tools.completed')}
|
<>
|
||||||
{isInvoking && <LoadingOutlined spin style={{ marginLeft: 6 }} />}
|
{t('message.tools.pending')}
|
||||||
{isDone && !hasError && <CheckOutlined style={{ marginLeft: 6 }} />}
|
<LoadingOutlined spin style={{ marginLeft: 6 }} />
|
||||||
{hasError && <WarningOutlined style={{ marginLeft: 6 }} />}
|
</>
|
||||||
|
)
|
||||||
|
case 'invoking':
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{t('message.tools.invoking')}
|
||||||
|
<LoadingOutlined spin style={{ marginLeft: 6 }} />
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
case 'cancelled':
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{t('message.tools.cancelled')}
|
||||||
|
<CloseOutlined style={{ marginLeft: 6 }} />
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
case 'done':
|
||||||
|
if (hasError) {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{t('message.tools.error')}
|
||||||
|
<WarningOutlined style={{ marginLeft: 6 }} />
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{t('message.tools.completed')}
|
||||||
|
<CheckOutlined style={{ marginLeft: 6 }} />
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return ''
|
||||||
|
}
|
||||||
|
})()}
|
||||||
</StatusIndicator>
|
</StatusIndicator>
|
||||||
</TitleContent>
|
</TitleContent>
|
||||||
<ActionButtonsContainer>
|
<ActionButtonsContainer>
|
||||||
{isDone && response && (
|
{isPending && (
|
||||||
<>
|
<>
|
||||||
<Tooltip title={t('common.expand')} mouseEnterDelay={0.5}>
|
<Tooltip title={t('common.cancel')} mouseEnterDelay={0.3}>
|
||||||
<ActionButton
|
<ActionButton
|
||||||
className="message-action-button"
|
|
||||||
onClick={(e) => {
|
onClick={(e) => {
|
||||||
e.stopPropagation()
|
e.stopPropagation()
|
||||||
setExpandedResponse({
|
handleCancelTool()
|
||||||
content: JSON.stringify(response, null, 2),
|
|
||||||
title: tool.name
|
|
||||||
})
|
|
||||||
}}
|
}}
|
||||||
aria-label={t('common.expand')}>
|
aria-label={t('common.cancel')}>
|
||||||
<ExpandOutlined />
|
<CloseOutlined style={{ fontSize: '14px' }} />
|
||||||
</ActionButton>
|
</ActionButton>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
|
<Tooltip title={t('common.confirm')} mouseEnterDelay={0.3}>
|
||||||
|
<ActionButton
|
||||||
|
className="confirm-button"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation()
|
||||||
|
handleConfirmTool()
|
||||||
|
}}
|
||||||
|
aria-label={t('common.confirm')}>
|
||||||
|
<CheckOutlined style={{ fontSize: '14px' }} />
|
||||||
|
</ActionButton>
|
||||||
|
</Tooltip>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{isInvoking && toolResponse?.id && (
|
||||||
|
<Tooltip title={t('chat.input.pause')} mouseEnterDelay={0.3}>
|
||||||
|
<ActionButton
|
||||||
|
className="abort-button"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation()
|
||||||
|
handleAbortTool()
|
||||||
|
}}
|
||||||
|
aria-label={t('chat.input.pause')}>
|
||||||
|
<PauseCircle color="var(--color-error)" size={14} />
|
||||||
|
</ActionButton>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
|
{isDone && response && (
|
||||||
|
<>
|
||||||
<Tooltip title={t('common.copy')} mouseEnterDelay={0.5}>
|
<Tooltip title={t('common.copy')} mouseEnterDelay={0.5}>
|
||||||
<ActionButton
|
<ActionButton
|
||||||
className="message-action-button"
|
className="message-action-button"
|
||||||
@ -113,98 +209,38 @@ const MessageTools: FC<Props> = ({ block }) => {
|
|||||||
</ActionButtonsContainer>
|
</ActionButtonsContainer>
|
||||||
</MessageTitleLabel>
|
</MessageTitleLabel>
|
||||||
),
|
),
|
||||||
children: isDone && result && (
|
children:
|
||||||
<ToolResponseContainer
|
isDone && result ? (
|
||||||
style={{
|
<ToolResponseContainer
|
||||||
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
|
style={{
|
||||||
fontSize: '12px'
|
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
|
||||||
}}>
|
fontSize
|
||||||
<CollapsedContent isExpanded={activeKeys.includes(id)} resultString={resultString} />
|
}}>
|
||||||
</ToolResponseContainer>
|
<CollapsedContent isExpanded={activeKeys.includes(id)} resultString={resultString} />
|
||||||
)
|
</ToolResponseContainer>
|
||||||
|
) : argsString ? (
|
||||||
|
<>
|
||||||
|
<ToolResponseContainer>
|
||||||
|
<CollapsedContent isExpanded={activeKeys.includes(id)} resultString={argsString} />
|
||||||
|
</ToolResponseContainer>
|
||||||
|
</>
|
||||||
|
) : null
|
||||||
})
|
})
|
||||||
|
|
||||||
return items
|
return items
|
||||||
}
|
}
|
||||||
|
|
||||||
const renderPreview = (content: string) => {
|
|
||||||
if (!content) return null
|
|
||||||
|
|
||||||
try {
|
|
||||||
const parsedResult = JSON.parse(content)
|
|
||||||
switch (parsedResult.content[0]?.type) {
|
|
||||||
case 'text':
|
|
||||||
return <PreviewBlock>{parsedResult.content[0].text}</PreviewBlock>
|
|
||||||
default:
|
|
||||||
return <PreviewBlock>{content}</PreviewBlock>
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.error('failed to render the preview of mcp results:', e)
|
|
||||||
return <PreviewBlock>{content}</PreviewBlock>
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<ToolContainer>
|
||||||
<CollapseContainer
|
<CollapseContainer
|
||||||
activeKey={activeKeys}
|
activeKey={activeKeys}
|
||||||
size="small"
|
size="small"
|
||||||
onChange={handleCollapseChange}
|
onChange={handleCollapseChange}
|
||||||
className="message-tools-container"
|
className="message-tools-container"
|
||||||
items={getCollapseItems()}
|
items={getCollapseItems()}
|
||||||
expandIcon={({ isActive }) => (
|
expandIconPosition="end"
|
||||||
<CollapsibleIcon className={`iconfont ${isActive ? 'icon-chevron-down' : 'icon-chevron-right'}`} />
|
|
||||||
)}
|
|
||||||
/>
|
/>
|
||||||
|
</ToolContainer>
|
||||||
<Modal
|
|
||||||
title={expandedResponse?.title}
|
|
||||||
open={!!expandedResponse}
|
|
||||||
onCancel={() => setExpandedResponse(null)}
|
|
||||||
footer={null}
|
|
||||||
width="80%"
|
|
||||||
centered
|
|
||||||
transitionName="animation-move-down"
|
|
||||||
styles={{ body: { maxHeight: '80vh', overflow: 'auto' } }}>
|
|
||||||
{expandedResponse && (
|
|
||||||
<ExpandedResponseContainer
|
|
||||||
style={{
|
|
||||||
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
|
|
||||||
fontSize
|
|
||||||
}}>
|
|
||||||
<Tabs
|
|
||||||
tabBarExtraContent={
|
|
||||||
<ActionButton
|
|
||||||
className="copy-expanded-button"
|
|
||||||
onClick={() => {
|
|
||||||
navigator.clipboard.writeText(
|
|
||||||
typeof expandedResponse.content === 'string'
|
|
||||||
? expandedResponse.content
|
|
||||||
: JSON.stringify(expandedResponse.content, null, 2)
|
|
||||||
)
|
|
||||||
antdMessage.success({ content: t('message.copied'), key: 'copy-expanded' })
|
|
||||||
}}
|
|
||||||
aria-label={t('common.copy')}>
|
|
||||||
<i className="iconfont icon-copy"></i>
|
|
||||||
</ActionButton>
|
|
||||||
}
|
|
||||||
items={[
|
|
||||||
{
|
|
||||||
key: 'preview',
|
|
||||||
label: t('message.tools.preview'),
|
|
||||||
children: <CollapsedContent isExpanded={true} resultString={resultString} />
|
|
||||||
},
|
|
||||||
{
|
|
||||||
key: 'raw',
|
|
||||||
label: t('message.tools.raw'),
|
|
||||||
children: renderPreview(expandedResponse.content)
|
|
||||||
}
|
|
||||||
]}
|
|
||||||
/>
|
|
||||||
</ExpandedResponseContainer>
|
|
||||||
)}
|
|
||||||
</Modal>
|
|
||||||
</>
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,15 +266,25 @@ const CollapsedContent: FC<{ isExpanded: boolean; resultString: string }> = ({ i
|
|||||||
}
|
}
|
||||||
|
|
||||||
const CollapseContainer = styled(Collapse)`
|
const CollapseContainer = styled(Collapse)`
|
||||||
margin-top: 10px;
|
|
||||||
margin-bottom: 12px;
|
|
||||||
border-radius: 8px;
|
border-radius: 8px;
|
||||||
|
border: none;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
|
|
||||||
.ant-collapse-header {
|
.ant-collapse-header {
|
||||||
background-color: var(--color-bg-2);
|
background-color: var(--color-bg-2);
|
||||||
transition: background-color 0.2s;
|
transition: background-color 0.2s;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
.ant-collapse-expand-icon {
|
||||||
|
height: 100% !important;
|
||||||
|
}
|
||||||
|
.ant-collapse-arrow {
|
||||||
|
height: 28px !important;
|
||||||
|
svg {
|
||||||
|
width: 14px;
|
||||||
|
height: 14px;
|
||||||
|
}
|
||||||
|
}
|
||||||
&:hover {
|
&:hover {
|
||||||
background-color: var(--color-bg-3);
|
background-color: var(--color-bg-3);
|
||||||
}
|
}
|
||||||
@ -249,6 +295,15 @@ const CollapseContainer = styled(Collapse)`
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
|
const ToolContainer = styled.div`
|
||||||
|
margin-top: 10px;
|
||||||
|
margin-bottom: 12px;
|
||||||
|
border: 1px solid var(--color-border);
|
||||||
|
background-color: var(--color-bg-2);
|
||||||
|
border-radius: 8px;
|
||||||
|
overflow: hidden;
|
||||||
|
`
|
||||||
|
|
||||||
const MarkdownContainer = styled.div`
|
const MarkdownContainer = styled.div`
|
||||||
& pre {
|
& pre {
|
||||||
background: transparent !important;
|
background: transparent !important;
|
||||||
@ -267,6 +322,7 @@ const MessageTitleLabel = styled.div`
|
|||||||
min-height: 26px;
|
min-height: 26px;
|
||||||
gap: 10px;
|
gap: 10px;
|
||||||
padding: 0;
|
padding: 0;
|
||||||
|
margin-left: 4px;
|
||||||
`
|
`
|
||||||
|
|
||||||
const TitleContent = styled.div`
|
const TitleContent = styled.div`
|
||||||
@ -282,18 +338,27 @@ const ToolName = styled.span`
|
|||||||
font-size: 13px;
|
font-size: 13px;
|
||||||
`
|
`
|
||||||
|
|
||||||
const StatusIndicator = styled.span<{ $isInvoking: boolean; $hasError?: boolean }>`
|
const StatusIndicator = styled.span<{ status: string; hasError?: boolean }>`
|
||||||
color: ${(props) => {
|
color: ${(props) => {
|
||||||
if (props.$hasError) return 'var(--color-error, #ff4d4f)'
|
switch (props.status) {
|
||||||
if (props.$isInvoking) return 'var(--color-primary)'
|
case 'pending':
|
||||||
return 'var(--color-success, #52c41a)'
|
return 'var(--color-text-2)'
|
||||||
|
case 'invoking':
|
||||||
|
return 'var(--color-primary)'
|
||||||
|
case 'cancelled':
|
||||||
|
return 'var(--color-error, #ff4d4f)' // Assuming cancelled should also be an error color
|
||||||
|
case 'done':
|
||||||
|
return props.hasError ? 'var(--color-error, #ff4d4f)' : 'var(--color-success, #52c41a)'
|
||||||
|
default:
|
||||||
|
return 'var(--color-text)'
|
||||||
|
}
|
||||||
}};
|
}};
|
||||||
font-size: 11px;
|
font-size: 11px;
|
||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
opacity: 0.85;
|
opacity: 0.85;
|
||||||
border-left: 1px solid var(--color-border);
|
border-left: 1px solid var(--color-border);
|
||||||
padding-left: 8px;
|
padding-left: 12px;
|
||||||
`
|
`
|
||||||
|
|
||||||
const ActionButtonsContainer = styled.div`
|
const ActionButtonsContainer = styled.div`
|
||||||
@ -307,18 +372,30 @@ const ActionButton = styled.button`
|
|||||||
border: none;
|
border: none;
|
||||||
color: var(--color-text-2);
|
color: var(--color-text-2);
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
padding: 4px 8px;
|
padding: 4px;
|
||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
justify-content: center;
|
justify-content: center;
|
||||||
opacity: 0.7;
|
opacity: 0.7;
|
||||||
transition: all 0.2s;
|
transition: all 0.2s;
|
||||||
border-radius: 4px;
|
border-radius: 4px;
|
||||||
|
gap: 4px;
|
||||||
|
min-width: 28px;
|
||||||
|
height: 28px;
|
||||||
|
|
||||||
&:hover {
|
&:hover {
|
||||||
opacity: 1;
|
opacity: 1;
|
||||||
color: var(--color-text);
|
color: var(--color-text);
|
||||||
background-color: var(--color-bg-1);
|
background-color: var(--color-bg-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
&.confirm-button {
|
||||||
|
color: var(--color-primary);
|
||||||
|
|
||||||
|
&:hover {
|
||||||
|
background-color: var(--color-primary-bg);
|
||||||
|
color: var(--color-primary);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
&:focus-visible {
|
&:focus-visible {
|
||||||
@ -332,12 +409,6 @@ const ActionButton = styled.button`
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
const CollapsibleIcon = styled.i`
|
|
||||||
color: var(--color-text-2);
|
|
||||||
font-size: 12px;
|
|
||||||
transition: transform 0.2s;
|
|
||||||
`
|
|
||||||
|
|
||||||
const ToolResponseContainer = styled.div`
|
const ToolResponseContainer = styled.div`
|
||||||
border-radius: 0 0 4px 4px;
|
border-radius: 0 0 4px 4px;
|
||||||
overflow: auto;
|
overflow: auto;
|
||||||
@ -346,35 +417,4 @@ const ToolResponseContainer = styled.div`
|
|||||||
position: relative;
|
position: relative;
|
||||||
`
|
`
|
||||||
|
|
||||||
const PreviewBlock = styled.div`
|
|
||||||
margin: 0;
|
|
||||||
white-space: pre-wrap;
|
|
||||||
word-break: break-word;
|
|
||||||
color: var(--color-text);
|
|
||||||
user-select: text;
|
|
||||||
`
|
|
||||||
|
|
||||||
const ExpandedResponseContainer = styled.div`
|
|
||||||
background: var(--color-bg-1);
|
|
||||||
border-radius: 8px;
|
|
||||||
padding: 16px;
|
|
||||||
position: relative;
|
|
||||||
|
|
||||||
.copy-expanded-button {
|
|
||||||
position: absolute;
|
|
||||||
top: 10px;
|
|
||||||
right: 10px;
|
|
||||||
background-color: var(--color-bg-2);
|
|
||||||
border-radius: 4px;
|
|
||||||
z-index: 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
pre {
|
|
||||||
margin: 0;
|
|
||||||
white-space: pre-wrap;
|
|
||||||
word-break: break-word;
|
|
||||||
color: var(--color-text);
|
|
||||||
}
|
|
||||||
`
|
|
||||||
|
|
||||||
export default memo(MessageTools)
|
export default memo(MessageTools)
|
||||||
|
|||||||
@ -16,6 +16,7 @@ export interface StreamProcessorCallbacks {
|
|||||||
onThinkingChunk?: (text: string, thinking_millsec?: number) => void
|
onThinkingChunk?: (text: string, thinking_millsec?: number) => void
|
||||||
onThinkingComplete?: (text: string, thinking_millsec?: number) => void
|
onThinkingComplete?: (text: string, thinking_millsec?: number) => void
|
||||||
// A tool call response chunk (from MCP)
|
// A tool call response chunk (from MCP)
|
||||||
|
onToolCallPending?: (toolResponse: MCPToolResponse) => void
|
||||||
onToolCallInProgress?: (toolResponse: MCPToolResponse) => void
|
onToolCallInProgress?: (toolResponse: MCPToolResponse) => void
|
||||||
onToolCallComplete?: (toolResponse: MCPToolResponse) => void
|
onToolCallComplete?: (toolResponse: MCPToolResponse) => void
|
||||||
// External tool call in progress
|
// External tool call in progress
|
||||||
@ -69,6 +70,10 @@ export function createStreamProcessor(callbacks: StreamProcessorCallbacks = {})
|
|||||||
if (callbacks.onThinkingComplete) callbacks.onThinkingComplete(data.text, data.thinking_millsec)
|
if (callbacks.onThinkingComplete) callbacks.onThinkingComplete(data.text, data.thinking_millsec)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
case ChunkType.MCP_TOOL_PENDING: {
|
||||||
|
if (callbacks.onToolCallPending) data.responses.forEach((toolResp) => callbacks.onToolCallPending!(toolResp))
|
||||||
|
break
|
||||||
|
}
|
||||||
case ChunkType.MCP_TOOL_IN_PROGRESS: {
|
case ChunkType.MCP_TOOL_IN_PROGRESS: {
|
||||||
if (callbacks.onToolCallInProgress)
|
if (callbacks.onToolCallInProgress)
|
||||||
data.responses.forEach((toolResp) => callbacks.onToolCallInProgress!(toolResp))
|
data.responses.forEach((toolResp) => callbacks.onToolCallInProgress!(toolResp))
|
||||||
|
|||||||
@ -529,12 +529,13 @@ const fetchAndProcessAssistantResponseImpl = async (
|
|||||||
}
|
}
|
||||||
thinkingBlockId = null
|
thinkingBlockId = null
|
||||||
},
|
},
|
||||||
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
|
onToolCallPending: (toolResponse: MCPToolResponse) => {
|
||||||
if (initialPlaceholderBlockId) {
|
if (initialPlaceholderBlockId) {
|
||||||
lastBlockType = MessageBlockType.TOOL
|
lastBlockType = MessageBlockType.TOOL
|
||||||
const changes = {
|
const changes = {
|
||||||
type: MessageBlockType.TOOL,
|
type: MessageBlockType.TOOL,
|
||||||
status: MessageBlockStatus.PROCESSING,
|
status: MessageBlockStatus.PENDING,
|
||||||
|
toolName: toolResponse.tool.name,
|
||||||
metadata: { rawMcpToolResponse: toolResponse }
|
metadata: { rawMcpToolResponse: toolResponse }
|
||||||
}
|
}
|
||||||
toolBlockId = initialPlaceholderBlockId
|
toolBlockId = initialPlaceholderBlockId
|
||||||
@ -542,14 +543,37 @@ const fetchAndProcessAssistantResponseImpl = async (
|
|||||||
dispatch(updateOneBlock({ id: toolBlockId, changes }))
|
dispatch(updateOneBlock({ id: toolBlockId, changes }))
|
||||||
saveUpdatedBlockToDB(toolBlockId, assistantMsgId, topicId, getState)
|
saveUpdatedBlockToDB(toolBlockId, assistantMsgId, topicId, getState)
|
||||||
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlockId)
|
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlockId)
|
||||||
} else if (toolResponse.status === 'invoking') {
|
} else if (toolResponse.status === 'pending') {
|
||||||
const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, {
|
const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, {
|
||||||
toolName: toolResponse.tool.name,
|
toolName: toolResponse.tool.name,
|
||||||
status: MessageBlockStatus.PROCESSING,
|
status: MessageBlockStatus.PENDING,
|
||||||
metadata: { rawMcpToolResponse: toolResponse }
|
metadata: { rawMcpToolResponse: toolResponse }
|
||||||
})
|
})
|
||||||
|
toolBlockId = toolBlock.id
|
||||||
handleBlockTransition(toolBlock, MessageBlockType.TOOL)
|
handleBlockTransition(toolBlock, MessageBlockType.TOOL)
|
||||||
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlock.id)
|
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlock.id)
|
||||||
|
} else {
|
||||||
|
console.warn(
|
||||||
|
`[onToolCallPending] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
|
||||||
|
// 根据 toolResponse.id 查找对应的块ID
|
||||||
|
const targetBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
||||||
|
|
||||||
|
if (targetBlockId && toolResponse.status === 'invoking') {
|
||||||
|
const changes = {
|
||||||
|
status: MessageBlockStatus.PROCESSING,
|
||||||
|
metadata: { rawMcpToolResponse: toolResponse }
|
||||||
|
}
|
||||||
|
dispatch(updateOneBlock({ id: targetBlockId, changes }))
|
||||||
|
saveUpdatedBlockToDB(targetBlockId, assistantMsgId, topicId, getState)
|
||||||
|
} else if (!targetBlockId) {
|
||||||
|
console.warn(
|
||||||
|
`[onToolCallInProgress] No block ID found for tool ID: ${toolResponse.id}. Available mappings:`,
|
||||||
|
Array.from(toolCallIdToBlockIdMap.entries())
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
console.warn(
|
console.warn(
|
||||||
`[onToolCallInProgress] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
`[onToolCallInProgress] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
||||||
@ -559,14 +583,17 @@ const fetchAndProcessAssistantResponseImpl = async (
|
|||||||
onToolCallComplete: (toolResponse: MCPToolResponse) => {
|
onToolCallComplete: (toolResponse: MCPToolResponse) => {
|
||||||
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
||||||
toolCallIdToBlockIdMap.delete(toolResponse.id)
|
toolCallIdToBlockIdMap.delete(toolResponse.id)
|
||||||
if (toolResponse.status === 'done' || toolResponse.status === 'error') {
|
if (toolResponse.status === 'done' || toolResponse.status === 'error' || toolResponse.status === 'cancelled') {
|
||||||
if (!existingBlockId) {
|
if (!existingBlockId) {
|
||||||
console.error(
|
console.error(
|
||||||
`[onToolCallComplete] No existing block found for completed/error tool call ID: ${toolResponse.id}. Cannot update.`
|
`[onToolCallComplete] No existing block found for completed/error tool call ID: ${toolResponse.id}. Cannot update.`
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
const finalStatus = toolResponse.status === 'done' ? MessageBlockStatus.SUCCESS : MessageBlockStatus.ERROR
|
const finalStatus =
|
||||||
|
toolResponse.status === 'done' || toolResponse.status === 'cancelled'
|
||||||
|
? MessageBlockStatus.SUCCESS
|
||||||
|
: MessageBlockStatus.ERROR
|
||||||
const changes: Partial<ToolMessageBlock> = {
|
const changes: Partial<ToolMessageBlock> = {
|
||||||
content: toolResponse.response,
|
content: toolResponse.response,
|
||||||
status: finalStatus,
|
status: finalStatus,
|
||||||
@ -583,6 +610,7 @@ const fetchAndProcessAssistantResponseImpl = async (
|
|||||||
`[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
`[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
toolBlockId = null
|
||||||
},
|
},
|
||||||
onExternalToolInProgress: async () => {
|
onExternalToolInProgress: async () => {
|
||||||
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
|
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
|
||||||
@ -762,7 +790,14 @@ const fetchAndProcessAssistantResponseImpl = async (
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
const possibleBlockId =
|
const possibleBlockId =
|
||||||
mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId
|
mainTextBlockId ||
|
||||||
|
thinkingBlockId ||
|
||||||
|
toolBlockId ||
|
||||||
|
imageBlockId ||
|
||||||
|
citationBlockId ||
|
||||||
|
initialPlaceholderBlockId ||
|
||||||
|
lastBlockId
|
||||||
|
|
||||||
if (possibleBlockId) {
|
if (possibleBlockId) {
|
||||||
// 更改上一个block的状态为ERROR
|
// 更改上一个block的状态为ERROR
|
||||||
const changes: Partial<MessageBlock> = {
|
const changes: Partial<MessageBlock> = {
|
||||||
@ -801,7 +836,13 @@ const fetchAndProcessAssistantResponseImpl = async (
|
|||||||
const finalContextWithAssistant = [...contextForUsage, finalAssistantMsg]
|
const finalContextWithAssistant = [...contextForUsage, finalAssistantMsg]
|
||||||
|
|
||||||
const possibleBlockId =
|
const possibleBlockId =
|
||||||
mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId
|
mainTextBlockId ||
|
||||||
|
thinkingBlockId ||
|
||||||
|
toolBlockId ||
|
||||||
|
imageBlockId ||
|
||||||
|
citationBlockId ||
|
||||||
|
initialPlaceholderBlockId ||
|
||||||
|
lastBlockId
|
||||||
if (possibleBlockId) {
|
if (possibleBlockId) {
|
||||||
const changes: Partial<MessageBlock> = {
|
const changes: Partial<MessageBlock> = {
|
||||||
status: MessageBlockStatus.SUCCESS
|
status: MessageBlockStatus.SUCCESS
|
||||||
@ -1109,7 +1150,6 @@ export const resendMessageThunk =
|
|||||||
// 没有相关的助手消息就创建一个或多个
|
// 没有相关的助手消息就创建一个或多个
|
||||||
|
|
||||||
if (userMessageToResend?.mentions?.length) {
|
if (userMessageToResend?.mentions?.length) {
|
||||||
console.log('userMessageToResend.mentions', userMessageToResend.mentions)
|
|
||||||
for (const mention of userMessageToResend.mentions) {
|
for (const mention of userMessageToResend.mentions) {
|
||||||
const assistantMessage = createAssistantMessage(assistant.id, topicId, {
|
const assistantMessage = createAssistantMessage(assistant.id, topicId, {
|
||||||
askId: userMessageToResend.id,
|
askId: userMessageToResend.id,
|
||||||
|
|||||||
@ -13,6 +13,7 @@ export enum ChunkType {
|
|||||||
KNOWLEDGE_SEARCH_IN_PROGRESS = 'knowledge_search_in_progress',
|
KNOWLEDGE_SEARCH_IN_PROGRESS = 'knowledge_search_in_progress',
|
||||||
KNOWLEDGE_SEARCH_COMPLETE = 'knowledge_search_complete',
|
KNOWLEDGE_SEARCH_COMPLETE = 'knowledge_search_complete',
|
||||||
MCP_TOOL_CREATED = 'mcp_tool_created',
|
MCP_TOOL_CREATED = 'mcp_tool_created',
|
||||||
|
MCP_TOOL_PENDING = 'mcp_tool_pending',
|
||||||
MCP_TOOL_IN_PROGRESS = 'mcp_tool_in_progress',
|
MCP_TOOL_IN_PROGRESS = 'mcp_tool_in_progress',
|
||||||
MCP_TOOL_COMPLETE = 'mcp_tool_complete',
|
MCP_TOOL_COMPLETE = 'mcp_tool_complete',
|
||||||
EXTERNEL_TOOL_COMPLETE = 'externel_tool_complete',
|
EXTERNEL_TOOL_COMPLETE = 'externel_tool_complete',
|
||||||
@ -260,6 +261,11 @@ export interface MCPToolCreatedChunk {
|
|||||||
tool_use_responses?: ToolUseResponse[] // 工具使用响应
|
tool_use_responses?: ToolUseResponse[] // 工具使用响应
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface MCPToolPendingChunk {
|
||||||
|
type: ChunkType.MCP_TOOL_PENDING
|
||||||
|
responses: MCPToolResponse[]
|
||||||
|
}
|
||||||
|
|
||||||
export interface MCPToolInProgressChunk {
|
export interface MCPToolInProgressChunk {
|
||||||
/**
|
/**
|
||||||
* The type of the chunk
|
* The type of the chunk
|
||||||
@ -353,6 +359,7 @@ export type Chunk =
|
|||||||
| KnowledgeSearchInProgressChunk // 知识库搜索进行中
|
| KnowledgeSearchInProgressChunk // 知识库搜索进行中
|
||||||
| KnowledgeSearchCompleteChunk // 知识库搜索完成
|
| KnowledgeSearchCompleteChunk // 知识库搜索完成
|
||||||
| MCPToolCreatedChunk // MCP工具被大模型创建
|
| MCPToolCreatedChunk // MCP工具被大模型创建
|
||||||
|
| MCPToolPendingChunk // MCP工具调用等待中
|
||||||
| MCPToolInProgressChunk // MCP工具调用中
|
| MCPToolInProgressChunk // MCP工具调用中
|
||||||
| MCPToolCompleteChunk // MCP工具调用完成
|
| MCPToolCompleteChunk // MCP工具调用完成
|
||||||
| ExternalToolCompleteChunk // 外部工具调用完成,外部工具包含搜索互联网,知识库,MCP服务器
|
| ExternalToolCompleteChunk // 外部工具调用完成,外部工具包含搜索互联网,知识库,MCP服务器
|
||||||
|
|||||||
@ -683,11 +683,13 @@ export interface MCPConfig {
|
|||||||
isBunInstalled: boolean
|
isBunInstalled: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type MCPToolResponseStatus = 'pending' | 'cancelled' | 'invoking' | 'done' | 'error'
|
||||||
|
|
||||||
interface BaseToolResponse {
|
interface BaseToolResponse {
|
||||||
id: string // unique id
|
id: string // unique id
|
||||||
tool: MCPTool
|
tool: MCPTool
|
||||||
arguments: Record<string, unknown> | undefined
|
arguments: Record<string, unknown> | undefined
|
||||||
status: string // 'invoking' | 'done'
|
status: MCPToolResponseStatus
|
||||||
response?: any
|
response?: any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
693
src/renderer/src/utils/__tests__/tagExtraction.test.ts
Normal file
693
src/renderer/src/utils/__tests__/tagExtraction.test.ts
Normal file
@ -0,0 +1,693 @@
|
|||||||
|
import { describe, expect, test } from 'vitest'
|
||||||
|
|
||||||
|
import { TagConfig, TagExtractor } from '../tagExtraction'
|
||||||
|
|
||||||
|
describe('TagExtractor', () => {
|
||||||
|
describe('基本标签提取', () => {
|
||||||
|
test('应该正确提取简单的标签内容', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
const results = extractor.processText('<think>Hello World</think>')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(2)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: 'Hello World',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: 'Hello World'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test('应该处理标签前后的普通文本', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
const results = extractor.processText('前文<think>思考内容</think>后文')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(4)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '前文',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '思考内容',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[2]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '思考内容'
|
||||||
|
})
|
||||||
|
expect(results[3]).toEqual({
|
||||||
|
content: '后文',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test('应该处理空标签', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
const results = extractor.processText('<think></think>')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('分块处理', () => {
|
||||||
|
test('应该正确处理分块的标签内容', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
let results = extractor.processText('<think>第一')
|
||||||
|
expect(results).toHaveLength(1)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '第一',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
|
||||||
|
results = extractor.processText('部分内容')
|
||||||
|
expect(results).toHaveLength(1)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '部分内容',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
|
||||||
|
results = extractor.processText('</think>')
|
||||||
|
expect(results).toHaveLength(1)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '第一部分内容'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test('应该处理分块的开始标签', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
let results = extractor.processText('<thi')
|
||||||
|
expect(results).toHaveLength(0)
|
||||||
|
|
||||||
|
results = extractor.processText('nk>内容</think>')
|
||||||
|
expect(results).toHaveLength(2)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '内容',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '内容'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test('应该处理模拟可读流的分块数据', async () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
// 模拟流式数据块
|
||||||
|
const streamChunks = [
|
||||||
|
'这是普通文本',
|
||||||
|
'<thi',
|
||||||
|
'nk>这是第一个',
|
||||||
|
'思考内容',
|
||||||
|
'</think>',
|
||||||
|
'中间的一些文本',
|
||||||
|
'<think>第二',
|
||||||
|
'个思考内容',
|
||||||
|
'</thi',
|
||||||
|
'nk>',
|
||||||
|
'结束文本'
|
||||||
|
]
|
||||||
|
|
||||||
|
const allResults: any[] = []
|
||||||
|
|
||||||
|
// 模拟异步流式处理
|
||||||
|
for (const chunk of streamChunks) {
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 10)) // 模拟异步延迟
|
||||||
|
const results = extractor.processText(chunk)
|
||||||
|
allResults.push(...results)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证结果
|
||||||
|
expect(allResults).toHaveLength(9)
|
||||||
|
|
||||||
|
// 第一个普通文本
|
||||||
|
expect(allResults[0]).toEqual({
|
||||||
|
content: '这是普通文本',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
|
||||||
|
// 第一个思考标签内容
|
||||||
|
expect(allResults[1]).toEqual({
|
||||||
|
content: '这是第一个',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(allResults[2]).toEqual({
|
||||||
|
content: '思考内容',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
|
||||||
|
// 第一个完整的标签内容提取
|
||||||
|
expect(allResults[3]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '这是第一个思考内容'
|
||||||
|
})
|
||||||
|
|
||||||
|
// 中间文本
|
||||||
|
expect(allResults[4]).toEqual({
|
||||||
|
content: '中间的一些文本',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
|
||||||
|
// 第二个思考标签内容
|
||||||
|
expect(allResults[5]).toEqual({
|
||||||
|
content: '第二',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
|
||||||
|
// 第二个完整的标签内容提取和结束文本
|
||||||
|
expect(allResults[6]).toEqual({
|
||||||
|
content: '个思考内容',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(allResults[7]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '第二个思考内容'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(allResults[8]).toEqual({
|
||||||
|
content: '结束文本',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('多个标签处理', () => {
|
||||||
|
test('应该处理连续的多个标签', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
const results = extractor.processText('<think>第一个</think><think>第二个</think>')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(4)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '第一个',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '第一个'
|
||||||
|
})
|
||||||
|
expect(results[2]).toEqual({
|
||||||
|
content: '第二个',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[3]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '第二个'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test('应该处理标签间的文本内容', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
const results = extractor.processText('<think>思考1</think>中间文本<think>思考2</think>')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(5)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '思考1',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '思考1'
|
||||||
|
})
|
||||||
|
expect(results[2]).toEqual({
|
||||||
|
content: '中间文本',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[3]).toEqual({
|
||||||
|
content: '思考2',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[4]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '思考2'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test('应该处理三个连续标签的分次输出', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
// 第一次输入:包含两个完整标签和第三个标签的开始
|
||||||
|
let results = extractor.processText('<think>第一个</think><think>第二个</think><think>第三个开始')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(5)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '第一个',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '第一个'
|
||||||
|
})
|
||||||
|
expect(results[2]).toEqual({
|
||||||
|
content: '第二个',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[3]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '第二个'
|
||||||
|
})
|
||||||
|
expect(results[4]).toEqual({
|
||||||
|
content: '第三个开始',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
|
||||||
|
// 第二次输入:继续第三个标签的内容
|
||||||
|
results = extractor.processText('继续内容')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(1)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '继续内容',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
|
||||||
|
// 第三次输入:完成第三个标签
|
||||||
|
results = extractor.processText('结束</think>')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(2)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '结束',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '第三个开始继续内容结束'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test('应该处理三个连续标签的另一种分次输出模式', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
// 第一次输入:第一个完整标签
|
||||||
|
let results = extractor.processText('<think>第一个思考</think>')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(2)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '第一个思考',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '第一个思考'
|
||||||
|
})
|
||||||
|
|
||||||
|
// 第二次输入:第二个完整标签和第三个标签的部分内容
|
||||||
|
results = extractor.processText('<think>第二个思考</think><think>第三个开')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(3)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '第二个思考',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '第二个思考'
|
||||||
|
})
|
||||||
|
expect(results[2]).toEqual({
|
||||||
|
content: '第三个开',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
|
||||||
|
// 第三次输入:完成第三个标签
|
||||||
|
results = extractor.processText('始部分</think>')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(2)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '始部分',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '第三个开始部分'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('不完整标签处理', () => {
|
||||||
|
test('应该处理只有开始标签的情况', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
const results = extractor.processText('<think>未完成的思考')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(1)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '未完成的思考',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test('应该处理文本中间截断的标签', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
const results = extractor.processText('正常文本<thi')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(1)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '正常文本',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('finalize 方法', () => {
|
||||||
|
test('应该返回未完成的标签内容', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
extractor.processText('<think>未完成的内容')
|
||||||
|
const result = extractor.finalize()
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '未完成的内容'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test('当没有未完成内容时应该返回 null', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
extractor.processText('<think>完整内容</think>')
|
||||||
|
const result = extractor.finalize()
|
||||||
|
|
||||||
|
expect(result).toBeNull()
|
||||||
|
})
|
||||||
|
|
||||||
|
test('对于普通文本应该返回 null', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
extractor.processText('只是普通文本')
|
||||||
|
const result = extractor.finalize()
|
||||||
|
|
||||||
|
expect(result).toBeNull()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('reset 方法', () => {
|
||||||
|
test('应该重置所有内部状态', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
// 处理一些文本以改变内部状态
|
||||||
|
extractor.processText('<think>一些内容')
|
||||||
|
|
||||||
|
// 重置
|
||||||
|
extractor.reset()
|
||||||
|
|
||||||
|
// 重置后应该能正常处理新的文本
|
||||||
|
const results = extractor.processText('<think>新内容</think>')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(2)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '新内容',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '新内容'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test('重置后 finalize 应该返回 null', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
extractor.processText('<think>未完成')
|
||||||
|
extractor.reset()
|
||||||
|
|
||||||
|
const result = extractor.finalize()
|
||||||
|
expect(result).toBeNull()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('不同标签配置', () => {
|
||||||
|
test('应该处理工具使用标签', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<tool_use>',
|
||||||
|
closingTag: '</tool_use>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
const results = extractor.processText('<tool_use>{"name": "search"}</tool_use>')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(2)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '{"name": "search"}',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '{"name": "search"}'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test('应该处理自定义标签', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '[START]',
|
||||||
|
closingTag: '[END]'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
const results = extractor.processText('前文[START]中间内容[END]后文')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(4)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '前文',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '中间内容',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[2]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '中间内容'
|
||||||
|
})
|
||||||
|
expect(results[3]).toEqual({
|
||||||
|
content: '后文',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('边界情况', () => {
|
||||||
|
test('应该处理空字符串输入', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
const results = extractor.processText('')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
test('应该处理只包含标签的输入', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
const results = extractor.processText('<think></think>')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
test('应该处理标签内容包含相似文本的情况', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
const results = extractor.processText('<think>我在<thinking>思考</think>')
|
||||||
|
|
||||||
|
expect(results).toHaveLength(2)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: '我在<thinking>思考',
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: '我在<thinking>思考'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test('应该处理换行符和特殊字符', () => {
|
||||||
|
const config: TagConfig = {
|
||||||
|
openingTag: '<think>',
|
||||||
|
closingTag: '</think>'
|
||||||
|
}
|
||||||
|
const extractor = new TagExtractor(config)
|
||||||
|
|
||||||
|
const content = '多行\n内容\t带制表符'
|
||||||
|
const results = extractor.processText(`<think>${content}</think>`)
|
||||||
|
|
||||||
|
expect(results).toHaveLength(2)
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
content: content,
|
||||||
|
isTagContent: true,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: content
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -14,9 +14,8 @@ import {
|
|||||||
Model,
|
Model,
|
||||||
ToolUseResponse
|
ToolUseResponse
|
||||||
} from '@renderer/types'
|
} from '@renderer/types'
|
||||||
import type { MCPToolCompleteChunk, MCPToolInProgressChunk } from '@renderer/types/chunk'
|
import type { MCPToolCompleteChunk, MCPToolInProgressChunk, MCPToolPendingChunk } from '@renderer/types/chunk'
|
||||||
import { ChunkType } from '@renderer/types/chunk'
|
import { ChunkType } from '@renderer/types/chunk'
|
||||||
import { SdkMessageParam } from '@renderer/types/sdk'
|
|
||||||
import { isArray, isObject, pull, transform } from 'lodash'
|
import { isArray, isObject, pull, transform } from 'lodash'
|
||||||
import { nanoid } from 'nanoid'
|
import { nanoid } from 'nanoid'
|
||||||
import OpenAI from 'openai'
|
import OpenAI from 'openai'
|
||||||
@ -28,6 +27,7 @@ import {
|
|||||||
} from 'openai/resources'
|
} from 'openai/resources'
|
||||||
|
|
||||||
import { CompletionsParams } from '../aiCore/middleware/schemas'
|
import { CompletionsParams } from '../aiCore/middleware/schemas'
|
||||||
|
import { requestToolConfirmation } from './userConfirmation'
|
||||||
|
|
||||||
const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install'
|
const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install'
|
||||||
const EXTRA_SCHEMA_KEYS = ['schema', 'headers']
|
const EXTRA_SCHEMA_KEYS = ['schema', 'headers']
|
||||||
@ -278,7 +278,8 @@ export async function callMCPTool(toolResponse: MCPToolResponse): Promise<MCPCal
|
|||||||
const resp = await window.api.mcp.callTool({
|
const resp = await window.api.mcp.callTool({
|
||||||
server,
|
server,
|
||||||
name: toolResponse.tool.name,
|
name: toolResponse.tool.name,
|
||||||
args: toolResponse.arguments
|
args: toolResponse.arguments,
|
||||||
|
callId: toolResponse.id
|
||||||
})
|
})
|
||||||
if (toolResponse.tool.serverName === MCP_AUTO_INSTALL_SERVER_NAME) {
|
if (toolResponse.tool.serverName === MCP_AUTO_INSTALL_SERVER_NAME) {
|
||||||
if (resp.data) {
|
if (resp.data) {
|
||||||
@ -400,7 +401,7 @@ export function geminiFunctionCallToMcpTool(
|
|||||||
export function upsertMCPToolResponse(
|
export function upsertMCPToolResponse(
|
||||||
results: MCPToolResponse[],
|
results: MCPToolResponse[],
|
||||||
resp: MCPToolResponse,
|
resp: MCPToolResponse,
|
||||||
onChunk: (chunk: MCPToolInProgressChunk | MCPToolCompleteChunk) => void
|
onChunk: (chunk: MCPToolPendingChunk | MCPToolInProgressChunk | MCPToolCompleteChunk) => void
|
||||||
) {
|
) {
|
||||||
const index = results.findIndex((ret) => ret.id === resp.id)
|
const index = results.findIndex((ret) => ret.id === resp.id)
|
||||||
let result = resp
|
let result = resp
|
||||||
@ -416,10 +417,29 @@ export function upsertMCPToolResponse(
|
|||||||
} else {
|
} else {
|
||||||
results.push(resp)
|
results.push(resp)
|
||||||
}
|
}
|
||||||
onChunk({
|
switch (resp.status) {
|
||||||
type: resp.status === 'invoking' ? ChunkType.MCP_TOOL_IN_PROGRESS : ChunkType.MCP_TOOL_COMPLETE,
|
case 'pending':
|
||||||
responses: [result]
|
onChunk({
|
||||||
})
|
type: ChunkType.MCP_TOOL_PENDING,
|
||||||
|
responses: [result]
|
||||||
|
})
|
||||||
|
break
|
||||||
|
case 'invoking':
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.MCP_TOOL_IN_PROGRESS,
|
||||||
|
responses: [result]
|
||||||
|
})
|
||||||
|
break
|
||||||
|
case 'cancelled':
|
||||||
|
case 'done':
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||||
|
responses: [result]
|
||||||
|
})
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function filterMCPTools(
|
export function filterMCPTools(
|
||||||
@ -441,7 +461,7 @@ export function getMcpServerByTool(tool: MCPTool) {
|
|||||||
return servers.find((s) => s.id === tool.serverId)
|
return servers.find((s) => s.id === tool.serverId)
|
||||||
}
|
}
|
||||||
|
|
||||||
export function parseToolUse(content: string, mcpTools: MCPTool[]): ToolUseResponse[] {
|
export function parseToolUse(content: string, mcpTools: MCPTool[], startIdx: number = 0): ToolUseResponse[] {
|
||||||
if (!content || !mcpTools || mcpTools.length === 0) {
|
if (!content || !mcpTools || mcpTools.length === 0) {
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
@ -461,7 +481,7 @@ export function parseToolUse(content: string, mcpTools: MCPTool[]): ToolUseRespo
|
|||||||
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
|
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
|
||||||
const tools: ToolUseResponse[] = []
|
const tools: ToolUseResponse[] = []
|
||||||
let match
|
let match
|
||||||
let idx = 0
|
let idx = startIdx
|
||||||
// Find all tool use blocks
|
// Find all tool use blocks
|
||||||
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
|
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
|
||||||
// const fullMatch = match[0]
|
// const fullMatch = match[0]
|
||||||
@ -505,8 +525,9 @@ export async function parseAndCallTools<R>(
|
|||||||
onChunk: CompletionsParams['onChunk'],
|
onChunk: CompletionsParams['onChunk'],
|
||||||
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||||
model: Model,
|
model: Model,
|
||||||
mcpTools?: MCPTool[]
|
mcpTools?: MCPTool[],
|
||||||
): Promise<SdkMessageParam[]>
|
abortSignal?: AbortSignal
|
||||||
|
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }>
|
||||||
|
|
||||||
export async function parseAndCallTools<R>(
|
export async function parseAndCallTools<R>(
|
||||||
content: string,
|
content: string,
|
||||||
@ -514,8 +535,9 @@ export async function parseAndCallTools<R>(
|
|||||||
onChunk: CompletionsParams['onChunk'],
|
onChunk: CompletionsParams['onChunk'],
|
||||||
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||||
model: Model,
|
model: Model,
|
||||||
mcpTools?: MCPTool[]
|
mcpTools?: MCPTool[],
|
||||||
): Promise<SdkMessageParam[]>
|
abortSignal?: AbortSignal
|
||||||
|
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }>
|
||||||
|
|
||||||
export async function parseAndCallTools<R>(
|
export async function parseAndCallTools<R>(
|
||||||
content: string | MCPToolResponse[],
|
content: string | MCPToolResponse[],
|
||||||
@ -523,68 +545,172 @@ export async function parseAndCallTools<R>(
|
|||||||
onChunk: CompletionsParams['onChunk'],
|
onChunk: CompletionsParams['onChunk'],
|
||||||
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||||
model: Model,
|
model: Model,
|
||||||
mcpTools?: MCPTool[]
|
mcpTools?: MCPTool[],
|
||||||
): Promise<R[]> {
|
abortSignal?: AbortSignal
|
||||||
|
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }> {
|
||||||
const toolResults: R[] = []
|
const toolResults: R[] = []
|
||||||
let curToolResponses: MCPToolResponse[] = []
|
let curToolResponses: MCPToolResponse[] = []
|
||||||
if (Array.isArray(content)) {
|
if (Array.isArray(content)) {
|
||||||
curToolResponses = content
|
curToolResponses = content
|
||||||
} else {
|
} else {
|
||||||
// process tool use
|
// process tool use
|
||||||
curToolResponses = parseToolUse(content, mcpTools || [])
|
curToolResponses = parseToolUse(content, mcpTools || [], 0)
|
||||||
}
|
}
|
||||||
if (!curToolResponses || curToolResponses.length === 0) {
|
if (!curToolResponses || curToolResponses.length === 0) {
|
||||||
return toolResults
|
return { toolResults, confirmedToolResponses: [] }
|
||||||
}
|
}
|
||||||
for (let i = 0; i < curToolResponses.length; i++) {
|
|
||||||
const toolResponse = curToolResponses[i]
|
for (const toolResponse of curToolResponses) {
|
||||||
upsertMCPToolResponse(
|
upsertMCPToolResponse(
|
||||||
allToolResponses,
|
allToolResponses,
|
||||||
{
|
{
|
||||||
...toolResponse,
|
...toolResponse,
|
||||||
status: 'invoking'
|
status: 'pending'
|
||||||
},
|
},
|
||||||
onChunk!
|
onChunk!
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const toolPromises = curToolResponses.map(async (toolResponse) => {
|
// 创建工具确认Promise映射,并立即处理每个确认
|
||||||
const images: string[] = []
|
const confirmedTools: MCPToolResponse[] = []
|
||||||
const toolCallResponse = await callMCPTool(toolResponse)
|
const pendingPromises: Promise<void>[] = []
|
||||||
upsertMCPToolResponse(
|
|
||||||
allToolResponses,
|
|
||||||
{
|
|
||||||
...toolResponse,
|
|
||||||
status: 'done',
|
|
||||||
response: toolCallResponse
|
|
||||||
},
|
|
||||||
onChunk!
|
|
||||||
)
|
|
||||||
|
|
||||||
for (const content of toolCallResponse.content) {
|
curToolResponses.forEach((toolResponse) => {
|
||||||
if (content.type === 'image' && content.data) {
|
const confirmationPromise = requestToolConfirmation(toolResponse.id, abortSignal)
|
||||||
images.push(`data:${content.mimeType};base64,${content.data}`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (images.length) {
|
const processingPromise = confirmationPromise
|
||||||
onChunk?.({
|
.then(async (confirmed) => {
|
||||||
type: ChunkType.IMAGE_CREATED
|
if (confirmed) {
|
||||||
})
|
// 立即更新为invoking状态
|
||||||
onChunk?.({
|
upsertMCPToolResponse(
|
||||||
type: ChunkType.IMAGE_COMPLETE,
|
allToolResponses,
|
||||||
image: {
|
{
|
||||||
type: 'base64',
|
...toolResponse,
|
||||||
images: images
|
status: 'invoking'
|
||||||
|
},
|
||||||
|
onChunk!
|
||||||
|
)
|
||||||
|
|
||||||
|
// 执行工具调用
|
||||||
|
try {
|
||||||
|
const images: string[] = []
|
||||||
|
const toolCallResponse = await callMCPTool(toolResponse)
|
||||||
|
|
||||||
|
// 立即更新为done状态
|
||||||
|
upsertMCPToolResponse(
|
||||||
|
allToolResponses,
|
||||||
|
{
|
||||||
|
...toolResponse,
|
||||||
|
status: 'done',
|
||||||
|
response: toolCallResponse
|
||||||
|
},
|
||||||
|
onChunk!
|
||||||
|
)
|
||||||
|
|
||||||
|
// 处理图片
|
||||||
|
for (const content of toolCallResponse.content) {
|
||||||
|
if (content.type === 'image' && content.data) {
|
||||||
|
images.push(`data:${content.mimeType};base64,${content.data}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (images.length) {
|
||||||
|
onChunk?.({
|
||||||
|
type: ChunkType.IMAGE_CREATED
|
||||||
|
})
|
||||||
|
onChunk?.({
|
||||||
|
type: ChunkType.IMAGE_COMPLETE,
|
||||||
|
image: {
|
||||||
|
type: 'base64',
|
||||||
|
images: images
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换消息并添加到结果
|
||||||
|
const convertedMessage = convertToMessage(toolResponse, toolCallResponse, model)
|
||||||
|
if (convertedMessage) {
|
||||||
|
confirmedTools.push(toolResponse)
|
||||||
|
toolResults.push(convertedMessage)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
Logger.error(`🔧 [MCP] Error executing tool ${toolResponse.id}:`, error)
|
||||||
|
// 更新为错误状态
|
||||||
|
upsertMCPToolResponse(
|
||||||
|
allToolResponses,
|
||||||
|
{
|
||||||
|
...toolResponse,
|
||||||
|
status: 'done',
|
||||||
|
response: {
|
||||||
|
isError: true,
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: `Error executing tool: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
onChunk!
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 立即更新为cancelled状态
|
||||||
|
upsertMCPToolResponse(
|
||||||
|
allToolResponses,
|
||||||
|
{
|
||||||
|
...toolResponse,
|
||||||
|
status: 'cancelled',
|
||||||
|
response: {
|
||||||
|
isError: false,
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: 'Tool call cancelled by user.'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
onChunk!
|
||||||
|
)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
.catch((error) => {
|
||||||
|
Logger.error(`🔧 [MCP] Error waiting for tool confirmation ${toolResponse.id}:`, error)
|
||||||
|
// 立即更新为cancelled状态
|
||||||
|
upsertMCPToolResponse(
|
||||||
|
allToolResponses,
|
||||||
|
{
|
||||||
|
...toolResponse,
|
||||||
|
status: 'cancelled',
|
||||||
|
response: {
|
||||||
|
isError: true,
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: `Error in confirmation process: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
onChunk!
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
return convertToMessage(toolResponse, toolCallResponse, model)
|
pendingPromises.push(processingPromise)
|
||||||
})
|
})
|
||||||
|
|
||||||
toolResults.push(...(await Promise.all(toolPromises)).filter((t) => typeof t !== 'undefined'))
|
Logger.info(
|
||||||
return toolResults
|
`🔧 [MCP] Waiting for tool confirmations:`,
|
||||||
|
curToolResponses.map((t) => t.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
// 等待所有工具处理完成(但每个工具的状态已经实时更新)
|
||||||
|
await Promise.all(pendingPromises)
|
||||||
|
|
||||||
|
Logger.info(`🔧 [MCP] All tools processed. Confirmed tools: ${confirmedTools.length}`)
|
||||||
|
|
||||||
|
return { toolResults, confirmedToolResponses: confirmedTools }
|
||||||
}
|
}
|
||||||
|
|
||||||
export function mcpToolCallResponseToOpenAICompatibleMessage(
|
export function mcpToolCallResponseToOpenAICompatibleMessage(
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import store from '@renderer/store'
|
|||||||
import { Assistant, MCPTool } from '@renderer/types'
|
import { Assistant, MCPTool } from '@renderer/types'
|
||||||
|
|
||||||
export const SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \
|
export const SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \
|
||||||
You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
|
You can use one or more tools per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
|
||||||
|
|
||||||
## Tool Use Formatting
|
## Tool Use Formatting
|
||||||
|
|
||||||
|
|||||||
86
src/renderer/src/utils/userConfirmation.ts
Normal file
86
src/renderer/src/utils/userConfirmation.ts
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import Logger from '@renderer/config/logger'
|
||||||
|
|
||||||
|
// 存储每个工具的确认Promise的resolve函数
|
||||||
|
const toolConfirmResolvers = new Map<string, (value: boolean) => void>()
|
||||||
|
// 存储每个工具的abort监听器清理函数
|
||||||
|
const abortListeners = new Map<string, () => void>()
|
||||||
|
|
||||||
|
export function requestUserConfirmation(): Promise<boolean> {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
const globalKey = '_global'
|
||||||
|
toolConfirmResolvers.set(globalKey, resolve)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export function requestToolConfirmation(toolId: string, abortSignal?: AbortSignal): Promise<boolean> {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
if (abortSignal?.aborted) {
|
||||||
|
resolve(false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
toolConfirmResolvers.set(toolId, resolve)
|
||||||
|
|
||||||
|
if (abortSignal) {
|
||||||
|
const abortListener = () => {
|
||||||
|
const resolver = toolConfirmResolvers.get(toolId)
|
||||||
|
if (resolver) {
|
||||||
|
resolver(false)
|
||||||
|
toolConfirmResolvers.delete(toolId)
|
||||||
|
abortListeners.delete(toolId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abortSignal.addEventListener('abort', abortListener)
|
||||||
|
|
||||||
|
// 存储清理函数
|
||||||
|
const cleanup = () => {
|
||||||
|
abortSignal.removeEventListener('abort', abortListener)
|
||||||
|
abortListeners.delete(toolId)
|
||||||
|
}
|
||||||
|
abortListeners.set(toolId, cleanup)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export function confirmToolAction(toolId: string) {
|
||||||
|
const resolve = toolConfirmResolvers.get(toolId)
|
||||||
|
if (resolve) {
|
||||||
|
resolve(true)
|
||||||
|
toolConfirmResolvers.delete(toolId)
|
||||||
|
|
||||||
|
// 清理abort监听器
|
||||||
|
const cleanup = abortListeners.get(toolId)
|
||||||
|
if (cleanup) {
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Logger.warn(`🔧 [userConfirmation] No resolver found for tool: ${toolId}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function cancelToolAction(toolId: string) {
|
||||||
|
const resolve = toolConfirmResolvers.get(toolId)
|
||||||
|
if (resolve) {
|
||||||
|
resolve(false)
|
||||||
|
toolConfirmResolvers.delete(toolId)
|
||||||
|
|
||||||
|
// 清理abort监听器
|
||||||
|
const cleanup = abortListeners.get(toolId)
|
||||||
|
if (cleanup) {
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Logger.warn(`🔧 [userConfirmation] No resolver found for tool: ${toolId}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取所有待确认的工具ID
|
||||||
|
export function getPendingToolIds(): string[] {
|
||||||
|
return Array.from(toolConfirmResolvers.keys()).filter((id) => id !== '_global')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查某个工具是否在等待确认
|
||||||
|
export function isToolPending(toolId: string): boolean {
|
||||||
|
return toolConfirmResolvers.has(toolId)
|
||||||
|
}
|
||||||
16
yarn.lock
16
yarn.lock
@ -3756,11 +3756,11 @@ __metadata:
|
|||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
"@modelcontextprotocol/sdk@npm:^1.11.4":
|
"@modelcontextprotocol/sdk@npm:^1.12.3":
|
||||||
version: 1.11.4
|
version: 1.12.3
|
||||||
resolution: "@modelcontextprotocol/sdk@npm:1.11.4"
|
resolution: "@modelcontextprotocol/sdk@npm:1.12.3"
|
||||||
dependencies:
|
dependencies:
|
||||||
ajv: "npm:^8.17.1"
|
ajv: "npm:^6.12.6"
|
||||||
content-type: "npm:^1.0.5"
|
content-type: "npm:^1.0.5"
|
||||||
cors: "npm:^2.8.5"
|
cors: "npm:^2.8.5"
|
||||||
cross-spawn: "npm:^7.0.5"
|
cross-spawn: "npm:^7.0.5"
|
||||||
@ -3771,7 +3771,7 @@ __metadata:
|
|||||||
raw-body: "npm:^3.0.0"
|
raw-body: "npm:^3.0.0"
|
||||||
zod: "npm:^3.23.8"
|
zod: "npm:^3.23.8"
|
||||||
zod-to-json-schema: "npm:^3.24.1"
|
zod-to-json-schema: "npm:^3.24.1"
|
||||||
checksum: 10c0/797694937e65ccc02e8dc63db711d9d96fbc49b49e6d246e6fed95d8d2bfe98ef203207224e39c9fc3b54da182da865a5d311ea06ef939c5c57ce0cd27c0f546
|
checksum: 10c0/8bc0b91e596ec886efc64d68ae8474247647405f1a5ae407e02439c74c2a03528b3fbdce8f9352d9c2df54aa4548411e1aa1816ab3b09e045c2ff4202e2fd374
|
||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
@ -7092,7 +7092,7 @@ __metadata:
|
|||||||
"@libsql/client": "npm:0.14.0"
|
"@libsql/client": "npm:0.14.0"
|
||||||
"@libsql/win32-x64-msvc": "npm:^0.4.7"
|
"@libsql/win32-x64-msvc": "npm:^0.4.7"
|
||||||
"@mistralai/mistralai": "npm:^1.6.0"
|
"@mistralai/mistralai": "npm:^1.6.0"
|
||||||
"@modelcontextprotocol/sdk": "npm:^1.11.4"
|
"@modelcontextprotocol/sdk": "npm:^1.12.3"
|
||||||
"@mozilla/readability": "npm:^0.6.0"
|
"@mozilla/readability": "npm:^0.6.0"
|
||||||
"@notionhq/client": "npm:^2.2.15"
|
"@notionhq/client": "npm:^2.2.15"
|
||||||
"@playwright/test": "npm:^1.52.0"
|
"@playwright/test": "npm:^1.52.0"
|
||||||
@ -7349,7 +7349,7 @@ __metadata:
|
|||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
"ajv@npm:^6.10.0, ajv@npm:^6.12.0, ajv@npm:^6.12.4":
|
"ajv@npm:^6.10.0, ajv@npm:^6.12.0, ajv@npm:^6.12.4, ajv@npm:^6.12.6":
|
||||||
version: 6.12.6
|
version: 6.12.6
|
||||||
resolution: "ajv@npm:6.12.6"
|
resolution: "ajv@npm:6.12.6"
|
||||||
dependencies:
|
dependencies:
|
||||||
@ -7361,7 +7361,7 @@ __metadata:
|
|||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
"ajv@npm:^8.0.0, ajv@npm:^8.17.1, ajv@npm:^8.6.3":
|
"ajv@npm:^8.0.0, ajv@npm:^8.6.3":
|
||||||
version: 8.17.1
|
version: 8.17.1
|
||||||
resolution: "ajv@npm:8.17.1"
|
resolution: "ajv@npm:8.17.1"
|
||||||
dependencies:
|
dependencies:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user