mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
<type>: <subject>
<body> <footer> 用來簡要描述影響本次變動,概述即可
This commit is contained in:
parent
9293f26612
commit
ba121d04b4
@ -3,7 +3,7 @@
|
||||
* 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用
|
||||
* 内置默认逻辑,支持自定义覆盖
|
||||
*/
|
||||
import { ToolSet } from 'ai'
|
||||
import { ToolExecutionError, ToolSet } from 'ai'
|
||||
|
||||
import { definePlugin } from '../index'
|
||||
import type { AiRequestContext } from '../types'
|
||||
@ -48,39 +48,11 @@ export interface MCPPromptConfig {
|
||||
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
|
||||
}
|
||||
|
||||
// 全局存储,解决 transformStream 中无 context 的问题
|
||||
const globalToolsStorage = new Map<string, ToolSet>()
|
||||
|
||||
/**
|
||||
* 生成唯一的执行ID
|
||||
* 扩展的 AI 请求上下文,支持 MCP 工具存储
|
||||
*/
|
||||
function generateExecutionId(): string {
|
||||
return `mcp_${Date.now()}_${Math.random().toString(36).slice(2)}`
|
||||
}
|
||||
|
||||
/**
|
||||
* 存储工具信息
|
||||
*/
|
||||
|
||||
/**
|
||||
* 全局存储工具信息
|
||||
*/
|
||||
function storeGlobalTools(executionId: string, tools: ToolSet) {
|
||||
globalToolsStorage.set(executionId, tools)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取全局存储的工具信息
|
||||
*/
|
||||
function getGlobalTools(executionId: string): ToolSet | undefined {
|
||||
return globalToolsStorage.get(executionId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理全局存储
|
||||
*/
|
||||
function clearGlobalTools(executionId: string) {
|
||||
globalToolsStorage.delete(executionId)
|
||||
interface MCPRequestContext extends AiRequestContext {
|
||||
mcpTools?: ToolSet
|
||||
}
|
||||
|
||||
/**
|
||||
@ -315,30 +287,26 @@ function defaultParseToolUse(content: string, tools: ToolSet): ToolUseResult[] {
|
||||
export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {}) => {
|
||||
const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config
|
||||
|
||||
// 为每个插件实例生成唯一ID
|
||||
const executionId = generateExecutionId()
|
||||
|
||||
return {
|
||||
name: 'built-in:mcp-prompt',
|
||||
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
if (!enabled || !params.tools) return params
|
||||
transformParams: async (params: any, context: MCPRequestContext) => {
|
||||
if (!enabled || !params.tools || typeof params.tools !== 'object') {
|
||||
return params
|
||||
}
|
||||
|
||||
// 保存原始工具信息到 WeakMap 和全局存储中
|
||||
const tools: ToolSet = params.tools
|
||||
console.log('tools', tools)
|
||||
// storeTools(context, tools)
|
||||
storeGlobalTools(executionId, tools)
|
||||
// 直接存储工具信息到 context 上,利用改进的插件引擎
|
||||
context.mcpTools = params.tools
|
||||
console.log('tools stored in context', params.tools)
|
||||
|
||||
// 构建系统提示符
|
||||
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||
const systemPrompt = await buildSystemPrompt(userSystemPrompt, tools)
|
||||
const systemPrompt = await buildSystemPrompt(userSystemPrompt, params.tools)
|
||||
|
||||
// 将工具信息保存到参数中(用于后续解析)
|
||||
// 移除 tools,改为 prompt 模式
|
||||
const transformedParams = {
|
||||
...params,
|
||||
system: systemPrompt,
|
||||
// 移除 tools,改为 prompt 模式
|
||||
tools: undefined
|
||||
}
|
||||
console.log('transformedParams', transformedParams)
|
||||
@ -346,7 +314,7 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
||||
},
|
||||
|
||||
// 流式处理:监听 step-finish 事件并处理工具调用
|
||||
transformStream: (_, context: AiRequestContext) => () => {
|
||||
transformStream: (_, context: MCPRequestContext) => () => {
|
||||
let textBuffer = ''
|
||||
let executedResults: { toolCallId: string; toolName: string; result: any; isError?: boolean }[] = []
|
||||
|
||||
@ -365,10 +333,10 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
||||
if (chunk.type === 'step-finish' || chunk.type === 'finish') {
|
||||
console.log('[MCP Prompt Stream] Received step-finish, checking for tool use...')
|
||||
|
||||
// 获取工具信息
|
||||
const tools = getGlobalTools(executionId)
|
||||
console.log('tools', tools)
|
||||
if (!tools) {
|
||||
// 从 context 获取工具信息
|
||||
const tools = context.mcpTools
|
||||
console.log('tools from context', tools)
|
||||
if (!tools || Object.keys(tools).length === 0) {
|
||||
console.log('[MCP Prompt Stream] No tools available, passing through')
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
@ -376,7 +344,6 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
||||
|
||||
// 解析工具调用
|
||||
const parsedTools = parseToolUse(textBuffer, tools)
|
||||
// console.log('textBuffer', textBuffer)
|
||||
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
||||
console.log('parsedTools', parsedTools)
|
||||
|
||||
@ -444,19 +411,40 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
||||
} catch (error) {
|
||||
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
|
||||
|
||||
// 使用 AI SDK 标准错误格式
|
||||
const toolError = new ToolExecutionError({
|
||||
toolName: toolUse.toolName,
|
||||
toolArgs: toolUse.arguments,
|
||||
toolCallId: toolUse.id,
|
||||
message: `Tool execution failed: ${error instanceof Error ? error.message : String(error)}`,
|
||||
cause: error instanceof Error ? error : undefined
|
||||
})
|
||||
|
||||
// 发送标准错误事件
|
||||
controller.enqueue({
|
||||
type: 'error',
|
||||
error: {
|
||||
message: toolError.message,
|
||||
name: toolError.name,
|
||||
toolName: toolError.toolName,
|
||||
toolCallId: toolError.toolCallId
|
||||
}
|
||||
})
|
||||
|
||||
// 发送 tool-result 错误事件
|
||||
controller.enqueue({
|
||||
type: 'tool-result',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
args: toolUse.arguments,
|
||||
isError: true,
|
||||
result: error instanceof Error ? error.message : String(error)
|
||||
result: toolError.message
|
||||
})
|
||||
|
||||
executedResults.push({
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
result: error instanceof Error ? error.message : String(error),
|
||||
result: toolError.message,
|
||||
isError: true
|
||||
})
|
||||
}
|
||||
@ -466,7 +454,6 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
||||
controller.enqueue({
|
||||
type: 'step-finish',
|
||||
finishReason: 'tool-call'
|
||||
// usage: { completionTokens: 0, promptTokens: 0, totalTokens: 0 }
|
||||
})
|
||||
|
||||
// 递归调用逻辑
|
||||
@ -498,11 +485,11 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
||||
}
|
||||
]
|
||||
|
||||
// 递归调用,继续对话
|
||||
// 递归调用,继续对话,重新传递 tools
|
||||
const recursiveParams = {
|
||||
...context.originalParams,
|
||||
messages: newMessages,
|
||||
tools: tools // 重新传递 tools
|
||||
tools: tools
|
||||
}
|
||||
|
||||
try {
|
||||
@ -529,22 +516,25 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[MCP Prompt] Recursive call failed:', error)
|
||||
// 发送错误信息后也要确保流不会中断
|
||||
|
||||
// 使用 AI SDK 标准错误格式,但不中断流
|
||||
controller.enqueue({
|
||||
type: 'text-delta',
|
||||
textDelta: `\n\n[Error: Recursive call failed: ${error instanceof Error ? error.message : String(error)}]`
|
||||
type: 'error',
|
||||
error: {
|
||||
message: error instanceof Error ? error.message : String(error),
|
||||
name: error instanceof Error ? error.name : 'RecursiveCallError'
|
||||
}
|
||||
})
|
||||
|
||||
// 发送一个错误后的结束信号
|
||||
// 继续发送文本增量,保持流的连续性
|
||||
controller.enqueue({
|
||||
type: 'finish',
|
||||
finishReason: 'error'
|
||||
type: 'text-delta',
|
||||
textDelta: '\n\n[工具执行后递归调用失败,继续对话...]'
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 清理状态
|
||||
// clearGlobalTools(executionId)
|
||||
textBuffer = ''
|
||||
executedResults = []
|
||||
return
|
||||
@ -555,116 +545,10 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
||||
},
|
||||
|
||||
flush() {
|
||||
// 清理全局存储
|
||||
clearGlobalTools(executionId)
|
||||
// 流结束时的清理工作
|
||||
console.log('[MCP Prompt] Stream ended, cleaning up...')
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// transformResult: async (result: any, context: AiRequestContext) => {
|
||||
// // 这个方法现在主要用于非流式场景
|
||||
// if (!enabled || !result || typeof result.text !== 'string') return result
|
||||
|
||||
// console.log('[MCP Prompt] transformResult called - likely non-streaming mode')
|
||||
|
||||
// // 从 WeakMap 中获取工具信息
|
||||
// const tools: ToolSet | undefined = getStoredTools(context)
|
||||
// if (!tools || typeof tools !== 'object') return result
|
||||
|
||||
// // 使用工具解析函数(默认或自定义)
|
||||
// const parsedTools = parseToolUse(result.text, tools)
|
||||
// if (!parsedTools || parsedTools.length === 0) return result
|
||||
|
||||
// // 过滤掉解析失败的工具调用
|
||||
// const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
||||
// if (validToolUses.length === 0) {
|
||||
// console.warn('[MCP Prompt] No valid tool uses found:', parsedTools)
|
||||
// return result
|
||||
// }
|
||||
|
||||
// // 只在非流式模式下执行工具调用并递归
|
||||
// if (context.recursiveCall) {
|
||||
// console.log('[MCP Prompt] Non-streaming: Executing tools and continuing conversation...')
|
||||
|
||||
// // 执行工具调用
|
||||
// const toolResults = await Promise.all(
|
||||
// validToolUses.map(async (toolUse) => {
|
||||
// try {
|
||||
// const tool = tools[toolUse.toolName]
|
||||
// if (!tool || typeof tool.execute !== 'function') {
|
||||
// throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
|
||||
// }
|
||||
|
||||
// console.log(`[MCP Prompt] Non-streaming: Executing tool: ${toolUse.toolName}`, toolUse.arguments)
|
||||
|
||||
// const result = await tool.execute(toolUse.arguments, {
|
||||
// toolCallId: toolUse.id,
|
||||
// messages: [],
|
||||
// abortSignal: new AbortController().signal
|
||||
// })
|
||||
|
||||
// return {
|
||||
// id: toolUse.id,
|
||||
// name: toolUse.toolName,
|
||||
// arguments: toolUse.arguments,
|
||||
// result,
|
||||
// success: true
|
||||
// }
|
||||
// } catch (error) {
|
||||
// console.error(`[MCP Prompt] Non-streaming: Tool execution failed: ${toolUse.toolName}`, error)
|
||||
// return {
|
||||
// id: toolUse.id,
|
||||
// name: toolUse.toolName,
|
||||
// arguments: toolUse.arguments,
|
||||
// error: error instanceof Error ? error.message : String(error),
|
||||
// success: false
|
||||
// }
|
||||
// }
|
||||
// })
|
||||
// )
|
||||
|
||||
// // 构建工具结果的文本表示
|
||||
// const toolResultsText = toolResults
|
||||
// .map((tr) => {
|
||||
// if (tr.success) {
|
||||
// return `<tool_use_result>\n <name>${tr.name}</name>\n <result>${JSON.stringify(tr.result)}</result>\n</tool_use_result>`
|
||||
// } else {
|
||||
// return `<tool_use_result>\n <name>${tr.name}</name>\n <error>${tr.error}</error>\n</tool_use_result>`
|
||||
// }
|
||||
// })
|
||||
// .join('\n\n')
|
||||
|
||||
// // 构建新的对话消息
|
||||
// const newMessages = [
|
||||
// ...(context.originalParams.messages || []),
|
||||
// {
|
||||
// role: 'assistant',
|
||||
// content: result.text
|
||||
// },
|
||||
// {
|
||||
// role: 'user',
|
||||
// content: toolResultsText
|
||||
// }
|
||||
// ]
|
||||
|
||||
// // 递归调用,继续对话
|
||||
// const recursiveParams = {
|
||||
// ...context.originalParams,
|
||||
// messages: newMessages,
|
||||
// tools: tools // 重新传递 tools,在新的 context 中会重新存储
|
||||
// }
|
||||
|
||||
// try {
|
||||
// console.log('[MCP Prompt] Non-streaming: Starting recursive call...')
|
||||
// const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||
// return recursiveResult
|
||||
// } catch (error) {
|
||||
// console.error('[MCP Prompt] Non-streaming: Recursive call failed:', error)
|
||||
// return result
|
||||
// }
|
||||
// }
|
||||
|
||||
// return result
|
||||
// }
|
||||
}
|
||||
})
|
||||
|
||||
@ -63,15 +63,16 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
methodName: string,
|
||||
modelId: string,
|
||||
params: TParams,
|
||||
executor: (finalModelId: string, transformedParams: TParams) => Promise<TResult>
|
||||
executor: (finalModelId: string, transformedParams: TParams) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = createContext(this.providerId, modelId, params)
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
return this.executeWithPlugins(methodName, modelId, newParams, executor)
|
||||
return this.executeWithPlugins(methodName, modelId, newParams, executor, context)
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user