refactor: update RuntimeExecutor and introduce MCP Prompt Plugin

- Changed `pluginClient` to `pluginEngine` in `RuntimeExecutor` for clarity and consistency.
- Updated method calls in `RuntimeExecutor` to use the new `pluginEngine`.
- Enhanced `AiSdkMiddlewareBuilder` to include `mcpTools` in the middleware configuration.
- Added `MCPPromptPlugin` to support tool calls within prompts, enabling recursive processing and improved handling of tool interactions.
- Updated `ApiService` to pass `mcpTools` during chat completion requests, enhancing integration with the new plugin system.
This commit is contained in:
lizhixuan 2025-06-26 00:10:39 +08:00
parent f23a026a28
commit 8b67a45804
5 changed files with 280 additions and 10 deletions

View File

@ -12,7 +12,7 @@ import { PluginEngine } from './pluginEngine'
import { type RuntimeConfig } from './types'
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
private pluginClient: PluginEngine<T>
public pluginEngine: PluginEngine<T>
// private options: ProviderSettingsMap[T]
private config: RuntimeConfig<T>
@ -25,7 +25,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
// this.options = config.options
this.config = config
// 创建插件客户端
this.pluginClient = new PluginEngine(config.providerId, config.plugins || [])
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
}
// === 高阶重载:直接使用模型 ===
@ -62,7 +62,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
const model = await this.resolveModel(modelOrId, options?.middlewares)
// 2. 执行插件处理
return this.pluginClient.executeStreamWithPlugins(
return this.pluginEngine.executeStreamWithPlugins(
'streamText',
typeof modelOrId === 'string' ? modelOrId : model.modelId,
params,
@ -112,7 +112,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
): Promise<ReturnType<typeof generateText>> {
const model = await this.resolveModel(modelOrId, options?.middlewares)
return this.pluginClient.executeWithPlugins(
return this.pluginEngine.executeWithPlugins(
'generateText',
typeof modelOrId === 'string' ? modelOrId : model.modelId,
params,
@ -153,7 +153,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
): Promise<ReturnType<typeof generateObject>> {
const model = await this.resolveModel(modelOrId, options?.middlewares)
return this.pluginClient.executeWithPlugins(
return this.pluginEngine.executeWithPlugins(
'generateObject',
typeof modelOrId === 'string' ? modelOrId : model.modelId,
params,
@ -194,7 +194,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
): Promise<ReturnType<typeof streamObject>> {
const model = await this.resolveModel(modelOrId, options?.middlewares)
return this.pluginClient.executeWithPlugins(
return this.pluginEngine.executeWithPlugins(
'streamObject',
typeof modelOrId === 'string' ? modelOrId : model.modelId,
params,

View File

@ -53,8 +53,9 @@ function providerToAiSdkConfig(provider: Provider): {
if (aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, {
...actualProvider,
baseURL: actualProvider.apiHost
...actualProvider
// 使用ai-sdk内置的baseURL
// baseURL: actualProvider.apiHost
})
return {
@ -173,6 +174,16 @@ export default class ModernAiProvider {
if (middlewareConfig.onChunk) {
// 流式处理 - 使用适配器
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
// this.modernExecutor.pluginEngine.use(
// createMCPPromptPlugin({
// mcpTools: middlewareConfig.mcpTools || [],
// assistant: params.assistant,
// onChunk: middlewareConfig.onChunk,
// recursiveCall: this.modernExecutor.streamText,
// recursionDepth: 0,
// maxRecursionDepth: 20
// })
// )
const streamResult = await this.modernExecutor.streamText(
modelId,
params,

View File

@ -4,7 +4,7 @@ import {
simulateStreamingMiddleware
} from '@cherrystudio/ai-core'
import { isReasoningModel } from '@renderer/config/models'
import type { Model, Provider } from '@renderer/types'
import type { MCPTool, Model, Provider } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import thinkingTimeMiddleware from './ThinkingTimeMiddleware'
@ -20,6 +20,7 @@ export interface AiSdkMiddlewareConfig {
enableReasoning?: boolean
enableTool?: boolean
enableWebSearch?: boolean
mcpTools?: MCPTool[]
}
/**

View File

@ -0,0 +1,257 @@
import { definePlugin, StreamTextParams } from '@cherrystudio/ai-core'
import { buildSystemPromptWithTools } from '@renderer/aiCore/transformParameters'
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
import { Assistant, MCPTool, MCPToolResponse } from '@renderer/types'
import { parseAndCallTools } from '@renderer/utils/mcp-tools'
/**
* MCP Prompt
*/
export interface MCPPromptPluginConfig {
mcpTools: MCPTool[]
assistant: Assistant
onChunk: (chunk: any) => void
recursiveCall: (params: StreamTextParams) => Promise<{ stream?: ReadableStream; getText?: () => string }>
recursionDepth?: number // 当前递归深度,默认为 0
maxRecursionDepth?: number // 最大递归深度,默认为 20
}
/**
* MCP Prompt
* prompt
*/
export const createMCPPromptPlugin = (config: MCPPromptPluginConfig) => {
return definePlugin({
name: 'mcp-prompt-plugin',
// 1. 参数转换 - 注入工具描述到系统提示
transformParams: async (params: StreamTextParams) => {
const { mcpTools, assistant } = config
if (mcpTools.length === 0) {
return params
}
try {
// 复用现有的系统提示构建逻辑
const enhancedSystemPrompt = await buildSystemPromptWithTools(params.system || '', mcpTools, assistant)
return {
...params,
system: enhancedSystemPrompt,
// Prompt 模式不使用 function calling
tools: undefined
}
} catch (error) {
console.error('构建系统提示失败:', error)
return params
}
},
// 2. 流处理 - 检测工具调用并执行
transformStream: () => {
let fullResponseText = ''
let hasProcessedTools = false
return new TransformStream({
async transform(chunk, controller) {
try {
// 收集完整的文本响应
if (chunk.type === 'text-delta') {
fullResponseText += chunk.textDelta
}
// 在流结束时检查并处理工具调用
if (chunk.type === 'finish' && !hasProcessedTools) {
hasProcessedTools = true
if (containsToolCallPattern(fullResponseText)) {
await processToolCallsAndRecurse(fullResponseText, config, controller)
return // 不转发 finish chunk让递归调用处理
}
}
// 正常转发其他类型的 chunk
controller.enqueue(chunk)
} catch (error) {
console.error('MCP Prompt Plugin Transform Error:', error)
controller.error(error)
}
},
async flush(controller) {
// 流结束时的最后检查
if (!hasProcessedTools && containsToolCallPattern(fullResponseText)) {
await processToolCallsAndRecurse(fullResponseText, config, controller)
}
}
})
}
})
}
/**
*
*/
async function processToolCallsAndRecurse(
responseText: string,
config: MCPPromptPluginConfig,
controller: TransformStreamDefaultController
) {
const { mcpTools, assistant, onChunk, recursionDepth = 0, maxRecursionDepth = 20 } = config
// 检查是否超过最大递归深度
if (recursionDepth >= maxRecursionDepth) {
console.log(`已达到最大递归深度 ${maxRecursionDepth},停止工具调用处理`)
controller.enqueue({
type: 'text-delta',
textDelta: `\n\n[已达到最大工具调用深度 ${maxRecursionDepth},停止继续调用]`
})
return
}
try {
console.log(`检测到工具调用,开始处理... (递归深度: ${recursionDepth}/${maxRecursionDepth})`)
const allToolResponses: MCPToolResponse[] = []
// 直接使用现有的 parseAndCallTools 函数
// 它会自动解析文本中的工具调用、执行工具、触发 onChunk
const toolResults = await parseAndCallTools(
responseText, // 传入完整响应文本,让 parseAndCallTools 自己解析
allToolResponses,
onChunk, // 直接传入来自配置的 onChunk
(mcpToolResponse, resp) => {
// 复用现有的消息转换逻辑
return convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp)
},
assistant.model!,
mcpTools
)
console.log('工具执行完成,结果数量:', toolResults.length)
// 如果有工具结果,构建新消息并递归调用
if (toolResults.length > 0) {
await performRecursiveCall(responseText, toolResults, config, controller)
}
} catch (error) {
console.error('工具调用处理失败:', error)
// 发送错误信息作为文本 chunk
controller.enqueue({
type: 'text-delta',
textDelta: `\n\n[工具调用错误: ${error instanceof Error ? error.message : String(error)}]`
})
}
}
/**
*
*/
async function performRecursiveCall(
originalResponse: string,
toolResults: any[],
config: MCPPromptPluginConfig,
controller: TransformStreamDefaultController
) {
const { assistant, recursiveCall, recursionDepth = 0 } = config
try {
// 获取当前的消息历史(需要从上下文获取,这里暂时用空数组)
// TODO: 实现从上下文获取当前消息的逻辑
const currentMessages = getCurrentMessagesFromContext()
// 构建新的消息历史
const newMessages = [
...currentMessages,
{
role: 'assistant' as const,
content: originalResponse
},
...toolResults // toolResults 已经是正确的消息格式
]
console.log(`构建新消息历史完成,消息数量: ${newMessages.length},递归深度: ${recursionDepth}`)
// 复用现有的参数构建逻辑
const { params: recursiveParams } = await buildStreamTextParams(newMessages, assistant, {
mcpTools: config.mcpTools,
enableTools: true
})
console.log(`开始递归调用... (深度: ${recursionDepth + 1})`)
// 递归调用,递增递归深度
const recursiveResult = await recursiveCall(recursiveParams)
// 转发递归结果的流
if (recursiveResult.stream) {
const reader = recursiveResult.stream.getReader()
try {
while (true) {
const { done, value } = await reader.read()
if (done) break
controller.enqueue(value)
}
} finally {
reader.releaseLock()
}
} else if (recursiveResult.getText) {
// 如果没有流,但有文本结果
const finalText = recursiveResult.getText()
controller.enqueue({
type: 'text-delta',
textDelta: finalText
})
}
console.log(`递归调用完成 (深度: ${recursionDepth + 1})`)
} catch (error) {
console.error('递归调用失败:', error)
controller.enqueue({
type: 'text-delta',
textDelta: `\n\n[递归调用错误: ${error instanceof Error ? error.message : String(error)}]`
})
}
}
/**
*
*/
function containsToolCallPattern(text: string): boolean {
const patterns = [
/<tool_use>/i,
/<tool_call>/i
// 可以根据实际使用的格式添加更多模式
]
return patterns.some((pattern) => pattern.test(text))
}
/**
*
* TODO: 实现从实际上下文获取消息的逻辑
*/
function getCurrentMessagesFromContext(): any[] {
// 这里需要实现从上下文获取当前消息历史的逻辑
// 暂时返回空数组,后续根据实际情况补充
console.warn('getCurrentMessagesFromContext 尚未实现,返回空数组')
return []
}
/**
* MCP SDK
*
*/
function convertMcpToolResponseToSdkMessageParam(mcpToolResponse: MCPToolResponse, resp: any): any {
// 这里需要根据实际的转换逻辑来实现
// 暂时返回一个基础的用户消息格式
return {
role: 'user',
content: `工具 ${mcpToolResponse.tool.name} 执行结果: ${JSON.stringify(resp)}`
}
}
export default createMCPPromptPlugin

View File

@ -310,7 +310,8 @@ export async function fetchChatCompletion({
onChunk: onChunkReceived,
model: assistant.model,
provider: provider,
enableReasoning: assistant.settings?.reasoning_effort !== undefined
enableReasoning: assistant.settings?.reasoning_effort !== undefined,
mcpTools
}
// --- Call AI Completions ---