feat: enhance MCP Prompt plugin with recursive call support and context handling

- Updated `AiRequestContext` to enforce `recursiveCall` and added `isRecursiveCall` for better state management.
- Modified `createContext` to initialize `recursiveCall` with a placeholder function.
- Enhanced `MCPPromptPlugin` to utilize a custom `createSystemMessage` function for improved message handling during recursive calls.
- Refactored `PluginEngine` to manage recursive call states, ensuring proper execution flow and context integrity.
This commit is contained in:
lizhixuan 2025-06-26 23:48:06 +08:00
parent ba121d04b4
commit c934b45c09
5 changed files with 52 additions and 18 deletions

View File

@ -3,7 +3,8 @@
* Function Call prompt
*
*/
import { ToolExecutionError, ToolSet } from 'ai'
import type { ToolSet } from 'ai'
import { ToolExecutionError } from 'ai'
import { definePlugin } from '../index'
import type { AiRequestContext } from '../types'
@ -46,6 +47,7 @@ export interface MCPPromptConfig {
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => Promise<string>
// 自定义工具解析函数(可选,有默认实现)
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
createSystemMessage?: (systemPrompt: string, originalParams: any, context: MCPRequestContext) => string | null
}
/**
@ -302,11 +304,17 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
// 构建系统提示符
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
const systemPrompt = await buildSystemPrompt(userSystemPrompt, params.tools)
let systemMessage: string | null = systemPrompt
console.log('config.context', context)
if (config.createSystemMessage) {
// 🎯 如果用户提供了自定义处理函数,使用它
systemMessage = config.createSystemMessage(systemPrompt, params, context)
}
// 移除 tools改为 prompt 模式
const transformedParams = {
...params,
system: systemPrompt,
...(systemMessage ? { system: systemMessage } : {}),
tools: undefined
}
console.log('transformedParams', transformedParams)
@ -457,7 +465,7 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
})
// 递归调用逻辑
if (context.recursiveCall && validToolUses.length > 0) {
if (validToolUses.length > 0) {
console.log('[MCP Prompt] Starting recursive call after tool execution...')
// 构建工具结果的文本表示使用Cherry Studio标准格式
@ -471,7 +479,7 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
}
})
.join('\n\n')
console.log('context.originalParams.messages', context.originalParams.messages)
// 构建新的对话消息
const newMessages = [
...(context.originalParams.messages || []),
@ -491,6 +499,7 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
messages: newMessages,
tools: tools
}
context.originalParams.messages = newMessages
try {
const recursiveResult = await context.recursiveCall(recursiveParams)

View File

@ -14,7 +14,8 @@ export function createContext(providerId: string, modelId: string, originalParam
metadata: {},
startTime: Date.now(),
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
recursiveCall: undefined
// 占位
recursiveCall: () => Promise.resolve(null)
}
}

View File

@ -16,7 +16,8 @@ export interface AiRequestContext {
metadata: Record<string, any>
startTime: number
requestId: string
recursiveCall?: RecursiveCallFn
recursiveCall: RecursiveCallFn
isRecursiveCall?: boolean
[key: string]: any
}

View File

@ -70,9 +70,12 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
const context = _context ? _context : createContext(this.providerId, modelId, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = (newParams: any): Promise<TResult> => {
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
return this.executeWithPlugins(methodName, modelId, newParams, executor, context)
context.isRecursiveCall = true
const result = await this.executeWithPlugins(methodName, modelId, newParams, executor, context)
context.isRecursiveCall = false
return result
}
try {
@ -111,15 +114,19 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
methodName: string,
modelId: string,
params: TParams,
executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>
executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 创建请求上下文
const context = createContext(this.providerId, modelId, params)
const context = _context ? _context : createContext(this.providerId, modelId, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = (newParams: any): Promise<TResult> => {
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
return this.executeStreamWithPlugins(methodName, modelId, newParams, executor)
context.isRecursiveCall = true
const result = await this.executeStreamWithPlugins(methodName, modelId, newParams, executor, context)
context.isRecursiveCall = false
return result
}
try {

View File

@ -115,15 +115,9 @@ export default class ModernAiProvider {
// 初始化时不构建中间件,等到需要时再构建
const config = providerToAiSdkConfig(provider)
// 创建MCP Prompt插件
const mcpPromptPlugin = createMCPPromptPlugin({
enabled: true
})
console.log('[Modern AI Provider] Creating executor with MCP Prompt plugin enabled')
this.modernExecutor = createExecutor(config.providerId, config.options, [
mcpPromptPlugin,
reasonPlugin({
delayInMs: 80,
chunkingRegex: /([\u4E00-\u9FFF]{3})|\S+\s+/
@ -184,6 +178,28 @@ export default class ModernAiProvider {
if (middlewareConfig.onChunk) {
// 流式处理 - 使用适配器
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
// 创建MCP Prompt插件
const mcpPromptPlugin = createMCPPromptPlugin({
enabled: true,
createSystemMessage: (systemPrompt, params, context) => {
console.log('createSystemMessage_context', context.isRecursiveCall)
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
if (context.isRecursiveCall) {
return null
}
params.messages = [
{
role: 'assistant',
content: systemPrompt
},
...params.messages
]
return null
}
return systemPrompt
}
})
this.modernExecutor.pluginEngine.use(mcpPromptPlugin)
const streamResult = await this.modernExecutor.streamText(
modelId,
params,