diff --git a/.vscode/launch.json b/.vscode/launch.json index efacfda6f8..0b6b9a6499 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -7,7 +7,6 @@ "request": "launch", "cwd": "${workspaceRoot}", "runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite", - "runtimeVersion": "20", "windows": { "runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite.cmd" }, diff --git a/docs/technical/how-to-write-middlewares.md b/docs/technical/how-to-write-middlewares.md new file mode 100644 index 0000000000..9f3b691309 --- /dev/null +++ b/docs/technical/how-to-write-middlewares.md @@ -0,0 +1,214 @@ +# 如何为 AI Provider 编写中间件 + +本文档旨在指导开发者如何为我们的 AI Provider 框架创建和集成自定义中间件。中间件提供了一种强大而灵活的方式来增强、修改或观察 Provider 方法的调用过程,例如日志记录、缓存、请求/响应转换、错误处理等。 + +## 架构概览 + +我们的中间件架构借鉴了 Redux 的三段式设计,并结合了 JavaScript Proxy 来动态地将中间件应用于 Provider 的方法。 + +- **Proxy**: 拦截对 Provider 方法的调用,并将调用引导至中间件链。 +- **中间件链**: 一系列按顺序执行的中间件函数。每个中间件都可以处理请求/响应,然后将控制权传递给链中的下一个中间件,或者在某些情况下提前终止链。 +- **上下文 (Context)**: 一个在中间件之间传递的对象,携带了关于当前调用的信息(如方法名、原始参数、Provider 实例、以及中间件自定义的数据)。 + +## 中间件的类型 + +目前主要支持两种类型的中间件,它们共享相似的结构但针对不同的场景: + +1. **`CompletionsMiddleware`**: 专门为 `completions` 方法设计。这是最常用的中间件类型,因为它允许对 AI 模型的核心聊天/文本生成功能进行精细控制。 +2. **`ProviderMethodMiddleware`**: 通用中间件,可以应用于 Provider 上的任何其他方法(例如,`translate`, `summarize` 等,如果这些方法也通过中间件系统包装)。 + +## 编写一个 `CompletionsMiddleware` + +`CompletionsMiddleware` 的基本签名(TypeScript 类型)如下: + +```typescript +import { AiProviderMiddlewareCompletionsContext, CompletionsParams, MiddlewareAPI } from './AiProviderMiddlewareTypes' // 假设类型定义文件路径 + +export type CompletionsMiddleware = ( + api: MiddlewareAPI +) => ( + next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise // next 返回 Promise 代表原始SDK响应或下游中间件的结果 +) => (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise // 最内层函数通常返回 Promise,因为结果通过 onChunk 或 context 副作用传递 +``` + +让我们分解这个三段式结构: + +1. **第一层函数 `(api) => { ... }`**: + + - 接收一个 `api` 对象。 + - `api` 对象提供了以下方法: + - `api.getContext()`: 获取当前调用的上下文对象 (`AiProviderMiddlewareCompletionsContext`)。 + - `api.getOriginalArgs()`: 获取传递给 `completions` 方法的原始参数数组 (即 `[CompletionsParams]`)。 + - `api.getProviderId()`: 获取当前 Provider 的 ID。 + - `api.getProviderInstance()`: 获取原始的 Provider 实例。 + - 此函数通常用于进行一次性的设置或获取所需的服务/配置。它返回第二层函数。 + +2. **第二层函数 `(next) => { ... }`**: + + - 接收一个 `next` 函数。 + - `next` 函数代表了中间件链中的下一个环节。调用 `next(context, params)` 会将控制权传递给下一个中间件,或者如果当前中间件是链中的最后一个,则会调用核心的 Provider 方法逻辑 (例如,实际的 SDK 调用)。 + - `next` 函数接收当前的 `context` 和 `params` (这些可能已被上游中间件修改)。 + - **重要的是**:`next` 的返回类型通常是 `Promise`。对于 `completions` 方法,如果 `next` 调用了实际的 SDK,它将返回原始的 SDK 响应(例如,OpenAI 的流对象或 JSON 对象)。你需要处理这个响应。 + - 此函数返回第三层(也是最核心的)函数。 + +3. **第三层函数 `(context, params) => { ... }`**: + - 这是执行中间件主要逻辑的地方。 + - 它接收当前的 `context` (`AiProviderMiddlewareCompletionsContext`) 和 `params` (`CompletionsParams`)。 + - 在此函数中,你可以: + - **在调用 `next` 之前**: + - 读取或修改 `params`。例如,添加默认参数、转换消息格式。 + - 读取或修改 `context`。例如,设置一个时间戳用于后续计算延迟。 + - 执行某些检查,如果不满足条件,可以不调用 `next` 而直接返回或抛出错误(例如,参数校验失败)。 + - **调用 `await next(context, params)`**: + - 这是将控制权传递给下游的关键步骤。 + - `next` 的返回值是原始的 SDK 响应或下游中间件的结果,你需要根据情况处理它(例如,如果是流,则开始消费流)。 + - **在调用 `next` 之后**: + - 处理 `next` 的返回结果。例如,如果 `next` 返回了一个流,你可以在这里开始迭代处理这个流,并通过 `context.onChunk` 发送数据块。 + - 基于 `context` 的变化或 `next` 的结果执行进一步操作。例如,计算总耗时、记录日志。 + - 修改最终结果(尽管对于 `completions`,结果通常通过 `onChunk` 副作用发出)。 + +### 示例:一个简单的日志中间件 + +```typescript +import { + AiProviderMiddlewareCompletionsContext, + CompletionsParams, + MiddlewareAPI, + OnChunkFunction // 假设 OnChunkFunction 类型被导出 +} from './AiProviderMiddlewareTypes' // 调整路径 +import { ChunkType } from '@renderer/types' // 调整路径 + +export const createSimpleLoggingMiddleware = (): CompletionsMiddleware => { + return (api: MiddlewareAPI) => { + // console.log(`[LoggingMiddleware] Initialized for provider: ${api.getProviderId()}`); + + return (next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise) => { + return async (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams): Promise => { + const startTime = Date.now() + // 从 context 中获取 onChunk (它最初来自 params.onChunk) + const onChunk = context.onChunk + + console.log( + `[LoggingMiddleware] Request for ${context.methodName} with params:`, + params.messages?.[params.messages.length - 1]?.content + ) + + try { + // 调用下一个中间件或核心逻辑 + // `rawSdkResponse` 是来自下游的原始响应 (例如 OpenAIStream 或 ChatCompletion 对象) + const rawSdkResponse = await next(context, params) + + // 此处简单示例不处理 rawSdkResponse,假设下游中间件 (如 StreamingResponseHandler) + // 会处理它并通过 onChunk 发送数据。 + // 如果这个日志中间件在 StreamingResponseHandler 之后,那么流已经被处理。 + // 如果在之前,那么它需要自己处理 rawSdkResponse 或确保下游会处理。 + + const duration = Date.now() - startTime + console.log(`[LoggingMiddleware] Request for ${context.methodName} completed in ${duration}ms.`) + + // 假设下游已经通过 onChunk 发送了所有数据。 + // 如果这个中间件是链的末端,并且需要确保 BLOCK_COMPLETE 被发送, + // 它可能需要更复杂的逻辑来跟踪何时所有数据都已发送。 + } catch (error) { + const duration = Date.now() - startTime + console.error(`[LoggingMiddleware] Request for ${context.methodName} failed after ${duration}ms:`, error) + + // 如果 onChunk 可用,可以尝试发送一个错误块 + if (onChunk) { + onChunk({ + type: ChunkType.ERROR, + error: { message: (error as Error).message, name: (error as Error).name, stack: (error as Error).stack } + }) + // 考虑是否还需要发送 BLOCK_COMPLETE 来结束流 + onChunk({ type: ChunkType.BLOCK_COMPLETE, response: {} }) + } + throw error // 重新抛出错误,以便上层或全局错误处理器可以捕获 + } + } + } + } +} +``` + +### `AiProviderMiddlewareCompletionsContext` 的重要性 + +`AiProviderMiddlewareCompletionsContext` 是在中间件之间传递状态和数据的核心。它通常包含: + +- `methodName`: 当前调用的方法名 (总是 `'completions'`)。 +- `originalArgs`: 传递给 `completions` 的原始参数数组。 +- `providerId`: Provider 的 ID。 +- `_providerInstance`: Provider 实例。 +- `onChunk`: 从原始 `CompletionsParams` 传入的回调函数,用于流式发送数据块。**所有中间件都应该通过 `context.onChunk` 来发送数据。** +- `messages`, `model`, `assistant`, `mcpTools`: 从原始 `CompletionsParams` 中提取的常用字段,方便访问。 +- **自定义字段**: 中间件可以向上下文中添加自定义字段,以供后续中间件使用。例如,一个缓存中间件可能会添加 `context.cacheHit = true`。 + +**关键**: 当你在中间件中修改 `params` 或 `context` 时,这些修改会向下游中间件传播(如果它们在 `next` 调用之前修改)。 + +### 中间件的顺序 + +中间件的执行顺序非常重要。它们在 `AiProviderMiddlewareConfig` 的数组中定义的顺序就是它们的执行顺序。 + +- 请求首先通过第一个中间件,然后是第二个,依此类推。 +- 响应(或 `next` 的调用结果)则以相反的顺序"冒泡"回来。 + +例如,如果链是 `[AuthMiddleware, CacheMiddleware, LoggingMiddleware]`: + +1. `AuthMiddleware` 先执行其 "调用 `next` 之前" 的逻辑。 +2. 然后 `CacheMiddleware` 执行其 "调用 `next` 之前" 的逻辑。 +3. 然后 `LoggingMiddleware` 执行其 "调用 `next` 之前" 的逻辑。 +4. 核心SDK调用(或链的末端)。 +5. `LoggingMiddleware` 先接收到结果,执行其 "调用 `next` 之后" 的逻辑。 +6. 然后 `CacheMiddleware` 接收到结果(可能已被 LoggingMiddleware 修改的上下文),执行其 "调用 `next` 之后" 的逻辑(例如,存储结果)。 +7. 最后 `AuthMiddleware` 接收到结果,执行其 "调用 `next` 之后" 的逻辑。 + +### 注册中间件 + +中间件在 `src/renderer/src/providers/middleware/register.ts` (或其他类似的配置文件) 中进行注册。 + +```typescript +// register.ts +import { AiProviderMiddlewareConfig } from './AiProviderMiddlewareTypes' +import { createSimpleLoggingMiddleware } from './common/SimpleLoggingMiddleware' // 假设你创建了这个文件 +import { createCompletionsLoggingMiddleware } from './common/CompletionsLoggingMiddleware' // 已有的 + +const middlewareConfig: AiProviderMiddlewareConfig = { + completions: [ + createSimpleLoggingMiddleware(), // 你新加的中间件 + createCompletionsLoggingMiddleware() // 已有的日志中间件 + // ... 其他 completions 中间件 + ], + methods: { + // translate: [createGenericLoggingMiddleware()], + // ... 其他方法的中间件 + } +} + +export default middlewareConfig +``` + +### 最佳实践 + +1. **单一职责**: 每个中间件应专注于一个特定的功能(例如,日志、缓存、转换特定数据)。 +2. **无副作用 (尽可能)**: 除了通过 `context` 或 `onChunk` 明确的副作用外,尽量避免修改全局状态或产生其他隐蔽的副作用。 +3. **错误处理**: + - 在中间件内部使用 `try...catch` 来处理可能发生的错误。 + - 决定是自行处理错误(例如,通过 `onChunk` 发送错误块)还是将错误重新抛出给上游。 + - 如果重新抛出,确保错误对象包含足够的信息。 +4. **性能考虑**: 中间件会增加请求处理的开销。避免在中间件中执行非常耗时的同步操作。对于IO密集型操作,确保它们是异步的。 +5. **可配置性**: 使中间件的行为可通过参数或配置进行调整。例如,日志中间件可以接受一个日志级别参数。 +6. **上下文管理**: + - 谨慎地向 `context` 添加数据。避免污染 `context` 或添加过大的对象。 + - 明确你添加到 `context` 的字段的用途和生命周期。 +7. **`next` 的调用**: + - 除非你有充分的理由提前终止请求(例如,缓存命中、授权失败),否则**总是确保调用 `await next(context, params)`**。否则,下游的中间件和核心逻辑将不会执行。 + - 理解 `next` 的返回值并正确处理它,特别是当它是一个流时。你需要负责消费这个流或将其传递给另一个能够消费它的组件/中间件。 +8. **命名清晰**: 给你的中间件和它们创建的函数起描述性的名字。 +9. **文档和注释**: 对复杂的中间件逻辑添加注释,解释其工作原理和目的。 + +### 调试技巧 + +- 在中间件的关键点使用 `console.log` 或调试器来检查 `params`、`context` 的状态以及 `next` 的返回值。 +- 暂时简化中间件链,只保留你正在调试的中间件和最简单的核心逻辑,以隔离问题。 +- 编写单元测试来独立验证每个中间件的行为。 + +通过遵循这些指南,你应该能够有效地为我们的系统创建强大且可维护的中间件。如果你有任何疑问或需要进一步的帮助,请咨询团队。 diff --git a/packages/shared/config/constant.ts b/packages/shared/config/constant.ts index cfba46df70..5a3465f648 100644 --- a/packages/shared/config/constant.ts +++ b/packages/shared/config/constant.ts @@ -408,3 +408,4 @@ export enum FeedUrl { PRODUCTION = 'https://releases.cherry-ai.com', EARLY_ACCESS = 'https://github.com/CherryHQ/cherry-studio/releases/latest/download' } +export const defaultTimeout = 5 * 1000 * 60 diff --git a/src/main/ipc.ts b/src/main/ipc.ts index a8f02fe543..466c5f35a8 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -4,6 +4,7 @@ import { arch } from 'node:os' import { isMac, isWin } from '@main/constant' import { getBinaryPath, isBinaryExists, runInstallScript } from '@main/utils/process' import { handleZoomFactor } from '@main/utils/zoom' +import { FeedUrl } from '@shared/config/constant' import { IpcChannel } from '@shared/IpcChannel' import { Shortcut, ThemeMode } from '@types' import { BrowserWindow, ipcMain, session, shell } from 'electron' @@ -34,7 +35,6 @@ import { calculateDirectorySize, getResourcePath } from './utils' import { decrypt, encrypt } from './utils/aes' import { getCacheDir, getConfigDir, getFilesDir } from './utils/file' import { compress, decompress } from './utils/zip' -import { FeedUrl } from '@shared/config/constant' const fileManager = new FileStorage() const backupManager = new BackupManager() diff --git a/src/renderer/src/aiCore/AI_CORE_DESIGN.md b/src/renderer/src/aiCore/AI_CORE_DESIGN.md new file mode 100644 index 0000000000..611c582d83 --- /dev/null +++ b/src/renderer/src/aiCore/AI_CORE_DESIGN.md @@ -0,0 +1,223 @@ +# Cherry Studio AI Provider 技术架构文档 (新方案) + +## 1. 核心设计理念与目标 + +本架构旨在重构 Cherry Studio 的 AI Provider(现称为 `aiCore`)层,以实现以下目标: + +- **职责清晰**:明确划分各组件的职责,降低耦合度。 +- **高度复用**:最大化业务逻辑和通用处理逻辑的复用,减少重复代码。 +- **易于扩展**:方便快捷地接入新的 AI Provider (LLM供应商) 和添加新的 AI 功能 (如翻译、摘要、图像生成等)。 +- **易于维护**:简化单个组件的复杂性,提高代码的可读性和可维护性。 +- **标准化**:统一内部数据流和接口,简化不同 Provider 之间的差异处理。 + +核心思路是将纯粹的 **SDK 适配层 (`XxxApiClient`)**、**通用逻辑处理与智能解析层 (中间件)** 以及 **统一业务功能入口层 (`AiCoreService`)** 清晰地分离开来。 + +## 2. 核心组件详解 + +### 2.1. `aiCore` (原 `AiProvider` 文件夹) + +这是整个 AI 功能的核心模块。 + +#### 2.1.1. `XxxApiClient` (例如 `aiCore/clients/openai/OpenAIApiClient.ts`) + +- **职责**:作为特定 AI Provider SDK 的纯粹适配层。 + - **参数适配**:将应用内部统一的 `CoreRequest` 对象 (见下文) 转换为特定 SDK 所需的请求参数格式。 + - **基础响应转换**:将 SDK 返回的原始数据块 (`RawSdkChunk`,例如 `OpenAI.Chat.Completions.ChatCompletionChunk`) 转换为一组最基础、最直接的应用层 `Chunk` 对象 (定义于 `src/renderer/src/types/chunk.ts`)。 + - 例如:SDK 的 `delta.content` -> `TextDeltaChunk`;SDK 的 `delta.reasoning_content` -> `ThinkingDeltaChunk`;SDK 的 `delta.tool_calls` -> `RawToolCallChunk` (包含原始工具调用数据)。 + - **关键**:`XxxApiClient` **不处理**耦合在文本内容中的复杂结构,如 `` 或 `` 标签。 +- **特点**:极度轻量化,代码量少,易于实现和维护新的 Provider 适配。 + +#### 2.1.2. `ApiClient.ts` (或 `BaseApiClient.ts` 的核心接口) + +- 定义了所有 `XxxApiClient` 必须实现的接口,如: + - `getSdkInstance(): Promise | TSdkInstance` + - `getRequestTransformer(): RequestTransformer` + - `getResponseChunkTransformer(): ResponseChunkTransformer` + - 其他可选的、与特定 Provider 相关的辅助方法 (如工具调用转换)。 + +#### 2.1.3. `ApiClientFactory.ts` + +- 根据 Provider 配置动态创建和返回相应的 `XxxApiClient` 实例。 + +#### 2.1.4. `AiCoreService.ts` (`aiCore/index.ts`) + +- **职责**:作为所有 AI 相关业务功能的统一入口。 + - 提供面向应用的高层接口,例如: + - `executeCompletions(params: CompletionsParams): Promise` + - `translateText(params: TranslateParams): Promise` + - `summarizeText(params: SummarizeParams): Promise` + - 未来可能的 `generateImage(prompt: string): Promise` 等。 + - **返回 `Promise`**:每个服务方法返回一个 `Promise`,该 `Promise` 会在整个(可能是流式的)操作完成后,以包含所有聚合结果(如完整文本、工具调用详情、最终的`usage`/`metrics`等)的对象来 `resolve`。 + - **支持流式回调**:服务方法的参数 (如 `CompletionsParams`) 依然包含 `onChunk` 回调,用于向调用方实时推送处理过程中的 `Chunk` 数据,实现流式UI更新。 + - **封装特定任务的提示工程 (Prompt Engineering)**: + - 例如,`translateText` 方法内部会构建一个包含特定翻译指令的 `CoreRequest`。 + - **编排和调用中间件链**:通过内部的 `MiddlewareBuilder` (参见 `middleware/BUILDER_USAGE.md`) 实例,根据调用的业务方法和参数,动态构建和组织合适的中间件序列,然后通过 `applyCompletionsMiddlewares` 等组合函数执行。 + - 获取 `ApiClient` 实例并将其注入到中间件上游的 `Context` 中。 + - **将 `Promise` 的 `resolve` 和 `reject` 函数传递给中间件链** (通过 `Context`),以便 `FinalChunkConsumerAndNotifierMiddleware` 可以在操作完成或发生错误时结束该 `Promise`。 +- **优势**: + - 业务逻辑(如翻译、摘要的提示构建和流程控制)只需实现一次,即可支持所有通过 `ApiClient` 接入的底层 Provider。 + - **支持外部编排**:调用方可以 `await` 服务方法以获取最终聚合结果,然后将此结果作为后续操作的输入,轻松实现多步骤工作流。 + - **支持内部组合**:服务自身也可以通过 `await` 调用其他原子服务方法来构建更复杂的组合功能。 + +#### 2.1.5. `coreRequestTypes.ts` (或 `types.ts`) + +- 定义核心的、Provider 无关的内部请求结构,例如: + - `CoreCompletionsRequest`: 包含标准化后的消息列表、模型配置、工具列表、最大Token数、是否流式输出等。 + - `CoreTranslateRequest`, `CoreSummarizeRequest` 等 (如果与 `CoreCompletionsRequest` 结构差异较大,否则可复用并添加任务类型标记)。 + +### 2.2. `middleware` + +中间件层负责处理请求和响应流中的通用逻辑和特定特性。其设计和使用遵循 `middleware/BUILDER_USAGE.md` 中定义的规范。 + +**核心组件包括:** + +- **`MiddlewareBuilder`**: 一个通用的、提供流式API的类,用于动态构建中间件链。它支持从基础链开始,根据条件添加、插入、替换或移除中间件。 +- **`applyCompletionsMiddlewares`**: 负责接收 `MiddlewareBuilder` 构建的链并按顺序执行,专门用于 Completions 流程。 +- **`MiddlewareRegistry`**: 集中管理所有可用中间件的注册表,提供统一的中间件访问接口。 +- **各种独立的中间件模块** (存放于 `common/`, `core/`, `feat/` 子目录)。 + +#### 2.2.1. `middlewareTypes.ts` + +- 定义中间件的核心类型,如 `AiProviderMiddlewareContext` (扩展后包含 `_apiClientInstance` 和 `_coreRequest`)、`MiddlewareAPI`、`CompletionsMiddleware` 等。 + +#### 2.2.2. 核心中间件 (`middleware/core/`) + +- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()` 将 `CoreRequest` 转换为特定 SDK 的参数,并存入上下文。 +- **`RequestExecutionMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。 +- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流 (如异步迭代器) 统一适配为 `ReadableStream`。 + - **`RawSdkChunk`**:指特定AI提供商SDK在流式响应中返回的、未经应用层统一处理的原始数据块格式 (例如 OpenAI 的 `ChatCompletionChunk`,Gemini 的 `GenerateContentResponse` 中的部分等)。 +- **`RawSdkChunkToAppChunkMiddleware.ts`**: (新增) 消费 `ReadableStream`,在其内部对每个 `RawSdkChunk` 调用 `ApiClient.getResponseChunkTransformer()`,将其转换为一个或多个基础的应用层 `Chunk` 对象,并输出 `ReadableStream`。 + +#### 2.2.3. 特性中间件 (`middleware/feat/`) + +这些中间件消费由 `ResponseTransformMiddleware` 输出的、相对标准化的 `Chunk` 流,并处理更复杂的逻辑。 + +- **`ThinkingTagExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `...` 文本内嵌标签,生成 `ThinkingDeltaChunk` 和 `ThinkingCompleteChunk`。 +- **`ToolUseExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `...` 文本内嵌标签,生成工具调用相关的 Chunk。如果 `ApiClient` 输出了原生工具调用数据,此中间件也负责将其转换为标准格式。 + +#### 2.2.4. 核心处理中间件 (`middleware/core/`) + +- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()` 将 `CoreRequest` 转换为特定 SDK 的参数,并存入上下文。 +- **`SdkCallMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。 +- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流统一适配为标准流格式。 +- **`ResponseTransformMiddleware.ts`**: 将原始 SDK 响应转换为应用层标准 `Chunk` 对象。 +- **`TextChunkMiddleware.ts`**: 处理文本相关的 Chunk 流。 +- **`ThinkChunkMiddleware.ts`**: 处理思考相关的 Chunk 流。 +- **`McpToolChunkMiddleware.ts`**: 处理工具调用相关的 Chunk 流。 +- **`WebSearchMiddleware.ts`**: 处理 Web 搜索相关逻辑。 + +#### 2.2.5. 通用中间件 (`middleware/common/`) + +- **`LoggingMiddleware.ts`**: 请求和响应日志。 +- **`AbortHandlerMiddleware.ts`**: 处理请求中止。 +- **`FinalChunkConsumerMiddleware.ts`**: 消费最终的 `Chunk` 流,通过 `context.onChunk` 回调通知应用层实时数据。 + - **累积数据**:在流式处理过程中,累积关键数据,如文本片段、工具调用信息、`usage`/`metrics` 等。 + - **结束 `Promise`**:当输入流结束时,使用累积的聚合结果来完成整个处理流程。 + - 在流结束时,发送包含最终累加信息的完成信号。 + +### 2.3. `types/chunk.ts` + +- 定义应用全局统一的 `Chunk` 类型及其所有变体。这包括基础类型 (如 `TextDeltaChunk`, `ThinkingDeltaChunk`)、SDK原生数据传递类型 (如 `RawToolCallChunk`, `RawFinishChunk` - 作为 `ApiClient` 转换的中间产物),以及功能性类型 (如 `McpToolCallRequestChunk`, `WebSearchCompleteChunk`)。 + +## 3. 核心执行流程 (以 `AiCoreService.executeCompletions` 为例) + +```markdown +**应用层 (例如 UI 组件)** +|| +\\/ +**`AiProvider.completions` (`aiCore/index.ts`)** +(1. prepare ApiClient instance. 2. use `CompletionsMiddlewareBuilder.withDefaults()` to build middleware chain. 3. call `applyCompletionsMiddlewares`) +|| +\\/ +**`applyCompletionsMiddlewares` (`middleware/composer.ts`)** +(接收构建好的链、ApiClient实例、原始SDK方法,开始按序执行中间件) +|| +\\/ +**[ 预处理阶段中间件 ]** +(例如: `FinalChunkConsumerMiddleware`, `TransformCoreToSdkParamsMiddleware`, `AbortHandlerMiddleware`) +|| (Context 中准备好 SDK 请求参数) +\\/ +**[ 处理阶段中间件 ]** +(例如: `McpToolChunkMiddleware`, `WebSearchMiddleware`, `TextChunkMiddleware`, `ThinkingTagExtractionMiddleware`) +|| (处理各种特性和Chunk类型) +\\/ +**[ SDK调用阶段中间件 ]** +(例如: `ResponseTransformMiddleware`, `StreamAdapterMiddleware`, `SdkCallMiddleware`) +|| (输出: 标准化的应用层Chunk流) +\\/ +**`FinalChunkConsumerMiddleware` (核心)** +(消费最终的 `Chunk` 流, 通过 `context.onChunk` 回调通知应用层, 并在流结束时完成处理) +|| +\\/ +**`AiProvider.completions` 返回 `Promise`** +``` + +## 4. 建议的文件/目录结构 + +``` +src/renderer/src/ +└── aiCore/ + ├── clients/ + │ ├── openai/ + │ ├── gemini/ + │ ├── anthropic/ + │ ├── BaseApiClient.ts + │ ├── ApiClientFactory.ts + │ ├── AihubmixAPIClient.ts + │ ├── index.ts + │ └── types.ts + ├── middleware/ + │ ├── common/ + │ ├── core/ + │ ├── feat/ + │ ├── builder.ts + │ ├── composer.ts + │ ├── index.ts + │ ├── register.ts + │ ├── schemas.ts + │ ├── types.ts + │ └── utils.ts + ├── types/ + │ ├── chunk.ts + │ └── ... + └── index.ts +``` + +## 5. 迁移和实施建议 + +- **小步快跑,逐步迭代**:优先完成核心流程的重构(例如 `completions`),再逐步迁移其他功能(`translate` 等)和其他 Provider。 +- **优先定义核心类型**:`CoreRequest`, `Chunk`, `ApiClient` 接口是整个架构的基石。 +- **为 `ApiClient` 瘦身**:将现有 `XxxProvider` 中的复杂逻辑剥离到新的中间件或 `AiCoreService` 中。 +- **强化中间件**:让中间件承担起更多解析和特性处理的责任。 +- **编写单元测试和集成测试**:确保每个组件和整体流程的正确性。 + +此架构旨在提供一个更健壮、更灵活、更易于维护的 AI 功能核心,支撑 Cherry Studio 未来的发展。 + +## 6. 迁移策略与实施建议 + +本节内容提炼自早期的 `migrate.md` 文档,并根据最新的架构讨论进行了调整。 + +**目标架构核心组件回顾:** + +与第 2 节描述的核心组件一致,主要包括 `XxxApiClient`, `AiCoreService`, 中间件链, `CoreRequest` 类型, 和标准化的 `Chunk` 类型。 + +**迁移步骤:** + +**Phase 0: 准备工作和类型定义** + +1. **定义核心数据结构 (TypeScript 类型):** + - `CoreCompletionsRequest` (Type):定义应用内部统一的对话请求结构。 + - `Chunk` (Type - 检查并按需扩展现有 `src/renderer/src/types/chunk.ts`):定义所有可能的通用Chunk类型。 + - 为其他API(翻译、总结)定义类似的 `CoreXxxRequest` (Type)。 +2. **定义 `ApiClient` 接口:** 明确 `getRequestTransformer`, `getResponseChunkTransformer`, `getSdkInstance` 等核心方法。 +3. **调整 `AiProviderMiddlewareContext`:** + - 确保包含 `_apiClientInstance: ApiClient`。 + - 确保包含 `_coreRequest: CoreRequestType`。 + - 考虑添加 `resolvePromise: (value: AggregatedResultType) => void` 和 `rejectPromise: (reason?: any) => void` 用于 `AiCoreService` 的 Promise 返回。 + +**Phase 1: 实现第一个 `ApiClient` (以 `OpenAIApiClient` 为例)** + +1. **创建 `OpenAIApiClient` 类:** 实现 `ApiClient` 接口。 +2. **迁移SDK实例和配置。** +3. **实现 `getRequestTransformer()`:** 将 `CoreCompletionsRequest` 转换为 OpenAI SDK 参数。 +4. **实现 `getResponseChunkTransformer()`:** 将 `OpenAI.Chat.Completions.ChatCompletionChunk` 转换为基础的 ` diff --git a/src/renderer/src/aiCore/clients/AihubmixAPIClient.ts b/src/renderer/src/aiCore/clients/AihubmixAPIClient.ts new file mode 100644 index 0000000000..3aa0bc4263 --- /dev/null +++ b/src/renderer/src/aiCore/clients/AihubmixAPIClient.ts @@ -0,0 +1,207 @@ +import { isOpenAILLMModel } from '@renderer/config/models' +import { + GenerateImageParams, + MCPCallToolResponse, + MCPTool, + MCPToolResponse, + Model, + Provider, + ToolCallResponse +} from '@renderer/types' +import { + RequestOptions, + SdkInstance, + SdkMessageParam, + SdkModel, + SdkParams, + SdkRawChunk, + SdkRawOutput, + SdkTool, + SdkToolCall +} from '@renderer/types/sdk' + +import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient' +import { BaseApiClient } from './BaseApiClient' +import { GeminiAPIClient } from './gemini/GeminiAPIClient' +import { OpenAIAPIClient } from './openai/OpenAIApiClient' +import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient' +import { RequestTransformer, ResponseChunkTransformer } from './types' + +/** + * AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient + * 使用装饰器模式实现,在ApiClient层面进行模型路由 + */ +export class AihubmixAPIClient extends BaseApiClient { + // 使用联合类型而不是any,保持类型安全 + private clients: Map = + new Map() + private defaultClient: OpenAIAPIClient + private currentClient: BaseApiClient + + constructor(provider: Provider) { + super(provider) + + // 初始化各个client - 现在有类型安全 + const claudeClient = new AnthropicAPIClient(provider) + const geminiClient = new GeminiAPIClient({ ...provider, apiHost: 'https://aihubmix.com/gemini' }) + const openaiClient = new OpenAIResponseAPIClient(provider) + const defaultClient = new OpenAIAPIClient(provider) + + this.clients.set('claude', claudeClient) + this.clients.set('gemini', geminiClient) + this.clients.set('openai', openaiClient) + this.clients.set('default', defaultClient) + + // 设置默认client + this.defaultClient = defaultClient + this.currentClient = this.defaultClient as BaseApiClient + } + + /** + * 类型守卫:确保client是BaseApiClient的实例 + */ + private isValidClient(client: unknown): client is BaseApiClient { + return ( + client !== null && + client !== undefined && + typeof client === 'object' && + 'createCompletions' in client && + 'getRequestTransformer' in client && + 'getResponseChunkTransformer' in client + ) + } + + /** + * 根据模型获取合适的client + */ + private getClient(model: Model): BaseApiClient { + const id = model.id.toLowerCase() + + // claude开头 + if (id.startsWith('claude')) { + const client = this.clients.get('claude') + if (!client || !this.isValidClient(client)) { + throw new Error('Claude client not properly initialized') + } + return client + } + + // gemini开头 且不以-nothink、-search结尾 + if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) { + const client = this.clients.get('gemini') + if (!client || !this.isValidClient(client)) { + throw new Error('Gemini client not properly initialized') + } + return client + } + + // OpenAI系列模型 + if (isOpenAILLMModel(model)) { + const client = this.clients.get('openai') + if (!client || !this.isValidClient(client)) { + throw new Error('OpenAI client not properly initialized') + } + return client + } + + return this.defaultClient as BaseApiClient + } + + /** + * 根据模型选择合适的client并委托调用 + */ + public getClientForModel(model: Model): BaseApiClient { + this.currentClient = this.getClient(model) + return this.currentClient + } + + // ============ BaseApiClient 抽象方法实现 ============ + + async createCompletions(payload: SdkParams, options?: RequestOptions): Promise { + // 尝试从payload中提取模型信息来选择client + const modelId = this.extractModelFromPayload(payload) + if (modelId) { + const modelObj = { id: modelId } as Model + const targetClient = this.getClient(modelObj) + return targetClient.createCompletions(payload, options) + } + + // 如果无法从payload中提取模型,使用当前设置的client + return this.currentClient.createCompletions(payload, options) + } + + /** + * 从SDK payload中提取模型ID + */ + private extractModelFromPayload(payload: SdkParams): string | null { + // 不同的SDK可能有不同的字段名 + if ('model' in payload && typeof payload.model === 'string') { + return payload.model + } + return null + } + + async generateImage(params: GenerateImageParams): Promise { + return this.currentClient.generateImage(params) + } + + async getEmbeddingDimensions(model?: Model): Promise { + const client = model ? this.getClient(model) : this.currentClient + return client.getEmbeddingDimensions(model) + } + + async listModels(): Promise { + // 可以聚合所有client的模型,或者使用默认client + return this.defaultClient.listModels() + } + + async getSdkInstance(): Promise { + return this.currentClient.getSdkInstance() + } + + getRequestTransformer(): RequestTransformer { + return this.currentClient.getRequestTransformer() + } + + getResponseChunkTransformer(): ResponseChunkTransformer { + return this.currentClient.getResponseChunkTransformer() + } + + convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] { + return this.currentClient.convertMcpToolsToSdkTools(mcpTools) + } + + convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined { + return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools) + } + + convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse { + return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool) + } + + buildSdkMessages( + currentReqMessages: SdkMessageParam[], + output: SdkRawOutput | string, + toolResults: SdkMessageParam[], + toolCalls?: SdkToolCall[] + ): SdkMessageParam[] { + return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls) + } + + convertMcpToolResponseToSdkMessageParam( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ): SdkMessageParam | undefined { + const client = this.getClient(model) + return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model) + } + + extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] { + return this.currentClient.extractMessagesFromSdkPayload(sdkPayload) + } + + estimateMessageTokens(message: SdkMessageParam): number { + return this.currentClient.estimateMessageTokens(message) + } +} diff --git a/src/renderer/src/aiCore/clients/ApiClientFactory.ts b/src/renderer/src/aiCore/clients/ApiClientFactory.ts new file mode 100644 index 0000000000..0d8a97ed51 --- /dev/null +++ b/src/renderer/src/aiCore/clients/ApiClientFactory.ts @@ -0,0 +1,62 @@ +import { Provider } from '@renderer/types' + +import { AihubmixAPIClient } from './AihubmixAPIClient' +import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient' +import { BaseApiClient } from './BaseApiClient' +import { GeminiAPIClient } from './gemini/GeminiAPIClient' +import { OpenAIAPIClient } from './openai/OpenAIApiClient' +import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient' + +/** + * Factory for creating ApiClient instances based on provider configuration + * 根据提供者配置创建ApiClient实例的工厂 + */ +export class ApiClientFactory { + /** + * Create an ApiClient instance for the given provider + * 为给定的提供者创建ApiClient实例 + */ + static create(provider: Provider): BaseApiClient { + console.log(`[ApiClientFactory] Creating ApiClient for provider:`, { + id: provider.id, + type: provider.type + }) + + let instance: BaseApiClient + + // 首先检查特殊的provider id + if (provider.id === 'aihubmix') { + console.log(`[ApiClientFactory] Creating AihubmixAPIClient for provider: ${provider.id}`) + instance = new AihubmixAPIClient(provider) as BaseApiClient + return instance + } + + // 然后检查标准的provider type + switch (provider.type) { + case 'openai': + case 'azure-openai': + console.log(`[ApiClientFactory] Creating OpenAIApiClient for provider: ${provider.id}`) + instance = new OpenAIAPIClient(provider) as BaseApiClient + break + case 'openai-response': + instance = new OpenAIResponseAPIClient(provider) as BaseApiClient + break + case 'gemini': + instance = new GeminiAPIClient(provider) as BaseApiClient + break + case 'anthropic': + instance = new AnthropicAPIClient(provider) as BaseApiClient + break + default: + console.log(`[ApiClientFactory] Using default OpenAIApiClient for provider: ${provider.id}`) + instance = new OpenAIAPIClient(provider) as BaseApiClient + break + } + + return instance + } +} + +export function isOpenAIProvider(provider: Provider) { + return !['anthropic', 'gemini'].includes(provider.type) +} diff --git a/src/renderer/src/providers/AiProvider/BaseProvider.ts b/src/renderer/src/aiCore/clients/BaseApiClient.ts similarity index 53% rename from src/renderer/src/providers/AiProvider/BaseProvider.ts rename to src/renderer/src/aiCore/clients/BaseApiClient.ts index f1beba64d4..19e455026d 100644 --- a/src/renderer/src/providers/AiProvider/BaseProvider.ts +++ b/src/renderer/src/aiCore/clients/BaseApiClient.ts @@ -1,40 +1,69 @@ -import Logger from '@renderer/config/logger' -import { isFunctionCallingModel, isNotSupportTemperatureAndTopP } from '@renderer/config/models' +import { + isFunctionCallingModel, + isNotSupportTemperatureAndTopP, + isOpenAIModel, + isSupportedFlexServiceTier +} from '@renderer/config/models' import { REFERENCE_PROMPT } from '@renderer/config/prompts' import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio' -import type { +import { getStoreSetting } from '@renderer/hooks/useSettings' +import { SettingsState } from '@renderer/store/settings' +import { Assistant, + FileTypes, GenerateImageParams, KnowledgeReference, MCPCallToolResponse, MCPTool, MCPToolResponse, Model, + OpenAIServiceTier, Provider, - Suggestion, + ToolCallResponse, WebSearchProviderResponse, WebSearchResponse } from '@renderer/types' -import { ChunkType } from '@renderer/types/chunk' -import type { Message } from '@renderer/types/newMessage' -import { delay, isJSON, parseJSON } from '@renderer/utils' +import { Message } from '@renderer/types/newMessage' +import { + RequestOptions, + SdkInstance, + SdkMessageParam, + SdkModel, + SdkParams, + SdkRawChunk, + SdkRawOutput, + SdkTool, + SdkToolCall +} from '@renderer/types/sdk' +import { isJSON, parseJSON } from '@renderer/utils' import { addAbortController, removeAbortController } from '@renderer/utils/abortController' -import { formatApiHost } from '@renderer/utils/api' -import { getMainTextContent } from '@renderer/utils/messageUtils/find' +import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' +import { defaultTimeout } from '@shared/config/constant' +import Logger from 'electron-log/renderer' import { isEmpty } from 'lodash' -import type OpenAI from 'openai' -import type { CompletionsParams } from '.' +import { ApiClient, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from './types' -export default abstract class BaseProvider { - // Threshold for determining whether to use system prompt for tools +/** + * Abstract base class for API clients. + * Provides common functionality and structure for specific client implementations. + */ +export abstract class BaseApiClient< + TSdkInstance extends SdkInstance = SdkInstance, + TSdkParams extends SdkParams = SdkParams, + TRawOutput extends SdkRawOutput = SdkRawOutput, + TRawChunk extends SdkRawChunk = SdkRawChunk, + TMessageParam extends SdkMessageParam = SdkMessageParam, + TToolCall extends SdkToolCall = SdkToolCall, + TSdkSpecificTool extends SdkTool = SdkTool +> implements ApiClient +{ private static readonly SYSTEM_PROMPT_THRESHOLD: number = 128 - - protected provider: Provider + public provider: Provider protected host: string protected apiKey: string - - protected useSystemPromptForTools: boolean = true + protected sdkInstance?: TSdkInstance + public useSystemPromptForTools: boolean = true constructor(provider: Provider) { this.provider = provider @@ -42,32 +71,81 @@ export default abstract class BaseProvider { this.apiKey = this.getApiKey() } - abstract completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise - abstract translate( - content: string, - assistant: Assistant, - onResponse?: (text: string, isComplete: boolean) => void - ): Promise - abstract summaries(messages: Message[], assistant: Assistant): Promise - abstract summaryForSearch(messages: Message[], assistant: Assistant): Promise - abstract suggestions(messages: Message[], assistant: Assistant): Promise - abstract generateText({ prompt, content }: { prompt: string; content: string }): Promise - abstract check(model: Model, stream: boolean): Promise<{ valid: boolean; error: Error | null }> - abstract models(): Promise - abstract generateImage(params: GenerateImageParams): Promise - abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise - // 由于现在出现了一些能够选择嵌入维度的嵌入模型,这个不考虑dimensions参数的方法将只能应用于那些不支持dimensions的模型 - abstract getEmbeddingDimensions(model: Model): Promise - public abstract convertMcpTools(mcpTools: MCPTool[]): T[] - public abstract mcpToolCallResponseToMessage( + // // 核心的completions方法 - 在中间件架构中,这通常只是一个占位符 + // abstract completions(params: CompletionsParams, internal?: ProcessingState): Promise + + /** + * 核心API Endpoint + **/ + + abstract createCompletions(payload: TSdkParams, options?: RequestOptions): Promise + + abstract generateImage(generateImageParams: GenerateImageParams): Promise + + abstract getEmbeddingDimensions(model?: Model): Promise + + abstract listModels(): Promise + + abstract getSdkInstance(): Promise | TSdkInstance + + /** + * 中间件 + **/ + + // 在 CoreRequestToSdkParamsMiddleware中使用 + abstract getRequestTransformer(): RequestTransformer + // 在RawSdkChunkToGenericChunkMiddleware中使用 + abstract getResponseChunkTransformer(): ResponseChunkTransformer + + /** + * 工具转换 + **/ + + // Optional tool conversion methods - implement if needed by the specific provider + abstract convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[] + + abstract convertSdkToolCallToMcp(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined + + abstract convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse + + abstract buildSdkMessages( + currentReqMessages: TMessageParam[], + output: TRawOutput | string, + toolResults: TMessageParam[], + toolCalls?: TToolCall[] + ): TMessageParam[] + + abstract estimateMessageTokens(message: TMessageParam): number + + abstract convertMcpToolResponseToSdkMessageParam( mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model - ): any + ): TMessageParam | undefined + + /** + * 从SDK载荷中提取消息数组(用于中间件中的类型安全访问) + * 不同的提供商可能使用不同的字段名(如messages、history等) + */ + abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[] + + /** + * 附加原始流监听器 + */ + public attachRawStreamListener>( + rawOutput: TRawOutput, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + _listener: TListener + ): TRawOutput { + return rawOutput + } + + /** + * 通用函数 + **/ public getBaseURL(): string { - const host = this.provider.apiHost - return formatApiHost(host) + return this.provider.apiHost } public getApiKey() { @@ -112,14 +190,32 @@ export default abstract class BaseProvider { return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.topP } - public async fakeCompletions({ onChunk }: CompletionsParams) { - for (let i = 0; i < 100; i++) { - await delay(0.01) - onChunk({ - response: { text: i + '\n', usage: { completion_tokens: 0, prompt_tokens: 0, total_tokens: 0 } }, - type: ChunkType.BLOCK_COMPLETE - }) + protected getServiceTier(model: Model) { + if (!isOpenAIModel(model) || model.provider === 'github' || model.provider === 'copilot') { + return undefined } + + const openAI = getStoreSetting('openAI') as SettingsState['openAI'] + let serviceTier = 'auto' as OpenAIServiceTier + + if (openAI && openAI?.serviceTier === 'flex') { + if (isSupportedFlexServiceTier(model)) { + serviceTier = 'flex' + } else { + serviceTier = 'auto' + } + } else { + serviceTier = openAI.serviceTier + } + + return serviceTier + } + + protected getTimeout(model: Model) { + if (isSupportedFlexServiceTier(model)) { + return 15 * 1000 * 60 + } + return defaultTimeout } public async getMessageContent(message: Message): Promise { @@ -149,6 +245,36 @@ export default abstract class BaseProvider { return content } + /** + * Extract the file content from the message + * @param message - The message + * @returns The file content + */ + protected async extractFileContent(message: Message) { + const fileBlocks = findFileBlocks(message) + if (fileBlocks.length > 0) { + const textFileBlocks = fileBlocks.filter( + (fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type) + ) + + if (textFileBlocks.length > 0) { + let text = '' + const divider = '\n\n---\n\n' + + for (const fileBlock of textFileBlocks) { + const file = fileBlock.file + const fileContent = (await window.api.file.read(file.id + file.ext)).trim() + const fileNameRow = 'file: ' + file.origin_name + '\n\n' + text = text + fileNameRow + fileContent + divider + } + + return text + } + } + + return '' + } + private async getWebSearchReferencesFromCache(message: Message) { const content = getMainTextContent(message) if (isEmpty(content)) { @@ -210,7 +336,7 @@ export default abstract class BaseProvider { ) } - protected createAbortController(messageId?: string, isAddEventListener?: boolean) { + public createAbortController(messageId?: string, isAddEventListener?: boolean) { const abortController = new AbortController() const abortFn = () => abortController.abort() @@ -256,11 +382,11 @@ export default abstract class BaseProvider { } // Setup tools configuration based on provided parameters - protected setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): { - tools: T[] + public setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): { + tools: TSdkSpecificTool[] } { const { mcpTools, model, enableToolUse } = params - let tools: T[] = [] + let tools: TSdkSpecificTool[] = [] // If there are no tools, return an empty array if (!mcpTools?.length) { @@ -268,14 +394,14 @@ export default abstract class BaseProvider { } // If the number of tools exceeds the threshold, use the system prompt - if (mcpTools.length > BaseProvider.SYSTEM_PROMPT_THRESHOLD) { + if (mcpTools.length > BaseApiClient.SYSTEM_PROMPT_THRESHOLD) { this.useSystemPromptForTools = true return { tools } } // If the model supports function calling and tool usage is enabled if (isFunctionCallingModel(model) && enableToolUse) { - tools = this.convertMcpTools(mcpTools) + tools = this.convertMcpToolsToSdkTools(mcpTools) this.useSystemPromptForTools = false } diff --git a/src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts b/src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts new file mode 100644 index 0000000000..ffbda737c6 --- /dev/null +++ b/src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts @@ -0,0 +1,714 @@ +import Anthropic from '@anthropic-ai/sdk' +import { + Base64ImageSource, + ImageBlockParam, + MessageParam, + TextBlockParam, + ToolResultBlockParam, + ToolUseBlock, + WebSearchTool20250305 +} from '@anthropic-ai/sdk/resources' +import { + ContentBlock, + ContentBlockParam, + MessageCreateParams, + MessageCreateParamsBase, + RedactedThinkingBlockParam, + ServerToolUseBlockParam, + ThinkingBlockParam, + ThinkingConfigParam, + ToolUnion, + ToolUseBlockParam, + WebSearchResultBlock, + WebSearchToolResultBlockParam, + WebSearchToolResultError +} from '@anthropic-ai/sdk/resources/messages' +import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages' +import { GenericChunk } from '@renderer/aiCore/middleware/schemas' +import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' +import Logger from '@renderer/config/logger' +import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models' +import { getAssistantSettings } from '@renderer/services/AssistantService' +import FileManager from '@renderer/services/FileManager' +import { estimateTextTokens } from '@renderer/services/TokenService' +import { + Assistant, + EFFORT_RATIO, + FileTypes, + MCPCallToolResponse, + MCPTool, + MCPToolResponse, + Model, + Provider, + ToolCallResponse, + WebSearchSource +} from '@renderer/types' +import { + ChunkType, + ErrorChunk, + LLMWebSearchCompleteChunk, + LLMWebSearchInProgressChunk, + MCPToolCreatedChunk, + TextDeltaChunk, + ThinkingDeltaChunk +} from '@renderer/types/chunk' +import type { Message } from '@renderer/types/newMessage' +import { + AnthropicSdkMessageParam, + AnthropicSdkParams, + AnthropicSdkRawChunk, + AnthropicSdkRawOutput +} from '@renderer/types/sdk' +import { addImageFileToContents } from '@renderer/utils/formats' +import { + anthropicToolUseToMcpTool, + isEnabledToolUse, + mcpToolCallResponseToAnthropicMessage, + mcpToolsToAnthropicTools +} from '@renderer/utils/mcp-tools' +import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' +import { buildSystemPrompt } from '@renderer/utils/prompt' + +import { BaseApiClient } from '../BaseApiClient' +import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types' + +export class AnthropicAPIClient extends BaseApiClient< + Anthropic, + AnthropicSdkParams, + AnthropicSdkRawOutput, + AnthropicSdkRawChunk, + AnthropicSdkMessageParam, + ToolUseBlock, + ToolUnion +> { + constructor(provider: Provider) { + super(provider) + } + + async getSdkInstance(): Promise { + if (this.sdkInstance) { + return this.sdkInstance + } + this.sdkInstance = new Anthropic({ + apiKey: this.getApiKey(), + baseURL: this.getBaseURL(), + dangerouslyAllowBrowser: true, + defaultHeaders: { + 'anthropic-beta': 'output-128k-2025-02-19' + } + }) + return this.sdkInstance + } + + override async createCompletions( + payload: AnthropicSdkParams, + options?: Anthropic.RequestOptions + ): Promise { + const sdk = await this.getSdkInstance() + if (payload.stream) { + return sdk.messages.stream(payload, options) + } + return await sdk.messages.create(payload, options) + } + + // @ts-ignore sdk未提供 + // eslint-disable-next-line @typescript-eslint/no-unused-vars + override async generateImage(generateImageParams: GenerateImageParams): Promise { + return [] + } + + override async listModels(): Promise { + const sdk = await this.getSdkInstance() + const response = await sdk.models.list() + return response.data + } + + // @ts-ignore sdk未提供 + override async getEmbeddingDimensions(): Promise { + return 0 + } + + override getTemperature(assistant: Assistant, model: Model): number | undefined { + if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { + return undefined + } + return assistant.settings?.temperature + } + + override getTopP(assistant: Assistant, model: Model): number | undefined { + if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { + return undefined + } + return assistant.settings?.topP + } + + /** + * Get the reasoning effort + * @param assistant - The assistant + * @param model - The model + * @returns The reasoning effort + */ + private getBudgetToken(assistant: Assistant, model: Model): ThinkingConfigParam | undefined { + if (!isReasoningModel(model)) { + return undefined + } + const { maxTokens } = getAssistantSettings(assistant) + + const reasoningEffort = assistant?.settings?.reasoning_effort + + if (reasoningEffort === undefined) { + return { + type: 'disabled' + } + } + + const effortRatio = EFFORT_RATIO[reasoningEffort] + + const budgetTokens = Math.max( + 1024, + Math.floor( + Math.min( + (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + + findTokenLimit(model.id)?.min!, + (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio + ) + ) + ) + + return { + type: 'enabled', + budget_tokens: budgetTokens + } + } + + /** + * Get the message parameter + * @param message - The message + * @param model - The model + * @returns The message parameter + */ + public async convertMessageToSdkParam(message: Message): Promise { + const parts: MessageParam['content'] = [ + { + type: 'text', + text: getMainTextContent(message) + } + ] + + // Get and process image blocks + const imageBlocks = findImageBlocks(message) + for (const imageBlock of imageBlocks) { + if (imageBlock.file) { + // Handle uploaded file + const file = imageBlock.file + const base64Data = await window.api.file.base64Image(file.id + file.ext) + parts.push({ + type: 'image', + source: { + data: base64Data.base64, + media_type: base64Data.mime.replace('jpg', 'jpeg') as any, + type: 'base64' + } + }) + } + } + // Get and process file blocks + const fileBlocks = findFileBlocks(message) + for (const fileBlock of fileBlocks) { + const { file } = fileBlock + if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { + if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) { + const base64Data = await FileManager.readBase64File(file) + parts.push({ + type: 'document', + source: { + type: 'base64', + media_type: 'application/pdf', + data: base64Data + } + }) + } else { + const fileContent = await (await window.api.file.read(file.id + file.ext)).trim() + parts.push({ + type: 'text', + text: file.origin_name + '\n' + fileContent + }) + } + } + } + + return { + role: message.role === 'system' ? 'user' : message.role, + content: parts + } + } + + public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ToolUnion[] { + return mcpToolsToAnthropicTools(mcpTools) + } + + public convertMcpToolResponseToSdkMessageParam( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ): AnthropicSdkMessageParam | undefined { + if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { + return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model) + } else if ('toolCallId' in mcpToolResponse) { + return { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: mcpToolResponse.toolCallId!, + content: resp.content + .map((item) => { + if (item.type === 'text') { + return { + type: 'text', + text: item.text || '' + } satisfies TextBlockParam + } + if (item.type === 'image') { + return { + type: 'image', + source: { + data: item.data || '', + media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'], + type: 'base64' + } + } satisfies ImageBlockParam + } + return + }) + .filter((n) => typeof n !== 'undefined'), + is_error: resp.isError + } satisfies ToolResultBlockParam + ] + } + } + return + } + + // Implementing abstract methods from BaseApiClient + convertSdkToolCallToMcp(toolCall: ToolUseBlock, mcpTools: MCPTool[]): MCPTool | undefined { + // Based on anthropicToolUseToMcpTool logic in AnthropicProvider + // This might need adjustment based on how tool calls are specifically handled in the new structure + const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall) + return mcpTool + } + + convertSdkToolCallToMcpToolResponse(toolCall: ToolUseBlock, mcpTool: MCPTool): ToolCallResponse { + return { + id: toolCall.id, + toolCallId: toolCall.id, + tool: mcpTool, + arguments: toolCall.input as Record, + status: 'pending' + } as ToolCallResponse + } + + override buildSdkMessages( + currentReqMessages: AnthropicSdkMessageParam[], + output: Anthropic.Message, + toolResults: AnthropicSdkMessageParam[] + ): AnthropicSdkMessageParam[] { + const assistantMessage: AnthropicSdkMessageParam = { + role: output.role, + content: convertContentBlocksToParams(output.content) + } + + const newMessages: AnthropicSdkMessageParam[] = [...currentReqMessages, assistantMessage] + if (toolResults && toolResults.length > 0) { + newMessages.push(...toolResults) + } + return newMessages + } + + override estimateMessageTokens(message: AnthropicSdkMessageParam): number { + if (typeof message.content === 'string') { + return estimateTextTokens(message.content) + } + return message.content + .map((content) => { + switch (content.type) { + case 'text': + return estimateTextTokens(content.text) + case 'image': + if (content.source.type === 'base64') { + return estimateTextTokens(content.source.data) + } else { + return estimateTextTokens(content.source.url) + } + case 'tool_use': + return estimateTextTokens(JSON.stringify(content.input)) + case 'tool_result': + return estimateTextTokens(JSON.stringify(content.content)) + default: + return 0 + } + }) + .reduce((acc, curr) => acc + curr, 0) + } + + public buildAssistantMessage(message: Anthropic.Message): AnthropicSdkMessageParam { + const messageParam: AnthropicSdkMessageParam = { + role: message.role, + content: convertContentBlocksToParams(message.content) + } + return messageParam + } + + public extractMessagesFromSdkPayload(sdkPayload: AnthropicSdkParams): AnthropicSdkMessageParam[] { + return sdkPayload.messages || [] + } + + /** + * Anthropic专用的原始流监听器 + * 处理MessageStream对象的特定事件 + */ + override attachRawStreamListener( + rawOutput: AnthropicSdkRawOutput, + listener: RawStreamListener + ): AnthropicSdkRawOutput { + console.log(`[AnthropicApiClient] 附加流监听器到原始输出`) + + // 检查是否为MessageStream + if (rawOutput instanceof MessageStream) { + console.log(`[AnthropicApiClient] 检测到 Anthropic MessageStream,附加专用监听器`) + + if (listener.onStart) { + listener.onStart() + } + + if (listener.onChunk) { + rawOutput.on('streamEvent', (event: AnthropicSdkRawChunk) => { + listener.onChunk!(event) + }) + } + + // 专用的Anthropic事件处理 + const anthropicListener = listener as AnthropicStreamListener + + if (anthropicListener.onContentBlock) { + rawOutput.on('contentBlock', anthropicListener.onContentBlock) + } + + if (anthropicListener.onMessage) { + rawOutput.on('finalMessage', anthropicListener.onMessage) + } + + if (listener.onEnd) { + rawOutput.on('end', () => { + listener.onEnd!() + }) + } + + if (listener.onError) { + rawOutput.on('error', (error: Error) => { + listener.onError!(error) + }) + } + + return rawOutput + } + + // 对于非MessageStream响应 + return rawOutput + } + + private async getWebSearchParams(model: Model): Promise { + if (!isWebSearchModel(model)) { + return undefined + } + return { + type: 'web_search_20250305', + name: 'web_search', + max_uses: 5 + } as WebSearchTool20250305 + } + + getRequestTransformer(): RequestTransformer { + return { + transform: async ( + coreRequest, + assistant, + model, + isRecursiveCall, + recursiveSdkMessages + ): Promise<{ + payload: AnthropicSdkParams + messages: AnthropicSdkMessageParam[] + metadata: Record + }> => { + const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest + // 1. 处理系统消息 + let systemPrompt = assistant.prompt + + // 2. 设置工具 + const { tools } = this.setupToolsConfig({ + mcpTools: mcpTools, + model, + enableToolUse: isEnabledToolUse(assistant) + }) + + if (this.useSystemPromptForTools) { + systemPrompt = await buildSystemPrompt(systemPrompt, mcpTools) + } + + const systemMessage: TextBlockParam | undefined = systemPrompt + ? { type: 'text', text: systemPrompt } + : undefined + + // 3. 处理用户消息 + const sdkMessages: AnthropicSdkMessageParam[] = [] + if (typeof messages === 'string') { + sdkMessages.push({ role: 'user', content: messages }) + } else { + const processedMessages = addImageFileToContents(messages) + for (const message of processedMessages) { + sdkMessages.push(await this.convertMessageToSdkParam(message)) + } + } + + if (enableWebSearch) { + const webSearchTool = await this.getWebSearchParams(model) + if (webSearchTool) { + tools.push(webSearchTool) + } + } + + const commonParams: MessageCreateParamsBase = { + model: model.id, + messages: + isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 + ? recursiveSdkMessages + : sdkMessages, + max_tokens: maxTokens || DEFAULT_MAX_TOKENS, + temperature: this.getTemperature(assistant, model), + top_p: this.getTopP(assistant, model), + system: systemMessage ? [systemMessage] : undefined, + thinking: this.getBudgetToken(assistant, model), + tools: tools.length > 0 ? tools : undefined, + ...this.getCustomParameters(assistant) + } + + const finalParams: MessageCreateParams = streamOutput + ? { + ...commonParams, + stream: true + } + : { + ...commonParams, + stream: false + } + + const timeout = this.getTimeout(model) + return { payload: finalParams, messages: sdkMessages, metadata: { timeout } } + } + } + } + + getResponseChunkTransformer(): ResponseChunkTransformer { + return () => { + let accumulatedJson = '' + const toolCalls: Record = {} + + return { + async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController) { + switch (rawChunk.type) { + case 'message': { + for (const content of rawChunk.content) { + switch (content.type) { + case 'text': { + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: content.text + } as TextDeltaChunk) + break + } + case 'tool_use': { + toolCalls[0] = content + break + } + case 'thinking': { + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: content.thinking + } as ThinkingDeltaChunk) + break + } + case 'web_search_tool_result': { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + results: content.content, + source: WebSearchSource.ANTHROPIC + } + } as LLMWebSearchCompleteChunk) + break + } + } + } + break + } + case 'content_block_start': { + const contentBlock = rawChunk.content_block + switch (contentBlock.type) { + case 'server_tool_use': { + if (contentBlock.name === 'web_search') { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS + } as LLMWebSearchInProgressChunk) + } + break + } + case 'web_search_tool_result': { + if ( + contentBlock.content && + (contentBlock.content as WebSearchToolResultError).type === 'web_search_tool_result_error' + ) { + controller.enqueue({ + type: ChunkType.ERROR, + error: { + code: (contentBlock.content as WebSearchToolResultError).error_code, + message: (contentBlock.content as WebSearchToolResultError).error_code + } + } as ErrorChunk) + } else { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + results: contentBlock.content as Array, + source: WebSearchSource.ANTHROPIC + } + } as LLMWebSearchCompleteChunk) + } + break + } + case 'tool_use': { + toolCalls[rawChunk.index] = contentBlock + break + } + } + break + } + case 'content_block_delta': { + const messageDelta = rawChunk.delta + switch (messageDelta.type) { + case 'text_delta': { + if (messageDelta.text) { + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: messageDelta.text + } as TextDeltaChunk) + } + break + } + case 'thinking_delta': { + if (messageDelta.thinking) { + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: messageDelta.thinking + } as ThinkingDeltaChunk) + } + break + } + case 'input_json_delta': { + if (messageDelta.partial_json) { + accumulatedJson += messageDelta.partial_json + } + break + } + } + break + } + case 'content_block_stop': { + const toolCall = toolCalls[rawChunk.index] + if (toolCall) { + try { + toolCall.input = JSON.parse(accumulatedJson) + Logger.debug(`Tool call id: ${toolCall.id}, accumulated json: ${accumulatedJson}`) + controller.enqueue({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: [toolCall] + } as MCPToolCreatedChunk) + } catch (error) { + Logger.error(`Error parsing tool call input: ${error}`) + } + } + break + } + case 'message_delta': { + controller.enqueue({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + prompt_tokens: rawChunk.usage.input_tokens || 0, + completion_tokens: rawChunk.usage.output_tokens || 0, + total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0) + } + } + }) + } + } + } + } + } + } +} + +/** + * 将 ContentBlock 数组转换为 ContentBlockParam 数组 + * 去除服务器生成的额外字段,只保留发送给API所需的字段 + */ +function convertContentBlocksToParams(contentBlocks: ContentBlock[]): ContentBlockParam[] { + return contentBlocks.map((block): ContentBlockParam => { + switch (block.type) { + case 'text': + // TextBlock -> TextBlockParam,去除 citations 等服务器字段 + return { + type: 'text', + text: block.text + } satisfies TextBlockParam + case 'tool_use': + // ToolUseBlock -> ToolUseBlockParam + return { + type: 'tool_use', + id: block.id, + name: block.name, + input: block.input + } satisfies ToolUseBlockParam + case 'thinking': + // ThinkingBlock -> ThinkingBlockParam + return { + type: 'thinking', + thinking: block.thinking, + signature: block.signature + } satisfies ThinkingBlockParam + case 'redacted_thinking': + // RedactedThinkingBlock -> RedactedThinkingBlockParam + return { + type: 'redacted_thinking', + data: block.data + } satisfies RedactedThinkingBlockParam + case 'server_tool_use': + // ServerToolUseBlock -> ServerToolUseBlockParam + return { + type: 'server_tool_use', + id: block.id, + name: block.name, + input: block.input + } satisfies ServerToolUseBlockParam + case 'web_search_tool_result': + // WebSearchToolResultBlock -> WebSearchToolResultBlockParam + return { + type: 'web_search_tool_result', + tool_use_id: block.tool_use_id, + content: block.content + } satisfies WebSearchToolResultBlockParam + default: + return block as ContentBlockParam + } + }) +} diff --git a/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts b/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts new file mode 100644 index 0000000000..b40aff9182 --- /dev/null +++ b/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts @@ -0,0 +1,781 @@ +import { + Content, + File, + FileState, + FunctionCall, + GenerateContentConfig, + GenerateImagesConfig, + GoogleGenAI, + HarmBlockThreshold, + HarmCategory, + Modality, + Model as GeminiModel, + Pager, + Part, + SafetySetting, + SendMessageParameters, + ThinkingConfig, + Tool +} from '@google/genai' +import { nanoid } from '@reduxjs/toolkit' +import { GenericChunk } from '@renderer/aiCore/middleware/schemas' +import { findTokenLimit, isGeminiReasoningModel, isGemmaModel, isVisionModel } from '@renderer/config/models' +import { CacheService } from '@renderer/services/CacheService' +import { estimateTextTokens } from '@renderer/services/TokenService' +import { + Assistant, + EFFORT_RATIO, + FileType, + FileTypes, + GenerateImageParams, + MCPCallToolResponse, + MCPTool, + MCPToolResponse, + Model, + Provider, + ToolCallResponse, + WebSearchSource +} from '@renderer/types' +import { ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk' +import { Message } from '@renderer/types/newMessage' +import { + GeminiOptions, + GeminiSdkMessageParam, + GeminiSdkParams, + GeminiSdkRawChunk, + GeminiSdkRawOutput, + GeminiSdkToolCall +} from '@renderer/types/sdk' +import { + geminiFunctionCallToMcpTool, + isEnabledToolUse, + mcpToolCallResponseToGeminiMessage, + mcpToolsToGeminiTools +} from '@renderer/utils/mcp-tools' +import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' +import { buildSystemPrompt } from '@renderer/utils/prompt' +import { MB } from '@shared/config/constant' + +import { BaseApiClient } from '../BaseApiClient' +import { RequestTransformer, ResponseChunkTransformer } from '../types' + +export class GeminiAPIClient extends BaseApiClient< + GoogleGenAI, + GeminiSdkParams, + GeminiSdkRawOutput, + GeminiSdkRawChunk, + GeminiSdkMessageParam, + GeminiSdkToolCall, + Tool +> { + constructor(provider: Provider) { + super(provider) + } + + override async createCompletions(payload: GeminiSdkParams, options?: GeminiOptions): Promise { + const sdk = await this.getSdkInstance() + const { model, history, ...rest } = payload + const realPayload: Omit = { + ...rest, + config: { + ...rest.config, + abortSignal: options?.abortSignal, + httpOptions: { + ...rest.config?.httpOptions, + timeout: options?.timeout + } + } + } satisfies SendMessageParameters + + const streamOutput = options?.streamOutput + + const chat = sdk.chats.create({ + model: model, + history: history + }) + + if (streamOutput) { + const stream = chat.sendMessageStream(realPayload) + return stream + } else { + const response = await chat.sendMessage(realPayload) + return response + } + } + + override async generateImage(generateImageParams: GenerateImageParams): Promise { + const sdk = await this.getSdkInstance() + try { + const { model, prompt, imageSize, batchSize, signal } = generateImageParams + const config: GenerateImagesConfig = { + numberOfImages: batchSize, + aspectRatio: imageSize, + abortSignal: signal, + httpOptions: { + timeout: 5 * 60 * 1000 + } + } + const response = await sdk.models.generateImages({ + model: model, + prompt, + config + }) + + if (!response.generatedImages || response.generatedImages.length === 0) { + return [] + } + + const images = response.generatedImages + .filter((image) => image.image?.imageBytes) + .map((image) => { + const dataPrefix = `data:${image.image?.mimeType || 'image/png'};base64,` + return dataPrefix + image.image?.imageBytes + }) + // console.log(response?.generatedImages?.[0]?.image?.imageBytes); + return images + } catch (error) { + console.error('[generateImage] error:', error) + throw error + } + } + + override async getEmbeddingDimensions(model: Model): Promise { + const sdk = await this.getSdkInstance() + try { + const data = await sdk.models.embedContent({ + model: model.id, + contents: [{ role: 'user', parts: [{ text: 'hi' }] }] + }) + return data.embeddings?.[0]?.values?.length || 0 + } catch (e) { + return 0 + } + } + + override async listModels(): Promise { + const sdk = await this.getSdkInstance() + const response = await sdk.models.list() + const models: GeminiModel[] = [] + for await (const model of response) { + models.push(model) + } + return models + } + + override async getSdkInstance() { + if (this.sdkInstance) { + return this.sdkInstance + } + + this.sdkInstance = new GoogleGenAI({ + vertexai: false, + apiKey: this.apiKey, + httpOptions: { baseUrl: this.getBaseURL() } + }) + + return this.sdkInstance + } + + /** + * Handle a PDF file + * @param file - The file + * @returns The part + */ + private async handlePdfFile(file: FileType): Promise { + const smallFileSize = 20 * MB + const isSmallFile = file.size < smallFileSize + + if (isSmallFile) { + const { data, mimeType } = await this.base64File(file) + return { + inlineData: { + data, + mimeType + } as Part['inlineData'] + } + } + + // Retrieve file from Gemini uploaded files + const fileMetadata: File | undefined = await this.retrieveFile(file) + + if (fileMetadata) { + return { + fileData: { + fileUri: fileMetadata.uri, + mimeType: fileMetadata.mimeType + } as Part['fileData'] + } + } + + // If file is not found, upload it to Gemini + const result = await this.uploadFile(file) + + return { + fileData: { + fileUri: result.uri, + mimeType: result.mimeType + } as Part['fileData'] + } + } + + /** + * Get the message contents + * @param message - The message + * @returns The message contents + */ + private async convertMessageToSdkParam(message: Message): Promise { + const role = message.role === 'user' ? 'user' : 'model' + const parts: Part[] = [{ text: await this.getMessageContent(message) }] + // Add any generated images from previous responses + const imageBlocks = findImageBlocks(message) + for (const imageBlock of imageBlocks) { + if ( + imageBlock.metadata?.generateImageResponse?.images && + imageBlock.metadata.generateImageResponse.images.length > 0 + ) { + for (const imageUrl of imageBlock.metadata.generateImageResponse.images) { + if (imageUrl && imageUrl.startsWith('data:')) { + // Extract base64 data and mime type from the data URL + const matches = imageUrl.match(/^data:(.+);base64,(.*)$/) + if (matches && matches.length === 3) { + const mimeType = matches[1] + const base64Data = matches[2] + parts.push({ + inlineData: { + data: base64Data, + mimeType: mimeType + } as Part['inlineData'] + }) + } + } + } + } + const file = imageBlock.file + if (file) { + const base64Data = await window.api.file.base64Image(file.id + file.ext) + parts.push({ + inlineData: { + data: base64Data.base64, + mimeType: base64Data.mime + } as Part['inlineData'] + }) + } + } + + const fileBlocks = findFileBlocks(message) + for (const fileBlock of fileBlocks) { + const file = fileBlock.file + if (file.type === FileTypes.IMAGE) { + const base64Data = await window.api.file.base64Image(file.id + file.ext) + parts.push({ + inlineData: { + data: base64Data.base64, + mimeType: base64Data.mime + } as Part['inlineData'] + }) + } + + if (file.ext === '.pdf') { + parts.push(await this.handlePdfFile(file)) + continue + } + if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { + const fileContent = await (await window.api.file.read(file.id + file.ext)).trim() + parts.push({ + text: file.origin_name + '\n' + fileContent + }) + } + } + + return { + role, + parts: parts + } + } + + // @ts-ignore unused + private async getImageFileContents(message: Message): Promise { + const role = message.role === 'user' ? 'user' : 'model' + const content = getMainTextContent(message) + const parts: Part[] = [{ text: content }] + const imageBlocks = findImageBlocks(message) + for (const imageBlock of imageBlocks) { + if ( + imageBlock.metadata?.generateImageResponse?.images && + imageBlock.metadata.generateImageResponse.images.length > 0 + ) { + for (const imageUrl of imageBlock.metadata.generateImageResponse.images) { + if (imageUrl && imageUrl.startsWith('data:')) { + // Extract base64 data and mime type from the data URL + const matches = imageUrl.match(/^data:(.+);base64,(.*)$/) + if (matches && matches.length === 3) { + const mimeType = matches[1] + const base64Data = matches[2] + parts.push({ + inlineData: { + data: base64Data, + mimeType: mimeType + } as Part['inlineData'] + }) + } + } + } + } + const file = imageBlock.file + if (file) { + const base64Data = await window.api.file.base64Image(file.id + file.ext) + parts.push({ + inlineData: { + data: base64Data.base64, + mimeType: base64Data.mime + } as Part['inlineData'] + }) + } + } + return { + role, + parts: parts + } + } + + /** + * Get the safety settings + * @returns The safety settings + */ + private getSafetySettings(): SafetySetting[] { + const safetyThreshold = 'OFF' as HarmBlockThreshold + + return [ + { + category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold: safetyThreshold + }, + { + category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold: safetyThreshold + }, + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: safetyThreshold + }, + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: safetyThreshold + }, + { + category: HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, + threshold: HarmBlockThreshold.BLOCK_NONE + } + ] + } + + /** + * Get the reasoning effort for the assistant + * @param assistant - The assistant + * @param model - The model + * @returns The reasoning effort + */ + private getBudgetToken(assistant: Assistant, model: Model) { + if (isGeminiReasoningModel(model)) { + const reasoningEffort = assistant?.settings?.reasoning_effort + const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini-.*-flash.*$') + + // 如果thinking_budget是undefined,不思考 + if (reasoningEffort === undefined) { + return { + thinkingConfig: { + includeThoughts: false, + ...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {}) + } as ThinkingConfig + } + } + + const effortRatio = EFFORT_RATIO[reasoningEffort] + + if (effortRatio > 1) { + return { + thinkingConfig: { + includeThoughts: true + } + } + } + + const { max } = findTokenLimit(model.id) || { max: 0 } + const budget = Math.floor(max * effortRatio) + + return { + thinkingConfig: { + ...(budget > 0 ? { thinkingBudget: budget } : {}), + includeThoughts: true + } as ThinkingConfig + } + } + + return {} + } + + private getGenerateImageParameter(): Partial { + return { + systemInstruction: undefined, + responseModalities: [Modality.TEXT, Modality.IMAGE], + responseMimeType: 'text/plain' + } + } + + getRequestTransformer(): RequestTransformer { + return { + transform: async ( + coreRequest, + assistant, + model, + isRecursiveCall, + recursiveSdkMessages + ): Promise<{ + payload: GeminiSdkParams + messages: GeminiSdkMessageParam[] + metadata: Record + }> => { + const { messages, mcpTools, maxTokens, enableWebSearch, enableGenerateImage } = coreRequest + // 1. 处理系统消息 + let systemInstruction = assistant.prompt + + // 2. 设置工具 + const { tools } = this.setupToolsConfig({ + mcpTools, + model, + enableToolUse: isEnabledToolUse(assistant) + }) + + if (this.useSystemPromptForTools) { + systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools) + } + + let messageContents: Content + const history: Content[] = [] + // 3. 处理用户消息 + if (typeof messages === 'string') { + messageContents = { + role: 'user', + parts: [{ text: messages }] + } + } else { + const userLastMessage = messages.pop()! + messageContents = await this.convertMessageToSdkParam(userLastMessage) + for (const message of messages) { + history.push(await this.convertMessageToSdkParam(message)) + } + } + + if (enableWebSearch) { + tools.push({ + googleSearch: {} + }) + } + + if (isGemmaModel(model) && assistant.prompt) { + const isFirstMessage = history.length === 0 + if (isFirstMessage && messageContents) { + const systemMessage = [ + { + text: + 'user\n' + + systemInstruction + + '\n' + + 'user\n' + + (messageContents?.parts?.[0] as Part).text + + '' + } + ] as Part[] + if (messageContents && messageContents.parts) { + messageContents.parts[0] = systemMessage[0] + } + } + } + + const newHistory = + isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 + ? recursiveSdkMessages.slice(0, recursiveSdkMessages.length - 1) + : history + + const newMessageContents = + isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 + ? { + ...messageContents, + parts: [ + ...(messageContents.parts || []), + ...(recursiveSdkMessages[recursiveSdkMessages.length - 1].parts || []) + ] + } + : messageContents + + const generateContentConfig: GenerateContentConfig = { + safetySettings: this.getSafetySettings(), + systemInstruction: isGemmaModel(model) ? undefined : systemInstruction, + temperature: this.getTemperature(assistant, model), + topP: this.getTopP(assistant, model), + maxOutputTokens: maxTokens, + tools: tools, + ...(enableGenerateImage ? this.getGenerateImageParameter() : {}), + ...this.getBudgetToken(assistant, model), + ...this.getCustomParameters(assistant) + } + + const param: GeminiSdkParams = { + model: model.id, + config: generateContentConfig, + history: newHistory, + message: newMessageContents.parts! + } + + return { + payload: param, + messages: [messageContents], + metadata: {} + } + } + } + } + + getResponseChunkTransformer(): ResponseChunkTransformer { + return () => ({ + async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController) { + let toolCalls: FunctionCall[] = [] + if (chunk.candidates && chunk.candidates.length > 0) { + for (const candidate of chunk.candidates) { + if (candidate.content) { + candidate.content.parts?.forEach((part) => { + const text = part.text || '' + if (part.thought) { + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: text + }) + } else if (part.text) { + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: text + }) + } else if (part.inlineData) { + controller.enqueue({ + type: ChunkType.IMAGE_COMPLETE, + image: { + type: 'base64', + images: [ + part.inlineData?.data?.startsWith('data:') + ? part.inlineData?.data + : `data:${part.inlineData?.mimeType || 'image/png'};base64,${part.inlineData?.data}` + ] + } + }) + } + }) + } + + if (candidate.finishReason) { + if (candidate.groundingMetadata) { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + results: candidate.groundingMetadata, + source: WebSearchSource.GEMINI + } + } as LLMWebSearchCompleteChunk) + } + if (chunk.functionCalls) { + toolCalls = toolCalls.concat(chunk.functionCalls) + } + controller.enqueue({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0, + completion_tokens: + (chunk.usageMetadata?.totalTokenCount || 0) - (chunk.usageMetadata?.promptTokenCount || 0), + total_tokens: chunk.usageMetadata?.totalTokenCount || 0 + } + } + }) + } + } + } + + if (toolCalls.length > 0) { + controller.enqueue({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: toolCalls + }) + } + } + }) + } + + public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Tool[] { + return mcpToolsToGeminiTools(mcpTools) + } + + public convertSdkToolCallToMcp(toolCall: GeminiSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined { + return geminiFunctionCallToMcpTool(mcpTools, toolCall) + } + + public convertSdkToolCallToMcpToolResponse(toolCall: GeminiSdkToolCall, mcpTool: MCPTool): ToolCallResponse { + const parsedArgs = (() => { + try { + return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args + } catch { + return toolCall.args + } + })() + + return { + id: toolCall.id || nanoid(), + toolCallId: toolCall.id, + tool: mcpTool, + arguments: parsedArgs, + status: 'pending' + } as ToolCallResponse + } + + public convertMcpToolResponseToSdkMessageParam( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ): GeminiSdkMessageParam | undefined { + if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { + return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model)) + } else if ('toolCallId' in mcpToolResponse) { + return { + role: 'user', + parts: [ + { + functionResponse: { + id: mcpToolResponse.toolCallId, + name: mcpToolResponse.tool.id, + response: { + output: !resp.isError ? resp.content : undefined, + error: resp.isError ? resp.content : undefined + } + } + } + ] + } satisfies Content + } + return + } + + public buildSdkMessages( + currentReqMessages: Content[], + output: string, + toolResults: Content[], + toolCalls: FunctionCall[] + ): Content[] { + const parts: Part[] = [] + if (output) { + parts.push({ + text: output + }) + } + toolCalls.forEach((toolCall) => { + parts.push({ + functionCall: toolCall + }) + }) + parts.push( + ...toolResults + .map((ts) => ts.parts) + .flat() + .filter((p) => p !== undefined) + ) + + const userMessage: Content = { + role: 'user', + parts: parts + } + + return [...currentReqMessages, userMessage] + } + + override estimateMessageTokens(message: GeminiSdkMessageParam): number { + return ( + message.parts?.reduce((acc, part) => { + if (part.text) { + return acc + estimateTextTokens(part.text) + } + if (part.functionCall) { + return acc + estimateTextTokens(JSON.stringify(part.functionCall)) + } + if (part.functionResponse) { + return acc + estimateTextTokens(JSON.stringify(part.functionResponse.response)) + } + if (part.inlineData) { + return acc + estimateTextTokens(part.inlineData.data || '') + } + if (part.fileData) { + return acc + estimateTextTokens(part.fileData.fileUri || '') + } + return acc + }, 0) || 0 + ) + } + + public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] { + return sdkPayload.history || [] + } + + private async uploadFile(file: FileType): Promise { + return await this.sdkInstance!.files.upload({ + file: file.path, + config: { + mimeType: 'application/pdf', + name: file.id, + displayName: file.origin_name + } + }) + } + + private async base64File(file: FileType) { + const { data } = await window.api.file.base64File(file.id + file.ext) + return { + data, + mimeType: 'application/pdf' + } + } + + private async retrieveFile(file: FileType): Promise { + const cachedResponse = CacheService.get('gemini_file_list') + + if (cachedResponse) { + return this.processResponse(cachedResponse, file) + } + + const response = await this.sdkInstance!.files.list() + CacheService.set('gemini_file_list', response, 3000) + + return this.processResponse(response, file) + } + + private async processResponse(response: Pager, file: FileType) { + for await (const f of response) { + if (f.state === FileState.ACTIVE) { + if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) { + return f + } + } + } + + return undefined + } + + // @ts-ignore unused + private async listFiles(): Promise { + const files: File[] = [] + for await (const f of await this.sdkInstance!.files.list()) { + files.push(f) + } + return files + } + + // @ts-ignore unused + private async deleteFile(fileId: string) { + await this.sdkInstance!.files.delete({ name: fileId }) + } +} diff --git a/src/renderer/src/aiCore/clients/index.ts b/src/renderer/src/aiCore/clients/index.ts new file mode 100644 index 0000000000..ec7f9d9d7e --- /dev/null +++ b/src/renderer/src/aiCore/clients/index.ts @@ -0,0 +1,6 @@ +export * from './ApiClientFactory' +export * from './BaseApiClient' +export * from './types' + +// Export specific clients from subdirectories +export * from './openai/OpenAIApiClient' diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts b/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts new file mode 100644 index 0000000000..1a78aea0f8 --- /dev/null +++ b/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts @@ -0,0 +1,646 @@ +import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' +import Logger from '@renderer/config/logger' +import { + findTokenLimit, + getOpenAIWebSearchParams, + isReasoningModel, + isSupportedReasoningEffortGrokModel, + isSupportedReasoningEffortModel, + isSupportedReasoningEffortOpenAIModel, + isSupportedThinkingTokenClaudeModel, + isSupportedThinkingTokenGeminiModel, + isSupportedThinkingTokenModel, + isSupportedThinkingTokenQwenModel, + isVisionModel +} from '@renderer/config/models' +import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService' +import { estimateTextTokens } from '@renderer/services/TokenService' +// For Copilot token +import { + Assistant, + EFFORT_RATIO, + FileTypes, + MCPCallToolResponse, + MCPTool, + MCPToolResponse, + Model, + Provider, + ToolCallResponse, + WebSearchSource +} from '@renderer/types' +import { ChunkType } from '@renderer/types/chunk' +import { Message } from '@renderer/types/newMessage' +import { + OpenAISdkMessageParam, + OpenAISdkParams, + OpenAISdkRawChunk, + OpenAISdkRawContentSource, + OpenAISdkRawOutput, + ReasoningEffortOptionalParams +} from '@renderer/types/sdk' +import { addImageFileToContents } from '@renderer/utils/formats' +import { + isEnabledToolUse, + mcpToolCallResponseToOpenAICompatibleMessage, + mcpToolsToOpenAIChatTools, + openAIToolsToMcpTool +} from '@renderer/utils/mcp-tools' +import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' +import { buildSystemPrompt } from '@renderer/utils/prompt' +import OpenAI, { AzureOpenAI } from 'openai' +import { ChatCompletionContentPart, ChatCompletionContentPartRefusal, ChatCompletionTool } from 'openai/resources' + +import { GenericChunk } from '../../middleware/schemas' +import { RequestTransformer, ResponseChunkTransformer, ResponseChunkTransformerContext } from '../types' +import { OpenAIBaseClient } from './OpenAIBaseClient' + +export class OpenAIAPIClient extends OpenAIBaseClient< + OpenAI | AzureOpenAI, + OpenAISdkParams, + OpenAISdkRawOutput, + OpenAISdkRawChunk, + OpenAISdkMessageParam, + OpenAI.Chat.Completions.ChatCompletionMessageToolCall, + ChatCompletionTool +> { + constructor(provider: Provider) { + super(provider) + } + + override async createCompletions( + payload: OpenAISdkParams, + options?: OpenAI.RequestOptions + ): Promise { + const sdk = await this.getSdkInstance() + // @ts-ignore - SDK参数可能有额外的字段 + return await sdk.chat.completions.create(payload, options) + } + + /** + * Get the reasoning effort for the assistant + * @param assistant - The assistant + * @param model - The model + * @returns The reasoning effort + */ + // Method for reasoning effort, moved from OpenAIProvider + override getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams { + if (this.provider.id === 'groq') { + return {} + } + + if (!isReasoningModel(model)) { + return {} + } + const reasoningEffort = assistant?.settings?.reasoning_effort + if (!reasoningEffort) { + if (isSupportedThinkingTokenQwenModel(model)) { + return { enable_thinking: false } + } + + if (isSupportedThinkingTokenClaudeModel(model)) { + return {} + } + + if (isSupportedThinkingTokenGeminiModel(model)) { + // openrouter没有提供一个不推理的选项,先隐藏 + if (this.provider.id === 'openrouter') { + return { reasoning: { max_tokens: 0, exclude: true } } + } + return { + reasoning_effort: 'none' + } + } + + return {} + } + const effortRatio = EFFORT_RATIO[reasoningEffort] + const budgetTokens = Math.floor( + (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min! + ) + + // OpenRouter models + if (model.provider === 'openrouter') { + if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) { + return { + reasoning: { + effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort + } + } + } + } + + // Qwen models + if (isSupportedThinkingTokenQwenModel(model)) { + return { + enable_thinking: true, + thinking_budget: budgetTokens + } + } + + // Grok models + if (isSupportedReasoningEffortGrokModel(model)) { + return { + reasoning_effort: reasoningEffort + } + } + + // OpenAI models + if (isSupportedReasoningEffortOpenAIModel(model) || isSupportedThinkingTokenGeminiModel(model)) { + return { + reasoning_effort: reasoningEffort + } + } + + // Claude models + if (isSupportedThinkingTokenClaudeModel(model)) { + const maxTokens = assistant.settings?.maxTokens + return { + thinking: { + type: 'enabled', + budget_tokens: Math.floor( + Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio)) + ) + } + } + } + + // Default case: no special thinking settings + return {} + } + + /** + * Check if the provider does not support files + * @returns True if the provider does not support files, false otherwise + */ + private get isNotSupportFiles() { + if (this.provider?.isNotSupportArrayContent) { + return true + } + + const providers = ['deepseek', 'baichuan', 'minimax', 'xirang'] + + return providers.includes(this.provider.id) + } + + /** + * Get the message parameter + * @param message - The message + * @param model - The model + * @returns The message parameter + */ + public async convertMessageToSdkParam(message: Message, model: Model): Promise { + const isVision = isVisionModel(model) + const content = await this.getMessageContent(message) + const fileBlocks = findFileBlocks(message) + const imageBlocks = findImageBlocks(message) + + if (fileBlocks.length === 0 && imageBlocks.length === 0) { + return { + role: message.role === 'system' ? 'user' : message.role, + content + } as OpenAISdkMessageParam + } + + // If the model does not support files, extract the file content + if (this.isNotSupportFiles) { + const fileContent = await this.extractFileContent(message) + + return { + role: message.role === 'system' ? 'user' : message.role, + content: content + '\n\n---\n\n' + fileContent + } as OpenAISdkMessageParam + } + + // If the model supports files, add the file content to the message + const parts: ChatCompletionContentPart[] = [] + + if (content) { + parts.push({ type: 'text', text: content }) + } + + for (const imageBlock of imageBlocks) { + if (isVision) { + if (imageBlock.file) { + const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) + parts.push({ type: 'image_url', image_url: { url: image.data } }) + } else if (imageBlock.url && imageBlock.url.startsWith('data:')) { + parts.push({ type: 'image_url', image_url: { url: imageBlock.url } }) + } + } + } + + for (const fileBlock of fileBlocks) { + const file = fileBlock.file + if (!file) { + continue + } + + if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { + const fileContent = await (await window.api.file.read(file.id + file.ext)).trim() + parts.push({ + type: 'text', + text: file.origin_name + '\n' + fileContent + }) + } + } + + return { + role: message.role === 'system' ? 'user' : message.role, + content: parts + } as OpenAISdkMessageParam + } + + public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ChatCompletionTool[] { + return mcpToolsToOpenAIChatTools(mcpTools) + } + + public convertSdkToolCallToMcp( + toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall, + mcpTools: MCPTool[] + ): MCPTool | undefined { + return openAIToolsToMcpTool(mcpTools, toolCall) + } + + public convertSdkToolCallToMcpToolResponse( + toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall, + mcpTool: MCPTool + ): ToolCallResponse { + let parsedArgs: any + try { + parsedArgs = JSON.parse(toolCall.function.arguments) + } catch { + parsedArgs = toolCall.function.arguments + } + return { + id: toolCall.id, + toolCallId: toolCall.id, + tool: mcpTool, + arguments: parsedArgs, + status: 'pending' + } as ToolCallResponse + } + + public convertMcpToolResponseToSdkMessageParam( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ): OpenAISdkMessageParam | undefined { + if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { + // This case is for Anthropic/Claude like tool usage, OpenAI uses tool_call_id + // For OpenAI, we primarily expect toolCallId. This might need adjustment if mixing provider concepts. + return mcpToolCallResponseToOpenAICompatibleMessage(mcpToolResponse, resp, isVisionModel(model)) + } else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) { + return { + role: 'tool', + tool_call_id: mcpToolResponse.toolCallId, + content: JSON.stringify(resp.content) + } as OpenAI.Chat.Completions.ChatCompletionToolMessageParam + } + return undefined + } + + public buildSdkMessages( + currentReqMessages: OpenAISdkMessageParam[], + output: string, + toolResults: OpenAISdkMessageParam[], + toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] + ): OpenAISdkMessageParam[] { + const assistantMessage: OpenAISdkMessageParam = { + role: 'assistant', + content: output, + tool_calls: toolCalls.length > 0 ? toolCalls : undefined + } + const newReqMessages = [...currentReqMessages, assistantMessage, ...toolResults] + return newReqMessages + } + + override estimateMessageTokens(message: OpenAISdkMessageParam): number { + let sum = 0 + if (typeof message.content === 'string') { + sum += estimateTextTokens(message.content) + } else if (Array.isArray(message.content)) { + sum += (message.content || []) + .map((part: ChatCompletionContentPart | ChatCompletionContentPartRefusal) => { + switch (part.type) { + case 'text': + return estimateTextTokens(part.text) + case 'image_url': + return estimateTextTokens(part.image_url.url) + case 'input_audio': + return estimateTextTokens(part.input_audio.data) + case 'file': + return estimateTextTokens(part.file.file_data || '') + default: + return 0 + } + }) + .reduce((acc, curr) => acc + curr, 0) + } + if ('tool_calls' in message && message.tool_calls) { + sum += message.tool_calls.reduce((acc, toolCall) => { + return acc + estimateTextTokens(JSON.stringify(toolCall.function.arguments)) + }, 0) + } + return sum + } + + public extractMessagesFromSdkPayload(sdkPayload: OpenAISdkParams): OpenAISdkMessageParam[] { + return sdkPayload.messages || [] + } + + getRequestTransformer(): RequestTransformer { + return { + transform: async ( + coreRequest, + assistant, + model, + isRecursiveCall, + recursiveSdkMessages + ): Promise<{ + payload: OpenAISdkParams + messages: OpenAISdkMessageParam[] + metadata: Record + }> => { + const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest + // 1. 处理系统消息 + let systemMessage = { role: 'system', content: assistant.prompt || '' } + + if (isSupportedReasoningEffortOpenAIModel(model)) { + systemMessage = { + role: 'developer', + content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}` + } + } + + if (model.id.includes('o1-mini') || model.id.includes('o1-preview')) { + systemMessage.role = 'assistant' + } + + // 2. 设置工具(必须在this.usesystemPromptForTools前面) + const { tools } = this.setupToolsConfig({ + mcpTools: mcpTools, + model, + enableToolUse: isEnabledToolUse(assistant) + }) + + if (this.useSystemPromptForTools) { + systemMessage.content = await buildSystemPrompt(systemMessage.content || '', mcpTools) + } + + // 3. 处理用户消息 + const userMessages: OpenAISdkMessageParam[] = [] + if (typeof messages === 'string') { + userMessages.push({ role: 'user', content: messages }) + } else { + const processedMessages = addImageFileToContents(messages) + for (const message of processedMessages) { + userMessages.push(await this.convertMessageToSdkParam(message, model)) + } + } + + const lastUserMsg = userMessages.findLast((m) => m.role === 'user') + if (lastUserMsg && isSupportedThinkingTokenQwenModel(model)) { + const postsuffix = '/no_think' + const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true + const currentContent = lastUserMsg.content + + lastUserMsg.content = processPostsuffixQwen3Model(currentContent, postsuffix, qwenThinkModeEnabled) as any + } + + // 4. 最终请求消息 + let reqMessages: OpenAISdkMessageParam[] + if (!systemMessage.content) { + reqMessages = [...userMessages] + } else { + reqMessages = [systemMessage, ...userMessages].filter(Boolean) as OpenAISdkMessageParam[] + } + + reqMessages = processReqMessages(model, reqMessages) + + // 5. 创建通用参数 + const commonParams = { + model: model.id, + messages: + isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 + ? recursiveSdkMessages + : reqMessages, + temperature: this.getTemperature(assistant, model), + top_p: this.getTopP(assistant, model), + max_tokens: maxTokens, + tools: tools.length > 0 ? tools : undefined, + service_tier: this.getServiceTier(model), + ...this.getProviderSpecificParameters(assistant, model), + ...this.getReasoningEffort(assistant, model), + ...getOpenAIWebSearchParams(model, enableWebSearch), + ...this.getCustomParameters(assistant) + } + + // Create the appropriate parameters object based on whether streaming is enabled + const sdkParams: OpenAISdkParams = streamOutput + ? { + ...commonParams, + stream: true + } + : { + ...commonParams, + stream: false + } + + const timeout = this.getTimeout(model) + + return { payload: sdkParams, messages: reqMessages, metadata: { timeout } } + } + } + } + + // 在RawSdkChunkToGenericChunkMiddleware中使用 + getResponseChunkTransformer = (): ResponseChunkTransformer => { + let hasBeenCollectedWebSearch = false + const collectWebSearchData = ( + chunk: OpenAISdkRawChunk, + contentSource: OpenAISdkRawContentSource, + context: ResponseChunkTransformerContext + ) => { + if (hasBeenCollectedWebSearch) { + return + } + // OpenAI annotations + // @ts-ignore - annotations may not be in standard type definitions + const annotations = contentSource.annotations || chunk.annotations + if (annotations && annotations.length > 0 && annotations[0].type === 'url_citation') { + hasBeenCollectedWebSearch = true + return { + results: annotations, + source: WebSearchSource.OPENAI + } + } + + // Grok citations + // @ts-ignore - citations may not be in standard type definitions + if (context.provider?.id === 'grok' && chunk.citations) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - citations may not be in standard type definitions + results: chunk.citations, + source: WebSearchSource.GROK + } + } + + // Perplexity citations + // @ts-ignore - citations may not be in standard type definitions + if (context.provider?.id === 'perplexity' && chunk.citations && chunk.citations.length > 0) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - citations may not be in standard type definitions + results: chunk.citations, + source: WebSearchSource.PERPLEXITY + } + } + + // OpenRouter citations + // @ts-ignore - citations may not be in standard type definitions + if (context.provider?.id === 'openrouter' && chunk.citations && chunk.citations.length > 0) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - citations may not be in standard type definitions + results: chunk.citations, + source: WebSearchSource.OPENROUTER + } + } + + // Zhipu web search + // @ts-ignore - web_search may not be in standard type definitions + if (context.provider?.id === 'zhipu' && chunk.web_search) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - web_search may not be in standard type definitions + results: chunk.web_search, + source: WebSearchSource.ZHIPU + } + } + + // Hunyuan web search + // @ts-ignore - search_info may not be in standard type definitions + if (context.provider?.id === 'hunyuan' && chunk.search_info?.search_results) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - search_info may not be in standard type definitions + results: chunk.search_info.search_results, + source: WebSearchSource.HUNYUAN + } + } + + // TODO: 放到AnthropicApiClient中 + // // Other providers... + // // @ts-ignore - web_search may not be in standard type definitions + // if (chunk.web_search) { + // const sourceMap: Record = { + // openai: 'openai', + // anthropic: 'anthropic', + // qwenlm: 'qwen' + // } + // const source = sourceMap[context.provider?.id] || 'openai_response' + // return { + // results: chunk.web_search, + // source: source as const + // } + // } + + return null + } + const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = [] + return (context: ResponseChunkTransformerContext) => ({ + async transform(chunk: OpenAISdkRawChunk, controller: TransformStreamDefaultController) { + // 处理chunk + if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) { + const choice = chunk.choices[0] + + if (!choice) return + + // 对于流式响应,使用delta;对于非流式响应,使用message + const contentSource: OpenAISdkRawContentSource | null = + 'delta' in choice ? choice.delta : 'message' in choice ? choice.message : null + + if (!contentSource) return + + const webSearchData = collectWebSearchData(chunk, contentSource, context) + if (webSearchData) { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: webSearchData + }) + } + + // 处理推理内容 (e.g. from OpenRouter DeepSeek-R1) + // @ts-ignore - reasoning_content is not in standard OpenAI types but some providers use it + const reasoningText = contentSource.reasoning_content || contentSource.reasoning + if (reasoningText) { + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: reasoningText + }) + } + + // 处理文本内容 + if (contentSource.content) { + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: contentSource.content + }) + } + + // 处理工具调用 + if (contentSource.tool_calls) { + for (const toolCall of contentSource.tool_calls) { + if ('index' in toolCall) { + const { id, index, function: fun } = toolCall + if (fun?.name) { + toolCalls[index] = { + id: id || '', + function: { + name: fun.name, + arguments: fun.arguments || '' + }, + type: 'function' + } + } else if (fun?.arguments) { + toolCalls[index].function.arguments += fun.arguments + } + } else { + toolCalls.push(toolCall) + } + } + } + + // 处理finish_reason,发送流结束信号 + if ('finish_reason' in choice && choice.finish_reason) { + Logger.debug(`[OpenAIApiClient] Stream finished with reason: ${choice.finish_reason}`) + if (toolCalls.length > 0) { + controller.enqueue({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: toolCalls + }) + } + const webSearchData = collectWebSearchData(chunk, contentSource, context) + if (webSearchData) { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: webSearchData + }) + } + controller.enqueue({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + prompt_tokens: chunk.usage?.prompt_tokens || 0, + completion_tokens: chunk.usage?.completion_tokens || 0, + total_tokens: (chunk.usage?.prompt_tokens || 0) + (chunk.usage?.completion_tokens || 0) + } + } + }) + } + } + } + }) + } +} diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts b/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts new file mode 100644 index 0000000000..a44f25def8 --- /dev/null +++ b/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts @@ -0,0 +1,258 @@ +import { + isClaudeReasoningModel, + isNotSupportTemperatureAndTopP, + isOpenAIReasoningModel, + isSupportedModel, + isSupportedReasoningEffortOpenAIModel +} from '@renderer/config/models' +import { getStoreSetting } from '@renderer/hooks/useSettings' +import { getAssistantSettings } from '@renderer/services/AssistantService' +import store from '@renderer/store' +import { SettingsState } from '@renderer/store/settings' +import { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types' +import { + OpenAIResponseSdkMessageParam, + OpenAIResponseSdkParams, + OpenAIResponseSdkRawChunk, + OpenAIResponseSdkRawOutput, + OpenAIResponseSdkTool, + OpenAIResponseSdkToolCall, + OpenAISdkMessageParam, + OpenAISdkParams, + OpenAISdkRawChunk, + OpenAISdkRawOutput, + ReasoningEffortOptionalParams +} from '@renderer/types/sdk' +import { formatApiHost } from '@renderer/utils/api' +import OpenAI, { AzureOpenAI } from 'openai' + +import { BaseApiClient } from '../BaseApiClient' + +/** + * 抽象的OpenAI基础客户端类,包含两个OpenAI客户端之间的共享功能 + */ +export abstract class OpenAIBaseClient< + TSdkInstance extends OpenAI | AzureOpenAI, + TSdkParams extends OpenAISdkParams | OpenAIResponseSdkParams, + TRawOutput extends OpenAISdkRawOutput | OpenAIResponseSdkRawOutput, + TRawChunk extends OpenAISdkRawChunk | OpenAIResponseSdkRawChunk, + TMessageParam extends OpenAISdkMessageParam | OpenAIResponseSdkMessageParam, + TToolCall extends OpenAI.Chat.Completions.ChatCompletionMessageToolCall | OpenAIResponseSdkToolCall, + TSdkSpecificTool extends OpenAI.Chat.Completions.ChatCompletionTool | OpenAIResponseSdkTool +> extends BaseApiClient { + constructor(provider: Provider) { + super(provider) + } + + // 仅适用于openai + override getBaseURL(): string { + const host = this.provider.apiHost + return formatApiHost(host) + } + + override async generateImage({ + model, + prompt, + negativePrompt, + imageSize, + batchSize, + seed, + numInferenceSteps, + guidanceScale, + signal, + promptEnhancement + }: GenerateImageParams): Promise { + const sdk = await this.getSdkInstance() + const response = (await sdk.request({ + method: 'post', + path: '/images/generations', + signal, + body: { + model, + prompt, + negative_prompt: negativePrompt, + image_size: imageSize, + batch_size: batchSize, + seed: seed ? parseInt(seed) : undefined, + num_inference_steps: numInferenceSteps, + guidance_scale: guidanceScale, + prompt_enhancement: promptEnhancement + } + })) as { data: Array<{ url: string }> } + + return response.data.map((item) => item.url) + } + + override async getEmbeddingDimensions(model: Model): Promise { + const sdk = await this.getSdkInstance() + try { + const data = await sdk.embeddings.create({ + model: model.id, + input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi', + encoding_format: 'float' + }) + return data.data[0].embedding.length + } catch (e) { + return 0 + } + } + + override async listModels(): Promise { + try { + const sdk = await this.getSdkInstance() + const response = await sdk.models.list() + if (this.provider.id === 'github') { + // @ts-ignore key is not typed + return response?.body + .map((model) => ({ + id: model.name, + description: model.summary, + object: 'model', + owned_by: model.publisher + })) + .filter(isSupportedModel) + } + if (this.provider.id === 'together') { + // @ts-ignore key is not typed + return response?.body.map((model) => ({ + id: model.id, + description: model.display_name, + object: 'model', + owned_by: model.organization + })) + } + const models = response.data || [] + models.forEach((model) => { + model.id = model.id.trim() + }) + + return models.filter(isSupportedModel) + } catch (error) { + console.error('Error listing models:', error) + return [] + } + } + + override async getSdkInstance() { + if (this.sdkInstance) { + return this.sdkInstance + } + + let apiKeyForSdkInstance = this.provider.apiKey + + if (this.provider.id === 'copilot') { + const defaultHeaders = store.getState().copilot.defaultHeaders + const { token } = await window.api.copilot.getToken(defaultHeaders) + // this.provider.apiKey不允许修改 + // this.provider.apiKey = token + apiKeyForSdkInstance = token + } + + if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') { + this.sdkInstance = new AzureOpenAI({ + dangerouslyAllowBrowser: true, + apiKey: apiKeyForSdkInstance, + apiVersion: this.provider.apiVersion, + endpoint: this.provider.apiHost + }) as TSdkInstance + } else { + this.sdkInstance = new OpenAI({ + dangerouslyAllowBrowser: true, + apiKey: apiKeyForSdkInstance, + baseURL: this.getBaseURL(), + defaultHeaders: { + ...this.defaultHeaders(), + ...(this.provider.id === 'copilot' ? { 'editor-version': 'vscode/1.97.2' } : {}), + ...(this.provider.id === 'copilot' ? { 'copilot-vision-request': 'true' } : {}) + } + }) as TSdkInstance + } + return this.sdkInstance + } + + override getTemperature(assistant: Assistant, model: Model): number | undefined { + if ( + isNotSupportTemperatureAndTopP(model) || + (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) + ) { + return undefined + } + return assistant.settings?.temperature + } + + override getTopP(assistant: Assistant, model: Model): number | undefined { + if ( + isNotSupportTemperatureAndTopP(model) || + (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) + ) { + return undefined + } + return assistant.settings?.topP + } + + /** + * Get the provider specific parameters for the assistant + * @param assistant - The assistant + * @param model - The model + * @returns The provider specific parameters + */ + protected getProviderSpecificParameters(assistant: Assistant, model: Model) { + const { maxTokens } = getAssistantSettings(assistant) + + if (this.provider.id === 'openrouter') { + if (model.id.includes('deepseek-r1')) { + return { + include_reasoning: true + } + } + } + + if (isOpenAIReasoningModel(model)) { + return { + max_tokens: undefined, + max_completion_tokens: maxTokens + } + } + + return {} + } + + /** + * Get the reasoning effort for the assistant + * @param assistant - The assistant + * @param model - The model + * @returns The reasoning effort + */ + protected getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams { + if (!isSupportedReasoningEffortOpenAIModel(model)) { + return {} + } + + const openAI = getStoreSetting('openAI') as SettingsState['openAI'] + const summaryText = openAI?.summaryText || 'off' + + let summary: string | undefined = undefined + + if (summaryText === 'off' || model.id.includes('o1-pro')) { + summary = undefined + } else { + summary = summaryText + } + + const reasoningEffort = assistant?.settings?.reasoning_effort + if (!reasoningEffort) { + return {} + } + + if (isSupportedReasoningEffortOpenAIModel(model)) { + return { + reasoning: { + effort: reasoningEffort as OpenAI.ReasoningEffort, + summary: summary + } as OpenAI.Reasoning + } + } + + return {} + } +} diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts b/src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts new file mode 100644 index 0000000000..0fdd65f709 --- /dev/null +++ b/src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts @@ -0,0 +1,532 @@ +import { GenericChunk } from '@renderer/aiCore/middleware/schemas' +import { + isOpenAIChatCompletionOnlyModel, + isSupportedReasoningEffortOpenAIModel, + isVisionModel +} from '@renderer/config/models' +import { estimateTextTokens } from '@renderer/services/TokenService' +import { + FileTypes, + MCPCallToolResponse, + MCPTool, + MCPToolResponse, + Model, + Provider, + ToolCallResponse, + WebSearchSource +} from '@renderer/types' +import { ChunkType } from '@renderer/types/chunk' +import { Message } from '@renderer/types/newMessage' +import { + OpenAIResponseSdkMessageParam, + OpenAIResponseSdkParams, + OpenAIResponseSdkRawChunk, + OpenAIResponseSdkRawOutput, + OpenAIResponseSdkTool, + OpenAIResponseSdkToolCall +} from '@renderer/types/sdk' +import { addImageFileToContents } from '@renderer/utils/formats' +import { + isEnabledToolUse, + mcpToolCallResponseToOpenAIMessage, + mcpToolsToOpenAIResponseTools, + openAIToolsToMcpTool +} from '@renderer/utils/mcp-tools' +import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' +import { buildSystemPrompt } from '@renderer/utils/prompt' +import { isEmpty } from 'lodash' +import OpenAI from 'openai' + +import { RequestTransformer, ResponseChunkTransformer } from '../types' +import { OpenAIAPIClient } from './OpenAIApiClient' +import { OpenAIBaseClient } from './OpenAIBaseClient' + +export class OpenAIResponseAPIClient extends OpenAIBaseClient< + OpenAI, + OpenAIResponseSdkParams, + OpenAIResponseSdkRawOutput, + OpenAIResponseSdkRawChunk, + OpenAIResponseSdkMessageParam, + OpenAIResponseSdkToolCall, + OpenAIResponseSdkTool +> { + private client: OpenAIAPIClient + constructor(provider: Provider) { + super(provider) + this.client = new OpenAIAPIClient(provider) + } + + /** + * 根据模型特征选择合适的客户端 + */ + public getClient(model: Model) { + if (isOpenAIChatCompletionOnlyModel(model)) { + return this.client + } else { + return this + } + } + + override async getSdkInstance() { + if (this.sdkInstance) { + return this.sdkInstance + } + + return new OpenAI({ + dangerouslyAllowBrowser: true, + apiKey: this.provider.apiKey, + baseURL: this.getBaseURL(), + defaultHeaders: { + ...this.defaultHeaders() + } + }) + } + + override async createCompletions( + payload: OpenAIResponseSdkParams, + options?: OpenAI.RequestOptions + ): Promise { + const sdk = await this.getSdkInstance() + return await sdk.responses.create(payload, options) + } + + public async convertMessageToSdkParam(message: Message, model: Model): Promise { + const isVision = isVisionModel(model) + const content = await this.getMessageContent(message) + const fileBlocks = findFileBlocks(message) + const imageBlocks = findImageBlocks(message) + + if (fileBlocks.length === 0 && imageBlocks.length === 0) { + if (message.role === 'assistant') { + return { + role: 'assistant', + content: content + } + } else { + return { + role: message.role === 'system' ? 'user' : message.role, + content: content ? [{ type: 'input_text', text: content }] : [] + } as OpenAI.Responses.EasyInputMessage + } + } + + const parts: OpenAI.Responses.ResponseInputContent[] = [] + if (content) { + parts.push({ + type: 'input_text', + text: content + }) + } + + for (const imageBlock of imageBlocks) { + if (isVision) { + if (imageBlock.file) { + const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) + parts.push({ + detail: 'auto', + type: 'input_image', + image_url: image.data as string + }) + } else if (imageBlock.url && imageBlock.url.startsWith('data:')) { + parts.push({ + detail: 'auto', + type: 'input_image', + image_url: imageBlock.url + }) + } + } + } + + for (const fileBlock of fileBlocks) { + const file = fileBlock.file + if (!file) continue + + if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { + const fileContent = (await window.api.file.read(file.id + file.ext)).trim() + parts.push({ + type: 'input_text', + text: file.origin_name + '\n' + fileContent + }) + } + } + + return { + role: message.role === 'system' ? 'user' : message.role, + content: parts + } + } + + public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): OpenAI.Responses.Tool[] { + return mcpToolsToOpenAIResponseTools(mcpTools) + } + + public convertSdkToolCallToMcp(toolCall: OpenAIResponseSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined { + return openAIToolsToMcpTool(mcpTools, toolCall) + } + public convertSdkToolCallToMcpToolResponse(toolCall: OpenAIResponseSdkToolCall, mcpTool: MCPTool): ToolCallResponse { + const parsedArgs = (() => { + try { + return JSON.parse(toolCall.arguments) + } catch { + return toolCall.arguments + } + })() + + return { + id: toolCall.call_id, + toolCallId: toolCall.call_id, + tool: mcpTool, + arguments: parsedArgs, + status: 'pending' + } + } + + public convertMcpToolResponseToSdkMessageParam( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ): OpenAIResponseSdkMessageParam | undefined { + if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { + return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model)) + } else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) { + return { + type: 'function_call_output', + call_id: mcpToolResponse.toolCallId, + output: JSON.stringify(resp.content) + } + } + return + } + + public buildSdkMessages( + currentReqMessages: OpenAIResponseSdkMessageParam[], + output: string, + toolResults: OpenAIResponseSdkMessageParam[], + toolCalls: OpenAIResponseSdkToolCall[] + ): OpenAIResponseSdkMessageParam[] { + const assistantMessage: OpenAIResponseSdkMessageParam = { + role: 'assistant', + content: [{ type: 'input_text', text: output }] + } + const newReqMessages = [...currentReqMessages, assistantMessage, ...(toolCalls || []), ...(toolResults || [])] + return newReqMessages + } + + override estimateMessageTokens(message: OpenAIResponseSdkMessageParam): number { + let sum = 0 + if ('content' in message) { + if (typeof message.content === 'string') { + sum += estimateTextTokens(message.content) + } else if (Array.isArray(message.content)) { + for (const part of message.content) { + switch (part.type) { + case 'input_text': + sum += estimateTextTokens(part.text) + break + case 'input_image': + sum += estimateTextTokens(part.image_url || '') + break + default: + break + } + } + } + } + switch (message.type) { + case 'function_call_output': + sum += estimateTextTokens(message.output) + break + case 'function_call': + sum += estimateTextTokens(message.arguments) + break + default: + break + } + return sum + } + + public extractMessagesFromSdkPayload(sdkPayload: OpenAIResponseSdkParams): OpenAIResponseSdkMessageParam[] { + if (typeof sdkPayload.input === 'string') { + return [{ role: 'user', content: sdkPayload.input }] + } + return sdkPayload.input + } + + getRequestTransformer(): RequestTransformer { + return { + transform: async ( + coreRequest, + assistant, + model, + isRecursiveCall, + recursiveSdkMessages + ): Promise<{ + payload: OpenAIResponseSdkParams + messages: OpenAIResponseSdkMessageParam[] + metadata: Record + }> => { + const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch, enableGenerateImage } = coreRequest + // 1. 处理系统消息 + const systemMessage: OpenAI.Responses.EasyInputMessage = { + role: 'system', + content: [] + } + + const systemMessageContent: OpenAI.Responses.ResponseInputMessageContentList = [] + const systemMessageInput: OpenAI.Responses.ResponseInputText = { + text: assistant.prompt || '', + type: 'input_text' + } + if (isSupportedReasoningEffortOpenAIModel(model)) { + systemMessage.role = 'developer' + } + + // 2. 设置工具 + let tools: OpenAI.Responses.Tool[] = [] + const { tools: extraTools } = this.setupToolsConfig({ + mcpTools: mcpTools, + model, + enableToolUse: isEnabledToolUse(assistant) + }) + + if (this.useSystemPromptForTools) { + systemMessageInput.text = await buildSystemPrompt(systemMessageInput.text || '', mcpTools) + } + systemMessageContent.push(systemMessageInput) + systemMessage.content = systemMessageContent + + // 3. 处理用户消息 + let userMessage: OpenAI.Responses.ResponseInputItem[] = [] + if (typeof messages === 'string') { + userMessage.push({ role: 'user', content: messages }) + } else { + const processedMessages = addImageFileToContents(messages) + for (const message of processedMessages) { + userMessage.push(await this.convertMessageToSdkParam(message, model)) + } + } + // FIXME: 最好还是直接使用previous_response_id来处理(或者在数据库中存储image_generation_call的id) + if (enableGenerateImage) { + const finalAssistantMessage = userMessage.findLast( + (m) => (m as OpenAI.Responses.EasyInputMessage).role === 'assistant' + ) as OpenAI.Responses.EasyInputMessage + const finalUserMessage = userMessage.pop() as OpenAI.Responses.EasyInputMessage + if ( + finalAssistantMessage && + Array.isArray(finalAssistantMessage.content) && + finalUserMessage && + Array.isArray(finalUserMessage.content) + ) { + finalAssistantMessage.content = [...finalAssistantMessage.content, ...finalUserMessage.content] + } + // 这里是故意将上条助手消息的内容(包含图片和文件)作为用户消息发送 + userMessage = [{ ...finalAssistantMessage, role: 'user' } as OpenAI.Responses.EasyInputMessage] + } + + // 4. 最终请求消息 + let reqMessages: OpenAI.Responses.ResponseInput + if (!systemMessage.content) { + reqMessages = [...userMessage] + } else { + reqMessages = [systemMessage, ...userMessage].filter(Boolean) as OpenAI.Responses.EasyInputMessage[] + } + + if (enableWebSearch) { + tools.push({ + type: 'web_search_preview' + }) + } + + if (enableGenerateImage) { + tools.push({ + type: 'image_generation', + partial_images: streamOutput ? 2 : undefined + }) + } + + const toolChoices: OpenAI.Responses.ToolChoiceTypes = { + type: 'web_search_preview' + } + + tools = tools.concat(extraTools) + const commonParams = { + model: model.id, + input: + isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 + ? recursiveSdkMessages + : reqMessages, + temperature: this.getTemperature(assistant, model), + top_p: this.getTopP(assistant, model), + max_output_tokens: maxTokens, + stream: streamOutput, + tools: !isEmpty(tools) ? tools : undefined, + tool_choice: enableWebSearch ? toolChoices : undefined, + service_tier: this.getServiceTier(model), + ...(this.getReasoningEffort(assistant, model) as OpenAI.Reasoning), + ...this.getCustomParameters(assistant) + } + const sdkParams: OpenAIResponseSdkParams = streamOutput + ? { + ...commonParams, + stream: true + } + : { + ...commonParams, + stream: false + } + const timeout = this.getTimeout(model) + return { payload: sdkParams, messages: reqMessages, metadata: { timeout } } + } + } + } + + getResponseChunkTransformer(): ResponseChunkTransformer { + const toolCalls: OpenAIResponseSdkToolCall[] = [] + const outputItems: OpenAI.Responses.ResponseOutputItem[] = [] + return () => ({ + async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController) { + // 处理chunk + if ('output' in chunk) { + for (const output of chunk.output) { + switch (output.type) { + case 'message': + if (output.content[0].type === 'output_text') { + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: output.content[0].text + }) + if (output.content[0].annotations && output.content[0].annotations.length > 0) { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + source: WebSearchSource.OPENAI_RESPONSE, + results: output.content[0].annotations + } + }) + } + } + break + case 'reasoning': + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: output.summary.map((s) => s.text).join('\n') + }) + break + case 'function_call': + toolCalls.push(output) + break + case 'image_generation_call': + controller.enqueue({ + type: ChunkType.IMAGE_CREATED + }) + controller.enqueue({ + type: ChunkType.IMAGE_COMPLETE, + image: { + type: 'base64', + images: [`data:image/png;base64,${output.result}`] + } + }) + } + } + } else { + switch (chunk.type) { + case 'response.output_item.added': + if (chunk.item.type === 'function_call') { + outputItems.push(chunk.item) + } + break + case 'response.reasoning_summary_text.delta': + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: chunk.delta + }) + break + case 'response.image_generation_call.generating': + controller.enqueue({ + type: ChunkType.IMAGE_CREATED + }) + break + case 'response.image_generation_call.partial_image': + controller.enqueue({ + type: ChunkType.IMAGE_DELTA, + image: { + type: 'base64', + images: [`data:image/png;base64,${chunk.partial_image_b64}`] + } + }) + break + case 'response.image_generation_call.completed': + controller.enqueue({ + type: ChunkType.IMAGE_COMPLETE + }) + break + case 'response.output_text.delta': { + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: chunk.delta + }) + break + } + case 'response.function_call_arguments.done': { + const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find( + (item) => item.id === chunk.item_id + ) + if (outputItem) { + if (outputItem.type === 'function_call') { + toolCalls.push({ + ...outputItem, + arguments: chunk.arguments + }) + } + } + break + } + case 'response.content_part.done': { + if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + source: WebSearchSource.OPENAI_RESPONSE, + results: chunk.part.annotations + } + }) + } + if (toolCalls.length > 0) { + controller.enqueue({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: toolCalls + }) + } + break + } + case 'response.completed': { + const completion_tokens = chunk.response.usage?.output_tokens || 0 + const total_tokens = chunk.response.usage?.total_tokens || 0 + controller.enqueue({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + prompt_tokens: chunk.response.usage?.input_tokens || 0, + completion_tokens: completion_tokens, + total_tokens: total_tokens + } + } + }) + break + } + case 'error': { + controller.enqueue({ + type: ChunkType.ERROR, + error: { + message: chunk.message, + code: chunk.code + } + }) + break + } + } + } + } + }) + } +} diff --git a/src/renderer/src/aiCore/clients/types.ts b/src/renderer/src/aiCore/clients/types.ts new file mode 100644 index 0000000000..84562a13e9 --- /dev/null +++ b/src/renderer/src/aiCore/clients/types.ts @@ -0,0 +1,129 @@ +import Anthropic from '@anthropic-ai/sdk' +import { Assistant, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types' +import { Provider } from '@renderer/types' +import { + AnthropicSdkRawChunk, + OpenAISdkRawChunk, + SdkMessageParam, + SdkParams, + SdkRawChunk, + SdkRawOutput, + SdkTool, + SdkToolCall +} from '@renderer/types/sdk' +import OpenAI from 'openai' + +import { CompletionsParams, GenericChunk } from '../middleware/schemas' + +/** + * 原始流监听器接口 + */ +export interface RawStreamListener { + onChunk?: (chunk: TRawChunk) => void + onStart?: () => void + onEnd?: () => void + onError?: (error: Error) => void +} + +/** + * OpenAI 专用的流监听器 + */ +export interface OpenAIStreamListener extends RawStreamListener { + onChoice?: (choice: OpenAI.Chat.Completions.ChatCompletionChunk.Choice) => void + onFinishReason?: (reason: string) => void +} + +/** + * Anthropic 专用的流监听器 + */ +export interface AnthropicStreamListener + extends RawStreamListener { + onContentBlock?: (contentBlock: Anthropic.Messages.ContentBlock) => void + onMessage?: (message: Anthropic.Messages.Message) => void +} + +/** + * 请求转换器接口 + */ +export interface RequestTransformer< + TSdkParams extends SdkParams = SdkParams, + TMessageParam extends SdkMessageParam = SdkMessageParam +> { + transform( + completionsParams: CompletionsParams, + assistant: Assistant, + model: Model, + isRecursiveCall?: boolean, + recursiveSdkMessages?: TMessageParam[] + ): Promise<{ + payload: TSdkParams + messages: TMessageParam[] + metadata?: Record + }> +} + +/** + * 响应块转换器接口 + */ +export type ResponseChunkTransformer = ( + context?: TContext +) => Transformer + +export interface ResponseChunkTransformerContext { + isStreaming: boolean + isEnabledToolCalling: boolean + isEnabledWebSearch: boolean + isEnabledReasoning: boolean + mcpTools: MCPTool[] + provider: Provider +} + +/** + * API客户端接口 + */ +export interface ApiClient< + TSdkInstance = any, + TSdkParams extends SdkParams = SdkParams, + TRawOutput extends SdkRawOutput = SdkRawOutput, + TRawChunk extends SdkRawChunk = SdkRawChunk, + TMessageParam extends SdkMessageParam = SdkMessageParam, + TToolCall extends SdkToolCall = SdkToolCall, + TSdkSpecificTool extends SdkTool = SdkTool +> { + provider: Provider + + // 核心方法 - 在中间件架构中,这个方法可能只是一个占位符 + // 实际的SDK调用由SdkCallMiddleware处理 + // completions(params: CompletionsParams): Promise + + createCompletions(payload: TSdkParams): Promise + + // SDK相关方法 + getSdkInstance(): Promise | TSdkInstance + getRequestTransformer(): RequestTransformer + getResponseChunkTransformer(): ResponseChunkTransformer + + // 原始流监听方法 + attachRawStreamListener?(rawOutput: TRawOutput, listener: RawStreamListener): TRawOutput + + // 工具转换相关方法 (保持可选,因为不是所有Provider都支持工具) + convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[] + convertMcpToolResponseToSdkMessageParam?( + mcpToolResponse: MCPToolResponse, + resp: any, + model: Model + ): TMessageParam | undefined + convertSdkToolCallToMcp?(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined + convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse + + // 构建SDK特定的消息列表,用于工具调用后的递归调用 + buildSdkMessages( + currentReqMessages: TMessageParam[], + output: TRawOutput | string, + toolResults: TMessageParam[], + toolCalls?: TToolCall[] + ): TMessageParam[] + + // 从SDK载荷中提取消息数组(用于中间件中的类型安全访问) + extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[] +} diff --git a/src/renderer/src/aiCore/index.ts b/src/renderer/src/aiCore/index.ts new file mode 100644 index 0000000000..c7bffa17de --- /dev/null +++ b/src/renderer/src/aiCore/index.ts @@ -0,0 +1,130 @@ +import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory' +import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient' +import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models' +import type { GenerateImageParams, Model, Provider } from '@renderer/types' +import { RequestOptions, SdkModel } from '@renderer/types/sdk' +import { isEnabledToolUse } from '@renderer/utils/mcp-tools' + +import { OpenAIAPIClient } from './clients' +import { AihubmixAPIClient } from './clients/AihubmixAPIClient' +import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient' +import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient' +import { CompletionsMiddlewareBuilder } from './middleware/builder' +import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware' +import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware' +import { applyCompletionsMiddlewares } from './middleware/composer' +import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware' +import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware' +import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware' +import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware' +import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware' +import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware' +import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware' +import { MiddlewareRegistry } from './middleware/register' +import { CompletionsParams, CompletionsResult } from './middleware/schemas' + +export default class AiProvider { + private apiClient: BaseApiClient + + constructor(provider: Provider) { + // Use the new ApiClientFactory to get a BaseApiClient instance + this.apiClient = ApiClientFactory.create(provider) + } + + public async completions(params: CompletionsParams, options?: RequestOptions): Promise { + // 1. 根据模型识别正确的客户端 + const model = params.assistant.model + if (!model) { + return Promise.reject(new Error('Model is required')) + } + + // 根据client类型选择合适的处理方式 + let client: BaseApiClient + + if (this.apiClient instanceof AihubmixAPIClient) { + // AihubmixAPIClient: 根据模型选择合适的子client + client = this.apiClient.getClientForModel(model) + if (client instanceof OpenAIResponseAPIClient) { + client = client.getClient(model) as BaseApiClient + } + } else if (this.apiClient instanceof OpenAIResponseAPIClient) { + // OpenAIResponseAPIClient: 根据模型特征选择API类型 + client = this.apiClient.getClient(model) as BaseApiClient + } else { + // 其他client直接使用 + client = this.apiClient + } + + // 2. 构建中间件链 + const builder = CompletionsMiddlewareBuilder.withDefaults() + // images api + if (isDedicatedImageGenerationModel(model)) { + builder.clear() + builder + .add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName]) + .add(MiddlewareRegistry[AbortHandlerMiddlewareName]) + .add(MiddlewareRegistry[ImageGenerationMiddlewareName]) + } else { + // Existing logic for other models + if (!params.enableReasoning) { + builder.remove(ThinkingTagExtractionMiddlewareName) + builder.remove(ThinkChunkMiddlewareName) + } + // 注意:用client判断会导致typescript类型收窄 + if (!(this.apiClient instanceof OpenAIAPIClient)) { + builder.remove(ThinkingTagExtractionMiddlewareName) + } + if (!(this.apiClient instanceof AnthropicAPIClient)) { + builder.remove(RawStreamListenerMiddlewareName) + } + if (!params.enableWebSearch) { + builder.remove(WebSearchMiddlewareName) + } + if (!params.mcpTools?.length) { + builder.remove(ToolUseExtractionMiddlewareName) + builder.remove(McpToolChunkMiddlewareName) + } + if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) { + builder.remove(ToolUseExtractionMiddlewareName) + } + if (params.callType !== 'chat') { + builder.remove(AbortHandlerMiddlewareName) + } + } + + const middlewares = builder.build() + + // 3. Create the wrapped SDK method with middlewares + const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares) + + // 4. Execute the wrapped method with the original params + return wrappedCompletionMethod(params, options) + } + + public async models(): Promise { + return this.apiClient.listModels() + } + + public async getEmbeddingDimensions(model: Model): Promise { + try { + // Use the SDK instance to test embedding capabilities + const dimensions = await this.apiClient.getEmbeddingDimensions(model) + return dimensions + } catch (error) { + console.error('Error getting embedding dimensions:', error) + return 0 + } + } + + public async generateImage(params: GenerateImageParams): Promise { + return this.apiClient.generateImage(params) + } + + public getBaseURL(): string { + return this.apiClient.getBaseURL() + } + + public getApiKey(): string { + return this.apiClient.getApiKey() + } +} diff --git a/src/renderer/src/aiCore/middleware/BUILDER_USAGE.md b/src/renderer/src/aiCore/middleware/BUILDER_USAGE.md new file mode 100644 index 0000000000..27d9e32136 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/BUILDER_USAGE.md @@ -0,0 +1,182 @@ +# MiddlewareBuilder 使用指南 + +`MiddlewareBuilder` 是一个用于动态构建和管理中间件链的工具,提供灵活的中间件组织和配置能力。 + +## 主要特性 + +### 1. 统一的中间件命名 + +所有中间件都通过导出的 `MIDDLEWARE_NAME` 常量标识: + +```typescript +// 中间件文件示例 +export const MIDDLEWARE_NAME = 'SdkCallMiddleware' +export const SdkCallMiddleware: CompletionsMiddleware = ... +``` + +### 2. NamedMiddleware 接口 + +中间件使用统一的 `NamedMiddleware` 接口格式: + +```typescript +interface NamedMiddleware { + name: string + middleware: TMiddleware +} +``` + +### 3. 中间件注册表 + +通过 `MiddlewareRegistry` 集中管理所有可用中间件: + +```typescript +import { MiddlewareRegistry } from './register' + +// 通过名称获取中间件 +const sdkCallMiddleware = MiddlewareRegistry['SdkCallMiddleware'] +``` + +## 基本用法 + +### 1. 使用默认中间件链 + +```typescript +import { CompletionsMiddlewareBuilder } from './builder' + +const builder = CompletionsMiddlewareBuilder.withDefaults() +const middlewares = builder.build() +``` + +### 2. 自定义中间件链 + +```typescript +import { createCompletionsBuilder, MiddlewareRegistry } from './builder' + +const builder = createCompletionsBuilder([ + MiddlewareRegistry['AbortHandlerMiddleware'], + MiddlewareRegistry['TextChunkMiddleware'] +]) + +const middlewares = builder.build() +``` + +### 3. 动态调整中间件链 + +```typescript +const builder = CompletionsMiddlewareBuilder.withDefaults() + +// 根据条件添加、移除、替换中间件 +if (needsLogging) { + builder.prepend(MiddlewareRegistry['GenericLoggingMiddleware']) +} + +if (disableTools) { + builder.remove('McpToolChunkMiddleware') +} + +if (customThinking) { + builder.replace('ThinkingTagExtractionMiddleware', customThinkingMiddleware) +} + +const middlewares = builder.build() +``` + +### 4. 链式操作 + +```typescript +const middlewares = CompletionsMiddlewareBuilder.withDefaults() + .add(MiddlewareRegistry['CustomMiddleware']) + .insertBefore('SdkCallMiddleware', MiddlewareRegistry['SecurityCheckMiddleware']) + .remove('WebSearchMiddleware') + .build() +``` + +## API 参考 + +### CompletionsMiddlewareBuilder + +**静态方法:** + +- `static withDefaults()`: 创建带有默认中间件链的构建器 + +**实例方法:** + +- `add(middleware: NamedMiddleware)`: 在链末尾添加中间件 +- `prepend(middleware: NamedMiddleware)`: 在链开头添加中间件 +- `insertAfter(targetName: string, middleware: NamedMiddleware)`: 在指定中间件后插入 +- `insertBefore(targetName: string, middleware: NamedMiddleware)`: 在指定中间件前插入 +- `replace(targetName: string, middleware: NamedMiddleware)`: 替换指定中间件 +- `remove(targetName: string)`: 移除指定中间件 +- `has(name: string)`: 检查是否包含指定中间件 +- `build()`: 构建最终的中间件数组 +- `getChain()`: 获取当前链(包含名称信息) +- `clear()`: 清空中间件链 +- `execute(context, params, middlewareExecutor)`: 直接执行构建好的中间件链 + +### 工厂函数 + +- `createCompletionsBuilder(baseChain?)`: 创建 Completions 中间件构建器 +- `createMethodBuilder(baseChain?)`: 创建通用方法中间件构建器 +- `addMiddlewareName(middleware, name)`: 为中间件添加名称属性的辅助函数 + +### 中间件注册表 + +- `MiddlewareRegistry`: 所有注册中间件的集中访问点 +- `getMiddleware(name)`: 根据名称获取中间件 +- `getRegisteredMiddlewareNames()`: 获取所有注册的中间件名称 +- `DefaultCompletionsNamedMiddlewares`: 默认的 Completions 中间件链(NamedMiddleware 格式) + +## 类型安全 + +构建器提供完整的 TypeScript 类型支持: + +- `CompletionsMiddlewareBuilder` 专门用于 `CompletionsMiddleware` 类型 +- `MethodMiddlewareBuilder` 用于通用的 `MethodMiddleware` 类型 +- 所有中间件操作都基于 `NamedMiddleware` 接口 + +## 默认中间件链 + +默认的 Completions 中间件执行顺序: + +1. `FinalChunkConsumerMiddleware` - 最终消费者 +2. `TransformCoreToSdkParamsMiddleware` - 参数转换 +3. `AbortHandlerMiddleware` - 中止处理 +4. `McpToolChunkMiddleware` - 工具处理 +5. `WebSearchMiddleware` - Web搜索处理 +6. `TextChunkMiddleware` - 文本处理 +7. `ThinkingTagExtractionMiddleware` - 思考标签提取处理 +8. `ThinkChunkMiddleware` - 思考处理 +9. `ResponseTransformMiddleware` - 响应转换 +10. `StreamAdapterMiddleware` - 流适配器 +11. `SdkCallMiddleware` - SDK调用 + +## 在 AiProvider 中的使用 + +```typescript +export default class AiProvider { + public async completions(params: CompletionsParams): Promise { + // 1. 构建中间件链 + const builder = CompletionsMiddlewareBuilder.withDefaults() + + // 2. 根据参数动态调整 + if (params.enableCustomFeature) { + builder.insertAfter('StreamAdapterMiddleware', customFeatureMiddleware) + } + + // 3. 应用中间件 + const middlewares = builder.build() + const wrappedMethod = applyCompletionsMiddlewares(this.apiClient, this.apiClient.createCompletions, middlewares) + + return wrappedMethod(params) + } +} +``` + +## 注意事项 + +1. **类型兼容性**:`MethodMiddleware` 和 `CompletionsMiddleware` 不兼容,需要使用对应的构建器 +2. **中间件名称**:所有中间件必须导出 `MIDDLEWARE_NAME` 常量用于标识 +3. **注册表管理**:新增中间件需要在 `register.ts` 中注册 +4. **默认链**:默认链通过 `DefaultCompletionsNamedMiddlewares` 提供,支持延迟加载避免循环依赖 + +这种设计使得中间件链的构建既灵活又类型安全,同时保持了简洁的 API 接口。 diff --git a/src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md b/src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md new file mode 100644 index 0000000000..6437282ff2 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md @@ -0,0 +1,175 @@ +# Cherry Studio 中间件规范 + +本文档定义了 Cherry Studio `aiCore` 模块中中间件的设计、实现和使用规范。目标是建立一个灵活、可维护且易于扩展的中间件系统。 + +## 1. 核心概念 + +### 1.1. 中间件 (Middleware) + +中间件是一个函数或对象,它在 AI 请求的处理流程中的特定阶段执行,可以访问和修改请求上下文 (`AiProviderMiddlewareContext`)、请求参数 (`Params`),并控制是否将请求传递给下一个中间件或终止流程。 + +每个中间件应该专注于一个单一的横切关注点,例如日志记录、错误处理、流适配、特性解析等。 + +### 1.2. `AiProviderMiddlewareContext` (上下文对象) + +这是一个在整个中间件链执行过程中传递的对象,包含以下核心信息: + +- `_apiClientInstance: ApiClient`: 当前选定的、已实例化的 AI Provider 客户端。 +- `_coreRequest: CoreRequestType`: 标准化的内部核心请求对象。 +- `resolvePromise: (value: AggregatedResultType) => void`: 用于在整个操作成功完成时解析 `AiCoreService` 返回的 Promise。 +- `rejectPromise: (reason?: any) => void`: 用于在发生错误时拒绝 `AiCoreService` 返回的 Promise。 +- `onChunk?: (chunk: Chunk) => void`: 应用层提供的流式数据块回调。 +- `abortController?: AbortController`: 用于中止请求的控制器。 +- 其他中间件可能读写的、与当前请求相关的动态数据。 + +### 1.3. `MiddlewareName` (中间件名称) + +为了方便动态操作(如插入、替换、移除)中间件,每个重要的、可能被其他逻辑引用的中间件都应该有一个唯一的、可识别的名称。推荐使用 TypeScript 的 `enum` 来定义: + +```typescript +// example +export enum MiddlewareName { + LOGGING_START = 'LoggingStartMiddleware', + LOGGING_END = 'LoggingEndMiddleware', + ERROR_HANDLING = 'ErrorHandlingMiddleware', + ABORT_HANDLER = 'AbortHandlerMiddleware', + // Core Flow + TRANSFORM_CORE_TO_SDK_PARAMS = 'TransformCoreToSdkParamsMiddleware', + REQUEST_EXECUTION = 'RequestExecutionMiddleware', + STREAM_ADAPTER = 'StreamAdapterMiddleware', + RAW_SDK_CHUNK_TO_APP_CHUNK = 'RawSdkChunkToAppChunkMiddleware', + // Features + THINKING_TAG_EXTRACTION = 'ThinkingTagExtractionMiddleware', + TOOL_USE_TAG_EXTRACTION = 'ToolUseTagExtractionMiddleware', + MCP_TOOL_HANDLER = 'McpToolHandlerMiddleware', + // Finalization + FINAL_CHUNK_CONSUMER = 'FinalChunkConsumerAndNotifierMiddleware' + // Add more as needed +} +``` + +中间件实例需要某种方式暴露其 `MiddlewareName`,例如通过一个 `name` 属性。 + +### 1.4. 中间件执行结构 + +我们采用一种灵活的中间件执行结构。一个中间件通常是一个函数,它接收 `Context`、`Params`,以及一个 `next` 函数(用于调用链中的下一个中间件)。 + +```typescript +// 简化形式的中间件函数签名 +type MiddlewareFunction = ( + context: AiProviderMiddlewareContext, + params: any, // e.g., CompletionsParams + next: () => Promise // next 通常返回 Promise 以支持异步操作 +) => Promise // 中间件自身也可能返回 Promise + +// 或者更经典的 Koa/Express 风格 (三段式) +// type MiddlewareFactory = (api?: MiddlewareApi) => +// (nextMiddleware: (ctx: AiProviderMiddlewareContext, params: any) => Promise) => +// (context: AiProviderMiddlewareContext, params: any) => Promise; +// 当前设计更倾向于上述简化的 MiddlewareFunction,由 MiddlewareExecutor 负责 next 的编排。 +``` + +`MiddlewareExecutor` (或 `applyMiddlewares`) 会负责管理 `next` 的调用。 + +## 2. `MiddlewareBuilder` (通用中间件构建器) + +为了动态构建和管理中间件链,我们引入一个通用的 `MiddlewareBuilder` 类。 + +### 2.1. 设计理念 + +`MiddlewareBuilder` 提供了一个流式 API,用于以声明式的方式构建中间件链。它允许从一个基础链开始,然后根据特定条件添加、插入、替换或移除中间件。 + +### 2.2. API 概览 + +```typescript +class MiddlewareBuilder { + constructor(baseChain?: Middleware[]) + + add(middleware: Middleware): this + prepend(middleware: Middleware): this + insertAfter(targetName: MiddlewareName, middlewareToInsert: Middleware): this + insertBefore(targetName: MiddlewareName, middlewareToInsert: Middleware): this + replace(targetName: MiddlewareName, newMiddleware: Middleware): this + remove(targetName: MiddlewareName): this + + build(): Middleware[] // 返回构建好的中间件数组 + + // 可选:直接执行链 + execute( + context: AiProviderMiddlewareContext, + params: any, + middlewareExecutor: (chain: Middleware[], context: AiProviderMiddlewareContext, params: any) => void + ): void +} +``` + +### 2.3. 使用示例 + +```typescript +// 1. 定义一些中间件实例 (假设它们有 .name 属性) +const loggingStart = { name: MiddlewareName.LOGGING_START, fn: loggingStartFn } +const requestExec = { name: MiddlewareName.REQUEST_EXECUTION, fn: requestExecFn } +const streamAdapter = { name: MiddlewareName.STREAM_ADAPTER, fn: streamAdapterFn } +const customFeature = { name: MiddlewareName.CUSTOM_FEATURE, fn: customFeatureFn } // 假设自定义 + +// 2. 定义一个基础链 (可选) +const BASE_CHAIN: Middleware[] = [loggingStart, requestExec, streamAdapter] + +// 3. 使用 MiddlewareBuilder +const builder = new MiddlewareBuilder(BASE_CHAIN) + +if (params.needsCustomFeature) { + builder.insertAfter(MiddlewareName.STREAM_ADAPTER, customFeature) +} + +if (params.isHighSecurityContext) { + builder.insertBefore(MiddlewareName.REQUEST_EXECUTION, высокоSecurityCheckMiddleware) +} + +if (params.overrideLogging) { + builder.replace(MiddlewareName.LOGGING_START, newSpecialLoggingMiddleware) +} + +// 4. 获取最终链 +const finalChain = builder.build() + +// 5. 执行 (通过外部执行器) +// middlewareExecutor(finalChain, context, params); +// 或者 builder.execute(context, params, middlewareExecutor); +``` + +## 3. `MiddlewareExecutor` / `applyMiddlewares` (中间件执行器) + +这是负责接收 `MiddlewareBuilder` 构建的中间件链并实际执行它们的组件。 + +### 3.1. 职责 + +- 接收 `Middleware[]`, `AiProviderMiddlewareContext`, `Params`。 +- 按顺序迭代中间件。 +- 为每个中间件提供正确的 `next` 函数,该函数在被调用时会执行链中的下一个中间件。 +- 处理中间件执行过程中的Promise(如果中间件是异步的)。 +- 基础的错误捕获(具体错误处理应由链内的 `ErrorHandlingMiddleware` 负责)。 + +## 4. 在 `AiCoreService` 中使用 + +`AiCoreService` 中的每个核心业务方法 (如 `executeCompletions`) 将负责: + +1. 准备基础数据:实例化 `ApiClient`,转换 `Params` 为 `CoreRequest`。 +2. 实例化 `MiddlewareBuilder`,可能会传入一个特定于该业务方法的基础中间件链。 +3. 根据 `Params` 和 `CoreRequest` 中的条件,调用 `MiddlewareBuilder` 的方法来动态调整中间件链。 +4. 调用 `MiddlewareBuilder.build()` 获取最终的中间件链。 +5. 创建完整的 `AiProviderMiddlewareContext` (包含 `resolvePromise`, `rejectPromise` 等)。 +6. 调用 `MiddlewareExecutor` (或 `applyMiddlewares`) 来执行构建好的链。 + +## 5. 组合功能 + +对于组合功能(例如 "Completions then Translate"): + +- 不推荐创建一个单一、庞大的 `MiddlewareBuilder` 来处理整个组合流程。 +- 推荐在 `AiCoreService` 中创建一个新的方法,该方法按顺序 `await` 调用底层的原子 `AiCoreService` 方法(例如,先 `await this.executeCompletions(...)`,然后用其结果 `await this.translateText(...)`)。 +- 每个被调用的原子方法内部会使用其自身的 `MiddlewareBuilder` 实例来构建和执行其特定阶段的中间件链。 +- 这种方式最大化了复用,并保持了各部分职责的清晰。 + +## 6. 中间件命名和发现 + +为中间件赋予唯一的 `MiddlewareName` 对于 `MiddlewareBuilder` 的 `insertAfter`, `insertBefore`, `replace`, `remove` 等操作至关重要。确保中间件实例能够以某种方式暴露其名称(例如,一个 `name` 属性)。 diff --git a/src/renderer/src/aiCore/middleware/builder.ts b/src/renderer/src/aiCore/middleware/builder.ts new file mode 100644 index 0000000000..e76b59c2bd --- /dev/null +++ b/src/renderer/src/aiCore/middleware/builder.ts @@ -0,0 +1,241 @@ +import { DefaultCompletionsNamedMiddlewares } from './register' +import { BaseContext, CompletionsMiddleware, MethodMiddleware } from './types' + +/** + * 带有名称标识的中间件接口 + */ +export interface NamedMiddleware { + name: string + middleware: TMiddleware +} + +/** + * 中间件执行器函数类型 + */ +export type MiddlewareExecutor = ( + chain: any[], + context: TContext, + params: any +) => Promise + +/** + * 通用中间件构建器类 + * 提供流式 API 用于动态构建和管理中间件链 + * + * 注意:所有中间件都通过 MiddlewareRegistry 管理,使用 NamedMiddleware 格式 + */ +export class MiddlewareBuilder { + private middlewares: NamedMiddleware[] + + /** + * 构造函数 + * @param baseChain - 可选的基础中间件链(NamedMiddleware 格式) + */ + constructor(baseChain?: NamedMiddleware[]) { + this.middlewares = baseChain ? [...baseChain] : [] + } + + /** + * 在链的末尾添加中间件 + * @param middleware - 要添加的具名中间件 + * @returns this,支持链式调用 + */ + add(middleware: NamedMiddleware): this { + this.middlewares.push(middleware) + return this + } + + /** + * 在链的开头添加中间件 + * @param middleware - 要添加的具名中间件 + * @returns this,支持链式调用 + */ + prepend(middleware: NamedMiddleware): this { + this.middlewares.unshift(middleware) + return this + } + + /** + * 在指定中间件之后插入新中间件 + * @param targetName - 目标中间件名称 + * @param middlewareToInsert - 要插入的具名中间件 + * @returns this,支持链式调用 + */ + insertAfter(targetName: string, middlewareToInsert: NamedMiddleware): this { + const index = this.findMiddlewareIndex(targetName) + if (index !== -1) { + this.middlewares.splice(index + 1, 0, middlewareToInsert) + } else { + console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`) + } + return this + } + + /** + * 在指定中间件之前插入新中间件 + * @param targetName - 目标中间件名称 + * @param middlewareToInsert - 要插入的具名中间件 + * @returns this,支持链式调用 + */ + insertBefore(targetName: string, middlewareToInsert: NamedMiddleware): this { + const index = this.findMiddlewareIndex(targetName) + if (index !== -1) { + this.middlewares.splice(index, 0, middlewareToInsert) + } else { + console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`) + } + return this + } + + /** + * 替换指定的中间件 + * @param targetName - 要替换的中间件名称 + * @param newMiddleware - 新的具名中间件 + * @returns this,支持链式调用 + */ + replace(targetName: string, newMiddleware: NamedMiddleware): this { + const index = this.findMiddlewareIndex(targetName) + if (index !== -1) { + this.middlewares[index] = newMiddleware + } else { + console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法替换`) + } + return this + } + + /** + * 移除指定的中间件 + * @param targetName - 要移除的中间件名称 + * @returns this,支持链式调用 + */ + remove(targetName: string): this { + const index = this.findMiddlewareIndex(targetName) + if (index !== -1) { + this.middlewares.splice(index, 1) + } + return this + } + + /** + * 构建最终的中间件数组 + * @returns 构建好的中间件数组 + */ + build(): TMiddleware[] { + return this.middlewares.map((item) => item.middleware) + } + + /** + * 获取当前中间件链的副本(包含名称信息) + * @returns 当前中间件链的副本 + */ + getChain(): NamedMiddleware[] { + return [...this.middlewares] + } + + /** + * 检查是否包含指定名称的中间件 + * @param name - 中间件名称 + * @returns 是否包含该中间件 + */ + has(name: string): boolean { + return this.findMiddlewareIndex(name) !== -1 + } + + /** + * 获取中间件链的长度 + * @returns 中间件数量 + */ + get length(): number { + return this.middlewares.length + } + + /** + * 清空中间件链 + * @returns this,支持链式调用 + */ + clear(): this { + this.middlewares = [] + return this + } + + /** + * 直接执行构建好的中间件链 + * @param context - 中间件上下文 + * @param params - 参数 + * @param middlewareExecutor - 中间件执行器 + * @returns 执行结果 + */ + execute( + context: TContext, + params: any, + middlewareExecutor: MiddlewareExecutor + ): Promise { + const chain = this.build() + return middlewareExecutor(chain, context, params) + } + + /** + * 查找中间件在链中的索引 + * @param name - 中间件名称 + * @returns 索引,如果未找到返回 -1 + */ + private findMiddlewareIndex(name: string): number { + return this.middlewares.findIndex((item) => item.name === name) + } +} + +/** + * Completions 中间件构建器 + */ +export class CompletionsMiddlewareBuilder extends MiddlewareBuilder { + constructor(baseChain?: NamedMiddleware[]) { + super(baseChain) + } + + /** + * 使用默认的 Completions 中间件链 + * @returns CompletionsMiddlewareBuilder 实例 + */ + static withDefaults(): CompletionsMiddlewareBuilder { + return new CompletionsMiddlewareBuilder(DefaultCompletionsNamedMiddlewares) + } +} + +/** + * 通用方法中间件构建器 + */ +export class MethodMiddlewareBuilder extends MiddlewareBuilder { + constructor(baseChain?: NamedMiddleware[]) { + super(baseChain) + } +} + +// 便捷的工厂函数 + +/** + * 创建 Completions 中间件构建器 + * @param baseChain - 可选的基础链 + * @returns Completions 中间件构建器实例 + */ +export function createCompletionsBuilder( + baseChain?: NamedMiddleware[] +): CompletionsMiddlewareBuilder { + return new CompletionsMiddlewareBuilder(baseChain) +} + +/** + * 创建通用方法中间件构建器 + * @param baseChain - 可选的基础链 + * @returns 通用方法中间件构建器实例 + */ +export function createMethodBuilder(baseChain?: NamedMiddleware[]): MethodMiddlewareBuilder { + return new MethodMiddlewareBuilder(baseChain) +} + +/** + * 为中间件添加名称属性的辅助函数 + * 可以用于给现有的中间件添加名称属性 + */ +export function addMiddlewareName(middleware: T, name: string): T & { MIDDLEWARE_NAME: string } { + return Object.assign(middleware, { MIDDLEWARE_NAME: name }) +} diff --git a/src/renderer/src/aiCore/middleware/common/AbortHandlerMiddleware.ts b/src/renderer/src/aiCore/middleware/common/AbortHandlerMiddleware.ts new file mode 100644 index 0000000000..7186cec12f --- /dev/null +++ b/src/renderer/src/aiCore/middleware/common/AbortHandlerMiddleware.ts @@ -0,0 +1,106 @@ +import { Chunk, ChunkType, ErrorChunk } from '@renderer/types/chunk' +import { addAbortController, removeAbortController } from '@renderer/utils/abortController' + +import { CompletionsParams, CompletionsResult } from '../schemas' +import type { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'AbortHandlerMiddleware' + +export const AbortHandlerMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + const isRecursiveCall = ctx._internal?.toolProcessingState?.isRecursiveCall || false + + // 在递归调用中,跳过 AbortController 的创建,直接使用已有的 + if (isRecursiveCall) { + const result = await next(ctx, params) + return result + } + + // 获取当前消息的ID用于abort管理 + // 优先使用处理过的消息,如果没有则使用原始消息 + let messageId: string | undefined + + if (typeof params.messages === 'string') { + messageId = `message-${Date.now()}-${Math.random().toString(36).substring(2, 9)}` + } else { + const processedMessages = params.messages + const lastUserMessage = processedMessages.findLast((m) => m.role === 'user') + messageId = lastUserMessage?.id + } + + if (!messageId) { + console.warn(`[${MIDDLEWARE_NAME}] No messageId found, abort functionality will not be available.`) + return next(ctx, params) + } + + const abortController = new AbortController() + const abortFn = (): void => abortController.abort() + + addAbortController(messageId, abortFn) + + let abortSignal: AbortSignal | null = abortController.signal + + const cleanup = (): void => { + removeAbortController(messageId as string, abortFn) + if (ctx._internal?.flowControl) { + ctx._internal.flowControl.abortController = undefined + ctx._internal.flowControl.abortSignal = undefined + ctx._internal.flowControl.cleanup = undefined + } + abortSignal = null + } + + // 将controller添加到_internal中的flowControl状态 + if (!ctx._internal.flowControl) { + ctx._internal.flowControl = {} + } + ctx._internal.flowControl.abortController = abortController + ctx._internal.flowControl.abortSignal = abortSignal + ctx._internal.flowControl.cleanup = cleanup + + const result = await next(ctx, params) + + const error = new DOMException('Request was aborted', 'AbortError') + + const streamWithAbortHandler = (result.stream as ReadableStream).pipeThrough( + new TransformStream({ + transform(chunk, controller) { + // 检查 abort 状态 + if (abortSignal?.aborted) { + // 转换为 ErrorChunk + const errorChunk: ErrorChunk = { + type: ChunkType.ERROR, + error + } + + controller.enqueue(errorChunk) + cleanup() + return + } + + // 正常传递 chunk + controller.enqueue(chunk) + }, + + flush(controller) { + // 在流结束时再次检查 abort 状态 + if (abortSignal?.aborted) { + const errorChunk: ErrorChunk = { + type: ChunkType.ERROR, + error + } + controller.enqueue(errorChunk) + } + // 在流完全处理完成后清理 AbortController + cleanup() + } + }) + ) + + return { + ...result, + stream: streamWithAbortHandler + } + } diff --git a/src/renderer/src/aiCore/middleware/common/ErrorHandlerMiddleware.ts b/src/renderer/src/aiCore/middleware/common/ErrorHandlerMiddleware.ts new file mode 100644 index 0000000000..2dd5aa9833 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/common/ErrorHandlerMiddleware.ts @@ -0,0 +1,60 @@ +import { Chunk } from '@renderer/types/chunk' +import { isAbortError } from '@renderer/utils/error' + +import { CompletionsResult } from '../schemas' +import { CompletionsContext } from '../types' +import { createErrorChunk } from '../utils' + +export const MIDDLEWARE_NAME = 'ErrorHandlerMiddleware' + +/** + * 创建一个错误处理中间件。 + * + * 这是一个高阶函数,它接收配置并返回一个标准的中间件。 + * 它的主要职责是捕获下游中间件或API调用中发生的任何错误。 + * + * @param config - 中间件的配置。 + * @returns 一个配置好的CompletionsMiddleware。 + */ +export const ErrorHandlerMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params): Promise => { + const { shouldThrow } = params + + try { + // 尝试执行下一个中间件 + return await next(ctx, params) + } catch (error: any) { + let errorStream: ReadableStream | undefined + // 有些sdk的abort error 是直接抛出的 + if (!isAbortError(error)) { + // 1. 使用通用的工具函数将错误解析为标准格式 + const errorChunk = createErrorChunk(error) + // 2. 调用从外部传入的 onError 回调 + if (params.onError) { + params.onError(error) + } + + // 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递 + if (shouldThrow) { + throw error + } + + // 如果不抛出,则创建一个只包含该错误块的流并向下传递 + errorStream = new ReadableStream({ + start(controller) { + controller.enqueue(errorChunk) + controller.close() + } + }) + } + + return { + rawOutput: undefined, + stream: errorStream, // 将包含错误的流传递下去 + controller: undefined, + getText: () => '' // 错误情况下没有文本结果 + } + } + } diff --git a/src/renderer/src/aiCore/middleware/common/FinalChunkConsumerMiddleware.ts b/src/renderer/src/aiCore/middleware/common/FinalChunkConsumerMiddleware.ts new file mode 100644 index 0000000000..b0b9bd7ce6 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/common/FinalChunkConsumerMiddleware.ts @@ -0,0 +1,183 @@ +import Logger from '@renderer/config/logger' +import { Usage } from '@renderer/types' +import type { Chunk } from '@renderer/types/chunk' +import { ChunkType } from '@renderer/types/chunk' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'FinalChunkConsumerAndNotifierMiddleware' + +/** + * 最终Chunk消费和通知中间件 + * + * 职责: + * 1. 消费所有GenericChunk流中的chunks并转发给onChunk回调 + * 2. 累加usage/metrics数据(从原始SDK chunks或GenericChunk中提取) + * 3. 在检测到LLM_RESPONSE_COMPLETE时发送包含累计数据的BLOCK_COMPLETE + * 4. 处理MCP工具调用的多轮请求中的数据累加 + */ +const FinalChunkConsumerMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + const isRecursiveCall = + params._internal?.toolProcessingState?.isRecursiveCall || + ctx._internal?.toolProcessingState?.isRecursiveCall || + false + + // 初始化累计数据(只在顶层调用时初始化) + if (!isRecursiveCall) { + if (!ctx._internal.customState) { + ctx._internal.customState = {} + } + ctx._internal.observer = { + usage: { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0 + }, + metrics: { + completion_tokens: 0, + time_completion_millsec: 0, + time_first_token_millsec: 0, + time_thinking_millsec: 0 + } + } + // 初始化文本累积器 + ctx._internal.customState.accumulatedText = '' + ctx._internal.customState.startTimestamp = Date.now() + } + + // 调用下游中间件 + const result = await next(ctx, params) + + // 响应后处理:处理GenericChunk流式响应 + if (result.stream) { + const resultFromUpstream = result.stream + + if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { + const reader = resultFromUpstream.getReader() + + try { + while (true) { + const { done, value: chunk } = await reader.read() + if (done) { + Logger.debug(`[${MIDDLEWARE_NAME}] Input stream finished.`) + break + } + + if (chunk) { + const genericChunk = chunk as GenericChunk + // 提取并累加usage/metrics数据 + extractAndAccumulateUsageMetrics(ctx, genericChunk) + + const shouldSkipChunk = + isRecursiveCall && + (genericChunk.type === ChunkType.BLOCK_COMPLETE || + genericChunk.type === ChunkType.LLM_RESPONSE_COMPLETE) + + if (!shouldSkipChunk) params.onChunk?.(genericChunk) + } else { + Logger.warn(`[${MIDDLEWARE_NAME}] Received undefined chunk before stream was done.`) + } + } + } catch (error) { + Logger.error(`[${MIDDLEWARE_NAME}] Error consuming stream:`, error) + throw error + } finally { + if (params.onChunk && !isRecursiveCall) { + params.onChunk({ + type: ChunkType.BLOCK_COMPLETE, + response: { + usage: ctx._internal.observer?.usage ? { ...ctx._internal.observer.usage } : undefined, + metrics: ctx._internal.observer?.metrics ? { ...ctx._internal.observer.metrics } : undefined + } + } as Chunk) + if (ctx._internal.toolProcessingState) { + ctx._internal.toolProcessingState = {} + } + } + } + + // 为流式输出添加getText方法 + const modifiedResult = { + ...result, + stream: new ReadableStream({ + start(controller) { + controller.close() + } + }), + getText: () => { + return ctx._internal.customState?.accumulatedText || '' + } + } + + return modifiedResult + } else { + Logger.debug(`[${MIDDLEWARE_NAME}] No GenericChunk stream to process.`) + } + } + + return result + } + +/** + * 从GenericChunk或原始SDK chunks中提取usage/metrics数据并累加 + */ +function extractAndAccumulateUsageMetrics(ctx: CompletionsContext, chunk: GenericChunk): void { + if (!ctx._internal.observer?.usage || !ctx._internal.observer?.metrics) { + return + } + + try { + if (ctx._internal.customState && !ctx._internal.customState?.firstTokenTimestamp) { + ctx._internal.customState.firstTokenTimestamp = Date.now() + Logger.debug(`[${MIDDLEWARE_NAME}] First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`) + } + if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) { + Logger.debug(`[${MIDDLEWARE_NAME}] LLM_RESPONSE_COMPLETE chunk received:`, ctx._internal) + // 从LLM_RESPONSE_COMPLETE chunk中提取usage数据 + if (chunk.response?.usage) { + accumulateUsage(ctx._internal.observer.usage, chunk.response.usage) + } + + if (ctx._internal.customState && ctx._internal.customState?.firstTokenTimestamp) { + ctx._internal.observer.metrics.time_first_token_millsec = + ctx._internal.customState.firstTokenTimestamp - ctx._internal.customState.startTimestamp + ctx._internal.observer.metrics.time_completion_millsec += + Date.now() - ctx._internal.customState.firstTokenTimestamp + } + } + + // 也可以从其他chunk类型中提取metrics数据 + if (chunk.type === ChunkType.THINKING_COMPLETE && chunk.thinking_millsec && ctx._internal.observer?.metrics) { + ctx._internal.observer.metrics.time_thinking_millsec = Math.max( + ctx._internal.observer.metrics.time_thinking_millsec || 0, + chunk.thinking_millsec + ) + } + } catch (error) { + console.error(`[${MIDDLEWARE_NAME}] Error extracting usage/metrics from chunk:`, error) + } +} + +/** + * 累加usage数据 + */ +function accumulateUsage(accumulated: Usage, newUsage: Usage): void { + if (newUsage.prompt_tokens !== undefined) { + accumulated.prompt_tokens += newUsage.prompt_tokens + } + if (newUsage.completion_tokens !== undefined) { + accumulated.completion_tokens += newUsage.completion_tokens + } + if (newUsage.total_tokens !== undefined) { + accumulated.total_tokens += newUsage.total_tokens + } + if (newUsage.thoughts_tokens !== undefined) { + accumulated.thoughts_tokens = (accumulated.thoughts_tokens || 0) + newUsage.thoughts_tokens + } +} + +export default FinalChunkConsumerMiddleware diff --git a/src/renderer/src/aiCore/middleware/common/LoggingMiddleware.ts b/src/renderer/src/aiCore/middleware/common/LoggingMiddleware.ts new file mode 100644 index 0000000000..361eea3119 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/common/LoggingMiddleware.ts @@ -0,0 +1,64 @@ +import { BaseContext, MethodMiddleware, MiddlewareAPI } from '../types' + +export const MIDDLEWARE_NAME = 'GenericLoggingMiddlewares' + +/** + * Helper function to safely stringify arguments for logging, handling circular references and large objects. + * 安全地字符串化日志参数的辅助函数,处理循环引用和大型对象。 + * @param args - The arguments array to stringify. 要字符串化的参数数组。 + * @returns A string representation of the arguments. 参数的字符串表示形式。 + */ +const stringifyArgsForLogging = (args: any[]): string => { + try { + return args + .map((arg) => { + if (typeof arg === 'function') return '[Function]' + if (typeof arg === 'object' && arg !== null && arg.constructor === Object && Object.keys(arg).length > 20) { + return '[Object with >20 keys]' + } + // Truncate long strings to avoid flooding logs 截断长字符串以避免日志泛滥 + const stringifiedArg = JSON.stringify(arg, null, 2) + return stringifiedArg && stringifiedArg.length > 200 ? stringifiedArg.substring(0, 200) + '...' : stringifiedArg + }) + .join(', ') + } catch (e) { + return '[Error serializing arguments]' // Handle potential errors during stringification 处理字符串化期间的潜在错误 + } +} + +/** + * Generic logging middleware for provider methods. + * 为提供者方法创建一个通用的日志中间件。 + * This middleware logs the initiation, success/failure, and duration of a method call. + * 此中间件记录方法调用的启动、成功/失败以及持续时间。 + */ + +/** + * Creates a generic logging middleware for provider methods. + * 为提供者方法创建一个通用的日志中间件。 + * @returns A `MethodMiddleware` instance. 一个 `MethodMiddleware` 实例。 + */ +export const createGenericLoggingMiddleware: () => MethodMiddleware = () => { + const middlewareName = 'GenericLoggingMiddleware' + // eslint-disable-next-line @typescript-eslint/no-unused-vars + return (_: MiddlewareAPI) => (next) => async (ctx, args) => { + const methodName = ctx.methodName + const logPrefix = `[${middlewareName} (${methodName})]` + console.log(`${logPrefix} Initiating. Args:`, stringifyArgsForLogging(args)) + const startTime = Date.now() + try { + const result = await next(ctx, args) + const duration = Date.now() - startTime + // Log successful completion of the method call with duration. / + // 记录方法调用成功完成及其持续时间。 + console.log(`${logPrefix} Successful. Duration: ${duration}ms`) + return result + } catch (error) { + const duration = Date.now() - startTime + // Log failure of the method call with duration and error information. / + // 记录方法调用失败及其持续时间和错误信息。 + console.error(`${logPrefix} Failed. Duration: ${duration}ms`, error) + throw error // Re-throw the error to be handled by subsequent layers or the caller / 重新抛出错误,由后续层或调用者处理 + } + } +} diff --git a/src/renderer/src/aiCore/middleware/composer.ts b/src/renderer/src/aiCore/middleware/composer.ts new file mode 100644 index 0000000000..8b93b8015a --- /dev/null +++ b/src/renderer/src/aiCore/middleware/composer.ts @@ -0,0 +1,285 @@ +import { + RequestOptions, + SdkInstance, + SdkMessageParam, + SdkParams, + SdkRawChunk, + SdkRawOutput, + SdkTool, + SdkToolCall +} from '@renderer/types/sdk' + +import { BaseApiClient } from '../clients' +import { CompletionsParams, CompletionsResult } from './schemas' +import { + BaseContext, + CompletionsContext, + CompletionsMiddleware, + MethodMiddleware, + MIDDLEWARE_CONTEXT_SYMBOL, + MiddlewareAPI +} from './types' + +/** + * Creates the initial context for a method call, populating method-specific fields. / + * 为方法调用创建初始上下文,并填充特定于该方法的字段。 + * @param methodName - The name of the method being called. / 被调用的方法名。 + * @param originalCallArgs - The actual arguments array from the proxy/method call. / 代理/方法调用的实际参数数组。 + * @param providerId - The ID of the provider, if available. / 提供者的ID(如果可用)。 + * @param providerInstance - The instance of the provider. / 提供者实例。 + * @param specificContextFactory - An optional factory function to create a specific context type from the base context and original call arguments. / 一个可选的工厂函数,用于从基础上下文和原始调用参数创建特定的上下文类型。 + * @returns The created context object. / 创建的上下文对象。 + */ +function createInitialCallContext( + methodName: string, + originalCallArgs: TCallArgs, // Renamed from originalArgs to avoid confusion with context.originalArgs + // Factory to create specific context from base and the *original call arguments array* + specificContextFactory?: (base: BaseContext, callArgs: TCallArgs) => TContext +): TContext { + const baseContext: BaseContext = { + [MIDDLEWARE_CONTEXT_SYMBOL]: true, + methodName, + originalArgs: originalCallArgs // Store the full original arguments array in the context + } + + if (specificContextFactory) { + return specificContextFactory(baseContext, originalCallArgs) + } + return baseContext as TContext // Fallback to base context if no specific factory +} + +/** + * Composes an array of functions from right to left. / + * 从右到左组合一个函数数组。 + * `compose(f, g, h)` is `(...args) => f(g(h(...args)))`. / + * `compose(f, g, h)` 等同于 `(...args) => f(g(h(...args)))`。 + * Each function in funcs is expected to take the result of the next function + * (or the initial value for the rightmost function) as its argument. / + * `funcs` 中的每个函数都期望接收下一个函数的结果(或最右侧函数的初始值)作为其参数。 + * @param funcs - Array of functions to compose. / 要组合的函数数组。 + * @returns The composed function. / 组合后的函数。 + */ +function compose(...funcs: Array<(...args: any[]) => any>): (...args: any[]) => any { + if (funcs.length === 0) { + // If no functions to compose, return a function that returns its first argument, or undefined if no args. / + // 如果没有要组合的函数,则返回一个函数,该函数返回其第一个参数,如果没有参数则返回undefined。 + return (...args: any[]) => (args.length > 0 ? args[0] : undefined) + } + if (funcs.length === 1) { + return funcs[0] + } + return funcs.reduce( + (a, b) => + (...args: any[]) => + a(b(...args)) + ) +} + +/** + * Applies an array of Redux-style middlewares to a generic provider method. / + * 将一组Redux风格的中间件应用于一个通用的提供者方法。 + * This version keeps arguments as an array throughout the middleware chain. / + * 此版本在整个中间件链中将参数保持为数组形式。 + * @param originalProviderInstance - The original provider instance. / 原始提供者实例。 + * @param methodName - The name of the method to be enhanced. / 需要增强的方法名。 + * @param originalMethod - The original method to be wrapped. / 需要包装的原始方法。 + * @param middlewares - An array of `ProviderMethodMiddleware` to apply. / 要应用的 `ProviderMethodMiddleware` 数组。 + * @param specificContextFactory - An optional factory to create a specific context for this method. / 可选的工厂函数,用于为此方法创建特定的上下文。 + * @returns An enhanced method with the middlewares applied. / 应用了中间件的增强方法。 + */ +export function applyMethodMiddlewares< + TArgs extends unknown[] = unknown[], // Original method's arguments array type / 原始方法的参数数组类型 + TResult = unknown, + TContext extends BaseContext = BaseContext +>( + methodName: string, + originalMethod: (...args: TArgs) => Promise, + middlewares: MethodMiddleware[], // Expects generic middlewares / 期望通用中间件 + specificContextFactory?: (base: BaseContext, callArgs: TArgs) => TContext +): (...args: TArgs) => Promise { + // Returns a function matching the original method signature. / + // 返回一个与原始方法签名匹配的函数。 + return async function enhancedMethod(...methodCallArgs: TArgs): Promise { + const ctx = createInitialCallContext( + methodName, + methodCallArgs, // Pass the actual call arguments array / 传递实际的调用参数数组 + specificContextFactory + ) + + const api: MiddlewareAPI = { + getContext: () => ctx, + getOriginalArgs: () => methodCallArgs // API provides the original arguments array / API提供原始参数数组 + } + + // `finalDispatch` is the function that will ultimately call the original provider method. / + // `finalDispatch` 是最终将调用原始提供者方法的函数。 + // It receives the current context and arguments, which may have been transformed by middlewares. / + // 它接收当前的上下文和参数,这些参数可能已被中间件转换。 + const finalDispatch = async ( + _: TContext, + currentArgs: TArgs // Generic final dispatch expects args array / 通用finalDispatch期望参数数组 + ): Promise => { + return originalMethod.apply(currentArgs) + } + + const chain = middlewares.map((middleware) => middleware(api)) // Cast API if TContext/TArgs mismatch general ProviderMethodMiddleware / 如果TContext/TArgs与通用的ProviderMethodMiddleware不匹配,则转换API + const composedMiddlewareLogic = compose(...chain) + const enhancedDispatch = composedMiddlewareLogic(finalDispatch) + + return enhancedDispatch(ctx, methodCallArgs) // Pass context and original args array / 传递上下文和原始参数数组 + } +} + +/** + * Applies an array of `CompletionsMiddleware` to the `completions` method. / + * 将一组 `CompletionsMiddleware` 应用于 `completions` 方法。 + * This version adapts for `CompletionsMiddleware` expecting a single `params` object. / + * 此版本适配了期望单个 `params` 对象的 `CompletionsMiddleware`。 + * @param originalProviderInstance - The original provider instance. / 原始提供者实例。 + * @param originalCompletionsMethod - The original SDK `createCompletions` method. / 原始的 SDK `createCompletions` 方法。 + * @param middlewares - An array of `CompletionsMiddleware` to apply. / 要应用的 `CompletionsMiddleware` 数组。 + * @returns An enhanced `completions` method with the middlewares applied. / 应用了中间件的增强版 `completions` 方法。 + */ +export function applyCompletionsMiddlewares< + TSdkInstance extends SdkInstance = SdkInstance, + TSdkParams extends SdkParams = SdkParams, + TRawOutput extends SdkRawOutput = SdkRawOutput, + TRawChunk extends SdkRawChunk = SdkRawChunk, + TMessageParam extends SdkMessageParam = SdkMessageParam, + TToolCall extends SdkToolCall = SdkToolCall, + TSdkSpecificTool extends SdkTool = SdkTool +>( + originalApiClientInstance: BaseApiClient< + TSdkInstance, + TSdkParams, + TRawOutput, + TRawChunk, + TMessageParam, + TToolCall, + TSdkSpecificTool + >, + originalCompletionsMethod: (payload: TSdkParams, options?: RequestOptions) => Promise, + middlewares: CompletionsMiddleware< + TSdkParams, + TMessageParam, + TToolCall, + TSdkInstance, + TRawOutput, + TRawChunk, + TSdkSpecificTool + >[] +): (params: CompletionsParams, options?: RequestOptions) => Promise { + // Returns a function matching the original method signature. / + // 返回一个与原始方法签名匹配的函数。 + + const methodName = 'completions' + + // Factory to create AiProviderMiddlewareCompletionsContext. / + // 用于创建 AiProviderMiddlewareCompletionsContext 的工厂函数。 + const completionsContextFactory = ( + base: BaseContext, + callArgs: [CompletionsParams] + ): CompletionsContext< + TSdkParams, + TMessageParam, + TToolCall, + TSdkInstance, + TRawOutput, + TRawChunk, + TSdkSpecificTool + > => { + return { + ...base, + methodName, + apiClientInstance: originalApiClientInstance, + originalArgs: callArgs, + _internal: { + toolProcessingState: { + recursionDepth: 0, + isRecursiveCall: false + }, + observer: {} + } + } + } + + return async function enhancedCompletionsMethod( + params: CompletionsParams, + options?: RequestOptions + ): Promise { + // `originalCallArgs` for context creation is `[params]`. / + // 用于上下文创建的 `originalCallArgs` 是 `[params]`。 + const originalCallArgs: [CompletionsParams] = [params] + const baseContext: BaseContext = { + [MIDDLEWARE_CONTEXT_SYMBOL]: true, + methodName, + originalArgs: originalCallArgs + } + const ctx = completionsContextFactory(baseContext, originalCallArgs) + + const api: MiddlewareAPI< + CompletionsContext, + [CompletionsParams] + > = { + getContext: () => ctx, + getOriginalArgs: () => originalCallArgs // API provides [CompletionsParams] / API提供 `[CompletionsParams]` + } + + // `finalDispatch` for CompletionsMiddleware: expects (context, params) not (context, args_array). / + // `CompletionsMiddleware` 的 `finalDispatch`:期望 (context, params) 而不是 (context, args_array)。 + const finalDispatch = async ( + context: CompletionsContext< + TSdkParams, + TMessageParam, + TToolCall, + TSdkInstance, + TRawOutput, + TRawChunk, + TSdkSpecificTool + > // Context passed through / 上下文透传 + // _currentParams: CompletionsParams // Directly takes params / 直接接收参数 (unused but required for middleware signature) + ): Promise => { + // At this point, middleware should have transformed CompletionsParams to SDK params + // and stored them in context. If no transformation happened, we need to handle it. + // 此时,中间件应该已经将 CompletionsParams 转换为 SDK 参数并存储在上下文中。 + // 如果没有进行转换,我们需要处理它。 + + const sdkPayload = context._internal?.sdkPayload + if (!sdkPayload) { + throw new Error('SDK payload not found in context. Middleware chain should have transformed parameters.') + } + + const abortSignal = context._internal.flowControl?.abortSignal + const timeout = context._internal.customState?.sdkMetadata?.timeout + + // Call the original SDK method with transformed parameters + // 使用转换后的参数调用原始 SDK 方法 + const rawOutput = await originalCompletionsMethod.call(originalApiClientInstance, sdkPayload, { + ...options, + signal: abortSignal, + timeout + }) + + // Return result wrapped in CompletionsResult format + // 以 CompletionsResult 格式返回包装的结果 + return { + rawOutput + } as CompletionsResult + } + + const chain = middlewares.map((middleware) => middleware(api)) + const composedMiddlewareLogic = compose(...chain) + + // `enhancedDispatch` has the signature `(context, params) => Promise`. / + // `enhancedDispatch` 的签名为 `(context, params) => Promise`。 + const enhancedDispatch = composedMiddlewareLogic(finalDispatch) + + // 将 enhancedDispatch 保存到 context 中,供中间件进行递归调用 + // 这样可以避免重复执行整个中间件链 + ctx._internal.enhancedDispatch = enhancedDispatch + + // Execute with context and the single params object. / + // 使用上下文和单个参数对象执行。 + return enhancedDispatch(ctx, params) + } +} diff --git a/src/renderer/src/aiCore/middleware/core/McpToolChunkMiddleware.ts b/src/renderer/src/aiCore/middleware/core/McpToolChunkMiddleware.ts new file mode 100644 index 0000000000..3dd046c12e --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/McpToolChunkMiddleware.ts @@ -0,0 +1,306 @@ +import Logger from '@renderer/config/logger' +import { MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types' +import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk' +import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk' +import { parseAndCallTools } from '@renderer/utils/mcp-tools' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'McpToolChunkMiddleware' +const MAX_TOOL_RECURSION_DEPTH = 20 // 防止无限递归 + +/** + * MCP工具处理中间件 + * + * 职责: + * 1. 检测并拦截MCP工具进展chunk(Function Call方式和Tool Use方式) + * 2. 执行工具调用 + * 3. 递归处理工具结果 + * 4. 管理工具调用状态和递归深度 + */ +export const McpToolChunkMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + const mcpTools = params.mcpTools || [] + + // 如果没有工具,直接调用下一个中间件 + if (!mcpTools || mcpTools.length === 0) { + return next(ctx, params) + } + + const executeWithToolHandling = async (currentParams: CompletionsParams, depth = 0): Promise => { + if (depth >= MAX_TOOL_RECURSION_DEPTH) { + Logger.error(`🔧 [${MIDDLEWARE_NAME}] Maximum recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`) + throw new Error(`Maximum tool recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`) + } + + let result: CompletionsResult + + if (depth === 0) { + result = await next(ctx, currentParams) + } else { + const enhancedCompletions = ctx._internal.enhancedDispatch + if (!enhancedCompletions) { + Logger.error(`🔧 [${MIDDLEWARE_NAME}] Enhanced completions method not found, cannot perform recursive call`) + throw new Error('Enhanced completions method not found') + } + + ctx._internal.toolProcessingState!.isRecursiveCall = true + ctx._internal.toolProcessingState!.recursionDepth = depth + + result = await enhancedCompletions(ctx, currentParams) + } + + if (!result.stream) { + Logger.error(`🔧 [${MIDDLEWARE_NAME}] No stream returned from enhanced completions`) + throw new Error('No stream returned from enhanced completions') + } + + const resultFromUpstream = result.stream as ReadableStream + const toolHandlingStream = resultFromUpstream.pipeThrough( + createToolHandlingTransform(ctx, currentParams, mcpTools, depth, executeWithToolHandling) + ) + + return { + ...result, + stream: toolHandlingStream + } + } + + return executeWithToolHandling(params, 0) + } + +/** + * 创建工具处理的 TransformStream + */ +function createToolHandlingTransform( + ctx: CompletionsContext, + currentParams: CompletionsParams, + mcpTools: MCPTool[], + depth: number, + executeWithToolHandling: (params: CompletionsParams, depth: number) => Promise +): TransformStream { + const toolCalls: SdkToolCall[] = [] + const toolUseResponses: MCPToolResponse[] = [] + const allToolResponses: MCPToolResponse[] = [] // 统一的工具响应状态管理数组 + let hasToolCalls = false + let hasToolUseResponses = false + let streamEnded = false + + return new TransformStream({ + async transform(chunk: GenericChunk, controller) { + try { + // 处理MCP工具进展chunk + if (chunk.type === ChunkType.MCP_TOOL_CREATED) { + const createdChunk = chunk as MCPToolCreatedChunk + + // 1. 处理Function Call方式的工具调用 + if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) { + toolCalls.push(...createdChunk.tool_calls) + hasToolCalls = true + } + + // 2. 处理Tool Use方式的工具调用 + if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) { + toolUseResponses.push(...createdChunk.tool_use_responses) + hasToolUseResponses = true + } + + // 不转发MCP工具进展chunks,避免重复处理 + return + } + + // 转发其他所有chunk + controller.enqueue(chunk) + } catch (error) { + console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error) + controller.error(error) + } + }, + + async flush(controller) { + const shouldExecuteToolCalls = hasToolCalls && toolCalls.length > 0 + const shouldExecuteToolUseResponses = hasToolUseResponses && toolUseResponses.length > 0 + + if (!streamEnded && (shouldExecuteToolCalls || shouldExecuteToolUseResponses)) { + streamEnded = true + + try { + let toolResult: SdkMessageParam[] = [] + + if (shouldExecuteToolCalls) { + toolResult = await executeToolCalls( + ctx, + toolCalls, + mcpTools, + allToolResponses, + currentParams.onChunk, + currentParams.assistant.model! + ) + } else if (shouldExecuteToolUseResponses) { + toolResult = await executeToolUseResponses( + ctx, + toolUseResponses, + mcpTools, + allToolResponses, + currentParams.onChunk, + currentParams.assistant.model! + ) + } + + if (toolResult.length > 0) { + const output = ctx._internal.toolProcessingState?.output + + const newParams = buildParamsWithToolResults(ctx, currentParams, output!, toolResult, toolCalls) + await executeWithToolHandling(newParams, depth + 1) + } + } catch (error) { + console.error(`🔧 [${MIDDLEWARE_NAME}] Error in tool processing:`, error) + controller.error(error) + } finally { + hasToolCalls = false + hasToolUseResponses = false + } + } + } + }) +} + +/** + * 执行工具调用(Function Call 方式) + */ +async function executeToolCalls( + ctx: CompletionsContext, + toolCalls: SdkToolCall[], + mcpTools: MCPTool[], + allToolResponses: MCPToolResponse[], + onChunk: CompletionsParams['onChunk'], + model: Model +): Promise { + // 转换为MCPToolResponse格式 + const mcpToolResponses: ToolCallResponse[] = toolCalls + .map((toolCall) => { + const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools) + if (!mcpTool) { + return undefined + } + return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool) + }) + .filter((t): t is ToolCallResponse => typeof t !== 'undefined') + + if (mcpToolResponses.length === 0) { + console.warn(`🔧 [${MIDDLEWARE_NAME}] No valid MCP tool responses to execute`) + return [] + } + + // 使用现有的parseAndCallTools函数执行工具 + const toolResults = await parseAndCallTools( + mcpToolResponses, + allToolResponses, + onChunk, + (mcpToolResponse, resp, model) => { + return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model) + }, + model, + mcpTools + ) + + return toolResults +} + +/** + * 执行工具使用响应(Tool Use Response 方式) + * 处理已经解析好的 ToolUseResponse[],不需要重新解析字符串 + */ +async function executeToolUseResponses( + ctx: CompletionsContext, + toolUseResponses: MCPToolResponse[], + mcpTools: MCPTool[], + allToolResponses: MCPToolResponse[], + onChunk: CompletionsParams['onChunk'], + model: Model +): Promise { + // 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse + const toolResults = await parseAndCallTools( + toolUseResponses, + allToolResponses, + onChunk, + (mcpToolResponse, resp, model) => { + return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model) + }, + model, + mcpTools + ) + + return toolResults +} + +/** + * 构建包含工具结果的新参数 + */ +function buildParamsWithToolResults( + ctx: CompletionsContext, + currentParams: CompletionsParams, + output: SdkRawOutput | string, + toolResults: SdkMessageParam[], + toolCalls: SdkToolCall[] +): CompletionsParams { + // 获取当前已经转换好的reqMessages,如果没有则使用原始messages + const currentReqMessages = getCurrentReqMessages(ctx) + + const apiClient = ctx.apiClientInstance + + // 从回复中构建助手消息 + const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls) + + // 估算新增消息的 token 消耗并累加到 usage 中 + if (ctx._internal.observer?.usage && newReqMessages.length > currentReqMessages.length) { + try { + const newMessages = newReqMessages.slice(currentReqMessages.length) + const additionalTokens = newMessages.reduce((acc, message) => { + return acc + ctx.apiClientInstance.estimateMessageTokens(message) + }, 0) + + if (additionalTokens > 0) { + ctx._internal.observer.usage.prompt_tokens += additionalTokens + ctx._internal.observer.usage.total_tokens += additionalTokens + } + } catch (error) { + Logger.error(`🔧 [${MIDDLEWARE_NAME}] Error estimating token usage for new messages:`, error) + } + } + + // 更新递归状态 + if (!ctx._internal.toolProcessingState) { + ctx._internal.toolProcessingState = {} + } + ctx._internal.toolProcessingState.isRecursiveCall = true + ctx._internal.toolProcessingState.recursionDepth = (ctx._internal.toolProcessingState?.recursionDepth || 0) + 1 + + return { + ...currentParams, + _internal: { + ...ctx._internal, + sdkPayload: ctx._internal.sdkPayload, + newReqMessages: newReqMessages + } + } +} + +/** + * 类型安全地获取当前请求消息 + * 使用API客户端提供的抽象方法,保持中间件的provider无关性 + */ +function getCurrentReqMessages(ctx: CompletionsContext): SdkMessageParam[] { + const sdkPayload = ctx._internal.sdkPayload + if (!sdkPayload) { + return [] + } + + // 使用API客户端的抽象方法来提取消息,保持provider无关性 + return ctx.apiClientInstance.extractMessagesFromSdkPayload(sdkPayload) +} + +export default McpToolChunkMiddleware diff --git a/src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts b/src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts new file mode 100644 index 0000000000..36c1693b3a --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts @@ -0,0 +1,48 @@ +import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient' +import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk' + +import { AnthropicStreamListener } from '../../clients/types' +import { CompletionsParams, CompletionsResult } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'RawStreamListenerMiddleware' + +export const RawStreamListenerMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + const result = await next(ctx, params) + + // 在这里可以监听到从SDK返回的最原始流 + if (result.rawOutput) { + console.log(`[${MIDDLEWARE_NAME}] 检测到原始SDK输出,准备附加监听器`) + + const providerType = ctx.apiClientInstance.provider.type + // TODO: 后面下放到AnthropicAPIClient + if (providerType === 'anthropic') { + const anthropicListener: AnthropicStreamListener = { + onMessage: (message) => { + if (ctx._internal?.toolProcessingState) { + ctx._internal.toolProcessingState.output = message + } + } + // onContentBlock: (contentBlock) => { + // console.log(`[${MIDDLEWARE_NAME}] 📝 Anthropic content block:`, contentBlock.type) + // } + } + + const specificApiClient = ctx.apiClientInstance as AnthropicAPIClient + + const monitoredOutput = specificApiClient.attachRawStreamListener( + result.rawOutput as AnthropicSdkRawOutput, + anthropicListener + ) + return { + ...result, + rawOutput: monitoredOutput + } + } + } + + return result + } diff --git a/src/renderer/src/aiCore/middleware/core/ResponseTransformMiddleware.ts b/src/renderer/src/aiCore/middleware/core/ResponseTransformMiddleware.ts new file mode 100644 index 0000000000..fbb3bac198 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/ResponseTransformMiddleware.ts @@ -0,0 +1,85 @@ +import Logger from '@renderer/config/logger' +import { SdkRawChunk } from '@renderer/types/sdk' + +import { ResponseChunkTransformerContext } from '../../clients/types' +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'ResponseTransformMiddleware' + +/** + * 响应转换中间件 + * + * 职责: + * 1. 检测ReadableStream类型的响应流 + * 2. 使用ApiClient的getResponseChunkTransformer()将原始SDK响应块转换为通用格式 + * 3. 将转换后的ReadableStream保存到ctx._internal.apiCall.genericChunkStream,供下游中间件使用 + * + * 注意:此中间件应该在StreamAdapterMiddleware之后执行 + */ +export const ResponseTransformMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + // 调用下游中间件 + const result = await next(ctx, params) + + // 响应后处理:转换原始SDK响应块 + if (result.stream) { + const adaptedStream = result.stream + + // 处理ReadableStream类型的流 + if (adaptedStream instanceof ReadableStream) { + const apiClient = ctx.apiClientInstance + if (!apiClient) { + console.error(`[${MIDDLEWARE_NAME}] ApiClient instance not found in context`) + throw new Error('ApiClient instance not found in context') + } + + // 获取响应转换器 + const responseChunkTransformer = apiClient.getResponseChunkTransformer?.() + if (!responseChunkTransformer) { + Logger.warn(`[${MIDDLEWARE_NAME}] No ResponseChunkTransformer available, skipping transformation`) + return result + } + + const assistant = params.assistant + const model = assistant?.model + + if (!assistant || !model) { + console.error(`[${MIDDLEWARE_NAME}] Assistant or Model not found for transformation`) + throw new Error('Assistant or Model not found for transformation') + } + + const transformerContext: ResponseChunkTransformerContext = { + isStreaming: params.streamOutput || false, + isEnabledToolCalling: (params.mcpTools && params.mcpTools.length > 0) || false, + isEnabledWebSearch: params.enableWebSearch || false, + isEnabledReasoning: params.enableReasoning || false, + mcpTools: params.mcpTools || [], + provider: ctx.apiClientInstance?.provider + } + + console.log(`[${MIDDLEWARE_NAME}] Transforming raw SDK chunks with context:`, transformerContext) + + try { + // 创建转换后的流 + const genericChunkTransformStream = (adaptedStream as ReadableStream).pipeThrough( + new TransformStream(responseChunkTransformer(transformerContext)) + ) + + // 将转换后的ReadableStream保存到result,供下游中间件使用 + return { + ...result, + stream: genericChunkTransformStream + } + } catch (error) { + Logger.error(`[${MIDDLEWARE_NAME}] Error during chunk transformation:`, error) + throw error + } + } + } + + // 如果没有流或不是ReadableStream,返回原始结果 + return result + } diff --git a/src/renderer/src/aiCore/middleware/core/StreamAdapterMiddleware.ts b/src/renderer/src/aiCore/middleware/core/StreamAdapterMiddleware.ts new file mode 100644 index 0000000000..118d96e035 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/StreamAdapterMiddleware.ts @@ -0,0 +1,57 @@ +import { SdkRawChunk } from '@renderer/types/sdk' +import { asyncGeneratorToReadableStream, createSingleChunkReadableStream } from '@renderer/utils/stream' + +import { CompletionsParams, CompletionsResult } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' +import { isAsyncIterable } from '../utils' + +export const MIDDLEWARE_NAME = 'StreamAdapterMiddleware' + +/** + * 流适配器中间件 + * + * 职责: + * 1. 检测ctx._internal.apiCall.rawSdkOutput(优先)或原始AsyncIterable流 + * 2. 将AsyncIterable转换为WHATWG ReadableStream + * 3. 更新响应结果中的stream + * + * 注意:如果ResponseTransformMiddleware已处理过,会优先使用transformedStream + */ +export const StreamAdapterMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + // TODO:调用开始,因为这个是最靠近接口请求的地方,next执行代表着开始接口请求了 + // 但是这个中间件的职责是流适配,是否在这调用优待商榷 + // 调用下游中间件 + const result = await next(ctx, params) + + if ( + result.rawOutput && + !(result.rawOutput instanceof ReadableStream) && + isAsyncIterable(result.rawOutput) + ) { + const whatwgReadableStream: ReadableStream = asyncGeneratorToReadableStream( + result.rawOutput + ) + return { + ...result, + stream: whatwgReadableStream + } + } else if (result.rawOutput && result.rawOutput instanceof ReadableStream) { + return { + ...result, + stream: result.rawOutput + } + } else if (result.rawOutput) { + // 非流式输出,强行变为可读流 + const whatwgReadableStream: ReadableStream = createSingleChunkReadableStream( + result.rawOutput + ) + return { + ...result, + stream: whatwgReadableStream + } + } + return result + } diff --git a/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts b/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts new file mode 100644 index 0000000000..2a3255356f --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts @@ -0,0 +1,99 @@ +import Logger from '@renderer/config/logger' +import { ChunkType, TextDeltaChunk } from '@renderer/types/chunk' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'TextChunkMiddleware' + +/** + * 文本块处理中间件 + * + * 职责: + * 1. 累积文本内容(TEXT_DELTA) + * 2. 对文本内容进行智能链接转换 + * 3. 生成TEXT_COMPLETE事件 + * 4. 暂存Web搜索结果,用于最终链接完善 + * 5. 处理 onResponse 回调,实时发送文本更新和最终完整文本 + */ +export const TextChunkMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + // 调用下游中间件 + const result = await next(ctx, params) + + // 响应后处理:转换流式响应中的文本内容 + if (result.stream) { + const resultFromUpstream = result.stream as ReadableStream + + if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { + const assistant = params.assistant + const model = params.assistant?.model + + if (!assistant || !model) { + Logger.warn(`[${MIDDLEWARE_NAME}] Missing assistant or model information, skipping text processing`) + return result + } + + // 用于跨chunk的状态管理 + let accumulatedTextContent = '' + let hasEnqueue = false + const enhancedTextStream = resultFromUpstream.pipeThrough( + new TransformStream({ + transform(chunk: GenericChunk, controller) { + if (chunk.type === ChunkType.TEXT_DELTA) { + const textChunk = chunk as TextDeltaChunk + accumulatedTextContent += textChunk.text + + // 处理 onResponse 回调 - 发送增量文本更新 + if (params.onResponse) { + params.onResponse(accumulatedTextContent, false) + } + + // 创建新的chunk,包含处理后的文本 + controller.enqueue(chunk) + } else if (accumulatedTextContent) { + if (chunk.type !== ChunkType.LLM_RESPONSE_COMPLETE) { + controller.enqueue(chunk) + hasEnqueue = true + } + const finalText = accumulatedTextContent + ctx._internal.customState!.accumulatedText = finalText + if (ctx._internal.toolProcessingState && !ctx._internal.toolProcessingState?.output) { + ctx._internal.toolProcessingState.output = finalText + } + + // 处理 onResponse 回调 - 发送最终完整文本 + if (params.onResponse) { + params.onResponse(finalText, true) + } + + controller.enqueue({ + type: ChunkType.TEXT_COMPLETE, + text: finalText + }) + accumulatedTextContent = '' + if (!hasEnqueue) { + controller.enqueue(chunk) + } + } else { + // 其他类型的chunk直接传递 + controller.enqueue(chunk) + } + } + }) + ) + + // 更新响应结果 + return { + ...result, + stream: enhancedTextStream + } + } else { + Logger.warn(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream. Returning original result.`) + } + } + + return result + } diff --git a/src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts b/src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts new file mode 100644 index 0000000000..b0df8313a5 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts @@ -0,0 +1,101 @@ +import Logger from '@renderer/config/logger' +import { ChunkType, ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'ThinkChunkMiddleware' + +/** + * 处理思考内容的中间件 + * + * 注意:从 v2 版本开始,流结束语义的判断已移至 ApiClient 层处理 + * 此中间件现在主要负责: + * 1. 处理原始SDK chunk中的reasoning字段 + * 2. 计算准确的思考时间 + * 3. 在思考内容结束时生成THINKING_COMPLETE事件 + * + * 职责: + * 1. 累积思考内容(THINKING_DELTA) + * 2. 监听流结束信号,生成THINKING_COMPLETE事件 + * 3. 计算准确的思考时间 + * + */ +export const ThinkChunkMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + // 调用下游中间件 + const result = await next(ctx, params) + + // 响应后处理:处理思考内容 + if (result.stream) { + const resultFromUpstream = result.stream as ReadableStream + + // 检查是否启用reasoning + const enableReasoning = params.enableReasoning || false + if (!enableReasoning) { + return result + } + + // 检查是否有流需要处理 + if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { + // thinking 处理状态 + let accumulatedThinkingContent = '' + let hasThinkingContent = false + let thinkingStartTime = 0 + + const processedStream = resultFromUpstream.pipeThrough( + new TransformStream({ + transform(chunk: GenericChunk, controller) { + if (chunk.type === ChunkType.THINKING_DELTA) { + const thinkingChunk = chunk as ThinkingDeltaChunk + + // 第一次接收到思考内容时记录开始时间 + if (!hasThinkingContent) { + hasThinkingContent = true + thinkingStartTime = Date.now() + } + + accumulatedThinkingContent += thinkingChunk.text + + // 更新思考时间并传递 + const enhancedChunk: ThinkingDeltaChunk = { + ...thinkingChunk, + thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 + } + controller.enqueue(enhancedChunk) + } else if (hasThinkingContent && thinkingStartTime > 0) { + // 收到任何非THINKING_DELTA的chunk时,如果有累积的思考内容,生成THINKING_COMPLETE + const thinkingCompleteChunk: ThinkingCompleteChunk = { + type: ChunkType.THINKING_COMPLETE, + text: accumulatedThinkingContent, + thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 + } + controller.enqueue(thinkingCompleteChunk) + hasThinkingContent = false + accumulatedThinkingContent = '' + thinkingStartTime = 0 + + // 继续传递当前chunk + controller.enqueue(chunk) + } else { + // 其他情况直接传递 + controller.enqueue(chunk) + } + } + }) + ) + + // 更新响应结果 + return { + ...result, + stream: processedStream + } + } else { + Logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`) + } + } + + return result + } diff --git a/src/renderer/src/aiCore/middleware/core/TransformCoreToSdkParamsMiddleware.ts b/src/renderer/src/aiCore/middleware/core/TransformCoreToSdkParamsMiddleware.ts new file mode 100644 index 0000000000..12c7a27acd --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/TransformCoreToSdkParamsMiddleware.ts @@ -0,0 +1,83 @@ +import Logger from '@renderer/config/logger' +import { ChunkType } from '@renderer/types/chunk' + +import { CompletionsParams, CompletionsResult } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'TransformCoreToSdkParamsMiddleware' + +/** + * 中间件:将CoreCompletionsRequest转换为SDK特定的参数 + * 使用上下文中ApiClient实例的requestTransformer进行转换 + */ +export const TransformCoreToSdkParamsMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + Logger.debug(`🔄 [${MIDDLEWARE_NAME}] Starting core to SDK params transformation:`, ctx) + + const internal = ctx._internal + + // 🔧 检测递归调用:检查 params 中是否携带了预处理的 SDK 消息 + const isRecursiveCall = internal?.toolProcessingState?.isRecursiveCall || false + const newSdkMessages = params._internal?.newReqMessages + + const apiClient = ctx.apiClientInstance + + if (!apiClient) { + Logger.error(`🔄 [${MIDDLEWARE_NAME}] ApiClient instance not found in context.`) + throw new Error('ApiClient instance not found in context') + } + + // 检查是否有requestTransformer方法 + const requestTransformer = apiClient.getRequestTransformer() + if (!requestTransformer) { + Logger.warn( + `🔄 [${MIDDLEWARE_NAME}] ApiClient does not have getRequestTransformer method, skipping transformation` + ) + const result = await next(ctx, params) + return result + } + + // 确保assistant和model可用,它们是transformer所需的 + const assistant = params.assistant + const model = params.assistant.model + + if (!assistant || !model) { + console.error(`🔄 [${MIDDLEWARE_NAME}] Assistant or Model not found for transformation.`) + throw new Error('Assistant or Model not found for transformation') + } + + try { + const transformResult = await requestTransformer.transform( + params, + assistant, + model, + isRecursiveCall, + newSdkMessages + ) + + const { payload: sdkPayload, metadata } = transformResult + + // 将SDK特定的payload和metadata存储在状态中,供下游中间件使用 + ctx._internal.sdkPayload = sdkPayload + + if (metadata) { + ctx._internal.customState = { + ...ctx._internal.customState, + sdkMetadata: metadata + } + } + + if (params.enableGenerateImage) { + params.onChunk?.({ + type: ChunkType.IMAGE_CREATED + }) + } + return next(ctx, params) + } catch (error) { + Logger.error(`🔄 [${MIDDLEWARE_NAME}] Error during request transformation:`, error) + // 让错误向上传播,或者可以在这里进行特定的错误处理 + throw error + } + } diff --git a/src/renderer/src/aiCore/middleware/core/WebSearchMiddleware.ts b/src/renderer/src/aiCore/middleware/core/WebSearchMiddleware.ts new file mode 100644 index 0000000000..97261e3d52 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/WebSearchMiddleware.ts @@ -0,0 +1,76 @@ +import { ChunkType } from '@renderer/types/chunk' +import { smartLinkConverter } from '@renderer/utils/linkConverter' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'WebSearchMiddleware' + +/** + * Web搜索处理中间件 - 基于GenericChunk流处理 + * + * 职责: + * 1. 监听和记录Web搜索事件 + * 2. 可以在此处添加Web搜索结果的后处理逻辑 + * 3. 维护Web搜索相关的状态 + * + * 注意:Web搜索结果的识别和生成已在ApiClient的响应转换器中处理 + */ +export const WebSearchMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + ctx._internal.webSearchState = { + results: undefined + } + // 调用下游中间件 + const result = await next(ctx, params) + + const model = params.assistant?.model! + let isFirstChunk = true + + // 响应后处理:记录Web搜索事件 + if (result.stream) { + const resultFromUpstream = result.stream + + if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { + // Web搜索状态跟踪 + const enhancedStream = (resultFromUpstream as ReadableStream).pipeThrough( + new TransformStream({ + transform(chunk: GenericChunk, controller) { + if (chunk.type === ChunkType.TEXT_DELTA) { + const providerType = model.provider || 'openai' + // 使用当前可用的Web搜索结果进行链接转换 + const text = chunk.text + const processedText = smartLinkConverter(text, providerType, isFirstChunk) + if (isFirstChunk) { + isFirstChunk = false + } + controller.enqueue({ + ...chunk, + text: processedText + }) + } else if (chunk.type === ChunkType.LLM_WEB_SEARCH_COMPLETE) { + // 暂存Web搜索结果用于链接完善 + ctx._internal.webSearchState!.results = chunk.llm_web_search + + // 将Web搜索完成事件继续传递下去 + controller.enqueue(chunk) + } else { + controller.enqueue(chunk) + } + } + }) + ) + + return { + ...result, + stream: enhancedStream + } + } else { + console.log(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream.`) + } + } + + return result + } diff --git a/src/renderer/src/aiCore/middleware/feat/ImageGenerationMiddleware.ts b/src/renderer/src/aiCore/middleware/feat/ImageGenerationMiddleware.ts new file mode 100644 index 0000000000..560a9e0aac --- /dev/null +++ b/src/renderer/src/aiCore/middleware/feat/ImageGenerationMiddleware.ts @@ -0,0 +1,132 @@ +import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient' +import { isDedicatedImageGenerationModel } from '@renderer/config/models' +import { ChunkType } from '@renderer/types/chunk' +import { findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' +import OpenAI from 'openai' +import { toFile } from 'openai/uploads' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'ImageGenerationMiddleware' + +export const ImageGenerationMiddleware: CompletionsMiddleware = + () => + (next) => + async (context: CompletionsContext, params: CompletionsParams): Promise => { + const { assistant, messages } = params + const client = context.apiClientInstance as BaseApiClient + const signal = context._internal?.flowControl?.abortSignal + + if (!assistant.model || !isDedicatedImageGenerationModel(assistant.model) || typeof messages === 'string') { + return next(context, params) + } + + const stream = new ReadableStream({ + async start(controller) { + const enqueue = (chunk: GenericChunk) => controller.enqueue(chunk) + + try { + if (!assistant.model) { + throw new Error('Assistant model is not defined.') + } + + const sdk = await client.getSdkInstance() + const lastUserMessage = messages.findLast((m) => m.role === 'user') + const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant') + + if (!lastUserMessage) { + throw new Error('No user message found for image generation.') + } + + const prompt = getMainTextContent(lastUserMessage) + let imageFiles: Blob[] = [] + + // Collect images from user message + const userImageBlocks = findImageBlocks(lastUserMessage) + const userImages = await Promise.all( + userImageBlocks.map(async (block) => { + if (!block.file) return null + const binaryData: Uint8Array = await window.api.file.binaryImage(block.file.id) + const mimeType = `${block.file.type}/${block.file.ext.slice(1)}` + return await toFile(new Blob([binaryData]), block.file.origin_name || 'image.png', { type: mimeType }) + }) + ) + imageFiles = imageFiles.concat(userImages.filter(Boolean) as Blob[]) + + // Collect images from last assistant message + if (lastAssistantMessage) { + const assistantImageBlocks = findImageBlocks(lastAssistantMessage) + const assistantImages = await Promise.all( + assistantImageBlocks.map(async (block) => { + const b64 = block.url?.replace(/^data:image\/\w+;base64,/, '') + if (!b64) return null + const binary = atob(b64) + const bytes = new Uint8Array(binary.length) + for (let i = 0; i < binary.length; i++) bytes[i] = binary.charCodeAt(i) + return await toFile(new Blob([bytes]), 'assistant_image.png', { type: 'image/png' }) + }) + ) + imageFiles = imageFiles.concat(assistantImages.filter(Boolean) as Blob[]) + } + + enqueue({ type: ChunkType.IMAGE_CREATED }) + + const startTime = Date.now() + let response: OpenAI.Images.ImagesResponse + + const options = { signal, timeout: 300_000 } + + if (imageFiles.length > 0) { + response = await sdk.images.edit( + { + model: assistant.model.id, + image: imageFiles, + prompt: prompt || '' + }, + options + ) + } else { + response = await sdk.images.generate( + { + model: assistant.model.id, + prompt: prompt || '', + response_format: assistant.model.id.includes('gpt-image-1') ? undefined : 'b64_json' + }, + options + ) + } + + const b64_json_array = response.data?.map((item) => `data:image/png;base64,${item.b64_json}`) || [] + + enqueue({ + type: ChunkType.IMAGE_COMPLETE, + image: { type: 'base64', images: b64_json_array } + }) + + const usage = (response as any).usage || { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 } + + enqueue({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage, + metrics: { + completion_tokens: usage.completion_tokens, + time_first_token_millsec: 0, + time_completion_millsec: Date.now() - startTime + } + } + }) + } catch (error: any) { + enqueue({ type: ChunkType.ERROR, error }) + } finally { + controller.close() + } + } + }) + + return { + stream, + getText: () => '' + } + } diff --git a/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts b/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts new file mode 100644 index 0000000000..440de40045 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts @@ -0,0 +1,136 @@ +import { Model } from '@renderer/types' +import { ChunkType, TextDeltaChunk, ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk' +import { TagConfig, TagExtractor } from '@renderer/utils/tagExtraction' +import Logger from 'electron-log/renderer' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'ThinkingTagExtractionMiddleware' + +// 不同模型的思考标签配置 +const reasoningTags: TagConfig[] = [ + { openingTag: '', closingTag: '', separator: '\n' }, + { openingTag: '###Thinking', closingTag: '###Response', separator: '\n' } +] + +const getAppropriateTag = (model?: Model): TagConfig => { + if (model?.id?.includes('qwen3')) return reasoningTags[0] + // 可以在这里添加更多模型特定的标签配置 + return reasoningTags[0] // 默认使用 标签 +} + +/** + * 处理文本流中思考标签提取的中间件 + * + * 该中间件专门处理文本流中的思考标签内容(如 ...) + * 主要用于 OpenAI 等支持思考标签的 provider + * + * 职责: + * 1. 从文本流中提取思考标签内容 + * 2. 将标签内的内容转换为 THINKING_DELTA chunk + * 3. 将标签外的内容作为正常文本输出 + * 4. 处理不同模型的思考标签格式 + * 5. 在思考内容结束时生成 THINKING_COMPLETE 事件 + */ +export const ThinkingTagExtractionMiddleware: CompletionsMiddleware = + () => + (next) => + async (context: CompletionsContext, params: CompletionsParams): Promise => { + // 调用下游中间件 + const result = await next(context, params) + + // 响应后处理:处理思考标签提取 + if (result.stream) { + const resultFromUpstream = result.stream as ReadableStream + + // 检查是否有流需要处理 + if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { + // 获取当前模型的思考标签配置 + const model = params.assistant?.model + const reasoningTag = getAppropriateTag(model) + + // 创建标签提取器 + const tagExtractor = new TagExtractor(reasoningTag) + + // thinking 处理状态 + let hasThinkingContent = false + let thinkingStartTime = 0 + + const processedStream = resultFromUpstream.pipeThrough( + new TransformStream({ + transform(chunk: GenericChunk, controller) { + if (chunk.type === ChunkType.TEXT_DELTA) { + const textChunk = chunk as TextDeltaChunk + + // 使用 TagExtractor 处理文本 + const extractionResults = tagExtractor.processText(textChunk.text) + + for (const extractionResult of extractionResults) { + if (extractionResult.complete && extractionResult.tagContentExtracted) { + // 生成 THINKING_COMPLETE 事件 + const thinkingCompleteChunk: ThinkingCompleteChunk = { + type: ChunkType.THINKING_COMPLETE, + text: extractionResult.tagContentExtracted, + thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 + } + controller.enqueue(thinkingCompleteChunk) + + // 重置思考状态 + hasThinkingContent = false + thinkingStartTime = 0 + } else if (extractionResult.content.length > 0) { + if (extractionResult.isTagContent) { + // 第一次接收到思考内容时记录开始时间 + if (!hasThinkingContent) { + hasThinkingContent = true + thinkingStartTime = Date.now() + } + + const thinkingDeltaChunk: ThinkingDeltaChunk = { + type: ChunkType.THINKING_DELTA, + text: extractionResult.content, + thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 + } + controller.enqueue(thinkingDeltaChunk) + } else { + // 发送清理后的文本内容 + const cleanTextChunk: TextDeltaChunk = { + ...textChunk, + text: extractionResult.content + } + controller.enqueue(cleanTextChunk) + } + } + } + } else { + // 其他类型的chunk直接传递(包括 THINKING_DELTA, THINKING_COMPLETE 等) + controller.enqueue(chunk) + } + }, + flush(controller) { + // 处理可能剩余的思考内容 + const finalResult = tagExtractor.finalize() + if (finalResult?.tagContentExtracted) { + const thinkingCompleteChunk: ThinkingCompleteChunk = { + type: ChunkType.THINKING_COMPLETE, + text: finalResult.tagContentExtracted, + thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 + } + controller.enqueue(thinkingCompleteChunk) + } + } + }) + ) + + // 更新响应结果 + return { + ...result, + stream: processedStream + } + } else { + Logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`) + } + } + return result + } diff --git a/src/renderer/src/aiCore/middleware/feat/ToolUseExtractionMiddleware.ts b/src/renderer/src/aiCore/middleware/feat/ToolUseExtractionMiddleware.ts new file mode 100644 index 0000000000..5f444953a9 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/feat/ToolUseExtractionMiddleware.ts @@ -0,0 +1,124 @@ +import { MCPTool } from '@renderer/types' +import { ChunkType, MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk' +import { parseToolUse } from '@renderer/utils/mcp-tools' +import { TagConfig, TagExtractor } from '@renderer/utils/tagExtraction' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'ToolUseExtractionMiddleware' + +// 工具使用标签配置 +const TOOL_USE_TAG_CONFIG: TagConfig = { + openingTag: '', + closingTag: '', + separator: '\n' +} + +/** + * 工具使用提取中间件 + * + * 职责: + * 1. 从文本流中检测并提取 标签 + * 2. 解析工具调用信息并转换为 ToolUseResponse 格式 + * 3. 生成 MCP_TOOL_CREATED chunk 供 McpToolChunkMiddleware 处理 + * 4. 清理文本流,移除工具使用标签但保留正常文本 + * + * 注意:此中间件只负责提取和转换,实际工具调用由 McpToolChunkMiddleware 处理 + */ +export const ToolUseExtractionMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + const mcpTools = params.mcpTools || [] + + // 如果没有工具,直接调用下一个中间件 + if (!mcpTools || mcpTools.length === 0) return next(ctx, params) + + // 调用下游中间件 + const result = await next(ctx, params) + + // 响应后处理:处理工具使用标签提取 + if (result.stream) { + const resultFromUpstream = result.stream as ReadableStream + + const processedStream = resultFromUpstream.pipeThrough(createToolUseExtractionTransform(ctx, mcpTools)) + + return { + ...result, + stream: processedStream + } + } + + return result + } + +/** + * 创建工具使用提取的 TransformStream + */ +function createToolUseExtractionTransform( + _ctx: CompletionsContext, + mcpTools: MCPTool[] +): TransformStream { + const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG) + + return new TransformStream({ + async transform(chunk: GenericChunk, controller) { + try { + // 处理文本内容,检测工具使用标签 + if (chunk.type === ChunkType.TEXT_DELTA) { + const textChunk = chunk as TextDeltaChunk + const extractionResults = tagExtractor.processText(textChunk.text) + + for (const result of extractionResults) { + if (result.complete && result.tagContentExtracted) { + // 提取到完整的工具使用内容,解析并转换为 SDK ToolCall 格式 + const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools) + + if (toolUseResponses.length > 0) { + // 生成 MCP_TOOL_CREATED chunk,复用现有的处理流程 + const mcpToolCreatedChunk: MCPToolCreatedChunk = { + type: ChunkType.MCP_TOOL_CREATED, + tool_use_responses: toolUseResponses + } + controller.enqueue(mcpToolCreatedChunk) + } + } else if (!result.isTagContent && result.content) { + // 发送标签外的正常文本内容 + const cleanTextChunk: TextDeltaChunk = { + ...textChunk, + text: result.content + } + controller.enqueue(cleanTextChunk) + } + // 注意:标签内的内容不会作为TEXT_DELTA转发,避免重复显示 + } + return + } + + // 转发其他所有chunk + controller.enqueue(chunk) + } catch (error) { + console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error) + controller.error(error) + } + }, + + async flush(controller) { + // 检查是否有未完成的标签内容 + const finalResult = tagExtractor.finalize() + if (finalResult && finalResult.tagContentExtracted) { + const toolUseResponses = parseToolUse(finalResult.tagContentExtracted, mcpTools) + if (toolUseResponses.length > 0) { + const mcpToolCreatedChunk: MCPToolCreatedChunk = { + type: ChunkType.MCP_TOOL_CREATED, + tool_use_responses: toolUseResponses + } + controller.enqueue(mcpToolCreatedChunk) + } + } + } + }) +} + +export default ToolUseExtractionMiddleware diff --git a/src/renderer/src/aiCore/middleware/index.ts b/src/renderer/src/aiCore/middleware/index.ts new file mode 100644 index 0000000000..64be4edd44 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/index.ts @@ -0,0 +1,88 @@ +import { CompletionsMiddleware, MethodMiddleware } from './types' + +// /** +// * Wraps a provider instance with middlewares. +// */ +// export function wrapProviderWithMiddleware( +// apiClientInstance: BaseApiClient, +// middlewareConfig: MiddlewareConfig +// ): BaseApiClient { +// console.log(`[wrapProviderWithMiddleware] Wrapping provider: ${apiClientInstance.provider?.id}`) +// console.log(`[wrapProviderWithMiddleware] Middleware config:`, { +// completions: middlewareConfig.completions?.length || 0, +// methods: Object.keys(middlewareConfig.methods || {}).length +// }) + +// // Cache for already wrapped methods to avoid re-wrapping on every access. +// const wrappedMethodsCache = new Map Promise>() + +// const proxy = new Proxy(apiClientInstance, { +// get(target, propKey, receiver) { +// const methodName = typeof propKey === 'string' ? propKey : undefined + +// if (!methodName) { +// return Reflect.get(target, propKey, receiver) +// } + +// if (wrappedMethodsCache.has(methodName)) { +// console.log(`[wrapProviderWithMiddleware] Using cached wrapped method: ${methodName}`) +// return wrappedMethodsCache.get(methodName) +// } + +// const originalMethod = Reflect.get(target, propKey, receiver) + +// // If the property is not a function, return it directly. +// if (typeof originalMethod !== 'function') { +// return originalMethod +// } + +// let wrappedMethod: ((...args: any[]) => Promise) | undefined + +// // Handle completions method +// if (methodName === 'completions' && middlewareConfig.completions?.length) { +// console.log( +// `[wrapProviderWithMiddleware] Wrapping completions method with ${middlewareConfig.completions.length} middlewares` +// ) +// const completionsOriginalMethod = originalMethod as (params: CompletionsParams) => Promise +// wrappedMethod = applyCompletionsMiddlewares(target, completionsOriginalMethod, middlewareConfig.completions) +// } +// // Handle other methods +// else { +// const methodMiddlewares = middlewareConfig.methods?.[methodName] +// if (methodMiddlewares?.length) { +// console.log( +// `[wrapProviderWithMiddleware] Wrapping method ${methodName} with ${methodMiddlewares.length} middlewares` +// ) +// const genericOriginalMethod = originalMethod as (...args: any[]) => Promise +// wrappedMethod = applyMethodMiddlewares(target, methodName, genericOriginalMethod, methodMiddlewares) +// } +// } + +// if (wrappedMethod) { +// console.log(`[wrapProviderWithMiddleware] Successfully wrapped method: ${methodName}`) +// wrappedMethodsCache.set(methodName, wrappedMethod) +// return wrappedMethod +// } + +// // If no middlewares are configured for this method, return the original method bound to the target. / +// // 如果没有为此方法配置中间件,则返回绑定到目标的原始方法。 +// console.log(`[wrapProviderWithMiddleware] No middlewares for method ${methodName}, returning original`) +// return originalMethod.bind(target) +// } +// }) +// return proxy as BaseApiClient +// } + +// Export types for external use +export type { CompletionsMiddleware, MethodMiddleware } + +// Export MiddlewareBuilder related types and classes +export { + CompletionsMiddlewareBuilder, + createCompletionsBuilder, + createMethodBuilder, + MethodMiddlewareBuilder, + MiddlewareBuilder, + type MiddlewareExecutor, + type NamedMiddleware +} from './builder' diff --git a/src/renderer/src/aiCore/middleware/register.ts b/src/renderer/src/aiCore/middleware/register.ts new file mode 100644 index 0000000000..003ce7e93a --- /dev/null +++ b/src/renderer/src/aiCore/middleware/register.ts @@ -0,0 +1,149 @@ +import * as AbortHandlerModule from './common/AbortHandlerMiddleware' +import * as ErrorHandlerModule from './common/ErrorHandlerMiddleware' +import * as FinalChunkConsumerModule from './common/FinalChunkConsumerMiddleware' +import * as LoggingModule from './common/LoggingMiddleware' +import * as McpToolChunkModule from './core/McpToolChunkMiddleware' +import * as RawStreamListenerModule from './core/RawStreamListenerMiddleware' +import * as ResponseTransformModule from './core/ResponseTransformMiddleware' +// import * as SdkCallModule from './core/SdkCallMiddleware' +import * as StreamAdapterModule from './core/StreamAdapterMiddleware' +import * as TextChunkModule from './core/TextChunkMiddleware' +import * as ThinkChunkModule from './core/ThinkChunkMiddleware' +import * as TransformCoreToSdkParamsModule from './core/TransformCoreToSdkParamsMiddleware' +import * as WebSearchModule from './core/WebSearchMiddleware' +import * as ImageGenerationModule from './feat/ImageGenerationMiddleware' +import * as ThinkingTagExtractionModule from './feat/ThinkingTagExtractionMiddleware' +import * as ToolUseExtractionMiddleware from './feat/ToolUseExtractionMiddleware' + +/** + * 中间件注册表 - 提供所有可用中间件的集中访问 + * 注意:目前中间件文件还未导出 MIDDLEWARE_NAME,会有 linter 错误,这是正常的 + */ +export const MiddlewareRegistry = { + [ErrorHandlerModule.MIDDLEWARE_NAME]: { + name: ErrorHandlerModule.MIDDLEWARE_NAME, + middleware: ErrorHandlerModule.ErrorHandlerMiddleware + }, + // 通用中间件 + [AbortHandlerModule.MIDDLEWARE_NAME]: { + name: AbortHandlerModule.MIDDLEWARE_NAME, + middleware: AbortHandlerModule.AbortHandlerMiddleware + }, + [FinalChunkConsumerModule.MIDDLEWARE_NAME]: { + name: FinalChunkConsumerModule.MIDDLEWARE_NAME, + middleware: FinalChunkConsumerModule.default + }, + + // 核心流程中间件 + [TransformCoreToSdkParamsModule.MIDDLEWARE_NAME]: { + name: TransformCoreToSdkParamsModule.MIDDLEWARE_NAME, + middleware: TransformCoreToSdkParamsModule.TransformCoreToSdkParamsMiddleware + }, + // [SdkCallModule.MIDDLEWARE_NAME]: { + // name: SdkCallModule.MIDDLEWARE_NAME, + // middleware: SdkCallModule.SdkCallMiddleware + // }, + [StreamAdapterModule.MIDDLEWARE_NAME]: { + name: StreamAdapterModule.MIDDLEWARE_NAME, + middleware: StreamAdapterModule.StreamAdapterMiddleware + }, + [RawStreamListenerModule.MIDDLEWARE_NAME]: { + name: RawStreamListenerModule.MIDDLEWARE_NAME, + middleware: RawStreamListenerModule.RawStreamListenerMiddleware + }, + [ResponseTransformModule.MIDDLEWARE_NAME]: { + name: ResponseTransformModule.MIDDLEWARE_NAME, + middleware: ResponseTransformModule.ResponseTransformMiddleware + }, + + // 特性处理中间件 + [ThinkingTagExtractionModule.MIDDLEWARE_NAME]: { + name: ThinkingTagExtractionModule.MIDDLEWARE_NAME, + middleware: ThinkingTagExtractionModule.ThinkingTagExtractionMiddleware + }, + [ToolUseExtractionMiddleware.MIDDLEWARE_NAME]: { + name: ToolUseExtractionMiddleware.MIDDLEWARE_NAME, + middleware: ToolUseExtractionMiddleware.ToolUseExtractionMiddleware + }, + [ThinkChunkModule.MIDDLEWARE_NAME]: { + name: ThinkChunkModule.MIDDLEWARE_NAME, + middleware: ThinkChunkModule.ThinkChunkMiddleware + }, + [McpToolChunkModule.MIDDLEWARE_NAME]: { + name: McpToolChunkModule.MIDDLEWARE_NAME, + middleware: McpToolChunkModule.McpToolChunkMiddleware + }, + [WebSearchModule.MIDDLEWARE_NAME]: { + name: WebSearchModule.MIDDLEWARE_NAME, + middleware: WebSearchModule.WebSearchMiddleware + }, + [TextChunkModule.MIDDLEWARE_NAME]: { + name: TextChunkModule.MIDDLEWARE_NAME, + middleware: TextChunkModule.TextChunkMiddleware + }, + [ImageGenerationModule.MIDDLEWARE_NAME]: { + name: ImageGenerationModule.MIDDLEWARE_NAME, + middleware: ImageGenerationModule.ImageGenerationMiddleware + } +} as const + +/** + * 根据名称获取中间件 + * @param name - 中间件名称 + * @returns 对应的中间件信息 + */ +export function getMiddleware(name: string) { + return MiddlewareRegistry[name] +} + +/** + * 获取所有注册的中间件名称 + * @returns 中间件名称列表 + */ +export function getRegisteredMiddlewareNames(): string[] { + return Object.keys(MiddlewareRegistry) +} + +/** + * 默认的 Completions 中间件配置 - NamedMiddleware 格式,用于 MiddlewareBuilder + */ +export const DefaultCompletionsNamedMiddlewares = [ + MiddlewareRegistry[FinalChunkConsumerModule.MIDDLEWARE_NAME], // 最终消费者 + MiddlewareRegistry[ErrorHandlerModule.MIDDLEWARE_NAME], // 错误处理 + MiddlewareRegistry[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME], // 参数转换 + MiddlewareRegistry[AbortHandlerModule.MIDDLEWARE_NAME], // 中止处理 + MiddlewareRegistry[McpToolChunkModule.MIDDLEWARE_NAME], // 工具处理 + MiddlewareRegistry[TextChunkModule.MIDDLEWARE_NAME], // 文本处理 + MiddlewareRegistry[WebSearchModule.MIDDLEWARE_NAME], // Web搜索处理 + MiddlewareRegistry[ToolUseExtractionMiddleware.MIDDLEWARE_NAME], // 工具使用提取处理 + MiddlewareRegistry[ThinkingTagExtractionModule.MIDDLEWARE_NAME], // 思考标签提取处理(特定provider) + MiddlewareRegistry[ThinkChunkModule.MIDDLEWARE_NAME], // 思考处理(通用SDK) + MiddlewareRegistry[ResponseTransformModule.MIDDLEWARE_NAME], // 响应转换 + MiddlewareRegistry[StreamAdapterModule.MIDDLEWARE_NAME], // 流适配器 + MiddlewareRegistry[RawStreamListenerModule.MIDDLEWARE_NAME] // 原始流监听器 +] + +/** + * 默认的通用方法中间件 - 例如翻译、摘要等 + */ +export const DefaultMethodMiddlewares = { + translate: [LoggingModule.createGenericLoggingMiddleware()], + summaries: [LoggingModule.createGenericLoggingMiddleware()] +} + +/** + * 导出所有中间件模块,方便外部使用 + */ +export { + AbortHandlerModule, + FinalChunkConsumerModule, + LoggingModule, + McpToolChunkModule, + ResponseTransformModule, + StreamAdapterModule, + TextChunkModule, + ThinkChunkModule, + ThinkingTagExtractionModule, + TransformCoreToSdkParamsModule, + WebSearchModule +} diff --git a/src/renderer/src/aiCore/middleware/schemas.ts b/src/renderer/src/aiCore/middleware/schemas.ts new file mode 100644 index 0000000000..33d9816b4f --- /dev/null +++ b/src/renderer/src/aiCore/middleware/schemas.ts @@ -0,0 +1,77 @@ +import { Assistant, MCPTool } from '@renderer/types' +import { Chunk } from '@renderer/types/chunk' +import { Message } from '@renderer/types/newMessage' +import { SdkRawChunk, SdkRawOutput } from '@renderer/types/sdk' + +import { ProcessingState } from './types' + +// ============================================================================ +// Core Request Types - 核心请求结构 +// ============================================================================ + +/** + * 标准化的内部核心请求结构,用于所有AI Provider的统一处理 + * 这是应用层参数转换后的标准格式,不包含回调函数和控制逻辑 + */ +export interface CompletionsParams { + /** + * 调用的业务场景类型,用于中间件判断是否执行 + * 'chat': 主要对话流程 + * 'translate': 翻译 + * 'summary': 摘要 + * 'search': 搜索摘要 + * 'generate': 生成 + * 'check': API检查 + */ + callType?: 'chat' | 'translate' | 'summary' | 'search' | 'generate' | 'check' + + // 基础对话数据 + messages: Message[] | string // 联合类型方便判断是否为空 + + assistant: Assistant // 助手为基本单位 + // model: Model + + onChunk?: (chunk: Chunk) => void + onResponse?: (text: string, isComplete: boolean) => void + + // 错误相关 + onError?: (error: Error) => void + shouldThrow?: boolean + + // 工具相关 + mcpTools?: MCPTool[] + + // 生成参数 + temperature?: number + topP?: number + maxTokens?: number + + // 功能开关 + streamOutput: boolean + enableWebSearch?: boolean + enableReasoning?: boolean + enableGenerateImage?: boolean + + // 上下文控制 + contextCount?: number + + _internal?: ProcessingState +} + +export interface CompletionsResult { + rawOutput?: SdkRawOutput + stream?: ReadableStream | ReadableStream | AsyncIterable + controller?: AbortController + + getText: () => string +} + +// ============================================================================ +// Generic Chunk Types - 通用数据块结构 +// ============================================================================ + +/** + * 通用数据块类型 + * 复用现有的 Chunk 类型,这是所有AI Provider都应该输出的标准化数据块格式 + */ +export type GenericChunk = Chunk diff --git a/src/renderer/src/aiCore/middleware/types.ts b/src/renderer/src/aiCore/middleware/types.ts new file mode 100644 index 0000000000..0a7dbe390b --- /dev/null +++ b/src/renderer/src/aiCore/middleware/types.ts @@ -0,0 +1,166 @@ +import { MCPToolResponse, Metrics, Usage, WebSearchResponse } from '@renderer/types' +import { Chunk, ErrorChunk } from '@renderer/types/chunk' +import { + SdkInstance, + SdkMessageParam, + SdkParams, + SdkRawChunk, + SdkRawOutput, + SdkTool, + SdkToolCall +} from '@renderer/types/sdk' + +import { BaseApiClient } from '../clients' +import { CompletionsParams, CompletionsResult } from './schemas' + +/** + * Symbol to uniquely identify middleware context objects. + */ +export const MIDDLEWARE_CONTEXT_SYMBOL = Symbol.for('AiProviderMiddlewareContext') + +/** + * Defines the structure for the onChunk callback function. + */ +export type OnChunkFunction = (chunk: Chunk | ErrorChunk) => void + +/** + * Base context that carries information about the current method call. + */ +export interface BaseContext { + [MIDDLEWARE_CONTEXT_SYMBOL]: true + methodName: string + originalArgs: Readonly +} + +/** + * Processing state shared between middlewares. + */ +export interface ProcessingState< + TParams extends SdkParams = SdkParams, + TMessageParam extends SdkMessageParam = SdkMessageParam, + TToolCall extends SdkToolCall = SdkToolCall +> { + sdkPayload?: TParams + newReqMessages?: TMessageParam[] + observer?: { + usage?: Usage + metrics?: Metrics + } + toolProcessingState?: { + pendingToolCalls?: Array + executingToolCalls?: Array<{ + sdkToolCall: TToolCall + mcpToolResponse: MCPToolResponse + }> + output?: SdkRawOutput | string + isRecursiveCall?: boolean + recursionDepth?: number + } + webSearchState?: { + results?: WebSearchResponse + } + flowControl?: { + abortController?: AbortController + abortSignal?: AbortSignal + cleanup?: () => void + } + enhancedDispatch?: (context: CompletionsContext, params: CompletionsParams) => Promise + customState?: Record +} + +/** + * Extended context for completions method. + */ +export interface CompletionsContext< + TSdkParams extends SdkParams = SdkParams, + TSdkMessageParam extends SdkMessageParam = SdkMessageParam, + TSdkToolCall extends SdkToolCall = SdkToolCall, + TSdkInstance extends SdkInstance = SdkInstance, + TRawOutput extends SdkRawOutput = SdkRawOutput, + TRawChunk extends SdkRawChunk = SdkRawChunk, + TSdkSpecificTool extends SdkTool = SdkTool +> extends BaseContext { + readonly methodName: 'completions' // 强制方法名为 'completions' + + apiClientInstance: BaseApiClient< + TSdkInstance, + TSdkParams, + TRawOutput, + TRawChunk, + TSdkMessageParam, + TSdkToolCall, + TSdkSpecificTool + > + + // --- Mutable internal state for the duration of the middleware chain --- + _internal: ProcessingState // 包含所有可变的处理状态 +} + +export interface MiddlewareAPI { + getContext: () => Ctx // Function to get the current context / 获取当前上下文的函数 + getOriginalArgs: () => Args // Function to get the original arguments of the method call / 获取方法调用原始参数的函数 +} + +/** + * Base middleware type. + */ +export type Middleware = ( + api: MiddlewareAPI +) => ( + next: (context: TContext, args: any[]) => Promise +) => (context: TContext, args: any[]) => Promise + +export type MethodMiddleware = Middleware + +/** + * Completions middleware type. + */ +export type CompletionsMiddleware< + TSdkParams extends SdkParams = SdkParams, + TSdkMessageParam extends SdkMessageParam = SdkMessageParam, + TSdkToolCall extends SdkToolCall = SdkToolCall, + TSdkInstance extends SdkInstance = SdkInstance, + TRawOutput extends SdkRawOutput = SdkRawOutput, + TRawChunk extends SdkRawChunk = SdkRawChunk, + TSdkSpecificTool extends SdkTool = SdkTool +> = ( + api: MiddlewareAPI< + CompletionsContext< + TSdkParams, + TSdkMessageParam, + TSdkToolCall, + TSdkInstance, + TRawOutput, + TRawChunk, + TSdkSpecificTool + >, + [CompletionsParams] + > +) => ( + next: ( + context: CompletionsContext< + TSdkParams, + TSdkMessageParam, + TSdkToolCall, + TSdkInstance, + TRawOutput, + TRawChunk, + TSdkSpecificTool + >, + params: CompletionsParams + ) => Promise +) => ( + context: CompletionsContext< + TSdkParams, + TSdkMessageParam, + TSdkToolCall, + TSdkInstance, + TRawOutput, + TRawChunk, + TSdkSpecificTool + >, + params: CompletionsParams +) => Promise + +// Re-export for convenience +export type { Chunk as OnChunkArg } from '@renderer/types/chunk' diff --git a/src/renderer/src/aiCore/middleware/utils.ts b/src/renderer/src/aiCore/middleware/utils.ts new file mode 100644 index 0000000000..12a2fe651d --- /dev/null +++ b/src/renderer/src/aiCore/middleware/utils.ts @@ -0,0 +1,57 @@ +import { ChunkType, ErrorChunk } from '@renderer/types/chunk' + +/** + * Creates an ErrorChunk object with a standardized structure. + * @param error The error object or message. + * @param chunkType The type of chunk, defaults to ChunkType.ERROR. + * @returns An ErrorChunk object. + */ +export function createErrorChunk(error: any, chunkType: ChunkType = ChunkType.ERROR): ErrorChunk { + let errorDetails: Record = {} + + if (error instanceof Error) { + errorDetails = { + message: error.message, + name: error.name, + stack: error.stack + } + } else if (typeof error === 'string') { + errorDetails = { message: error } + } else if (typeof error === 'object' && error !== null) { + errorDetails = Object.getOwnPropertyNames(error).reduce( + (acc, key) => { + acc[key] = error[key] + return acc + }, + {} as Record + ) + if (!errorDetails.message && error.toString && typeof error.toString === 'function') { + const errMsg = error.toString() + if (errMsg !== '[object Object]') { + errorDetails.message = errMsg + } + } + } + + return { + type: chunkType, + error: errorDetails + } as ErrorChunk +} + +// Helper to capitalize method names for hook construction +export function capitalize(str: string): string { + if (!str) return '' + return str.charAt(0).toUpperCase() + str.slice(1) +} + +/** + * 检查对象是否实现了AsyncIterable接口 + */ +export function isAsyncIterable(obj: unknown): obj is AsyncIterable { + return ( + obj !== null && + typeof obj === 'object' && + typeof (obj as Record)[Symbol.asyncIterator] === 'function' + ) +} diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index acebf6171c..361478e411 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -143,7 +143,7 @@ import YiModelLogoDark from '@renderer/assets/images/models/yi_dark.png' import YoudaoLogo from '@renderer/assets/images/providers/netease-youdao.svg' import NomicLogo from '@renderer/assets/images/providers/nomic.png' import { getProviderByModel } from '@renderer/services/AssistantService' -import { Assistant, Model } from '@renderer/types' +import { Model } from '@renderer/types' import OpenAI from 'openai' import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from './prompts' @@ -199,6 +199,11 @@ export const VISION_REGEX = new RegExp( 'i' ) +// For middleware to identify models that must use the dedicated Image API +export const DEDICATED_IMAGE_MODELS = ['grok-2-image', 'dall-e-3', 'dall-e-2', 'gpt-image-1'] +export const isDedicatedImageGenerationModel = (model: Model): boolean => + DEDICATED_IMAGE_MODELS.filter((m) => model.id.includes(m)).length > 0 + // Text to image models export const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus/i @@ -2246,14 +2251,24 @@ export const TEXT_TO_IMAGES_MODELS_SUPPORT_IMAGE_ENHANCEMENT = [ 'stabilityai/stable-diffusion-xl-base-1.0' ] +export const SUPPORTED_DISABLE_GENERATION_MODELS = [ + 'gemini-2.0-flash-exp', + 'gpt-4o', + 'gpt-4o-mini', + 'gpt-4.1', + 'gpt-4.1-mini', + 'gpt-4.1-nano', + 'o3' +] + export const GENERATE_IMAGE_MODELS = [ 'gemini-2.0-flash-exp-image-generation', 'gemini-2.0-flash-preview-image-generation', - 'gemini-2.0-flash-exp', 'grok-2-image-1212', 'grok-2-image', 'grok-2-image-latest', - 'gpt-image-1' + 'gpt-image-1', + ...SUPPORTED_DISABLE_GENERATION_MODELS ] export const GEMINI_SEARCH_MODELS = [ @@ -2362,10 +2377,32 @@ export function isSupportedReasoningEffortOpenAIModel(model: Model): boolean { ) } -export function isOpenAIWebSearch(model: Model): boolean { +export function isOpenAIChatCompletionOnlyModel(model: Model): boolean { + if (!model) { + return false + } + + return ( + model.id.includes('gpt-4o-search-preview') || + model.id.includes('gpt-4o-mini-search-preview') || + model.id.includes('o1-mini') || + model.id.includes('o1-preview') + ) +} + +export function isOpenAIWebSearchChatCompletionOnlyModel(model: Model): boolean { return model.id.includes('gpt-4o-search-preview') || model.id.includes('gpt-4o-mini-search-preview') } +export function isOpenAIWebSearchModel(model: Model): boolean { + return ( + model.id.includes('gpt-4o-search-preview') || + model.id.includes('gpt-4o-mini-search-preview') || + (model.id.includes('gpt-4.1') && !model.id.includes('gpt-4.1-nano')) || + (model.id.includes('gpt-4o') && !model.id.includes('gpt-4o-image')) + ) +} + export function isSupportedThinkingTokenModel(model?: Model): boolean { if (!model) { return false @@ -2506,7 +2543,7 @@ export function isNotSupportTemperatureAndTopP(model: Model): boolean { return true } - if (isOpenAIReasoningModel(model) || isOpenAIWebSearch(model)) { + if (isOpenAIReasoningModel(model) || isOpenAIChatCompletionOnlyModel(model)) { return true } @@ -2536,17 +2573,13 @@ export function isWebSearchModel(model: Model): boolean { return false } + // 不管哪个供应商都判断了 if (model.id.includes('claude')) { return CLAUDE_SUPPORTED_WEBSEARCH_REGEX.test(model.id) } if (provider.type === 'openai-response') { - if ( - isOpenAILLMModel(model) && - !isTextToImageModel(model) && - !isOpenAIReasoningModel(model) && - !GENERATE_IMAGE_MODELS.includes(model.id) - ) { + if (isOpenAIWebSearchModel(model)) { return true } @@ -2558,12 +2591,7 @@ export function isWebSearchModel(model: Model): boolean { } if (provider.id === 'aihubmix') { - if ( - isOpenAILLMModel(model) && - !isTextToImageModel(model) && - !isOpenAIReasoningModel(model) && - !GENERATE_IMAGE_MODELS.includes(model.id) - ) { + if (isOpenAIWebSearchModel(model)) { return true } @@ -2572,7 +2600,7 @@ export function isWebSearchModel(model: Model): boolean { } if (provider?.type === 'openai') { - if (GEMINI_SEARCH_MODELS.includes(model?.id) || isOpenAIWebSearch(model)) { + if (GEMINI_SEARCH_MODELS.includes(model?.id) || isOpenAIWebSearchModel(model)) { return true } } @@ -2606,6 +2634,20 @@ export function isWebSearchModel(model: Model): boolean { return false } +export function isOpenRouterBuiltInWebSearchModel(model: Model): boolean { + if (!model) { + return false + } + + const provider = getProviderByModel(model) + + if (provider.id !== 'openrouter') { + return false + } + + return isOpenAIWebSearchModel(model) || model.id.includes('sonar') +} + export function isGenerateImageModel(model: Model): boolean { if (!model) { return false @@ -2628,56 +2670,60 @@ export function isGenerateImageModel(model: Model): boolean { return false } -export function getOpenAIWebSearchParams(assistant: Assistant, model: Model): Record { - if (isWebSearchModel(model)) { - if (assistant.enableWebSearch) { - const webSearchTools = getWebSearchTools(model) +export function isSupportedDisableGenerationModel(model: Model): boolean { + if (!model) { + return false + } - if (model.provider === 'grok') { - return { - search_parameters: { - mode: 'auto', - return_citations: true, - sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }] - } - } - } + return SUPPORTED_DISABLE_GENERATION_MODELS.includes(model.id) +} - if (model.provider === 'hunyuan') { - return { enable_enhancement: true, citation: true, search_info: true } - } +export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boolean): Record { + if (!isEnableWebSearch) { + return {} + } - if (model.provider === 'dashscope') { - return { - enable_search: true, - search_options: { - forced_search: true - } - } - } + const webSearchTools = getWebSearchTools(model) - if (model.provider === 'openrouter') { - return { - plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }] - } - } - - if (isOpenAIWebSearch(model)) { - return { - web_search_options: {} - } - } - - return { - tools: webSearchTools - } - } else { - if (model.provider === 'hunyuan') { - return { enable_enhancement: false } + if (model.provider === 'grok') { + return { + search_parameters: { + mode: 'auto', + return_citations: true, + sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }] } } } + if (model.provider === 'hunyuan') { + return { enable_enhancement: true, citation: true, search_info: true } + } + + if (model.provider === 'dashscope') { + return { + enable_search: true, + search_options: { + forced_search: true + } + } + } + + if (isOpenAIWebSearchChatCompletionOnlyModel(model)) { + return { + web_search_options: {} + } + } + + if (model.provider === 'openrouter') { + return { + plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }] + } + } + + return { + tools: webSearchTools + } + return {} } diff --git a/src/renderer/src/middlewares/extractReasoningMiddleware.ts b/src/renderer/src/middlewares/extractReasoningMiddleware.ts deleted file mode 100644 index a466822d8c..0000000000 --- a/src/renderer/src/middlewares/extractReasoningMiddleware.ts +++ /dev/null @@ -1,118 +0,0 @@ -// Modified from https://github.com/vercel/ai/blob/845080d80b8538bb9c7e527d2213acb5f33ac9c2/packages/ai/core/middleware/extract-reasoning-middleware.ts - -import { getPotentialStartIndex } from '../utils/getPotentialIndex' - -export interface ExtractReasoningMiddlewareOptions { - openingTag: string - closingTag: string - separator?: string - enableReasoning?: boolean -} - -function escapeRegExp(str: string) { - return str.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&') -} - -// 支持泛型 T,默认 T = { type: string; textDelta: string } -export function extractReasoningMiddleware< - T extends { type: string } & ( - | { type: 'text-delta' | 'reasoning'; textDelta: string } - | { type: string } // 其他类型 - ) = { type: string; textDelta: string } ->({ openingTag, closingTag, separator = '\n', enableReasoning }: ExtractReasoningMiddlewareOptions) { - const openingTagEscaped = escapeRegExp(openingTag) - const closingTagEscaped = escapeRegExp(closingTag) - - return { - wrapGenerate: async ({ doGenerate }: { doGenerate: () => Promise<{ text: string } & Record> }) => { - const { text: rawText, ...rest } = await doGenerate() - if (rawText == null) { - return { text: rawText, ...rest } - } - const text = rawText - const regexp = new RegExp(`${openingTagEscaped}(.*?)${closingTagEscaped}`, 'gs') - const matches = Array.from(text.matchAll(regexp)) - if (!matches.length) { - return { text, ...rest } - } - const reasoning = matches.map((match: RegExpMatchArray) => match[1]).join(separator) - let textWithoutReasoning = text - for (let i = matches.length - 1; i >= 0; i--) { - const match = matches[i] as RegExpMatchArray - const beforeMatch = textWithoutReasoning.slice(0, match.index as number) - const afterMatch = textWithoutReasoning.slice((match.index as number) + match[0].length) - textWithoutReasoning = - beforeMatch + (beforeMatch.length > 0 && afterMatch.length > 0 ? separator : '') + afterMatch - } - return { ...rest, text: textWithoutReasoning, reasoning } - }, - wrapStream: async ({ - doStream - }: { - doStream: () => Promise<{ stream: ReadableStream } & Record> - }) => { - const { stream, ...rest } = await doStream() - if (!enableReasoning) { - return { - stream, - ...rest - } - } - let isFirstReasoning = true - let isFirstText = true - let afterSwitch = false - let isReasoning = false - let buffer = '' - return { - stream: stream.pipeThrough( - new TransformStream({ - transform: (chunk, controller) => { - if (chunk.type !== 'text-delta') { - controller.enqueue(chunk) - return - } - // textDelta 只在 text-delta/reasoning chunk 上 - buffer += (chunk as { textDelta: string }).textDelta - function publish(text: string) { - if (text.length > 0) { - const prefix = afterSwitch && (isReasoning ? !isFirstReasoning : !isFirstText) ? separator : '' - controller.enqueue({ - ...chunk, - type: isReasoning ? 'reasoning' : 'text-delta', - textDelta: prefix + text - } as T) - afterSwitch = false - if (isReasoning) { - isFirstReasoning = false - } else { - isFirstText = false - } - } - } - while (true) { - const nextTag = isReasoning ? closingTag : openingTag - const startIndex = getPotentialStartIndex(buffer, nextTag) - if (startIndex == null) { - publish(buffer) - buffer = '' - break - } - publish(buffer.slice(0, startIndex)) - const foundFullMatch = startIndex + nextTag.length <= buffer.length - if (foundFullMatch) { - buffer = buffer.slice(startIndex + nextTag.length) - isReasoning = !isReasoning - afterSwitch = true - } else { - buffer = buffer.slice(startIndex) - break - } - } - } - }) - ), - ...rest - } - } - } -} diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index 83b3373093..8690d99d84 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -4,6 +4,7 @@ import TranslateButton from '@renderer/components/TranslateButton' import Logger from '@renderer/config/logger' import { isGenerateImageModel, + isSupportedDisableGenerationModel, isSupportedReasoningEffortModel, isSupportedThinkingTokenModel, isVisionModel, @@ -727,7 +728,7 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = if (!isGenerateImageModel(model) && assistant.enableGenerateImage) { updateAssistant({ ...assistant, enableGenerateImage: false }) } - if (isGenerateImageModel(model) && !assistant.enableGenerateImage && model.id !== 'gemini-2.0-flash-exp') { + if (isGenerateImageModel(model) && !assistant.enableGenerateImage && !isSupportedDisableGenerationModel(model)) { updateAssistant({ ...assistant, enableGenerateImage: true }) } }, [assistant, model, updateAssistant]) diff --git a/src/renderer/src/pages/home/Messages/Blocks/CitationBlock.tsx b/src/renderer/src/pages/home/Messages/Blocks/CitationBlock.tsx index 09a16c3496..fdd4640d00 100644 --- a/src/renderer/src/pages/home/Messages/Blocks/CitationBlock.tsx +++ b/src/renderer/src/pages/home/Messages/Blocks/CitationBlock.tsx @@ -40,7 +40,18 @@ function CitationBlock({ block }: { block: CitationMessageBlock }) { __html: (block.response?.results as GroundingMetadata)?.searchEntryPoint?.renderedContent ?.replace(/@media \(prefers-color-scheme: light\)/g, 'body[theme-mode="light"]') - .replace(/@media \(prefers-color-scheme: dark\)/g, 'body[theme-mode="dark"]') || '' + .replace(/@media \(prefers-color-scheme: dark\)/g, 'body[theme-mode="dark"]') + .replace( + /background-color\s*:\s*#[0-9a-fA-F]{3,6}\b|\bbackground-color\s*:\s*[a-zA-Z-]+\b/g, + 'background-color: var(--color-background-soft)' + ) + .replace(/\.gradient\s*{[^}]*background\s*:\s*[^};]+[;}]/g, (match) => { + // Remove the background property while preserving the rest + return match.replace(/background\s*:\s*[^};]+;?\s*/g, '') + }) + .replace(/\.chip {\n/g, '.chip {\n background-color: var(--color-background)!important;\n') + .replace(/border-color\s*:\s*[^};]+;?\s*/g, '') + .replace(/border\s*:\s*[^};]+;?\s*/g, '') || '' }} /> diff --git a/src/renderer/src/pages/home/Messages/Blocks/ImageBlock.tsx b/src/renderer/src/pages/home/Messages/Blocks/ImageBlock.tsx index ba11fb1a08..5229b12304 100644 --- a/src/renderer/src/pages/home/Messages/Blocks/ImageBlock.tsx +++ b/src/renderer/src/pages/home/Messages/Blocks/ImageBlock.tsx @@ -1,6 +1,6 @@ import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring' import ImageViewer from '@renderer/components/ImageViewer' -import type { ImageMessageBlock } from '@renderer/types/newMessage' +import { type ImageMessageBlock, MessageBlockStatus } from '@renderer/types/newMessage' import React from 'react' import styled from 'styled-components' @@ -9,23 +9,28 @@ interface Props { } const ImageBlock: React.FC = ({ block }) => { - if (block.status !== 'success') return - const images = block.metadata?.generateImageResponse?.images?.length - ? block.metadata?.generateImageResponse?.images - : block?.file?.path - ? [`file://${block?.file?.path}`] - : [] - return ( - - {images.map((src, index) => ( - - ))} - - ) + if (block.status === MessageBlockStatus.STREAMING || block.status === MessageBlockStatus.PROCESSING) + return + if (block.status === MessageBlockStatus.SUCCESS) { + const images = block.metadata?.generateImageResponse?.images?.length + ? block.metadata?.generateImageResponse?.images + : block?.file?.path + ? [`file://${block?.file?.path}`] + : [] + return ( + + {images.map((src, index) => ( + + ))} + + ) + } else { + return <> + } } const Container = styled.div` display: flex; diff --git a/src/renderer/src/pages/home/components/Suggestions.tsx b/src/renderer/src/pages/home/components/Suggestions.tsx deleted file mode 100644 index 50f1dffee0..0000000000 --- a/src/renderer/src/pages/home/components/Suggestions.tsx +++ /dev/null @@ -1,124 +0,0 @@ -import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring' -import { fetchSuggestions } from '@renderer/services/ApiService' -import { getUserMessage } from '@renderer/services/MessagesService' -import { useAppDispatch } from '@renderer/store' -import { sendMessage } from '@renderer/store/thunk/messageThunk' -import { Assistant, Suggestion } from '@renderer/types' -import type { Message } from '@renderer/types/newMessage' -import { last } from 'lodash' -import { FC, memo, useEffect, useState } from 'react' -import styled from 'styled-components' - -interface Props { - assistant: Assistant - messages: Message[] -} - -const suggestionsMap = new Map() - -const Suggestions: FC = ({ assistant, messages }) => { - const dispatch = useAppDispatch() - - const [suggestions, setSuggestions] = useState( - suggestionsMap.get(messages[messages.length - 1]?.id) || [] - ) - const [loadingSuggestions, setLoadingSuggestions] = useState(false) - - const handleSuggestionClick = async (content: string) => { - const { message: userMessage, blocks } = getUserMessage({ - assistant, - topic: assistant.topics[0], - content - }) - - await dispatch(sendMessage(userMessage, blocks, assistant, assistant.topics[0].id)) - } - - const suggestionsHandle = async () => { - if (loadingSuggestions) return - try { - setLoadingSuggestions(true) - const _suggestions = await fetchSuggestions({ - assistant, - messages - }) - if (_suggestions.length) { - setSuggestions(_suggestions) - suggestionsMap.set(messages[messages.length - 1].id, _suggestions) - } - } finally { - setLoadingSuggestions(false) - } - } - - useEffect(() => { - suggestionsHandle() - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []) - - useEffect(() => { - setSuggestions(suggestionsMap.get(messages[messages.length - 1]?.id) || []) - }, [messages]) - - if (last(messages)?.status !== 'success') { - return null - } - if (loadingSuggestions) { - return ( - - - - ) - } - - if (suggestions.length === 0) { - return null - } - - return ( - - - {suggestions.map((s, i) => ( - handleSuggestionClick(s.content)}> - {s.content} → - - ))} - - - ) -} - -const Container = styled.div` - display: flex; - flex-direction: column; - padding: 10px 10px 20px 65px; - display: flex; - width: 100%; - flex-direction: row; - flex-wrap: wrap; - gap: 15px; -` - -const SuggestionsContainer = styled.div` - display: flex; - flex-direction: row; - flex-wrap: wrap; - gap: 10px; -` - -const SuggestionItem = styled.div` - display: flex; - align-items: center; - width: fit-content; - padding: 5px 10px; - border-radius: 12px; - font-size: 12px; - color: var(--color-text); - background: var(--color-background-mute); - cursor: pointer; - &:hover { - opacity: 0.9; - } -` - -export default memo(Suggestions) diff --git a/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx b/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx index 70f976f4b8..50c9c80fa1 100644 --- a/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx +++ b/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx @@ -1,3 +1,4 @@ +import AiProvider from '@renderer/aiCore' import { TopView } from '@renderer/components/TopView' import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant' import { isEmbeddingModel, isRerankModel } from '@renderer/config/models' @@ -6,7 +7,6 @@ import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers' import { useKnowledgeBases } from '@renderer/hooks/useKnowledge' import { useProviders } from '@renderer/hooks/useProvider' import { SettingHelpText } from '@renderer/pages/settings' -import AiProvider from '@renderer/providers/AiProvider' import { getKnowledgeBaseParams } from '@renderer/services/KnowledgeService' import { getModelUniqId } from '@renderer/services/ModelService' import { KnowledgeBase, Model } from '@renderer/types' diff --git a/src/renderer/src/pages/paintings/AihubmixPage.tsx b/src/renderer/src/pages/paintings/AihubmixPage.tsx index 8e411fcae1..0b6ad7d722 100644 --- a/src/renderer/src/pages/paintings/AihubmixPage.tsx +++ b/src/renderer/src/pages/paintings/AihubmixPage.tsx @@ -11,7 +11,7 @@ import { usePaintings } from '@renderer/hooks/usePaintings' import { useAllProviders } from '@renderer/hooks/useProvider' import { useRuntime } from '@renderer/hooks/useRuntime' import { useSettings } from '@renderer/hooks/useSettings' -import AiProvider from '@renderer/providers/AiProvider' +import AiProvider from '@renderer/aiCore' import FileManager from '@renderer/services/FileManager' import { translateText } from '@renderer/services/TranslateService' import { useAppDispatch } from '@renderer/store' @@ -182,11 +182,9 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => { const base64s = await AI.generateImage({ prompt, model: painting.model, - config: { - aspectRatio: painting.aspectRatio?.replace('ASPECT_', '').replace('_', ':'), - numberOfImages: painting.model.startsWith('imagen-4.0-ultra-generate-exp') ? 1 : painting.numberOfImages, - personGeneration: painting.personGeneration - } + imageSize: painting.aspectRatio?.replace('ASPECT_', '').replace('_', ':') || '1:1', + batchSize: painting.model.startsWith('imagen-4.0-ultra-generate-exp') ? 1 : painting.numberOfImages || 1, + personGeneration: painting.personGeneration }) if (base64s?.length > 0) { const validFiles = await Promise.all( diff --git a/src/renderer/src/pages/paintings/SiliconPage.tsx b/src/renderer/src/pages/paintings/SiliconPage.tsx index 0a3648260b..6c4331e373 100644 --- a/src/renderer/src/pages/paintings/SiliconPage.tsx +++ b/src/renderer/src/pages/paintings/SiliconPage.tsx @@ -16,7 +16,7 @@ import { usePaintings } from '@renderer/hooks/usePaintings' import { useAllProviders } from '@renderer/hooks/useProvider' import { useRuntime } from '@renderer/hooks/useRuntime' import { useSettings } from '@renderer/hooks/useSettings' -import AiProvider from '@renderer/providers/AiProvider' +import AiProvider from '@renderer/aiCore' import { getProviderByModel } from '@renderer/services/AssistantService' import FileManager from '@renderer/services/FileManager' import { translateText } from '@renderer/services/TranslateService' diff --git a/src/renderer/src/pages/settings/ProviderSettings/ApiCheckPopup.tsx b/src/renderer/src/pages/settings/ProviderSettings/ApiCheckPopup.tsx index 91ae99da3c..afd6ba576e 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/ApiCheckPopup.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/ApiCheckPopup.tsx @@ -51,8 +51,8 @@ const PopupContainer: React.FC = ({ title, provider, model, apiKeys, type try { let valid = false if (type === 'provider' && model) { - const result = await checkApi({ ...(provider as Provider), apiKey: status.key }, model) - valid = result.valid + await checkApi({ ...(provider as Provider), apiKey: status.key }, model) + valid = true } else { const result = await WebSearchService.checkSearch({ ...(provider as WebSearchProvider), @@ -65,7 +65,7 @@ const PopupContainer: React.FC = ({ title, provider, model, apiKeys, type setKeyStatuses((prev) => prev.map((s, idx) => (idx === i ? { ...s, checking: false, isValid: valid } : s))) return { index: i, valid } - } catch (error) { + } catch (error: unknown) { // 处理错误情况 setKeyStatuses((prev) => prev.map((s, idx) => (idx === i ? { ...s, checking: false, isValid: false } : s))) return { index: i, valid: false } @@ -90,8 +90,8 @@ const PopupContainer: React.FC = ({ title, provider, model, apiKeys, type try { let valid = false if (type === 'provider' && model) { - const result = await checkApi({ ...(provider as Provider), apiKey: keyStatuses[keyIndex].key }, model) - valid = result.valid + await checkApi({ ...(provider as Provider), apiKey: keyStatuses[keyIndex].key }, model) + valid = true } else { const result = await WebSearchService.checkSearch({ ...(provider as WebSearchProvider), @@ -103,7 +103,7 @@ const PopupContainer: React.FC = ({ title, provider, model, apiKeys, type setKeyStatuses((prev) => prev.map((status, idx) => (idx === keyIndex ? { ...status, checking: false, isValid: valid } : status)) ) - } catch (error) { + } catch (error: unknown) { setKeyStatuses((prev) => prev.map((status, idx) => (idx === keyIndex ? { ...status, checking: false, isValid: false } : status)) ) diff --git a/src/renderer/src/pages/settings/ProviderSettings/EditModelsPopup.tsx b/src/renderer/src/pages/settings/ProviderSettings/EditModelsPopup.tsx index 4ff54fa55a..f659d7f094 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/EditModelsPopup.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/EditModelsPopup.tsx @@ -145,14 +145,17 @@ const PopupContainer: React.FC = ({ provider: _provider, resolve }) => { setListModels( models .map((model) => ({ - id: model.id, + // @ts-ignore modelId + id: model?.id || model?.name, // @ts-ignore name - name: model.name || model.id, + name: model?.display_name || model?.displayName || model?.name || model?.id, provider: _provider.id, - group: getDefaultGroupName(model.id, _provider.id), - // @ts-ignore name - description: model?.description, - owned_by: model?.owned_by + // @ts-ignore group + group: getDefaultGroupName(model?.id || model?.name, _provider.id), + // @ts-ignore description + description: model?.description || '', + // @ts-ignore owned_by + owned_by: model?.owned_by || '' })) .filter((model) => !isEmpty(model.name)) ) diff --git a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx index 493d0ac4f7..1555f899ab 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx @@ -7,7 +7,7 @@ import { PROVIDER_CONFIG } from '@renderer/config/providers' import { useTheme } from '@renderer/context/ThemeProvider' import { useAllProviders, useProvider, useProviders } from '@renderer/hooks/useProvider' import i18n from '@renderer/i18n' -import { isOpenAIProvider } from '@renderer/providers/AiProvider/ProviderFactory' +import { isOpenAIProvider } from '@renderer/aiCore/clients/ApiClientFactory' import { checkApi, formatApiKeys } from '@renderer/services/ApiService' import { checkModelsHealth, getModelCheckSummary } from '@renderer/services/HealthCheckService' import { isProviderSupportAuth } from '@renderer/services/ProviderService' @@ -231,22 +231,32 @@ const ProviderSetting: FC = ({ provider: _provider }) => { } else { setApiChecking(true) - const { valid, error } = await checkApi({ ...provider, apiKey, apiHost }, model) + try { + await checkApi({ ...provider, apiKey, apiHost }, model) - const errorMessage = error && error?.message ? ' ' + error?.message : '' + window.message.success({ + key: 'api-check', + style: { marginTop: '3vh' }, + duration: 2, + content: i18n.t('message.api.connection.success') + }) - window.message[valid ? 'success' : 'error']({ - key: 'api-check', - style: { marginTop: '3vh' }, - duration: valid ? 2 : 8, - content: valid - ? i18n.t('message.api.connection.success') - : i18n.t('message.api.connection.failed') + errorMessage - }) + setApiValid(true) + setTimeout(() => setApiValid(false), 3000) + } catch (error: any) { + const errorMessage = error?.message ? ' ' + error.message : '' - setApiValid(valid) - setApiChecking(false) - setTimeout(() => setApiValid(false), 3000) + window.message.error({ + key: 'api-check', + style: { marginTop: '3vh' }, + duration: 8, + content: i18n.t('message.api.connection.failed') + errorMessage + }) + + setApiValid(false) + } finally { + setApiChecking(false) + } } } diff --git a/src/renderer/src/providers/AiProvider/AihubmixProvider.ts b/src/renderer/src/providers/AiProvider/AihubmixProvider.ts deleted file mode 100644 index e42a4e2039..0000000000 --- a/src/renderer/src/providers/AiProvider/AihubmixProvider.ts +++ /dev/null @@ -1,117 +0,0 @@ -import { isOpenAILLMModel } from '@renderer/config/models' -import { getDefaultModel } from '@renderer/services/AssistantService' -import { Assistant, MCPCallToolResponse, MCPTool, MCPToolResponse, Model, Provider, Suggestion } from '@renderer/types' -import { Message } from '@renderer/types/newMessage' -import OpenAI from 'openai' - -import { CompletionsParams } from '.' -import AnthropicProvider from './AnthropicProvider' -import BaseProvider from './BaseProvider' -import GeminiProvider from './GeminiProvider' -import OpenAIProvider from './OpenAIProvider' -import OpenAIResponseProvider from './OpenAIResponseProvider' - -/** - * AihubmixProvider - 根据模型类型自动选择合适的提供商 - * 使用装饰器模式实现 - */ -export default class AihubmixProvider extends BaseProvider { - private providers: Map = new Map() - private defaultProvider: BaseProvider - private currentProvider: BaseProvider - - constructor(provider: Provider) { - super(provider) - - // 初始化各个提供商 - this.providers.set('claude', new AnthropicProvider(provider)) - this.providers.set('gemini', new GeminiProvider({ ...provider, apiHost: 'https://aihubmix.com/gemini' })) - this.providers.set('openai', new OpenAIResponseProvider(provider)) - this.providers.set('default', new OpenAIProvider(provider)) - - // 设置默认提供商 - this.defaultProvider = this.providers.get('default')! - this.currentProvider = this.defaultProvider - } - - /** - * 根据模型获取合适的提供商 - */ - private getProvider(model: Model): BaseProvider { - const id = model.id.toLowerCase() - // claude开头 - if (id.startsWith('claude')) { - return this.providers.get('claude')! - } - // gemini开头 或 imagen开头 且不以-nothink、-search结尾 - if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) { - return this.providers.get('gemini')! - } - if (isOpenAILLMModel(model)) { - return this.providers.get('openai')! - } - - return this.defaultProvider - } - - // 直接使用默认提供商的方法 - public async models(): Promise { - return this.defaultProvider.models() - } - - public async generateText(params: { prompt: string; content: string }): Promise { - return this.defaultProvider.generateText(params) - } - - public async generateImage(params: any): Promise { - return this.getProvider({ - id: params.model - } as unknown as Model).generateImage(params) - } - - public async generateImageByChat(params: any): Promise { - return this.defaultProvider.generateImageByChat(params) - } - - public async completions(params: CompletionsParams): Promise { - const model = params.assistant.model - this.currentProvider = this.getProvider(model!) - return this.currentProvider.completions(params) - } - - public async translate( - content: string, - assistant: Assistant, - onResponse?: (text: string, isComplete: boolean) => void - ): Promise { - return this.getProvider(assistant.model || getDefaultModel()).translate(content, assistant, onResponse) - } - - public async summaries(messages: Message[], assistant: Assistant): Promise { - return this.getProvider(assistant.model || getDefaultModel()).summaries(messages, assistant) - } - - public async summaryForSearch(messages: Message[], assistant: Assistant): Promise { - return this.getProvider(assistant.model || getDefaultModel()).summaryForSearch(messages, assistant) - } - - public async suggestions(messages: Message[], assistant: Assistant): Promise { - return this.getProvider(assistant.model || getDefaultModel()).suggestions(messages, assistant) - } - - public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> { - return this.getProvider(model).check(model, stream) - } - - public async getEmbeddingDimensions(model: Model): Promise { - return this.getProvider(model).getEmbeddingDimensions(model) - } - - public convertMcpTools(mcpTools: MCPTool[]) { - return this.currentProvider.convertMcpTools(mcpTools) as T[] - } - - public mcpToolCallResponseToMessage(mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) { - return this.currentProvider.mcpToolCallResponseToMessage(mcpToolResponse, resp, model) - } -} diff --git a/src/renderer/src/providers/AiProvider/AnthropicProvider.ts b/src/renderer/src/providers/AiProvider/AnthropicProvider.ts deleted file mode 100644 index 07ce8eeb87..0000000000 --- a/src/renderer/src/providers/AiProvider/AnthropicProvider.ts +++ /dev/null @@ -1,802 +0,0 @@ -import Anthropic from '@anthropic-ai/sdk' -import { - Base64ImageSource, - ImageBlockParam, - MessageCreateParamsNonStreaming, - MessageParam, - TextBlockParam, - ToolResultBlockParam, - ToolUnion, - ToolUseBlock, - WebSearchResultBlock, - WebSearchTool20250305, - WebSearchToolResultError -} from '@anthropic-ai/sdk/resources' -import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' -import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models' -import { getStoreSetting } from '@renderer/hooks/useSettings' -import i18n from '@renderer/i18n' -import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' -import FileManager from '@renderer/services/FileManager' -import { - filterContextMessages, - filterEmptyMessages, - filterUserRoleStartMessages -} from '@renderer/services/MessagesService' -import { - Assistant, - EFFORT_RATIO, - FileTypes, - MCPCallToolResponse, - MCPTool, - MCPToolResponse, - Metrics, - Model, - Provider, - Suggestion, - ToolCallResponse, - Usage, - WebSearchSource -} from '@renderer/types' -import { ChunkType } from '@renderer/types/chunk' -import type { Message } from '@renderer/types/newMessage' -import { removeSpecialCharactersForTopicName } from '@renderer/utils' -import { - anthropicToolUseToMcpTool, - isEnabledToolUse, - mcpToolCallResponseToAnthropicMessage, - mcpToolsToAnthropicTools, - parseAndCallTools -} from '@renderer/utils/mcp-tools' -import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' -import { buildSystemPrompt } from '@renderer/utils/prompt' -import { first, flatten, takeRight } from 'lodash' -import OpenAI from 'openai' - -import { CompletionsParams } from '.' -import BaseProvider from './BaseProvider' - -interface ReasoningConfig { - type: 'enabled' | 'disabled' - budget_tokens?: number -} - -export default class AnthropicProvider extends BaseProvider { - private sdk: Anthropic - - constructor(provider: Provider) { - super(provider) - this.sdk = new Anthropic({ - apiKey: this.apiKey, - baseURL: this.getBaseURL(), - dangerouslyAllowBrowser: true, - defaultHeaders: { - 'anthropic-beta': 'output-128k-2025-02-19' - } - }) - } - - public getBaseURL(): string { - return this.provider.apiHost - } - - /** - * Get the message parameter - * @param message - The message - * @returns The message parameter - */ - private async getMessageParam(message: Message): Promise { - const parts: MessageParam['content'] = [ - { - type: 'text', - text: getMainTextContent(message) - } - ] - - // Get and process image blocks - const imageBlocks = findImageBlocks(message) - for (const imageBlock of imageBlocks) { - if (imageBlock.file) { - // Handle uploaded file - const file = imageBlock.file - const base64Data = await window.api.file.base64Image(file.id + file.ext) - parts.push({ - type: 'image', - source: { - data: base64Data.base64, - media_type: base64Data.mime.replace('jpg', 'jpeg') as any, - type: 'base64' - } - }) - } - } - // Get and process file blocks - const fileBlocks = findFileBlocks(message) - for (const fileBlock of fileBlocks) { - const { file } = fileBlock - if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { - if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) { - const base64Data = await FileManager.readBase64File(file) - parts.push({ - type: 'document', - source: { - type: 'base64', - media_type: 'application/pdf', - data: base64Data - } - }) - } else { - const fileContent = await (await window.api.file.read(file.id + file.ext)).trim() - parts.push({ - type: 'text', - text: file.origin_name + '\n' + fileContent - }) - } - } - } - - return { - role: message.role === 'system' ? 'user' : message.role, - content: parts - } - } - - private async getWebSearchParams(model: Model): Promise { - if (!isWebSearchModel(model)) { - return undefined - } - - return { - type: 'web_search_20250305', - name: 'web_search', - max_uses: 5 - } as WebSearchTool20250305 - } - - override getTemperature(assistant: Assistant, model: Model): number | undefined { - if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { - return undefined - } - return assistant.settings?.temperature - } - - override getTopP(assistant: Assistant, model: Model): number | undefined { - if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { - return undefined - } - return assistant.settings?.topP - } - - /** - * Get the reasoning effort - * @param assistant - The assistant - * @param model - The model - * @returns The reasoning effort - */ - private getBudgetToken(assistant: Assistant, model: Model): ReasoningConfig | undefined { - if (!isReasoningModel(model)) { - return undefined - } - const { maxTokens } = getAssistantSettings(assistant) - - const reasoningEffort = assistant?.settings?.reasoning_effort - - if (reasoningEffort === undefined) { - return { - type: 'disabled' - } - } - - const effortRatio = EFFORT_RATIO[reasoningEffort] - - const budgetTokens = Math.max( - 1024, - Math.floor( - Math.min( - (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + - findTokenLimit(model.id)?.min!, - (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio - ) - ) - ) - - return { - type: 'enabled', - budget_tokens: budgetTokens - } - } - - /** - * Generate completions - * @param messages - The messages - * @param assistant - The assistant - * @param mcpTools - The MCP tools - * @param onChunk - The onChunk callback - * @param onFilterMessages - The onFilterMessages callback - */ - public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) { - const defaultModel = getDefaultModel() - const model = assistant.model || defaultModel - const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) - - const userMessagesParams: MessageParam[] = [] - - const _messages = filterUserRoleStartMessages( - filterContextMessages(filterEmptyMessages(takeRight(messages, contextCount + 2))) - ) - - onFilterMessages(_messages) - - for (const message of _messages) { - userMessagesParams.push(await this.getMessageParam(message)) - } - - const userMessages = flatten(userMessagesParams) - const lastUserMessage = _messages.findLast((m) => m.role === 'user') - - let systemPrompt = assistant.prompt - - const { tools } = this.setupToolsConfig({ - model, - mcpTools, - enableToolUse: isEnabledToolUse(assistant) - }) - - if (this.useSystemPromptForTools && mcpTools && mcpTools.length) { - systemPrompt = await buildSystemPrompt(systemPrompt, mcpTools) - } - - let systemMessage: TextBlockParam | undefined = undefined - if (systemPrompt) { - systemMessage = { - type: 'text', - text: systemPrompt - } - } - - const isEnabledBuiltinWebSearch = assistant.enableWebSearch && isWebSearchModel(model) - - if (isEnabledBuiltinWebSearch) { - const webSearchTool = await this.getWebSearchParams(model) - if (webSearchTool) { - tools.push(webSearchTool) - } - } - - const body: MessageCreateParamsNonStreaming = { - model: model.id, - messages: userMessages, - max_tokens: maxTokens || DEFAULT_MAX_TOKENS, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - system: systemMessage ? [systemMessage] : undefined, - // @ts-ignore thinking - thinking: this.getBudgetToken(assistant, model), - tools: tools, - ...this.getCustomParameters(assistant) - } - - const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id) - const { signal } = abortController - - const finalUsage: Usage = { - completion_tokens: 0, - prompt_tokens: 0, - total_tokens: 0 - } - - const finalMetrics: Metrics = { - completion_tokens: 0, - time_completion_millsec: 0, - time_first_token_millsec: 0 - } - const toolResponses: MCPToolResponse[] = [] - - const processStream = async (body: MessageCreateParamsNonStreaming, idx: number) => { - let time_first_token_millsec = 0 - - if (!streamOutput) { - const message = await this.sdk.messages.create({ ...body, stream: false }) - const time_completion_millsec = new Date().getTime() - start_time_millsec - - let text = '' - let reasoning_content = '' - - if (message.content && message.content.length > 0) { - const thinkingBlock = message.content.find((block) => block.type === 'thinking') - const textBlock = message.content.find((block) => block.type === 'text') - - if (thinkingBlock && 'thinking' in thinkingBlock) { - reasoning_content = thinkingBlock.thinking - } - - if (textBlock && 'text' in textBlock) { - text = textBlock.text - } - } - - return onChunk({ - type: ChunkType.BLOCK_COMPLETE, - response: { - text, - reasoning_content, - usage: message.usage as any, - metrics: { - completion_tokens: message.usage?.output_tokens || 0, - time_completion_millsec, - time_first_token_millsec: 0 - } - } - }) - } - - let thinking_content = '' - let isFirstChunk = true - - return new Promise((resolve, reject) => { - // 等待接口返回流 - const toolCalls: ToolUseBlock[] = [] - - this.sdk.messages - .stream({ ...body, stream: true }, { signal, timeout: 5 * 60 * 1000 }) - .on('text', (text) => { - if (isFirstChunk) { - isFirstChunk = false - if (time_first_token_millsec == 0) { - time_first_token_millsec = new Date().getTime() - } else { - onChunk({ - type: ChunkType.THINKING_COMPLETE, - text: thinking_content, - thinking_millsec: new Date().getTime() - time_first_token_millsec - }) - } - } - - onChunk({ type: ChunkType.TEXT_DELTA, text }) - }) - .on('contentBlock', (block) => { - if (block.type === 'server_tool_use' && block.name === 'web_search') { - onChunk({ - type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS - }) - } else if (block.type === 'web_search_tool_result') { - if ( - block.content && - (block.content as WebSearchToolResultError).type === 'web_search_tool_result_error' - ) { - onChunk({ - type: ChunkType.ERROR, - error: { - code: (block.content as WebSearchToolResultError).error_code, - message: (block.content as WebSearchToolResultError).error_code - } - }) - } else { - onChunk({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - results: block.content as Array, - source: WebSearchSource.ANTHROPIC - } - }) - } - } - if (block.type === 'tool_use') { - toolCalls.push(block) - } - }) - .on('thinking', (thinking) => { - if (time_first_token_millsec == 0) { - time_first_token_millsec = new Date().getTime() - } - - onChunk({ - type: ChunkType.THINKING_DELTA, - text: thinking, - thinking_millsec: new Date().getTime() - time_first_token_millsec - }) - thinking_content += thinking - }) - .on('finalMessage', async (message) => { - const toolResults: Awaited> = [] - // tool call - if (toolCalls.length > 0) { - const mcpToolResponses = toolCalls - .map((toolCall) => { - const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall) - if (!mcpTool) { - return undefined - } - return { - id: toolCall.id, - toolCallId: toolCall.id, - tool: mcpTool, - arguments: toolCall.input as Record, - status: 'pending' - } as ToolCallResponse - }) - .filter((t) => typeof t !== 'undefined') - toolResults.push( - ...(await parseAndCallTools( - mcpToolResponses, - toolResponses, - onChunk, - this.mcpToolCallResponseToMessage, - model, - mcpTools - )) - ) - } - - // tool use - const content = message.content[0] - if (content && content.type === 'text') { - onChunk({ type: ChunkType.TEXT_COMPLETE, text: content.text }) - toolResults.push( - ...(await parseAndCallTools( - content.text, - toolResponses, - onChunk, - this.mcpToolCallResponseToMessage, - model, - mcpTools - )) - ) - } - - if (thinking_content) { - onChunk({ - type: ChunkType.THINKING_COMPLETE, - text: thinking_content, - thinking_millsec: new Date().getTime() - time_first_token_millsec - }) - } - - userMessages.push({ - role: message.role, - content: message.content - }) - - if (toolResults.length > 0) { - toolResults.forEach((ts) => userMessages.push(ts as MessageParam)) - const newBody = body - newBody.messages = userMessages - - onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) - try { - await processStream(newBody, idx + 1) - } catch (error) { - console.error('Error processing stream:', error) - reject(error) - } - } - - // 直接修改finalUsage对象会报错,TypeError: Cannot assign to read only property 'prompt_tokens' of object '#' - // 暂未找到原因 - const updatedUsage: Usage = { - ...finalUsage, - prompt_tokens: finalUsage.prompt_tokens + (message.usage?.input_tokens || 0), - completion_tokens: finalUsage.completion_tokens + (message.usage?.output_tokens || 0) - } - updatedUsage.total_tokens = updatedUsage.prompt_tokens + updatedUsage.completion_tokens - - const updatedMetrics: Metrics = { - ...finalMetrics, - completion_tokens: updatedUsage.completion_tokens, - time_completion_millsec: - finalMetrics.time_completion_millsec + (new Date().getTime() - start_time_millsec), - time_first_token_millsec: time_first_token_millsec - start_time_millsec - } - - Object.assign(finalUsage, updatedUsage) - Object.assign(finalMetrics, updatedMetrics) - - onChunk({ - type: ChunkType.BLOCK_COMPLETE, - response: { - usage: updatedUsage, - metrics: updatedMetrics - } - }) - resolve() - }) - .on('error', (error) => reject(error)) - .on('abort', () => { - reject(new Error('Request was aborted.')) - }) - }) - } - onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) - const start_time_millsec = new Date().getTime() - await processStream(body, 0).finally(() => { - cleanup() - }) - } - - /** - * Translate a message - * @param content - * @param assistant - The assistant - * @param onResponse - The onResponse callback - * @returns The translated message - */ - public async translate( - content: string, - assistant: Assistant, - onResponse?: (text: string, isComplete: boolean) => void - ) { - const defaultModel = getDefaultModel() - const model = assistant.model || defaultModel - - const messagesForApi = [{ role: 'user' as const, content: content }] - - const stream = !!onResponse - - const body: MessageCreateParamsNonStreaming = { - model: model.id, - messages: messagesForApi, - max_tokens: 4096, - temperature: assistant?.settings?.temperature, - system: assistant.prompt - } - - if (!stream) { - const response = await this.sdk.messages.create({ ...body, stream: false }) - return response.content[0].type === 'text' ? response.content[0].text : '' - } - - let text = '' - - return new Promise((resolve, reject) => { - this.sdk.messages - .stream({ ...body, stream: true }) - .on('text', (_text) => { - text += _text - onResponse?.(text, false) - }) - .on('finalMessage', () => { - onResponse?.(text, true) - resolve(text) - }) - .on('error', (error) => reject(error)) - }) - } - - /** - * Summarize a message - * @param messages - The messages - * @param assistant - The assistant - * @returns The summary - */ - public async summaries(messages: Message[], assistant: Assistant): Promise { - const model = getTopNamingModel() || assistant.model || getDefaultModel() - - const userMessages = takeRight(messages, 5).map((message) => ({ - role: message.role, - content: getMainTextContent(message) - })) - - if (first(userMessages)?.role === 'assistant') { - userMessages.shift() - } - - const userMessageContent = userMessages.reduce((prev, curr) => { - const currentContent = curr.role === 'user' ? `User: ${curr.content}` : `Assistant: ${curr.content}` - return prev + (prev ? '\n' : '') + currentContent - }, '') - - const systemMessage = { - role: 'system', - content: (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title') - } - - const userMessage = { - role: 'user', - content: userMessageContent - } - - const message = await this.sdk.messages.create({ - messages: [userMessage] as Anthropic.Messages.MessageParam[], - model: model.id, - system: systemMessage.content, - stream: false, - max_tokens: 4096 - }) - - const responseContent = message.content[0].type === 'text' ? message.content[0].text : '' - return removeSpecialCharactersForTopicName(responseContent) - } - - /** - * Summarize a message for search - * @param messages - The messages - * @param assistant - The assistant - * @returns The summary - */ - public async summaryForSearch(messages: Message[], assistant: Assistant): Promise { - const model = assistant.model || getDefaultModel() - const systemMessage = { content: assistant.prompt } - - const userMessageContent = messages.map((m) => getMainTextContent(m)).join('\n') - - const userMessage = { - role: 'user' as const, - content: userMessageContent - } - const lastUserMessage = messages[messages.length - 1] - const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id) - const { signal } = abortController - - const response = await this.sdk.messages - .create( - { - messages: [userMessage], - model: model.id, - system: systemMessage.content, - stream: false, - max_tokens: 4096 - }, - { timeout: 20 * 1000, signal } - ) - .finally(cleanup) - - return response.content[0].type === 'text' ? response.content[0].text : '' - } - - /** - * Generate text - * @param prompt - The prompt - * @param content - The content - * @returns The generated text - */ - public async generateText({ prompt, content }: { prompt: string; content: string }): Promise { - const model = getDefaultModel() - - const message = await this.sdk.messages.create({ - model: model.id, - system: prompt, - stream: false, - max_tokens: 4096, - messages: [ - { - role: 'user', - content - } - ] - }) - - return message.content[0].type === 'text' ? message.content[0].text : '' - } - - /** - * Generate an image - * @returns The generated image - */ - public async generateImage(): Promise { - return [] - } - - public async generateImageByChat(): Promise { - throw new Error('Method not implemented.') - } - - /** - * Generate suggestions - * @returns The suggestions - */ - public async suggestions(): Promise { - return [] - } - - /** - * Check if the model is valid - * @param model - The model - * @param stream - Whether to use streaming interface - * @returns The validity of the model - */ - public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> { - if (!model) { - return { valid: false, error: new Error('No model found') } - } - - const body = { - model: model.id, - messages: [{ role: 'user' as const, content: 'hi' }], - max_tokens: 2, // api文档写的 x>1 - stream - } - - try { - if (!stream) { - const message = await this.sdk.messages.create(body as MessageCreateParamsNonStreaming) - return { - valid: message.content.length > 0, - error: null - } - } else { - return await new Promise((resolve, reject) => { - let hasContent = false - this.sdk.messages - .stream(body) - .on('text', (text) => { - if (!hasContent && text) { - hasContent = true - resolve({ valid: true, error: null }) - } - }) - .on('finalMessage', (message) => { - if (!hasContent && message.content && message.content.length > 0) { - hasContent = true - resolve({ valid: true, error: null }) - } - if (!hasContent) { - reject(new Error('Empty streaming response')) - } - }) - .on('error', (error) => reject(error)) - }) - } - } catch (error: any) { - return { - valid: false, - error - } - } - } - - /** - * Get the models - * @returns The models - */ - public async models(): Promise { - return [] - } - - public async getEmbeddingDimensions(): Promise { - return 0 - } - - public convertMcpTools(mcpTools: MCPTool[]): T[] { - return mcpToolsToAnthropicTools(mcpTools) as T[] - } - - public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => { - if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { - return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model) - } else if ('toolCallId' in mcpToolResponse) { - return { - role: 'user', - content: [ - { - type: 'tool_result', - tool_use_id: mcpToolResponse.toolCallId!, - content: resp.content - .map((item) => { - if (item.type === 'text') { - return { - type: 'text', - text: item.text || '' - } satisfies TextBlockParam - } - if (item.type === 'image') { - return { - type: 'image', - source: { - data: item.data || '', - media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'], - type: 'base64' - } - } satisfies ImageBlockParam - } - return - }) - .filter((n) => typeof n !== 'undefined'), - is_error: resp.isError - } satisfies ToolResultBlockParam - ] - } - } - return - } -} diff --git a/src/renderer/src/providers/AiProvider/GeminiProvider.ts b/src/renderer/src/providers/AiProvider/GeminiProvider.ts deleted file mode 100644 index 97c90363d0..0000000000 --- a/src/renderer/src/providers/AiProvider/GeminiProvider.ts +++ /dev/null @@ -1,1238 +0,0 @@ -import { - Content, - File, - FileState, - FinishReason, - FunctionCall, - GenerateContentConfig, - GenerateContentResponse, - GenerateImagesParameters, - GoogleGenAI, - HarmBlockThreshold, - HarmCategory, - Modality, - Pager, - Part, - PartUnion, - SafetySetting, - ThinkingConfig, - Tool -} from '@google/genai' -import { nanoid } from '@reduxjs/toolkit' -import { - findTokenLimit, - isGeminiReasoningModel, - isGemmaModel, - isGenerateImageModel, - isVisionModel, - isWebSearchModel -} from '@renderer/config/models' -import { getStoreSetting } from '@renderer/hooks/useSettings' -import i18n from '@renderer/i18n' -import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' -import { CacheService } from '@renderer/services/CacheService' -import { EVENT_NAMES } from '@renderer/services/EventService' -import { - filterContextMessages, - filterEmptyMessages, - filterUserRoleStartMessages -} from '@renderer/services/MessagesService' -import { - Assistant, - EFFORT_RATIO, - FileType, - FileTypes, - MCPCallToolResponse, - MCPTool, - MCPToolResponse, - Metrics, - Model, - Provider, - Suggestion, - ToolCallResponse, - Usage, - WebSearchSource -} from '@renderer/types' -import { BlockCompleteChunk, Chunk, ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk' -import type { Message, Response } from '@renderer/types/newMessage' -import { removeSpecialCharactersForTopicName } from '@renderer/utils' -import { - geminiFunctionCallToMcpTool, - isEnabledToolUse, - mcpToolCallResponseToGeminiMessage, - mcpToolsToGeminiTools, - parseAndCallTools -} from '@renderer/utils/mcp-tools' -import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' -import { buildSystemPrompt } from '@renderer/utils/prompt' -import { MB } from '@shared/config/constant' -import axios from 'axios' -import { flatten, isEmpty, takeRight } from 'lodash' -import OpenAI from 'openai' - -import { CompletionsParams } from '.' -import BaseProvider from './BaseProvider' - -export default class GeminiProvider extends BaseProvider { - private sdk: GoogleGenAI - - constructor(provider: Provider) { - super(provider) - this.sdk = new GoogleGenAI({ vertexai: false, apiKey: this.apiKey, httpOptions: { baseUrl: this.getBaseURL() } }) - } - - public getBaseURL(): string { - return this.provider.apiHost - } - - /** - * Handle a PDF file - * @param file - The file - * @returns The part - */ - private async handlePdfFile(file: FileType): Promise { - const smallFileSize = 20 * MB - const isSmallFile = file.size < smallFileSize - - if (isSmallFile) { - const { data, mimeType } = await this.base64File(file) - return { - inlineData: { - data, - mimeType - } as Part['inlineData'] - } - } - - // Retrieve file from Gemini uploaded files - const fileMetadata: File | undefined = await this.retrieveFile(file) - - if (fileMetadata) { - return { - fileData: { - fileUri: fileMetadata.uri, - mimeType: fileMetadata.mimeType - } as Part['fileData'] - } - } - - // If file is not found, upload it to Gemini - const result = await this.uploadFile(file) - - return { - fileData: { - fileUri: result.uri, - mimeType: result.mimeType - } as Part['fileData'] - } - } - - /** - * Get the message contents - * @param message - The message - * @returns The message contents - */ - private async getMessageContents(message: Message): Promise { - const role = message.role === 'user' ? 'user' : 'model' - const parts: Part[] = [{ text: await this.getMessageContent(message) }] - // Add any generated images from previous responses - const imageBlocks = findImageBlocks(message) - for (const imageBlock of imageBlocks) { - if ( - imageBlock.metadata?.generateImageResponse?.images && - imageBlock.metadata.generateImageResponse.images.length > 0 - ) { - for (const imageUrl of imageBlock.metadata.generateImageResponse.images) { - if (imageUrl && imageUrl.startsWith('data:')) { - // Extract base64 data and mime type from the data URL - const matches = imageUrl.match(/^data:(.+);base64,(.*)$/) - if (matches && matches.length === 3) { - const mimeType = matches[1] - const base64Data = matches[2] - parts.push({ - inlineData: { - data: base64Data, - mimeType: mimeType - } as Part['inlineData'] - }) - } - } - } - } - const file = imageBlock.file - if (file) { - const base64Data = await window.api.file.base64Image(file.id + file.ext) - parts.push({ - inlineData: { - data: base64Data.base64, - mimeType: base64Data.mime - } as Part['inlineData'] - }) - } - } - - const fileBlocks = findFileBlocks(message) - for (const fileBlock of fileBlocks) { - const file = fileBlock.file - if (file.type === FileTypes.IMAGE) { - const base64Data = await window.api.file.base64Image(file.id + file.ext) - parts.push({ - inlineData: { - data: base64Data.base64, - mimeType: base64Data.mime - } as Part['inlineData'] - }) - } - - if (file.ext === '.pdf') { - parts.push(await this.handlePdfFile(file)) - continue - } - if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { - const fileContent = await (await window.api.file.read(file.id + file.ext)).trim() - parts.push({ - text: file.origin_name + '\n' + fileContent - }) - } - } - - return { - role, - parts: parts - } - } - - private async getImageFileContents(message: Message): Promise { - const role = message.role === 'user' ? 'user' : 'model' - const content = getMainTextContent(message) - const parts: Part[] = [{ text: content }] - const imageBlocks = findImageBlocks(message) - for (const imageBlock of imageBlocks) { - if ( - imageBlock.metadata?.generateImageResponse?.images && - imageBlock.metadata.generateImageResponse.images.length > 0 - ) { - for (const imageUrl of imageBlock.metadata.generateImageResponse.images) { - if (imageUrl && imageUrl.startsWith('data:')) { - // Extract base64 data and mime type from the data URL - const matches = imageUrl.match(/^data:(.+);base64,(.*)$/) - if (matches && matches.length === 3) { - const mimeType = matches[1] - const base64Data = matches[2] - parts.push({ - inlineData: { - data: base64Data, - mimeType: mimeType - } as Part['inlineData'] - }) - } - } - } - } - const file = imageBlock.file - if (file) { - const base64Data = await window.api.file.base64Image(file.id + file.ext) - parts.push({ - inlineData: { - data: base64Data.base64, - mimeType: base64Data.mime - } as Part['inlineData'] - }) - } - } - return { - role, - parts: parts - } - } - - /** - * Get the safety settings - * @returns The safety settings - */ - private getSafetySettings(): SafetySetting[] { - const safetyThreshold = 'OFF' as HarmBlockThreshold - - return [ - { - category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, - threshold: safetyThreshold - }, - { - category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - threshold: safetyThreshold - }, - { - category: HarmCategory.HARM_CATEGORY_HARASSMENT, - threshold: safetyThreshold - }, - { - category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold: safetyThreshold - }, - { - category: HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, - threshold: HarmBlockThreshold.BLOCK_NONE - } - ] - } - - /** - * Get the reasoning effort for the assistant - * @param assistant - The assistant - * @param model - The model - * @returns The reasoning effort - */ - private getBudgetToken(assistant: Assistant, model: Model) { - if (isGeminiReasoningModel(model)) { - const reasoningEffort = assistant?.settings?.reasoning_effort - const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini-.*-flash.*$') - - // 如果thinking_budget是undefined,不思考 - if (reasoningEffort === undefined) { - return { - thinkingConfig: { - includeThoughts: false, - ...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {}) - } as ThinkingConfig - } - } - - const effortRatio = EFFORT_RATIO[reasoningEffort] - - if (effortRatio > 1) { - return { - thinkingConfig: { - includeThoughts: true - } - } - } - - const { max } = findTokenLimit(model.id) || { max: 0 } - const budget = Math.floor(max * effortRatio) - - return { - thinkingConfig: { - ...(budget > 0 ? { thinkingBudget: budget } : {}), - includeThoughts: true - } as ThinkingConfig - } - } - - return {} - } - - /** - * Generate completions - * @param messages - The messages - * @param assistant - The assistant - * @param mcpTools - The MCP tools - * @param onChunk - The onChunk callback - * @param onFilterMessages - The onFilterMessages callback - */ - public async completions({ - messages, - assistant, - mcpTools, - onChunk, - onFilterMessages - }: CompletionsParams): Promise { - const defaultModel = getDefaultModel() - const model = assistant.model || defaultModel - let canGenerateImage = false - if (isGenerateImageModel(model)) { - if (model.id === 'gemini-2.0-flash-exp') { - canGenerateImage = assistant.enableGenerateImage! - } else { - canGenerateImage = true - } - } - if (canGenerateImage) { - await this.generateImageByChat({ messages, assistant, onChunk }) - return - } - const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) - - const userMessages = filterUserRoleStartMessages( - filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2))) - ) - onFilterMessages(userMessages) - - const userLastMessage = userMessages.pop() - - const history: Content[] = [] - - for (const message of userMessages) { - history.push(await this.getMessageContents(message)) - } - - let systemInstruction = assistant.prompt - - const { tools } = this.setupToolsConfig({ - mcpTools, - model, - enableToolUse: isEnabledToolUse(assistant) - }) - - if (this.useSystemPromptForTools) { - systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools) - } - - const toolResponses: MCPToolResponse[] = [] - - if (assistant.enableWebSearch && isWebSearchModel(model)) { - tools.push({ - // @ts-ignore googleSearch is not a valid tool for Gemini - googleSearch: {} - }) - } - - const generateContentConfig: GenerateContentConfig = { - safetySettings: this.getSafetySettings(), - // generate image don't need system instruction - systemInstruction: isGemmaModel(model) ? undefined : systemInstruction, - temperature: this.getTemperature(assistant, model), - topP: this.getTopP(assistant, model), - maxOutputTokens: maxTokens, - tools: tools, - ...this.getBudgetToken(assistant, model), - ...this.getCustomParameters(assistant) - } - - const messageContents: Content = await this.getMessageContents(userLastMessage!) - - const chat = this.sdk.chats.create({ - model: model.id, - config: generateContentConfig, - history: history - }) - - if (isGemmaModel(model) && assistant.prompt) { - const isFirstMessage = history.length === 0 - if (isFirstMessage && messageContents) { - const systemMessage = [ - { - text: - 'user\n' + - systemInstruction + - '\n' + - 'user\n' + - (messageContents?.parts?.[0] as Part).text + - '' - } - ] as Part[] - if (messageContents && messageContents.parts) { - messageContents.parts[0] = systemMessage[0] - } - } - } - - const finalUsage: Usage = { - completion_tokens: 0, - prompt_tokens: 0, - total_tokens: 0 - } - - const finalMetrics: Metrics = { - completion_tokens: 0, - time_completion_millsec: 0, - time_first_token_millsec: 0 - } - - const { cleanup, abortController } = this.createAbortController(userLastMessage?.id, true) - - const processToolResults = async (toolResults: Awaited>, idx: number) => { - if (toolResults.length === 0) return - const newChat = this.sdk.chats.create({ - model: model.id, - config: generateContentConfig, - history: history as Content[] - }) - - const newStream = await newChat.sendMessageStream({ - message: flatten(toolResults.map((ts) => (ts as Content).parts)) as PartUnion, - config: { - ...generateContentConfig, - abortSignal: abortController.signal - } - }) - await processStream(newStream, idx + 1) - } - - const processToolCalls = async (toolCalls: FunctionCall[]) => { - const mcpToolResponses: ToolCallResponse[] = toolCalls - .map((toolCall) => { - const mcpTool = geminiFunctionCallToMcpTool(mcpTools, toolCall) - if (!mcpTool) return undefined - - const parsedArgs = (() => { - try { - return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args - } catch { - return toolCall.args - } - })() - - return { - id: toolCall.id || nanoid(), - toolCallId: toolCall.id, - tool: mcpTool, - arguments: parsedArgs, - status: 'pending' - } as ToolCallResponse - }) - .filter((t): t is ToolCallResponse => typeof t !== 'undefined') - - return await parseAndCallTools( - mcpToolResponses, - toolResponses, - onChunk, - this.mcpToolCallResponseToMessage, - model, - mcpTools - ) - } - - const processToolUses = async (content: string) => { - return await parseAndCallTools( - content, - toolResponses, - onChunk, - this.mcpToolCallResponseToMessage, - model, - mcpTools - ) - } - - const processStream = async ( - stream: AsyncGenerator | GenerateContentResponse, - idx: number - ) => { - history.push(messageContents) - - let functionCalls: FunctionCall[] = [] - let time_first_token_millsec = 0 - - if (stream instanceof GenerateContentResponse) { - const time_completion_millsec = new Date().getTime() - start_time_millsec - - const toolResults: Awaited> = [] - if (stream.text?.length) { - toolResults.push(...(await processToolUses(stream.text))) - } - stream.candidates?.forEach((candidate) => { - if (candidate.content) { - history.push(candidate.content) - - candidate.content.parts?.forEach((part) => { - if (part.functionCall) { - functionCalls.push(part.functionCall) - } - const text = part.text || '' - if (part.thought) { - onChunk({ type: ChunkType.THINKING_DELTA, text }) - onChunk({ type: ChunkType.THINKING_COMPLETE, text }) - } else if (part.text) { - onChunk({ type: ChunkType.TEXT_DELTA, text }) - onChunk({ type: ChunkType.TEXT_COMPLETE, text }) - } - }) - } - }) - - if (functionCalls.length) { - toolResults.push(...(await processToolCalls(functionCalls))) - } - if (stream.text?.length) { - toolResults.push(...(await processToolUses(stream.text))) - } - if (toolResults.length) { - await processToolResults(toolResults, idx) - } - onChunk({ - type: ChunkType.BLOCK_COMPLETE, - response: { - text: stream.text, - usage: { - prompt_tokens: stream.usageMetadata?.promptTokenCount || 0, - thoughts_tokens: stream.usageMetadata?.thoughtsTokenCount || 0, - completion_tokens: stream.usageMetadata?.candidatesTokenCount || 0, - total_tokens: stream.usageMetadata?.totalTokenCount || 0 - }, - metrics: { - completion_tokens: stream.usageMetadata?.candidatesTokenCount, - time_completion_millsec, - time_first_token_millsec: 0 - }, - webSearch: { - results: stream.candidates?.[0]?.groundingMetadata, - source: 'gemini' - } - } as Response - } as BlockCompleteChunk) - } else { - let content = '' - let thinkingContent = '' - for await (const chunk of stream) { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break - - if (chunk.candidates?.[0]?.content?.parts && chunk.candidates[0].content.parts.length > 0) { - const parts = chunk.candidates[0].content.parts - for (const part of parts) { - if (!part.text) { - continue - } else if (part.thought) { - if (time_first_token_millsec === 0) { - time_first_token_millsec = new Date().getTime() - } - thinkingContent += part.text - onChunk({ type: ChunkType.THINKING_DELTA, text: part.text || '' }) - } else { - if (time_first_token_millsec == 0) { - time_first_token_millsec = new Date().getTime() - } else { - onChunk({ - type: ChunkType.THINKING_COMPLETE, - text: thinkingContent, - thinking_millsec: new Date().getTime() - time_first_token_millsec - }) - } - content += part.text - onChunk({ type: ChunkType.TEXT_DELTA, text: part.text }) - } - } - } - - if (chunk.candidates?.[0]?.finishReason) { - if (chunk.text) { - onChunk({ type: ChunkType.TEXT_COMPLETE, text: content }) - } - if (chunk.usageMetadata) { - finalUsage.prompt_tokens += chunk.usageMetadata.promptTokenCount || 0 - finalUsage.completion_tokens += chunk.usageMetadata.candidatesTokenCount || 0 - finalUsage.total_tokens += chunk.usageMetadata.totalTokenCount || 0 - } - if (chunk.candidates?.[0]?.groundingMetadata) { - const groundingMetadata = chunk.candidates?.[0]?.groundingMetadata - onChunk({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - results: groundingMetadata, - source: WebSearchSource.GEMINI - } - } as LLMWebSearchCompleteChunk) - } - if (chunk.functionCalls) { - chunk.candidates?.forEach((candidate) => { - if (candidate.content) { - history.push(candidate.content) - } - }) - functionCalls = functionCalls.concat(chunk.functionCalls) - } - - finalMetrics.completion_tokens = finalUsage.completion_tokens - finalMetrics.time_completion_millsec += new Date().getTime() - start_time_millsec - finalMetrics.time_first_token_millsec = - (finalMetrics.time_first_token_millsec || 0) + (time_first_token_millsec - start_time_millsec) - } - } - - // --- End Incremental onChunk calls --- - - // Call processToolUses AFTER potentially processing text content in this chunk - // This assumes tools might be specified within the text stream - // Note: parseAndCallTools inside should handle its own onChunk for tool responses - let toolResults: Awaited> = [] - if (functionCalls.length) { - toolResults = await processToolCalls(functionCalls) - } - if (content.length) { - toolResults = toolResults.concat(await processToolUses(content)) - } - if (toolResults.length) { - await processToolResults(toolResults, idx) - } - - // FIXME: 由于递归,会发送n次 - onChunk({ - type: ChunkType.BLOCK_COMPLETE, - response: { - usage: finalUsage, - metrics: finalMetrics - } - }) - } - } - - // 在发起请求之前开始计时 - const start_time_millsec = new Date().getTime() - - if (!streamOutput) { - onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) - const response = await chat.sendMessage({ - message: messageContents as PartUnion, - config: { - ...generateContentConfig, - abortSignal: abortController.signal - } - }) - return await processStream(response, 0).then(cleanup) - } - - onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) - const userMessagesStream = await chat.sendMessageStream({ - message: messageContents as PartUnion, - config: { - ...generateContentConfig, - abortSignal: abortController.signal - } - }) - - await processStream(userMessagesStream, 0).finally(cleanup) - } - - /** - * Translate a message - * @param content - * @param assistant - The assistant - * @param onResponse - The onResponse callback - * @returns The translated message - */ - public async translate( - content: string, - assistant: Assistant, - onResponse?: (text: string, isComplete: boolean) => void - ) { - const defaultModel = getDefaultModel() - const { maxTokens } = getAssistantSettings(assistant) - const model = assistant.model || defaultModel - - const _content = - isGemmaModel(model) && assistant.prompt - ? `user\n${assistant.prompt}\nuser\n${content}` - : content - if (!onResponse) { - const response = await this.sdk.models.generateContent({ - model: model.id, - config: { - maxOutputTokens: maxTokens, - temperature: assistant?.settings?.temperature, - systemInstruction: isGemmaModel(model) ? undefined : assistant.prompt - }, - contents: [ - { - role: 'user', - parts: [{ text: _content }] - } - ] - }) - return response.text || '' - } - - const response = await this.sdk.models.generateContentStream({ - model: model.id, - config: { - maxOutputTokens: maxTokens, - temperature: assistant?.settings?.temperature, - systemInstruction: isGemmaModel(model) ? undefined : assistant.prompt - }, - contents: [ - { - role: 'user', - parts: [{ text: content }] - } - ] - }) - let text = '' - - for await (const chunk of response) { - text += chunk.text - onResponse?.(text, false) - } - - onResponse?.(text, true) - - return text - } - - /** - * Summarize a message - * @param messages - The messages - * @param assistant - The assistant - * @returns The summary - */ - public async summaries(messages: Message[], assistant: Assistant): Promise { - const model = getTopNamingModel() || assistant.model || getDefaultModel() - - const userMessages = takeRight(messages, 5).map((message) => ({ - role: message.role, - // Get content using helper - content: getMainTextContent(message) - })) - - const userMessageContent = userMessages.reduce((prev, curr) => { - const content = curr.role === 'user' ? `User: ${curr.content}` : `Assistant: ${curr.content}` - return prev + (prev ? '\n' : '') + content - }, '') - - const systemMessage = { - role: 'system', - content: (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title') - } - - const userMessage = { - role: 'user', - content: userMessageContent - } - - const content = isGemmaModel(model) - ? `user\n${systemMessage.content}\nuser\n${userMessage.content}` - : userMessage.content - - const response = await this.sdk.models.generateContent({ - model: model.id, - config: { - systemInstruction: isGemmaModel(model) ? undefined : systemMessage.content - }, - contents: [ - { - role: 'user', - parts: [{ text: content }] - } - ] - }) - - return removeSpecialCharactersForTopicName(response.text || '') - } - - /** - * Generate text - * @param prompt - The prompt - * @param content - The content - * @returns The generated text - */ - public async generateText({ prompt, content }: { prompt: string; content: string }): Promise { - const model = getDefaultModel() - const MessageContent = isGemmaModel(model) - ? `user\n${prompt}\nuser\n${content}` - : content - const response = await this.sdk.models.generateContent({ - model: model.id, - config: { - systemInstruction: isGemmaModel(model) ? undefined : prompt - }, - contents: [ - { - role: 'user', - parts: [{ text: MessageContent }] - } - ] - }) - - return response.text || '' - } - - /** - * Generate suggestions - * @returns The suggestions - */ - public async suggestions(): Promise { - return [] - } - - /** - * Summarize a message for search - * @param messages - The messages - * @param assistant - The assistant - * @returns The summary - */ - public async summaryForSearch(messages: Message[], assistant: Assistant): Promise { - const model = assistant.model || getDefaultModel() - - const systemMessage = { - role: 'system', - content: assistant.prompt - } - - // Get content using helper - const userMessageContent = messages.map(getMainTextContent).join('\n') - - const content = isGemmaModel(model) - ? `user\n${systemMessage.content}\nuser\n${userMessageContent}` - : userMessageContent - - const lastUserMessage = messages[messages.length - 1] - const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id) - const { signal } = abortController - - const response = await this.sdk.models - .generateContent({ - model: model.id, - config: { - systemInstruction: isGemmaModel(model) ? undefined : systemMessage.content, - temperature: assistant?.settings?.temperature, - httpOptions: { - timeout: 20 * 1000 - }, - abortSignal: signal - }, - contents: [ - { - role: 'user', - parts: [{ text: content }] - } - ] - }) - .finally(cleanup) - - return response.text || '' - } - - /** - * Generate an image - * @param params - The parameters for image generation - * @returns The generated image URLs - */ - public async generateImage(params: GenerateImagesParameters): Promise { - try { - console.log('[GeminiProvider] generateImage params:', params) - const response = await this.sdk.models.generateImages(params) - - if (!response.generatedImages || response.generatedImages.length === 0) { - return [] - } - - const images = response.generatedImages - .filter((image) => image.image?.imageBytes) - .map((image) => { - const dataPrefix = `data:${image.image?.mimeType || 'image/png'};base64,` - return dataPrefix + image.image?.imageBytes - }) - // console.log(response?.generatedImages?.[0]?.image?.imageBytes); - return images - } catch (error) { - console.error('[generateImage] error:', error) - throw error - } - } - - /** - * 处理Gemini图像响应 - * @param chunk - * @param onChunk - 处理生成块的回调 - */ - private processGeminiImageResponse( - chunk: GenerateContentResponse, - onChunk: (chunk: Chunk) => void - ): { type: 'base64'; images: string[] } | undefined { - const parts = chunk.candidates?.[0]?.content?.parts - if (!parts) { - return - } - // 提取图像数据 - const images = parts - .filter((part: Part) => part.inlineData) - .map((part: Part) => { - if (!part.inlineData) { - return null - } - // onChunk的位置需要更改 - onChunk({ - type: ChunkType.IMAGE_CREATED - }) - const dataPrefix = `data:${part.inlineData.mimeType || 'image/png'};base64,` - return part.inlineData.data?.startsWith('data:') ? part.inlineData.data : dataPrefix + part.inlineData.data - }) - - return { - type: 'base64', - images: images.filter((image) => image !== null) - } - } - - /** - * Check if the model is valid - * @param model - The model - * @param stream - Whether to use streaming interface - * @returns The validity of the model - */ - public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> { - if (!model) { - return { valid: false, error: new Error('No model found') } - } - - let config: GenerateContentConfig = { - maxOutputTokens: 1 - } - if (isGeminiReasoningModel(model)) { - config = { - ...config, - thinkingConfig: { - includeThoughts: false, - thinkingBudget: 0 - } as ThinkingConfig - } - } - - if (isGenerateImageModel(model)) { - config = { - ...config, - responseModalities: [Modality.TEXT, Modality.IMAGE], - responseMimeType: 'text/plain' - } - } - - try { - if (!stream) { - const result = await this.sdk.models.generateContent({ - model: model.id, - contents: [{ role: 'user', parts: [{ text: 'hi' }] }], - config: config - }) - if (isEmpty(result.text)) { - throw new Error('Empty response') - } - } else { - const response = await this.sdk.models.generateContentStream({ - model: model.id, - contents: [{ role: 'user', parts: [{ text: 'hi' }] }], - config: config - }) - // 等待整个流式响应结束 - let hasContent = false - for await (const chunk of response) { - if (chunk.candidates && chunk.candidates[0].finishReason === FinishReason.MAX_TOKENS) { - hasContent = true - break - } - } - if (!hasContent) { - throw new Error('Empty streaming response') - } - } - return { valid: true, error: null } - } catch (error: any) { - return { - valid: false, - error - } - } - } - - /** - * Get the models - * @returns The models - */ - public async models(): Promise { - try { - const api = this.provider.apiHost + '/v1beta/models' - const { data } = await axios.get(api, { params: { key: this.apiKey } }) - - return data.models.map( - (m) => - ({ - id: m.name.replace('models/', ''), - name: m.displayName, - description: m.description, - object: 'model', - created: Date.now(), - owned_by: 'gemini' - }) as OpenAI.Models.Model - ) - } catch (error) { - return [] - } - } - - /** - * Get the embedding dimensions - * @param model - The model - * @returns The embedding dimensions - */ - public async getEmbeddingDimensions(model: Model): Promise { - const data = await this.sdk.models.embedContent({ - model: model.id, - contents: [{ role: 'user', parts: [{ text: 'hi' }] }] - }) - return data.embeddings?.[0]?.values?.length || 0 - } - - public async generateImageByChat({ messages, assistant, onChunk }): Promise { - const defaultModel = getDefaultModel() - const model = assistant.model || defaultModel - const { contextCount, maxTokens } = getAssistantSettings(assistant) - const userMessages = filterUserRoleStartMessages( - filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2))) - ) - - const userLastMessage = userMessages.pop() - const { abortController } = this.createAbortController(userLastMessage?.id, true) - const { signal } = abortController - const generateContentConfig: GenerateContentConfig = { - responseModalities: [Modality.TEXT, Modality.IMAGE], - responseMimeType: 'text/plain', - safetySettings: this.getSafetySettings(), - temperature: assistant?.settings?.temperature, - topP: assistant?.settings?.top_p, - maxOutputTokens: maxTokens, - abortSignal: signal, - ...this.getCustomParameters(assistant) - } - const history: Content[] = [] - try { - for (const message of userMessages) { - history.push(await this.getImageFileContents(message)) - } - - let time_first_token_millsec = 0 - const start_time_millsec = new Date().getTime() - onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) - const chat = this.sdk.chats.create({ - model: model.id, - config: generateContentConfig, - history: history - }) - let content = '' - const finalUsage: Usage = { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0 - } - const userMessage: Content = await this.getImageFileContents(userLastMessage!) - const response = await chat.sendMessageStream({ - message: userMessage.parts!, - config: { - ...generateContentConfig, - abortSignal: signal - } - }) - for await (const chunk of response as AsyncGenerator) { - if (time_first_token_millsec == 0) { - time_first_token_millsec = new Date().getTime() - } - - if (chunk.text !== undefined) { - content += chunk.text - onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text }) - } - const generateImage = this.processGeminiImageResponse(chunk, onChunk) - if (generateImage?.images?.length) { - onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage }) - } - if (chunk.candidates?.[0]?.finishReason) { - if (chunk.text) { - onChunk({ type: ChunkType.TEXT_COMPLETE, text: content }) - } - if (chunk.usageMetadata) { - finalUsage.prompt_tokens = chunk.usageMetadata.promptTokenCount || 0 - finalUsage.completion_tokens = chunk.usageMetadata.candidatesTokenCount || 0 - finalUsage.total_tokens = chunk.usageMetadata.totalTokenCount || 0 - } - } - } - onChunk({ - type: ChunkType.BLOCK_COMPLETE, - response: { - usage: finalUsage, - metrics: { - completion_tokens: finalUsage.completion_tokens, - time_completion_millsec: new Date().getTime() - start_time_millsec, - time_first_token_millsec: time_first_token_millsec - start_time_millsec - } - } - }) - } catch (error) { - console.error('[generateImageByChat] error', error) - onChunk({ - type: ChunkType.ERROR, - error - }) - } - } - - public convertMcpTools(mcpTools: MCPTool[]): T[] { - return mcpToolsToGeminiTools(mcpTools) as T[] - } - - public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => { - if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { - return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model)) - } else if ('toolCallId' in mcpToolResponse) { - return { - role: 'user', - parts: [ - { - functionResponse: { - id: mcpToolResponse.toolCallId, - name: mcpToolResponse.tool.id, - response: { - output: !resp.isError ? resp.content : undefined, - error: resp.isError ? resp.content : undefined - } - } - } - ] - } satisfies Content - } - return - } - - private async uploadFile(file: FileType): Promise { - return await this.sdk.files.upload({ - file: file.path, - config: { - mimeType: 'application/pdf', - name: file.id, - displayName: file.origin_name - } - }) - } - - private async base64File(file: FileType) { - const { data } = await window.api.file.base64File(file.id + file.ext) - return { - data, - mimeType: 'application/pdf' - } - } - - private async retrieveFile(file: FileType): Promise { - const cachedResponse = CacheService.get('gemini_file_list') - - if (cachedResponse) { - return this.processResponse(cachedResponse, file) - } - - const response = await this.sdk.files.list() - CacheService.set('gemini_file_list', response, 3000) - - return this.processResponse(response, file) - } - - private async processResponse(response: Pager, file: FileType) { - for await (const f of response) { - if (f.state === FileState.ACTIVE) { - if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) { - return f - } - } - } - - return undefined - } - - // @ts-ignore unused - private async listFiles(): Promise { - const files: File[] = [] - for await (const f of await this.sdk.files.list()) { - files.push(f) - } - return files - } - - // @ts-ignore unused - private async deleteFile(fileId: string) { - await this.sdk.files.delete({ name: fileId }) - } -} diff --git a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts deleted file mode 100644 index 4f68dbfa3f..0000000000 --- a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts +++ /dev/null @@ -1,1277 +0,0 @@ -import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' -import { - findTokenLimit, - getOpenAIWebSearchParams, - isClaudeReasoningModel, - isHunyuanSearchModel, - isOpenAIReasoningModel, - isReasoningModel, - isSupportedModel, - isSupportedReasoningEffortGrokModel, - isSupportedReasoningEffortModel, - isSupportedReasoningEffortOpenAIModel, - isSupportedThinkingTokenClaudeModel, - isSupportedThinkingTokenGeminiModel, - isSupportedThinkingTokenModel, - isSupportedThinkingTokenQwenModel, - isVisionModel, - isWebSearchModel, - isZhipuModel -} from '@renderer/config/models' -import { getStoreSetting } from '@renderer/hooks/useSettings' -import i18n from '@renderer/i18n' -import { extractReasoningMiddleware } from '@renderer/middlewares/extractReasoningMiddleware' -import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' -import { EVENT_NAMES } from '@renderer/services/EventService' -import { - filterContextMessages, - filterEmptyMessages, - filterUserRoleStartMessages -} from '@renderer/services/MessagesService' -import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService' -import store from '@renderer/store' -import { - Assistant, - EFFORT_RATIO, - FileTypes, - MCPCallToolResponse, - MCPTool, - MCPToolResponse, - Metrics, - Model, - Provider, - Suggestion, - ToolCallResponse, - Usage, - WebSearchSource -} from '@renderer/types' -import { ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk' -import { Message } from '@renderer/types/newMessage' -import { removeSpecialCharactersForTopicName } from '@renderer/utils' -import { addImageFileToContents } from '@renderer/utils/formats' -import { - convertLinks, - convertLinksToHunyuan, - convertLinksToOpenRouter, - convertLinksToZhipu -} from '@renderer/utils/linkConverter' -import { - isEnabledToolUse, - mcpToolCallResponseToOpenAICompatibleMessage, - mcpToolsToOpenAIChatTools, - openAIToolsToMcpTool, - parseAndCallTools -} from '@renderer/utils/mcp-tools' -import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' -import { buildSystemPrompt } from '@renderer/utils/prompt' -import { asyncGeneratorToReadableStream, readableStreamAsyncIterable } from '@renderer/utils/stream' -import { isEmpty, takeRight } from 'lodash' -import OpenAI, { AzureOpenAI } from 'openai' -import { - ChatCompletionContentPart, - ChatCompletionCreateParamsNonStreaming, - ChatCompletionMessageParam, - ChatCompletionMessageToolCall, - ChatCompletionTool, - ChatCompletionToolMessageParam -} from 'openai/resources' - -import { CompletionsParams } from '.' -import { BaseOpenAIProvider } from './OpenAIResponseProvider' - -// 1. 定义联合类型 -export type OpenAIStreamChunk = - | { type: 'reasoning' | 'text-delta'; textDelta: string } - | { type: 'tool-calls'; delta: any } - | { type: 'finish'; finishReason: any; usage: any; delta: any; chunk: any } - | { type: 'unknown'; chunk: any } - -export default class OpenAIProvider extends BaseOpenAIProvider { - constructor(provider: Provider) { - super(provider) - - if (provider.id === 'azure-openai' || provider.type === 'azure-openai') { - this.sdk = new AzureOpenAI({ - dangerouslyAllowBrowser: true, - apiKey: this.apiKey, - apiVersion: provider.apiVersion, - endpoint: provider.apiHost - }) - return - } - - this.sdk = new OpenAI({ - dangerouslyAllowBrowser: true, - apiKey: this.apiKey, - baseURL: this.getBaseURL(), - defaultHeaders: { - ...this.defaultHeaders(), - ...(this.provider.id === 'copilot' ? { 'editor-version': 'vscode/1.97.2' } : {}), - ...(this.provider.id === 'copilot' ? { 'copilot-vision-request': 'true' } : {}) - } - }) - } - - /** - * Check if the provider does not support files - * @returns True if the provider does not support files, false otherwise - */ - private get isNotSupportFiles() { - if (this.provider?.isNotSupportArrayContent) { - return true - } - - const providers = ['deepseek', 'baichuan', 'minimax', 'xirang'] - - return providers.includes(this.provider.id) - } - - /** - * Get the message parameter - * @param message - The message - * @param model - The model - * @returns The message parameter - */ - override async getMessageParam( - message: Message, - model: Model - ): Promise { - const isVision = isVisionModel(model) - const content = await this.getMessageContent(message) - const fileBlocks = findFileBlocks(message) - const imageBlocks = findImageBlocks(message) - - if (fileBlocks.length === 0 && imageBlocks.length === 0) { - return { - role: message.role === 'system' ? 'user' : message.role, - content - } - } - - // If the model does not support files, extract the file content - if (this.isNotSupportFiles) { - const fileContent = await this.extractFileContent(message) - - return { - role: message.role === 'system' ? 'user' : message.role, - content: content + '\n\n---\n\n' + fileContent - } - } - - // If the model supports files, add the file content to the message - const parts: ChatCompletionContentPart[] = [] - - if (content) { - parts.push({ type: 'text', text: content }) - } - - for (const imageBlock of imageBlocks) { - if (isVision) { - if (imageBlock.file) { - const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) - parts.push({ type: 'image_url', image_url: { url: image.data } }) - } else if (imageBlock.url && imageBlock.url.startsWith('data:')) { - parts.push({ type: 'image_url', image_url: { url: imageBlock.url } }) - } - } - } - - for (const fileBlock of fileBlocks) { - const file = fileBlock.file - if (!file) { - continue - } - - if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { - const fileContent = await (await window.api.file.read(file.id + file.ext)).trim() - parts.push({ - type: 'text', - text: file.origin_name + '\n' + fileContent - }) - } - } - - return { - role: message.role === 'system' ? 'user' : message.role, - content: parts - } as ChatCompletionMessageParam - } - - override getTemperature(assistant: Assistant, model: Model): number | undefined { - if (isOpenAIReasoningModel(model) || (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model))) { - return undefined - } - return assistant.settings?.temperature - } - - override getTopP(assistant: Assistant, model: Model): number | undefined { - if (isOpenAIReasoningModel(model) || (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model))) { - return undefined - } - return assistant.settings?.topP - } - - /** - * Get the provider specific parameters for the assistant - * @param assistant - The assistant - * @param model - The model - * @returns The provider specific parameters - */ - private getProviderSpecificParameters(assistant: Assistant, model: Model) { - const { maxTokens } = getAssistantSettings(assistant) - - if (this.provider.id === 'openrouter') { - if (model.id.includes('deepseek-r1')) { - return { - include_reasoning: true - } - } - } - - if (isOpenAIReasoningModel(model)) { - return { - max_tokens: undefined, - max_completion_tokens: maxTokens - } - } - - return {} - } - - /** - * Get the reasoning effort for the assistant - * @param assistant - The assistant - * @param model - The model - * @returns The reasoning effort - */ - private getReasoningEffort(assistant: Assistant, model: Model) { - if (this.provider.id === 'groq') { - return {} - } - - if (!isReasoningModel(model)) { - return {} - } - const reasoningEffort = assistant?.settings?.reasoning_effort - if (!reasoningEffort) { - if (isSupportedThinkingTokenQwenModel(model)) { - return { enable_thinking: false } - } - - if (isSupportedThinkingTokenClaudeModel(model)) { - return { thinking: { type: 'disabled' } } - } - - if (isSupportedThinkingTokenGeminiModel(model)) { - // openrouter没有提供一个不推理的选项,先隐藏 - if (this.provider.id === 'openrouter') { - return { reasoning: { maxTokens: 0, exclude: true } } - } - return { - reasoning_effort: 'none' - } - } - - return {} - } - const effortRatio = EFFORT_RATIO[reasoningEffort] - const budgetTokens = Math.floor( - (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min! - ) - - // OpenRouter models - if (model.provider === 'openrouter') { - if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) { - return { - reasoning: { - effort: assistant?.settings?.reasoning_effort === 'auto' ? 'medium' : assistant?.settings?.reasoning_effort - } - } - } - } - - // Qwen models - if (isSupportedThinkingTokenQwenModel(model)) { - return { - enable_thinking: true, - thinking_budget: budgetTokens - } - } - - // Grok models - if (isSupportedReasoningEffortGrokModel(model)) { - return { - reasoning_effort: assistant?.settings?.reasoning_effort - } - } - - // OpenAI models - if (isSupportedReasoningEffortOpenAIModel(model) || isSupportedThinkingTokenGeminiModel(model)) { - return { - reasoning_effort: assistant?.settings?.reasoning_effort - } - } - - // Claude models - if (isSupportedThinkingTokenClaudeModel(model)) { - const maxTokens = assistant.settings?.maxTokens - return { - thinking: { - type: 'enabled', - budget_tokens: Math.floor( - Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio)) - ) - } - } - } - - // Default case: no special thinking settings - return {} - } - - public convertMcpTools(mcpTools: MCPTool[]): T[] { - return mcpToolsToOpenAIChatTools(mcpTools) as T[] - } - - public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => { - if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { - return mcpToolCallResponseToOpenAICompatibleMessage(mcpToolResponse, resp, isVisionModel(model)) - } else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) { - const toolCallOut: ChatCompletionToolMessageParam = { - role: 'tool', - tool_call_id: mcpToolResponse.toolCallId, - content: JSON.stringify(resp.content) - } - return toolCallOut - } - return - } - - /** - * Generate completions for the assistant - * @param messages - The messages - * @param assistant - The assistant - * @param mcpTools - The MCP tools - * @param onChunk - The onChunk callback - * @param onFilterMessages - The onFilterMessages callback - * @returns The completions - */ - async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams): Promise { - if (assistant.enableGenerateImage) { - await this.generateImageByChat({ messages, assistant, onChunk } as CompletionsParams) - return - } - const defaultModel = getDefaultModel() - const model = assistant.model || defaultModel - - const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) - const isEnabledBultinWebSearch = assistant.enableWebSearch && isWebSearchModel(model) - messages = addImageFileToContents(messages) - const enableReasoning = - ((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) && - assistant.settings?.reasoning_effort !== undefined) || - (isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model))) - let systemMessage = { role: 'system', content: assistant.prompt || '' } - if (isSupportedReasoningEffortOpenAIModel(model)) { - systemMessage = { - role: 'developer', - content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}` - } - } - if (model.id.includes('o1-preview') || model.id.includes('o1-mini')) { - systemMessage = { - role: 'assistant', - content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}` - } - } - const { tools } = this.setupToolsConfig({ - mcpTools, - model, - enableToolUse: isEnabledToolUse(assistant) - }) - - if (this.useSystemPromptForTools) { - systemMessage.content = await buildSystemPrompt(systemMessage.content || '', mcpTools) - } - - const userMessages: ChatCompletionMessageParam[] = [] - const _messages = filterUserRoleStartMessages( - filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 1))) - ) - - onFilterMessages(_messages) - - for (const message of _messages) { - userMessages.push(await this.getMessageParam(message, model)) - } - - const isSupportStreamOutput = () => { - return streamOutput - } - - const lastUserMessage = _messages.findLast((m) => m.role === 'user') - const { abortController, cleanup, signalPromise } = this.createAbortController(lastUserMessage?.id, true) - const { signal } = abortController - await this.checkIsCopilot() - - const lastUserMsg = userMessages.findLast((m) => m.role === 'user') - if (lastUserMsg && isSupportedThinkingTokenQwenModel(model)) { - const postsuffix = '/no_think' - // qwenThinkMode === true 表示思考模式啓用,此時不應添加 /no_think,如果存在則移除 - const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true - const currentContent = lastUserMsg.content // content 類型:string | ChatCompletionContentPart[] | null - - lastUserMsg.content = processPostsuffixQwen3Model( - currentContent, - postsuffix, - qwenThinkModeEnabled - ) as ChatCompletionContentPart[] - } - - //当 systemMessage 内容为空时不发送 systemMessage - let reqMessages: ChatCompletionMessageParam[] - if (!systemMessage.content) { - reqMessages = [...userMessages] - } else { - reqMessages = [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[] - } - - let finalUsage: Usage = { - completion_tokens: 0, - prompt_tokens: 0, - total_tokens: 0 - } - - const finalMetrics: Metrics = { - completion_tokens: 0, - time_completion_millsec: 0, - time_first_token_millsec: 0 - } - - const toolResponses: MCPToolResponse[] = [] - - const processToolResults = async (toolResults: Awaited>, idx: number) => { - if (toolResults.length === 0) return - - toolResults.forEach((ts) => reqMessages.push(ts as ChatCompletionMessageParam)) - - console.debug('[tool] reqMessages before processing', model.id, reqMessages) - reqMessages = processReqMessages(model, reqMessages) - console.debug('[tool] reqMessages', model.id, reqMessages) - - onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) - const newStream = await this.sdk.chat.completions - // @ts-ignore key is not typed - .create( - { - model: model.id, - messages: reqMessages, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - max_tokens: maxTokens, - keep_alive: this.keepAliveTime, - stream: isSupportStreamOutput(), - tools: !isEmpty(tools) ? tools : undefined, - service_tier: this.getServiceTier(model), - ...getOpenAIWebSearchParams(assistant, model), - ...this.getReasoningEffort(assistant, model), - ...this.getProviderSpecificParameters(assistant, model), - ...this.getCustomParameters(assistant) - }, - { - signal - } - ) - await processStream(newStream, idx + 1) - } - - const processToolCalls = async (mcpTools, toolCalls: ChatCompletionMessageToolCall[]) => { - const mcpToolResponses = toolCalls - .map((toolCall) => { - const mcpTool = openAIToolsToMcpTool(mcpTools, toolCall as ChatCompletionMessageToolCall) - if (!mcpTool) return undefined - - const parsedArgs = (() => { - try { - return JSON.parse(toolCall.function.arguments) - } catch { - return toolCall.function.arguments - } - })() - - return { - id: toolCall.id, - toolCallId: toolCall.id, - tool: mcpTool, - arguments: parsedArgs, - status: 'pending' - } as ToolCallResponse - }) - .filter((t): t is ToolCallResponse => typeof t !== 'undefined') - return await parseAndCallTools( - mcpToolResponses, - toolResponses, - onChunk, - this.mcpToolCallResponseToMessage, - model, - mcpTools - ) - } - - const processToolUses = async (content: string) => { - return await parseAndCallTools( - content, - toolResponses, - onChunk, - this.mcpToolCallResponseToMessage, - model, - mcpTools - ) - } - - const processStream = async (stream: any, idx: number) => { - const toolCalls: ChatCompletionMessageToolCall[] = [] - let time_first_token_millsec = 0 - - // Handle non-streaming case (already returns early, no change needed here) - if (!isSupportStreamOutput()) { - // Calculate final metrics once - finalMetrics.completion_tokens = stream.usage?.completion_tokens - finalMetrics.time_completion_millsec = new Date().getTime() - start_time_millsec - - // Create a synthetic usage object if stream.usage is undefined - finalUsage = { ...stream.usage } - // Separate onChunk calls for text and usage/metrics - let content = '' - stream.choices.forEach((choice) => { - const reasoning = choice.message.reasoning || choice.message.reasoning_content - // reasoning - if (reasoning) { - onChunk({ - type: ChunkType.THINKING_DELTA, - text: reasoning - }) - onChunk({ - type: ChunkType.THINKING_COMPLETE, - text: reasoning, - thinking_millsec: new Date().getTime() - start_time_millsec - }) - } - // text - if (choice.message.content) { - content += choice.message.content - onChunk({ type: ChunkType.TEXT_DELTA, text: choice.message.content }) - } - // tool call - if (choice.message.tool_calls && choice.message.tool_calls.length) { - choice.message.tool_calls.forEach((t) => toolCalls.push(t)) - } - - reqMessages.push({ - role: choice.message.role, - content: choice.message.content, - tool_calls: toolCalls.length - ? toolCalls.map((toolCall) => ({ - id: toolCall.id, - function: { - ...toolCall.function, - arguments: - typeof toolCall.function.arguments === 'string' - ? toolCall.function.arguments - : JSON.stringify(toolCall.function.arguments) - }, - type: 'function' - })) - : undefined - }) - }) - - if (content.length) { - onChunk({ type: ChunkType.TEXT_COMPLETE, text: content }) - } - - const toolResults: Awaited> = [] - if (toolCalls.length) { - toolResults.push(...(await processToolCalls(mcpTools, toolCalls))) - } - if (stream.choices[0].message?.content) { - toolResults.push(...(await processToolUses(stream.choices[0].message?.content))) - } - await processToolResults(toolResults, idx) - - // Always send usage and metrics data - onChunk({ type: ChunkType.BLOCK_COMPLETE, response: { usage: finalUsage, metrics: finalMetrics } }) - return - } - - let content = '' - let thinkingContent = '' - let isFirstChunk = true - - // 1. 初始化中间件 - const reasoningTags = [ - { openingTag: '', closingTag: '', separator: '\n' }, - { openingTag: '###Thinking', closingTag: '###Response', separator: '\n' } - ] - const getAppropriateTag = (model: Model) => { - if (model.id.includes('qwen3')) return reasoningTags[0] - return reasoningTags[0] - } - const reasoningTag = getAppropriateTag(model) - async function* openAIChunkToTextDelta(stream: any): AsyncGenerator { - try { - for await (const chunk of stream) { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { - break - } - - if (chunk.choices && chunk.choices.length > 0) { - const delta = chunk.choices[0]?.delta - if ( - (delta?.reasoning_content && delta?.reasoning_content !== '\n') || - (delta?.reasoning && delta?.reasoning !== '\n') - ) { - yield { type: 'reasoning', textDelta: delta.reasoning_content || delta.reasoning } - } - if (delta?.content) { - yield { type: 'text-delta', textDelta: delta.content } - } - if (delta?.tool_calls && delta?.tool_calls.length > 0) { - yield { type: 'tool-calls', delta: delta } - } - - const finishReason = chunk?.choices[0]?.finish_reason - if (!isEmpty(finishReason)) { - yield { type: 'finish', finishReason, usage: chunk.usage, delta, chunk } - } - } - } - } catch (error) { - console.error('[openAIChunkToTextDelta] error', error) - throw error - } - } - - // 2. 使用中间件 - const { stream: processedStream } = await extractReasoningMiddleware({ - openingTag: reasoningTag?.openingTag, - closingTag: reasoningTag?.closingTag, - separator: reasoningTag?.separator, - enableReasoning - }).wrapStream({ - doStream: async () => ({ - stream: asyncGeneratorToReadableStream(openAIChunkToTextDelta(stream)) - }) - }) - - // 3. 消费 processedStream,分发 onChunk - for await (const chunk of readableStreamAsyncIterable(processedStream)) { - const delta = chunk.type === 'finish' ? chunk.delta : chunk - const rawChunk = chunk.type === 'finish' ? chunk.chunk : chunk - switch (chunk.type) { - case 'reasoning': { - if (time_first_token_millsec === 0) { - time_first_token_millsec = new Date().getTime() - } - thinkingContent += chunk.textDelta - onChunk({ - type: ChunkType.THINKING_DELTA, - text: chunk.textDelta, - thinking_millsec: new Date().getTime() - time_first_token_millsec - }) - break - } - case 'text-delta': { - let textDelta = chunk.textDelta - if (assistant.enableWebSearch && delta) { - const originalDelta = rawChunk?.choices?.[0]?.delta - - if (originalDelta?.annotations) { - textDelta = convertLinks(textDelta, isFirstChunk) - } else if (assistant.model?.provider === 'openrouter') { - textDelta = convertLinksToOpenRouter(textDelta, isFirstChunk) - } else if (isZhipuModel(assistant.model)) { - textDelta = convertLinksToZhipu(textDelta, isFirstChunk) - } else if (isHunyuanSearchModel(assistant.model)) { - const searchResults = rawChunk?.search_info?.search_results || [] - textDelta = convertLinksToHunyuan(textDelta, searchResults, isFirstChunk) - } - } - if (isFirstChunk) { - isFirstChunk = false - if (time_first_token_millsec === 0) { - time_first_token_millsec = new Date().getTime() - } else { - onChunk({ - type: ChunkType.THINKING_COMPLETE, - text: thinkingContent, - thinking_millsec: new Date().getTime() - time_first_token_millsec - }) - } - } - content += textDelta - onChunk({ type: ChunkType.TEXT_DELTA, text: textDelta }) - break - } - case 'tool-calls': { - if (isFirstChunk) { - isFirstChunk = false - if (time_first_token_millsec === 0) { - time_first_token_millsec = new Date().getTime() - } else { - onChunk({ - type: ChunkType.THINKING_COMPLETE, - text: thinkingContent, - thinking_millsec: new Date().getTime() - time_first_token_millsec - }) - } - } - chunk.delta.tool_calls.forEach((toolCall) => { - const { id, index, type, function: fun } = toolCall - if (id && type === 'function' && fun) { - const { name, arguments: args } = fun - toolCalls.push({ - id, - function: { - name: name || '', - arguments: args || '' - }, - type: 'function' - }) - } else if (fun?.arguments) { - toolCalls[index].function.arguments += fun.arguments - } - }) - break - } - case 'finish': { - const finishReason = chunk.finishReason - const usage = chunk.usage - const originalFinishDelta = chunk.delta - const originalFinishRawChunk = chunk.chunk - if (!isEmpty(finishReason)) { - if (content) { - onChunk({ type: ChunkType.TEXT_COMPLETE, text: content }) - } - if (thinkingContent) { - onChunk({ - type: ChunkType.THINKING_COMPLETE, - text: thinkingContent, - thinking_millsec: new Date().getTime() - time_first_token_millsec - }) - } - if (usage) { - finalUsage.completion_tokens += usage.completion_tokens || 0 - finalUsage.prompt_tokens += usage.prompt_tokens || 0 - finalUsage.total_tokens += usage.total_tokens || 0 - finalMetrics.completion_tokens += usage.completion_tokens || 0 - } - finalMetrics.time_completion_millsec += new Date().getTime() - start_time_millsec - finalMetrics.time_first_token_millsec = time_first_token_millsec - start_time_millsec - if (originalFinishDelta?.annotations) { - if (assistant.model?.provider === 'copilot') return - - onChunk({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - results: originalFinishDelta.annotations, - source: WebSearchSource.OPENAI_RESPONSE - } - } as LLMWebSearchCompleteChunk) - } - if (assistant.model?.provider === 'grok') { - const citations = originalFinishRawChunk.citations - if (citations) { - onChunk({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - results: citations, - source: WebSearchSource.GROK - } - } as LLMWebSearchCompleteChunk) - } - } - if (assistant.model?.provider === 'perplexity') { - const citations = originalFinishRawChunk.citations - if (citations) { - onChunk({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - results: citations, - source: WebSearchSource.PERPLEXITY - } - } as LLMWebSearchCompleteChunk) - } - } - if ( - isEnabledBultinWebSearch && - isZhipuModel(model) && - finishReason === 'stop' && - originalFinishRawChunk?.web_search - ) { - onChunk({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - results: originalFinishRawChunk.web_search, - source: WebSearchSource.ZHIPU - } - } as LLMWebSearchCompleteChunk) - } - if ( - isEnabledBultinWebSearch && - isHunyuanSearchModel(model) && - originalFinishRawChunk?.search_info?.search_results - ) { - onChunk({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - results: originalFinishRawChunk.search_info.search_results, - source: WebSearchSource.HUNYUAN - } - } as LLMWebSearchCompleteChunk) - } - } - break - } - case 'unknown': { - onChunk({ - type: ChunkType.ERROR, - error: chunk.chunk - }) - } - } - } - - reqMessages.push({ - role: 'assistant', - content: content, - tool_calls: toolCalls.length - ? toolCalls.map((toolCall) => ({ - id: toolCall.id, - function: { - ...toolCall.function, - arguments: - typeof toolCall.function.arguments === 'string' - ? toolCall.function.arguments - : JSON.stringify(toolCall.function.arguments) - }, - type: 'function' - })) - : undefined - }) - let toolResults: Awaited> = [] - if (toolCalls.length) { - toolResults = await processToolCalls(mcpTools, toolCalls) - } - if (content.length) { - toolResults = toolResults.concat(await processToolUses(content)) - } - if (toolResults.length) { - await processToolResults(toolResults, idx) - } - onChunk({ - type: ChunkType.BLOCK_COMPLETE, - response: { - usage: finalUsage, - metrics: finalMetrics - } - }) - } - - reqMessages = processReqMessages(model, reqMessages) - // 等待接口返回流 - onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) - const start_time_millsec = new Date().getTime() - const stream = await this.sdk.chat.completions - // @ts-ignore key is not typed - .create( - { - model: model.id, - messages: reqMessages, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - max_tokens: maxTokens, - keep_alive: this.keepAliveTime, - stream: isSupportStreamOutput(), - tools: !isEmpty(tools) ? tools : undefined, - service_tier: this.getServiceTier(model), - ...getOpenAIWebSearchParams(assistant, model), - ...this.getReasoningEffort(assistant, model), - ...this.getProviderSpecificParameters(assistant, model), - ...this.getCustomParameters(assistant) - }, - { - signal, - timeout: this.getTimeout(model) - } - ) - - await processStream(stream, 0).finally(cleanup) - - // 捕获signal的错误 - await signalPromise?.promise?.catch((error) => { - throw error - }) - } - - /** - * Translate a message - * @param content - * @param assistant - The assistant - * @param onResponse - The onResponse callback - * @returns The translated message - */ - async translate(content: string, assistant: Assistant, onResponse?: (text: string, isComplete: boolean) => void) { - const defaultModel = getDefaultModel() - const model = assistant.model || defaultModel - - const messagesForApi = content - ? [ - { role: 'system', content: assistant.prompt }, - { role: 'user', content } - ] - : [{ role: 'user', content: assistant.prompt }] - - const isSupportedStreamOutput = () => { - if (!onResponse) { - return false - } - return true - } - - const stream = isSupportedStreamOutput() - - await this.checkIsCopilot() - - // console.debug('[translate] reqMessages', model.id, message) - // @ts-ignore key is not typed - const response = await this.sdk.chat.completions.create({ - model: model.id, - messages: messagesForApi as ChatCompletionMessageParam[], - stream, - keep_alive: this.keepAliveTime, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - ...this.getReasoningEffort(assistant, model) - }) - - if (!stream) { - return response.choices[0].message?.content || '' - } - - let text = '' - let isThinking = false - const isReasoning = isReasoningModel(model) - - for await (const chunk of response) { - const deltaContent = chunk.choices[0]?.delta?.content || '' - - if (isReasoning) { - if (deltaContent.includes('')) { - isThinking = true - } - - if (!isThinking) { - text += deltaContent - onResponse?.(text, false) - } - - if (deltaContent.includes('')) { - isThinking = false - } - } else { - text += deltaContent - onResponse?.(text, false) - } - } - - onResponse?.(text, true) - - return text - } - - /** - * Summarize a message - * @param messages - The messages - * @param assistant - The assistant - * @returns The summary - */ - public async summaries(messages: Message[], assistant: Assistant): Promise { - const model = getTopNamingModel() || assistant.model || getDefaultModel() - - const userMessages = takeRight(messages, 5).map((message) => ({ - role: message.role, - content: getMainTextContent(message) - })) - - const userMessageContent = userMessages.reduce((prev, curr) => { - const content = curr.role === 'user' ? `User: ${curr.content}` : `Assistant: ${curr.content}` - return prev + (prev ? '\n' : '') + content - }, '') - - const systemMessage = { - role: 'system', - content: getStoreSetting('topicNamingPrompt') || i18n.t('prompts.title') - } - - const userMessage = { - role: 'user', - content: userMessageContent - } - - await this.checkIsCopilot() - - // @ts-ignore key is not typed - const response = await this.sdk.chat.completions.create({ - model: model.id, - messages: [systemMessage, userMessage] as ChatCompletionMessageParam[], - stream: false, - keep_alive: this.keepAliveTime, - max_tokens: 1000 - }) - - // 针对思考类模型的返回,总结仅截取之后的内容 - let content = response.choices[0].message?.content || '' - content = content.replace(/^(.*?)<\/think>/s, '') - - return removeSpecialCharactersForTopicName(content.substring(0, 50)) - } - - /** - * Summarize a message for search - * @param messages - The messages - * @param assistant - The assistant - * @returns The summary - */ - public async summaryForSearch(messages: Message[], assistant: Assistant): Promise { - const model = assistant.model || getDefaultModel() - - const systemMessage = { - role: 'system', - content: assistant.prompt - } - - const messageContents = messages.map((m) => getMainTextContent(m)) - const userMessageContent = messageContents.join('\n') - - const userMessage = { - role: 'user', - content: userMessageContent - } - - const lastUserMessage = messages[messages.length - 1] - - const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id) - const { signal } = abortController - - const response = await this.sdk.chat.completions - // @ts-ignore key is not typed - .create( - { - model: model.id, - messages: [systemMessage, userMessage] as ChatCompletionMessageParam[], - stream: false, - keep_alive: this.keepAliveTime, - max_tokens: 1000 - }, - { - timeout: 20 * 1000, - signal: signal - } - ) - .finally(cleanup) - - // 针对思考类模型的返回,总结仅截取之后的内容 - let content = response.choices[0].message?.content || '' - content = content.replace(/^(.*?)<\/think>/s, '') - - return content - } - - /** - * Generate text - * @param prompt - The prompt - * @param content - The content - * @returns The generated text - */ - public async generateText({ prompt, content }: { prompt: string; content: string }): Promise { - const model = getDefaultModel() - - await this.checkIsCopilot() - - const response = await this.sdk.chat.completions.create({ - model: model.id, - stream: false, - messages: [ - { role: 'system', content: prompt }, - { role: 'user', content } - ] - }) - - return response.choices[0].message?.content || '' - } - - /** - * Generate suggestions - * @param messages - The messages - * @param assistant - The assistant - * @returns The suggestions - */ - async suggestions(messages: Message[], assistant: Assistant): Promise { - const { model } = assistant - - if (!model) { - return [] - } - - await this.checkIsCopilot() - - const userMessagesForApi = messages - .filter((m) => m.role === 'user') - .map((m) => ({ - role: m.role, - content: getMainTextContent(m) - })) - - const response: any = await this.sdk.request({ - method: 'post', - path: '/advice_questions', - body: { - messages: userMessagesForApi, - model: model.id, - max_tokens: 0, - temperature: 0, - n: 0 - } - }) - - return response?.questions?.filter(Boolean)?.map((q: any) => ({ content: q })) || [] - } - - /** - * Check if the model is valid - * @param model - The model - * @param stream - Whether to use streaming interface - * @returns The validity of the model - */ - public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> { - if (!model) { - return { valid: false, error: new Error('No model found') } - } - - const body: any = { - model: model.id, - messages: [{ role: 'user', content: 'hi' }], - stream - } - - if (isSupportedThinkingTokenQwenModel(model)) { - body.enable_thinking = false // qwen3 - } - - try { - await this.checkIsCopilot() - if (!stream) { - const response = await this.sdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming) - if (!response?.choices[0].message) { - throw new Error('Empty response') - } - return { valid: true, error: null } - } else { - const response: any = await this.sdk.chat.completions.create(body as any) - // 等待整个流式响应结束 - let hasContent = false - for await (const chunk of response) { - if (chunk.choices?.[0]?.delta?.content) { - hasContent = true - } - } - if (hasContent) { - return { valid: true, error: null } - } - throw new Error('Empty streaming response') - } - } catch (error: any) { - return { - valid: false, - error - } - } - } - - /** - * Get the models - * @returns The models - */ - public async models(): Promise { - try { - await this.checkIsCopilot() - - const response = await this.sdk.models.list() - - if (this.provider.id === 'github') { - // @ts-ignore key is not typed - return response.body - .map((model) => ({ - id: model.name, - description: model.summary, - object: 'model', - owned_by: model.publisher - })) - .filter(isSupportedModel) - } - - if (this.provider.id === 'together') { - // @ts-ignore key is not typed - return response?.body - .map((model: any) => ({ - id: model.id, - description: model.display_name, - object: 'model', - owned_by: model.organization - })) - .filter(isSupportedModel) - } - - const models = response.data || [] - models.forEach((model) => { - model.id = model.id.trim() - }) - - return models.filter(isSupportedModel) - } catch (error) { - return [] - } - } - - /** - * Get the embedding dimensions - * @param model - The model - * @returns The embedding dimensions - */ - public async getEmbeddingDimensions(model: Model): Promise { - await this.checkIsCopilot() - - try { - const data = await this.sdk.embeddings.create({ - model: model.id, - input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi', - // @ts-ignore voyage api need null - encoding_format: model?.provider === 'voyageai' ? null : 'float' - }) - return data.data[0].embedding.length - } catch (e) { - return 0 - } - } - - public async checkIsCopilot() { - if (this.provider.id !== 'copilot') { - return - } - const defaultHeaders = store.getState().copilot.defaultHeaders - // copilot每次请求前需要重新获取token,因为token中附带时间戳 - const { token } = await window.api.copilot.getToken(defaultHeaders) - this.sdk.apiKey = token - } -} diff --git a/src/renderer/src/providers/AiProvider/OpenAIResponseProvider.ts b/src/renderer/src/providers/AiProvider/OpenAIResponseProvider.ts deleted file mode 100644 index 2d2c73e67f..0000000000 --- a/src/renderer/src/providers/AiProvider/OpenAIResponseProvider.ts +++ /dev/null @@ -1,1218 +0,0 @@ -import { - isOpenAIModel, - isOpenAIReasoningModel, - isOpenAIWebSearch, - isSupportedFlexServiceTier, - isSupportedModel, - isSupportedReasoningEffortOpenAIModel, - isVisionModel, - isWebSearchModel -} from '@renderer/config/models' -import { getStoreSetting } from '@renderer/hooks/useSettings' -import i18n from '@renderer/i18n' -import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' -import { EVENT_NAMES } from '@renderer/services/EventService' -import FileManager from '@renderer/services/FileManager' -import { - filterContextMessages, - filterEmptyMessages, - filterUserRoleStartMessages -} from '@renderer/services/MessagesService' -import { - Assistant, - FileTypes, - GenerateImageParams, - MCPCallToolResponse, - MCPTool, - MCPToolResponse, - Metrics, - Model, - OpenAIServiceTier, - OpenAISummaryText, - Provider, - Suggestion, - ToolCallResponse, - Usage, - WebSearchSource -} from '@renderer/types' -import { ChunkType } from '@renderer/types/chunk' -import { Message } from '@renderer/types/newMessage' -import { removeSpecialCharactersForTopicName } from '@renderer/utils' -import { addImageFileToContents } from '@renderer/utils/formats' -import { convertLinks } from '@renderer/utils/linkConverter' -import { - isEnabledToolUse, - mcpToolCallResponseToOpenAIMessage, - mcpToolsToOpenAIResponseTools, - openAIToolsToMcpTool, - parseAndCallTools -} from '@renderer/utils/mcp-tools' -import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' -import { buildSystemPrompt } from '@renderer/utils/prompt' -import { Base64 } from 'js-base64' -import { isEmpty, takeRight } from 'lodash' -import mime from 'mime' -import OpenAI from 'openai' -import { ChatCompletionContentPart, ChatCompletionMessageParam } from 'openai/resources/chat/completions' -import { Stream } from 'openai/streaming' -import { toFile, Uploadable } from 'openai/uploads' - -import { CompletionsParams } from '.' -import BaseProvider from './BaseProvider' -import OpenAIProvider from './OpenAIProvider' - -export abstract class BaseOpenAIProvider extends BaseProvider { - protected sdk: OpenAI - - constructor(provider: Provider) { - super(provider) - - this.sdk = new OpenAI({ - dangerouslyAllowBrowser: true, - apiKey: this.apiKey, - baseURL: this.getBaseURL(), - defaultHeaders: { - ...this.defaultHeaders() - } - }) - } - - abstract convertMcpTools(mcpTools: MCPTool[]): T[] - - abstract mcpToolCallResponseToMessage: ( - mcpToolResponse: MCPToolResponse, - resp: MCPCallToolResponse, - model: Model - ) => OpenAI.Responses.ResponseInputItem | ChatCompletionMessageParam | undefined - - /** - * Extract the file content from the message - * @param message - The message - * @returns The file content - */ - protected async extractFileContent(message: Message) { - const fileBlocks = findFileBlocks(message) - if (fileBlocks.length > 0) { - const textFileBlocks = fileBlocks.filter( - (fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type) - ) - - if (textFileBlocks.length > 0) { - let text = '' - const divider = '\n\n---\n\n' - - for (const fileBlock of textFileBlocks) { - const file = fileBlock.file - const fileContent = (await window.api.file.read(file.id + file.ext)).trim() - const fileNameRow = 'file: ' + file.origin_name + '\n\n' - text = text + fileNameRow + fileContent + divider - } - - return text - } - } - - return '' - } - - private async getReponseMessageParam(message: Message, model: Model): Promise { - const isVision = isVisionModel(model) - const content = await this.getMessageContent(message) - const fileBlocks = findFileBlocks(message) - const imageBlocks = findImageBlocks(message) - - if (fileBlocks.length === 0 && imageBlocks.length === 0) { - if (message.role === 'assistant') { - return { - role: 'assistant', - content: content - } - } else { - return { - role: message.role === 'system' ? 'user' : message.role, - content: content ? [{ type: 'input_text', text: content }] : [] - } as OpenAI.Responses.EasyInputMessage - } - } - - const parts: OpenAI.Responses.ResponseInputContent[] = [] - if (content) { - parts.push({ - type: 'input_text', - text: content - }) - } - - for (const imageBlock of imageBlocks) { - if (isVision) { - if (imageBlock.file) { - const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) - parts.push({ - detail: 'auto', - type: 'input_image', - image_url: image.data as string - }) - } else if (imageBlock.url && imageBlock.url.startsWith('data:')) { - parts.push({ - detail: 'auto', - type: 'input_image', - image_url: imageBlock.url - }) - } - } - } - - for (const fileBlock of fileBlocks) { - const file = fileBlock.file - if (!file) continue - - if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { - const fileContent = (await window.api.file.read(file.id + file.ext)).trim() - parts.push({ - type: 'input_text', - text: file.origin_name + '\n' + fileContent - }) - } - } - - return { - role: message.role === 'system' ? 'user' : message.role, - content: parts - } - } - - protected getServiceTier(model: Model) { - if (!isOpenAIModel(model) || model.provider === 'github' || model.provider === 'copilot') { - return undefined - } - - const openAI = getStoreSetting('openAI') as any - let serviceTier = 'auto' as OpenAIServiceTier - - if (openAI && openAI?.serviceTier === 'flex') { - if (isSupportedFlexServiceTier(model)) { - serviceTier = 'flex' - } else { - serviceTier = 'auto' - } - } else { - serviceTier = openAI.serviceTier - } - - return serviceTier - } - - protected getTimeout(model: Model) { - if (isSupportedFlexServiceTier(model)) { - return 15 * 1000 * 60 - } - return 5 * 1000 * 60 - } - - private getResponseReasoningEffort(assistant: Assistant, model: Model) { - if (!isSupportedReasoningEffortOpenAIModel(model)) { - return {} - } - - const openAI = getStoreSetting('openAI') as any - const summaryText = (openAI?.summaryText as OpenAISummaryText) || 'off' - - let summary: string | undefined = undefined - - if (summaryText === 'off' || model.id.includes('o1-pro')) { - summary = undefined - } else { - summary = summaryText - } - - const reasoningEffort = assistant?.settings?.reasoning_effort - if (!reasoningEffort) { - return {} - } - - if (isSupportedReasoningEffortOpenAIModel(model)) { - return { - reasoning: { - effort: reasoningEffort as OpenAI.ReasoningEffort, - summary: summary - } as OpenAI.Reasoning - } - } - - return {} - } - - /** - * Get the message parameter - * @param message - The message - * @param model - The model - * @returns The message parameter - */ - protected async getMessageParam( - message: Message, - model: Model - ): Promise { - const isVision = isVisionModel(model) - const content = await this.getMessageContent(message) - const fileBlocks = findFileBlocks(message) - const imageBlocks = findImageBlocks(message) - - if (fileBlocks.length === 0 && imageBlocks.length === 0) { - return { - role: message.role === 'system' ? 'user' : message.role, - content - } - } - - const parts: ChatCompletionContentPart[] = [] - - if (content) { - parts.push({ type: 'text', text: content }) - } - - for (const imageBlock of imageBlocks) { - if (isVision) { - if (imageBlock.file) { - const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) - parts.push({ type: 'image_url', image_url: { url: image.data } }) - } else if (imageBlock.url && imageBlock.url.startsWith('data:')) { - parts.push({ type: 'image_url', image_url: { url: imageBlock.url } }) - } - } - } - - for (const fileBlock of fileBlocks) { - const { file } = fileBlock - if (!file) { - continue - } - - if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { - const fileContent = await (await window.api.file.read(file.id + file.ext)).trim() - parts.push({ - type: 'text', - text: file.origin_name + '\n' + fileContent - }) - } - } - - return { - role: message.role === 'system' ? 'user' : message.role, - content: parts - } as ChatCompletionMessageParam - } - - /** - * Generate completions for the assistant use Response API - * @param messages - The messages - * @param assistant - The assistant - * @param mcpTools - * @param onChunk - The onChunk callback - * @param onFilterMessages - The onFilterMessages callback - * @returns The completions - */ - async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams): Promise { - if (assistant.enableGenerateImage) { - await this.generateImageByChat({ messages, assistant, onChunk } as CompletionsParams) - return - } - const defaultModel = getDefaultModel() - const model = assistant.model || defaultModel - - const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) - const isEnabledBuiltinWebSearch = assistant.enableWebSearch && isWebSearchModel(model) - - let tools: OpenAI.Responses.Tool[] = [] - const toolChoices: OpenAI.Responses.ToolChoiceTypes = { - type: 'web_search_preview' - } - if (isEnabledBuiltinWebSearch) { - tools.push({ - type: 'web_search_preview' - }) - } - messages = addImageFileToContents(messages) - const systemMessage: OpenAI.Responses.EasyInputMessage = { - role: 'system', - content: [] - } - const systemMessageContent: OpenAI.Responses.ResponseInputMessageContentList = [] - const systemMessageInput: OpenAI.Responses.ResponseInputText = { - text: assistant.prompt || '', - type: 'input_text' - } - if (isSupportedReasoningEffortOpenAIModel(model)) { - systemMessage.role = 'developer' - } - - const { tools: extraTools } = this.setupToolsConfig({ - mcpTools, - model, - enableToolUse: isEnabledToolUse(assistant) - }) - - tools = tools.concat(extraTools) - - if (this.useSystemPromptForTools) { - systemMessageInput.text = await buildSystemPrompt(systemMessageInput.text || '', mcpTools) - } - systemMessageContent.push(systemMessageInput) - systemMessage.content = systemMessageContent - const _messages = filterUserRoleStartMessages( - filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 1))) - ) - - onFilterMessages(_messages) - const userMessage: OpenAI.Responses.ResponseInputItem[] = [] - for (const message of _messages) { - userMessage.push(await this.getReponseMessageParam(message, model)) - } - - const lastUserMessage = _messages.findLast((m) => m.role === 'user') - const { abortController, cleanup, signalPromise } = this.createAbortController(lastUserMessage?.id, true) - const { signal } = abortController - - // 当 systemMessage 内容为空时不发送 systemMessage - let reqMessages: OpenAI.Responses.ResponseInput - if (!systemMessage.content) { - reqMessages = [...userMessage] - } else { - reqMessages = [systemMessage, ...userMessage].filter(Boolean) as OpenAI.Responses.EasyInputMessage[] - } - - const finalUsage: Usage = { - completion_tokens: 0, - prompt_tokens: 0, - total_tokens: 0 - } - - const finalMetrics: Metrics = { - completion_tokens: 0, - time_completion_millsec: 0, - time_first_token_millsec: 0 - } - - const toolResponses: MCPToolResponse[] = [] - - const processToolResults = async (toolResults: Awaited>, idx: number) => { - if (toolResults.length === 0) return - - toolResults.forEach((ts) => reqMessages.push(ts as OpenAI.Responses.EasyInputMessage)) - - onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) - const stream = await this.sdk.responses.create( - { - model: model.id, - input: reqMessages, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - max_output_tokens: maxTokens, - stream: streamOutput, - tools: !isEmpty(tools) ? tools : undefined, - service_tier: this.getServiceTier(model), - ...this.getResponseReasoningEffort(assistant, model), - ...this.getCustomParameters(assistant) - }, - { - signal, - timeout: this.getTimeout(model) - } - ) - await processStream(stream, idx + 1) - } - - const processToolCalls = async (mcpTools, toolCalls: OpenAI.Responses.ResponseFunctionToolCall[]) => { - const mcpToolResponses = toolCalls - .map((toolCall) => { - const mcpTool = openAIToolsToMcpTool(mcpTools, toolCall as OpenAI.Responses.ResponseFunctionToolCall) - if (!mcpTool) return undefined - - const parsedArgs = (() => { - try { - return JSON.parse(toolCall.arguments) - } catch { - return toolCall.arguments - } - })() - - return { - id: toolCall.call_id, - toolCallId: toolCall.call_id, - tool: mcpTool, - arguments: parsedArgs, - status: 'pending' - } as ToolCallResponse - }) - .filter((t): t is ToolCallResponse => typeof t !== 'undefined') - - return await parseAndCallTools( - mcpToolResponses, - toolResponses, - onChunk, - this.mcpToolCallResponseToMessage, - model, - mcpTools - ) - } - - const processToolUses = async (content: string) => { - return await parseAndCallTools( - content, - toolResponses, - onChunk, - this.mcpToolCallResponseToMessage, - model, - mcpTools - ) - } - - const processStream = async ( - stream: Stream | OpenAI.Responses.Response, - idx: number - ) => { - const toolCalls: OpenAI.Responses.ResponseFunctionToolCall[] = [] - let time_first_token_millsec = 0 - - if (!streamOutput) { - const nonStream = stream as OpenAI.Responses.Response - const time_completion_millsec = new Date().getTime() - start_time_millsec - const completion_tokens = - (nonStream.usage?.output_tokens || 0) + (nonStream.usage?.output_tokens_details.reasoning_tokens ?? 0) - const total_tokens = - (nonStream.usage?.total_tokens || 0) + (nonStream.usage?.output_tokens_details.reasoning_tokens ?? 0) - const finalMetrics = { - completion_tokens, - time_completion_millsec, - time_first_token_millsec: 0 - } - const finalUsage = { - completion_tokens, - prompt_tokens: nonStream.usage?.input_tokens || 0, - total_tokens - } - let content = '' - - for (const output of nonStream.output) { - switch (output.type) { - case 'message': - if (output.content[0].type === 'output_text') { - onChunk({ type: ChunkType.TEXT_DELTA, text: output.content[0].text }) - onChunk({ type: ChunkType.TEXT_COMPLETE, text: output.content[0].text }) - content += output.content[0].text - if (output.content[0].annotations && output.content[0].annotations.length > 0) { - onChunk({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - source: WebSearchSource.OPENAI_RESPONSE, - results: output.content[0].annotations - } - }) - } - } - break - case 'reasoning': - onChunk({ - type: ChunkType.THINKING_COMPLETE, - text: output.summary.map((s) => s.text).join('\n'), - thinking_millsec: new Date().getTime() - start_time_millsec - }) - break - case 'function_call': - toolCalls.push(output) - } - } - - if (content) { - reqMessages.push({ - role: 'assistant', - content: content - }) - } - if (toolCalls.length) { - toolCalls.forEach((toolCall) => { - reqMessages.push(toolCall) - }) - } - - const toolResults: Awaited> = [] - if (toolCalls.length) { - toolResults.push(...(await processToolCalls(mcpTools, toolCalls))) - } - if (content.length) { - toolResults.push(...(await processToolUses(content))) - } - await processToolResults(toolResults, idx) - - onChunk({ - type: ChunkType.BLOCK_COMPLETE, - response: { - usage: finalUsage, - metrics: finalMetrics - } - }) - return - } - let content = '' - let thinkContent = '' - - const outputItems: OpenAI.Responses.ResponseOutputItem[] = [] - - for await (const chunk of stream as Stream) { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { - break - } - switch (chunk.type) { - case 'response.output_item.added': - if (chunk.item.type === 'function_call') { - outputItems.push(chunk.item) - } - break - case 'response.reasoning_summary_part.added': - if (time_first_token_millsec === 0) { - time_first_token_millsec = new Date().getTime() - } - // Insert separation between summary parts - if (thinkContent.length > 0) { - const separator = '\n\n' - onChunk({ - type: ChunkType.THINKING_DELTA, - text: separator, - thinking_millsec: new Date().getTime() - time_first_token_millsec - }) - thinkContent += separator - } - break - case 'response.reasoning_summary_text.delta': - onChunk({ - type: ChunkType.THINKING_DELTA, - text: chunk.delta, - thinking_millsec: new Date().getTime() - time_first_token_millsec - }) - thinkContent += chunk.delta - break - case 'response.output_item.done': { - if (thinkContent !== '' && chunk.item.type === 'reasoning') { - onChunk({ - type: ChunkType.THINKING_COMPLETE, - text: thinkContent, - thinking_millsec: new Date().getTime() - time_first_token_millsec - }) - } - break - } - case 'response.content_part.added': { - if (time_first_token_millsec === 0) { - time_first_token_millsec = new Date().getTime() - } - break - } - case 'response.output_text.delta': { - let delta = chunk.delta - if (isEnabledBuiltinWebSearch) { - delta = convertLinks(delta) - } - onChunk({ - type: ChunkType.TEXT_DELTA, - text: delta - }) - content += delta - break - } - case 'response.output_text.done': - onChunk({ - type: ChunkType.TEXT_COMPLETE, - text: content - }) - break - case 'response.function_call_arguments.done': { - const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find( - (item) => item.id === chunk.item_id - ) - if (outputItem) { - if (outputItem.type === 'function_call') { - toolCalls.push({ - ...outputItem, - arguments: chunk.arguments - }) - } - } - - break - } - case 'response.content_part.done': - if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) { - onChunk({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - source: WebSearchSource.OPENAI_RESPONSE, - results: chunk.part.annotations - } - }) - } - break - case 'response.completed': { - const completion_tokens = - (chunk.response.usage?.output_tokens || 0) + - (chunk.response.usage?.output_tokens_details.reasoning_tokens ?? 0) - const total_tokens = - (chunk.response.usage?.total_tokens || 0) + - (chunk.response.usage?.output_tokens_details.reasoning_tokens ?? 0) - finalUsage.completion_tokens += completion_tokens - finalUsage.prompt_tokens += chunk.response.usage?.input_tokens || 0 - finalUsage.total_tokens += total_tokens - finalMetrics.completion_tokens += completion_tokens - finalMetrics.time_completion_millsec += new Date().getTime() - start_time_millsec - finalMetrics.time_first_token_millsec = time_first_token_millsec - start_time_millsec - break - } - case 'error': - onChunk({ - type: ChunkType.ERROR, - error: { - message: chunk.message, - code: chunk.code - } - }) - break - } - - // --- End of Incremental onChunk calls --- - } // End of for await loop - if (content) { - reqMessages.push({ - role: 'assistant', - content: content - }) - } - if (toolCalls.length) { - toolCalls.forEach((toolCall) => { - reqMessages.push(toolCall) - }) - } - - // Call processToolUses AFTER the loop finishes processing the main stream content - // Note: parseAndCallTools inside processToolUses should handle its own onChunk for tool responses - const toolResults: Awaited> = [] - if (toolCalls.length) { - toolResults.push(...(await processToolCalls(mcpTools, toolCalls))) - } - if (content) { - toolResults.push(...(await processToolUses(content))) - } - await processToolResults(toolResults, idx) - - onChunk({ - type: ChunkType.BLOCK_COMPLETE, - response: { - usage: finalUsage, - metrics: finalMetrics - } - }) - } - - onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) - const start_time_millsec = new Date().getTime() - const stream = await this.sdk.responses.create( - { - model: model.id, - input: reqMessages, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - max_output_tokens: maxTokens, - stream: streamOutput, - tools: tools.length > 0 ? tools : undefined, - tool_choice: isEnabledBuiltinWebSearch ? toolChoices : undefined, - service_tier: this.getServiceTier(model), - ...this.getResponseReasoningEffort(assistant, model), - ...this.getCustomParameters(assistant) - }, - { - signal, - timeout: this.getTimeout(model) - } - ) - - await processStream(stream, 0).finally(cleanup) - - // 捕获signal的错误 - await signalPromise?.promise?.catch((error) => { - throw error - }) - } - - /** - * Translate the content - * @param content - The content - * @param assistant - The assistant - * @param onResponse - The onResponse callback - * @returns The translated content - */ - async translate( - content: string, - assistant: Assistant, - onResponse?: (text: string, isComplete: boolean) => void - ): Promise { - const defaultModel = getDefaultModel() - const model = assistant.model || defaultModel - const messageForApi: OpenAI.Responses.EasyInputMessage[] = content - ? [ - { - role: 'system', - content: assistant.prompt - }, - { - role: 'user', - content - } - ] - : [{ role: 'user', content: assistant.prompt }] - - const isOpenAIReasoning = isOpenAIReasoningModel(model) - const isSupportedStreamOutput = () => { - if (!onResponse) { - return false - } - return !isOpenAIReasoning - } - - const stream = isSupportedStreamOutput() - let text = '' - if (stream) { - const response = await this.sdk.responses.create({ - model: model.id, - input: messageForApi, - stream: true, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - ...this.getResponseReasoningEffort(assistant, model) - }) - - for await (const chunk of response) { - switch (chunk.type) { - case 'response.output_text.delta': - text += chunk.delta - onResponse?.(text, false) - break - case 'response.output_text.done': - onResponse?.(chunk.text, true) - break - } - } - } else { - const response = await this.sdk.responses.create({ - model: model.id, - input: messageForApi, - stream: false, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - ...this.getResponseReasoningEffort(assistant, model) - }) - return response.output_text - } - - return text - } - - /** - * Summarize the messages - * @param messages - The messages - * @param assistant - The assistant - * @returns The summary - */ - public async summaries(messages: Message[], assistant: Assistant): Promise { - const model = getTopNamingModel() || assistant.model || getDefaultModel() - const userMessages = takeRight(messages, 5).map((message) => ({ - role: message.role, - content: getMainTextContent(message) - })) - const userMessageContent = userMessages.reduce((prev, curr) => { - const content = curr.role === 'user' ? `User: ${curr.content}` : `Assistant: ${curr.content}` - return prev + (prev ? '\n' : '') + content - }, '') - - const systemMessage: OpenAI.Responses.EasyInputMessage = { - role: 'system', - content: (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title') - } - - const userMessage: OpenAI.Responses.EasyInputMessage = { - role: 'user', - content: userMessageContent - } - - const response = await this.sdk.responses.create({ - model: model.id, - input: [systemMessage, userMessage], - stream: false, - max_output_tokens: 1000 - }) - return removeSpecialCharactersForTopicName(response.output_text.substring(0, 50)) - } - - public async summaryForSearch(messages: Message[], assistant: Assistant): Promise { - const model = assistant.model || getDefaultModel() - - const systemMessage: OpenAI.Responses.EasyInputMessage = { - role: 'system', - content: assistant.prompt - } - - const messageContents = messages.map((m) => getMainTextContent(m)) - const userMessageContent = messageContents.join('\n') - - const userMessage: OpenAI.Responses.EasyInputMessage = { - role: 'user', - content: userMessageContent - } - - const lastUserMessage = messages[messages.length - 1] - - const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id) - const { signal } = abortController - - const response = await this.sdk.responses - .create( - { - model: model.id, - input: [systemMessage, userMessage], - stream: false, - max_output_tokens: 1000 - }, - { - signal, - timeout: 20 * 1000 - } - ) - .finally(cleanup) - - return response.output_text - } - - /** - * Generate suggestions - * @param messages - The messages - * @param assistant - The assistant - * @returns The suggestions - */ - async suggestions(messages: Message[], assistant: Assistant): Promise { - const model = assistant.model - - if (!model) { - return [] - } - - const userMessagesForApi = messages - .filter((m) => m.role === 'user') - .map((m) => ({ - role: m.role, - content: getMainTextContent(m) - })) - - const response: any = await this.sdk.request({ - method: 'post', - path: '/advice_questions', - body: { - messages: userMessagesForApi, - model: model.id, - max_tokens: 0, - temperature: 0, - n: 0 - } - }) - - return response?.questions?.filter(Boolean)?.map((q: any) => ({ content: q })) || [] - } - - /** - * Generate text - * @param prompt - The prompt - * @param content - The content - * @returns The generated text - */ - public async generateText({ prompt, content }: { prompt: string; content: string }): Promise { - const model = getDefaultModel() - const response = await this.sdk.responses.create({ - model: model.id, - stream: false, - input: [ - { role: 'system', content: prompt }, - { role: 'user', content } - ] - }) - return response.output_text - } - - /** - * Check if the model is valid - * @param model - The model - * @param stream - Whether to use streaming interface - * @returns The validity of the model - */ - public async check(model: Model, stream: boolean): Promise<{ valid: boolean; error: Error | null }> { - if (!model) { - return { valid: false, error: new Error('No model found') } - } - try { - if (stream) { - const response = await this.sdk.responses.create({ - model: model.id, - input: [{ role: 'user', content: 'hi' }], - stream: true - }) - for await (const chunk of response) { - if (chunk.type === 'response.output_text.delta') { - return { valid: true, error: null } - } - } - return { valid: false, error: new Error('No streaming response') } - } else { - const response = await this.sdk.responses.create({ - model: model.id, - input: [{ role: 'user', content: 'hi' }], - stream: false - }) - if (!response.output_text) { - return { valid: false, error: new Error('No response') } - } - return { valid: true, error: null } - } - } catch (error: any) { - return { valid: false, error: error } - } - } - - /** - * Get the models - * @returns The models - */ - public async models(): Promise { - try { - const response = await this.sdk.models.list() - const models = response.data || [] - models.forEach((model) => { - model.id = model.id.trim() - }) - return models.filter(isSupportedModel) - } catch (error) { - return [] - } - } - - /** - * Generate an image - * @param params - The parameters - * @returns The generated image - */ - public async generateImage({ - model, - prompt, - negativePrompt, - imageSize, - batchSize, - seed, - numInferenceSteps, - guidanceScale, - signal, - promptEnhancement - }: GenerateImageParams): Promise { - const response = (await this.sdk.request({ - method: 'post', - path: '/images/generations', - signal, - body: { - model, - prompt, - negative_prompt: negativePrompt, - image_size: imageSize, - batch_size: batchSize, - seed: seed ? parseInt(seed) : undefined, - num_inference_steps: numInferenceSteps, - guidance_scale: guidanceScale, - prompt_enhancement: promptEnhancement - } - })) as { data: Array<{ url: string }> } - - return response.data.map((item) => item.url) - } - - public async generateImageByChat({ messages, assistant, onChunk }: CompletionsParams): Promise { - const defaultModel = getDefaultModel() - const model = assistant.model || defaultModel - // save image data from the last assistant message - messages = addImageFileToContents(messages) - const lastUserMessage = messages.findLast((m) => m.role === 'user') - const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant') - if (!lastUserMessage) { - return - } - - const { abortController } = this.createAbortController(lastUserMessage?.id, true) - const { signal } = abortController - const content = getMainTextContent(lastUserMessage!) - let response: OpenAI.Images.ImagesResponse | null = null - let images: Uploadable[] = [] - - try { - if (lastUserMessage) { - const UserFiles = findImageBlocks(lastUserMessage) - const validUserFiles = UserFiles.filter((f) => f.file) // Filter out files that are undefined first - const userImages = await Promise.all( - validUserFiles.map(async (f) => { - // f.file is guaranteed to exist here due to the filter above - const fileInfo = f.file! - const binaryData = await FileManager.readBinaryImage(fileInfo) - return await toFile(binaryData, fileInfo.origin_name || 'image.png', { - type: 'image/png' - }) - }) - ) - images = images.concat(userImages) - } - - if (lastAssistantMessage) { - const assistantFiles = findImageBlocks(lastAssistantMessage) - const assistantImages = await Promise.all( - assistantFiles.filter(Boolean).map(async (f) => { - const match = f?.url?.match(/^data:(image\/\w+);base64,(.+)$/) - if (!match) return null - const mimeType = match[1] - const extension = mime.getExtension(mimeType) || 'bin' - const bytes = Base64.toUint8Array(match[2]) - const fileName = `assistant_image.${extension}` - return await toFile(bytes, fileName, { type: mimeType }) - }) - ) - images = images.concat(assistantImages.filter(Boolean) as Uploadable[]) - } - - onChunk({ - type: ChunkType.LLM_RESPONSE_CREATED - }) - - onChunk({ - type: ChunkType.IMAGE_CREATED - }) - - const start_time_millsec = new Date().getTime() - - if (images.length > 0) { - response = await this.sdk.images.edit( - { - model: model.id, - image: images, - prompt: content || '', - ...this.getCustomParameters(assistant) - }, - { - signal, - timeout: 300_000 - } - ) - } else { - response = await this.sdk.images.generate( - { - model: model.id, - prompt: content || '', - response_format: model.id.includes('gpt-image-1') ? undefined : 'b64_json', - ...this.getCustomParameters(assistant) - }, - { - signal, - timeout: 300_000 - } - ) - } - - onChunk({ - type: ChunkType.IMAGE_COMPLETE, - image: { - type: 'base64', - images: response?.data?.map((item) => `data:image/png;base64,${item.b64_json}`) || [] - } - }) - - onChunk({ - type: ChunkType.BLOCK_COMPLETE, - response: { - usage: { - completion_tokens: response.usage?.output_tokens || 0, - prompt_tokens: response.usage?.input_tokens || 0, - total_tokens: response.usage?.total_tokens || 0 - }, - metrics: { - completion_tokens: response.usage?.output_tokens || 0, - time_first_token_millsec: 0, // Non-streaming, first token time is not relevant - time_completion_millsec: new Date().getTime() - start_time_millsec - } - } - }) - } catch (error: any) { - console.error('[generateImageByChat] error', error) - onChunk({ - type: ChunkType.ERROR, - error - }) - } - } - - /** - * Get the embedding dimensions - * @param model - The model - * @returns The embedding dimensions - */ - public async getEmbeddingDimensions(model: Model): Promise { - const data = await this.sdk.embeddings.create({ - model: model.id, - input: 'hi' - }) - return data.data[0].embedding.length - } -} - -export default class OpenAIResponseProvider extends BaseOpenAIProvider { - private providers: Map = new Map() - - constructor(provider: Provider) { - super(provider) - this.providers.set('openai-compatible', new OpenAIProvider(provider)) - } - - private getProvider(model: Model): BaseOpenAIProvider { - if (isOpenAIWebSearch(model) || model.id.includes('o1-preview') || model.id.includes('o1-mini')) { - return this.providers.get('openai-compatible')! - } else { - return this - } - } - - public completions(params: CompletionsParams): Promise { - const model = params.assistant.model - if (!model) { - return Promise.reject(new Error('Model is required')) - } - - const provider = this.getProvider(model) - return provider === this ? super.completions(params) : provider.completions(params) - } - - public convertMcpTools(mcpTools: MCPTool[]) { - return mcpToolsToOpenAIResponseTools(mcpTools) as T[] - } - - public mcpToolCallResponseToMessage = ( - mcpToolResponse: MCPToolResponse, - resp: MCPCallToolResponse, - model: Model - ): OpenAI.Responses.ResponseInputItem | undefined => { - if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { - return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model)) - } else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) { - return { - type: 'function_call_output', - call_id: mcpToolResponse.toolCallId, - output: JSON.stringify(resp.content) - } - } - return - } -} diff --git a/src/renderer/src/providers/AiProvider/ProviderFactory.ts b/src/renderer/src/providers/AiProvider/ProviderFactory.ts deleted file mode 100644 index d8c1f40e6f..0000000000 --- a/src/renderer/src/providers/AiProvider/ProviderFactory.ts +++ /dev/null @@ -1,33 +0,0 @@ -import { Provider } from '@renderer/types' - -import AihubmixProvider from './AihubmixProvider' -import AnthropicProvider from './AnthropicProvider' -import BaseProvider from './BaseProvider' -import GeminiProvider from './GeminiProvider' -import OpenAIProvider from './OpenAIProvider' -import OpenAIResponseProvider from './OpenAIResponseProvider' - -export default class ProviderFactory { - static create(provider: Provider): BaseProvider { - if (provider.id === 'aihubmix') { - return new AihubmixProvider(provider) - } - - switch (provider.type) { - case 'openai': - return new OpenAIProvider(provider) - case 'openai-response': - return new OpenAIResponseProvider(provider) - case 'anthropic': - return new AnthropicProvider(provider) - case 'gemini': - return new GeminiProvider(provider) - default: - return new OpenAIProvider(provider) - } - } -} - -export function isOpenAIProvider(provider: Provider) { - return !['anthropic', 'gemini'].includes(provider.type) -} diff --git a/src/renderer/src/providers/AiProvider/index.ts b/src/renderer/src/providers/AiProvider/index.ts deleted file mode 100644 index ef77dff14b..0000000000 --- a/src/renderer/src/providers/AiProvider/index.ts +++ /dev/null @@ -1,94 +0,0 @@ -import { GenerateImagesParameters } from '@google/genai' -import BaseProvider from '@renderer/providers/AiProvider/BaseProvider' -import ProviderFactory from '@renderer/providers/AiProvider/ProviderFactory' -import type { Assistant, GenerateImageParams, MCPTool, Model, Provider, Suggestion } from '@renderer/types' -import { Chunk } from '@renderer/types/chunk' -import type { Message } from '@renderer/types/newMessage' -import OpenAI from 'openai' - -export interface CompletionsParams { - messages: Message[] - assistant: Assistant - onChunk: (chunk: Chunk) => void - onFilterMessages: (messages: Message[]) => void - mcpTools?: MCPTool[] -} - -export default class AiProvider { - private sdk: BaseProvider - - constructor(provider: Provider) { - this.sdk = ProviderFactory.create(provider) - } - - public async fakeCompletions(params: CompletionsParams): Promise { - return this.sdk.fakeCompletions(params) - } - - public async completions({ - messages, - assistant, - mcpTools, - onChunk, - onFilterMessages - }: CompletionsParams): Promise { - return this.sdk.completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }) - } - - public async translate( - content: string, - assistant: Assistant, - onResponse?: (text: string, isComplete: boolean) => void - ): Promise { - return this.sdk.translate(content, assistant, onResponse) - } - - public async summaries(messages: Message[], assistant: Assistant): Promise { - return this.sdk.summaries(messages, assistant) - } - - public async summaryForSearch(messages: Message[], assistant: Assistant): Promise { - return this.sdk.summaryForSearch(messages, assistant) - } - - public async suggestions(messages: Message[], assistant: Assistant): Promise { - return this.sdk.suggestions(messages, assistant) - } - - public async generateText({ prompt, content }: { prompt: string; content: string }): Promise { - return this.sdk.generateText({ prompt, content }) - } - - public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> { - return this.sdk.check(model, stream) - } - - public async models(): Promise { - return this.sdk.models() - } - - public getApiKey(): string { - return this.sdk.getApiKey() - } - - public async generateImage(params: GenerateImageParams | GenerateImagesParameters): Promise { - return this.sdk.generateImage(params as GenerateImageParams) - } - - public async generateImageByChat({ - messages, - assistant, - onChunk, - onFilterMessages - }: CompletionsParams): Promise { - return this.sdk.generateImageByChat({ messages, assistant, onChunk, onFilterMessages }) - } - - public async getEmbeddingDimensions(model: Model): Promise { - return this.sdk.getEmbeddingDimensions(model) - } - - public getBaseURL(): string { - return this.sdk.getBaseURL() - } -} diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index d0ebdb4c22..d9becd6952 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -1,10 +1,21 @@ +import { CompletionsParams } from '@renderer/aiCore/middleware/schemas' import Logger from '@renderer/config/logger' -import { getOpenAIWebSearchParams, isOpenAIWebSearch } from '@renderer/config/models' +import { + isEmbeddingModel, + isGenerateImageModel, + isOpenRouterBuiltInWebSearchModel, + isReasoningModel, + isSupportedDisableGenerationModel, + isSupportedReasoningEffortModel, + isSupportedThinkingTokenModel, + isWebSearchModel +} from '@renderer/config/models' import { SEARCH_SUMMARY_PROMPT, SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY, SEARCH_SUMMARY_PROMPT_WEB_ONLY } from '@renderer/config/prompts' +import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import { Assistant, @@ -13,20 +24,22 @@ import { MCPTool, Model, Provider, - Suggestion, WebSearchResponse, WebSearchSource } from '@renderer/types' import { type Chunk, ChunkType } from '@renderer/types/chunk' import { Message } from '@renderer/types/newMessage' +import { SdkModel } from '@renderer/types/sdk' +import { removeSpecialCharactersForTopicName } from '@renderer/utils' import { isAbortError } from '@renderer/utils/error' import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract' import { getKnowledgeBaseIds, getMainTextContent } from '@renderer/utils/messageUtils/find' -import { findLast, isEmpty } from 'lodash' +import { findLast, isEmpty, takeRight } from 'lodash' -import AiProvider from '../providers/AiProvider' +import AiProvider from '../aiCore' import { getAssistantProvider, + getAssistantSettings, getDefaultModel, getProviderByModel, getTopNamingModel, @@ -34,7 +47,13 @@ import { } from './AssistantService' import { getDefaultAssistant } from './AssistantService' import { processKnowledgeSearch } from './KnowledgeService' -import { filterContextMessages, filterMessages, filterUsefulMessages } from './MessagesService' +import { + filterContextMessages, + filterEmptyMessages, + filterMessages, + filterUsefulMessages, + filterUserRoleStartMessages +} from './MessagesService' import WebSearchService from './WebSearchService' // TODO:考虑拆开 @@ -50,6 +69,7 @@ async function fetchExternalTool( const knowledgeRecognition = assistant.knowledgeRecognition || 'on' const webSearchProvider = WebSearchService.getWebSearchProvider(assistant.webSearchProviderId) + // 使用外部搜索工具 const shouldWebSearch = !!assistant.webSearchProviderId && webSearchProvider !== null const shouldKnowledgeSearch = hasKnowledgeBase @@ -83,14 +103,14 @@ async function fetchExternalTool( summaryAssistant.prompt = prompt try { - const keywords = await fetchSearchSummary({ + const result = await fetchSearchSummary({ messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage], assistant: summaryAssistant }) - if (!keywords) return getFallbackResult() + if (!result) return getFallbackResult() - const extracted = extractInfoFromXML(keywords) + const extracted = extractInfoFromXML(result.getText()) // 根据需求过滤结果 return { websearch: needWebExtract ? extracted?.websearch : undefined, @@ -134,12 +154,6 @@ async function fetchExternalTool( return undefined } - // Pass the guaranteed model to the check function - const webSearchParams = getOpenAIWebSearchParams(assistant, assistant.model) - if (!isEmpty(webSearchParams) || isOpenAIWebSearch(assistant.model)) { - return - } - try { // Use the consolidated processWebsearch function WebSearchService.createAbortSignal(lastUserMessage.id) @@ -238,7 +252,7 @@ async function fetchExternalTool( // Get MCP tools (Fix duplicate declaration) let mcpTools: MCPTool[] = [] // Initialize as empty array - const enabledMCPs = lastUserMessage?.enabledMCPs + const enabledMCPs = assistant.mcpServers if (enabledMCPs && enabledMCPs.length > 0) { try { const toolPromises = enabledMCPs.map(async (mcpServer) => { @@ -301,17 +315,52 @@ export async function fetchChatCompletion({ // NOTE: The search results are NOT added to the messages sent to the AI here. // They will be retrieved and used by the messageThunk later to create CitationBlocks. const { mcpTools } = await fetchExternalTool(lastUserMessage, assistant, onChunkReceived, lastAnswer) + const model = assistant.model || getDefaultModel() + + const { maxTokens, contextCount } = getAssistantSettings(assistant) const filteredMessages = filterUsefulMessages(messages) + const _messages = filterUserRoleStartMessages( + filterEmptyMessages(filterContextMessages(takeRight(filteredMessages, contextCount + 2))) // 取原来几个provider的最大值 + ) + + const enableReasoning = + ((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) && + assistant.settings?.reasoning_effort !== undefined) || + (isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model))) + + const enableWebSearch = + (assistant.enableWebSearch && isWebSearchModel(model)) || + isOpenRouterBuiltInWebSearchModel(model) || + model.id.includes('sonar') || + false + + const enableGenerateImage = + isGenerateImageModel(model) && (isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage : true) + // --- Call AI Completions --- - await AI.completions({ - messages: filteredMessages, - assistant, - onFilterMessages: () => {}, - onChunk: onChunkReceived, - mcpTools: mcpTools - }) + onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED }) + if (enableWebSearch) { + onChunkReceived({ type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS }) + } + await AI.completions( + { + callType: 'chat', + messages: _messages, + assistant, + onChunk: onChunkReceived, + mcpTools: mcpTools, + maxTokens, + streamOutput: assistant.settings?.streamOutput || false, + enableReasoning, + enableWebSearch, + enableGenerateImage + }, + { + streamOutput: assistant.settings?.streamOutput || false + } + ) } interface FetchTranslateProps { @@ -321,7 +370,7 @@ interface FetchTranslateProps { } export async function fetchTranslate({ content, assistant, onResponse }: FetchTranslateProps) { - const model = getTranslateModel() + const model = getTranslateModel() || assistant.model || getDefaultModel() if (!model) { throw new Error(i18n.t('error.provider_disabled')) @@ -333,17 +382,42 @@ export async function fetchTranslate({ content, assistant, onResponse }: FetchTr throw new Error(i18n.t('error.no_api_key')) } + const isSupportedStreamOutput = () => { + if (!onResponse) { + return false + } + return true + } + + const stream = isSupportedStreamOutput() + const enableReasoning = + ((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) && + assistant.settings?.reasoning_effort !== undefined) || + (isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model))) + + const params: CompletionsParams = { + callType: 'translate', + messages: content, + assistant: { ...assistant, model }, + streamOutput: stream, + enableReasoning, + onResponse + } + const AI = new AiProvider(provider) try { - return await AI.translate(content, assistant, onResponse) + return (await AI.completions(params)).getText() || '' } catch (error: any) { return '' } } export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) { + const prompt = (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title') const model = getTopNamingModel() || assistant.model || getDefaultModel() + const userMessages = takeRight(messages, 5) + const provider = getProviderByModel(model) if (!hasApiKey(provider)) { @@ -352,9 +426,18 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages: const AI = new AiProvider(provider) + const params: CompletionsParams = { + callType: 'summary', + messages: filterMessages(userMessages), + assistant: { ...assistant, prompt, model }, + maxTokens: 1000, + streamOutput: false + } + try { - const text = await AI.summaries(filterMessages(messages), assistant) - return text?.replace(/["']/g, '') || null + const { getText } = await AI.completions(params) + const text = getText() + return removeSpecialCharactersForTopicName(text) || null } catch (error: any) { return null } @@ -370,7 +453,14 @@ export async function fetchSearchSummary({ messages, assistant }: { messages: Me const AI = new AiProvider(provider) - return await AI.summaryForSearch(messages, assistant) + const params: CompletionsParams = { + callType: 'search', + messages: messages, + assistant, + streamOutput: false + } + + return await AI.completions(params) } export async function fetchGenerate({ prompt, content }: { prompt: string; content: string }): Promise { @@ -383,42 +473,32 @@ export async function fetchGenerate({ prompt, content }: { prompt: string; conte const AI = new AiProvider(provider) + const assistant = getDefaultAssistant() + assistant.model = model + assistant.prompt = prompt + + const params: CompletionsParams = { + callType: 'generate', + messages: content, + assistant, + streamOutput: false + } + try { - return await AI.generateText({ prompt, content }) + const result = await AI.completions(params) + return result.getText() || '' } catch (error: any) { return '' } } -export async function fetchSuggestions({ - messages, - assistant -}: { - messages: Message[] - assistant: Assistant -}): Promise { - const model = assistant.model - if (!model || model.id.endsWith('global')) { - return [] - } - - const provider = getAssistantProvider(assistant) - const AI = new AiProvider(provider) - - try { - return await AI.suggestions(filterMessages(messages), assistant) - } catch (error: any) { - return [] - } -} - function hasApiKey(provider: Provider) { if (!provider) return false if (provider.id === 'ollama' || provider.id === 'lmstudio') return true return !isEmpty(provider.apiKey) } -export async function fetchModels(provider: Provider) { +export async function fetchModels(provider: Provider): Promise { const AI = new AiProvider(provider) try { @@ -432,68 +512,69 @@ export const formatApiKeys = (value: string) => { return value.replaceAll(',', ',').replaceAll(' ', ',').replaceAll(' ', '').replaceAll('\n', ',') } -export function checkApiProvider(provider: Provider): { - valid: boolean - error: Error | null -} { +export function checkApiProvider(provider: Provider): void { const key = 'api-check' const style = { marginTop: '3vh' } if (provider.id !== 'ollama' && provider.id !== 'lmstudio') { if (!provider.apiKey) { window.message.error({ content: i18n.t('message.error.enter.api.key'), key, style }) - return { - valid: false, - error: new Error(i18n.t('message.error.enter.api.key')) - } + throw new Error(i18n.t('message.error.enter.api.key')) } } if (!provider.apiHost) { window.message.error({ content: i18n.t('message.error.enter.api.host'), key, style }) - return { - valid: false, - error: new Error(i18n.t('message.error.enter.api.host')) - } + throw new Error(i18n.t('message.error.enter.api.host')) } if (isEmpty(provider.models)) { window.message.error({ content: i18n.t('message.error.enter.model'), key, style }) - return { - valid: false, - error: new Error(i18n.t('message.error.enter.model')) - } - } - - return { - valid: true, - error: null + throw new Error(i18n.t('message.error.enter.model')) } } -export async function checkApi(provider: Provider, model: Model): Promise<{ valid: boolean; error: Error | null }> { - const validation = checkApiProvider(provider) - if (!validation.valid) { - return { - valid: validation.valid, - error: validation.error - } - } +export async function checkApi(provider: Provider, model: Model): Promise { + checkApiProvider(provider) const ai = new AiProvider(provider) - // Try streaming check first - const result = await ai.check(model, true) + const assistant = getDefaultAssistant() + assistant.model = model + try { + if (isEmbeddingModel(model)) { + const result = await ai.getEmbeddingDimensions(model) + if (result === 0) { + throw new Error(i18n.t('message.error.enter.model')) + } + } else { + const params: CompletionsParams = { + callType: 'check', + messages: 'hi', + assistant, + streamOutput: true + } - if (result.valid && !result.error) { - return result - } - - // 不应该假设错误由流式引发。多次发起检测请求可能触发429,掩盖了真正的问题。 - // 但这里错误类型做的很粗糙,暂时先这样 - if (result.error && result.error.message.includes('stream')) { - return ai.check(model, false) - } else { - return result + // Try streaming check first + const result = await ai.completions(params) + if (!result.getText()) { + throw new Error('No response received') + } + } + } catch (error: any) { + if (error.message.includes('stream')) { + const params: CompletionsParams = { + callType: 'check', + messages: 'hi', + assistant, + streamOutput: false + } + const result = await ai.completions(params) + if (!result.getText()) { + throw new Error('No response received') + } + } else { + throw error + } } } diff --git a/src/renderer/src/services/HealthCheckService.ts b/src/renderer/src/services/HealthCheckService.ts index e631e1f40c..598074b87e 100644 --- a/src/renderer/src/services/HealthCheckService.ts +++ b/src/renderer/src/services/HealthCheckService.ts @@ -98,14 +98,20 @@ export async function checkModelWithMultipleKeys( if (isParallel) { // Check all API keys in parallel const keyPromises = apiKeys.map(async (key) => { - const result = await checkModel({ ...provider, apiKey: key }, model) - - return { - key, - isValid: result.valid, - error: result.error?.message, - latency: result.latency - } as ApiKeyCheckStatus + try { + const result = await checkModel({ ...provider, apiKey: key }, model) + return { + key, + isValid: true, + latency: result.latency + } as ApiKeyCheckStatus + } catch (error: unknown) { + return { + key, + isValid: false, + error: error instanceof Error ? error.message.slice(0, 20) + '...' : String(error).slice(0, 20) + '...' + } as ApiKeyCheckStatus + } }) const results = await Promise.allSettled(keyPromises) @@ -125,14 +131,20 @@ export async function checkModelWithMultipleKeys( } else { // Check all API keys serially for (const key of apiKeys) { - const result = await checkModel({ ...provider, apiKey: key }, model) - - keyResults.push({ - key, - isValid: result.valid, - error: result.error?.message, - latency: result.latency - }) + try { + const result = await checkModel({ ...provider, apiKey: key }, model) + keyResults.push({ + key, + isValid: true, + latency: result.latency + }) + } catch (error: unknown) { + keyResults.push({ + key, + isValid: false, + error: error instanceof Error ? error.message.slice(0, 20) + '...' : String(error).slice(0, 20) + '...' + }) + } } } diff --git a/src/renderer/src/services/KnowledgeService.ts b/src/renderer/src/services/KnowledgeService.ts index 8a28732a02..4ddc4360b1 100644 --- a/src/renderer/src/services/KnowledgeService.ts +++ b/src/renderer/src/services/KnowledgeService.ts @@ -1,8 +1,8 @@ import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' +import AiProvider from '@renderer/aiCore' import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant' import { getEmbeddingMaxContext } from '@renderer/config/embedings' import Logger from '@renderer/config/logger' -import AiProvider from '@renderer/providers/AiProvider' import store from '@renderer/store' import { FileType, KnowledgeBase, KnowledgeBaseParams, KnowledgeReference } from '@renderer/types' import { ExtractResults } from '@renderer/utils/extract' diff --git a/src/renderer/src/services/ModelService.ts b/src/renderer/src/services/ModelService.ts index 3ca5c485e5..e9f13dfa94 100644 --- a/src/renderer/src/services/ModelService.ts +++ b/src/renderer/src/services/ModelService.ts @@ -1,11 +1,9 @@ -import { isEmbeddingModel } from '@renderer/config/models' -import AiProvider from '@renderer/providers/AiProvider' import store from '@renderer/store' import { Model, Provider } from '@renderer/types' import { t } from 'i18next' import { pick } from 'lodash' -import { checkApiProvider } from './ApiService' +import { checkApi } from './ApiService' export const getModelUniqId = (m?: Model) => { return m?.id ? JSON.stringify(pick(m, ['id', 'provider'])) : '' @@ -33,64 +31,23 @@ export function getModelName(model?: Model) { return modelName } -// Generic function to perform model checks -// Abstracts provider validation and error handling, allowing different types of check logic +// Generic function to perform model checks with exception handling async function performModelCheck( provider: Provider, model: Model, - checkFn: (ai: AiProvider, model: Model) => Promise, - processResult: (result: T) => { valid: boolean; error: Error | null } -): Promise<{ valid: boolean; error: Error | null; latency?: number }> { - const validation = checkApiProvider(provider) - if (!validation.valid) { - return { - valid: validation.valid, - error: validation.error - } - } + checkFn: (provider: Provider, model: Model) => Promise +): Promise<{ latency: number }> { + const startTime = performance.now() + await checkFn(provider, model) + const latency = performance.now() - startTime - const AI = new AiProvider(provider) - - try { - const startTime = performance.now() - const result = await checkFn(AI, model) - const latency = performance.now() - startTime - - return { - ...processResult(result), - latency - } - } catch (error: any) { - return { - valid: false, - error - } - } + return { latency } } // Unified model check function // Automatically selects appropriate check method based on model type -export async function checkModel(provider: Provider, model: Model) { - if (isEmbeddingModel(model)) { - return performModelCheck( - provider, - model, - (ai, model) => ai.getEmbeddingDimensions(model), - (dimensions) => ({ valid: dimensions > 0, error: null }) - ) - } else { - return performModelCheck( - provider, - model, - async (ai, model) => { - // Try streaming check first - const result = await ai.check(model, true) - if (result.valid && !result.error) { - return result - } - return ai.check(model, false) - }, - ({ valid, error }) => ({ valid, error: error || null }) - ) - } +export async function checkModel(provider: Provider, model: Model): Promise<{ latency: number }> { + return performModelCheck(provider, model, async (provider, model) => { + await checkApi(provider, model) + }) } diff --git a/src/renderer/src/services/StreamProcessingService.ts b/src/renderer/src/services/StreamProcessingService.ts index 67acc2f87d..527cc2242f 100644 --- a/src/renderer/src/services/StreamProcessingService.ts +++ b/src/renderer/src/services/StreamProcessingService.ts @@ -28,7 +28,9 @@ export interface StreamProcessorCallbacks { onLLMWebSearchComplete?: (llmWebSearchResult: WebSearchResponse) => void // Image generation chunk received onImageCreated?: () => void - onImageGenerated?: (imageData: GenerateImageResponse) => void + onImageDelta?: (imageData: GenerateImageResponse) => void + onImageGenerated?: (imageData?: GenerateImageResponse) => void + onLLMResponseComplete?: (response?: Response) => void // Called when an error occurs during chunk processing onError?: (error: any) => void // Called when the entire stream processing is signaled as complete (success or failure) @@ -40,59 +42,84 @@ export function createStreamProcessor(callbacks: StreamProcessorCallbacks = {}) // The returned function processes a single chunk or a final signal return (chunk: Chunk) => { try { - // Logger.log(`[${new Date().toLocaleString()}] createStreamProcessor ${chunk.type}`, chunk) - // 1. Handle the manual final signal first - if (chunk?.type === ChunkType.BLOCK_COMPLETE) { - callbacks.onComplete?.(AssistantMessageStatus.SUCCESS, chunk?.response) - return + const data = chunk + switch (data.type) { + case ChunkType.BLOCK_COMPLETE: { + if (callbacks.onComplete) callbacks.onComplete(AssistantMessageStatus.SUCCESS, data?.response) + break + } + case ChunkType.LLM_RESPONSE_CREATED: { + if (callbacks.onLLMResponseCreated) callbacks.onLLMResponseCreated() + break + } + case ChunkType.TEXT_DELTA: { + if (callbacks.onTextChunk) callbacks.onTextChunk(data.text) + break + } + case ChunkType.TEXT_COMPLETE: { + if (callbacks.onTextComplete) callbacks.onTextComplete(data.text) + break + } + case ChunkType.THINKING_DELTA: { + if (callbacks.onThinkingChunk) callbacks.onThinkingChunk(data.text, data.thinking_millsec) + break + } + case ChunkType.THINKING_COMPLETE: { + if (callbacks.onThinkingComplete) callbacks.onThinkingComplete(data.text, data.thinking_millsec) + break + } + case ChunkType.MCP_TOOL_IN_PROGRESS: { + if (callbacks.onToolCallInProgress) + data.responses.forEach((toolResp) => callbacks.onToolCallInProgress!(toolResp)) + break + } + case ChunkType.MCP_TOOL_COMPLETE: { + if (callbacks.onToolCallComplete && data.responses.length > 0) { + data.responses.forEach((toolResp) => callbacks.onToolCallComplete!(toolResp)) + } + break + } + case ChunkType.EXTERNEL_TOOL_IN_PROGRESS: { + if (callbacks.onExternalToolInProgress) callbacks.onExternalToolInProgress() + break + } + case ChunkType.EXTERNEL_TOOL_COMPLETE: { + if (callbacks.onExternalToolComplete) callbacks.onExternalToolComplete(data.external_tool) + break + } + case ChunkType.LLM_WEB_SEARCH_IN_PROGRESS: { + if (callbacks.onLLMWebSearchInProgress) callbacks.onLLMWebSearchInProgress() + break + } + case ChunkType.LLM_WEB_SEARCH_COMPLETE: { + if (callbacks.onLLMWebSearchComplete) callbacks.onLLMWebSearchComplete(data.llm_web_search) + break + } + case ChunkType.IMAGE_CREATED: { + if (callbacks.onImageCreated) callbacks.onImageCreated() + break + } + case ChunkType.IMAGE_DELTA: { + if (callbacks.onImageDelta) callbacks.onImageDelta(data.image) + break + } + case ChunkType.IMAGE_COMPLETE: { + if (callbacks.onImageGenerated) callbacks.onImageGenerated(data.image) + break + } + case ChunkType.LLM_RESPONSE_COMPLETE: { + if (callbacks.onLLMResponseComplete) callbacks.onLLMResponseComplete(data.response) + break + } + case ChunkType.ERROR: { + if (callbacks.onError) callbacks.onError(data.error) + break + } + default: { + // Handle unknown chunk types or log an error + console.warn(`Unknown chunk type: ${data.type}`) + } } - // 2. Process the actual ChunkCallbackData - const data = chunk // Cast after checking for 'final' - // Invoke callbacks based on the fields present in the chunk data - if (data.type === ChunkType.LLM_RESPONSE_CREATED && callbacks.onLLMResponseCreated) { - callbacks.onLLMResponseCreated() - } - if (data.type === ChunkType.TEXT_DELTA && callbacks.onTextChunk) { - callbacks.onTextChunk(data.text) - } - if (data.type === ChunkType.TEXT_COMPLETE && callbacks.onTextComplete) { - callbacks.onTextComplete(data.text) - } - if (data.type === ChunkType.THINKING_DELTA && callbacks.onThinkingChunk) { - callbacks.onThinkingChunk(data.text, data.thinking_millsec) - } - if (data.type === ChunkType.THINKING_COMPLETE && callbacks.onThinkingComplete) { - callbacks.onThinkingComplete(data.text, data.thinking_millsec) - } - if (data.type === ChunkType.MCP_TOOL_IN_PROGRESS && callbacks.onToolCallInProgress) { - data.responses.forEach((toolResp) => callbacks.onToolCallInProgress!(toolResp)) - } - if (data.type === ChunkType.MCP_TOOL_COMPLETE && data.responses.length > 0 && callbacks.onToolCallComplete) { - data.responses.forEach((toolResp) => callbacks.onToolCallComplete!(toolResp)) - } - if (data.type === ChunkType.EXTERNEL_TOOL_IN_PROGRESS && callbacks.onExternalToolInProgress) { - callbacks.onExternalToolInProgress() - } - if (data.type === ChunkType.EXTERNEL_TOOL_COMPLETE && callbacks.onExternalToolComplete) { - callbacks.onExternalToolComplete(data.external_tool) - } - if (data.type === ChunkType.LLM_WEB_SEARCH_IN_PROGRESS && callbacks.onLLMWebSearchInProgress) { - callbacks.onLLMWebSearchInProgress() - } - if (data.type === ChunkType.LLM_WEB_SEARCH_COMPLETE && callbacks.onLLMWebSearchComplete) { - callbacks.onLLMWebSearchComplete(data.llm_web_search) - } - if (data.type === ChunkType.IMAGE_CREATED && callbacks.onImageCreated) { - callbacks.onImageCreated() - } - if (data.type === ChunkType.IMAGE_COMPLETE && callbacks.onImageGenerated) { - callbacks.onImageGenerated(data.image) - } - if (data.type === ChunkType.ERROR && callbacks.onError) { - callbacks.onError(data.error) - } - // Note: Usage and Metrics are usually handled at the end or accumulated differently, - // so direct callbacks might not be the best fit here. They are often part of the final message state. } catch (error) { console.error('Error processing stream chunk:', error) callbacks.onError?.(error) diff --git a/src/renderer/src/store/thunk/messageThunk.ts b/src/renderer/src/store/thunk/messageThunk.ts index 0932ff929d..17487ad69b 100644 --- a/src/renderer/src/store/thunk/messageThunk.ts +++ b/src/renderer/src/store/thunk/messageThunk.ts @@ -8,7 +8,6 @@ import { createStreamProcessor, type StreamProcessorCallbacks } from '@renderer/ import { estimateMessagesUsage } from '@renderer/services/TokenService' import store from '@renderer/store' import type { Assistant, ExternalToolResult, FileType, MCPToolResponse, Model, Topic } from '@renderer/types' -import { WebSearchSource } from '@renderer/types' import type { CitationMessageBlock, FileMessageBlock, @@ -22,7 +21,6 @@ import { AssistantMessageStatus, MessageBlockStatus, MessageBlockType } from '@r import { Response } from '@renderer/types/newMessage' import { uuid } from '@renderer/utils' import { formatErrorMessage, isAbortError } from '@renderer/utils/error' -import { extractUrlsFromMarkdown } from '@renderer/utils/linkConverter' import { createAssistantMessage, createBaseMessageBlock, @@ -35,7 +33,8 @@ import { createTranslationBlock, resetAssistantMessage } from '@renderer/utils/messageUtils/create' -import { getTopicQueue, waitForTopicQueue } from '@renderer/utils/queue' +import { getMainTextContent } from '@renderer/utils/messageUtils/find' +import { getTopicQueue } from '@renderer/utils/queue' import { isOnHomePage } from '@renderer/utils/window' import { t } from 'i18next' import { isEmpty, throttle } from 'lodash' @@ -45,10 +44,10 @@ import type { AppDispatch, RootState } from '../index' import { removeManyBlocks, updateOneBlock, upsertManyBlocks, upsertOneBlock } from '../messageBlock' import { newMessagesActions, selectMessagesForTopic } from '../newMessage' -const handleChangeLoadingOfTopic = async (topicId: string) => { - await waitForTopicQueue(topicId) - store.dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false })) -} +// const handleChangeLoadingOfTopic = async (topicId: string) => { +// await waitForTopicQueue(topicId) +// store.dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false })) +// } // TODO: 后续可以将db操作移到Listener Middleware中 export const saveMessageAndBlocksToDB = async (message: Message, blocks: MessageBlock[], messageIndex: number = -1) => { try { @@ -337,10 +336,17 @@ const fetchAndProcessAssistantResponseImpl = async ( let accumulatedContent = '' let accumulatedThinking = '' + // 专注于管理UI焦点和块切换 let lastBlockId: string | null = null let lastBlockType: MessageBlockType | null = null + // 专注于块内部的生命周期处理 + let initialPlaceholderBlockId: string | null = null let citationBlockId: string | null = null let mainTextBlockId: string | null = null + let thinkingBlockId: string | null = null + let imageBlockId: string | null = null + let toolBlockId: string | null = null + let hasWebSearch = false const toolCallIdToBlockIdMap = new Map() const notificationService = NotificationService.getInstance() @@ -400,129 +406,129 @@ const fetchAndProcessAssistantResponseImpl = async ( } callbacks = { - onLLMResponseCreated: () => { + onLLMResponseCreated: async () => { const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, { status: MessageBlockStatus.PROCESSING }) - handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN) + initialPlaceholderBlockId = baseBlock.id + await handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN) }, - onTextChunk: (text) => { + onTextChunk: async (text) => { accumulatedContent += text - if (lastBlockId) { - if (lastBlockType === MessageBlockType.UNKNOWN) { - const initialChanges: Partial = { - type: MessageBlockType.MAIN_TEXT, - content: accumulatedContent, - status: MessageBlockStatus.STREAMING, - citationReferences: citationBlockId ? [{ citationBlockId }] : [] - } - mainTextBlockId = lastBlockId - lastBlockType = MessageBlockType.MAIN_TEXT - dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges })) - saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState) - } else if (lastBlockType === MessageBlockType.MAIN_TEXT) { - const blockChanges: Partial = { - content: accumulatedContent, - status: MessageBlockStatus.STREAMING - } - throttledBlockUpdate(lastBlockId, blockChanges) - // throttledBlockDbUpdate(lastBlockId, blockChanges) - } else { - const newBlock = createMainTextBlock(assistantMsgId, accumulatedContent, { - status: MessageBlockStatus.STREAMING, - citationReferences: citationBlockId ? [{ citationBlockId }] : [] - }) - handleBlockTransition(newBlock, MessageBlockType.MAIN_TEXT) - mainTextBlockId = newBlock.id + if (mainTextBlockId) { + const blockChanges: Partial = { + content: accumulatedContent, + status: MessageBlockStatus.STREAMING } + throttledBlockUpdate(mainTextBlockId, blockChanges) + } else if (initialPlaceholderBlockId) { + // 将占位块转换为主文本块 + const initialChanges: Partial = { + type: MessageBlockType.MAIN_TEXT, + content: accumulatedContent, + status: MessageBlockStatus.STREAMING, + citationReferences: citationBlockId ? [{ citationBlockId }] : [] + } + mainTextBlockId = initialPlaceholderBlockId + // 清理占位块 + initialPlaceholderBlockId = null + lastBlockType = MessageBlockType.MAIN_TEXT + dispatch(updateOneBlock({ id: mainTextBlockId, changes: initialChanges })) + saveUpdatedBlockToDB(mainTextBlockId, assistantMsgId, topicId, getState) + } else { + const newBlock = createMainTextBlock(assistantMsgId, accumulatedContent, { + status: MessageBlockStatus.STREAMING, + citationReferences: citationBlockId ? [{ citationBlockId }] : [] + }) + mainTextBlockId = newBlock.id // 立即设置ID,防止竞态条件 + await handleBlockTransition(newBlock, MessageBlockType.MAIN_TEXT) } }, onTextComplete: async (finalText) => { - if (lastBlockType === MessageBlockType.MAIN_TEXT && lastBlockId) { + if (mainTextBlockId) { const changes = { content: finalText, status: MessageBlockStatus.SUCCESS } - cancelThrottledBlockUpdate(lastBlockId) - dispatch(updateOneBlock({ id: lastBlockId, changes })) - saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState) - - if (assistant.enableWebSearch && assistant.model?.provider === 'openrouter') { - const extractedUrls = extractUrlsFromMarkdown(finalText) - if (extractedUrls.length > 0) { - const citationBlock = createCitationBlock( - assistantMsgId, - { response: { source: WebSearchSource.OPENROUTER, results: extractedUrls } }, - { status: MessageBlockStatus.SUCCESS } - ) - await handleBlockTransition(citationBlock, MessageBlockType.CITATION) - // saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState) - } - } + cancelThrottledBlockUpdate(mainTextBlockId) + dispatch(updateOneBlock({ id: mainTextBlockId, changes })) + saveUpdatedBlockToDB(mainTextBlockId, assistantMsgId, topicId, getState) + mainTextBlockId = null } else { console.warn( - `[onTextComplete] Received text.complete but last block was not MAIN_TEXT (was ${lastBlockType}) or lastBlockId is null.` + `[onTextComplete] Received text.complete but last block was not MAIN_TEXT (was ${lastBlockType}) or lastBlockId is null.` ) } - }, - onThinkingChunk: (text, thinking_millsec) => { - accumulatedThinking += text - if (lastBlockId) { - if (lastBlockType === MessageBlockType.UNKNOWN) { - // First chunk for this block: Update type and status immediately - lastBlockType = MessageBlockType.THINKING - const initialChanges: Partial = { - type: MessageBlockType.THINKING, - content: accumulatedThinking, - status: MessageBlockStatus.STREAMING - } - dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges })) - saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState) - } else if (lastBlockType === MessageBlockType.THINKING) { - const blockChanges: Partial = { - content: accumulatedThinking, - status: MessageBlockStatus.STREAMING, - thinking_millsec: thinking_millsec - } - throttledBlockUpdate(lastBlockId, blockChanges) - // throttledBlockDbUpdate(lastBlockId, blockChanges) - } else { - const newBlock = createThinkingBlock(assistantMsgId, accumulatedThinking, { - status: MessageBlockStatus.STREAMING, - thinking_millsec: 0 - }) - handleBlockTransition(newBlock, MessageBlockType.THINKING) + if (citationBlockId && !hasWebSearch) { + const changes: Partial = { + status: MessageBlockStatus.SUCCESS } + dispatch(updateOneBlock({ id: citationBlockId, changes })) + saveUpdatedBlockToDB(citationBlockId, assistantMsgId, topicId, getState) + citationBlockId = null + } + }, + onThinkingChunk: async (text, thinking_millsec) => { + accumulatedThinking += text + if (thinkingBlockId) { + const blockChanges: Partial = { + content: accumulatedThinking, + status: MessageBlockStatus.STREAMING, + thinking_millsec: thinking_millsec + } + throttledBlockUpdate(thinkingBlockId, blockChanges) + } else if (initialPlaceholderBlockId) { + // First chunk for this block: Update type and status immediately + lastBlockType = MessageBlockType.THINKING + const initialChanges: Partial = { + type: MessageBlockType.THINKING, + content: accumulatedThinking, + status: MessageBlockStatus.STREAMING + } + thinkingBlockId = initialPlaceholderBlockId + initialPlaceholderBlockId = null + dispatch(updateOneBlock({ id: thinkingBlockId, changes: initialChanges })) + saveUpdatedBlockToDB(thinkingBlockId, assistantMsgId, topicId, getState) + } else { + const newBlock = createThinkingBlock(assistantMsgId, accumulatedThinking, { + status: MessageBlockStatus.STREAMING, + thinking_millsec: 0 + }) + thinkingBlockId = newBlock.id // 立即设置ID,防止竞态条件 + await handleBlockTransition(newBlock, MessageBlockType.THINKING) } }, onThinkingComplete: (finalText, final_thinking_millsec) => { - if (lastBlockType === MessageBlockType.THINKING && lastBlockId) { + if (thinkingBlockId) { const changes = { type: MessageBlockType.THINKING, content: finalText, status: MessageBlockStatus.SUCCESS, thinking_millsec: final_thinking_millsec } - cancelThrottledBlockUpdate(lastBlockId) - dispatch(updateOneBlock({ id: lastBlockId, changes })) - saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState) + cancelThrottledBlockUpdate(thinkingBlockId) + dispatch(updateOneBlock({ id: thinkingBlockId, changes })) + saveUpdatedBlockToDB(thinkingBlockId, assistantMsgId, topicId, getState) } else { console.warn( - `[onThinkingComplete] Received thinking.complete but last block was not THINKING (was ${lastBlockType}) or lastBlockId is null.` + `[onThinkingComplete] Received thinking.complete but last block was not THINKING (was ${lastBlockType}) or lastBlockId is null.` ) } + thinkingBlockId = null }, onToolCallInProgress: (toolResponse: MCPToolResponse) => { - if (lastBlockType === MessageBlockType.UNKNOWN && lastBlockId) { + if (initialPlaceholderBlockId) { lastBlockType = MessageBlockType.TOOL const changes = { type: MessageBlockType.TOOL, status: MessageBlockStatus.PROCESSING, metadata: { rawMcpToolResponse: toolResponse } } - dispatch(updateOneBlock({ id: lastBlockId, changes })) - saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState) - toolCallIdToBlockIdMap.set(toolResponse.id, lastBlockId) + toolBlockId = initialPlaceholderBlockId + initialPlaceholderBlockId = null + dispatch(updateOneBlock({ id: toolBlockId, changes })) + saveUpdatedBlockToDB(toolBlockId, assistantMsgId, topicId, getState) + toolCallIdToBlockIdMap.set(toolResponse.id, toolBlockId) } else if (toolResponse.status === 'invoking') { const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, { toolName: toolResponse.tool.name, @@ -539,6 +545,7 @@ const fetchAndProcessAssistantResponseImpl = async ( }, onToolCallComplete: (toolResponse: MCPToolResponse) => { const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id) + toolCallIdToBlockIdMap.delete(toolResponse.id) if (toolResponse.status === 'done' || toolResponse.status === 'error') { if (!existingBlockId) { console.error( @@ -564,10 +571,10 @@ const fetchAndProcessAssistantResponseImpl = async ( ) } }, - onExternalToolInProgress: () => { + onExternalToolInProgress: async () => { const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING }) citationBlockId = citationBlock.id - handleBlockTransition(citationBlock, MessageBlockType.CITATION) + await handleBlockTransition(citationBlock, MessageBlockType.CITATION) // saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState) }, onExternalToolComplete: (externalToolResult: ExternalToolResult) => { @@ -583,35 +590,39 @@ const fetchAndProcessAssistantResponseImpl = async ( console.error('[onExternalToolComplete] citationBlockId is null. Cannot update.') } }, - onLLMWebSearchInProgress: () => { - const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING }) - citationBlockId = citationBlock.id - handleBlockTransition(citationBlock, MessageBlockType.CITATION) - // saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState) + onLLMWebSearchInProgress: async () => { + if (initialPlaceholderBlockId) { + lastBlockType = MessageBlockType.CITATION + citationBlockId = initialPlaceholderBlockId + const changes = { + type: MessageBlockType.CITATION, + status: MessageBlockStatus.PROCESSING + } + lastBlockType = MessageBlockType.CITATION + dispatch(updateOneBlock({ id: initialPlaceholderBlockId, changes })) + saveUpdatedBlockToDB(initialPlaceholderBlockId, assistantMsgId, topicId, getState) + initialPlaceholderBlockId = null + } else { + const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING }) + citationBlockId = citationBlock.id + await handleBlockTransition(citationBlock, MessageBlockType.CITATION) + } }, onLLMWebSearchComplete: async (llmWebSearchResult) => { if (citationBlockId) { + hasWebSearch = true const changes: Partial = { response: llmWebSearchResult, status: MessageBlockStatus.SUCCESS } dispatch(updateOneBlock({ id: citationBlockId, changes })) saveUpdatedBlockToDB(citationBlockId, assistantMsgId, topicId, getState) - } else { - const citationBlock = createCitationBlock( - assistantMsgId, - { response: llmWebSearchResult }, - { status: MessageBlockStatus.SUCCESS } - ) - citationBlockId = citationBlock.id - handleBlockTransition(citationBlock, MessageBlockType.CITATION) - } - if (mainTextBlockId) { - const state = getState() - const existingMainTextBlock = state.messageBlocks.entities[mainTextBlockId] - if (existingMainTextBlock && existingMainTextBlock.type === MessageBlockType.MAIN_TEXT) { - const currentRefs = existingMainTextBlock.citationReferences || [] - if (!currentRefs.some((ref) => ref.citationBlockId === citationBlockId)) { + + if (mainTextBlockId) { + const state = getState() + const existingMainTextBlock = state.messageBlocks.entities[mainTextBlockId] + if (existingMainTextBlock && existingMainTextBlock.type === MessageBlockType.MAIN_TEXT) { + const currentRefs = existingMainTextBlock.citationReferences || [] const mainTextChanges = { citationReferences: [ ...currentRefs, @@ -621,40 +632,64 @@ const fetchAndProcessAssistantResponseImpl = async ( dispatch(updateOneBlock({ id: mainTextBlockId, changes: mainTextChanges })) saveUpdatedBlockToDB(mainTextBlockId, assistantMsgId, topicId, getState) } + mainTextBlockId = null } } }, - onImageCreated: () => { - if (lastBlockId) { - if (lastBlockType === MessageBlockType.UNKNOWN) { - const initialChanges: Partial = { - type: MessageBlockType.IMAGE, - status: MessageBlockStatus.STREAMING - } - lastBlockType = MessageBlockType.IMAGE - dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges })) - saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState) - } else { - const imageBlock = createImageBlock(assistantMsgId, { - status: MessageBlockStatus.PROCESSING - }) - handleBlockTransition(imageBlock, MessageBlockType.IMAGE) + onImageCreated: async () => { + if (initialPlaceholderBlockId) { + lastBlockType = MessageBlockType.IMAGE + const initialChanges: Partial = { + type: MessageBlockType.IMAGE, + status: MessageBlockStatus.STREAMING } + lastBlockType = MessageBlockType.IMAGE + imageBlockId = initialPlaceholderBlockId + initialPlaceholderBlockId = null + dispatch(updateOneBlock({ id: imageBlockId, changes: initialChanges })) + saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState) + } else if (!imageBlockId) { + const imageBlock = createImageBlock(assistantMsgId, { + status: MessageBlockStatus.STREAMING + }) + imageBlockId = imageBlock.id + await handleBlockTransition(imageBlock, MessageBlockType.IMAGE) } }, - onImageGenerated: (imageData) => { + onImageDelta: (imageData) => { const imageUrl = imageData.images?.[0] || 'placeholder_image_url' - if (lastBlockId && lastBlockType === MessageBlockType.IMAGE) { + if (imageBlockId) { const changes: Partial = { url: imageUrl, metadata: { generateImageResponse: imageData }, - status: MessageBlockStatus.SUCCESS + status: MessageBlockStatus.STREAMING + } + dispatch(updateOneBlock({ id: imageBlockId, changes })) + saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState) + } + }, + onImageGenerated: (imageData) => { + if (imageBlockId) { + if (!imageData) { + const changes: Partial = { + status: MessageBlockStatus.SUCCESS + } + dispatch(updateOneBlock({ id: imageBlockId, changes })) + saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState) + } else { + const imageUrl = imageData.images?.[0] || 'placeholder_image_url' + const changes: Partial = { + url: imageUrl, + metadata: { generateImageResponse: imageData }, + status: MessageBlockStatus.SUCCESS + } + dispatch(updateOneBlock({ id: imageBlockId, changes })) + saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState) } - dispatch(updateOneBlock({ id: lastBlockId, changes })) - saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState) } else { console.error('[onImageGenerated] Last block was not an Image block or ID is missing.') } + imageBlockId = null }, onError: async (error) => { console.dir(error, { depth: null }) @@ -683,15 +718,16 @@ const fetchAndProcessAssistantResponseImpl = async ( source: 'assistant' }) } - - if (lastBlockId) { + const possibleBlockId = + mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId + if (possibleBlockId) { // 更改上一个block的状态为ERROR const changes: Partial = { status: isErrorTypeAbort ? MessageBlockStatus.PAUSED : MessageBlockStatus.ERROR } - cancelThrottledBlockUpdate(lastBlockId) - dispatch(updateOneBlock({ id: lastBlockId, changes })) - saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState) + cancelThrottledBlockUpdate(possibleBlockId) + dispatch(updateOneBlock({ id: possibleBlockId, changes })) + saveUpdatedBlockToDB(possibleBlockId, assistantMsgId, topicId, getState) } const errorBlock = createErrorBlock(assistantMsgId, serializableError, { status: MessageBlockStatus.SUCCESS }) @@ -721,35 +757,45 @@ const fetchAndProcessAssistantResponseImpl = async ( const contextForUsage = userMsgIndex !== -1 ? orderedMsgs.slice(0, userMsgIndex + 1) : [] const finalContextWithAssistant = [...contextForUsage, finalAssistantMsg] - if (lastBlockId) { + const possibleBlockId = + mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId + if (possibleBlockId) { const changes: Partial = { status: MessageBlockStatus.SUCCESS } - cancelThrottledBlockUpdate(lastBlockId) - dispatch(updateOneBlock({ id: lastBlockId, changes })) - saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState) + cancelThrottledBlockUpdate(possibleBlockId) + dispatch(updateOneBlock({ id: possibleBlockId, changes })) + saveUpdatedBlockToDB(possibleBlockId, assistantMsgId, topicId, getState) } - // const content = getMainTextContent(finalAssistantMsg) - // if (!isOnHomePage()) { - // await notificationService.send({ - // id: uuid(), - // type: 'success', - // title: t('notification.assistant'), - // message: content.length > 50 ? content.slice(0, 47) + '...' : content, - // silent: false, - // timestamp: Date.now(), - // source: 'assistant' - // }) - // } + const endTime = Date.now() + const duration = endTime - startTime + const content = getMainTextContent(finalAssistantMsg) + if (!isOnHomePage() && duration > 60 * 1000) { + await notificationService.send({ + id: uuid(), + type: 'success', + title: t('notification.assistant'), + message: content.length > 50 ? content.slice(0, 47) + '...' : content, + silent: false, + timestamp: Date.now(), + source: 'assistant' + }) + } // 更新topic的name autoRenameTopic(assistant, topicId) - if (response && response.usage?.total_tokens === 0) { + if ( + response && + (response.usage?.total_tokens === 0 || + response?.usage?.prompt_tokens === 0 || + response?.usage?.completion_tokens === 0) + ) { const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant }) response.usage = usage } + dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false })) } if (response && response.metrics) { if (response.metrics.completion_tokens === 0 && response.usage?.completion_tokens) { @@ -779,6 +825,7 @@ const fetchAndProcessAssistantResponseImpl = async ( const streamProcessorCallbacks = createStreamProcessor(callbacks) + const startTime = Date.now() await fetchChatCompletion({ messages: messagesForContext, assistant: assistant, @@ -833,9 +880,10 @@ export const sendMessage = } } catch (error) { console.error('Error in sendMessage thunk:', error) - } finally { - handleChangeLoadingOfTopic(topicId) } + // finally { + // handleChangeLoadingOfTopic(topicId) + // } } /** @@ -1069,9 +1117,10 @@ export const resendMessageThunk = } } catch (error) { console.error(`[resendMessageThunk] Error resending user message ${userMessageToResend.id}:`, error) - } finally { - handleChangeLoadingOfTopic(topicId) } + // finally { + // handleChangeLoadingOfTopic(topicId) + // } } /** @@ -1179,10 +1228,11 @@ export const regenerateAssistantResponseThunk = `[regenerateAssistantResponseThunk] Error regenerating response for assistant message ${assistantMessageToRegenerate.id}:`, error ) - dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false })) - } finally { - handleChangeLoadingOfTopic(topicId) + // dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false })) } + // finally { + // handleChangeLoadingOfTopic(topicId) + // } } // --- Thunk to initiate translation and create the initial block --- @@ -1348,9 +1398,10 @@ export const appendAssistantResponseThunk = console.error(`[appendAssistantResponseThunk] Error appending assistant response:`, error) // Optionally dispatch an error action or notification // Resetting loading state should be handled by the underlying fetchAndProcessAssistantResponseImpl - } finally { - handleChangeLoadingOfTopic(topicId) } + // finally { + // handleChangeLoadingOfTopic(topicId) + // } } /** diff --git a/src/renderer/src/types/chunk.ts b/src/renderer/src/types/chunk.ts index dbdff685eb..4cb4755382 100644 --- a/src/renderer/src/types/chunk.ts +++ b/src/renderer/src/types/chunk.ts @@ -1,5 +1,6 @@ -import { ExternalToolResult, KnowledgeReference, MCPToolResponse, WebSearchResponse } from '.' +import { ExternalToolResult, KnowledgeReference, MCPToolResponse, ToolUseResponse, WebSearchResponse } from '.' import { Response, ResponseError } from './newMessage' +import { SdkToolCall } from './sdk' // Define Enum for Chunk Types // 目前用到的,并没有列出完整的生命周期 @@ -11,6 +12,7 @@ export enum ChunkType { WEB_SEARCH_COMPLETE = 'web_search_complete', KNOWLEDGE_SEARCH_IN_PROGRESS = 'knowledge_search_in_progress', KNOWLEDGE_SEARCH_COMPLETE = 'knowledge_search_complete', + MCP_TOOL_CREATED = 'mcp_tool_created', MCP_TOOL_IN_PROGRESS = 'mcp_tool_in_progress', MCP_TOOL_COMPLETE = 'mcp_tool_complete', EXTERNEL_TOOL_COMPLETE = 'externel_tool_complete', @@ -118,7 +120,7 @@ export interface ImageDeltaChunk { /** * A chunk of Base64 encoded image data */ - image: string + image: { type: 'base64'; images: string[] } /** * The type of the chunk @@ -135,7 +137,7 @@ export interface ImageCompleteChunk { /** * The image content of the chunk */ - image: { type: 'base64'; images: string[] } + image?: { type: 'base64'; images: string[] } } export interface ThinkingDeltaChunk { @@ -253,6 +255,12 @@ export interface ExternalToolCompleteChunk { type: ChunkType.EXTERNEL_TOOL_COMPLETE } +export interface MCPToolCreatedChunk { + type: ChunkType.MCP_TOOL_CREATED + tool_calls?: SdkToolCall[] // 工具调用 + tool_use_responses?: ToolUseResponse[] // 工具使用响应 +} + export interface MCPToolInProgressChunk { /** * The type of the chunk @@ -345,6 +353,7 @@ export type Chunk = | WebSearchCompleteChunk // 互联网搜索完成 | KnowledgeSearchInProgressChunk // 知识库搜索进行中 | KnowledgeSearchCompleteChunk // 知识库搜索完成 + | MCPToolCreatedChunk // MCP工具被大模型创建 | MCPToolInProgressChunk // MCP工具调用中 | MCPToolCompleteChunk // MCP工具调用完成 | ExternalToolCompleteChunk // 外部工具调用完成,外部工具包含搜索互联网,知识库,MCP服务器 diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 6c7ef0576f..d82929f0df 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -1,5 +1,5 @@ import type { WebSearchResultBlock } from '@anthropic-ai/sdk/resources' -import type { GenerateImagesConfig, GroundingMetadata } from '@google/genai' +import type { GenerateImagesConfig, GroundingMetadata, PersonGeneration } from '@google/genai' import type OpenAI from 'openai' import type { CSSProperties } from 'react' @@ -444,10 +444,11 @@ export type GenerateImageParams = { imageSize: string batchSize: number seed?: string - numInferenceSteps: number - guidanceScale: number + numInferenceSteps?: number + guidanceScale?: number signal?: AbortSignal promptEnhancement?: boolean + personGeneration?: PersonGeneration } export type GenerateImageResponse = { @@ -520,7 +521,7 @@ export enum WebSearchSource { } export type WebSearchResponse = { - results: WebSearchResults + results?: WebSearchResults source: WebSearchSource } diff --git a/src/renderer/src/types/sdk.ts b/src/renderer/src/types/sdk.ts new file mode 100644 index 0000000000..cef04febff --- /dev/null +++ b/src/renderer/src/types/sdk.ts @@ -0,0 +1,107 @@ +import Anthropic from '@anthropic-ai/sdk' +import { + Message, + MessageCreateParams, + MessageParam, + RawMessageStreamEvent, + ToolUnion, + ToolUseBlock +} from '@anthropic-ai/sdk/resources' +import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages' +import { + Content, + CreateChatParameters, + FunctionCall, + GenerateContentResponse, + GoogleGenAI, + Model as GeminiModel, + SendMessageParameters, + Tool +} from '@google/genai' +import OpenAI, { AzureOpenAI } from 'openai' +import { Stream } from 'openai/streaming' + +export type SdkInstance = OpenAI | AzureOpenAI | Anthropic | GoogleGenAI +export type SdkParams = OpenAISdkParams | OpenAIResponseSdkParams | AnthropicSdkParams | GeminiSdkParams +export type SdkRawChunk = OpenAISdkRawChunk | OpenAIResponseSdkRawChunk | AnthropicSdkRawChunk | GeminiSdkRawChunk +export type SdkRawOutput = OpenAISdkRawOutput | OpenAIResponseSdkRawOutput | AnthropicSdkRawOutput | GeminiSdkRawOutput +export type SdkMessageParam = + | OpenAISdkMessageParam + | OpenAIResponseSdkMessageParam + | AnthropicSdkMessageParam + | GeminiSdkMessageParam +export type SdkToolCall = + | OpenAI.Chat.Completions.ChatCompletionMessageToolCall + | ToolUseBlock + | FunctionCall + | OpenAIResponseSdkToolCall +export type SdkTool = OpenAI.Chat.Completions.ChatCompletionTool | ToolUnion | Tool | OpenAIResponseSdkTool +export type SdkModel = OpenAI.Models.Model | Anthropic.ModelInfo | GeminiModel + +export type RequestOptions = Anthropic.RequestOptions | OpenAI.RequestOptions | GeminiOptions + +/** + * OpenAI + */ + +type OpenAIParamsWithoutReasoningEffort = Omit + +export type ReasoningEffortOptionalParams = { + thinking?: { type: 'disabled' | 'enabled'; budget_tokens?: number } + reasoning?: { max_tokens?: number; exclude?: boolean; effort?: string } | OpenAI.Reasoning + reasoning_effort?: OpenAI.Chat.Completions.ChatCompletionCreateParams['reasoning_effort'] | 'none' | 'auto' + enable_thinking?: boolean + thinking_budget?: number + enable_reasoning?: boolean + // Add any other potential reasoning-related keys here if they exist +} + +export type OpenAISdkParams = OpenAIParamsWithoutReasoningEffort & ReasoningEffortOptionalParams +export type OpenAISdkRawChunk = + | OpenAI.Chat.Completions.ChatCompletionChunk + | ({ + _request_id?: string | null | undefined + } & OpenAI.ChatCompletion) + +export type OpenAISdkRawOutput = Stream | OpenAI.ChatCompletion +export type OpenAISdkRawContentSource = + | OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta + | OpenAI.Chat.Completions.ChatCompletionMessage + +export type OpenAISdkMessageParam = OpenAI.Chat.Completions.ChatCompletionMessageParam + +/** + * OpenAI Response + */ + +export type OpenAIResponseSdkParams = OpenAI.Responses.ResponseCreateParams +export type OpenAIResponseSdkRawOutput = Stream | OpenAI.Responses.Response +export type OpenAIResponseSdkRawChunk = OpenAI.Responses.ResponseStreamEvent | OpenAI.Responses.Response +export type OpenAIResponseSdkMessageParam = OpenAI.Responses.ResponseInputItem +export type OpenAIResponseSdkToolCall = OpenAI.Responses.ResponseFunctionToolCall +export type OpenAIResponseSdkTool = OpenAI.Responses.Tool + +/** + * Anthropic + */ + +export type AnthropicSdkParams = MessageCreateParams +export type AnthropicSdkRawOutput = MessageStream | Message +export type AnthropicSdkRawChunk = RawMessageStreamEvent | Message +export type AnthropicSdkMessageParam = MessageParam + +/** + * Gemini + */ + +export type GeminiSdkParams = SendMessageParameters & CreateChatParameters +export type GeminiSdkRawOutput = AsyncGenerator | GenerateContentResponse +export type GeminiSdkRawChunk = GenerateContentResponse +export type GeminiSdkMessageParam = Content +export type GeminiSdkToolCall = FunctionCall + +export type GeminiOptions = { + streamOutput: boolean + abortSignal?: AbortSignal + timeout?: number +} diff --git a/src/renderer/src/utils/linkConverter.ts b/src/renderer/src/utils/linkConverter.ts index 258b85c9fd..238c88b10e 100644 --- a/src/renderer/src/utils/linkConverter.ts +++ b/src/renderer/src/utils/linkConverter.ts @@ -369,3 +369,99 @@ export function cleanLinkCommas(text: string): string { // 匹配两个 Markdown 链接之间的英文逗号(可能包含空格) return text.replace(/\]\(([^)]+)\)\s*,\s*\[/g, ']($1)[') } + +/** + * 从文本中识别各种格式的Web搜索引用占位符 + * 支持的格式包括:[1], [ref_1], [1](@ref), [1,2,3](@ref) 等 + * @param {string} text 要分析的文本 + * @returns {Array} 识别到的引用信息数组 + */ +export function extractWebSearchReferences(text: string): Array<{ + match: string + placeholder: string + numbers: number[] + startIndex: number + endIndex: number +}> { + const references: Array<{ + match: string + placeholder: string + numbers: number[] + startIndex: number + endIndex: number + }> = [] + + // 匹配各种引用格式的正则表达式 + const patterns = [ + // [1], [2], [3] - 简单数字引用 + { regex: /\[(\d+)\]/g, type: 'simple' }, + // [ref_1], [ref_2] - Zhipu格式 + { regex: /\[ref_(\d+)\]/g, type: 'zhipu' }, + // [1](@ref), [2](@ref) - Hunyuan单个引用格式 + { regex: /\[(\d+)\]\(@ref\)/g, type: 'hunyuan_single' }, + // [1,2,3](@ref) - Hunyuan多个引用格式 + { regex: /\[([\d,\s]+)\]\(@ref\)/g, type: 'hunyuan_multiple' } + ] + + patterns.forEach(({ regex, type }) => { + let match + while ((match = regex.exec(text)) !== null) { + let numbers: number[] = [] + + if (type === 'hunyuan_multiple') { + // 解析逗号分隔的数字 + numbers = match[1] + .split(',') + .map((num) => parseInt(num.trim())) + .filter((num) => !isNaN(num)) + } else { + // 单个数字 + numbers = [parseInt(match[1])] + } + + references.push({ + match: match[0], + placeholder: match[0], + numbers: numbers, + startIndex: match.index!, + endIndex: match.index! + match[0].length + }) + } + }) + + // 按位置排序 + return references.sort((a, b) => a.startIndex - b.startIndex) +} + +/** + * 智能链接转换器 - 根据文本中的引用模式和Web搜索结果自动选择合适的转换策略 + * @param {string} text 当前文本块 + * @param {any[]} webSearchResults Web搜索结果数组 + * @param {string} providerType Provider类型 ('openai', 'zhipu', 'hunyuan', 'openrouter', etc.) + * @param {boolean} resetCounter 是否重置计数器 + * @returns {string} 转换后的文本 + */ +export function smartLinkConverter( + text: string, + providerType: string = 'openai', + resetCounter: boolean = false +): string { + // 检测文本中的引用模式 + const references = extractWebSearchReferences(text) + + if (references.length === 0) { + // 如果没有特定的引用模式,使用通用转换 + return convertLinks(text, resetCounter) + } + + // 根据检测到的引用模式选择合适的转换器 + const hasZhipuPattern = references.some((ref) => ref.placeholder.includes('ref_')) + + if (hasZhipuPattern) { + return convertLinksToZhipu(text, resetCounter) + } else if (providerType === 'openrouter') { + return convertLinksToOpenRouter(text, resetCounter) + } else { + return convertLinks(text, resetCounter) + } +} diff --git a/src/renderer/src/utils/mcp-tools.ts b/src/renderer/src/utils/mcp-tools.ts index 8bdc499135..b26836ee13 100644 --- a/src/renderer/src/utils/mcp-tools.ts +++ b/src/renderer/src/utils/mcp-tools.ts @@ -1,10 +1,4 @@ -import { - ContentBlockParam, - MessageParam, - ToolResultBlockParam, - ToolUnion, - ToolUseBlock -} from '@anthropic-ai/sdk/resources' +import { ContentBlockParam, MessageParam, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources' import { Content, FunctionCall, Part, Tool, Type as GeminiSchemaType } from '@google/genai' import Logger from '@renderer/config/logger' import { isFunctionCallingModel, isVisionModel } from '@renderer/config/models' @@ -21,6 +15,7 @@ import { } from '@renderer/types' import type { MCPToolCompleteChunk, MCPToolInProgressChunk } from '@renderer/types/chunk' import { ChunkType } from '@renderer/types/chunk' +import { SdkMessageParam } from '@renderer/types/sdk' import { isArray, isObject, pull, transform } from 'lodash' import { nanoid } from 'nanoid' import OpenAI from 'openai' @@ -31,7 +26,7 @@ import { ChatCompletionTool } from 'openai/resources' -import { CompletionsParams } from '../providers/AiProvider' +import { CompletionsParams } from '../aiCore/middleware/schemas' const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install' const EXTRA_SCHEMA_KEYS = ['schema', 'headers'] @@ -449,13 +444,25 @@ export function parseToolUse(content: string, mcpTools: MCPTool[]): ToolUseRespo if (!content || !mcpTools || mcpTools.length === 0) { return [] } + + // 支持两种格式: + // 1. 完整的 标签包围的内容 + // 2. 只有内部内容(从 TagExtractor 提取出来的) + + let contentToProcess = content + + // 如果内容不包含 标签,说明是从 TagExtractor 提取的内部内容,需要包装 + if (!content.includes('')) { + contentToProcess = `\n${content}\n` + } + const toolUsePattern = /([\s\S]*?)([\s\S]*?)<\/name>([\s\S]*?)([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g const tools: ToolUseResponse[] = [] let match let idx = 0 // Find all tool use blocks - while ((match = toolUsePattern.exec(content)) !== null) { + while ((match = toolUsePattern.exec(contentToProcess)) !== null) { // const fullMatch = match[0] const toolName = match[2].trim() const toolArgs = match[4].trim() @@ -497,9 +504,7 @@ export async function parseAndCallTools( convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined, model: Model, mcpTools?: MCPTool[] -): Promise< - (ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.ResponseInputItem | ToolResultBlockParam)[] -> +): Promise export async function parseAndCallTools( content: string, @@ -508,9 +513,7 @@ export async function parseAndCallTools( convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined, model: Model, mcpTools?: MCPTool[] -): Promise< - (ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.ResponseInputItem | ToolResultBlockParam)[] -> +): Promise export async function parseAndCallTools( content: string | MCPToolResponse[], @@ -539,7 +542,7 @@ export async function parseAndCallTools( ...toolResponse, status: 'invoking' }, - onChunk + onChunk! ) } @@ -553,7 +556,7 @@ export async function parseAndCallTools( status: 'done', response: toolCallResponse }, - onChunk + onChunk! ) for (const content of toolCallResponse.content) { @@ -563,10 +566,10 @@ export async function parseAndCallTools( } if (images.length) { - onChunk({ + onChunk?.({ type: ChunkType.IMAGE_CREATED }) - onChunk({ + onChunk?.({ type: ChunkType.IMAGE_COMPLETE, image: { type: 'base64', diff --git a/src/renderer/src/utils/naming.ts b/src/renderer/src/utils/naming.ts index d27b1da30b..f475fa7421 100644 --- a/src/renderer/src/utils/naming.ts +++ b/src/renderer/src/utils/naming.ts @@ -101,7 +101,7 @@ export function isEmoji(str: string): boolean { * @returns {string} 处理后的字符串 */ export function removeSpecialCharactersForTopicName(str: string): string { - return str.replace(/[\r\n]+/g, ' ').trim() + return str.replace(/["'\r\n]+/g, ' ').trim() } /** diff --git a/src/renderer/src/utils/stream.ts b/src/renderer/src/utils/stream.ts index e1d34cc94e..b8e9f9fd2c 100644 --- a/src/renderer/src/utils/stream.ts +++ b/src/renderer/src/utils/stream.ts @@ -31,10 +31,12 @@ export function readableStreamAsyncIterable(stream: any): AsyncIterableIterat } } -export function asyncGeneratorToReadableStream(gen: AsyncGenerator): ReadableStream { +export function asyncGeneratorToReadableStream(gen: AsyncIterable): ReadableStream { + const iterator = gen[Symbol.asyncIterator]() + return new ReadableStream({ async pull(controller) { - const { value, done } = await gen.next() + const { value, done } = await iterator.next() if (done) { controller.close() } else { @@ -43,3 +45,17 @@ export function asyncGeneratorToReadableStream(gen: AsyncGenerator): Reada } }) } + +/** + * 将单个数据项转换为可读流 + * @param data 要转换为流的单个数据项 + * @returns 包含单个数据项的ReadableStream + */ +export function createSingleChunkReadableStream(data: T): ReadableStream { + return new ReadableStream({ + start(controller) { + controller.enqueue(data) + controller.close() + } + }) +} diff --git a/src/renderer/src/utils/tagExtraction.ts b/src/renderer/src/utils/tagExtraction.ts new file mode 100644 index 0000000000..28664b01c5 --- /dev/null +++ b/src/renderer/src/utils/tagExtraction.ts @@ -0,0 +1,168 @@ +import { getPotentialStartIndex } from './getPotentialIndex' + +export interface TagConfig { + openingTag: string + closingTag: string + separator?: string +} + +export interface TagExtractionState { + textBuffer: string + isInsideTag: boolean + isFirstTag: boolean + isFirstText: boolean + afterSwitch: boolean + accumulatedTagContent: string + hasTagContent: boolean +} + +export interface TagExtractionResult { + content: string + isTagContent: boolean + complete: boolean + tagContentExtracted?: string +} + +/** + * 通用标签提取处理器 + * 可以处理各种形式的标签对,如 ..., ... 等 + */ +export class TagExtractor { + private config: TagConfig + private state: TagExtractionState + + constructor(config: TagConfig) { + this.config = config + this.state = { + textBuffer: '', + isInsideTag: false, + isFirstTag: true, + isFirstText: true, + afterSwitch: false, + accumulatedTagContent: '', + hasTagContent: false + } + } + + /** + * 处理文本块,返回处理结果 + */ + processText(newText: string): TagExtractionResult[] { + this.state.textBuffer += newText + const results: TagExtractionResult[] = [] + + // 处理标签提取逻辑 + while (true) { + const nextTag = this.state.isInsideTag ? this.config.closingTag : this.config.openingTag + const startIndex = getPotentialStartIndex(this.state.textBuffer, nextTag) + + if (startIndex == null) { + const content = this.state.textBuffer + if (content.length > 0) { + results.push({ + content: this.addPrefix(content), + isTagContent: this.state.isInsideTag, + complete: false + }) + + if (this.state.isInsideTag) { + this.state.accumulatedTagContent += this.addPrefix(content) + this.state.hasTagContent = true + } + } + this.state.textBuffer = '' + break + } + + // 处理标签前的内容 + const contentBeforeTag = this.state.textBuffer.slice(0, startIndex) + if (contentBeforeTag.length > 0) { + results.push({ + content: this.addPrefix(contentBeforeTag), + isTagContent: this.state.isInsideTag, + complete: false + }) + + if (this.state.isInsideTag) { + this.state.accumulatedTagContent += this.addPrefix(contentBeforeTag) + this.state.hasTagContent = true + } + } + + const foundFullMatch = startIndex + nextTag.length <= this.state.textBuffer.length + + if (foundFullMatch) { + // 如果找到完整的标签 + this.state.textBuffer = this.state.textBuffer.slice(startIndex + nextTag.length) + + // 如果刚刚结束一个标签内容,生成完整的标签内容结果 + if (this.state.isInsideTag && this.state.hasTagContent) { + results.push({ + content: '', + isTagContent: false, + complete: true, + tagContentExtracted: this.state.accumulatedTagContent + }) + this.state.accumulatedTagContent = '' + this.state.hasTagContent = false + } + + this.state.isInsideTag = !this.state.isInsideTag + this.state.afterSwitch = true + + if (this.state.isInsideTag) { + this.state.isFirstTag = false + } else { + this.state.isFirstText = false + } + } else { + this.state.textBuffer = this.state.textBuffer.slice(startIndex) + break + } + } + + return results + } + + /** + * 完成处理,返回任何剩余的标签内容 + */ + finalize(): TagExtractionResult | null { + if (this.state.hasTagContent && this.state.accumulatedTagContent) { + const result = { + content: '', + isTagContent: false, + complete: true, + tagContentExtracted: this.state.accumulatedTagContent + } + this.state.accumulatedTagContent = '' + this.state.hasTagContent = false + return result + } + return null + } + + private addPrefix(text: string): string { + const needsPrefix = + this.state.afterSwitch && (this.state.isInsideTag ? !this.state.isFirstTag : !this.state.isFirstText) + + const prefix = needsPrefix && this.config.separator ? this.config.separator : '' + this.state.afterSwitch = false + return prefix + text + } + + /** + * 重置状态 + */ + reset(): void { + this.state = { + textBuffer: '', + isInsideTag: false, + isFirstTag: true, + isFirstText: true, + afterSwitch: false, + accumulatedTagContent: '', + hasTagContent: false + } + } +}