mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
refactor: update RuntimeExecutor and introduce MCP Prompt Plugin
- Changed `pluginClient` to `pluginEngine` in `RuntimeExecutor` for clarity and consistency. - Updated method calls in `RuntimeExecutor` to use the new `pluginEngine`. - Enhanced `AiSdkMiddlewareBuilder` to include `mcpTools` in the middleware configuration. - Added `MCPPromptPlugin` to support tool calls within prompts, enabling recursive processing and improved handling of tool interactions. - Updated `ApiService` to pass `mcpTools` during chat completion requests, enhancing integration with the new plugin system.
This commit is contained in:
parent
f23a026a28
commit
8b67a45804
@ -12,7 +12,7 @@ import { PluginEngine } from './pluginEngine'
|
||||
import { type RuntimeConfig } from './types'
|
||||
|
||||
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
private pluginClient: PluginEngine<T>
|
||||
public pluginEngine: PluginEngine<T>
|
||||
// private options: ProviderSettingsMap[T]
|
||||
private config: RuntimeConfig<T>
|
||||
|
||||
@ -25,7 +25,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
// this.options = config.options
|
||||
this.config = config
|
||||
// 创建插件客户端
|
||||
this.pluginClient = new PluginEngine(config.providerId, config.plugins || [])
|
||||
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
|
||||
}
|
||||
|
||||
// === 高阶重载:直接使用模型 ===
|
||||
@ -62,7 +62,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
||||
|
||||
// 2. 执行插件处理
|
||||
return this.pluginClient.executeStreamWithPlugins(
|
||||
return this.pluginEngine.executeStreamWithPlugins(
|
||||
'streamText',
|
||||
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
||||
params,
|
||||
@ -112,7 +112,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
): Promise<ReturnType<typeof generateText>> {
|
||||
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
||||
|
||||
return this.pluginClient.executeWithPlugins(
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'generateText',
|
||||
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
||||
params,
|
||||
@ -153,7 +153,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
): Promise<ReturnType<typeof generateObject>> {
|
||||
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
||||
|
||||
return this.pluginClient.executeWithPlugins(
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'generateObject',
|
||||
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
||||
params,
|
||||
@ -194,7 +194,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
): Promise<ReturnType<typeof streamObject>> {
|
||||
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
||||
|
||||
return this.pluginClient.executeWithPlugins(
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'streamObject',
|
||||
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
||||
params,
|
||||
|
||||
@ -53,8 +53,9 @@ function providerToAiSdkConfig(provider: Provider): {
|
||||
|
||||
if (aiSdkProviderId !== 'openai-compatible') {
|
||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, {
|
||||
...actualProvider,
|
||||
baseURL: actualProvider.apiHost
|
||||
...actualProvider
|
||||
// 使用ai-sdk内置的baseURL
|
||||
// baseURL: actualProvider.apiHost
|
||||
})
|
||||
|
||||
return {
|
||||
@ -173,6 +174,16 @@ export default class ModernAiProvider {
|
||||
if (middlewareConfig.onChunk) {
|
||||
// 流式处理 - 使用适配器
|
||||
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
|
||||
// this.modernExecutor.pluginEngine.use(
|
||||
// createMCPPromptPlugin({
|
||||
// mcpTools: middlewareConfig.mcpTools || [],
|
||||
// assistant: params.assistant,
|
||||
// onChunk: middlewareConfig.onChunk,
|
||||
// recursiveCall: this.modernExecutor.streamText,
|
||||
// recursionDepth: 0,
|
||||
// maxRecursionDepth: 20
|
||||
// })
|
||||
// )
|
||||
const streamResult = await this.modernExecutor.streamText(
|
||||
modelId,
|
||||
params,
|
||||
|
||||
@ -4,7 +4,7 @@ import {
|
||||
simulateStreamingMiddleware
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { isReasoningModel } from '@renderer/config/models'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import type { MCPTool, Model, Provider } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
|
||||
import thinkingTimeMiddleware from './ThinkingTimeMiddleware'
|
||||
@ -20,6 +20,7 @@ export interface AiSdkMiddlewareConfig {
|
||||
enableReasoning?: boolean
|
||||
enableTool?: boolean
|
||||
enableWebSearch?: boolean
|
||||
mcpTools?: MCPTool[]
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
257
src/renderer/src/aiCore/plugins/mcpPromptPlugin.ts
Normal file
257
src/renderer/src/aiCore/plugins/mcpPromptPlugin.ts
Normal file
@ -0,0 +1,257 @@
|
||||
import { definePlugin, StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { buildSystemPromptWithTools } from '@renderer/aiCore/transformParameters'
|
||||
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
|
||||
import { Assistant, MCPTool, MCPToolResponse } from '@renderer/types'
|
||||
import { parseAndCallTools } from '@renderer/utils/mcp-tools'
|
||||
|
||||
/**
|
||||
* MCP Prompt 插件配置
|
||||
*/
|
||||
export interface MCPPromptPluginConfig {
|
||||
mcpTools: MCPTool[]
|
||||
assistant: Assistant
|
||||
onChunk: (chunk: any) => void
|
||||
recursiveCall: (params: StreamTextParams) => Promise<{ stream?: ReadableStream; getText?: () => string }>
|
||||
recursionDepth?: number // 当前递归深度,默认为 0
|
||||
maxRecursionDepth?: number // 最大递归深度,默认为 20
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 MCP Prompt 模式插件
|
||||
* 支持在 prompt 模式下解析文本中的工具调用并执行
|
||||
*/
|
||||
export const createMCPPromptPlugin = (config: MCPPromptPluginConfig) => {
|
||||
return definePlugin({
|
||||
name: 'mcp-prompt-plugin',
|
||||
|
||||
// 1. 参数转换 - 注入工具描述到系统提示
|
||||
transformParams: async (params: StreamTextParams) => {
|
||||
const { mcpTools, assistant } = config
|
||||
|
||||
if (mcpTools.length === 0) {
|
||||
return params
|
||||
}
|
||||
|
||||
try {
|
||||
// 复用现有的系统提示构建逻辑
|
||||
const enhancedSystemPrompt = await buildSystemPromptWithTools(params.system || '', mcpTools, assistant)
|
||||
|
||||
return {
|
||||
...params,
|
||||
system: enhancedSystemPrompt,
|
||||
// Prompt 模式不使用 function calling
|
||||
tools: undefined
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('构建系统提示失败:', error)
|
||||
return params
|
||||
}
|
||||
},
|
||||
|
||||
// 2. 流处理 - 检测工具调用并执行
|
||||
transformStream: () => {
|
||||
let fullResponseText = ''
|
||||
let hasProcessedTools = false
|
||||
|
||||
return new TransformStream({
|
||||
async transform(chunk, controller) {
|
||||
try {
|
||||
// 收集完整的文本响应
|
||||
if (chunk.type === 'text-delta') {
|
||||
fullResponseText += chunk.textDelta
|
||||
}
|
||||
|
||||
// 在流结束时检查并处理工具调用
|
||||
if (chunk.type === 'finish' && !hasProcessedTools) {
|
||||
hasProcessedTools = true
|
||||
|
||||
if (containsToolCallPattern(fullResponseText)) {
|
||||
await processToolCallsAndRecurse(fullResponseText, config, controller)
|
||||
return // 不转发 finish chunk,让递归调用处理
|
||||
}
|
||||
}
|
||||
|
||||
// 正常转发其他类型的 chunk
|
||||
controller.enqueue(chunk)
|
||||
} catch (error) {
|
||||
console.error('MCP Prompt Plugin Transform Error:', error)
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
|
||||
async flush(controller) {
|
||||
// 流结束时的最后检查
|
||||
if (!hasProcessedTools && containsToolCallPattern(fullResponseText)) {
|
||||
await processToolCallsAndRecurse(fullResponseText, config, controller)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具调用并执行递归
|
||||
*/
|
||||
async function processToolCallsAndRecurse(
|
||||
responseText: string,
|
||||
config: MCPPromptPluginConfig,
|
||||
controller: TransformStreamDefaultController
|
||||
) {
|
||||
const { mcpTools, assistant, onChunk, recursionDepth = 0, maxRecursionDepth = 20 } = config
|
||||
|
||||
// 检查是否超过最大递归深度
|
||||
if (recursionDepth >= maxRecursionDepth) {
|
||||
console.log(`已达到最大递归深度 ${maxRecursionDepth},停止工具调用处理`)
|
||||
controller.enqueue({
|
||||
type: 'text-delta',
|
||||
textDelta: `\n\n[已达到最大工具调用深度 ${maxRecursionDepth},停止继续调用]`
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
console.log(`检测到工具调用,开始处理... (递归深度: ${recursionDepth}/${maxRecursionDepth})`)
|
||||
|
||||
const allToolResponses: MCPToolResponse[] = []
|
||||
|
||||
// 直接使用现有的 parseAndCallTools 函数
|
||||
// 它会自动解析文本中的工具调用、执行工具、触发 onChunk
|
||||
const toolResults = await parseAndCallTools(
|
||||
responseText, // 传入完整响应文本,让 parseAndCallTools 自己解析
|
||||
allToolResponses,
|
||||
onChunk, // 直接传入来自配置的 onChunk
|
||||
(mcpToolResponse, resp) => {
|
||||
// 复用现有的消息转换逻辑
|
||||
return convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp)
|
||||
},
|
||||
assistant.model!,
|
||||
mcpTools
|
||||
)
|
||||
|
||||
console.log('工具执行完成,结果数量:', toolResults.length)
|
||||
|
||||
// 如果有工具结果,构建新消息并递归调用
|
||||
if (toolResults.length > 0) {
|
||||
await performRecursiveCall(responseText, toolResults, config, controller)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('工具调用处理失败:', error)
|
||||
|
||||
// 发送错误信息作为文本 chunk
|
||||
controller.enqueue({
|
||||
type: 'text-delta',
|
||||
textDelta: `\n\n[工具调用错误: ${error instanceof Error ? error.message : String(error)}]`
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行递归调用
|
||||
*/
|
||||
async function performRecursiveCall(
|
||||
originalResponse: string,
|
||||
toolResults: any[],
|
||||
config: MCPPromptPluginConfig,
|
||||
controller: TransformStreamDefaultController
|
||||
) {
|
||||
const { assistant, recursiveCall, recursionDepth = 0 } = config
|
||||
|
||||
try {
|
||||
// 获取当前的消息历史(需要从上下文获取,这里暂时用空数组)
|
||||
// TODO: 实现从上下文获取当前消息的逻辑
|
||||
const currentMessages = getCurrentMessagesFromContext()
|
||||
|
||||
// 构建新的消息历史
|
||||
const newMessages = [
|
||||
...currentMessages,
|
||||
{
|
||||
role: 'assistant' as const,
|
||||
content: originalResponse
|
||||
},
|
||||
...toolResults // toolResults 已经是正确的消息格式
|
||||
]
|
||||
|
||||
console.log(`构建新消息历史完成,消息数量: ${newMessages.length},递归深度: ${recursionDepth}`)
|
||||
|
||||
// 复用现有的参数构建逻辑
|
||||
const { params: recursiveParams } = await buildStreamTextParams(newMessages, assistant, {
|
||||
mcpTools: config.mcpTools,
|
||||
enableTools: true
|
||||
})
|
||||
|
||||
console.log(`开始递归调用... (深度: ${recursionDepth + 1})`)
|
||||
|
||||
// 递归调用,递增递归深度
|
||||
const recursiveResult = await recursiveCall(recursiveParams)
|
||||
|
||||
// 转发递归结果的流
|
||||
if (recursiveResult.stream) {
|
||||
const reader = recursiveResult.stream.getReader()
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
controller.enqueue(value)
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
} else if (recursiveResult.getText) {
|
||||
// 如果没有流,但有文本结果
|
||||
const finalText = recursiveResult.getText()
|
||||
controller.enqueue({
|
||||
type: 'text-delta',
|
||||
textDelta: finalText
|
||||
})
|
||||
}
|
||||
|
||||
console.log(`递归调用完成 (深度: ${recursionDepth + 1})`)
|
||||
} catch (error) {
|
||||
console.error('递归调用失败:', error)
|
||||
|
||||
controller.enqueue({
|
||||
type: 'text-delta',
|
||||
textDelta: `\n\n[递归调用错误: ${error instanceof Error ? error.message : String(error)}]`
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查文本是否包含工具调用模式
|
||||
*/
|
||||
function containsToolCallPattern(text: string): boolean {
|
||||
const patterns = [
|
||||
/<tool_use>/i,
|
||||
/<tool_call>/i
|
||||
// 可以根据实际使用的格式添加更多模式
|
||||
]
|
||||
|
||||
return patterns.some((pattern) => pattern.test(text))
|
||||
}
|
||||
|
||||
/**
|
||||
* 从上下文获取当前消息历史
|
||||
* TODO: 实现从实际上下文获取消息的逻辑
|
||||
*/
|
||||
function getCurrentMessagesFromContext(): any[] {
|
||||
// 这里需要实现从上下文获取当前消息历史的逻辑
|
||||
// 暂时返回空数组,后续根据实际情况补充
|
||||
console.warn('getCurrentMessagesFromContext 尚未实现,返回空数组')
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换 MCP 工具响应为 SDK 消息参数
|
||||
* 复用现有的转换逻辑
|
||||
*/
|
||||
function convertMcpToolResponseToSdkMessageParam(mcpToolResponse: MCPToolResponse, resp: any): any {
|
||||
// 这里需要根据实际的转换逻辑来实现
|
||||
// 暂时返回一个基础的用户消息格式
|
||||
return {
|
||||
role: 'user',
|
||||
content: `工具 ${mcpToolResponse.tool.name} 执行结果: ${JSON.stringify(resp)}`
|
||||
}
|
||||
}
|
||||
|
||||
export default createMCPPromptPlugin
|
||||
@ -310,7 +310,8 @@ export async function fetchChatCompletion({
|
||||
onChunk: onChunkReceived,
|
||||
model: assistant.model,
|
||||
provider: provider,
|
||||
enableReasoning: assistant.settings?.reasoning_effort !== undefined
|
||||
enableReasoning: assistant.settings?.reasoning_effort !== undefined,
|
||||
mcpTools
|
||||
}
|
||||
|
||||
// --- Call AI Completions ---
|
||||
|
||||
Loading…
Reference in New Issue
Block a user