From 1bccfd31708a178111529716b7eccab00893d907 Mon Sep 17 00:00:00 2001 From: suyao Date: Fri, 20 Jun 2025 05:44:44 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=8C=E6=88=90api=E5=B1=82=EF=BC=8C?= =?UTF-8?q?=E4=B8=9A=E5=8A=A1=E9=80=BB=E8=BE=91=E5=B1=82=EF=BC=8C=E7=BC=96?= =?UTF-8?q?=E6=8E=92=E5=B1=82=E7=9A=84=E5=88=86=E7=A6=BB=20feat:=20?= =?UTF-8?q?=E4=B8=BA=E6=8F=92=E4=BB=B6=E7=B3=BB=E7=BB=9F=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E4=B8=AD=E9=97=B4=E4=BB=B6=20feat:=20=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E8=87=AA=E5=AE=9A=E4=B9=89=E7=9A=84=E6=80=9D=E8=80=83=E4=B8=AD?= =?UTF-8?q?=E9=97=B4=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Updated package.json and related files to reflect the correct naming convention for the @cherrystudio/ai-core package. - Adjusted import paths in various files to ensure consistency with the new package name. - Enhanced type resolution in tsconfig.web.json to align with the updated package structure. --- package.json | 2 +- packages/aiCore/AI_SDK_ARCHITECTURE.md | 6 +- packages/aiCore/README.md | 10 +- packages/aiCore/package.json | 2 +- .../aiCore/src/clients/ApiClientFactory.ts | 19 +- .../src/clients/PluginEnabledAiClient.ts | 50 +- .../src/clients/UniversalAiSdkClient.ts | 2 +- packages/aiCore/src/index.ts | 4 +- packages/aiCore/src/plugins/README.md | 4 +- packages/aiCore/src/plugins/manager.ts | 9 +- packages/aiCore/src/providers/registry.ts | 4 - .../src/aiCore/AiSdkToChunkAdapter.ts | 30 +- src/renderer/src/aiCore/index_new.ts | 88 +-- .../aisdk/ThinkingTimeMiddleware.ts | 70 +++ .../src/aiCore/transformParameters.ts | 70 ++- src/renderer/src/services/ApiService.ts | 580 ++++++++---------- .../src/services/ConversationService.ts | 34 + .../src/services/OrchestrateService.ts | 54 ++ src/renderer/src/store/thunk/messageThunk.ts | 29 +- tsconfig.web.json | 2 +- yarn.lock | 6 +- 21 files changed, 610 insertions(+), 465 deletions(-) create mode 100644 src/renderer/src/aiCore/middleware/aisdk/ThinkingTimeMiddleware.ts create mode 100644 src/renderer/src/services/ConversationService.ts create mode 100644 src/renderer/src/services/OrchestrateService.ts diff --git a/package.json b/package.json index 0be12b6185..3fb89dcdba 100644 --- a/package.json +++ b/package.json @@ -73,7 +73,7 @@ "@agentic/tavily": "^7.3.3", "@ant-design/v5-patch-for-react-19": "^1.0.3", "@anthropic-ai/sdk": "^0.41.0", - "@cherry-studio/ai-core": "workspace:*", + "@cherrystudio/ai-core": "workspace:*", "@cherrystudio/embedjs": "^0.1.31", "@cherrystudio/embedjs-libsql": "^0.1.31", "@cherrystudio/embedjs-loader-csv": "^0.1.31", diff --git a/packages/aiCore/AI_SDK_ARCHITECTURE.md b/packages/aiCore/AI_SDK_ARCHITECTURE.md index af644f71f1..b091a57a41 100644 --- a/packages/aiCore/AI_SDK_ARCHITECTURE.md +++ b/packages/aiCore/AI_SDK_ARCHITECTURE.md @@ -317,7 +317,7 @@ export class AiCoreService { ### 5.1 多 Provider 支持 ```typescript -import { createAiSdkClient, AiCore } from '@cherry-studio/ai-core' +import { createAiSdkClient, AiCore } from '@cherrystudio/ai-core' // 检查支持的 providers const providers = AiCore.getSupportedProviders() @@ -339,7 +339,7 @@ const xai = await createAiSdkClient('xai', { apiKey: 'xai-key' }) // const anthropicClient = new AnthropicApiClient(config) // 现在: -import { createAiSdkClient } from '@cherry-studio/ai-core' +import { createAiSdkClient } from '@cherrystudio/ai-core' const createProviderClient = async (provider: CherryProvider) => { return await createAiSdkClient(provider.id, { @@ -359,7 +359,7 @@ import { PreRequestMiddleware, StreamProcessingMiddleware, PostResponseMiddleware -} from '@cherry-studio/ai-core' +} from '@cherrystudio/ai-core' // 创建完整的工作流 const createEnhancedAiService = async () => { diff --git a/packages/aiCore/README.md b/packages/aiCore/README.md index ddccffdda2..7e10c998f5 100644 --- a/packages/aiCore/README.md +++ b/packages/aiCore/README.md @@ -1,4 +1,4 @@ -# @cherry-studio/ai-core +# @cherrystudio/ai-core Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口包。 @@ -42,7 +42,7 @@ Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口 ## 安装 ```bash -npm install @cherry-studio/ai-core ai +npm install @cherrystudio/ai-core ai ``` 还需要安装你要使用的 AI SDK provider: @@ -56,7 +56,7 @@ npm install @ai-sdk/openai @ai-sdk/anthropic @ai-sdk/google ### 基础用法 ```typescript -import { createAiSdkClient } from '@cherry-studio/ai-core' +import { createAiSdkClient } from '@cherrystudio/ai-core' // 创建 OpenAI 客户端 const client = await createAiSdkClient('openai', { @@ -79,7 +79,7 @@ const response = await client.generate({ ### 便捷函数 ```typescript -import { createOpenAIClient, streamGeneration } from '@cherry-studio/ai-core' +import { createOpenAIClient, streamGeneration } from '@cherrystudio/ai-core' // 快速创建 OpenAI 客户端 const client = await createOpenAIClient({ @@ -95,7 +95,7 @@ const result = await streamGeneration('openai', 'gpt-4', [{ role: 'user', conten ### 多 Provider 支持 ```typescript -import { createAiSdkClient } from '@cherry-studio/ai-core' +import { createAiSdkClient } from '@cherrystudio/ai-core' // 支持多种 AI providers const openaiClient = await createAiSdkClient('openai', { apiKey: 'openai-key' }) diff --git a/packages/aiCore/package.json b/packages/aiCore/package.json index 69f73c341b..aec2760321 100644 --- a/packages/aiCore/package.json +++ b/packages/aiCore/package.json @@ -1,5 +1,5 @@ { - "name": "@cherry-studio/ai-core", + "name": "@cherrystudio/ai-core", "version": "1.0.0", "description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK", "main": "src/index.ts", diff --git a/packages/aiCore/src/clients/ApiClientFactory.ts b/packages/aiCore/src/clients/ApiClientFactory.ts index ed2b3c514b..6c5e67edf6 100644 --- a/packages/aiCore/src/clients/ApiClientFactory.ts +++ b/packages/aiCore/src/clients/ApiClientFactory.ts @@ -4,7 +4,7 @@ */ import type { ImageModelV1 } from '@ai-sdk/provider' -import { type LanguageModelV1, wrapLanguageModel } from 'ai' +import { type LanguageModelV1, LanguageModelV1Middleware, wrapLanguageModel } from 'ai' import { aiProviderRegistry } from '../providers/registry' import { type ProviderId, type ProviderSettingsMap } from './types' @@ -39,16 +39,23 @@ export class ApiClientFactory { static async createClient( providerId: T, modelId: string, - options: ProviderSettingsMap[T] + options: ProviderSettingsMap[T], + middlewares?: LanguageModelV1Middleware[] ): Promise static async createClient( providerId: string, modelId: string, - options: ProviderSettingsMap['openai-compatible'] + options: ProviderSettingsMap['openai-compatible'], + middlewares?: LanguageModelV1Middleware[] ): Promise - static async createClient(providerId: string, modelId: string = 'default', options: any): Promise { + static async createClient( + providerId: string, + modelId: string = 'default', + options: any, + middlewares?: LanguageModelV1Middleware[] + ): Promise { try { // 对于不在注册表中的 provider,默认使用 openai-compatible const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible' @@ -78,10 +85,10 @@ export class ApiClientFactory { let model = provider(modelId) // 应用 AI SDK 中间件 - if (providerConfig.aiSdkMiddlewares) { + if (middlewares && middlewares.length > 0) { model = wrapLanguageModel({ model: model, - middleware: providerConfig.aiSdkMiddlewares + middleware: middlewares }) } diff --git a/packages/aiCore/src/clients/PluginEnabledAiClient.ts b/packages/aiCore/src/clients/PluginEnabledAiClient.ts index 62707e5f5e..0eeffea7b3 100644 --- a/packages/aiCore/src/clients/PluginEnabledAiClient.ts +++ b/packages/aiCore/src/clients/PluginEnabledAiClient.ts @@ -5,7 +5,7 @@ * ## 使用方式 * * ```typescript - * import { AiClient } from '@cherry-studio/ai-core' + * import { AiClient } from '@cherrystudio/ai-core' * * // 创建客户端(默认带插件系统) * const client = AiClient.create('openai', { @@ -19,7 +19,14 @@ * }) * ``` */ -import { generateObject, generateText, streamObject, streamText } from 'ai' +import { + generateObject, + generateText, + LanguageModelV1Middleware, + simulateStreamingMiddleware, + streamObject, + streamText +} from 'ai' import { AiPlugin, createContext, PluginManager } from '../plugins' import { isProviderSupported } from '../providers/registry' @@ -34,6 +41,7 @@ import { UniversalAiSdkClient } from './UniversalAiSdkClient' export class PluginEnabledAiClient { private pluginManager: PluginManager private baseClient: UniversalAiSdkClient + private middlewares: LanguageModelV1Middleware[] = [] constructor( private readonly providerId: T, @@ -42,6 +50,7 @@ export class PluginEnabledAiClient { ) { this.pluginManager = new PluginManager(plugins) this.baseClient = UniversalAiSdkClient.create(providerId, options) + this.updateMiddlewares() } /** @@ -49,6 +58,7 @@ export class PluginEnabledAiClient { */ use(plugin: AiPlugin): this { this.pluginManager.use(plugin) + this.updateMiddlewares() return this } @@ -57,6 +67,7 @@ export class PluginEnabledAiClient { */ usePlugins(plugins: AiPlugin[]): this { plugins.forEach((plugin) => this.pluginManager.use(plugin)) + this.updateMiddlewares() return this } @@ -65,9 +76,19 @@ export class PluginEnabledAiClient { */ removePlugin(pluginName: string): this { this.pluginManager.remove(pluginName) + this.updateMiddlewares() return this } + /** + * 重新计算并更新中间件列表 + * 这是一个原子操作,以确保中间件列表总是最新的 + */ + private updateMiddlewares(): void { + const pluginMiddlewares = this.pluginManager.collectAiSdkMiddlewares() + this.middlewares = pluginMiddlewares + } + /** * 获取插件统计信息 */ @@ -164,6 +185,21 @@ export class PluginEnabledAiClient { } } + /** + * 获取注入了中间件的 AI SDK 模型实例 + * 这是应用原生中间件的关键 + */ + private async getModelWithMiddlewares(modelId: string) { + const middlewares = this.middlewares + // 3. 如果有中间件,创建一个新的、注入了中间件的客户端实例 + return await ApiClientFactory.createClient( + this.providerId, + modelId, + this.options, + middlewares.length > 0 ? middlewares : [simulateStreamingMiddleware()] //TODO: 这里需要改成非流时调用simulateStreamingMiddleware(),这里先随便传一个 + ) + } + /** * 流式文本生成 - 集成插件系统 */ @@ -176,8 +212,7 @@ export class PluginEnabledAiClient { modelId, params, async (finalModelId, transformedParams, streamTransforms) => { - // 对于流式调用,需要直接调用 AI SDK 以支持流转换器 - const model = await ApiClientFactory.createClient(this.providerId, finalModelId, this.options) + const model = await this.getModelWithMiddlewares(finalModelId) return await streamText({ model, ...transformedParams, @@ -189,13 +224,15 @@ export class PluginEnabledAiClient { /** * 生成文本 - 集成插件系统 + * 可能不需要了,因为内置模拟非流中间件 */ async generateText( modelId: string, params: Omit[0], 'model'> ): Promise> { return this.executeWithPlugins('generateText', modelId, params, async (finalModelId, transformedParams) => { - return await this.baseClient.generateText(finalModelId, transformedParams) + const model = await this.getModelWithMiddlewares(finalModelId) + return await generateText({ model, ...transformedParams }) }) } @@ -207,7 +244,8 @@ export class PluginEnabledAiClient { params: Omit[0], 'model'> ): Promise> { return this.executeWithPlugins('generateObject', modelId, params, async (finalModelId, transformedParams) => { - return await this.baseClient.generateObject(finalModelId, transformedParams) + const model = await this.getModelWithMiddlewares(finalModelId) + return await generateObject({ model, ...transformedParams }) }) } diff --git a/packages/aiCore/src/clients/UniversalAiSdkClient.ts b/packages/aiCore/src/clients/UniversalAiSdkClient.ts index 700a52f3ac..7b1f80beb7 100644 --- a/packages/aiCore/src/clients/UniversalAiSdkClient.ts +++ b/packages/aiCore/src/clients/UniversalAiSdkClient.ts @@ -6,7 +6,7 @@ * * ### 1. 官方提供商 * ```typescript - * import { UniversalAiSdkClient } from '@cherry-studio/ai-core' + * import { UniversalAiSdkClient } from '@cherrystudio/ai-core' * * // OpenAI * const openai = UniversalAiSdkClient.create('openai', { diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts index 7968645d68..0165214613 100644 --- a/packages/aiCore/src/index.ts +++ b/packages/aiCore/src/index.ts @@ -66,6 +66,8 @@ export type { GenerateTextResult, InvalidToolArgumentsError, LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage + LanguageModelV1Middleware, + LanguageModelV1StreamPart, // 错误类型 NoSuchToolError, StreamTextResult, @@ -115,7 +117,7 @@ export { getAllProviders, getProvider, isProviderSupported, registerProvider } f // ==================== 包信息 ==================== export const AI_CORE_VERSION = '1.0.0' -export const AI_CORE_NAME = '@cherry-studio/ai-core' +export const AI_CORE_NAME = '@cherrystudio/ai-core' // ==================== 便捷 API ==================== // 主要的便捷工厂类 diff --git a/packages/aiCore/src/plugins/README.md b/packages/aiCore/src/plugins/README.md index 0ef1a2fd1b..bc0e28954e 100644 --- a/packages/aiCore/src/plugins/README.md +++ b/packages/aiCore/src/plugins/README.md @@ -50,7 +50,7 @@ transformStream?: () => (options) => TransformStream } + /** + * 收集所有 AI SDK 原生中间件 + */ + collectAiSdkMiddlewares(): LanguageModelV1Middleware[] { + return this.plugins.flatMap((plugin) => plugin.aiSdkMiddlewares || []) + } + /** * 获取所有插件信息 */ diff --git a/packages/aiCore/src/providers/registry.ts b/packages/aiCore/src/providers/registry.ts index 863e979348..53504b4b07 100644 --- a/packages/aiCore/src/providers/registry.ts +++ b/packages/aiCore/src/providers/registry.ts @@ -1,5 +1,3 @@ -import type { LanguageModelV1Middleware } from 'ai' - /** * AI Provider 注册表 * 静态类型 + 动态导入模式:所有类型静态导入,所有实现动态导入 @@ -73,8 +71,6 @@ export interface ProviderConfig { creatorFunctionName: string // 是否支持图片生成 supportsImageGeneration?: boolean - // AI SDK 原生中间件 - aiSdkMiddlewares?: LanguageModelV1Middleware[] } /** diff --git a/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts index c15ebdebeb..ae6ba25987 100644 --- a/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts +++ b/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts @@ -3,7 +3,7 @@ * 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式 */ -import { TextStreamPart } from '@cherry-studio/ai-core' +import { TextStreamPart } from '@cherrystudio/ai-core' import { Chunk, ChunkType } from '@renderer/types/chunk' export interface CherryStudioChunk { @@ -78,38 +78,14 @@ export class AiSdkToChunkAdapter { type: ChunkType.TEXT_DELTA, text: chunk.textDelta || '' }) - if (final.reasoning_content) { - this.onChunk({ - type: ChunkType.THINKING_COMPLETE, - text: final.reasoning_content || '' - }) - final.reasoning_content = '' - } - break - - // === 推理相关事件 === - case 'reasoning': - final.reasoning_content += chunk.textDelta || '' - this.onChunk({ - type: ChunkType.THINKING_DELTA, - text: chunk.textDelta || '' - }) break case 'reasoning-signature': - // 推理签名,可以映射到思考完成 - this.onChunk({ - type: ChunkType.THINKING_COMPLETE, - text: chunk.signature || '' - }) + // 不再需要处理,中间件会发出 THINKING_COMPLETE break case 'redacted-reasoning': - // 被编辑的推理内容,也映射到思考 - this.onChunk({ - type: ChunkType.THINKING_DELTA, - text: chunk.data || '' - }) + // 不再需要处理 break // === 工具调用相关事件 === diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 12ec392cec..559f9bb5f7 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -1,6 +1,6 @@ /** * Cherry Studio AI Core - 新版本入口 - * 集成 @cherry-studio/ai-core 库的渐进式重构方案 + * 集成 @cherrystudio/ai-core 库的渐进式重构方案 * * 融合方案:简化实现,专注于核心功能 * 1. 优先使用新AI SDK @@ -13,20 +13,20 @@ import { AiCore, createClient, type OpenAICompatibleProviderSettings, - type ProviderId -} from '@cherry-studio/ai-core' + type ProviderId, + StreamTextParams +} from '@cherrystudio/ai-core' import { isDedicatedImageGenerationModel } from '@renderer/config/models' import type { GenerateImageParams, Model, Provider } from '@renderer/types' -import { Chunk, ChunkType } from '@renderer/types/chunk' -import { RequestOptions } from '@renderer/types/sdk' +import { Chunk } from '@renderer/types/chunk' // 引入适配器 import AiSdkToChunkAdapter from './AiSdkToChunkAdapter' // 引入原有的AiProvider作为fallback import LegacyAiProvider from './index' -import { CompletionsParams, CompletionsResult } from './middleware/schemas' +import thinkingTimeMiddleware from './middleware/aisdk/ThinkingTimeMiddleware' +import { CompletionsResult } from './middleware/schemas' // 引入参数转换模块 -import { buildStreamTextParams } from './transformParameters' /** * 将现有 Provider 类型映射到 AI SDK 的 Provider ID @@ -108,21 +108,30 @@ export default class ModernAiProvider { private legacyProvider: LegacyAiProvider private provider: Provider - constructor(provider: Provider) { + constructor(provider: Provider, onChunk?: (chunk: Chunk) => void) { this.provider = provider this.legacyProvider = new LegacyAiProvider(provider) const config = providerToAiSdkConfig(provider) - this.modernClient = createClient(config.providerId, config.options) + this.modernClient = createClient( + config.providerId, + config.options, + onChunk ? [{ name: 'thinking-time', aiSdkMiddlewares: [thinkingTimeMiddleware(onChunk)] }] : undefined + ) } - public async completions(params: CompletionsParams, options?: RequestOptions): Promise { + public async completions( + modelId: string, + params: StreamTextParams, + onChunk?: (chunk: Chunk) => void + ): Promise { // const model = params.assistant.model // 检查是否应该使用现代化客户端 // if (this.modernClient && model && isModernSdkSupported(this.provider, model)) { // try { - return await this.modernCompletions(params, options) + console.log('completions', modelId, params, onChunk) + return await this.modernCompletions(modelId, params, onChunk) // } catch (error) { // console.warn('Modern client failed, falling back to legacy:', error) // fallback到原有实现 @@ -137,66 +146,33 @@ export default class ModernAiProvider { * 使用现代化AI SDK的completions实现 * 使用 AiSdkUtils 工具模块进行参数构建 */ - private async modernCompletions(params: CompletionsParams, options?: RequestOptions): Promise { - if (!this.modernClient || !params.assistant.model) { - throw new Error('Modern client not available') + private async modernCompletions( + modelId: string, + params: StreamTextParams, + onChunk?: (chunk: Chunk) => void + ): Promise { + if (!this.modernClient) { + throw new Error('Modern AI SDK client not initialized') } - console.log('Modern completions with params:', params, 'options:', options) - - const model = params.assistant.model - const assistant = params.assistant - - // 检查 messages 类型并转换 - const messages = Array.isArray(params.messages) ? params.messages : [] - if (typeof params.messages === 'string') { - console.warn('Messages is string, using empty array') - } - - // 使用 transformParameters 模块构建参数 - const aiSdkParams = await buildStreamTextParams(messages, assistant, model, { - maxTokens: params.maxTokens, - mcpTools: params.mcpTools - }) - - console.log('Built AI SDK params:', aiSdkParams) - const chunks: Chunk[] = [] - try { - if (params.streamOutput && params.onChunk) { + if (onChunk) { // 流式处理 - 使用适配器 - const adapter = new AiSdkToChunkAdapter(params.onChunk) - const streamResult = await this.modernClient.streamText(model.id, aiSdkParams) + const adapter = new AiSdkToChunkAdapter(onChunk) + const streamResult = await this.modernClient.streamText(modelId, params) const finalText = await adapter.processStream(streamResult) return { getText: () => finalText } - } else if (params.streamOutput) { + } else { // 流式处理但没有 onChunk 回调 - const streamResult = await this.modernClient.streamText(model.id, aiSdkParams) + const streamResult = await this.modernClient.streamText(modelId, params) const finalText = await streamResult.text return { getText: () => finalText } - } else { - // 非流式处理 - const result = await this.modernClient.generateText(model.id, aiSdkParams) - - const cherryChunk: Chunk = { - type: ChunkType.TEXT_COMPLETE, - text: result.text || '' - } - chunks.push(cherryChunk) - - if (params.onChunk) { - params.onChunk(cherryChunk) - } - - return { - getText: () => result.text || '' - } } } catch (error) { console.error('Modern AI SDK error:', error) diff --git a/src/renderer/src/aiCore/middleware/aisdk/ThinkingTimeMiddleware.ts b/src/renderer/src/aiCore/middleware/aisdk/ThinkingTimeMiddleware.ts new file mode 100644 index 0000000000..66db710874 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/aisdk/ThinkingTimeMiddleware.ts @@ -0,0 +1,70 @@ +import { LanguageModelV1Middleware, LanguageModelV1StreamPart } from '@cherrystudio/ai-core' +import { Chunk, ChunkType, ThinkingCompleteChunk } from '@renderer/types/chunk' + +/** + * 一个用于统计 LLM "思考时间"(Time to First Token)的 AI SDK 中间件。 + * + * 工作原理: + * 1. 在 `stream` 方法被调用时,记录一个起始时间。 + * 2. 它会创建一个新的 `TransformStream` 来代理原始的流。 + * 3. 当第一个数据块 (chunk) 从原始流中到达时,记录结束时间。 + * 4. 计算两者之差,即为 "思考时间" + */ +export default function thinkingTimeMiddleware(onChunkReceived: (chunk: Chunk) => void): LanguageModelV1Middleware { + return { + wrapStream: async ({ doStream }) => { + let hasThinkingContent = false + let thinkingStartTime = 0 + let accumulatedThinkingContent = '' + const { stream, ...reset } = await doStream() + const transformStream = new TransformStream({ + transform(chunk, controller) { + if (chunk.type === 'reasoning' || chunk.type === 'redacted-reasoning') { + if (!hasThinkingContent) { + hasThinkingContent = true + thinkingStartTime = Date.now() + } + accumulatedThinkingContent += chunk.textDelta || '' + onChunkReceived({ + type: ChunkType.THINKING_DELTA, + text: chunk.textDelta || '' + }) + } else { + if (hasThinkingContent && thinkingStartTime > 0) { + const thinkingTime = Date.now() - thinkingStartTime + const thinkingCompleteChunk: ThinkingCompleteChunk = { + type: ChunkType.THINKING_COMPLETE, + text: accumulatedThinkingContent, + thinking_millsec: thinkingTime + } + onChunkReceived(thinkingCompleteChunk) + hasThinkingContent = false + thinkingStartTime = 0 + accumulatedThinkingContent = '' + } + } + // 将所有 chunk 原样传递下去 + controller.enqueue(chunk) + }, + flush(controller) { + // 如果流的末尾都是 reasoning,也需要发送 complete 事件 + if (hasThinkingContent && thinkingStartTime > 0) { + const thinkingTime = Date.now() - thinkingStartTime + const thinkingCompleteChunk: ThinkingCompleteChunk = { + type: ChunkType.THINKING_COMPLETE, + text: accumulatedThinkingContent, + thinking_millsec: thinkingTime + } + onChunkReceived(thinkingCompleteChunk) + } + controller.terminate() + } + }) + + return { + stream: stream.pipeThrough(transformStream), + ...reset + } + } + } +} diff --git a/src/renderer/src/aiCore/transformParameters.ts b/src/renderer/src/aiCore/transformParameters.ts index 25c5a400c6..128df9442d 100644 --- a/src/renderer/src/aiCore/transformParameters.ts +++ b/src/renderer/src/aiCore/transformParameters.ts @@ -3,8 +3,19 @@ * 统一管理从各个 apiClient 提取的参数处理和转换功能 */ -import type { StreamTextParams } from '@cherry-studio/ai-core' -import { isNotSupportTemperatureAndTopP, isSupportedFlexServiceTier } from '@renderer/config/models' +import type { CoreMessage, StreamTextParams } from '@cherrystudio/ai-core' +import { + isGenerateImageModel, + isNotSupportTemperatureAndTopP, + isOpenRouterBuiltInWebSearchModel, + isReasoningModel, + isSupportedDisableGenerationModel, + isSupportedFlexServiceTier, + isSupportedReasoningEffortModel, + isSupportedThinkingTokenModel, + isWebSearchModel +} from '@renderer/config/models' +import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService' import type { Assistant, MCPTool, Message, Model } from '@renderer/types' import { FileTypes } from '@renderer/types' import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' @@ -183,19 +194,38 @@ export async function convertMessagesToSdkMessages( * 这是主要的参数构建函数,整合所有转换逻辑 */ export async function buildStreamTextParams( - messages: Message[], + sdkMessages: StreamTextParams['messages'], assistant: Assistant, - model: Model, options: { - maxTokens?: number mcpTools?: MCPTool[] enableTools?: boolean + requestOptions?: { + signal?: AbortSignal + timeout?: number + headers?: Record + } } = {} -): Promise { - const { maxTokens, mcpTools, enableTools = false } = options +): Promise<{ params: StreamTextParams; modelId: string }> { + const { mcpTools, enableTools = false } = options - // 转换消息 - const sdkMessages = await convertMessagesToSdkMessages(messages, model) + const model = assistant.model || getDefaultModel() + + const { maxTokens, reasoning_effort } = getAssistantSettings(assistant) + + const enableReasoning = + ((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) && + reasoning_effort !== undefined) || + (isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model))) + + const enableWebSearch = + (assistant.enableWebSearch && isWebSearchModel(model)) || + isOpenRouterBuiltInWebSearchModel(model) || + model.id.includes('sonar') || + false + + const enableGenerateImage = + isGenerateImageModel(model) && + (isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true) // 构建系统提示 let systemPrompt = assistant.prompt || '' @@ -210,6 +240,20 @@ export async function buildStreamTextParams( temperature: getTemperature(assistant, model), topP: getTopP(assistant, model), system: systemPrompt || undefined, + abortSignal: options.requestOptions?.signal, + headers: options.requestOptions?.headers, + // 随便填着,后面再改 + providerOptions: { + reasoning: { + enabled: enableReasoning + }, + webSearch: { + enabled: enableWebSearch + }, + generateImage: { + enabled: enableGenerateImage + } + }, ...getCustomParameters(assistant) } @@ -219,24 +263,22 @@ export async function buildStreamTextParams( // params.tools = convertMcpToolsToSdkTools(mcpTools) } - return params + return { params, modelId: model.id } } /** * 构建非流式的 generateText 参数 */ export async function buildGenerateTextParams( - messages: Message[], + messages: CoreMessage[], assistant: Assistant, - model: Model, options: { - maxTokens?: number mcpTools?: MCPTool[] enableTools?: boolean } = {} ): Promise { // 复用流式参数的构建逻辑 - return await buildStreamTextParams(messages, assistant, model, options) + return await buildStreamTextParams(messages, assistant, options) } /** diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 587434a4f2..d8fe261497 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -1,379 +1,311 @@ +/** + * 职责:提供原子化的、无状态的API调用函数 + */ + +import { StreamTextParams } from '@cherrystudio/ai-core' import { CompletionsParams } from '@renderer/aiCore/middleware/schemas' -import Logger from '@renderer/config/logger' +import { buildStreamTextParams } from '@renderer/aiCore/transformParameters' import { isEmbeddingModel, - isGenerateImageModel, - isOpenRouterBuiltInWebSearchModel, isReasoningModel, - isSupportedDisableGenerationModel, isSupportedReasoningEffortModel, - isSupportedThinkingTokenModel, - isWebSearchModel + isSupportedThinkingTokenModel } from '@renderer/config/models' -import { - SEARCH_SUMMARY_PROMPT, - SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY, - SEARCH_SUMMARY_PROMPT_WEB_ONLY -} from '@renderer/config/prompts' import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' -import { - Assistant, - ExternalToolResult, - KnowledgeReference, - MCPTool, - Model, - Provider, - WebSearchResponse, - WebSearchSource -} from '@renderer/types' +import { Assistant, MCPTool, Model, Provider } from '@renderer/types' import { type Chunk, ChunkType } from '@renderer/types/chunk' import { Message } from '@renderer/types/newMessage' import { SdkModel } from '@renderer/types/sdk' import { removeSpecialCharactersForTopicName } from '@renderer/utils' -import { isAbortError } from '@renderer/utils/error' -import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract' import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' -import { findLast, isEmpty, takeRight } from 'lodash' +import { isEmpty, takeRight } from 'lodash' import AiProvider from '../aiCore' import AiProviderNew from '../aiCore/index_new' import { getAssistantProvider, - getAssistantSettings, getDefaultModel, getProviderByModel, getTopNamingModel, getTranslateModel } from './AssistantService' import { getDefaultAssistant } from './AssistantService' -import { processKnowledgeSearch } from './KnowledgeService' -import { - filterContextMessages, - filterEmptyMessages, - filterUsefulMessages, - filterUserRoleStartMessages -} from './MessagesService' -import WebSearchService from './WebSearchService' -// TODO:考虑拆开 -async function fetchExternalTool( - lastUserMessage: Message, - assistant: Assistant, - onChunkReceived: (chunk: Chunk) => void, - lastAnswer?: Message -): Promise { - // 可能会有重复? - const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) - const hasKnowledgeBase = !isEmpty(knowledgeBaseIds) - const knowledgeRecognition = assistant.knowledgeRecognition || 'on' - const webSearchProvider = WebSearchService.getWebSearchProvider(assistant.webSearchProviderId) +// // TODO:考虑拆开 +// async function fetchExternalTool( +// lastUserMessage: Message, +// assistant: Assistant, +// onChunkReceived: (chunk: Chunk) => void, +// lastAnswer?: Message +// ) { +// // 可能会有重复? +// const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) +// const hasKnowledgeBase = !isEmpty(knowledgeBaseIds) +// const knowledgeRecognition = assistant.knowledgeRecognition || 'on' +// const webSearchProvider = WebSearchService.getWebSearchProvider(assistant.webSearchProviderId) - // 使用外部搜索工具 - const shouldWebSearch = !!assistant.webSearchProviderId && webSearchProvider !== null - const shouldKnowledgeSearch = hasKnowledgeBase +// // 使用外部搜索工具 +// const shouldWebSearch = !!assistant.webSearchProviderId && webSearchProvider !== null +// const shouldKnowledgeSearch = hasKnowledgeBase - // 在工具链开始时发送进度通知 - const willUseTools = shouldWebSearch || shouldKnowledgeSearch - if (willUseTools) { - onChunkReceived({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS }) - } +// // 在工具链开始时发送进度通知 +// const willUseTools = shouldWebSearch || shouldKnowledgeSearch +// if (willUseTools) { +// onChunkReceived({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS }) +// } - // --- Keyword/Question Extraction Function --- - const extract = async (): Promise => { - if (!lastUserMessage) return undefined +// // --- Keyword/Question Extraction Function --- +// const extract = async (): Promise => { +// if (!lastUserMessage) return undefined - // 根据配置决定是否需要提取 - const needWebExtract = shouldWebSearch - const needKnowledgeExtract = hasKnowledgeBase && knowledgeRecognition === 'on' +// // 根据配置决定是否需要提取 +// const needWebExtract = shouldWebSearch +// const needKnowledgeExtract = hasKnowledgeBase && knowledgeRecognition === 'on' - if (!needWebExtract && !needKnowledgeExtract) return undefined +// if (!needWebExtract && !needKnowledgeExtract) return undefined - let prompt: string - if (needWebExtract && !needKnowledgeExtract) { - prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY - } else if (!needWebExtract && needKnowledgeExtract) { - prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY - } else { - prompt = SEARCH_SUMMARY_PROMPT - } +// let prompt: string +// if (needWebExtract && !needKnowledgeExtract) { +// prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY +// } else if (!needWebExtract && needKnowledgeExtract) { +// prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY +// } else { +// prompt = SEARCH_SUMMARY_PROMPT +// } - const summaryAssistant = getDefaultAssistant() - summaryAssistant.model = assistant.model || getDefaultModel() - summaryAssistant.prompt = prompt +// const summaryAssistant = getDefaultAssistant() +// summaryAssistant.model = assistant.model || getDefaultModel() +// summaryAssistant.prompt = prompt +// try { +// const result = await fetchSearchSummary({ +// messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage], +// assistant: summaryAssistant +// }) + +// if (!result) return getFallbackResult() + +// const extracted = extractInfoFromXML(result.getText()) +// // 根据需求过滤结果 +// return { +// websearch: needWebExtract ? extracted?.websearch : undefined, +// knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined +// } +// } catch (e: any) { +// console.error('extract error', e) +// if (isAbortError(e)) throw e +// return getFallbackResult() +// } +// } + +// const getFallbackResult = (): ExtractResults => { +// const fallbackContent = getMainTextContent(lastUserMessage) +// return { +// websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined, +// knowledge: shouldKnowledgeSearch +// ? { +// question: [fallbackContent || 'search'], +// rewrite: fallbackContent +// } +// : undefined +// } +// } + +// // --- Web Search Function --- +// const searchTheWeb = async (extractResults: ExtractResults | undefined): Promise => { +// if (!shouldWebSearch) return + +// // Add check for extractResults existence early +// if (!extractResults?.websearch) { +// console.warn('searchTheWeb called without valid extractResults.websearch') +// return +// } + +// if (extractResults.websearch.question[0] === 'not_needed') return + +// // Add check for assistant.model before using it +// if (!assistant.model) { +// console.warn('searchTheWeb called without assistant.model') +// return undefined +// } + +// try { +// // Use the consolidated processWebsearch function +// WebSearchService.createAbortSignal(lastUserMessage.id) +// return { +// results: await WebSearchService.processWebsearch(webSearchProvider!, extractResults), +// source: WebSearchSource.WEBSEARCH +// } +// } catch (error) { +// if (isAbortError(error)) throw error +// console.error('Web search failed:', error) +// return +// } +// } + +// // --- Knowledge Base Search Function --- +// const searchKnowledgeBase = async ( +// extractResults: ExtractResults | undefined +// ): Promise => { +// if (!hasKnowledgeBase) return + +// // 知识库搜索条件 +// let searchCriteria: { question: string[]; rewrite: string } +// if (knowledgeRecognition === 'off') { +// const directContent = getMainTextContent(lastUserMessage) +// searchCriteria = { question: [directContent || 'search'], rewrite: directContent } +// } else { +// // auto mode +// if (!extractResults?.knowledge) { +// console.warn('searchKnowledgeBase: No valid search criteria in auto mode') +// return +// } +// searchCriteria = extractResults.knowledge +// } + +// if (searchCriteria.question[0] === 'not_needed') return + +// try { +// const tempExtractResults: ExtractResults = { +// websearch: undefined, +// knowledge: searchCriteria +// } +// // Attempt to get knowledgeBaseIds from the main text block +// // NOTE: This assumes knowledgeBaseIds are ONLY on the main text block +// // NOTE: processKnowledgeSearch needs to handle undefined ids gracefully +// // const mainTextBlock = mainTextBlocks +// // ?.map((blockId) => store.getState().messageBlocks.entities[blockId]) +// // .find((block) => block?.type === MessageBlockType.MAIN_TEXT) as MainTextMessageBlock | undefined +// return await processKnowledgeSearch(tempExtractResults, knowledgeBaseIds) +// } catch (error) { +// console.error('Knowledge base search failed:', error) +// return +// } +// } + +// // --- Execute Extraction and Searches --- +// let extractResults: ExtractResults | undefined + +// try { +// // 根据配置决定是否需要提取 +// if (shouldWebSearch || hasKnowledgeBase) { +// extractResults = await extract() +// Logger.log('[fetchExternalTool] Extraction results:', extractResults) +// } + +// let webSearchResponseFromSearch: WebSearchResponse | undefined +// let knowledgeReferencesFromSearch: KnowledgeReference[] | undefined + +// // 并行执行搜索 +// if (shouldWebSearch || shouldKnowledgeSearch) { +// ;[webSearchResponseFromSearch, knowledgeReferencesFromSearch] = await Promise.all([ +// searchTheWeb(extractResults), +// searchKnowledgeBase(extractResults) +// ]) +// } + +// // 存储搜索结果 +// if (lastUserMessage) { +// if (webSearchResponseFromSearch) { +// window.keyv.set(`web-search-${lastUserMessage.id}`, webSearchResponseFromSearch) +// } +// if (knowledgeReferencesFromSearch) { +// window.keyv.set(`knowledge-search-${lastUserMessage.id}`, knowledgeReferencesFromSearch) +// } +// } + +// // 发送工具执行完成通知 +// if (willUseTools) { +// onChunkReceived({ +// type: ChunkType.EXTERNEL_TOOL_COMPLETE, +// external_tool: { +// webSearch: webSearchResponseFromSearch, +// knowledge: knowledgeReferencesFromSearch +// } +// }) +// } +// } catch (error) { +// if (isAbortError(error)) throw error +// console.error('Tool execution failed:', error) + +// // 发送错误状态 +// if (willUseTools) { +// onChunkReceived({ +// type: ChunkType.EXTERNEL_TOOL_COMPLETE, +// external_tool: { +// webSearch: undefined, +// knowledge: undefined +// } +// }) +// } + +// return { mcpTools: [] } +// } +// } + +export async function fetchMcpTools(assistant: Assistant) { + // Get MCP tools (Fix duplicate declaration) + let mcpTools: MCPTool[] = [] // Initialize as empty array + const allMcpServers = store.getState().mcp.servers || [] + const activedMcpServers = allMcpServers.filter((s) => s.isActive) + const assistantMcpServers = assistant.mcpServers || [] + + const enabledMCPs = activedMcpServers.filter((server) => assistantMcpServers.some((s) => s.id === server.id)) + + if (enabledMCPs && enabledMCPs.length > 0) { try { - const result = await fetchSearchSummary({ - messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage], - assistant: summaryAssistant - }) - - if (!result) return getFallbackResult() - - const extracted = extractInfoFromXML(result.getText()) - // 根据需求过滤结果 - return { - websearch: needWebExtract ? extracted?.websearch : undefined, - knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined - } - } catch (e: any) { - console.error('extract error', e) - if (isAbortError(e)) throw e - return getFallbackResult() - } - } - - const getFallbackResult = (): ExtractResults => { - const fallbackContent = getMainTextContent(lastUserMessage) - return { - websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined, - knowledge: shouldKnowledgeSearch - ? { - question: [fallbackContent || 'search'], - rewrite: fallbackContent - } - : undefined - } - } - - // --- Web Search Function --- - const searchTheWeb = async (extractResults: ExtractResults | undefined): Promise => { - if (!shouldWebSearch) return - - // Add check for extractResults existence early - if (!extractResults?.websearch) { - console.warn('searchTheWeb called without valid extractResults.websearch') - return - } - - if (extractResults.websearch.question[0] === 'not_needed') return - - // Add check for assistant.model before using it - if (!assistant.model) { - console.warn('searchTheWeb called without assistant.model') - return undefined - } - - try { - // Use the consolidated processWebsearch function - WebSearchService.createAbortSignal(lastUserMessage.id) - return { - results: await WebSearchService.processWebsearch(webSearchProvider!, extractResults), - source: WebSearchSource.WEBSEARCH - } - } catch (error) { - if (isAbortError(error)) throw error - console.error('Web search failed:', error) - return - } - } - - // --- Knowledge Base Search Function --- - const searchKnowledgeBase = async ( - extractResults: ExtractResults | undefined - ): Promise => { - if (!hasKnowledgeBase) return - - // 知识库搜索条件 - let searchCriteria: { question: string[]; rewrite: string } - if (knowledgeRecognition === 'off') { - const directContent = getMainTextContent(lastUserMessage) - searchCriteria = { question: [directContent || 'search'], rewrite: directContent } - } else { - // auto mode - if (!extractResults?.knowledge) { - console.warn('searchKnowledgeBase: No valid search criteria in auto mode') - return - } - searchCriteria = extractResults.knowledge - } - - if (searchCriteria.question[0] === 'not_needed') return - - try { - const tempExtractResults: ExtractResults = { - websearch: undefined, - knowledge: searchCriteria - } - // Attempt to get knowledgeBaseIds from the main text block - // NOTE: This assumes knowledgeBaseIds are ONLY on the main text block - // NOTE: processKnowledgeSearch needs to handle undefined ids gracefully - // const mainTextBlock = mainTextBlocks - // ?.map((blockId) => store.getState().messageBlocks.entities[blockId]) - // .find((block) => block?.type === MessageBlockType.MAIN_TEXT) as MainTextMessageBlock | undefined - return await processKnowledgeSearch(tempExtractResults, knowledgeBaseIds) - } catch (error) { - console.error('Knowledge base search failed:', error) - return - } - } - - // --- Execute Extraction and Searches --- - let extractResults: ExtractResults | undefined - - try { - // 根据配置决定是否需要提取 - if (shouldWebSearch || hasKnowledgeBase) { - extractResults = await extract() - Logger.log('[fetchExternalTool] Extraction results:', extractResults) - } - - let webSearchResponseFromSearch: WebSearchResponse | undefined - let knowledgeReferencesFromSearch: KnowledgeReference[] | undefined - - // 并行执行搜索 - if (shouldWebSearch || shouldKnowledgeSearch) { - ;[webSearchResponseFromSearch, knowledgeReferencesFromSearch] = await Promise.all([ - searchTheWeb(extractResults), - searchKnowledgeBase(extractResults) - ]) - } - - // 存储搜索结果 - if (lastUserMessage) { - if (webSearchResponseFromSearch) { - window.keyv.set(`web-search-${lastUserMessage.id}`, webSearchResponseFromSearch) - } - if (knowledgeReferencesFromSearch) { - window.keyv.set(`knowledge-search-${lastUserMessage.id}`, knowledgeReferencesFromSearch) - } - } - - // 发送工具执行完成通知 - if (willUseTools) { - onChunkReceived({ - type: ChunkType.EXTERNEL_TOOL_COMPLETE, - external_tool: { - webSearch: webSearchResponseFromSearch, - knowledge: knowledgeReferencesFromSearch + const toolPromises = enabledMCPs.map>(async (mcpServer) => { + try { + const tools = await window.api.mcp.listTools(mcpServer) + return tools.filter((tool: any) => !mcpServer.disabledTools?.includes(tool.name)) + } catch (error) { + console.error(`Error fetching tools from MCP server ${mcpServer.name}:`, error) + return [] } }) + const results = await Promise.allSettled(toolPromises) + mcpTools = results + .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') + .map((result) => result.value) + .flat() + } catch (toolError) { + console.error('Error fetching MCP tools:', toolError) } - - // Get MCP tools (Fix duplicate declaration) - let mcpTools: MCPTool[] = [] // Initialize as empty array - const allMcpServers = store.getState().mcp.servers || [] - const activedMcpServers = allMcpServers.filter((s) => s.isActive) - const assistantMcpServers = assistant.mcpServers || [] - - const enabledMCPs = activedMcpServers.filter((server) => assistantMcpServers.some((s) => s.id === server.id)) - - if (enabledMCPs && enabledMCPs.length > 0) { - try { - const toolPromises = enabledMCPs.map>(async (mcpServer) => { - try { - const tools = await window.api.mcp.listTools(mcpServer) - return tools.filter((tool: any) => !mcpServer.disabledTools?.includes(tool.name)) - } catch (error) { - console.error(`Error fetching tools from MCP server ${mcpServer.name}:`, error) - return [] - } - }) - const results = await Promise.allSettled(toolPromises) - mcpTools = results - .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') - .map((result) => result.value) - .flat() - } catch (toolError) { - console.error('Error fetching MCP tools:', toolError) - } - } - - return { mcpTools } - } catch (error) { - if (isAbortError(error)) throw error - console.error('Tool execution failed:', error) - - // 发送错误状态 - if (willUseTools) { - onChunkReceived({ - type: ChunkType.EXTERNEL_TOOL_COMPLETE, - external_tool: { - webSearch: undefined, - knowledge: undefined - } - }) - } - - return { mcpTools: [] } } + + return mcpTools } export async function fetchChatCompletion({ messages, assistant, + options, onChunkReceived }: { - messages: Message[] + messages: StreamTextParams['messages'] assistant: Assistant - onChunkReceived: (chunk: Chunk) => void - // TODO - // onChunkStatus: (status: 'searching' | 'processing' | 'success' | 'error') => void -}) { - console.log('fetchChatCompletion', messages, assistant) - - const provider = getAssistantProvider(assistant) - const AI = new AiProviderNew(provider) - - // Make sure that 'Clear Context' works for all scenarios including external tool and normal chat. - messages = filterContextMessages(messages) - - const lastUserMessage = findLast(messages, (m) => m.role === 'user') - const lastAnswer = findLast(messages, (m) => m.role === 'assistant') - if (!lastUserMessage) { - console.error('fetchChatCompletion returning early: Missing lastUserMessage or lastAnswer') - return + options: { + signal?: AbortSignal + timeout?: number + headers?: Record } - // try { - // NOTE: The search results are NOT added to the messages sent to the AI here. - // They will be retrieved and used by the messageThunk later to create CitationBlocks. - const { mcpTools } = await fetchExternalTool(lastUserMessage, assistant, onChunkReceived, lastAnswer) - const model = assistant.model || getDefaultModel() + onChunkReceived: (chunk: Chunk) => void +}) { + const provider = getAssistantProvider(assistant) + const AI = new AiProviderNew(provider, onChunkReceived) - const { maxTokens, contextCount } = getAssistantSettings(assistant) + const mcpTools = await fetchMcpTools(assistant) - const filteredMessages = filterUsefulMessages(messages) - - const _messages = filterUserRoleStartMessages( - filterEmptyMessages(filterContextMessages(takeRight(filteredMessages, contextCount + 2))) // 取原来几个provider的最大值 - ) - - const enableReasoning = - ((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) && - assistant.settings?.reasoning_effort !== undefined) || - (isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model))) - - const enableWebSearch = - (assistant.enableWebSearch && isWebSearchModel(model)) || - isOpenRouterBuiltInWebSearchModel(model) || - model.id.includes('sonar') || - false - - const enableGenerateImage = - isGenerateImageModel(model) && (isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage : true) + // 使用 transformParameters 模块构建参数 + const { params: aiSdkParams, modelId } = await buildStreamTextParams(messages, assistant, { + mcpTools: mcpTools, + requestOptions: options + }) // --- Call AI Completions --- onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED }) - if (enableWebSearch) { - onChunkReceived({ type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS }) - } - await AI.completions( - { - callType: 'chat', - messages: _messages, - assistant, - onChunk: onChunkReceived, - mcpTools: mcpTools, - maxTokens, - streamOutput: assistant.settings?.streamOutput || false, - enableReasoning, - enableWebSearch, - enableGenerateImage - }, - { - streamOutput: assistant.settings?.streamOutput || false - } - ) + await AI.completions(modelId, aiSdkParams, onChunkReceived) } interface FetchTranslateProps { diff --git a/src/renderer/src/services/ConversationService.ts b/src/renderer/src/services/ConversationService.ts new file mode 100644 index 0000000000..031f164c89 --- /dev/null +++ b/src/renderer/src/services/ConversationService.ts @@ -0,0 +1,34 @@ +import { StreamTextParams } from '@cherrystudio/ai-core' +import { convertMessagesToSdkMessages } from '@renderer/aiCore/transformParameters' +import { Assistant, Message } from '@renderer/types' +import { isEmpty, takeRight } from 'lodash' + +import { getAssistantSettings, getDefaultModel } from './AssistantService' +import { + filterContextMessages, + filterEmptyMessages, + filterUsefulMessages, + filterUserRoleStartMessages +} from './MessagesService' + +export class ConversationService { + static async prepareMessagesForLlm(messages: Message[], assistant: Assistant): Promise { + const { contextCount } = getAssistantSettings(assistant) + // This logic is extracted from the original ApiService.fetchChatCompletion + const contextMessages = filterContextMessages(messages) + const filteredMessages = filterUsefulMessages(contextMessages) + // Take the last `contextCount` messages, plus 2 to allow for a final user/assistant exchange. + const finalMessages = filterUserRoleStartMessages( + filterEmptyMessages(takeRight(filteredMessages, contextCount + 2)) + ) + return await convertMessagesToSdkMessages(finalMessages, assistant.model || getDefaultModel()) + } + + static needsWebSearch(assistant: Assistant): boolean { + return !!assistant.webSearchProviderId + } + + static needsKnowledgeSearch(assistant: Assistant): boolean { + return !isEmpty(assistant.knowledge_bases) + } +} diff --git a/src/renderer/src/services/OrchestrateService.ts b/src/renderer/src/services/OrchestrateService.ts new file mode 100644 index 0000000000..60a5339151 --- /dev/null +++ b/src/renderer/src/services/OrchestrateService.ts @@ -0,0 +1,54 @@ +import { Assistant, Message } from '@renderer/types' +import { Chunk, ChunkType } from '@renderer/types/chunk' + +import { fetchChatCompletion } from './ApiService' +import { ConversationService } from './ConversationService' + +/** + * The request object for handling a user message. + */ +export interface OrchestrationRequest { + messages: Message[] + assistant: Assistant + options: { + signal?: AbortSignal + timeout?: number + headers?: Record + } +} + +/** + * The OrchestrationService is responsible for orchestrating the different services + * to handle a user's message. It contains the core logic of the application. + */ +export class OrchestrationService { + constructor() { + // In the future, this could be a singleton, but for now, a new instance is fine. + // this.conversationService = new ConversationService() + } + + /** + * This is the core method to handle user messages. + * It takes the message context and an events object for callbacks, + * and orchestrates the call to the LLM. + * The logic is moved from `messageThunk.ts`. + * @param request The orchestration request containing messages and assistant info. + * @param events A set of callbacks to report progress and results to the UI layer. + */ + async handleUserMessage(request: OrchestrationRequest, onChunkReceived: (chunk: Chunk) => void) { + const { messages, assistant } = request + + try { + const llmMessages = await ConversationService.prepareMessagesForLlm(messages, assistant) + + await fetchChatCompletion({ + messages: llmMessages, + assistant: assistant, + options: request.options, + onChunkReceived + }) + } catch (error: any) { + onChunkReceived({ type: ChunkType.ERROR, error }) + } + } +} diff --git a/src/renderer/src/store/thunk/messageThunk.ts b/src/renderer/src/store/thunk/messageThunk.ts index afab876195..f02aef5b95 100644 --- a/src/renderer/src/store/thunk/messageThunk.ts +++ b/src/renderer/src/store/thunk/messageThunk.ts @@ -1,9 +1,9 @@ import db from '@renderer/databases' import { autoRenameTopic } from '@renderer/hooks/useTopic' -import { fetchChatCompletion } from '@renderer/services/ApiService' import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService' import FileManager from '@renderer/services/FileManager' import { NotificationService } from '@renderer/services/NotificationService' +import { OrchestrationService } from '@renderer/services/OrchestrateService' import { createStreamProcessor, type StreamProcessorCallbacks } from '@renderer/services/StreamProcessingService' import { estimateMessagesUsage } from '@renderer/services/TokenService' import store from '@renderer/store' @@ -829,15 +829,26 @@ const fetchAndProcessAssistantResponseImpl = async ( const streamProcessorCallbacks = createStreamProcessor(callbacks) const startTime = Date.now() - await fetchChatCompletion({ - messages: messagesForContext, - assistant: assistant, - onChunkReceived: streamProcessorCallbacks - }) + const orchestrationService = new OrchestrationService() + await orchestrationService.handleUserMessage( + { + messages: messagesForContext, + assistant, + options: { + timeout: 30000 + } + }, + streamProcessorCallbacks + ) } catch (error: any) { - console.error('Error fetching chat completion:', error) - if (assistantMessage) { - callbacks.onError?.(error) + console.error('Error in fetchAndProcessAssistantResponseImpl:', error) + // The main error handling is now delegated to OrchestrationService, + // which calls the `onError` callback. This catch block is for + // any errors that might occur outside of that orchestration flow. + if (assistantMessage && callbacks.onError) { + callbacks.onError(error) + } else { + // Fallback if callbacks are not even defined yet throw error } } diff --git a/tsconfig.web.json b/tsconfig.web.json index 11b6799d61..e7d6f2daef 100644 --- a/tsconfig.web.json +++ b/tsconfig.web.json @@ -16,7 +16,7 @@ "@renderer/*": ["src/renderer/src/*"], "@shared/*": ["packages/shared/*"], "@types": ["src/renderer/src/types/index.ts"], - "@cherry-studio/ai-core": ["packages/aiCore/src/"] + "@cherrystudio/ai-core": ["packages/aiCore/src/"] } } } diff --git a/yarn.lock b/yarn.lock index 335fcdb547..75aa86312b 100644 --- a/yarn.lock +++ b/yarn.lock @@ -960,9 +960,9 @@ __metadata: languageName: node linkType: hard -"@cherry-studio/ai-core@workspace:*, @cherry-studio/ai-core@workspace:packages/aiCore": +"@cherrystudio/ai-core@workspace:*, @cherrystudio/ai-core@workspace:packages/aiCore": version: 0.0.0-use.local - resolution: "@cherry-studio/ai-core@workspace:packages/aiCore" + resolution: "@cherrystudio/ai-core@workspace:packages/aiCore" dependencies: "@ai-sdk/amazon-bedrock": "npm:^2.2.10" "@ai-sdk/anthropic": "npm:^1.2.12" @@ -6392,7 +6392,7 @@ __metadata: "@agentic/tavily": "npm:^7.3.3" "@ant-design/v5-patch-for-react-19": "npm:^1.0.3" "@anthropic-ai/sdk": "npm:^0.41.0" - "@cherry-studio/ai-core": "workspace:*" + "@cherrystudio/ai-core": "workspace:*" "@cherrystudio/embedjs": "npm:^0.1.31" "@cherrystudio/embedjs-libsql": "npm:^0.1.31" "@cherrystudio/embedjs-loader-csv": "npm:^0.1.31"