feat(toolUsePlugin): refactor tool execution and event management

- Extracted `StreamEventManager` and `ToolExecutor` classes from `promptToolUsePlugin.ts` to improve code organization and reduce complexity.
- Enhanced tool execution logic with better error handling and event management.
- Updated the `createPromptToolUsePlugin` function to utilize the new classes for cleaner implementation.
- Improved recursive call handling and result formatting for tool executions.
- Streamlined the overall flow of tool calls and event emissions within the plugin.
This commit is contained in:
MyPrototypeWhat 2025-08-19 14:28:04 +08:00
parent 5d0ab0a9a1
commit 179b7af9bd
7 changed files with 354 additions and 230 deletions

View File

@ -300,10 +300,7 @@ export interface AiPlugin {
- 函数式设计,简化使用
```typescript
export function wrapModelWithMiddlewares(
model: LanguageModel,
middlewares: LanguageModelV1Middleware[]
): LanguageModel
export function wrapModelWithMiddlewares(model: LanguageModel, middlewares: LanguageModelV1Middleware[]): LanguageModel
```
### 4.5 Provider System (提供商系统)
@ -403,7 +400,7 @@ const model = await createModel({
import { Agent, run } from '@openai/agents'
const agent = new Agent({
model, // ✅ 直接兼容 LanguageModel 接口
model, // ✅ 直接兼容 LanguageModel 接口
name: 'Assistant',
instructions: '...',
tools: [tool1, tool2]

View File

@ -5,28 +5,33 @@ Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口
## ✨ 核心亮点
### 🏗️ 优雅的架构设计
- **简化分层**`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
- **函数式优先**:避免过度抽象,提供简洁直观的 API
- **类型安全**:完整的 TypeScript 支持,直接复用 AI SDK 类型系统
- **最小包装**:直接使用 AI SDK 的接口,避免重复定义和性能损耗
### 🔌 强大的插件系统
- **生命周期钩子**:支持请求全生命周期的扩展点
- **流转换支持**:基于 AI SDK 的 `experimental_transform` 实现流处理
- **插件分类**First、Sequential、Parallel 三种钩子类型,满足不同场景
- **内置插件**webSearch、logging、toolUse 等开箱即用的功能
### 🌐 统一多 Provider 接口
- **扩展注册**:支持自定义 Provider 注册,无限扩展能力
- **配置统一**:统一的配置接口,简化多 Provider 管理
### 🚀 多种使用方式
- **函数式调用**:适合简单场景的直接函数调用
- **执行器实例**:适合复杂场景的可复用执行器
- **静态工厂**:便捷的静态创建方法
- **原生兼容**:完全兼容 AI SDK 原生 Provider Registry
### 🔮 面向未来
- **Agent 就绪**:为 OpenAI Agents SDK 集成预留架构空间
- **模块化设计**:独立包结构,支持跨项目复用
- **渐进式迁移**:可以逐步从现有 AI SDK 代码迁移
@ -181,6 +186,7 @@ AI Core 提供了强大的插件系统,支持请求全生命周期的扩展。
### 内置插件
#### webSearchPlugin - 网络搜索插件
为不同 AI Provider 提供统一的网络搜索能力:
```typescript
@ -188,9 +194,13 @@ import { webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
webSearchPlugin({
openai: { /* OpenAI 搜索配置 */ },
openai: {
/* OpenAI 搜索配置 */
},
anthropic: { maxUses: 5 },
google: { /* Google 搜索配置 */ },
google: {
/* Google 搜索配置 */
},
xai: {
mode: 'on',
returnCitations: true,
@ -202,6 +212,7 @@ const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
```
#### loggingPlugin - 日志插件
提供详细的请求日志记录:
```typescript
@ -217,24 +228,29 @@ const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
```
#### promptToolUsePlugin - 提示工具使用插件
为不支持原生 Function Call 的模型提供 prompt 方式的工具调用:
```typescript
import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/built-in/plugins'
// 对于不支持 function call 的模型
const executor = AiCore.create('providerId', {
apiKey: 'your-key',
baseURL: 'https://your-model-endpoint'
}, [
createPromptToolUsePlugin({
enabled: true,
// 可选:自定义系统提示符构建
buildSystemPrompt: (userPrompt, tools) => {
return `${userPrompt}\n\nAvailable tools: ${Object.keys(tools).join(', ')}`
}
})
])
const executor = AiCore.create(
'providerId',
{
apiKey: 'your-key',
baseURL: 'https://your-model-endpoint'
},
[
createPromptToolUsePlugin({
enabled: true,
// 可选:自定义系统提示符构建
buildSystemPrompt: (userPrompt, tools) => {
return `${userPrompt}\n\nAvailable tools: ${Object.keys(tools).join(', ')}`
}
})
]
)
```
### 自定义插件
@ -402,6 +418,7 @@ await client.streamObject({
- [AI SDK Providers](https://ai-sdk.dev/providers/ai-sdk-providers)
## 未来版本
- 🔮 多 Agent 编排
- 🔮 可视化插件配置
- 🔮 实时监控和分析

View File

@ -0,0 +1,139 @@
/**
*
*
* AI SDK
* promptToolUsePlugin.ts
*/
import type { ModelMessage } from 'ai'
import type { AiRequestContext } from '../../types'
import type { StreamController } from './ToolExecutor'
/**
*
*/
export class StreamEventManager {
/**
*
*/
sendStepStartEvent(controller: StreamController): void {
controller.enqueue({
type: 'start-step',
request: {},
warnings: []
})
}
/**
*
*/
sendStepFinishEvent(controller: StreamController, chunk: any): void {
controller.enqueue({
type: 'finish-step',
finishReason: 'stop',
response: chunk.response,
usage: chunk.usage,
providerMetadata: chunk.providerMetadata
})
}
/**
*
*/
async handleRecursiveCall(
controller: StreamController,
recursiveParams: any,
context: AiRequestContext,
stepId: string
): Promise<void> {
try {
console.log('[MCP Prompt] Starting recursive call after tool execution...')
const recursiveResult = await context.recursiveCall(recursiveParams)
if (recursiveResult && recursiveResult.fullStream) {
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
} else {
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
}
} catch (error) {
this.handleRecursiveCallError(controller, error, stepId)
}
}
/**
*
*/
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
const reader = recursiveStream.getReader()
try {
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
if (value.type === 'finish') {
// 迭代的流不发finish
break
}
// 将递归流的数据传递到当前流
controller.enqueue(value)
}
} finally {
reader.releaseLock()
}
}
/**
*
*/
private handleRecursiveCallError(controller: StreamController, error: unknown, stepId: string): void {
console.error('[MCP Prompt] Recursive call failed:', error)
// 使用 AI SDK 标准错误格式,但不中断流
controller.enqueue({
type: 'error',
error: {
message: error instanceof Error ? error.message : String(error),
name: error instanceof Error ? error.name : 'RecursiveCallError'
}
})
// 继续发送文本增量,保持流的连续性
controller.enqueue({
type: 'text-delta',
id: stepId,
text: '\n\n[工具执行后递归调用失败,继续对话...]'
})
}
/**
*
*/
buildRecursiveParams(context: AiRequestContext, textBuffer: string, toolResultsText: string, tools: any): any {
// 构建新的对话消息
const newMessages: ModelMessage[] = [
...(context.originalParams.messages || []),
{
role: 'assistant',
content: textBuffer
},
{
role: 'user',
content: toolResultsText
}
]
// 递归调用,继续对话,重新传递 tools
const recursiveParams = {
...context.originalParams,
messages: newMessages,
tools: tools
}
// 更新上下文中的消息
context.originalParams.messages = newMessages
return recursiveParams
}
}

View File

@ -0,0 +1,156 @@
/**
*
*
*
* promptToolUsePlugin.ts
*/
import type { ToolSet } from 'ai'
import type { ToolUseResult } from './type'
/**
*
*/
export interface ExecutedResult {
toolCallId: string
toolName: string
result: any
isError?: boolean
}
/**
* AI SDK
*/
export interface StreamController {
enqueue(chunk: any): void
}
/**
*
*/
export class ToolExecutor {
/**
*
*/
async executeTools(
toolUses: ToolUseResult[],
tools: ToolSet,
controller: StreamController
): Promise<ExecutedResult[]> {
const executedResults: ExecutedResult[] = []
for (const toolUse of toolUses) {
try {
const tool = tools[toolUse.toolName]
if (!tool || typeof tool.execute !== 'function') {
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
}
// 发送工具调用开始事件
this.sendToolStartEvents(controller, toolUse)
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
// 发送 tool-call 事件
controller.enqueue({
type: 'tool-call',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: tool.inputSchema
})
const result = await tool.execute(toolUse.arguments, {
toolCallId: toolUse.id,
messages: [],
abortSignal: new AbortController().signal
})
// 发送 tool-result 事件
controller.enqueue({
type: 'tool-result',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: toolUse.arguments,
output: result
})
executedResults.push({
toolCallId: toolUse.id,
toolName: toolUse.toolName,
result,
isError: false
})
} catch (error) {
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
// 处理错误情况
const errorResult = this.handleToolError(toolUse, error, controller)
executedResults.push(errorResult)
}
}
return executedResults
}
/**
* Cherry Studio
*/
formatToolResults(executedResults: ExecutedResult[]): string {
return executedResults
.map((tr) => {
if (!tr.isError) {
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <result>${JSON.stringify(tr.result)}</result>\n</tool_use_result>`
} else {
const error = tr.result || 'Unknown error'
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <error>${error}</error>\n</tool_use_result>`
}
})
.join('\n\n')
}
/**
*
*/
private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
// 发送 tool-input-start 事件
controller.enqueue({
type: 'tool-input-start',
id: toolUse.id,
toolName: toolUse.toolName
})
}
/**
*
*/
private handleToolError(
toolUse: ToolUseResult,
error: unknown,
controller: StreamController
// _tools: ToolSet
): ExecutedResult {
// 使用 AI SDK 标准错误格式
// const toolError: TypedToolError<typeof _tools> = {
// type: 'tool-error',
// toolCallId: toolUse.id,
// toolName: toolUse.toolName,
// input: toolUse.arguments,
// error: error instanceof Error ? error.message : String(error)
// }
// controller.enqueue(toolError)
// 发送标准错误事件
controller.enqueue({
type: 'error',
error: error instanceof Error ? error.message : String(error)
})
return {
toolCallId: toolUse.id,
toolName: toolUse.toolName,
result: error instanceof Error ? error.message : String(error),
isError: true
}
}
}

View File

@ -3,29 +3,14 @@
* Function Call prompt
*
*/
import type { ModelMessage, TextStreamPart, ToolSet, TypedToolError } from 'ai'
import type { TextStreamPart, ToolSet } from 'ai'
import { definePlugin } from '../../index'
import type { AiRequestContext } from '../../types'
import { StreamEventManager } from './StreamEventManager'
import { ToolExecutor } from './ToolExecutor'
import { PromptToolUseConfig, ToolUseResult } from './type'
/**
* 使 AI SDK Tool
*/
// export interface Tool {
// type: 'function'
// function: {
// name: string
// description?: string
// parameters?: {
// type: 'object'
// properties: Record<string, any>
// required?: string[]
// additionalProperties?: boolean
// }
// }
// }
/**
* Cherry Studio
*/
@ -289,34 +274,32 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
transformStream: (_: any, context: AiRequestContext) => () => {
let textBuffer = ''
let stepId = ''
let executedResults: { toolCallId: string; toolName: string; result: any; isError?: boolean }[] = []
if (!context.mcpTools) {
throw new Error('No tools available')
}
// 创建工具执行器和流事件管理器
const toolExecutor = new ToolExecutor()
const streamEventManager = new StreamEventManager()
type TOOLS = NonNullable<typeof context.mcpTools>
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
async transform(
chunk: TextStreamPart<TOOLS>,
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
) {
// console.log('chunk', chunk)
// 收集文本内容
if (chunk.type === 'text-delta') {
textBuffer += chunk.text || ''
stepId = chunk.id || ''
// console.log('textBuffer', textBuffer)
controller.enqueue(chunk)
return
}
if (chunk.type === 'text-end' || chunk.type === 'finish-step') {
// console.log('[MCP Prompt Stream] Received step-finish, checking for tool use...')
// 从 context 获取工具信息
const tools = context.mcpTools
// console.log('tools from context', tools)
if (!tools || Object.keys(tools).length === 0) {
// console.log('[MCP Prompt Stream] No tools available, passing through')
controller.enqueue(chunk)
return
}
@ -324,14 +307,13 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
// 解析工具调用
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools)
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
// console.log('parsedTools', parsedTools)
// 如果没有有效的工具调用,直接传递原始事件
if (validToolUses.length === 0) {
// console.log('[MCP Prompt Stream] No valid tool uses found, passing through')
controller.enqueue(chunk)
return
}
if (chunk.type === 'text-end') {
controller.enqueue({
type: 'text-end',
@ -349,195 +331,31 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
...chunk,
finishReason: 'tool-calls'
})
// console.log('[MCP Prompt Stream] Found valid tool uses:', validToolUses.length)
// 发送 step-start 事件(工具调用步骤开始)
controller.enqueue({
type: 'start-step',
request: {},
warnings: []
})
// 发送步骤开始事件
streamEventManager.sendStepStartEvent(controller)
// 执行工具调用
executedResults = []
for (const toolUse of validToolUses) {
try {
const tool = tools[toolUse.toolName]
if (!tool || typeof tool.execute !== 'function') {
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
}
// 发送 tool-input-start 事件
controller.enqueue({
type: 'tool-input-start',
id: toolUse.id,
toolName: toolUse.toolName
})
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
console.log('toolUse,toolUse', toolUse)
// 发送 tool-call 事件
controller.enqueue({
type: 'tool-call',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: tool.inputSchema
})
// 发送步骤完成事件
streamEventManager.sendStepFinishEvent(controller, chunk)
const result = await tool.execute(toolUse.arguments, {
toolCallId: toolUse.id,
messages: [],
abortSignal: new AbortController().signal
})
// 发送 tool-result 事件
controller.enqueue({
type: 'tool-result',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: toolUse.arguments,
output: result
})
executedResults.push({
toolCallId: toolUse.id,
toolName: toolUse.toolName,
result,
isError: false
})
} catch (error) {
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
// 使用 AI SDK 标准错误格式
const toolError: TypedToolError<typeof context.mcpTools> = {
type: 'tool-error',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: toolUse.arguments,
error: error instanceof Error ? error.message : String(error)
}
controller.enqueue(toolError)
// 发送标准错误事件
controller.enqueue({
type: 'error',
error: toolError.error
})
// // 发送 tool-result 错误事件
// controller.enqueue({
// type: 'tool-result',
// toolCallId: toolUse.id,
// toolName: toolUse.toolName,
// args: toolUse.arguments,
// isError: true,
// result: toolError.message
// })
executedResults.push({
toolCallId: toolUse.id,
toolName: toolUse.toolName,
result: toolError.error,
isError: true
})
}
}
// 发送最终的 step-finish 事件
controller.enqueue({
type: 'finish-step',
finishReason: 'stop',
response: chunk.response,
usage: chunk.usage,
providerMetadata: chunk.providerMetadata
})
// 递归调用逻辑
// 处理递归调用
if (validToolUses.length > 0) {
// console.log('[MCP Prompt] Starting recursive call after tool execution...')
const toolResultsText = toolExecutor.formatToolResults(executedResults)
const recursiveParams = streamEventManager.buildRecursiveParams(
context,
textBuffer,
toolResultsText,
tools
)
// 构建工具结果的文本表示使用Cherry Studio标准格式
const toolResultsText = executedResults
.map((tr) => {
if (!tr.isError) {
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <result>${JSON.stringify(tr.result)}</result>\n</tool_use_result>`
} else {
const error = tr.result || 'Unknown error'
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <error>${error}</error>\n</tool_use_result>`
}
})
.join('\n\n')
// console.log('context.originalParams.messages', context.originalParams.messages)
// 构建新的对话消息
const newMessages: ModelMessage[] = [
...(context.originalParams.messages || []),
{
role: 'assistant',
content: textBuffer
},
{
role: 'user',
content: toolResultsText
}
]
// 递归调用,继续对话,重新传递 tools
const recursiveParams = {
...context.originalParams,
messages: newMessages,
tools: tools
}
context.originalParams.messages = newMessages
try {
const recursiveResult = await context.recursiveCall(recursiveParams)
// 将递归调用的结果流接入当前流
if (recursiveResult && recursiveResult.fullStream) {
const reader = recursiveResult.fullStream.getReader()
try {
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
if (value.type === 'finish') {
// 迭代的流不发finish
break
}
// 将递归流的数据传递到当前流
controller.enqueue(value)
}
} finally {
reader.releaseLock()
}
} else {
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
}
} catch (error) {
console.error('[MCP Prompt] Recursive call failed:', error)
// 使用 AI SDK 标准错误格式,但不中断流
controller.enqueue({
type: 'error',
error: {
message: error instanceof Error ? error.message : String(error),
name: error instanceof Error ? error.name : 'RecursiveCallError'
}
})
// 继续发送文本增量,保持流的连续性
controller.enqueue({
type: 'text-delta',
id: stepId,
text: '\n\n[工具执行后递归调用失败,继续对话...]'
})
}
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context, stepId)
}
// 清理状态
textBuffer = ''
executedResults = []
return
}

View File

@ -83,7 +83,6 @@ export class AiSdkToChunkAdapter {
chunk: TextStreamPart<any>,
final: { text: string; reasoningContent: string; webSearchResults: any[]; reasoningId: string }
) {
console.log('AI SDK chunk type:', chunk.type, chunk)
switch (chunk.type) {
// === 文本相关事件 ===
case 'text-start':
@ -101,7 +100,7 @@ export class AiSdkToChunkAdapter {
case 'text-end':
this.onChunk({
type: ChunkType.TEXT_COMPLETE,
text: (chunk.providerMetadata?.text?.value as string) || final.text || ''
text: (chunk.providerMetadata?.text?.value as string) ?? final.text ?? ''
})
final.text = ''
break

View File

@ -15,9 +15,7 @@ const MessageContent: React.FC<Props> = ({ message }) => {
<>
{!isEmpty(message.mentions) && (
<Flex gap="8px" wrap style={{ marginBottom: '10px' }}>
{message.mentions?.map((model) => (
<MentionTag key={getModelUniqId(model)}>{'@' + model.name}</MentionTag>
))}
{message.mentions?.map((model) => <MentionTag key={getModelUniqId(model)}>{'@' + model.name}</MentionTag>)}
</Flex>
)}
<MessageBlockRenderer blocks={message.blocks} message={message} />