mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 13:31:32 +08:00
feat: enhance MCP Prompt plugin with recursive call support and context handling
- Updated `AiRequestContext` to enforce `recursiveCall` and added `isRecursiveCall` for better state management. - Modified `createContext` to initialize `recursiveCall` with a placeholder function. - Enhanced `MCPPromptPlugin` to utilize a custom `createSystemMessage` function for improved message handling during recursive calls. - Refactored `PluginEngine` to manage recursive call states, ensuring proper execution flow and context integrity.
This commit is contained in:
parent
ba121d04b4
commit
c934b45c09
@ -3,7 +3,8 @@
|
||||
* 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用
|
||||
* 内置默认逻辑,支持自定义覆盖
|
||||
*/
|
||||
import { ToolExecutionError, ToolSet } from 'ai'
|
||||
import type { ToolSet } from 'ai'
|
||||
import { ToolExecutionError } from 'ai'
|
||||
|
||||
import { definePlugin } from '../index'
|
||||
import type { AiRequestContext } from '../types'
|
||||
@ -46,6 +47,7 @@ export interface MCPPromptConfig {
|
||||
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => Promise<string>
|
||||
// 自定义工具解析函数(可选,有默认实现)
|
||||
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
|
||||
createSystemMessage?: (systemPrompt: string, originalParams: any, context: MCPRequestContext) => string | null
|
||||
}
|
||||
|
||||
/**
|
||||
@ -302,11 +304,17 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
||||
// 构建系统提示符
|
||||
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||
const systemPrompt = await buildSystemPrompt(userSystemPrompt, params.tools)
|
||||
let systemMessage: string | null = systemPrompt
|
||||
console.log('config.context', context)
|
||||
if (config.createSystemMessage) {
|
||||
// 🎯 如果用户提供了自定义处理函数,使用它
|
||||
systemMessage = config.createSystemMessage(systemPrompt, params, context)
|
||||
}
|
||||
|
||||
// 移除 tools,改为 prompt 模式
|
||||
const transformedParams = {
|
||||
...params,
|
||||
system: systemPrompt,
|
||||
...(systemMessage ? { system: systemMessage } : {}),
|
||||
tools: undefined
|
||||
}
|
||||
console.log('transformedParams', transformedParams)
|
||||
@ -457,7 +465,7 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
||||
})
|
||||
|
||||
// 递归调用逻辑
|
||||
if (context.recursiveCall && validToolUses.length > 0) {
|
||||
if (validToolUses.length > 0) {
|
||||
console.log('[MCP Prompt] Starting recursive call after tool execution...')
|
||||
|
||||
// 构建工具结果的文本表示,使用Cherry Studio标准格式
|
||||
@ -471,7 +479,7 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
||||
}
|
||||
})
|
||||
.join('\n\n')
|
||||
|
||||
console.log('context.originalParams.messages', context.originalParams.messages)
|
||||
// 构建新的对话消息
|
||||
const newMessages = [
|
||||
...(context.originalParams.messages || []),
|
||||
@ -491,6 +499,7 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
||||
messages: newMessages,
|
||||
tools: tools
|
||||
}
|
||||
context.originalParams.messages = newMessages
|
||||
|
||||
try {
|
||||
const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||
|
||||
@ -14,7 +14,8 @@ export function createContext(providerId: string, modelId: string, originalParam
|
||||
metadata: {},
|
||||
startTime: Date.now(),
|
||||
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||
recursiveCall: undefined
|
||||
// 占位
|
||||
recursiveCall: () => Promise.resolve(null)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -16,7 +16,8 @@ export interface AiRequestContext {
|
||||
metadata: Record<string, any>
|
||||
startTime: number
|
||||
requestId: string
|
||||
recursiveCall?: RecursiveCallFn
|
||||
recursiveCall: RecursiveCallFn
|
||||
isRecursiveCall?: boolean
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
|
||||
@ -70,9 +70,12 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = (newParams: any): Promise<TResult> => {
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
return this.executeWithPlugins(methodName, modelId, newParams, executor, context)
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeWithPlugins(methodName, modelId, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
|
||||
try {
|
||||
@ -111,15 +114,19 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
methodName: string,
|
||||
modelId: string,
|
||||
params: TParams,
|
||||
executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>
|
||||
executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 创建请求上下文
|
||||
const context = createContext(this.providerId, modelId, params)
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = (newParams: any): Promise<TResult> => {
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
return this.executeStreamWithPlugins(methodName, modelId, newParams, executor)
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeStreamWithPlugins(methodName, modelId, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
@ -115,15 +115,9 @@ export default class ModernAiProvider {
|
||||
// 初始化时不构建中间件,等到需要时再构建
|
||||
const config = providerToAiSdkConfig(provider)
|
||||
|
||||
// 创建MCP Prompt插件
|
||||
const mcpPromptPlugin = createMCPPromptPlugin({
|
||||
enabled: true
|
||||
})
|
||||
|
||||
console.log('[Modern AI Provider] Creating executor with MCP Prompt plugin enabled')
|
||||
|
||||
this.modernExecutor = createExecutor(config.providerId, config.options, [
|
||||
mcpPromptPlugin,
|
||||
reasonPlugin({
|
||||
delayInMs: 80,
|
||||
chunkingRegex: /([\u4E00-\u9FFF]{3})|\S+\s+/
|
||||
@ -184,6 +178,28 @@ export default class ModernAiProvider {
|
||||
if (middlewareConfig.onChunk) {
|
||||
// 流式处理 - 使用适配器
|
||||
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
|
||||
// 创建MCP Prompt插件
|
||||
const mcpPromptPlugin = createMCPPromptPlugin({
|
||||
enabled: true,
|
||||
createSystemMessage: (systemPrompt, params, context) => {
|
||||
console.log('createSystemMessage_context', context.isRecursiveCall)
|
||||
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
|
||||
if (context.isRecursiveCall) {
|
||||
return null
|
||||
}
|
||||
params.messages = [
|
||||
{
|
||||
role: 'assistant',
|
||||
content: systemPrompt
|
||||
},
|
||||
...params.messages
|
||||
]
|
||||
return null
|
||||
}
|
||||
return systemPrompt
|
||||
}
|
||||
})
|
||||
this.modernExecutor.pluginEngine.use(mcpPromptPlugin)
|
||||
const streamResult = await this.modernExecutor.streamText(
|
||||
modelId,
|
||||
params,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user