diff --git a/packages/aiCore/package.json b/packages/aiCore/package.json index 5f3b09a030..2bcbce362a 100644 --- a/packages/aiCore/package.json +++ b/packages/aiCore/package.json @@ -46,7 +46,8 @@ "@openrouter/ai-sdk-provider": "^0.7.2", "ai": "5.0.0-beta.7", "anthropic-vertex-ai": "^1.0.2", - "ollama-ai-provider": "^1.2.0" + "ollama-ai-provider": "^1.2.0", + "zod": "^3.25.0" }, "peerDependenciesMeta": { "@ai-sdk/amazon-bedrock": { diff --git a/packages/aiCore/src/core/index.ts b/packages/aiCore/src/core/index.ts index a50a5a382d..748b2303de 100644 --- a/packages/aiCore/src/core/index.ts +++ b/packages/aiCore/src/core/index.ts @@ -19,5 +19,6 @@ export { } from './models' // 执行管理 +export type { MCPRequestContext } from './plugins/built-in/mcpPromptPlugin' export type { ExecutionOptions, ExecutorConfig } from './runtime' export { createExecutor, createOpenAICompatibleExecutor } from './runtime' diff --git a/packages/aiCore/src/core/options/types.ts b/packages/aiCore/src/core/options/types.ts index b1770bdd87..5e78323ce5 100644 --- a/packages/aiCore/src/core/options/types.ts +++ b/packages/aiCore/src/core/options/types.ts @@ -1,11 +1,11 @@ import { type AnthropicProviderOptions } from '@ai-sdk/anthropic' import { type GoogleGenerativeAIProviderOptions } from '@ai-sdk/google' import { type OpenAIResponsesProviderOptions } from '@ai-sdk/openai' -import { type LanguageModelV1ProviderMetadata } from '@ai-sdk/provider' +import { type SharedV2ProviderMetadata } from '@ai-sdk/provider' import { type OpenRouterProviderOptions } from './openrouter' -export type ProviderOptions = LanguageModelV1ProviderMetadata[T] +export type ProviderOptions = SharedV2ProviderMetadata[T] /** * 供应商选项类型,如果map中没有,说明没有约束 @@ -28,4 +28,4 @@ export type TypedProviderOptions = { [K in keyof ProviderOptionsMap]?: ProviderOptionsMap[K] } & { [K in string]?: Record -} & LanguageModelV1ProviderMetadata +} & SharedV2ProviderMetadata diff --git a/packages/aiCore/src/core/plugins/built-in/mcpPromptPlugin.ts b/packages/aiCore/src/core/plugins/built-in/mcpPromptPlugin.ts index 7f3fb2a7f6..2b93c553d9 100644 --- a/packages/aiCore/src/core/plugins/built-in/mcpPromptPlugin.ts +++ b/packages/aiCore/src/core/plugins/built-in/mcpPromptPlugin.ts @@ -3,8 +3,7 @@ * 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用 * 内置默认逻辑,支持自定义覆盖 */ -import type { ToolSet } from 'ai' -import { ToolExecutionError } from 'ai' +import type { ModelMessage, TextStreamPart, ToolErrorUnion, ToolSet } from 'ai' import { definePlugin } from '../index' import type { AiRequestContext } from '../types' @@ -44,17 +43,17 @@ export interface MCPPromptConfig { // 是否启用(用于运行时开关) enabled?: boolean // 自定义系统提示符构建函数(可选,有默认实现) - buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => Promise + buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string // 自定义工具解析函数(可选,有默认实现) parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[] - createSystemMessage?: (systemPrompt: string, originalParams: any, context: MCPRequestContext) => string | null + createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null } /** * 扩展的 AI 请求上下文,支持 MCP 工具存储 */ -interface MCPRequestContext extends AiRequestContext { - mcpTools?: ToolSet +export interface MCPRequestContext extends AiRequestContext { + mcpTools: ToolSet } /** @@ -201,7 +200,7 @@ function buildAvailableTools(tools: ToolSet): string { ${toolName} ${tool.description || ''} - ${tool.parameters ? JSON.stringify(tool.parameters) : ''} + ${tool.inputSchema ? JSON.stringify(tool.inputSchema) : ''} ` @@ -215,7 +214,7 @@ ${availableTools} /** * 默认的系统提示符构建函数(提取自 Cherry Studio) */ -async function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): Promise { +function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): string { const availableTools = buildAvailableTools(tools) const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES) @@ -291,8 +290,7 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => { return definePlugin({ name: 'built-in:mcp-prompt', - - transformParams: async (params: any, context: MCPRequestContext) => { + transformParams: (params: any, context: AiRequestContext) => { if (!enabled || !params.tools || typeof params.tools !== 'object') { return params } @@ -303,7 +301,7 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => { // 构建系统提示符 const userSystemPrompt = typeof params.system === 'string' ? params.system : '' - const systemPrompt = await buildSystemPrompt(userSystemPrompt, params.tools) + const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools) let systemMessage: string | null = systemPrompt console.log('config.context', context) if (config.createSystemMessage) { @@ -320,25 +318,30 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => { console.log('transformedParams', transformedParams) return transformedParams }, - - // 流式处理:监听 step-finish 事件并处理工具调用 - transformStream: (_, context: MCPRequestContext) => () => { + transformStream: (_: any, context: AiRequestContext) => () => { let textBuffer = '' + let stepId = '' let executedResults: { toolCallId: string; toolName: string; result: any; isError?: boolean }[] = [] - - return new TransformStream({ - async transform(chunk, controller) { + if (!context.mcpTools) { + throw new Error('No tools available') + } + type TOOLS = NonNullable + return new TransformStream, TextStreamPart>({ + async transform( + chunk: TextStreamPart, + controller: TransformStreamDefaultController> + ) { // console.log('chunk', chunk) // 收集文本内容 - if (chunk.type === 'text-delta') { - textBuffer += chunk.textDelta || '' + if (chunk.type === 'text') { + textBuffer += chunk.text || '' + stepId = chunk.id || '' // console.log('textBuffer', textBuffer) controller.enqueue(chunk) return } - // 监听 step-finish 事件 - if (chunk.type === 'step-finish' || chunk.type === 'finish') { + if (chunk.type === 'finish-step') { // console.log('[MCP Prompt Stream] Received step-finish, checking for tool use...') // 从 context 获取工具信息 @@ -364,17 +367,11 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => { // console.log('[MCP Prompt Stream] Found valid tool uses:', validToolUses.length) - // 修改 step-finish 事件,标记为工具调用 - if (chunk.type !== 'finish') { - controller.enqueue({ - ...chunk, - finishReason: 'tool-call' - }) - } - // 发送 step-start 事件(工具调用步骤开始) controller.enqueue({ - type: 'step-start' + type: 'start-step', + request: {}, + warnings: [] }) // 执行工具调用 @@ -392,7 +389,7 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => { type: 'tool-call', toolCallId: toolUse.id, toolName: toolUse.toolName, - args: toolUse.arguments + input: tool.inputSchema }) const result = await tool.execute(toolUse.arguments, { @@ -406,8 +403,8 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => { type: 'tool-result', toolCallId: toolUse.id, toolName: toolUse.toolName, - args: toolUse.arguments, - result + input: toolUse.arguments, + output: result }) executedResults.push({ @@ -420,39 +417,36 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => { console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error) // 使用 AI SDK 标准错误格式 - const toolError = new ToolExecutionError({ - toolName: toolUse.toolName, - toolArgs: toolUse.arguments, + const toolError: ToolErrorUnion = { + type: 'tool-error', toolCallId: toolUse.id, - message: `Tool execution failed: ${error instanceof Error ? error.message : String(error)}`, - cause: error instanceof Error ? error : undefined - }) + toolName: toolUse.toolName, + input: toolUse.arguments, + error: error instanceof Error ? error.message : String(error) + } + + controller.enqueue(toolError) // 发送标准错误事件 controller.enqueue({ type: 'error', - error: { - message: toolError.message, - name: toolError.name, - toolName: toolError.toolName, - toolCallId: toolError.toolCallId - } + error: toolError.error }) - // 发送 tool-result 错误事件 - controller.enqueue({ - type: 'tool-result', - toolCallId: toolUse.id, - toolName: toolUse.toolName, - args: toolUse.arguments, - isError: true, - result: toolError.message - }) + // // 发送 tool-result 错误事件 + // controller.enqueue({ + // type: 'tool-result', + // toolCallId: toolUse.id, + // toolName: toolUse.toolName, + // args: toolUse.arguments, + // isError: true, + // result: toolError.message + // }) executedResults.push({ toolCallId: toolUse.id, toolName: toolUse.toolName, - result: toolError.message, + result: toolError.error, isError: true }) } @@ -460,8 +454,11 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => { // 发送最终的 step-finish 事件 controller.enqueue({ - type: 'step-finish', - finishReason: 'tool-call' + type: 'finish-step', + finishReason: 'tool-calls', + response: chunk.response, + usage: chunk.usage, + providerMetadata: chunk.providerMetadata }) // 递归调用逻辑 @@ -481,7 +478,7 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => { .join('\n\n') // console.log('context.originalParams.messages', context.originalParams.messages) // 构建新的对话消息 - const newMessages = [ + const newMessages: ModelMessage[] = [ ...(context.originalParams.messages || []), { role: 'assistant', @@ -540,8 +537,9 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => { // 继续发送文本增量,保持流的连续性 controller.enqueue({ - type: 'text-delta', - textDelta: '\n\n[工具执行后递归调用失败,继续对话...]' + type: 'text', + id: stepId, + text: '\n\n[工具执行后递归调用失败,继续对话...]' }) } } diff --git a/packages/aiCore/src/core/plugins/index.ts b/packages/aiCore/src/core/plugins/index.ts index 06833fbc99..8b64b5b0d8 100644 --- a/packages/aiCore/src/core/plugins/index.ts +++ b/packages/aiCore/src/core/plugins/index.ts @@ -1,12 +1,13 @@ // 核心类型和接口 export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './types' +import { ProviderId } from '../providers/registry' import type { AiPlugin, AiRequestContext } from './types' // 插件管理器 export { PluginManager } from './manager' // 工具函数 -export function createContext(providerId: string, modelId: string, originalParams: any): AiRequestContext { +export function createContext(providerId: ProviderId, modelId: string, originalParams: any): AiRequestContext { return { providerId, modelId, diff --git a/packages/aiCore/src/core/plugins/types.ts b/packages/aiCore/src/core/plugins/types.ts index ffa7a66714..4aa9001526 100644 --- a/packages/aiCore/src/core/plugins/types.ts +++ b/packages/aiCore/src/core/plugins/types.ts @@ -20,6 +20,7 @@ export interface AiRequestContext { requestId: string recursiveCall: RecursiveCallFn isRecursiveCall?: boolean + mcpTools?: ToolSet [key: string]: any } @@ -47,7 +48,7 @@ export interface AiPlugin { transformStream?: ( params: any, context: AiRequestContext - ) => (options: { + ) => (options?: { tools: TOOLS stopStream: () => void }) => TransformStream, TextStreamPart> diff --git a/packages/aiCore/src/core/runtime/index.ts b/packages/aiCore/src/core/runtime/index.ts index 8cb9285c3c..e6a03a4340 100644 --- a/packages/aiCore/src/core/runtime/index.ts +++ b/packages/aiCore/src/core/runtime/index.ts @@ -16,7 +16,7 @@ export type { // === 便捷工厂函数 === -import { LanguageModelV1Middleware } from 'ai' +import { LanguageModelV2Middleware } from '@ai-sdk/provider' import { type ProviderId, type ProviderSettingsMap } from '../../types' import { type AiPlugin } from '../plugins' @@ -54,7 +54,7 @@ export async function streamText( modelId: string, params: Parameters['streamText']>[1], plugins?: AiPlugin[], - middlewares?: LanguageModelV1Middleware[] + middlewares?: LanguageModelV2Middleware[] ): Promise['streamText']>> { const executor = createExecutor(providerId, options, plugins) return executor.streamText(modelId, params, { middlewares }) @@ -69,7 +69,7 @@ export async function generateText( modelId: string, params: Parameters['generateText']>[1], plugins?: AiPlugin[], - middlewares?: LanguageModelV1Middleware[] + middlewares?: LanguageModelV2Middleware[] ): Promise['generateText']>> { const executor = createExecutor(providerId, options, plugins) return executor.generateText(modelId, params, { middlewares }) @@ -84,7 +84,7 @@ export async function generateObject( modelId: string, params: Parameters['generateObject']>[1], plugins?: AiPlugin[], - middlewares?: LanguageModelV1Middleware[] + middlewares?: LanguageModelV2Middleware[] ): Promise['generateObject']>> { const executor = createExecutor(providerId, options, plugins) return executor.generateObject(modelId, params, { middlewares }) @@ -99,7 +99,7 @@ export async function streamObject( modelId: string, params: Parameters['streamObject']>[1], plugins?: AiPlugin[], - middlewares?: LanguageModelV1Middleware[] + middlewares?: LanguageModelV2Middleware[] ): Promise['streamObject']>> { const executor = createExecutor(providerId, options, plugins) return executor.streamObject(modelId, params, { middlewares }) diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts index 9c4e621911..1cb02afb94 100644 --- a/packages/aiCore/src/index.ts +++ b/packages/aiCore/src/index.ts @@ -68,8 +68,10 @@ export type { TextStreamPart, // 工具相关类型 Tool, + ToolCallUnion, ToolModelMessage, ToolResultPart, + ToolSet, UserModelMessage } from 'ai' export { defaultSettingsMiddleware, extractReasoningMiddleware, simulateStreamingMiddleware, smoothStream } from 'ai' diff --git a/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts index 2d258ba1d3..494f69effe 100644 --- a/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts +++ b/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts @@ -3,7 +3,8 @@ * 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式 */ -import { TextStreamPart } from '@cherrystudio/ai-core' +import { TextStreamPart, ToolSet } from '@cherrystudio/ai-core' +import { MCPTool, WebSearchSource } from '@renderer/types' import { Chunk, ChunkType } from '@renderer/types/chunk' import { ToolCallChunkHandler } from './chunk/handleTooCallChunk' @@ -24,8 +25,11 @@ export interface CherryStudioChunk { */ export class AiSdkToChunkAdapter { toolCallHandler: ToolCallChunkHandler - constructor(private onChunk: (chunk: Chunk) => void) { - this.toolCallHandler = new ToolCallChunkHandler(onChunk) + constructor( + private onChunk: (chunk: Chunk) => void, + private mcpTools: MCPTool[] = [] + ) { + this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools) } /** @@ -47,7 +51,7 @@ export class AiSdkToChunkAdapter { * 读取 fullStream 并转换为 Cherry Studio chunks * @param fullStream AI SDK 的 fullStream (ReadableStream) */ - private async readFullStream(fullStream: ReadableStream>) { + private async readFullStream(fullStream: ReadableStream>) { const reader = fullStream.getReader() const final = { text: '', @@ -73,84 +77,39 @@ export class AiSdkToChunkAdapter { * 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调 * @param chunk AI SDK 的 chunk 数据 */ - private convertAndEmitChunk(chunk: any, final: { text: string; reasoning_content: string }) { + private convertAndEmitChunk(chunk: TextStreamPart, final: { text: string; reasoning_content: string }) { console.log('AI SDK chunk type:', chunk.type, chunk) switch (chunk.type) { // === 文本相关事件 === - case 'text-delta': - final.text += chunk.textDelta || '' + case 'text': + final.text += chunk.text || '' this.onChunk({ type: ChunkType.TEXT_DELTA, - text: chunk.textDelta || '' + text: chunk.text || '' + }) + break + case 'text-end': + this.onChunk({ + type: ChunkType.TEXT_COMPLETE, + text: final.text || '' }) break - case 'reasoning': this.onChunk({ type: ChunkType.THINKING_DELTA, - text: chunk.textDelta || '', - // 自定义字段 - thinking_millsec: chunk.thinking_millsec || 0 + text: chunk.text || '', + thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0 }) break - case 'redacted-reasoning': - this.onChunk({ - type: ChunkType.THINKING_DELTA, - text: chunk.data || '' - }) - break - case 'reasoning-signature': + case 'reasoning-end': this.onChunk({ type: ChunkType.THINKING_COMPLETE, - text: chunk.text || '', - thinking_millsec: chunk.thinking_millsec || 0 + text: (chunk.providerMetadata?.metadata?.thinking_content as string) || '', + thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0 }) break // === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) === - case 'tool-call-streaming-start': - // 开始流式工具调用 - this.onChunk({ - type: ChunkType.MCP_TOOL_CREATED, - tool_calls: [ - { - id: chunk.toolCallId, - name: chunk.toolName, - args: {} - } - ] - }) - break - - case 'tool-call-delta': - // 工具调用参数的增量更新 - this.onChunk({ - type: ChunkType.MCP_TOOL_IN_PROGRESS, - responses: [ - { - id: chunk.toolCallId, - tool: { - id: chunk.toolName, - // TODO: serverId,serverName - serverId: 'ai-sdk', - serverName: 'AI SDK', - name: chunk.toolName, - description: '', - inputSchema: { - type: 'object', - title: chunk.toolName, - properties: {} - } - }, - arguments: {}, - status: 'invoking', - response: chunk.argsTextDelta, - toolCallId: chunk.toolCallId - } - ] - }) - break - case 'tool-call': // 原始的工具调用(未被中间件处理) this.toolCallHandler.handleToolCall(chunk) @@ -160,6 +119,11 @@ export class AiSdkToChunkAdapter { // 原始的工具调用结果(未被中间件处理) this.toolCallHandler.handleToolResult(chunk) break + // case 'start': + // this.onChunk({ + // type: ChunkType.LLM_RESPONSE_CREATED + // }) + // break // === 步骤相关事件 === // TODO: 需要区分接口开始和步骤开始 @@ -168,13 +132,17 @@ export class AiSdkToChunkAdapter { // type: ChunkType.LLM_RESPONSE_CREATED // }) // break - case 'step-finish': - this.onChunk({ - type: ChunkType.TEXT_COMPLETE, - text: final.text || '' // TEXT_COMPLETE 需要 text 字段 - }) - final.text = '' - break + // case 'step-finish': + // this.onChunk({ + // type: ChunkType.TEXT_COMPLETE, + // text: final.text || '' // TEXT_COMPLETE 需要 text 字段 + // }) + // final.text = '' + // break + + // case 'finish-step': { + // const { totalUsage, finishReason, providerMetadata } = chunk + // } case 'finish': this.onChunk({ @@ -183,13 +151,13 @@ export class AiSdkToChunkAdapter { text: final.text || '', reasoning_content: final.reasoning_content || '', usage: { - completion_tokens: chunk.usage.completionTokens || 0, - prompt_tokens: chunk.usage.promptTokens || 0, - total_tokens: chunk.usage.totalTokens || 0 + completion_tokens: chunk.totalUsage.outputTokens || 0, + prompt_tokens: chunk.totalUsage.inputTokens || 0, + total_tokens: chunk.totalUsage.totalTokens || 0 }, - metrics: chunk.usage + metrics: chunk.totalUsage ? { - completion_tokens: chunk.usage.completionTokens || 0, + completion_tokens: chunk.totalUsage.outputTokens || 0, time_completion_millsec: 0 } : undefined @@ -201,13 +169,13 @@ export class AiSdkToChunkAdapter { text: final.text || '', reasoning_content: final.reasoning_content || '', usage: { - completion_tokens: chunk.usage.completionTokens || 0, - prompt_tokens: chunk.usage.promptTokens || 0, - total_tokens: chunk.usage.totalTokens || 0 + completion_tokens: chunk.totalUsage.outputTokens || 0, + prompt_tokens: chunk.totalUsage.inputTokens || 0, + total_tokens: chunk.totalUsage.totalTokens || 0 }, - metrics: chunk.usage + metrics: chunk.totalUsage ? { - completion_tokens: chunk.usage.completionTokens || 0, + completion_tokens: chunk.totalUsage.outputTokens || 0, time_completion_millsec: 0 } : undefined @@ -217,30 +185,24 @@ export class AiSdkToChunkAdapter { // === 源和文件相关事件 === case 'source': - // 源信息,可以映射到知识搜索完成 this.onChunk({ - type: ChunkType.KNOWLEDGE_SEARCH_COMPLETE, - knowledge: [ - { - id: Number(chunk.source.id) || Date.now(), - content: chunk.source.title || '', - sourceUrl: chunk.source.url || '', - type: 'url' - } - ] - }) - break - - case 'file': - // 文件相关事件,可能是图片生成 - this.onChunk({ - type: ChunkType.IMAGE_COMPLETE, - image: { - type: 'base64', - images: [chunk.base64] + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + source: WebSearchSource.AISDK, + results: [{}] } }) break + // case 'file': + // // 文件相关事件,可能是图片生成 + // this.onChunk({ + // type: ChunkType.IMAGE_COMPLETE, + // image: { + // type: 'base64', + // images: [chunk.base64] + // } + // }) + // break case 'error': this.onChunk({ type: ChunkType.ERROR, diff --git a/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts b/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts index 1943067673..d550dc569a 100644 --- a/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts +++ b/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts @@ -4,8 +4,9 @@ * 提供工具调用相关的处理API,每个交互使用一个新的实例 */ +import { ToolCallUnion, ToolSet } from '@cherrystudio/ai-core/index' import Logger from '@renderer/config/logger' -import { MCPToolResponse } from '@renderer/types' +import { MCPTool, MCPToolResponse } from '@renderer/types' import { Chunk, ChunkType } from '@renderer/types/chunk' /** @@ -19,10 +20,13 @@ export class ToolCallChunkHandler { toolCallId: string toolName: string args: any - // mcpTool: MCPTool + mcpTool: MCPTool } >() - constructor(private onChunk: (chunk: Chunk) => void) {} + constructor( + private onChunk: (chunk: Chunk) => void, + private mcpTools: MCPTool[] + ) {} // /** // * 设置 onChunk 回调 @@ -34,10 +38,14 @@ export class ToolCallChunkHandler { /** * 处理工具调用事件 */ - public handleToolCall(chunk: any): void { + public handleToolCall( + chunk: { + type: 'tool-call' + } & ToolCallUnion + ): void { const toolCallId = chunk.toolCallId const toolName = chunk.toolName - const args = chunk.args || {} + const args = chunk.input || {} if (!toolCallId || !toolName) { Logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`) @@ -51,17 +59,14 @@ export class ToolCallChunkHandler { this.activeToolCalls.set(toolCallId, { toolCallId, toolName, - args - // mcpTool + args, + mcpTool: this.mcpTools.find((tool) => tool.name === toolName)! }) // 创建 MCPToolResponse 格式 const toolResponse: MCPToolResponse = { id: toolCallId, - tool: { - id: toolCallId, - name: toolName - }, + tool: this.activeToolCalls.get(toolCallId)!.mcpTool, arguments: args, status: 'invoking', toolCallId: toolCallId @@ -98,10 +103,7 @@ export class ToolCallChunkHandler { // 创建工具调用结果的 MCPToolResponse 格式 const toolResponse: MCPToolResponse = { id: toolCallId, - tool: { - id: toolCallId, - name: toolCallInfo.toolName - }, + tool: toolCallInfo.mcpTool, arguments: toolCallInfo.args, status: 'done', response: { diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 17ff9076f8..ac5e2b1448 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -195,7 +195,7 @@ export default class ModernAiProvider { // 创建带有中间件的执行器 if (middlewareConfig.onChunk) { // 流式处理 - 使用适配器 - const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk) + const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk, middlewareConfig.mcpTools) console.log('最终params', params) const streamResult = await executor.streamText( modelId, diff --git a/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts index 303c7f80a8..de42d1ab45 100644 --- a/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts @@ -1,8 +1,4 @@ -import { - extractReasoningMiddleware, - LanguageModelV1Middleware, - simulateStreamingMiddleware -} from '@cherrystudio/ai-core' +import { LanguageModelV2Middleware, simulateStreamingMiddleware } from '@cherrystudio/ai-core' import type { MCPTool, Model, Provider } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' @@ -26,7 +22,7 @@ export interface AiSdkMiddlewareConfig { */ export interface NamedAiSdkMiddleware { name: string - middleware: LanguageModelV1Middleware + middleware: LanguageModelV2Middleware } /** @@ -75,7 +71,7 @@ export class AiSdkMiddlewareBuilder { /** * 构建最终的中间件数组 */ - public build(): LanguageModelV1Middleware[] { + public build(): LanguageModelV2Middleware[] { return this.middlewares.map((m) => m.middleware) } @@ -106,7 +102,7 @@ export class AiSdkMiddlewareBuilder { * 根据配置构建AI SDK中间件的工厂函数 * 这里要注意构建顺序,因为有些中间件需要依赖其他中间件的结果 */ -export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV1Middleware[] { +export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV2Middleware[] { const builder = new AiSdkMiddlewareBuilder() // 1. 根据provider添加特定中间件 @@ -143,10 +139,10 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: // Anthropic特定中间件 break case 'openai': - builder.add({ - name: 'thinking-tag-extraction', - middleware: extractReasoningMiddleware({ tagName: 'think' }) - }) + // builder.add({ + // name: 'thinking-tag-extraction', + // middleware: extractReasoningMiddleware({ tagName: 'think' }) + // }) break case 'gemini': // Gemini特定中间件 diff --git a/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts b/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts index 99e731e56f..26a019b192 100644 --- a/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts +++ b/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts @@ -1,4 +1,4 @@ -import { definePlugin } from '@cherrystudio/ai-core' +import { definePlugin, TextStreamPart, ToolSet } from '@cherrystudio/ai-core' export default definePlugin({ name: 'reasoningTimePlugin', @@ -8,57 +8,62 @@ export default definePlugin({ let thinkingStartTime = 0 let hasStartedThinking = false let accumulatedThinkingContent = '' + let reasoningBlockId = '' - return new TransformStream({ - transform(chunk, controller) { - if (chunk.type !== 'reasoning') { - // === 处理 reasoning 结束 === - if (hasStartedThinking) { - console.log(`[ReasoningPlugin] Ending reasoning.`) - - // 生成 reasoning-signature - controller.enqueue({ - type: 'reasoning-signature', - text: accumulatedThinkingContent, - thinking_millsec: performance.now() - thinkingStartTime - }) - - // 重置状态 - accumulatedThinkingContent = '' - hasStartedThinking = false - thinkingStartTime = 0 - } - - controller.enqueue(chunk) - return - } - + return new TransformStream, TextStreamPart>({ + transform(chunk: TextStreamPart, controller: TransformStreamDefaultController>) { // === 处理 reasoning 类型 === + if (chunk.type === 'reasoning') { + if (!hasStartedThinking) { + hasStartedThinking = true + thinkingStartTime = performance.now() + reasoningBlockId = chunk.id + } + accumulatedThinkingContent += chunk.text - // 1. 时间跟踪逻辑 - if (!hasStartedThinking) { - hasStartedThinking = true - thinkingStartTime = performance.now() - console.log(`[ReasoningPlugin] Starting reasoning session`) + controller.enqueue({ + ...chunk, + providerMetadata: { + ...chunk.providerMetadata, + metadata: { + ...chunk.providerMetadata?.metadata, + thinking_millsec: performance.now() - thinkingStartTime, + thinking_content: accumulatedThinkingContent + } + } + }) + } else if (hasStartedThinking) { + controller.enqueue({ + type: 'reasoning-end', + id: reasoningBlockId, + providerMetadata: { + metadata: { + thinking_millsec: performance.now() - thinkingStartTime, + thinking_content: accumulatedThinkingContent + } + } + }) + accumulatedThinkingContent = '' + hasStartedThinking = false + thinkingStartTime = 0 + reasoningBlockId = '' + controller.enqueue(chunk) + } else { + controller.enqueue(chunk) } - accumulatedThinkingContent += chunk.textDelta - - // 2. 直接透传 chunk,并附加上时间 - console.log(`[ReasoningPlugin] Forwarding reasoning chunk: "${chunk.textDelta}"`) - controller.enqueue({ - ...chunk, - thinking_millsec: performance.now() - thinkingStartTime - }) }, - // === flush 处理流结束时仍在reasoning状态的场景 === flush(controller) { if (hasStartedThinking) { - console.log(`[ReasoningPlugin] Final flush for reasoning-signature.`) controller.enqueue({ - type: 'reasoning-signature', - text: accumulatedThinkingContent, - thinking_millsec: performance.now() - thinkingStartTime + type: 'reasoning-end', + id: reasoningBlockId, + providerMetadata: { + metadata: { + thinking_millsec: performance.now() - thinkingStartTime, + thinking_content: accumulatedThinkingContent + } + } }) } } diff --git a/src/renderer/src/aiCore/utils/mcp.ts b/src/renderer/src/aiCore/utils/mcp.ts index 09153698a3..39c36deb4f 100644 --- a/src/renderer/src/aiCore/utils/mcp.ts +++ b/src/renderer/src/aiCore/utils/mcp.ts @@ -48,7 +48,7 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record({ description: mcpTool.description || `Tool from ${mcpTool.serverName}`, - parameters: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7), + inputSchema: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7), execute: async (params): Promise => { console.log('execute_params', params) // 创建适配的 MCPToolResponse 对象 diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 1ade4ed824..8a9bc7cbc8 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -545,7 +545,8 @@ export enum WebSearchSource { QWEN = 'qwen', HUNYUAN = 'hunyuan', ZHIPU = 'zhipu', - GROK = 'grok' + GROK = 'grok', + AISDK = 'ai-sdk' } export type WebSearchResponse = { diff --git a/yarn.lock b/yarn.lock index e83eb05c31..e7be1eba5f 100644 --- a/yarn.lock +++ b/yarn.lock @@ -956,6 +956,7 @@ __metadata: ollama-ai-provider: "npm:^1.2.0" tsdown: "npm:^0.12.9" typescript: "npm:^5.0.0" + zod: "npm:^3.25.0" peerDependenciesMeta: "@ai-sdk/amazon-bedrock": optional: true @@ -20140,6 +20141,13 @@ __metadata: languageName: node linkType: hard +"zod@npm:^3.25.0": + version: 3.25.74 + resolution: "zod@npm:3.25.74" + checksum: 10c0/59e38b046ac333b5bd1ba325a83b6798721227cbfb1e69dfc7159bd7824b904241ab923026edb714fafefec3624265ae374a70aee9a5a45b365bd31781ffa105 + languageName: node + linkType: hard + "zustand@npm:^4.4.0": version: 4.5.6 resolution: "zustand@npm:4.5.6"