diff --git a/packages/aiCore/src/core/plugins/built-in/mcpPrompt.ts b/packages/aiCore/src/core/plugins/built-in/mcpPrompt.ts index 6ab295b001..e78ff02b5e 100644 --- a/packages/aiCore/src/core/plugins/built-in/mcpPrompt.ts +++ b/packages/aiCore/src/core/plugins/built-in/mcpPrompt.ts @@ -3,7 +3,8 @@ * 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用 * 内置默认逻辑,支持自定义覆盖 */ -import { ToolExecutionError, ToolSet } from 'ai' +import type { ToolSet } from 'ai' +import { ToolExecutionError } from 'ai' import { definePlugin } from '../index' import type { AiRequestContext } from '../types' @@ -46,6 +47,7 @@ export interface MCPPromptConfig { buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => Promise // 自定义工具解析函数(可选,有默认实现) parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[] + createSystemMessage?: (systemPrompt: string, originalParams: any, context: MCPRequestContext) => string | null } /** @@ -302,11 +304,17 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {}) // 构建系统提示符 const userSystemPrompt = typeof params.system === 'string' ? params.system : '' const systemPrompt = await buildSystemPrompt(userSystemPrompt, params.tools) + let systemMessage: string | null = systemPrompt + console.log('config.context', context) + if (config.createSystemMessage) { + // 🎯 如果用户提供了自定义处理函数,使用它 + systemMessage = config.createSystemMessage(systemPrompt, params, context) + } // 移除 tools,改为 prompt 模式 const transformedParams = { ...params, - system: systemPrompt, + ...(systemMessage ? { system: systemMessage } : {}), tools: undefined } console.log('transformedParams', transformedParams) @@ -457,7 +465,7 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {}) }) // 递归调用逻辑 - if (context.recursiveCall && validToolUses.length > 0) { + if (validToolUses.length > 0) { console.log('[MCP Prompt] Starting recursive call after tool execution...') // 构建工具结果的文本表示,使用Cherry Studio标准格式 @@ -471,7 +479,7 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {}) } }) .join('\n\n') - + console.log('context.originalParams.messages', context.originalParams.messages) // 构建新的对话消息 const newMessages = [ ...(context.originalParams.messages || []), @@ -491,6 +499,7 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {}) messages: newMessages, tools: tools } + context.originalParams.messages = newMessages try { const recursiveResult = await context.recursiveCall(recursiveParams) diff --git a/packages/aiCore/src/core/plugins/index.ts b/packages/aiCore/src/core/plugins/index.ts index e78b57a4eb..06833fbc99 100644 --- a/packages/aiCore/src/core/plugins/index.ts +++ b/packages/aiCore/src/core/plugins/index.ts @@ -14,7 +14,8 @@ export function createContext(providerId: string, modelId: string, originalParam metadata: {}, startTime: Date.now(), requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`, - recursiveCall: undefined + // 占位 + recursiveCall: () => Promise.resolve(null) } } diff --git a/packages/aiCore/src/core/plugins/types.ts b/packages/aiCore/src/core/plugins/types.ts index 4ad3490f1b..2e4e18b413 100644 --- a/packages/aiCore/src/core/plugins/types.ts +++ b/packages/aiCore/src/core/plugins/types.ts @@ -16,7 +16,8 @@ export interface AiRequestContext { metadata: Record startTime: number requestId: string - recursiveCall?: RecursiveCallFn + recursiveCall: RecursiveCallFn + isRecursiveCall?: boolean [key: string]: any } diff --git a/packages/aiCore/src/core/runtime/pluginEngine.ts b/packages/aiCore/src/core/runtime/pluginEngine.ts index ca21ef3686..dd7f0b7d9b 100644 --- a/packages/aiCore/src/core/runtime/pluginEngine.ts +++ b/packages/aiCore/src/core/runtime/pluginEngine.ts @@ -70,9 +70,12 @@ export class PluginEngine { const context = _context ? _context : createContext(this.providerId, modelId, params) // 🔥 为上下文添加递归调用能力 - context.recursiveCall = (newParams: any): Promise => { + context.recursiveCall = async (newParams: any): Promise => { // 递归调用自身,重新走完整的插件流程 - return this.executeWithPlugins(methodName, modelId, newParams, executor, context) + context.isRecursiveCall = true + const result = await this.executeWithPlugins(methodName, modelId, newParams, executor, context) + context.isRecursiveCall = false + return result } try { @@ -111,15 +114,19 @@ export class PluginEngine { methodName: string, modelId: string, params: TParams, - executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise + executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise, + _context?: ReturnType ): Promise { // 创建请求上下文 - const context = createContext(this.providerId, modelId, params) + const context = _context ? _context : createContext(this.providerId, modelId, params) // 🔥 为上下文添加递归调用能力 - context.recursiveCall = (newParams: any): Promise => { + context.recursiveCall = async (newParams: any): Promise => { // 递归调用自身,重新走完整的插件流程 - return this.executeStreamWithPlugins(methodName, modelId, newParams, executor) + context.isRecursiveCall = true + const result = await this.executeStreamWithPlugins(methodName, modelId, newParams, executor, context) + context.isRecursiveCall = false + return result } try { diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index e1486400c6..edb6d93900 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -115,15 +115,9 @@ export default class ModernAiProvider { // 初始化时不构建中间件,等到需要时再构建 const config = providerToAiSdkConfig(provider) - // 创建MCP Prompt插件 - const mcpPromptPlugin = createMCPPromptPlugin({ - enabled: true - }) - console.log('[Modern AI Provider] Creating executor with MCP Prompt plugin enabled') this.modernExecutor = createExecutor(config.providerId, config.options, [ - mcpPromptPlugin, reasonPlugin({ delayInMs: 80, chunkingRegex: /([\u4E00-\u9FFF]{3})|\S+\s+/ @@ -184,6 +178,28 @@ export default class ModernAiProvider { if (middlewareConfig.onChunk) { // 流式处理 - 使用适配器 const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk) + // 创建MCP Prompt插件 + const mcpPromptPlugin = createMCPPromptPlugin({ + enabled: true, + createSystemMessage: (systemPrompt, params, context) => { + console.log('createSystemMessage_context', context.isRecursiveCall) + if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) { + if (context.isRecursiveCall) { + return null + } + params.messages = [ + { + role: 'assistant', + content: systemPrompt + }, + ...params.messages + ] + return null + } + return systemPrompt + } + }) + this.modernExecutor.pluginEngine.use(mcpPromptPlugin) const streamResult = await this.modernExecutor.streamText( modelId, params,