feat: enhance AI SDK chunk handling and tool call processing

- Introduced ToolCallChunkHandler for managing tool call events and results, improving the handling of tool interactions.
- Updated AiSdkToChunkAdapter to utilize the new handler, streamlining the processing of tool call chunks.
- Refactored transformParameters to support dynamic tool integration and improved parameter handling.
- Adjusted provider mapping in factory.ts to include new provider types, enhancing compatibility with various AI services.
- Removed obsolete cherryStudioTransformPlugin to clean up the codebase and focus on more relevant functionality.
This commit is contained in:
lizhixuan 2025-06-21 23:26:38 +08:00
parent ebe85ba24a
commit f6c3794ac9
7 changed files with 333 additions and 193 deletions

View File

@ -2,7 +2,7 @@
"compilerOptions": {
"target": "ES2020",
"module": "ESNext",
"moduleResolution": "node",
"moduleResolution": "bundler",
"declaration": true,
"outDir": "./dist",
"rootDir": "./src",

View File

@ -6,6 +6,8 @@
import { TextStreamPart } from '@cherrystudio/ai-core'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import { ToolCallChunkHandler } from './chunk/handleTooCallChunk'
export interface CherryStudioChunk {
type: 'text-delta' | 'text-complete' | 'tool-call' | 'tool-result' | 'finish' | 'error'
text?: string
@ -21,7 +23,10 @@ export interface CherryStudioChunk {
* fullStream Cherry Studio chunk
*/
export class AiSdkToChunkAdapter {
constructor(private onChunk: (chunk: Chunk) => void) {}
toolCallHandler: ToolCallChunkHandler
constructor(private onChunk: (chunk: Chunk) => void) {
this.toolCallHandler = new ToolCallChunkHandler(onChunk)
}
/**
* AI SDK
@ -100,7 +105,7 @@ export class AiSdkToChunkAdapter {
})
break
// === 工具调用相关事件 ===
// === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) ===
case 'tool-call-streaming-start':
// 开始流式工具调用
this.onChunk({
@ -145,55 +150,31 @@ export class AiSdkToChunkAdapter {
break
case 'tool-call':
// 完整的工具调用
this.onChunk({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: [
{
id: chunk.toolCallId,
name: chunk.toolName,
args: chunk.args
}
]
})
// 原始的工具调用(未被中间件处理)
this.toolCallHandler.handleToolCall(chunk)
break
case 'tool-result':
// 工具调用结果
this.onChunk({
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [
{
id: chunk.toolCallId,
tool: {
id: chunk.toolName,
// TODO: serverId,serverName
serverId: 'ai-sdk',
serverName: 'AI SDK',
name: chunk.toolName,
description: '',
inputSchema: {
type: 'object',
title: chunk.toolName,
properties: {}
}
},
arguments: chunk.args || {},
status: 'done',
response: chunk.result,
toolCallId: chunk.toolCallId
}
]
})
// 原始的工具调用结果(未被中间件处理)
this.toolCallHandler.handleToolResult(chunk)
break
// === 步骤相关事件 ===
// case 'step-start':
// this.onChunk({
// type: ChunkType.LLM_RESPONSE_CREATED
// })
// break
// TODO: 需要区分接口开始和步骤开始
// case 'step-start':
// this.onChunk({
// type: ChunkType.LLM_RESPONSE_CREATED
// })
// break
case 'step-finish':
this.onChunk({
type: ChunkType.TEXT_COMPLETE,
text: final.text || '' // TEXT_COMPLETE 需要 text 字段
})
final.text = ''
break
case 'finish':
this.onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
@ -212,13 +193,6 @@ export class AiSdkToChunkAdapter {
: undefined
}
})
break
case 'finish':
this.onChunk({
type: ChunkType.TEXT_COMPLETE,
text: final.text || '' // TEXT_COMPLETE 需要 text 字段
})
this.onChunk({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {

View File

@ -0,0 +1,130 @@
/**
* Chunk
*
* API使
*/
import Logger from '@renderer/config/logger'
import { MCPToolResponse } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
/**
*
*/
export class ToolCallChunkHandler {
// private onChunk: (chunk: Chunk) => void
private activeToolCalls = new Map<
string,
{
toolCallId: string
toolName: string
args: any
// mcpTool: MCPTool
}
>()
constructor(private onChunk: (chunk: Chunk) => void) {}
// /**
// * 设置 onChunk 回调
// */
// public setOnChunk(callback: (chunk: Chunk) => void): void {
// this.onChunk = callback
// }
/**
*
*/
public handleToolCall(chunk: any): void {
const toolCallId = chunk.toolCallId
const toolName = chunk.toolName
const args = chunk.args || {}
if (!toolCallId || !toolName) {
Logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`)
return
}
// 从 chunk 信息构造 MCPTool
// const mcpTool = this.createMcpToolFromChunk(chunk)
// 记录活跃的工具调用
this.activeToolCalls.set(toolCallId, {
toolCallId,
toolName,
args
// mcpTool
})
// 创建 MCPToolResponse 格式
const toolResponse: MCPToolResponse = {
id: toolCallId,
tool: {
id: toolCallId,
name: toolName
},
arguments: args,
status: 'invoking',
toolCallId: toolCallId
}
// 调用 onChunk
if (this.onChunk) {
this.onChunk({
type: ChunkType.MCP_TOOL_IN_PROGRESS,
responses: [toolResponse]
})
}
}
/**
*
*/
public handleToolResult(chunk: any): void {
const toolCallId = chunk.toolCallId
const result = chunk.result
if (!toolCallId) {
Logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`)
return
}
// 查找对应的工具调用信息
const toolCallInfo = this.activeToolCalls.get(toolCallId)
if (!toolCallInfo) {
Logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
return
}
// 创建工具调用结果的 MCPToolResponse 格式
const toolResponse: MCPToolResponse = {
id: toolCallId,
tool: {
id: toolCallId,
name: toolCallInfo.toolName
},
arguments: toolCallInfo.args,
status: 'done',
response: {
content: [
{
type: 'text',
text: typeof result === 'string' ? result : JSON.stringify(result)
}
],
isError: false
},
toolCallId: toolCallId
}
// 从活跃调用中移除(交互结束后整个实例会被丢弃)
this.activeToolCalls.delete(toolCallId)
// 调用 onChunk
if (this.onChunk) {
this.onChunk({
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [toolResponse]
})
}
}
}

View File

@ -155,7 +155,7 @@ export default class ModernAiProvider {
...middlewareConfig,
provider: this.provider,
// 工具相关信息从 params 中获取
enableTool: params.tools !== undefined && Array.isArray(params.tools) && params.tools.length > 0
enableTool: !!Object.keys(params.tools || {}).length
}
// 动态构建中间件数组

View File

@ -1,110 +0,0 @@
/**
* Cherry Studio
* Cherry Studio Assistant
*/
import { definePlugin } from '@cherrystudio/ai-core'
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
import {
buildStreamTextParams,
convertMessagesToSdkMessages,
getCustomParameters,
getTemperature,
getTopP
} from '../transformParameters'
/**
* Cherry Studio
* Cherry Studio AI SDK
*/
export const cherryStudioTransformPlugin = definePlugin({
name: 'cherry-studio-transform',
/**
*
* Cherry Studio Assistant + Messages AI SDK
*/
transformParams: async (params: any, context) => {
// 检查是否有 Cherry Studio 特有的数据结构
const cherryData = context.metadata?.cherryStudio
if (!cherryData) {
return params // 不是 Cherry Studio 调用,直接返回
}
const { assistant, messages, mcpTools, enableTools } = cherryData
try {
// 1. 转换 Cherry Studio 消息为 AI SDK 消息
const sdkMessages = await convertMessagesToSdkMessages(messages as Message[], assistant.model as Model)
// 2. 构建完整的 AI SDK 参数
const { params: transformedParams } = await buildStreamTextParams(sdkMessages, assistant as Assistant, {
mcpTools: mcpTools as MCPTool[],
enableTools,
requestOptions: {
signal: params.abortSignal,
headers: params.headers
}
})
// 3. 合并原始参数和转换后的参数
return {
...params,
...transformedParams,
// 保留原始的一些关键参数
abortSignal: params.abortSignal,
headers: params.headers
}
} catch (error) {
console.error('Cherry Studio 参数转换失败:', error)
return params // 转换失败时返回原始参数
}
}
})
/**
* Cherry Studio Assistant
* Assistant TopP
*/
export const cherryStudioSettingsPlugin = definePlugin({
name: 'cherry-studio-settings',
transformParams: async (params: any, context) => {
const cherryData = context.metadata?.cherryStudio
if (!cherryData?.assistant) {
return params
}
const { assistant } = cherryData
const model = assistant.model as Model
return {
...params,
temperature: getTemperature(assistant as Assistant, model),
topP: getTopP(assistant as Assistant, model),
...getCustomParameters(assistant as Assistant)
}
}
})
/**
* 便 Cherry Studio
*/
export function createCherryStudioContext(
assistant: Assistant,
messages: Message[],
options: {
mcpTools?: MCPTool[]
enableTools?: boolean
} = {}
) {
return {
cherryStudio: {
assistant,
messages,
mcpTools: options.mcpTools,
enableTools: options.enableTools
}
}
}

View File

@ -2,15 +2,16 @@ import { AiCore, ProviderId } from '@cherrystudio/ai-core'
import { Provider } from '@renderer/types'
const PROVIDER_MAPPING: Record<string, ProviderId> = {
anthropic: 'anthropic',
// anthropic: 'anthropic',
gemini: 'google',
vertexai: 'google-vertex',
'azure-openai': 'azure',
'openai-response': 'openai'
'openai-response': 'openai',
grok: 'xai'
}
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
const providerId = PROVIDER_MAPPING[provider.type]
const providerId = PROVIDER_MAPPING[provider.id]
if (providerId) {
return providerId

View File

@ -4,6 +4,8 @@
*/
import { type CoreMessage, type StreamTextParams } from '@cherrystudio/ai-core'
import { aiSdk } from '@cherrystudio/ai-core'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
isGenerateImageModel,
isNotSupportTemperatureAndTopP,
@ -16,14 +18,18 @@ import {
isWebSearchModel
} from '@renderer/config/models'
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
import type { Assistant, MCPTool, MCPToolResponse, Message, Model } from '@renderer/types'
import { FileTypes } from '@renderer/types'
import { callMCPTool } from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { defaultTimeout } from '@shared/config/constant'
// import { jsonSchemaToZod } from 'json-schema-to-zod'
import { jsonSchema } from 'ai'
import { buildProviderOptions } from './utils/reasoning'
const { tool } = aiSdk
/**
*
*/
@ -59,22 +65,6 @@ export async function buildSystemPromptWithTools(
return await buildSystemPrompt(prompt, mcpTools, assistant)
}
// /**
// * 转换 MCP 工具为 AI SDK 工具格式
// * 注意:这里返回通用格式,实际使用时需要根据具体 provider 转换
// TODO: 需要使用ai-sdk的mcp
// */
// export function convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Pick<StreamTextParams, 'tools'> {
// return mcpTools.map((tool) => ({
// type: 'function',
// function: {
// name: tool.id,
// description: tool.description,
// parameters: tool.inputSchema || {}
// }
// }))
// }
/**
*
*/
@ -200,6 +190,7 @@ export async function buildStreamTextParams(
assistant: Assistant,
options: {
mcpTools?: MCPTool[]
// FIXME: 上游没传
enableTools?: boolean
requestOptions?: {
signal?: AbortSignal
@ -208,7 +199,7 @@ export async function buildStreamTextParams(
}
} = {}
): Promise<{ params: StreamTextParams; modelId: string }> {
const { mcpTools, enableTools = false } = options
const { mcpTools } = options
const model = assistant.model || getDefaultModel()
@ -230,10 +221,11 @@ export async function buildStreamTextParams(
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
// 构建系统提示
let systemPrompt = assistant.prompt || ''
if (mcpTools && mcpTools.length > 0) {
systemPrompt = await buildSystemPromptWithTools(systemPrompt, mcpTools, assistant)
}
const systemPrompt = assistant.prompt || ''
// TODO:根据调用类型判断是否添加systemPrompt
// if (mcpTools && mcpTools.length > 0) {
// systemPrompt = await buildSystemPromptWithTools(systemPrompt, mcpTools, assistant)
// }
// 构建真正的 providerOptions
const providerOptions = buildProviderOptions(assistant, model, {
@ -245,19 +237,23 @@ export async function buildStreamTextParams(
// 构建基础参数
const params: StreamTextParams = {
messages: sdkMessages,
maxTokens: maxTokens || 1000,
maxTokens: maxTokens || DEFAULT_MAX_TOKENS,
temperature: getTemperature(assistant, model),
topP: getTopP(assistant, model),
system: systemPrompt || undefined,
abortSignal: options.requestOptions?.signal,
headers: options.requestOptions?.headers,
providerOptions
providerOptions,
maxSteps: 10
}
const tools = mcpTools ? convertMcpToolsToAiSdkTools(mcpTools) : {}
console.log('tools', tools)
console.log('enableTools', assistant?.mcpServers?.length)
// console.log('tools.length > 0', tools.length > 0)
// 添加工具(如果启用且有工具)
if (enableTools && mcpTools && mcpTools.length > 0) {
// TODO: 暂时注释掉工具支持,等类型问题解决后再启用
// params.tools = convertMcpToolsToSdkTools(mcpTools)
if (!!assistant?.mcpServers?.length && Object.keys(tools).length > 0) {
params.tools = tools
}
return { params, modelId: model.id }
@ -277,3 +273,152 @@ export async function buildGenerateTextParams(
// 复用流式参数的构建逻辑
return await buildStreamTextParams(messages, assistant, options)
}
/**
* JSON Schema Zod Schema
* string, number, boolean, array, object
*/
// function jsonSchemaToZod(schema: MCPToolInputSchema): z.ZodObject<any> {
// const properties: Record<string, z.ZodTypeAny> = {}
// const required = schema.required || []
// // 处理每个属性
// for (const [key, propSchema] of Object.entries(schema.properties)) {
// let zodSchema: z.ZodTypeAny
// // 根据 JSON Schema 类型创建对应的 Zod Schema
// const schemaType = (propSchema as any).type
// switch (schemaType) {
// case 'string': {
// let stringSchema = z.string()
// if ((propSchema as any).description) {
// stringSchema = stringSchema.describe((propSchema as any).description)
// }
// if ((propSchema as any).enum) {
// zodSchema = z.enum((propSchema as any).enum)
// } else {
// zodSchema = stringSchema
// }
// break
// }
// case 'number':
// case 'integer': {
// let numberSchema = z.number()
// if (schemaType === 'integer') {
// numberSchema = numberSchema.int()
// }
// if ((propSchema as any).minimum !== undefined) {
// numberSchema = numberSchema.min((propSchema as any).minimum)
// }
// if ((propSchema as any).maximum !== undefined) {
// numberSchema = numberSchema.max((propSchema as any).maximum)
// }
// if ((propSchema as any).description) {
// numberSchema = numberSchema.describe((propSchema as any).description)
// }
// zodSchema = numberSchema
// break
// }
// case 'boolean': {
// let booleanSchema = z.boolean()
// if ((propSchema as any).description) {
// booleanSchema = booleanSchema.describe((propSchema as any).description)
// }
// zodSchema = booleanSchema
// break
// }
// case 'array': {
// let itemSchema: z.ZodTypeAny = z.any()
// const itemsType = (propSchema as any).items?.type
// if (itemsType === 'string') {
// itemSchema = z.string()
// } else if (itemsType === 'number') {
// itemSchema = z.number()
// }
// let arraySchema = z.array(itemSchema)
// if ((propSchema as any).description) {
// arraySchema = arraySchema.describe((propSchema as any).description)
// }
// zodSchema = arraySchema
// break
// }
// case 'object': {
// // 对于嵌套对象,简单处理为 z.record
// let objectSchema = z.record(z.any())
// if ((propSchema as any).description) {
// objectSchema = objectSchema.describe((propSchema as any).description)
// }
// zodSchema = objectSchema
// break
// }
// default: {
// // 默认为 any
// zodSchema = z.any()
// break
// }
// }
// // 如果不是必需字段,添加 optional()
// if (!required.includes(key)) {
// zodSchema = zodSchema.optional()
// }
// properties[key] = zodSchema
// }
// return z.object(properties)
// }
/**
* MCPTool AI SDK
*/
export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string, any> {
const tools: Record<string, any> = {}
for (const mcpTool of mcpTools) {
console.log('mcpTool', mcpTool.inputSchema)
tools[mcpTool.name] = tool({
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
parameters: jsonSchema<Record<string, object>>(mcpTool.inputSchema),
execute: async (params) => {
console.log('execute_params', params)
// 创建适配的 MCPToolResponse 对象
const toolResponse: MCPToolResponse = {
id: `tool_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
tool: mcpTool,
arguments: params,
status: 'invoking',
toolCallId: `call_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
}
try {
// 复用现有的 callMCPTool 函数
const result = await callMCPTool(toolResponse)
// 返回结果AI SDK 会处理序列化
if (result.isError) {
throw new Error(result.content?.[0]?.text || 'Tool execution failed')
}
console.log('result', result)
// 返回工具执行结果
return {
success: true,
data: result
}
} catch (error) {
console.error(`MCP Tool execution failed: ${mcpTool.name}`, error)
throw new Error(
`Tool ${mcpTool.name} execution failed: ${error instanceof Error ? error.message : String(error)}`
)
}
}
})
}
return tools
}