mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 21:01:32 +08:00
feat: enhance AI SDK chunk handling and tool call processing
- Introduced ToolCallChunkHandler for managing tool call events and results, improving the handling of tool interactions. - Updated AiSdkToChunkAdapter to utilize the new handler, streamlining the processing of tool call chunks. - Refactored transformParameters to support dynamic tool integration and improved parameter handling. - Adjusted provider mapping in factory.ts to include new provider types, enhancing compatibility with various AI services. - Removed obsolete cherryStudioTransformPlugin to clean up the codebase and focus on more relevant functionality.
This commit is contained in:
parent
ebe85ba24a
commit
f6c3794ac9
@ -2,7 +2,7 @@
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "node",
|
||||
"moduleResolution": "bundler",
|
||||
"declaration": true,
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src",
|
||||
|
||||
@ -6,6 +6,8 @@
|
||||
import { TextStreamPart } from '@cherrystudio/ai-core'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import { ToolCallChunkHandler } from './chunk/handleTooCallChunk'
|
||||
|
||||
export interface CherryStudioChunk {
|
||||
type: 'text-delta' | 'text-complete' | 'tool-call' | 'tool-result' | 'finish' | 'error'
|
||||
text?: string
|
||||
@ -21,7 +23,10 @@ export interface CherryStudioChunk {
|
||||
* 处理 fullStream 到 Cherry Studio chunk 的转换
|
||||
*/
|
||||
export class AiSdkToChunkAdapter {
|
||||
constructor(private onChunk: (chunk: Chunk) => void) {}
|
||||
toolCallHandler: ToolCallChunkHandler
|
||||
constructor(private onChunk: (chunk: Chunk) => void) {
|
||||
this.toolCallHandler = new ToolCallChunkHandler(onChunk)
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理 AI SDK 流结果
|
||||
@ -100,7 +105,7 @@ export class AiSdkToChunkAdapter {
|
||||
})
|
||||
break
|
||||
|
||||
// === 工具调用相关事件 ===
|
||||
// === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) ===
|
||||
case 'tool-call-streaming-start':
|
||||
// 开始流式工具调用
|
||||
this.onChunk({
|
||||
@ -145,55 +150,31 @@ export class AiSdkToChunkAdapter {
|
||||
break
|
||||
|
||||
case 'tool-call':
|
||||
// 完整的工具调用
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: [
|
||||
{
|
||||
id: chunk.toolCallId,
|
||||
name: chunk.toolName,
|
||||
args: chunk.args
|
||||
}
|
||||
]
|
||||
})
|
||||
// 原始的工具调用(未被中间件处理)
|
||||
this.toolCallHandler.handleToolCall(chunk)
|
||||
break
|
||||
|
||||
case 'tool-result':
|
||||
// 工具调用结果
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||
responses: [
|
||||
{
|
||||
id: chunk.toolCallId,
|
||||
tool: {
|
||||
id: chunk.toolName,
|
||||
// TODO: serverId,serverName
|
||||
serverId: 'ai-sdk',
|
||||
serverName: 'AI SDK',
|
||||
name: chunk.toolName,
|
||||
description: '',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
title: chunk.toolName,
|
||||
properties: {}
|
||||
}
|
||||
},
|
||||
arguments: chunk.args || {},
|
||||
status: 'done',
|
||||
response: chunk.result,
|
||||
toolCallId: chunk.toolCallId
|
||||
}
|
||||
]
|
||||
})
|
||||
// 原始的工具调用结果(未被中间件处理)
|
||||
this.toolCallHandler.handleToolResult(chunk)
|
||||
break
|
||||
|
||||
// === 步骤相关事件 ===
|
||||
// case 'step-start':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||
// })
|
||||
// break
|
||||
// TODO: 需要区分接口开始和步骤开始
|
||||
// case 'step-start':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||
// })
|
||||
// break
|
||||
case 'step-finish':
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: final.text || '' // TEXT_COMPLETE 需要 text 字段
|
||||
})
|
||||
final.text = ''
|
||||
break
|
||||
|
||||
case 'finish':
|
||||
this.onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
@ -212,13 +193,6 @@ export class AiSdkToChunkAdapter {
|
||||
: undefined
|
||||
}
|
||||
})
|
||||
break
|
||||
|
||||
case 'finish':
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: final.text || '' // TEXT_COMPLETE 需要 text 字段
|
||||
})
|
||||
this.onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
|
||||
130
src/renderer/src/aiCore/chunk/handleTooCallChunk.ts
Normal file
130
src/renderer/src/aiCore/chunk/handleTooCallChunk.ts
Normal file
@ -0,0 +1,130 @@
|
||||
/**
|
||||
* 工具调用 Chunk 处理模块
|
||||
*
|
||||
* 提供工具调用相关的处理API,每个交互使用一个新的实例
|
||||
*/
|
||||
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { MCPToolResponse } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
/**
|
||||
* 工具调用处理器类
|
||||
*/
|
||||
export class ToolCallChunkHandler {
|
||||
// private onChunk: (chunk: Chunk) => void
|
||||
private activeToolCalls = new Map<
|
||||
string,
|
||||
{
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
args: any
|
||||
// mcpTool: MCPTool
|
||||
}
|
||||
>()
|
||||
constructor(private onChunk: (chunk: Chunk) => void) {}
|
||||
|
||||
// /**
|
||||
// * 设置 onChunk 回调
|
||||
// */
|
||||
// public setOnChunk(callback: (chunk: Chunk) => void): void {
|
||||
// this.onChunk = callback
|
||||
// }
|
||||
|
||||
/**
|
||||
* 处理工具调用事件
|
||||
*/
|
||||
public handleToolCall(chunk: any): void {
|
||||
const toolCallId = chunk.toolCallId
|
||||
const toolName = chunk.toolName
|
||||
const args = chunk.args || {}
|
||||
|
||||
if (!toolCallId || !toolName) {
|
||||
Logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`)
|
||||
return
|
||||
}
|
||||
|
||||
// 从 chunk 信息构造 MCPTool
|
||||
// const mcpTool = this.createMcpToolFromChunk(chunk)
|
||||
|
||||
// 记录活跃的工具调用
|
||||
this.activeToolCalls.set(toolCallId, {
|
||||
toolCallId,
|
||||
toolName,
|
||||
args
|
||||
// mcpTool
|
||||
})
|
||||
|
||||
// 创建 MCPToolResponse 格式
|
||||
const toolResponse: MCPToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: {
|
||||
id: toolCallId,
|
||||
name: toolName
|
||||
},
|
||||
arguments: args,
|
||||
status: 'invoking',
|
||||
toolCallId: toolCallId
|
||||
}
|
||||
|
||||
// 调用 onChunk
|
||||
if (this.onChunk) {
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_IN_PROGRESS,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具调用结果事件
|
||||
*/
|
||||
public handleToolResult(chunk: any): void {
|
||||
const toolCallId = chunk.toolCallId
|
||||
const result = chunk.result
|
||||
|
||||
if (!toolCallId) {
|
||||
Logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`)
|
||||
return
|
||||
}
|
||||
|
||||
// 查找对应的工具调用信息
|
||||
const toolCallInfo = this.activeToolCalls.get(toolCallId)
|
||||
if (!toolCallInfo) {
|
||||
Logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建工具调用结果的 MCPToolResponse 格式
|
||||
const toolResponse: MCPToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: {
|
||||
id: toolCallId,
|
||||
name: toolCallInfo.toolName
|
||||
},
|
||||
arguments: toolCallInfo.args,
|
||||
status: 'done',
|
||||
response: {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: typeof result === 'string' ? result : JSON.stringify(result)
|
||||
}
|
||||
],
|
||||
isError: false
|
||||
},
|
||||
toolCallId: toolCallId
|
||||
}
|
||||
|
||||
// 从活跃调用中移除(交互结束后整个实例会被丢弃)
|
||||
this.activeToolCalls.delete(toolCallId)
|
||||
|
||||
// 调用 onChunk
|
||||
if (this.onChunk) {
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -155,7 +155,7 @@ export default class ModernAiProvider {
|
||||
...middlewareConfig,
|
||||
provider: this.provider,
|
||||
// 工具相关信息从 params 中获取
|
||||
enableTool: params.tools !== undefined && Array.isArray(params.tools) && params.tools.length > 0
|
||||
enableTool: !!Object.keys(params.tools || {}).length
|
||||
}
|
||||
|
||||
// 动态构建中间件数组
|
||||
|
||||
@ -1,110 +0,0 @@
|
||||
/**
|
||||
* Cherry Studio 参数转换插件
|
||||
* 专门处理 Cherry Studio 特有的消息格式、文件处理、Assistant 设置等
|
||||
*/
|
||||
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
|
||||
|
||||
import {
|
||||
buildStreamTextParams,
|
||||
convertMessagesToSdkMessages,
|
||||
getCustomParameters,
|
||||
getTemperature,
|
||||
getTopP
|
||||
} from '../transformParameters'
|
||||
|
||||
/**
|
||||
* Cherry Studio 核心转换插件
|
||||
* 负责将 Cherry Studio 的数据结构转换为 AI SDK 兼容格式
|
||||
*/
|
||||
export const cherryStudioTransformPlugin = definePlugin({
|
||||
name: 'cherry-studio-transform',
|
||||
|
||||
/**
|
||||
* 转换请求参数
|
||||
* 将 Cherry Studio 的 Assistant + Messages 转换为 AI SDK 格式
|
||||
*/
|
||||
transformParams: async (params: any, context) => {
|
||||
// 检查是否有 Cherry Studio 特有的数据结构
|
||||
const cherryData = context.metadata?.cherryStudio
|
||||
if (!cherryData) {
|
||||
return params // 不是 Cherry Studio 调用,直接返回
|
||||
}
|
||||
|
||||
const { assistant, messages, mcpTools, enableTools } = cherryData
|
||||
|
||||
try {
|
||||
// 1. 转换 Cherry Studio 消息为 AI SDK 消息
|
||||
const sdkMessages = await convertMessagesToSdkMessages(messages as Message[], assistant.model as Model)
|
||||
|
||||
// 2. 构建完整的 AI SDK 参数
|
||||
const { params: transformedParams } = await buildStreamTextParams(sdkMessages, assistant as Assistant, {
|
||||
mcpTools: mcpTools as MCPTool[],
|
||||
enableTools,
|
||||
requestOptions: {
|
||||
signal: params.abortSignal,
|
||||
headers: params.headers
|
||||
}
|
||||
})
|
||||
|
||||
// 3. 合并原始参数和转换后的参数
|
||||
return {
|
||||
...params,
|
||||
...transformedParams,
|
||||
// 保留原始的一些关键参数
|
||||
abortSignal: params.abortSignal,
|
||||
headers: params.headers
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Cherry Studio 参数转换失败:', error)
|
||||
return params // 转换失败时返回原始参数
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
/**
|
||||
* Cherry Studio Assistant 设置插件
|
||||
* 专门处理 Assistant 的温度、TopP、自定义参数等设置
|
||||
*/
|
||||
export const cherryStudioSettingsPlugin = definePlugin({
|
||||
name: 'cherry-studio-settings',
|
||||
|
||||
transformParams: async (params: any, context) => {
|
||||
const cherryData = context.metadata?.cherryStudio
|
||||
if (!cherryData?.assistant) {
|
||||
return params
|
||||
}
|
||||
|
||||
const { assistant } = cherryData
|
||||
const model = assistant.model as Model
|
||||
|
||||
return {
|
||||
...params,
|
||||
temperature: getTemperature(assistant as Assistant, model),
|
||||
topP: getTopP(assistant as Assistant, model),
|
||||
...getCustomParameters(assistant as Assistant)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
/**
|
||||
* 便捷函数:为 Cherry Studio 调用准备上下文元数据
|
||||
*/
|
||||
export function createCherryStudioContext(
|
||||
assistant: Assistant,
|
||||
messages: Message[],
|
||||
options: {
|
||||
mcpTools?: MCPTool[]
|
||||
enableTools?: boolean
|
||||
} = {}
|
||||
) {
|
||||
return {
|
||||
cherryStudio: {
|
||||
assistant,
|
||||
messages,
|
||||
mcpTools: options.mcpTools,
|
||||
enableTools: options.enableTools
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2,15 +2,16 @@ import { AiCore, ProviderId } from '@cherrystudio/ai-core'
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
const PROVIDER_MAPPING: Record<string, ProviderId> = {
|
||||
anthropic: 'anthropic',
|
||||
// anthropic: 'anthropic',
|
||||
gemini: 'google',
|
||||
vertexai: 'google-vertex',
|
||||
'azure-openai': 'azure',
|
||||
'openai-response': 'openai'
|
||||
'openai-response': 'openai',
|
||||
grok: 'xai'
|
||||
}
|
||||
|
||||
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
|
||||
const providerId = PROVIDER_MAPPING[provider.type]
|
||||
const providerId = PROVIDER_MAPPING[provider.id]
|
||||
|
||||
if (providerId) {
|
||||
return providerId
|
||||
|
||||
@ -4,6 +4,8 @@
|
||||
*/
|
||||
|
||||
import { type CoreMessage, type StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { aiSdk } from '@cherrystudio/ai-core'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import {
|
||||
isGenerateImageModel,
|
||||
isNotSupportTemperatureAndTopP,
|
||||
@ -16,14 +18,18 @@ import {
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
|
||||
import type { Assistant, MCPTool, MCPToolResponse, Message, Model } from '@renderer/types'
|
||||
import { FileTypes } from '@renderer/types'
|
||||
import { callMCPTool } from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
// import { jsonSchemaToZod } from 'json-schema-to-zod'
|
||||
import { jsonSchema } from 'ai'
|
||||
|
||||
import { buildProviderOptions } from './utils/reasoning'
|
||||
|
||||
const { tool } = aiSdk
|
||||
/**
|
||||
* 获取温度参数
|
||||
*/
|
||||
@ -59,22 +65,6 @@ export async function buildSystemPromptWithTools(
|
||||
return await buildSystemPrompt(prompt, mcpTools, assistant)
|
||||
}
|
||||
|
||||
// /**
|
||||
// * 转换 MCP 工具为 AI SDK 工具格式
|
||||
// * 注意:这里返回通用格式,实际使用时需要根据具体 provider 转换
|
||||
// TODO: 需要使用ai-sdk的mcp
|
||||
// */
|
||||
// export function convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Pick<StreamTextParams, 'tools'> {
|
||||
// return mcpTools.map((tool) => ({
|
||||
// type: 'function',
|
||||
// function: {
|
||||
// name: tool.id,
|
||||
// description: tool.description,
|
||||
// parameters: tool.inputSchema || {}
|
||||
// }
|
||||
// }))
|
||||
// }
|
||||
|
||||
/**
|
||||
* 提取文件内容
|
||||
*/
|
||||
@ -200,6 +190,7 @@ export async function buildStreamTextParams(
|
||||
assistant: Assistant,
|
||||
options: {
|
||||
mcpTools?: MCPTool[]
|
||||
// FIXME: 上游没传
|
||||
enableTools?: boolean
|
||||
requestOptions?: {
|
||||
signal?: AbortSignal
|
||||
@ -208,7 +199,7 @@ export async function buildStreamTextParams(
|
||||
}
|
||||
} = {}
|
||||
): Promise<{ params: StreamTextParams; modelId: string }> {
|
||||
const { mcpTools, enableTools = false } = options
|
||||
const { mcpTools } = options
|
||||
|
||||
const model = assistant.model || getDefaultModel()
|
||||
|
||||
@ -230,10 +221,11 @@ export async function buildStreamTextParams(
|
||||
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
|
||||
|
||||
// 构建系统提示
|
||||
let systemPrompt = assistant.prompt || ''
|
||||
if (mcpTools && mcpTools.length > 0) {
|
||||
systemPrompt = await buildSystemPromptWithTools(systemPrompt, mcpTools, assistant)
|
||||
}
|
||||
const systemPrompt = assistant.prompt || ''
|
||||
// TODO:根据调用类型判断是否添加systemPrompt
|
||||
// if (mcpTools && mcpTools.length > 0) {
|
||||
// systemPrompt = await buildSystemPromptWithTools(systemPrompt, mcpTools, assistant)
|
||||
// }
|
||||
|
||||
// 构建真正的 providerOptions
|
||||
const providerOptions = buildProviderOptions(assistant, model, {
|
||||
@ -245,19 +237,23 @@ export async function buildStreamTextParams(
|
||||
// 构建基础参数
|
||||
const params: StreamTextParams = {
|
||||
messages: sdkMessages,
|
||||
maxTokens: maxTokens || 1000,
|
||||
maxTokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: getTemperature(assistant, model),
|
||||
topP: getTopP(assistant, model),
|
||||
system: systemPrompt || undefined,
|
||||
abortSignal: options.requestOptions?.signal,
|
||||
headers: options.requestOptions?.headers,
|
||||
providerOptions
|
||||
providerOptions,
|
||||
maxSteps: 10
|
||||
}
|
||||
|
||||
const tools = mcpTools ? convertMcpToolsToAiSdkTools(mcpTools) : {}
|
||||
console.log('tools', tools)
|
||||
console.log('enableTools', assistant?.mcpServers?.length)
|
||||
// console.log('tools.length > 0', tools.length > 0)
|
||||
// 添加工具(如果启用且有工具)
|
||||
if (enableTools && mcpTools && mcpTools.length > 0) {
|
||||
// TODO: 暂时注释掉工具支持,等类型问题解决后再启用
|
||||
// params.tools = convertMcpToolsToSdkTools(mcpTools)
|
||||
if (!!assistant?.mcpServers?.length && Object.keys(tools).length > 0) {
|
||||
params.tools = tools
|
||||
}
|
||||
|
||||
return { params, modelId: model.id }
|
||||
@ -277,3 +273,152 @@ export async function buildGenerateTextParams(
|
||||
// 复用流式参数的构建逻辑
|
||||
return await buildStreamTextParams(messages, assistant, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 简单的 JSON Schema 到 Zod Schema 转换
|
||||
* 支持基本类型:string, number, boolean, array, object
|
||||
*/
|
||||
// function jsonSchemaToZod(schema: MCPToolInputSchema): z.ZodObject<any> {
|
||||
// const properties: Record<string, z.ZodTypeAny> = {}
|
||||
// const required = schema.required || []
|
||||
|
||||
// // 处理每个属性
|
||||
// for (const [key, propSchema] of Object.entries(schema.properties)) {
|
||||
// let zodSchema: z.ZodTypeAny
|
||||
|
||||
// // 根据 JSON Schema 类型创建对应的 Zod Schema
|
||||
// const schemaType = (propSchema as any).type
|
||||
// switch (schemaType) {
|
||||
// case 'string': {
|
||||
// let stringSchema = z.string()
|
||||
// if ((propSchema as any).description) {
|
||||
// stringSchema = stringSchema.describe((propSchema as any).description)
|
||||
// }
|
||||
// if ((propSchema as any).enum) {
|
||||
// zodSchema = z.enum((propSchema as any).enum)
|
||||
// } else {
|
||||
// zodSchema = stringSchema
|
||||
// }
|
||||
// break
|
||||
// }
|
||||
|
||||
// case 'number':
|
||||
// case 'integer': {
|
||||
// let numberSchema = z.number()
|
||||
// if (schemaType === 'integer') {
|
||||
// numberSchema = numberSchema.int()
|
||||
// }
|
||||
// if ((propSchema as any).minimum !== undefined) {
|
||||
// numberSchema = numberSchema.min((propSchema as any).minimum)
|
||||
// }
|
||||
// if ((propSchema as any).maximum !== undefined) {
|
||||
// numberSchema = numberSchema.max((propSchema as any).maximum)
|
||||
// }
|
||||
// if ((propSchema as any).description) {
|
||||
// numberSchema = numberSchema.describe((propSchema as any).description)
|
||||
// }
|
||||
// zodSchema = numberSchema
|
||||
// break
|
||||
// }
|
||||
|
||||
// case 'boolean': {
|
||||
// let booleanSchema = z.boolean()
|
||||
// if ((propSchema as any).description) {
|
||||
// booleanSchema = booleanSchema.describe((propSchema as any).description)
|
||||
// }
|
||||
// zodSchema = booleanSchema
|
||||
// break
|
||||
// }
|
||||
|
||||
// case 'array': {
|
||||
// let itemSchema: z.ZodTypeAny = z.any()
|
||||
// const itemsType = (propSchema as any).items?.type
|
||||
// if (itemsType === 'string') {
|
||||
// itemSchema = z.string()
|
||||
// } else if (itemsType === 'number') {
|
||||
// itemSchema = z.number()
|
||||
// }
|
||||
// let arraySchema = z.array(itemSchema)
|
||||
// if ((propSchema as any).description) {
|
||||
// arraySchema = arraySchema.describe((propSchema as any).description)
|
||||
// }
|
||||
// zodSchema = arraySchema
|
||||
// break
|
||||
// }
|
||||
|
||||
// case 'object': {
|
||||
// // 对于嵌套对象,简单处理为 z.record
|
||||
// let objectSchema = z.record(z.any())
|
||||
// if ((propSchema as any).description) {
|
||||
// objectSchema = objectSchema.describe((propSchema as any).description)
|
||||
// }
|
||||
// zodSchema = objectSchema
|
||||
// break
|
||||
// }
|
||||
|
||||
// default: {
|
||||
// // 默认为 any
|
||||
// zodSchema = z.any()
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
|
||||
// // 如果不是必需字段,添加 optional()
|
||||
// if (!required.includes(key)) {
|
||||
// zodSchema = zodSchema.optional()
|
||||
// }
|
||||
|
||||
// properties[key] = zodSchema
|
||||
// }
|
||||
|
||||
// return z.object(properties)
|
||||
// }
|
||||
|
||||
/**
|
||||
* 将 MCPTool 转换为 AI SDK 工具格式
|
||||
*/
|
||||
export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string, any> {
|
||||
const tools: Record<string, any> = {}
|
||||
|
||||
for (const mcpTool of mcpTools) {
|
||||
console.log('mcpTool', mcpTool.inputSchema)
|
||||
tools[mcpTool.name] = tool({
|
||||
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
||||
parameters: jsonSchema<Record<string, object>>(mcpTool.inputSchema),
|
||||
execute: async (params) => {
|
||||
console.log('execute_params', params)
|
||||
// 创建适配的 MCPToolResponse 对象
|
||||
const toolResponse: MCPToolResponse = {
|
||||
id: `tool_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
|
||||
tool: mcpTool,
|
||||
arguments: params,
|
||||
status: 'invoking',
|
||||
toolCallId: `call_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
||||
}
|
||||
|
||||
try {
|
||||
// 复用现有的 callMCPTool 函数
|
||||
const result = await callMCPTool(toolResponse)
|
||||
|
||||
// 返回结果,AI SDK 会处理序列化
|
||||
if (result.isError) {
|
||||
throw new Error(result.content?.[0]?.text || 'Tool execution failed')
|
||||
}
|
||||
console.log('result', result)
|
||||
// 返回工具执行结果
|
||||
return {
|
||||
success: true,
|
||||
data: result
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`MCP Tool execution failed: ${mcpTool.name}`, error)
|
||||
throw new Error(
|
||||
`Tool ${mcpTool.name} execution failed: ${error instanceof Error ? error.message : String(error)}`
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user