mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-07 22:10:21 +08:00
feat: enhance MCP Prompt plugin with recursive call support and context handling
- Updated `AiRequestContext` to enforce `recursiveCall` and added `isRecursiveCall` for better state management. - Modified `createContext` to initialize `recursiveCall` with a placeholder function. - Enhanced `MCPPromptPlugin` to utilize a custom `createSystemMessage` function for improved message handling during recursive calls. - Refactored `PluginEngine` to manage recursive call states, ensuring proper execution flow and context integrity.
This commit is contained in:
parent
ba121d04b4
commit
c934b45c09
@ -3,7 +3,8 @@
|
|||||||
* 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用
|
* 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用
|
||||||
* 内置默认逻辑,支持自定义覆盖
|
* 内置默认逻辑,支持自定义覆盖
|
||||||
*/
|
*/
|
||||||
import { ToolExecutionError, ToolSet } from 'ai'
|
import type { ToolSet } from 'ai'
|
||||||
|
import { ToolExecutionError } from 'ai'
|
||||||
|
|
||||||
import { definePlugin } from '../index'
|
import { definePlugin } from '../index'
|
||||||
import type { AiRequestContext } from '../types'
|
import type { AiRequestContext } from '../types'
|
||||||
@ -46,6 +47,7 @@ export interface MCPPromptConfig {
|
|||||||
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => Promise<string>
|
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => Promise<string>
|
||||||
// 自定义工具解析函数(可选,有默认实现)
|
// 自定义工具解析函数(可选,有默认实现)
|
||||||
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
|
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
|
||||||
|
createSystemMessage?: (systemPrompt: string, originalParams: any, context: MCPRequestContext) => string | null
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -302,11 +304,17 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
|||||||
// 构建系统提示符
|
// 构建系统提示符
|
||||||
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||||
const systemPrompt = await buildSystemPrompt(userSystemPrompt, params.tools)
|
const systemPrompt = await buildSystemPrompt(userSystemPrompt, params.tools)
|
||||||
|
let systemMessage: string | null = systemPrompt
|
||||||
|
console.log('config.context', context)
|
||||||
|
if (config.createSystemMessage) {
|
||||||
|
// 🎯 如果用户提供了自定义处理函数,使用它
|
||||||
|
systemMessage = config.createSystemMessage(systemPrompt, params, context)
|
||||||
|
}
|
||||||
|
|
||||||
// 移除 tools,改为 prompt 模式
|
// 移除 tools,改为 prompt 模式
|
||||||
const transformedParams = {
|
const transformedParams = {
|
||||||
...params,
|
...params,
|
||||||
system: systemPrompt,
|
...(systemMessage ? { system: systemMessage } : {}),
|
||||||
tools: undefined
|
tools: undefined
|
||||||
}
|
}
|
||||||
console.log('transformedParams', transformedParams)
|
console.log('transformedParams', transformedParams)
|
||||||
@ -457,7 +465,7 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
|||||||
})
|
})
|
||||||
|
|
||||||
// 递归调用逻辑
|
// 递归调用逻辑
|
||||||
if (context.recursiveCall && validToolUses.length > 0) {
|
if (validToolUses.length > 0) {
|
||||||
console.log('[MCP Prompt] Starting recursive call after tool execution...')
|
console.log('[MCP Prompt] Starting recursive call after tool execution...')
|
||||||
|
|
||||||
// 构建工具结果的文本表示,使用Cherry Studio标准格式
|
// 构建工具结果的文本表示,使用Cherry Studio标准格式
|
||||||
@ -471,7 +479,7 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
.join('\n\n')
|
.join('\n\n')
|
||||||
|
console.log('context.originalParams.messages', context.originalParams.messages)
|
||||||
// 构建新的对话消息
|
// 构建新的对话消息
|
||||||
const newMessages = [
|
const newMessages = [
|
||||||
...(context.originalParams.messages || []),
|
...(context.originalParams.messages || []),
|
||||||
@ -491,6 +499,7 @@ export const createMCPPromptPlugin = definePlugin((config: MCPPromptConfig = {})
|
|||||||
messages: newMessages,
|
messages: newMessages,
|
||||||
tools: tools
|
tools: tools
|
||||||
}
|
}
|
||||||
|
context.originalParams.messages = newMessages
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const recursiveResult = await context.recursiveCall(recursiveParams)
|
const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||||
|
|||||||
@ -14,7 +14,8 @@ export function createContext(providerId: string, modelId: string, originalParam
|
|||||||
metadata: {},
|
metadata: {},
|
||||||
startTime: Date.now(),
|
startTime: Date.now(),
|
||||||
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||||
recursiveCall: undefined
|
// 占位
|
||||||
|
recursiveCall: () => Promise.resolve(null)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,8 @@ export interface AiRequestContext {
|
|||||||
metadata: Record<string, any>
|
metadata: Record<string, any>
|
||||||
startTime: number
|
startTime: number
|
||||||
requestId: string
|
requestId: string
|
||||||
recursiveCall?: RecursiveCallFn
|
recursiveCall: RecursiveCallFn
|
||||||
|
isRecursiveCall?: boolean
|
||||||
[key: string]: any
|
[key: string]: any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -70,9 +70,12 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
|||||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||||
|
|
||||||
// 🔥 为上下文添加递归调用能力
|
// 🔥 为上下文添加递归调用能力
|
||||||
context.recursiveCall = (newParams: any): Promise<TResult> => {
|
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||||
// 递归调用自身,重新走完整的插件流程
|
// 递归调用自身,重新走完整的插件流程
|
||||||
return this.executeWithPlugins(methodName, modelId, newParams, executor, context)
|
context.isRecursiveCall = true
|
||||||
|
const result = await this.executeWithPlugins(methodName, modelId, newParams, executor, context)
|
||||||
|
context.isRecursiveCall = false
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -111,15 +114,19 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
|||||||
methodName: string,
|
methodName: string,
|
||||||
modelId: string,
|
modelId: string,
|
||||||
params: TParams,
|
params: TParams,
|
||||||
executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>
|
executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
|
||||||
|
_context?: ReturnType<typeof createContext>
|
||||||
): Promise<TResult> {
|
): Promise<TResult> {
|
||||||
// 创建请求上下文
|
// 创建请求上下文
|
||||||
const context = createContext(this.providerId, modelId, params)
|
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||||
|
|
||||||
// 🔥 为上下文添加递归调用能力
|
// 🔥 为上下文添加递归调用能力
|
||||||
context.recursiveCall = (newParams: any): Promise<TResult> => {
|
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||||
// 递归调用自身,重新走完整的插件流程
|
// 递归调用自身,重新走完整的插件流程
|
||||||
return this.executeStreamWithPlugins(methodName, modelId, newParams, executor)
|
context.isRecursiveCall = true
|
||||||
|
const result = await this.executeStreamWithPlugins(methodName, modelId, newParams, executor, context)
|
||||||
|
context.isRecursiveCall = false
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|||||||
@ -115,15 +115,9 @@ export default class ModernAiProvider {
|
|||||||
// 初始化时不构建中间件,等到需要时再构建
|
// 初始化时不构建中间件,等到需要时再构建
|
||||||
const config = providerToAiSdkConfig(provider)
|
const config = providerToAiSdkConfig(provider)
|
||||||
|
|
||||||
// 创建MCP Prompt插件
|
|
||||||
const mcpPromptPlugin = createMCPPromptPlugin({
|
|
||||||
enabled: true
|
|
||||||
})
|
|
||||||
|
|
||||||
console.log('[Modern AI Provider] Creating executor with MCP Prompt plugin enabled')
|
console.log('[Modern AI Provider] Creating executor with MCP Prompt plugin enabled')
|
||||||
|
|
||||||
this.modernExecutor = createExecutor(config.providerId, config.options, [
|
this.modernExecutor = createExecutor(config.providerId, config.options, [
|
||||||
mcpPromptPlugin,
|
|
||||||
reasonPlugin({
|
reasonPlugin({
|
||||||
delayInMs: 80,
|
delayInMs: 80,
|
||||||
chunkingRegex: /([\u4E00-\u9FFF]{3})|\S+\s+/
|
chunkingRegex: /([\u4E00-\u9FFF]{3})|\S+\s+/
|
||||||
@ -184,6 +178,28 @@ export default class ModernAiProvider {
|
|||||||
if (middlewareConfig.onChunk) {
|
if (middlewareConfig.onChunk) {
|
||||||
// 流式处理 - 使用适配器
|
// 流式处理 - 使用适配器
|
||||||
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
|
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
|
||||||
|
// 创建MCP Prompt插件
|
||||||
|
const mcpPromptPlugin = createMCPPromptPlugin({
|
||||||
|
enabled: true,
|
||||||
|
createSystemMessage: (systemPrompt, params, context) => {
|
||||||
|
console.log('createSystemMessage_context', context.isRecursiveCall)
|
||||||
|
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
|
||||||
|
if (context.isRecursiveCall) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
params.messages = [
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: systemPrompt
|
||||||
|
},
|
||||||
|
...params.messages
|
||||||
|
]
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
return systemPrompt
|
||||||
|
}
|
||||||
|
})
|
||||||
|
this.modernExecutor.pluginEngine.use(mcpPromptPlugin)
|
||||||
const streamResult = await this.modernExecutor.streamText(
|
const streamResult = await this.modernExecutor.streamText(
|
||||||
modelId,
|
modelId,
|
||||||
params,
|
params,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user