From 9293f26612ac6fa8976f8fc422d3b4c4d89c3042 Mon Sep 17 00:00:00 2001 From: MyPrototypeWhat Date: Thu, 26 Jun 2025 19:42:04 +0800 Subject: [PATCH] feat: enhance MCP Prompt plugin and recursive call capabilities - Updated `tsconfig.web.json` to support wildcard imports for `@cherrystudio/ai-core`. - Enhanced `package.json` to include type definitions and imports for built-in plugins. - Introduced recursive call functionality in `PluginManager` and `PluginEngine`, allowing for improved handling of tool interactions. - Added `MCPPromptPlugin` to facilitate tool calls within prompts, enabling recursive processing of tool results. - Refactored `transformStream` methods across plugins to accommodate new parameters and improve type safety. --- packages/aiCore/package.json | 5 + .../aiCore/src/core/plugins/built-in/index.ts | 2 + .../src/core/plugins/built-in/mcpPrompt.ts | 670 ++++++++++++++++++ packages/aiCore/src/core/plugins/index.ts | 3 +- packages/aiCore/src/core/plugins/manager.ts | 16 +- packages/aiCore/src/core/plugins/types.ts | 13 +- .../aiCore/src/core/runtime/pluginEngine.ts | 14 +- src/renderer/src/aiCore/index_new.ts | 20 +- .../src/aiCore/plugins/mcpPromptPlugin.ts | 10 +- .../src/aiCore/plugins/reasonPlugin.ts | 2 +- src/renderer/src/aiCore/plugins/textPlugin.ts | 11 +- tsconfig.web.json | 2 +- 12 files changed, 729 insertions(+), 39 deletions(-) create mode 100644 packages/aiCore/src/core/plugins/built-in/mcpPrompt.ts diff --git a/packages/aiCore/package.json b/packages/aiCore/package.json index 853753a176..df031597a0 100644 --- a/packages/aiCore/package.json +++ b/packages/aiCore/package.json @@ -118,6 +118,11 @@ "types": "./src/index.ts", "import": "./src/index.ts", "require": "./src/index.ts" + }, + "./core/plugins/built-in": { + "types": "./src/core/plugins/built-in/index.ts", + "import": "./src/core/plugins/built-in/index.ts", + "require": "./src/core/plugins/built-in/index.ts" } } } diff --git a/packages/aiCore/src/core/plugins/built-in/index.ts b/packages/aiCore/src/core/plugins/built-in/index.ts index 5de58f2175..2510b756fa 100644 --- a/packages/aiCore/src/core/plugins/built-in/index.ts +++ b/packages/aiCore/src/core/plugins/built-in/index.ts @@ -5,3 +5,5 @@ export const BUILT_IN_PLUGIN_PREFIX = 'built-in:' export { createLoggingPlugin } from './logging' +export type { MCPPromptConfig, ToolUseResult } from './mcpPrompt' +export { createMCPPromptPlugin } from './mcpPrompt' diff --git a/packages/aiCore/src/core/plugins/built-in/mcpPrompt.ts b/packages/aiCore/src/core/plugins/built-in/mcpPrompt.ts new file mode 100644 index 0000000000..a752f94f2e --- /dev/null +++ b/packages/aiCore/src/core/plugins/built-in/mcpPrompt.ts @@ -0,0 +1,670 @@ +/** + * 内置插件:MCP Prompt 模式 + * 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用 + * 内置默认逻辑,支持自定义覆盖 + */ +import { ToolSet } from 'ai' + +import { definePlugin } from '../index' +import type { AiRequestContext } from '../types' + +/** + * 使用 AI SDK 的 Tool 类型,更通用 + */ +// export interface Tool { +// type: 'function' +// function: { +// name: string +// description?: string +// parameters?: { +// type: 'object' +// properties: Record +// required?: string[] +// additionalProperties?: boolean +// } +// } +// } + +/** + * 解析结果类型 + * 表示从AI响应中解析出的工具使用意图 + */ +export interface ToolUseResult { + id: string + toolName: string + arguments: any + status: 'pending' | 'invoking' | 'done' | 'error' +} + +/** + * MCP Prompt 插件配置 + */ +export interface MCPPromptConfig { + // 是否启用(用于运行时开关) + enabled?: boolean + // 自定义系统提示符构建函数(可选,有默认实现) + buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => Promise + // 自定义工具解析函数(可选,有默认实现) + parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[] +} + +// 全局存储,解决 transformStream 中无 context 的问题 +const globalToolsStorage = new Map() + +/** + * 生成唯一的执行ID + */ +function generateExecutionId(): string { + return `mcp_${Date.now()}_${Math.random().toString(36).slice(2)}` +} + +/** + * 存储工具信息 + */ + +/** + * 全局存储工具信息 + */ +function storeGlobalTools(executionId: string, tools: ToolSet) { + globalToolsStorage.set(executionId, tools) +} + +/** + * 获取全局存储的工具信息 + */ +function getGlobalTools(executionId: string): ToolSet | undefined { + return globalToolsStorage.get(executionId) +} + +/** + * 清理全局存储 + */ +function clearGlobalTools(executionId: string) { + globalToolsStorage.delete(executionId) +} + +/** + * 默认系统提示符模板(提取自 Cherry Studio) + */ +const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \\ +You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use. + +## Tool Use Formatting + +Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure: + + + {tool_name} + {json_arguments} + + +The tool name should be the exact name of the tool you are using, and the arguments should be a JSON object containing the parameters required by that tool. For example: + + python_interpreter + {"code": "5 + 3 + 1294.678"} + + +The user will respond with the result of the tool use, which should be formatted as follows: + + + {tool_name} + {result} + + +The result should be a string, which can represent a file or any other output type. You can use this result as input for the next action. +For example, if the result of the tool use is an image file, you can use it in the next action like this: + + + image_transformer + {"image": "image_1.jpg"} + + +Always adhere to this format for the tool use to ensure proper parsing and execution. + +## Tool Use Examples +{{ TOOL_USE_EXAMPLES }} + +## Tool Use Available Tools +Above example were using notional tools that might not exist for you. You only have access to these tools: +{{ AVAILABLE_TOOLS }} + +## Tool Use Rules +Here are the rules you should always follow to solve your task: +1. Always use the right arguments for the tools. Never use variable names as the action arguments, use the value instead. +2. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself. +3. If no tool call is needed, just answer the question directly. +4. Never re-do a tool call that you previously did with the exact same parameters. +5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format. + +# User Instructions +{{ USER_SYSTEM_PROMPT }} + +Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.` + +/** + * 默认工具使用示例(提取自 Cherry Studio) + */ +const DEFAULT_TOOL_USE_EXAMPLES = ` +Here are a few examples using notional tools: +--- +User: Generate an image of the oldest person in this document. + +A: I can use the document_qa tool to find out who the oldest person is in the document. + + document_qa + {"document": "document.pdf", "question": "Who is the oldest person mentioned?"} + + +User: + document_qa + John Doe, a 55 year old lumberjack living in Newfoundland. + + +A: I can use the image_generator tool to create a portrait of John Doe. + + image_generator + {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} + + +User: + image_generator + image.png + + +A: the image is generated as image.png + +--- +User: "What is the result of the following operation: 5 + 3 + 1294.678?" + +A: I can use the python_interpreter tool to calculate the result of the operation. + + python_interpreter + {"code": "5 + 3 + 1294.678"} + + +User: + python_interpreter + 1302.678 + + +A: The result of the operation is 1302.678. + +--- +User: "Which city has the highest population , Guangzhou or Shanghai?" + +A: I can use the search tool to find the population of Guangzhou. + + search + {"query": "Population Guangzhou"} + + +User: + search + Guangzhou has a population of 15 million inhabitants as of 2021. + + +A: I can use the search tool to find the population of Shanghai. + + search + {"query": "Population Shanghai"} + + +User: + search + 26 million (2019) + +Assistant: The population of Shanghai is 26 million, while Guangzhou has a population of 15 million. Therefore, Shanghai has the highest population.` + +/** + * 构建可用工具部分(提取自 Cherry Studio) + */ +function buildAvailableTools(tools: ToolSet): string { + const availableTools = Object.keys(tools) + .map((toolName: string) => { + const tool = tools[toolName] + return ` + + ${toolName} + ${tool.description || ''} + + ${tool.parameters ? JSON.stringify(tool.parameters) : ''} + + +` + }) + .join('\n') + return ` +${availableTools} +` +} + +/** + * 默认的系统提示符构建函数(提取自 Cherry Studio) + */ +async function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): Promise { + const availableTools = buildAvailableTools(tools) + + const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES) + .replace('{{ AVAILABLE_TOOLS }}', availableTools) + .replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt || '') + + return fullPrompt +} + +/** + * 默认工具解析函数(提取自 Cherry Studio) + * 解析 XML 格式的工具调用 + */ +function defaultParseToolUse(content: string, tools: ToolSet): ToolUseResult[] { + if (!content || !tools || Object.keys(tools).length === 0) { + return [] + } + + // 支持两种格式: + // 1. 完整的 标签包围的内容 + // 2. 只有内部内容(从 TagExtractor 提取出来的) + + let contentToProcess = content + + // 如果内容不包含 标签,说明是从 TagExtractor 提取的内部内容,需要包装 + if (!content.includes('')) { + contentToProcess = `\n${content}\n` + } + + const toolUsePattern = + /([\s\S]*?)([\s\S]*?)<\/name>([\s\S]*?)([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g + const results: ToolUseResult[] = [] + let match + let idx = 0 + + // Find all tool use blocks + while ((match = toolUsePattern.exec(contentToProcess)) !== null) { + const toolName = match[2].trim() + const toolArgs = match[4].trim() + + // Try to parse the arguments as JSON + let parsedArgs + try { + parsedArgs = JSON.parse(toolArgs) + } catch (error) { + // If parsing fails, use the string as is + parsedArgs = toolArgs + } + + // Find the corresponding tool + const tool = tools[toolName] + if (!tool) { + console.warn(`Tool "${toolName}" not found in available tools`) + continue + } + + // Add to results array + results.push({ + id: `${toolName}-${idx++}`, // Unique ID for each tool use + toolName: toolName, + arguments: parsedArgs, + status: 'pending' + }) + } + return results +} + +/** + * 创建 MCP Prompt 插件 + */ +export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {}) => { + const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config + + // 为每个插件实例生成唯一ID + const executionId = generateExecutionId() + + return { + name: 'built-in:mcp-prompt', + + transformParams: async (params: any, context: AiRequestContext) => { + if (!enabled || !params.tools) return params + + // 保存原始工具信息到 WeakMap 和全局存储中 + const tools: ToolSet = params.tools + console.log('tools', tools) + // storeTools(context, tools) + storeGlobalTools(executionId, tools) + + // 构建系统提示符 + const userSystemPrompt = typeof params.system === 'string' ? params.system : '' + const systemPrompt = await buildSystemPrompt(userSystemPrompt, tools) + + // 将工具信息保存到参数中(用于后续解析) + const transformedParams = { + ...params, + system: systemPrompt, + // 移除 tools,改为 prompt 模式 + tools: undefined + } + console.log('transformedParams', transformedParams) + return transformedParams + }, + + // 流式处理:监听 step-finish 事件并处理工具调用 + transformStream: (_, context: AiRequestContext) => () => { + let textBuffer = '' + let executedResults: { toolCallId: string; toolName: string; result: any; isError?: boolean }[] = [] + + return new TransformStream({ + async transform(chunk, controller) { + console.log('chunk', chunk) + // 收集文本内容 + if (chunk.type === 'text-delta') { + textBuffer += chunk.textDelta || '' + console.log('textBuffer', textBuffer) + controller.enqueue(chunk) + return + } + + // 监听 step-finish 事件 + if (chunk.type === 'step-finish' || chunk.type === 'finish') { + console.log('[MCP Prompt Stream] Received step-finish, checking for tool use...') + + // 获取工具信息 + const tools = getGlobalTools(executionId) + console.log('tools', tools) + if (!tools) { + console.log('[MCP Prompt Stream] No tools available, passing through') + controller.enqueue(chunk) + return + } + + // 解析工具调用 + const parsedTools = parseToolUse(textBuffer, tools) + // console.log('textBuffer', textBuffer) + const validToolUses = parsedTools.filter((t) => t.status === 'pending') + console.log('parsedTools', parsedTools) + + // 如果没有有效的工具调用,直接传递原始事件 + if (validToolUses.length === 0) { + console.log('[MCP Prompt Stream] No valid tool uses found, passing through') + controller.enqueue(chunk) + return + } + + console.log('[MCP Prompt Stream] Found valid tool uses:', validToolUses.length) + + // 修改 step-finish 事件,标记为工具调用 + if (chunk.type !== 'finish') { + controller.enqueue({ + ...chunk, + finishReason: 'tool-call' + }) + } + + // 发送 step-start 事件(工具调用步骤开始) + controller.enqueue({ + type: 'step-start' + }) + + // 执行工具调用 + executedResults = [] + for (const toolUse of validToolUses) { + try { + const tool = tools[toolUse.toolName] + if (!tool || typeof tool.execute !== 'function') { + throw new Error(`Tool "${toolUse.toolName}" has no execute method`) + } + + console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments) + // 发送 tool-call 事件 + controller.enqueue({ + type: 'tool-call', + toolCallId: toolUse.id, + toolName: toolUse.toolName, + args: toolUse.arguments + }) + + const result = await tool.execute(toolUse.arguments, { + toolCallId: toolUse.id, + messages: [], + abortSignal: new AbortController().signal + }) + + // 发送 tool-result 事件 + controller.enqueue({ + type: 'tool-result', + toolCallId: toolUse.id, + toolName: toolUse.toolName, + args: toolUse.arguments, + result + }) + + executedResults.push({ + toolCallId: toolUse.id, + toolName: toolUse.toolName, + result, + isError: false + }) + } catch (error) { + console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error) + + controller.enqueue({ + type: 'tool-result', + toolCallId: toolUse.id, + toolName: toolUse.toolName, + args: toolUse.arguments, + isError: true, + result: error instanceof Error ? error.message : String(error) + }) + + executedResults.push({ + toolCallId: toolUse.id, + toolName: toolUse.toolName, + result: error instanceof Error ? error.message : String(error), + isError: true + }) + } + } + + // 发送最终的 step-finish 事件 + controller.enqueue({ + type: 'step-finish', + finishReason: 'tool-call' + // usage: { completionTokens: 0, promptTokens: 0, totalTokens: 0 } + }) + + // 递归调用逻辑 + if (context.recursiveCall && validToolUses.length > 0) { + console.log('[MCP Prompt] Starting recursive call after tool execution...') + + // 构建工具结果的文本表示,使用Cherry Studio标准格式 + const toolResultsText = executedResults + .map((tr) => { + if (!tr.isError) { + return `\n ${tr.toolName}\n ${JSON.stringify(tr.result)}\n` + } else { + const error = tr.result || 'Unknown error' + return `\n ${tr.toolName}\n ${error}\n` + } + }) + .join('\n\n') + + // 构建新的对话消息 + const newMessages = [ + ...(context.originalParams.messages || []), + { + role: 'assistant', + content: textBuffer + }, + { + role: 'user', + content: toolResultsText + } + ] + + // 递归调用,继续对话 + const recursiveParams = { + ...context.originalParams, + messages: newMessages, + tools: tools // 重新传递 tools + } + + try { + const recursiveResult = await context.recursiveCall(recursiveParams) + + // 将递归调用的结果流接入当前流 + if (recursiveResult && recursiveResult.fullStream) { + const reader = recursiveResult.fullStream.getReader() + try { + while (true) { + const { done, value } = await reader.read() + if (done) { + break + } + + // 将递归流的数据传递到当前流 + controller.enqueue(value) + } + } finally { + reader.releaseLock() + } + } else { + console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult) + } + } catch (error) { + console.error('[MCP Prompt] Recursive call failed:', error) + // 发送错误信息后也要确保流不会中断 + controller.enqueue({ + type: 'text-delta', + textDelta: `\n\n[Error: Recursive call failed: ${error instanceof Error ? error.message : String(error)}]` + }) + + // 发送一个错误后的结束信号 + controller.enqueue({ + type: 'finish', + finishReason: 'error' + }) + } + } + + // 清理状态 + // clearGlobalTools(executionId) + textBuffer = '' + executedResults = [] + return + } + + // 对于其他类型的事件,直接传递 + controller.enqueue(chunk) + }, + + flush() { + // 清理全局存储 + clearGlobalTools(executionId) + } + }) + } + + // transformResult: async (result: any, context: AiRequestContext) => { + // // 这个方法现在主要用于非流式场景 + // if (!enabled || !result || typeof result.text !== 'string') return result + + // console.log('[MCP Prompt] transformResult called - likely non-streaming mode') + + // // 从 WeakMap 中获取工具信息 + // const tools: ToolSet | undefined = getStoredTools(context) + // if (!tools || typeof tools !== 'object') return result + + // // 使用工具解析函数(默认或自定义) + // const parsedTools = parseToolUse(result.text, tools) + // if (!parsedTools || parsedTools.length === 0) return result + + // // 过滤掉解析失败的工具调用 + // const validToolUses = parsedTools.filter((t) => t.status === 'pending') + // if (validToolUses.length === 0) { + // console.warn('[MCP Prompt] No valid tool uses found:', parsedTools) + // return result + // } + + // // 只在非流式模式下执行工具调用并递归 + // if (context.recursiveCall) { + // console.log('[MCP Prompt] Non-streaming: Executing tools and continuing conversation...') + + // // 执行工具调用 + // const toolResults = await Promise.all( + // validToolUses.map(async (toolUse) => { + // try { + // const tool = tools[toolUse.toolName] + // if (!tool || typeof tool.execute !== 'function') { + // throw new Error(`Tool "${toolUse.toolName}" has no execute method`) + // } + + // console.log(`[MCP Prompt] Non-streaming: Executing tool: ${toolUse.toolName}`, toolUse.arguments) + + // const result = await tool.execute(toolUse.arguments, { + // toolCallId: toolUse.id, + // messages: [], + // abortSignal: new AbortController().signal + // }) + + // return { + // id: toolUse.id, + // name: toolUse.toolName, + // arguments: toolUse.arguments, + // result, + // success: true + // } + // } catch (error) { + // console.error(`[MCP Prompt] Non-streaming: Tool execution failed: ${toolUse.toolName}`, error) + // return { + // id: toolUse.id, + // name: toolUse.toolName, + // arguments: toolUse.arguments, + // error: error instanceof Error ? error.message : String(error), + // success: false + // } + // } + // }) + // ) + + // // 构建工具结果的文本表示 + // const toolResultsText = toolResults + // .map((tr) => { + // if (tr.success) { + // return `\n ${tr.name}\n ${JSON.stringify(tr.result)}\n` + // } else { + // return `\n ${tr.name}\n ${tr.error}\n` + // } + // }) + // .join('\n\n') + + // // 构建新的对话消息 + // const newMessages = [ + // ...(context.originalParams.messages || []), + // { + // role: 'assistant', + // content: result.text + // }, + // { + // role: 'user', + // content: toolResultsText + // } + // ] + + // // 递归调用,继续对话 + // const recursiveParams = { + // ...context.originalParams, + // messages: newMessages, + // tools: tools // 重新传递 tools,在新的 context 中会重新存储 + // } + + // try { + // console.log('[MCP Prompt] Non-streaming: Starting recursive call...') + // const recursiveResult = await context.recursiveCall(recursiveParams) + // return recursiveResult + // } catch (error) { + // console.error('[MCP Prompt] Non-streaming: Recursive call failed:', error) + // return result + // } + // } + + // return result + // } + } +}) diff --git a/packages/aiCore/src/core/plugins/index.ts b/packages/aiCore/src/core/plugins/index.ts index 1b7c00eeb9..e78b57a4eb 100644 --- a/packages/aiCore/src/core/plugins/index.ts +++ b/packages/aiCore/src/core/plugins/index.ts @@ -13,7 +13,8 @@ export function createContext(providerId: string, modelId: string, originalParam originalParams, metadata: {}, startTime: Date.now(), - requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}` + requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`, + recursiveCall: undefined } } diff --git a/packages/aiCore/src/core/plugins/manager.ts b/packages/aiCore/src/core/plugins/manager.ts index 2ac3983f50..d37976b288 100644 --- a/packages/aiCore/src/core/plugins/manager.ts +++ b/packages/aiCore/src/core/plugins/manager.ts @@ -1,5 +1,3 @@ -import type { TextStreamPart, ToolSet } from 'ai' - import { AiPlugin, AiRequestContext } from './types' /** @@ -121,18 +119,8 @@ export class PluginManager { /** * 收集所有流转换器(返回数组,AI SDK 原生支持) */ - collectStreamTransforms(): Array< - (options: { - tools?: TOOLS - stopStream: () => void - }) => TransformStream, TextStreamPart> - > { - return this.plugins.map((plugin) => plugin.transformStream).filter(Boolean) as Array< - (options: { - tools?: TOOLS - stopStream: () => void - }) => TransformStream, TextStreamPart> - > + collectStreamTransforms(params: any, context: AiRequestContext) { + return this.plugins.map((plugin) => plugin.transformStream?.(params, context)) } /** diff --git a/packages/aiCore/src/core/plugins/types.ts b/packages/aiCore/src/core/plugins/types.ts index 4eee48c5f9..4ad3490f1b 100644 --- a/packages/aiCore/src/core/plugins/types.ts +++ b/packages/aiCore/src/core/plugins/types.ts @@ -1,5 +1,11 @@ import type { TextStreamPart, ToolSet } from 'ai' +/** + * 递归调用函数类型 + * 使用 any 是因为递归调用时参数和返回类型可能完全不同 + */ +export type RecursiveCallFn = (newParams: any) => Promise + /** * AI 请求上下文 */ @@ -10,6 +16,8 @@ export interface AiRequestContext { metadata: Record startTime: number requestId: string + recursiveCall?: RecursiveCallFn + [key: string]: any } /** @@ -33,7 +41,10 @@ export interface AiPlugin { onError?: (error: Error, context: AiRequestContext) => void | Promise // 【Stream】流处理 - 直接使用 AI SDK - transformStream?: (options: { + transformStream?: ( + params: any, + context: AiRequestContext + ) => (options: { tools: TOOLS stopStream: () => void }) => TransformStream, TextStreamPart> diff --git a/packages/aiCore/src/core/runtime/pluginEngine.ts b/packages/aiCore/src/core/runtime/pluginEngine.ts index f63bb32eac..6afb0435d6 100644 --- a/packages/aiCore/src/core/runtime/pluginEngine.ts +++ b/packages/aiCore/src/core/runtime/pluginEngine.ts @@ -68,6 +68,12 @@ export class PluginEngine { // 使用正确的createContext创建请求上下文 const context = createContext(this.providerId, modelId, params) + // 🔥 为上下文添加递归调用能力 + context.recursiveCall = (newParams: any): Promise => { + // 递归调用自身,重新走完整的插件流程 + return this.executeWithPlugins(methodName, modelId, newParams, executor) + } + try { // 1. 触发请求开始事件 await this.pluginManager.executeParallel('onRequestStart', context) @@ -109,6 +115,12 @@ export class PluginEngine { // 创建请求上下文 const context = createContext(this.providerId, modelId, params) + // 🔥 为上下文添加递归调用能力 + context.recursiveCall = (newParams: any): Promise => { + // 递归调用自身,重新走完整的插件流程 + return this.executeStreamWithPlugins(methodName, modelId, newParams, executor) + } + try { // 1. 触发请求开始事件 await this.pluginManager.executeParallel('onRequestStart', context) @@ -121,7 +133,7 @@ export class PluginEngine { const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context) // 4. 收集流转换器 - const streamTransforms = this.pluginManager.collectStreamTransforms() + const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context) // 5. 执行流式 API 调用 const result = await executor(finalModelId, transformedParams, streamTransforms) diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index c768cace77..e1486400c6 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -15,6 +15,7 @@ import { type ProviderSettingsMap, StreamTextParams } from '@cherrystudio/ai-core' +import { createMCPPromptPlugin } from '@cherrystudio/ai-core/core/plugins/built-in' import { isDedicatedImageGenerationModel } from '@renderer/config/models' import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import type { GenerateImageParams, Model, Provider } from '@renderer/types' @@ -113,7 +114,16 @@ export default class ModernAiProvider { // TODO:如果后续在调用completions时需要切换provider的话, // 初始化时不构建中间件,等到需要时再构建 const config = providerToAiSdkConfig(provider) + + // 创建MCP Prompt插件 + const mcpPromptPlugin = createMCPPromptPlugin({ + enabled: true + }) + + console.log('[Modern AI Provider] Creating executor with MCP Prompt plugin enabled') + this.modernExecutor = createExecutor(config.providerId, config.options, [ + mcpPromptPlugin, reasonPlugin({ delayInMs: 80, chunkingRegex: /([\u4E00-\u9FFF]{3})|\S+\s+/ @@ -174,16 +184,6 @@ export default class ModernAiProvider { if (middlewareConfig.onChunk) { // 流式处理 - 使用适配器 const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk) - // this.modernExecutor.pluginEngine.use( - // createMCPPromptPlugin({ - // mcpTools: middlewareConfig.mcpTools || [], - // assistant: params.assistant, - // onChunk: middlewareConfig.onChunk, - // recursiveCall: this.modernExecutor.streamText, - // recursionDepth: 0, - // maxRecursionDepth: 20 - // }) - // ) const streamResult = await this.modernExecutor.streamText( modelId, params, diff --git a/src/renderer/src/aiCore/plugins/mcpPromptPlugin.ts b/src/renderer/src/aiCore/plugins/mcpPromptPlugin.ts index c3a553ad63..a8b97b8788 100644 --- a/src/renderer/src/aiCore/plugins/mcpPromptPlugin.ts +++ b/src/renderer/src/aiCore/plugins/mcpPromptPlugin.ts @@ -20,8 +20,8 @@ export interface MCPPromptPluginConfig { * 创建 MCP Prompt 模式插件 * 支持在 prompt 模式下解析文本中的工具调用并执行 */ -export const createMCPPromptPlugin = (config: MCPPromptPluginConfig) => { - return definePlugin({ +export const createMCPPromptPlugin = definePlugin((config: MCPPromptPluginConfig) => { + return { name: 'mcp-prompt-plugin', // 1. 参数转换 - 注入工具描述到系统提示 @@ -49,7 +49,7 @@ export const createMCPPromptPlugin = (config: MCPPromptPluginConfig) => { }, // 2. 流处理 - 检测工具调用并执行 - transformStream: () => { + transformStream: () => () => { let fullResponseText = '' let hasProcessedTools = false @@ -87,8 +87,8 @@ export const createMCPPromptPlugin = (config: MCPPromptPluginConfig) => { } }) } - }) -} + } +}) /** * 处理工具调用并执行递归 diff --git a/src/renderer/src/aiCore/plugins/reasonPlugin.ts b/src/renderer/src/aiCore/plugins/reasonPlugin.ts index 9c709d24eb..a4f7293a74 100644 --- a/src/renderer/src/aiCore/plugins/reasonPlugin.ts +++ b/src/renderer/src/aiCore/plugins/reasonPlugin.ts @@ -3,7 +3,7 @@ import { definePlugin } from '@cherrystudio/ai-core' export default definePlugin(({ delayInMs, chunkingRegex }: { delayInMs: number; chunkingRegex: RegExp }) => ({ name: 'reasonPlugin', - transformStream: () => { + transformStream: () => () => { let buffer = '' const delay = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)) const detectChunk = (buffer: string) => { diff --git a/src/renderer/src/aiCore/plugins/textPlugin.ts b/src/renderer/src/aiCore/plugins/textPlugin.ts index 2ec6019ffb..aabcee56fe 100644 --- a/src/renderer/src/aiCore/plugins/textPlugin.ts +++ b/src/renderer/src/aiCore/plugins/textPlugin.ts @@ -2,9 +2,10 @@ import { definePlugin, smoothStream } from '@cherrystudio/ai-core' export default definePlugin({ name: 'textPlugin', - transformStream: smoothStream({ - delayInMs: 80, - // 中文3个字符一个chunk,英文一个单词一个chunk - chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/ - }) + transformStream: () => + smoothStream({ + delayInMs: 80, + // 中文3个字符一个chunk,英文一个单词一个chunk + chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/ + }) }) diff --git a/tsconfig.web.json b/tsconfig.web.json index e7d6f2daef..048864a036 100644 --- a/tsconfig.web.json +++ b/tsconfig.web.json @@ -16,7 +16,7 @@ "@renderer/*": ["src/renderer/src/*"], "@shared/*": ["packages/shared/*"], "@types": ["src/renderer/src/types/index.ts"], - "@cherrystudio/ai-core": ["packages/aiCore/src/"] + "@cherrystudio/ai-core/*": ["packages/aiCore/src/*"] } } }