mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-11 08:19:01 +08:00
<type>: <subject>
<body> <footer> 用來簡要描述影響本次變動,概述即可
This commit is contained in:
parent
9293f26612
commit
ba121d04b4
@ -3,7 +3,7 @@
|
|||||||
* 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用
|
* 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用
|
||||||
* 内置默认逻辑,支持自定义覆盖
|
* 内置默认逻辑,支持自定义覆盖
|
||||||
*/
|
*/
|
||||||
import { ToolSet } from 'ai'
|
import { ToolExecutionError, ToolSet } from 'ai'
|
||||||
|
|
||||||
import { definePlugin } from '../index'
|
import { definePlugin } from '../index'
|
||||||
import type { AiRequestContext } from '../types'
|
import type { AiRequestContext } from '../types'
|
||||||
@ -48,39 +48,11 @@ export interface MCPPromptConfig {
|
|||||||
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
|
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
|
||||||
}
|
}
|
||||||
|
|
||||||
// 全局存储,解决 transformStream 中无 context 的问题
|
|
||||||
const globalToolsStorage = new Map<string, ToolSet>()
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 生成唯一的执行ID
|
* 扩展的 AI 请求上下文,支持 MCP 工具存储
|
||||||
*/
|
*/
|
||||||
function generateExecutionId(): string {
|
interface MCPRequestContext extends AiRequestContext {
|
||||||
return `mcp_${Date.now()}_${Math.random().toString(36).slice(2)}`
|
mcpTools?: ToolSet
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 存储工具信息
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 全局存储工具信息
|
|
||||||
*/
|
|
||||||
function storeGlobalTools(executionId: string, tools: ToolSet) {
|
|
||||||
globalToolsStorage.set(executionId, tools)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 获取全局存储的工具信息
|
|
||||||
*/
|
|
||||||
function getGlobalTools(executionId: string): ToolSet | undefined {
|
|
||||||
return globalToolsStorage.get(executionId)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 清理全局存储
|
|
||||||
*/
|
|
||||||
function clearGlobalTools(executionId: string) {
|
|
||||||
globalToolsStorage.delete(executionId)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -315,30 +287,26 @@ function defaultParseToolUse(content: string, tools: ToolSet): ToolUseResult[] {
|
|||||||
export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {}) => {
|
export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {}) => {
|
||||||
const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config
|
const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config
|
||||||
|
|
||||||
// 为每个插件实例生成唯一ID
|
|
||||||
const executionId = generateExecutionId()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
name: 'built-in:mcp-prompt',
|
name: 'built-in:mcp-prompt',
|
||||||
|
|
||||||
transformParams: async (params: any, context: AiRequestContext) => {
|
transformParams: async (params: any, context: MCPRequestContext) => {
|
||||||
if (!enabled || !params.tools) return params
|
if (!enabled || !params.tools || typeof params.tools !== 'object') {
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
// 保存原始工具信息到 WeakMap 和全局存储中
|
// 直接存储工具信息到 context 上,利用改进的插件引擎
|
||||||
const tools: ToolSet = params.tools
|
context.mcpTools = params.tools
|
||||||
console.log('tools', tools)
|
console.log('tools stored in context', params.tools)
|
||||||
// storeTools(context, tools)
|
|
||||||
storeGlobalTools(executionId, tools)
|
|
||||||
|
|
||||||
// 构建系统提示符
|
// 构建系统提示符
|
||||||
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||||
const systemPrompt = await buildSystemPrompt(userSystemPrompt, tools)
|
const systemPrompt = await buildSystemPrompt(userSystemPrompt, params.tools)
|
||||||
|
|
||||||
// 将工具信息保存到参数中(用于后续解析)
|
// 移除 tools,改为 prompt 模式
|
||||||
const transformedParams = {
|
const transformedParams = {
|
||||||
...params,
|
...params,
|
||||||
system: systemPrompt,
|
system: systemPrompt,
|
||||||
// 移除 tools,改为 prompt 模式
|
|
||||||
tools: undefined
|
tools: undefined
|
||||||
}
|
}
|
||||||
console.log('transformedParams', transformedParams)
|
console.log('transformedParams', transformedParams)
|
||||||
@ -346,7 +314,7 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
|||||||
},
|
},
|
||||||
|
|
||||||
// 流式处理:监听 step-finish 事件并处理工具调用
|
// 流式处理:监听 step-finish 事件并处理工具调用
|
||||||
transformStream: (_, context: AiRequestContext) => () => {
|
transformStream: (_, context: MCPRequestContext) => () => {
|
||||||
let textBuffer = ''
|
let textBuffer = ''
|
||||||
let executedResults: { toolCallId: string; toolName: string; result: any; isError?: boolean }[] = []
|
let executedResults: { toolCallId: string; toolName: string; result: any; isError?: boolean }[] = []
|
||||||
|
|
||||||
@ -365,10 +333,10 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
|||||||
if (chunk.type === 'step-finish' || chunk.type === 'finish') {
|
if (chunk.type === 'step-finish' || chunk.type === 'finish') {
|
||||||
console.log('[MCP Prompt Stream] Received step-finish, checking for tool use...')
|
console.log('[MCP Prompt Stream] Received step-finish, checking for tool use...')
|
||||||
|
|
||||||
// 获取工具信息
|
// 从 context 获取工具信息
|
||||||
const tools = getGlobalTools(executionId)
|
const tools = context.mcpTools
|
||||||
console.log('tools', tools)
|
console.log('tools from context', tools)
|
||||||
if (!tools) {
|
if (!tools || Object.keys(tools).length === 0) {
|
||||||
console.log('[MCP Prompt Stream] No tools available, passing through')
|
console.log('[MCP Prompt Stream] No tools available, passing through')
|
||||||
controller.enqueue(chunk)
|
controller.enqueue(chunk)
|
||||||
return
|
return
|
||||||
@ -376,7 +344,6 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
|||||||
|
|
||||||
// 解析工具调用
|
// 解析工具调用
|
||||||
const parsedTools = parseToolUse(textBuffer, tools)
|
const parsedTools = parseToolUse(textBuffer, tools)
|
||||||
// console.log('textBuffer', textBuffer)
|
|
||||||
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
||||||
console.log('parsedTools', parsedTools)
|
console.log('parsedTools', parsedTools)
|
||||||
|
|
||||||
@ -444,19 +411,40 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
|||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
|
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
|
||||||
|
|
||||||
|
// 使用 AI SDK 标准错误格式
|
||||||
|
const toolError = new ToolExecutionError({
|
||||||
|
toolName: toolUse.toolName,
|
||||||
|
toolArgs: toolUse.arguments,
|
||||||
|
toolCallId: toolUse.id,
|
||||||
|
message: `Tool execution failed: ${error instanceof Error ? error.message : String(error)}`,
|
||||||
|
cause: error instanceof Error ? error : undefined
|
||||||
|
})
|
||||||
|
|
||||||
|
// 发送标准错误事件
|
||||||
|
controller.enqueue({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
message: toolError.message,
|
||||||
|
name: toolError.name,
|
||||||
|
toolName: toolError.toolName,
|
||||||
|
toolCallId: toolError.toolCallId
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 发送 tool-result 错误事件
|
||||||
controller.enqueue({
|
controller.enqueue({
|
||||||
type: 'tool-result',
|
type: 'tool-result',
|
||||||
toolCallId: toolUse.id,
|
toolCallId: toolUse.id,
|
||||||
toolName: toolUse.toolName,
|
toolName: toolUse.toolName,
|
||||||
args: toolUse.arguments,
|
args: toolUse.arguments,
|
||||||
isError: true,
|
isError: true,
|
||||||
result: error instanceof Error ? error.message : String(error)
|
result: toolError.message
|
||||||
})
|
})
|
||||||
|
|
||||||
executedResults.push({
|
executedResults.push({
|
||||||
toolCallId: toolUse.id,
|
toolCallId: toolUse.id,
|
||||||
toolName: toolUse.toolName,
|
toolName: toolUse.toolName,
|
||||||
result: error instanceof Error ? error.message : String(error),
|
result: toolError.message,
|
||||||
isError: true
|
isError: true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -466,7 +454,6 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
|||||||
controller.enqueue({
|
controller.enqueue({
|
||||||
type: 'step-finish',
|
type: 'step-finish',
|
||||||
finishReason: 'tool-call'
|
finishReason: 'tool-call'
|
||||||
// usage: { completionTokens: 0, promptTokens: 0, totalTokens: 0 }
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// 递归调用逻辑
|
// 递归调用逻辑
|
||||||
@ -498,11 +485,11 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
// 递归调用,继续对话
|
// 递归调用,继续对话,重新传递 tools
|
||||||
const recursiveParams = {
|
const recursiveParams = {
|
||||||
...context.originalParams,
|
...context.originalParams,
|
||||||
messages: newMessages,
|
messages: newMessages,
|
||||||
tools: tools // 重新传递 tools
|
tools: tools
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -529,22 +516,25 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
|||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('[MCP Prompt] Recursive call failed:', error)
|
console.error('[MCP Prompt] Recursive call failed:', error)
|
||||||
// 发送错误信息后也要确保流不会中断
|
|
||||||
|
// 使用 AI SDK 标准错误格式,但不中断流
|
||||||
controller.enqueue({
|
controller.enqueue({
|
||||||
type: 'text-delta',
|
type: 'error',
|
||||||
textDelta: `\n\n[Error: Recursive call failed: ${error instanceof Error ? error.message : String(error)}]`
|
error: {
|
||||||
|
message: error instanceof Error ? error.message : String(error),
|
||||||
|
name: error instanceof Error ? error.name : 'RecursiveCallError'
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// 发送一个错误后的结束信号
|
// 继续发送文本增量,保持流的连续性
|
||||||
controller.enqueue({
|
controller.enqueue({
|
||||||
type: 'finish',
|
type: 'text-delta',
|
||||||
finishReason: 'error'
|
textDelta: '\n\n[工具执行后递归调用失败,继续对话...]'
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清理状态
|
// 清理状态
|
||||||
// clearGlobalTools(executionId)
|
|
||||||
textBuffer = ''
|
textBuffer = ''
|
||||||
executedResults = []
|
executedResults = []
|
||||||
return
|
return
|
||||||
@ -555,116 +545,10 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
|||||||
},
|
},
|
||||||
|
|
||||||
flush() {
|
flush() {
|
||||||
// 清理全局存储
|
// 流结束时的清理工作
|
||||||
clearGlobalTools(executionId)
|
console.log('[MCP Prompt] Stream ended, cleaning up...')
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// transformResult: async (result: any, context: AiRequestContext) => {
|
|
||||||
// // 这个方法现在主要用于非流式场景
|
|
||||||
// if (!enabled || !result || typeof result.text !== 'string') return result
|
|
||||||
|
|
||||||
// console.log('[MCP Prompt] transformResult called - likely non-streaming mode')
|
|
||||||
|
|
||||||
// // 从 WeakMap 中获取工具信息
|
|
||||||
// const tools: ToolSet | undefined = getStoredTools(context)
|
|
||||||
// if (!tools || typeof tools !== 'object') return result
|
|
||||||
|
|
||||||
// // 使用工具解析函数(默认或自定义)
|
|
||||||
// const parsedTools = parseToolUse(result.text, tools)
|
|
||||||
// if (!parsedTools || parsedTools.length === 0) return result
|
|
||||||
|
|
||||||
// // 过滤掉解析失败的工具调用
|
|
||||||
// const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
|
||||||
// if (validToolUses.length === 0) {
|
|
||||||
// console.warn('[MCP Prompt] No valid tool uses found:', parsedTools)
|
|
||||||
// return result
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // 只在非流式模式下执行工具调用并递归
|
|
||||||
// if (context.recursiveCall) {
|
|
||||||
// console.log('[MCP Prompt] Non-streaming: Executing tools and continuing conversation...')
|
|
||||||
|
|
||||||
// // 执行工具调用
|
|
||||||
// const toolResults = await Promise.all(
|
|
||||||
// validToolUses.map(async (toolUse) => {
|
|
||||||
// try {
|
|
||||||
// const tool = tools[toolUse.toolName]
|
|
||||||
// if (!tool || typeof tool.execute !== 'function') {
|
|
||||||
// throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// console.log(`[MCP Prompt] Non-streaming: Executing tool: ${toolUse.toolName}`, toolUse.arguments)
|
|
||||||
|
|
||||||
// const result = await tool.execute(toolUse.arguments, {
|
|
||||||
// toolCallId: toolUse.id,
|
|
||||||
// messages: [],
|
|
||||||
// abortSignal: new AbortController().signal
|
|
||||||
// })
|
|
||||||
|
|
||||||
// return {
|
|
||||||
// id: toolUse.id,
|
|
||||||
// name: toolUse.toolName,
|
|
||||||
// arguments: toolUse.arguments,
|
|
||||||
// result,
|
|
||||||
// success: true
|
|
||||||
// }
|
|
||||||
// } catch (error) {
|
|
||||||
// console.error(`[MCP Prompt] Non-streaming: Tool execution failed: ${toolUse.toolName}`, error)
|
|
||||||
// return {
|
|
||||||
// id: toolUse.id,
|
|
||||||
// name: toolUse.toolName,
|
|
||||||
// arguments: toolUse.arguments,
|
|
||||||
// error: error instanceof Error ? error.message : String(error),
|
|
||||||
// success: false
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
// )
|
|
||||||
|
|
||||||
// // 构建工具结果的文本表示
|
|
||||||
// const toolResultsText = toolResults
|
|
||||||
// .map((tr) => {
|
|
||||||
// if (tr.success) {
|
|
||||||
// return `<tool_use_result>\n <name>${tr.name}</name>\n <result>${JSON.stringify(tr.result)}</result>\n</tool_use_result>`
|
|
||||||
// } else {
|
|
||||||
// return `<tool_use_result>\n <name>${tr.name}</name>\n <error>${tr.error}</error>\n</tool_use_result>`
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
// .join('\n\n')
|
|
||||||
|
|
||||||
// // 构建新的对话消息
|
|
||||||
// const newMessages = [
|
|
||||||
// ...(context.originalParams.messages || []),
|
|
||||||
// {
|
|
||||||
// role: 'assistant',
|
|
||||||
// content: result.text
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// role: 'user',
|
|
||||||
// content: toolResultsText
|
|
||||||
// }
|
|
||||||
// ]
|
|
||||||
|
|
||||||
// // 递归调用,继续对话
|
|
||||||
// const recursiveParams = {
|
|
||||||
// ...context.originalParams,
|
|
||||||
// messages: newMessages,
|
|
||||||
// tools: tools // 重新传递 tools,在新的 context 中会重新存储
|
|
||||||
// }
|
|
||||||
|
|
||||||
// try {
|
|
||||||
// console.log('[MCP Prompt] Non-streaming: Starting recursive call...')
|
|
||||||
// const recursiveResult = await context.recursiveCall(recursiveParams)
|
|
||||||
// return recursiveResult
|
|
||||||
// } catch (error) {
|
|
||||||
// console.error('[MCP Prompt] Non-streaming: Recursive call failed:', error)
|
|
||||||
// return result
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return result
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@ -63,15 +63,16 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
|||||||
methodName: string,
|
methodName: string,
|
||||||
modelId: string,
|
modelId: string,
|
||||||
params: TParams,
|
params: TParams,
|
||||||
executor: (finalModelId: string, transformedParams: TParams) => Promise<TResult>
|
executor: (finalModelId: string, transformedParams: TParams) => Promise<TResult>,
|
||||||
|
_context?: ReturnType<typeof createContext>
|
||||||
): Promise<TResult> {
|
): Promise<TResult> {
|
||||||
// 使用正确的createContext创建请求上下文
|
// 使用正确的createContext创建请求上下文
|
||||||
const context = createContext(this.providerId, modelId, params)
|
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||||
|
|
||||||
// 🔥 为上下文添加递归调用能力
|
// 🔥 为上下文添加递归调用能力
|
||||||
context.recursiveCall = (newParams: any): Promise<TResult> => {
|
context.recursiveCall = (newParams: any): Promise<TResult> => {
|
||||||
// 递归调用自身,重新走完整的插件流程
|
// 递归调用自身,重新走完整的插件流程
|
||||||
return this.executeWithPlugins(methodName, modelId, newParams, executor)
|
return this.executeWithPlugins(methodName, modelId, newParams, executor, context)
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user