From 8b67a4580407ab13d4a95b3ccf84407728905fa3 Mon Sep 17 00:00:00 2001 From: lizhixuan Date: Thu, 26 Jun 2025 00:10:39 +0800 Subject: [PATCH] refactor: update RuntimeExecutor and introduce MCP Prompt Plugin - Changed `pluginClient` to `pluginEngine` in `RuntimeExecutor` for clarity and consistency. - Updated method calls in `RuntimeExecutor` to use the new `pluginEngine`. - Enhanced `AiSdkMiddlewareBuilder` to include `mcpTools` in the middleware configuration. - Added `MCPPromptPlugin` to support tool calls within prompts, enabling recursive processing and improved handling of tool interactions. - Updated `ApiService` to pass `mcpTools` during chat completion requests, enhancing integration with the new plugin system. --- packages/aiCore/src/core/runtime/executor.ts | 12 +- src/renderer/src/aiCore/index_new.ts | 15 +- .../aisdk/AiSdkMiddlewareBuilder.ts | 3 +- .../src/aiCore/plugins/mcpPromptPlugin.ts | 257 ++++++++++++++++++ src/renderer/src/services/ApiService.ts | 3 +- 5 files changed, 280 insertions(+), 10 deletions(-) create mode 100644 src/renderer/src/aiCore/plugins/mcpPromptPlugin.ts diff --git a/packages/aiCore/src/core/runtime/executor.ts b/packages/aiCore/src/core/runtime/executor.ts index 1df259a9ba..7a67d99a79 100644 --- a/packages/aiCore/src/core/runtime/executor.ts +++ b/packages/aiCore/src/core/runtime/executor.ts @@ -12,7 +12,7 @@ import { PluginEngine } from './pluginEngine' import { type RuntimeConfig } from './types' export class RuntimeExecutor { - private pluginClient: PluginEngine + public pluginEngine: PluginEngine // private options: ProviderSettingsMap[T] private config: RuntimeConfig @@ -25,7 +25,7 @@ export class RuntimeExecutor { // this.options = config.options this.config = config // 创建插件客户端 - this.pluginClient = new PluginEngine(config.providerId, config.plugins || []) + this.pluginEngine = new PluginEngine(config.providerId, config.plugins || []) } // === 高阶重载:直接使用模型 === @@ -62,7 +62,7 @@ export class RuntimeExecutor { const model = await this.resolveModel(modelOrId, options?.middlewares) // 2. 执行插件处理 - return this.pluginClient.executeStreamWithPlugins( + return this.pluginEngine.executeStreamWithPlugins( 'streamText', typeof modelOrId === 'string' ? modelOrId : model.modelId, params, @@ -112,7 +112,7 @@ export class RuntimeExecutor { ): Promise> { const model = await this.resolveModel(modelOrId, options?.middlewares) - return this.pluginClient.executeWithPlugins( + return this.pluginEngine.executeWithPlugins( 'generateText', typeof modelOrId === 'string' ? modelOrId : model.modelId, params, @@ -153,7 +153,7 @@ export class RuntimeExecutor { ): Promise> { const model = await this.resolveModel(modelOrId, options?.middlewares) - return this.pluginClient.executeWithPlugins( + return this.pluginEngine.executeWithPlugins( 'generateObject', typeof modelOrId === 'string' ? modelOrId : model.modelId, params, @@ -194,7 +194,7 @@ export class RuntimeExecutor { ): Promise> { const model = await this.resolveModel(modelOrId, options?.middlewares) - return this.pluginClient.executeWithPlugins( + return this.pluginEngine.executeWithPlugins( 'streamObject', typeof modelOrId === 'string' ? modelOrId : model.modelId, params, diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index c99b71c71d..c768cace77 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -53,8 +53,9 @@ function providerToAiSdkConfig(provider: Provider): { if (aiSdkProviderId !== 'openai-compatible') { const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, { - ...actualProvider, - baseURL: actualProvider.apiHost + ...actualProvider + // 使用ai-sdk内置的baseURL + // baseURL: actualProvider.apiHost }) return { @@ -173,6 +174,16 @@ export default class ModernAiProvider { if (middlewareConfig.onChunk) { // 流式处理 - 使用适配器 const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk) + // this.modernExecutor.pluginEngine.use( + // createMCPPromptPlugin({ + // mcpTools: middlewareConfig.mcpTools || [], + // assistant: params.assistant, + // onChunk: middlewareConfig.onChunk, + // recursiveCall: this.modernExecutor.streamText, + // recursionDepth: 0, + // maxRecursionDepth: 20 + // }) + // ) const streamResult = await this.modernExecutor.streamText( modelId, params, diff --git a/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts index 775ccf3451..77efe9a7e6 100644 --- a/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts @@ -4,7 +4,7 @@ import { simulateStreamingMiddleware } from '@cherrystudio/ai-core' import { isReasoningModel } from '@renderer/config/models' -import type { Model, Provider } from '@renderer/types' +import type { MCPTool, Model, Provider } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' import thinkingTimeMiddleware from './ThinkingTimeMiddleware' @@ -20,6 +20,7 @@ export interface AiSdkMiddlewareConfig { enableReasoning?: boolean enableTool?: boolean enableWebSearch?: boolean + mcpTools?: MCPTool[] } /** diff --git a/src/renderer/src/aiCore/plugins/mcpPromptPlugin.ts b/src/renderer/src/aiCore/plugins/mcpPromptPlugin.ts new file mode 100644 index 0000000000..c3a553ad63 --- /dev/null +++ b/src/renderer/src/aiCore/plugins/mcpPromptPlugin.ts @@ -0,0 +1,257 @@ +import { definePlugin, StreamTextParams } from '@cherrystudio/ai-core' +import { buildSystemPromptWithTools } from '@renderer/aiCore/transformParameters' +import { buildStreamTextParams } from '@renderer/aiCore/transformParameters' +import { Assistant, MCPTool, MCPToolResponse } from '@renderer/types' +import { parseAndCallTools } from '@renderer/utils/mcp-tools' + +/** + * MCP Prompt 插件配置 + */ +export interface MCPPromptPluginConfig { + mcpTools: MCPTool[] + assistant: Assistant + onChunk: (chunk: any) => void + recursiveCall: (params: StreamTextParams) => Promise<{ stream?: ReadableStream; getText?: () => string }> + recursionDepth?: number // 当前递归深度,默认为 0 + maxRecursionDepth?: number // 最大递归深度,默认为 20 +} + +/** + * 创建 MCP Prompt 模式插件 + * 支持在 prompt 模式下解析文本中的工具调用并执行 + */ +export const createMCPPromptPlugin = (config: MCPPromptPluginConfig) => { + return definePlugin({ + name: 'mcp-prompt-plugin', + + // 1. 参数转换 - 注入工具描述到系统提示 + transformParams: async (params: StreamTextParams) => { + const { mcpTools, assistant } = config + + if (mcpTools.length === 0) { + return params + } + + try { + // 复用现有的系统提示构建逻辑 + const enhancedSystemPrompt = await buildSystemPromptWithTools(params.system || '', mcpTools, assistant) + + return { + ...params, + system: enhancedSystemPrompt, + // Prompt 模式不使用 function calling + tools: undefined + } + } catch (error) { + console.error('构建系统提示失败:', error) + return params + } + }, + + // 2. 流处理 - 检测工具调用并执行 + transformStream: () => { + let fullResponseText = '' + let hasProcessedTools = false + + return new TransformStream({ + async transform(chunk, controller) { + try { + // 收集完整的文本响应 + if (chunk.type === 'text-delta') { + fullResponseText += chunk.textDelta + } + + // 在流结束时检查并处理工具调用 + if (chunk.type === 'finish' && !hasProcessedTools) { + hasProcessedTools = true + + if (containsToolCallPattern(fullResponseText)) { + await processToolCallsAndRecurse(fullResponseText, config, controller) + return // 不转发 finish chunk,让递归调用处理 + } + } + + // 正常转发其他类型的 chunk + controller.enqueue(chunk) + } catch (error) { + console.error('MCP Prompt Plugin Transform Error:', error) + controller.error(error) + } + }, + + async flush(controller) { + // 流结束时的最后检查 + if (!hasProcessedTools && containsToolCallPattern(fullResponseText)) { + await processToolCallsAndRecurse(fullResponseText, config, controller) + } + } + }) + } + }) +} + +/** + * 处理工具调用并执行递归 + */ +async function processToolCallsAndRecurse( + responseText: string, + config: MCPPromptPluginConfig, + controller: TransformStreamDefaultController +) { + const { mcpTools, assistant, onChunk, recursionDepth = 0, maxRecursionDepth = 20 } = config + + // 检查是否超过最大递归深度 + if (recursionDepth >= maxRecursionDepth) { + console.log(`已达到最大递归深度 ${maxRecursionDepth},停止工具调用处理`) + controller.enqueue({ + type: 'text-delta', + textDelta: `\n\n[已达到最大工具调用深度 ${maxRecursionDepth},停止继续调用]` + }) + return + } + + try { + console.log(`检测到工具调用,开始处理... (递归深度: ${recursionDepth}/${maxRecursionDepth})`) + + const allToolResponses: MCPToolResponse[] = [] + + // 直接使用现有的 parseAndCallTools 函数 + // 它会自动解析文本中的工具调用、执行工具、触发 onChunk + const toolResults = await parseAndCallTools( + responseText, // 传入完整响应文本,让 parseAndCallTools 自己解析 + allToolResponses, + onChunk, // 直接传入来自配置的 onChunk + (mcpToolResponse, resp) => { + // 复用现有的消息转换逻辑 + return convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp) + }, + assistant.model!, + mcpTools + ) + + console.log('工具执行完成,结果数量:', toolResults.length) + + // 如果有工具结果,构建新消息并递归调用 + if (toolResults.length > 0) { + await performRecursiveCall(responseText, toolResults, config, controller) + } + } catch (error) { + console.error('工具调用处理失败:', error) + + // 发送错误信息作为文本 chunk + controller.enqueue({ + type: 'text-delta', + textDelta: `\n\n[工具调用错误: ${error instanceof Error ? error.message : String(error)}]` + }) + } +} + +/** + * 执行递归调用 + */ +async function performRecursiveCall( + originalResponse: string, + toolResults: any[], + config: MCPPromptPluginConfig, + controller: TransformStreamDefaultController +) { + const { assistant, recursiveCall, recursionDepth = 0 } = config + + try { + // 获取当前的消息历史(需要从上下文获取,这里暂时用空数组) + // TODO: 实现从上下文获取当前消息的逻辑 + const currentMessages = getCurrentMessagesFromContext() + + // 构建新的消息历史 + const newMessages = [ + ...currentMessages, + { + role: 'assistant' as const, + content: originalResponse + }, + ...toolResults // toolResults 已经是正确的消息格式 + ] + + console.log(`构建新消息历史完成,消息数量: ${newMessages.length},递归深度: ${recursionDepth}`) + + // 复用现有的参数构建逻辑 + const { params: recursiveParams } = await buildStreamTextParams(newMessages, assistant, { + mcpTools: config.mcpTools, + enableTools: true + }) + + console.log(`开始递归调用... (深度: ${recursionDepth + 1})`) + + // 递归调用,递增递归深度 + const recursiveResult = await recursiveCall(recursiveParams) + + // 转发递归结果的流 + if (recursiveResult.stream) { + const reader = recursiveResult.stream.getReader() + try { + while (true) { + const { done, value } = await reader.read() + if (done) break + controller.enqueue(value) + } + } finally { + reader.releaseLock() + } + } else if (recursiveResult.getText) { + // 如果没有流,但有文本结果 + const finalText = recursiveResult.getText() + controller.enqueue({ + type: 'text-delta', + textDelta: finalText + }) + } + + console.log(`递归调用完成 (深度: ${recursionDepth + 1})`) + } catch (error) { + console.error('递归调用失败:', error) + + controller.enqueue({ + type: 'text-delta', + textDelta: `\n\n[递归调用错误: ${error instanceof Error ? error.message : String(error)}]` + }) + } +} + +/** + * 检查文本是否包含工具调用模式 + */ +function containsToolCallPattern(text: string): boolean { + const patterns = [ + //i, + //i + // 可以根据实际使用的格式添加更多模式 + ] + + return patterns.some((pattern) => pattern.test(text)) +} + +/** + * 从上下文获取当前消息历史 + * TODO: 实现从实际上下文获取消息的逻辑 + */ +function getCurrentMessagesFromContext(): any[] { + // 这里需要实现从上下文获取当前消息历史的逻辑 + // 暂时返回空数组,后续根据实际情况补充 + console.warn('getCurrentMessagesFromContext 尚未实现,返回空数组') + return [] +} + +/** + * 转换 MCP 工具响应为 SDK 消息参数 + * 复用现有的转换逻辑 + */ +function convertMcpToolResponseToSdkMessageParam(mcpToolResponse: MCPToolResponse, resp: any): any { + // 这里需要根据实际的转换逻辑来实现 + // 暂时返回一个基础的用户消息格式 + return { + role: 'user', + content: `工具 ${mcpToolResponse.tool.name} 执行结果: ${JSON.stringify(resp)}` + } +} + +export default createMCPPromptPlugin diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 79ead1d949..8d4308cf97 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -310,7 +310,8 @@ export async function fetchChatCompletion({ onChunk: onChunkReceived, model: assistant.model, provider: provider, - enableReasoning: assistant.settings?.reasoning_effort !== undefined + enableReasoning: assistant.settings?.reasoning_effort !== undefined, + mcpTools } // --- Call AI Completions ---