diff --git a/.vscode/launch.json b/.vscode/launch.json index efacfda6f8..0b6b9a6499 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -7,7 +7,6 @@ "request": "launch", "cwd": "${workspaceRoot}", "runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite", - "runtimeVersion": "20", "windows": { "runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite.cmd" }, diff --git a/docs/technical/how-to-write-middlewares.md b/docs/technical/how-to-write-middlewares.md new file mode 100644 index 0000000000..9f3b691309 --- /dev/null +++ b/docs/technical/how-to-write-middlewares.md @@ -0,0 +1,214 @@ +# 如何为 AI Provider 编写中间件 + +本文档旨在指导开发者如何为我们的 AI Provider 框架创建和集成自定义中间件。中间件提供了一种强大而灵活的方式来增强、修改或观察 Provider 方法的调用过程,例如日志记录、缓存、请求/响应转换、错误处理等。 + +## 架构概览 + +我们的中间件架构借鉴了 Redux 的三段式设计,并结合了 JavaScript Proxy 来动态地将中间件应用于 Provider 的方法。 + +- **Proxy**: 拦截对 Provider 方法的调用,并将调用引导至中间件链。 +- **中间件链**: 一系列按顺序执行的中间件函数。每个中间件都可以处理请求/响应,然后将控制权传递给链中的下一个中间件,或者在某些情况下提前终止链。 +- **上下文 (Context)**: 一个在中间件之间传递的对象,携带了关于当前调用的信息(如方法名、原始参数、Provider 实例、以及中间件自定义的数据)。 + +## 中间件的类型 + +目前主要支持两种类型的中间件,它们共享相似的结构但针对不同的场景: + +1. **`CompletionsMiddleware`**: 专门为 `completions` 方法设计。这是最常用的中间件类型,因为它允许对 AI 模型的核心聊天/文本生成功能进行精细控制。 +2. **`ProviderMethodMiddleware`**: 通用中间件,可以应用于 Provider 上的任何其他方法(例如,`translate`, `summarize` 等,如果这些方法也通过中间件系统包装)。 + +## 编写一个 `CompletionsMiddleware` + +`CompletionsMiddleware` 的基本签名(TypeScript 类型)如下: + +```typescript +import { AiProviderMiddlewareCompletionsContext, CompletionsParams, MiddlewareAPI } from './AiProviderMiddlewareTypes' // 假设类型定义文件路径 + +export type CompletionsMiddleware = ( + api: MiddlewareAPI +) => ( + next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise // next 返回 Promise 代表原始SDK响应或下游中间件的结果 +) => (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise // 最内层函数通常返回 Promise,因为结果通过 onChunk 或 context 副作用传递 +``` + +让我们分解这个三段式结构: + +1. **第一层函数 `(api) => { ... }`**: + + - 接收一个 `api` 对象。 + - `api` 对象提供了以下方法: + - `api.getContext()`: 获取当前调用的上下文对象 (`AiProviderMiddlewareCompletionsContext`)。 + - `api.getOriginalArgs()`: 获取传递给 `completions` 方法的原始参数数组 (即 `[CompletionsParams]`)。 + - `api.getProviderId()`: 获取当前 Provider 的 ID。 + - `api.getProviderInstance()`: 获取原始的 Provider 实例。 + - 此函数通常用于进行一次性的设置或获取所需的服务/配置。它返回第二层函数。 + +2. **第二层函数 `(next) => { ... }`**: + + - 接收一个 `next` 函数。 + - `next` 函数代表了中间件链中的下一个环节。调用 `next(context, params)` 会将控制权传递给下一个中间件,或者如果当前中间件是链中的最后一个,则会调用核心的 Provider 方法逻辑 (例如,实际的 SDK 调用)。 + - `next` 函数接收当前的 `context` 和 `params` (这些可能已被上游中间件修改)。 + - **重要的是**:`next` 的返回类型通常是 `Promise`。对于 `completions` 方法,如果 `next` 调用了实际的 SDK,它将返回原始的 SDK 响应(例如,OpenAI 的流对象或 JSON 对象)。你需要处理这个响应。 + - 此函数返回第三层(也是最核心的)函数。 + +3. **第三层函数 `(context, params) => { ... }`**: + - 这是执行中间件主要逻辑的地方。 + - 它接收当前的 `context` (`AiProviderMiddlewareCompletionsContext`) 和 `params` (`CompletionsParams`)。 + - 在此函数中,你可以: + - **在调用 `next` 之前**: + - 读取或修改 `params`。例如,添加默认参数、转换消息格式。 + - 读取或修改 `context`。例如,设置一个时间戳用于后续计算延迟。 + - 执行某些检查,如果不满足条件,可以不调用 `next` 而直接返回或抛出错误(例如,参数校验失败)。 + - **调用 `await next(context, params)`**: + - 这是将控制权传递给下游的关键步骤。 + - `next` 的返回值是原始的 SDK 响应或下游中间件的结果,你需要根据情况处理它(例如,如果是流,则开始消费流)。 + - **在调用 `next` 之后**: + - 处理 `next` 的返回结果。例如,如果 `next` 返回了一个流,你可以在这里开始迭代处理这个流,并通过 `context.onChunk` 发送数据块。 + - 基于 `context` 的变化或 `next` 的结果执行进一步操作。例如,计算总耗时、记录日志。 + - 修改最终结果(尽管对于 `completions`,结果通常通过 `onChunk` 副作用发出)。 + +### 示例:一个简单的日志中间件 + +```typescript +import { + AiProviderMiddlewareCompletionsContext, + CompletionsParams, + MiddlewareAPI, + OnChunkFunction // 假设 OnChunkFunction 类型被导出 +} from './AiProviderMiddlewareTypes' // 调整路径 +import { ChunkType } from '@renderer/types' // 调整路径 + +export const createSimpleLoggingMiddleware = (): CompletionsMiddleware => { + return (api: MiddlewareAPI) => { + // console.log(`[LoggingMiddleware] Initialized for provider: ${api.getProviderId()}`); + + return (next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise) => { + return async (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams): Promise => { + const startTime = Date.now() + // 从 context 中获取 onChunk (它最初来自 params.onChunk) + const onChunk = context.onChunk + + console.log( + `[LoggingMiddleware] Request for ${context.methodName} with params:`, + params.messages?.[params.messages.length - 1]?.content + ) + + try { + // 调用下一个中间件或核心逻辑 + // `rawSdkResponse` 是来自下游的原始响应 (例如 OpenAIStream 或 ChatCompletion 对象) + const rawSdkResponse = await next(context, params) + + // 此处简单示例不处理 rawSdkResponse,假设下游中间件 (如 StreamingResponseHandler) + // 会处理它并通过 onChunk 发送数据。 + // 如果这个日志中间件在 StreamingResponseHandler 之后,那么流已经被处理。 + // 如果在之前,那么它需要自己处理 rawSdkResponse 或确保下游会处理。 + + const duration = Date.now() - startTime + console.log(`[LoggingMiddleware] Request for ${context.methodName} completed in ${duration}ms.`) + + // 假设下游已经通过 onChunk 发送了所有数据。 + // 如果这个中间件是链的末端,并且需要确保 BLOCK_COMPLETE 被发送, + // 它可能需要更复杂的逻辑来跟踪何时所有数据都已发送。 + } catch (error) { + const duration = Date.now() - startTime + console.error(`[LoggingMiddleware] Request for ${context.methodName} failed after ${duration}ms:`, error) + + // 如果 onChunk 可用,可以尝试发送一个错误块 + if (onChunk) { + onChunk({ + type: ChunkType.ERROR, + error: { message: (error as Error).message, name: (error as Error).name, stack: (error as Error).stack } + }) + // 考虑是否还需要发送 BLOCK_COMPLETE 来结束流 + onChunk({ type: ChunkType.BLOCK_COMPLETE, response: {} }) + } + throw error // 重新抛出错误,以便上层或全局错误处理器可以捕获 + } + } + } + } +} +``` + +### `AiProviderMiddlewareCompletionsContext` 的重要性 + +`AiProviderMiddlewareCompletionsContext` 是在中间件之间传递状态和数据的核心。它通常包含: + +- `methodName`: 当前调用的方法名 (总是 `'completions'`)。 +- `originalArgs`: 传递给 `completions` 的原始参数数组。 +- `providerId`: Provider 的 ID。 +- `_providerInstance`: Provider 实例。 +- `onChunk`: 从原始 `CompletionsParams` 传入的回调函数,用于流式发送数据块。**所有中间件都应该通过 `context.onChunk` 来发送数据。** +- `messages`, `model`, `assistant`, `mcpTools`: 从原始 `CompletionsParams` 中提取的常用字段,方便访问。 +- **自定义字段**: 中间件可以向上下文中添加自定义字段,以供后续中间件使用。例如,一个缓存中间件可能会添加 `context.cacheHit = true`。 + +**关键**: 当你在中间件中修改 `params` 或 `context` 时,这些修改会向下游中间件传播(如果它们在 `next` 调用之前修改)。 + +### 中间件的顺序 + +中间件的执行顺序非常重要。它们在 `AiProviderMiddlewareConfig` 的数组中定义的顺序就是它们的执行顺序。 + +- 请求首先通过第一个中间件,然后是第二个,依此类推。 +- 响应(或 `next` 的调用结果)则以相反的顺序"冒泡"回来。 + +例如,如果链是 `[AuthMiddleware, CacheMiddleware, LoggingMiddleware]`: + +1. `AuthMiddleware` 先执行其 "调用 `next` 之前" 的逻辑。 +2. 然后 `CacheMiddleware` 执行其 "调用 `next` 之前" 的逻辑。 +3. 然后 `LoggingMiddleware` 执行其 "调用 `next` 之前" 的逻辑。 +4. 核心SDK调用(或链的末端)。 +5. `LoggingMiddleware` 先接收到结果,执行其 "调用 `next` 之后" 的逻辑。 +6. 然后 `CacheMiddleware` 接收到结果(可能已被 LoggingMiddleware 修改的上下文),执行其 "调用 `next` 之后" 的逻辑(例如,存储结果)。 +7. 最后 `AuthMiddleware` 接收到结果,执行其 "调用 `next` 之后" 的逻辑。 + +### 注册中间件 + +中间件在 `src/renderer/src/providers/middleware/register.ts` (或其他类似的配置文件) 中进行注册。 + +```typescript +// register.ts +import { AiProviderMiddlewareConfig } from './AiProviderMiddlewareTypes' +import { createSimpleLoggingMiddleware } from './common/SimpleLoggingMiddleware' // 假设你创建了这个文件 +import { createCompletionsLoggingMiddleware } from './common/CompletionsLoggingMiddleware' // 已有的 + +const middlewareConfig: AiProviderMiddlewareConfig = { + completions: [ + createSimpleLoggingMiddleware(), // 你新加的中间件 + createCompletionsLoggingMiddleware() // 已有的日志中间件 + // ... 其他 completions 中间件 + ], + methods: { + // translate: [createGenericLoggingMiddleware()], + // ... 其他方法的中间件 + } +} + +export default middlewareConfig +``` + +### 最佳实践 + +1. **单一职责**: 每个中间件应专注于一个特定的功能(例如,日志、缓存、转换特定数据)。 +2. **无副作用 (尽可能)**: 除了通过 `context` 或 `onChunk` 明确的副作用外,尽量避免修改全局状态或产生其他隐蔽的副作用。 +3. **错误处理**: + - 在中间件内部使用 `try...catch` 来处理可能发生的错误。 + - 决定是自行处理错误(例如,通过 `onChunk` 发送错误块)还是将错误重新抛出给上游。 + - 如果重新抛出,确保错误对象包含足够的信息。 +4. **性能考虑**: 中间件会增加请求处理的开销。避免在中间件中执行非常耗时的同步操作。对于IO密集型操作,确保它们是异步的。 +5. **可配置性**: 使中间件的行为可通过参数或配置进行调整。例如,日志中间件可以接受一个日志级别参数。 +6. **上下文管理**: + - 谨慎地向 `context` 添加数据。避免污染 `context` 或添加过大的对象。 + - 明确你添加到 `context` 的字段的用途和生命周期。 +7. **`next` 的调用**: + - 除非你有充分的理由提前终止请求(例如,缓存命中、授权失败),否则**总是确保调用 `await next(context, params)`**。否则,下游的中间件和核心逻辑将不会执行。 + - 理解 `next` 的返回值并正确处理它,特别是当它是一个流时。你需要负责消费这个流或将其传递给另一个能够消费它的组件/中间件。 +8. **命名清晰**: 给你的中间件和它们创建的函数起描述性的名字。 +9. **文档和注释**: 对复杂的中间件逻辑添加注释,解释其工作原理和目的。 + +### 调试技巧 + +- 在中间件的关键点使用 `console.log` 或调试器来检查 `params`、`context` 的状态以及 `next` 的返回值。 +- 暂时简化中间件链,只保留你正在调试的中间件和最简单的核心逻辑,以隔离问题。 +- 编写单元测试来独立验证每个中间件的行为。 + +通过遵循这些指南,你应该能够有效地为我们的系统创建强大且可维护的中间件。如果你有任何疑问或需要进一步的帮助,请咨询团队。 diff --git a/package.json b/package.json index 050d52de02..ed63b03057 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "CherryStudio", - "version": "1.4.2-ui-preview", + "version": "1.4.2", "private": true, "description": "A powerful AI assistant for producer.", "main": "./out/main/index.js", @@ -58,6 +58,20 @@ "prepare": "husky" }, "dependencies": { + "@libsql/client": "0.14.0", + "@libsql/win32-x64-msvc": "^0.4.7", + "@strongtz/win32-arm64-msvc": "^0.4.7", + "jsdom": "26.1.0", + "os-proxy-config": "^1.1.2", + "selection-hook": "^0.9.23", + "turndown": "7.2.0" + }, + "devDependencies": { + "@agentic/exa": "^7.3.3", + "@agentic/searxng": "^7.3.3", + "@agentic/tavily": "^7.3.3", + "@ant-design/v5-patch-for-react-19": "^1.0.3", + "@anthropic-ai/sdk": "^0.41.0", "@cherrystudio/embedjs": "^0.1.31", "@cherrystudio/embedjs-libsql": "^0.1.31", "@cherrystudio/embedjs-loader-csv": "^0.1.31", @@ -70,48 +84,11 @@ "@cherrystudio/embedjs-loader-xml": "^0.1.31", "@cherrystudio/embedjs-ollama": "^0.1.31", "@cherrystudio/embedjs-openai": "^0.1.31", - "@electron-toolkit/utils": "^3.0.0", - "@langchain/community": "^0.3.36", - "@langchain/ollama": "^0.2.1", - "@strongtz/win32-arm64-msvc": "^0.4.7", - "@tanstack/react-query": "^5.27.0", - "@types/react-infinite-scroll-component": "^5.0.0", - "archiver": "^7.0.1", - "async-mutex": "^0.5.0", - "diff": "^7.0.0", - "docx": "^9.0.2", - "electron-log": "^5.1.5", - "electron-store": "^8.2.0", - "electron-updater": "6.6.4", - "electron-window-state": "^5.0.3", - "epub": "patch:epub@npm%3A1.3.0#~/.yarn/patches/epub-npm-1.3.0-8325494ffe.patch", - "fast-xml-parser": "^5.2.0", - "framer-motion": "^12.17.0", - "franc-min": "^6.2.0", - "fs-extra": "^11.2.0", - "jsdom": "^26.0.0", - "markdown-it": "^14.1.0", - "node-stream-zip": "^1.15.0", - "officeparser": "^4.1.1", - "os-proxy-config": "^1.1.2", - "proxy-agent": "^6.5.0", - "remove-markdown": "^0.6.2", - "selection-hook": "^0.9.23", - "tar": "^7.4.3", - "turndown": "^7.2.0", - "webdav": "^5.8.0", - "zipread": "^1.3.3" - }, - "devDependencies": { - "@agentic/exa": "^7.3.3", - "@agentic/searxng": "^7.3.3", - "@agentic/tavily": "^7.3.3", - "@ant-design/v5-patch-for-react-19": "^1.0.3", - "@anthropic-ai/sdk": "^0.41.0", "@electron-toolkit/eslint-config-prettier": "^3.0.0", "@electron-toolkit/eslint-config-ts": "^3.0.0", "@electron-toolkit/preload": "^3.0.0", "@electron-toolkit/tsconfig": "^1.0.1", + "@electron-toolkit/utils": "^3.0.0", "@electron/notarize": "^2.5.0", "@emotion/is-prop-valid": "^1.3.1", "@eslint-react/eslint-plugin": "^1.36.1", @@ -119,6 +96,8 @@ "@google/genai": "^1.0.1", "@hello-pangea/dnd": "^16.6.0", "@kangfenmao/keyv-storage": "^0.1.0", + "@langchain/community": "^0.3.36", + "@langchain/ollama": "^0.2.1", "@modelcontextprotocol/sdk": "^1.11.4", "@mozilla/readability": "^0.6.0", "@notionhq/client": "^2.2.15", @@ -126,6 +105,7 @@ "@reduxjs/toolkit": "^2.2.5", "@shikijs/markdown-it": "^3.4.2", "@swc/plugin-styled-components": "^7.1.5", + "@tanstack/react-query": "^5.27.0", "@testing-library/dom": "^10.4.0", "@testing-library/jest-dom": "^6.6.3", "@testing-library/react": "^16.3.0", @@ -152,24 +132,37 @@ "@vitest/web-worker": "^3.1.4", "@xyflow/react": "^12.4.4", "antd": "^5.22.5", + "archiver": "^7.0.1", + "async-mutex": "^0.5.0", "axios": "^1.7.3", "browser-image-compression": "^2.0.2", "color": "^5.0.0", "dayjs": "^1.11.11", "dexie": "^4.0.8", "dexie-react-hooks": "^1.1.7", + "diff": "^7.0.0", + "docx": "^9.0.2", "dotenv-cli": "^7.4.2", "electron": "35.4.0", "electron-builder": "26.0.15", "electron-devtools-installer": "^3.2.0", + "electron-log": "^5.1.5", + "electron-store": "^8.2.0", + "electron-updater": "6.6.4", "electron-vite": "^3.1.0", + "electron-window-state": "^5.0.3", "emittery": "^1.0.3", "emoji-picker-element": "^1.22.1", + "epub": "patch:epub@npm%3A1.3.0#~/.yarn/patches/epub-npm-1.3.0-8325494ffe.patch", "eslint": "^9.22.0", "eslint-plugin-react-hooks": "^5.2.0", "eslint-plugin-simple-import-sort": "^12.1.1", "eslint-plugin-unused-imports": "^4.1.4", "fast-diff": "^1.3.0", + "fast-xml-parser": "^5.2.0", + "framer-motion": "^12.17.3", + "franc-min": "^6.2.0", + "fs-extra": "^11.2.0", "html-to-image": "^1.11.13", "husky": "^9.1.7", "i18next": "^23.11.5", @@ -178,14 +171,18 @@ "lodash": "^4.17.21", "lru-cache": "^11.1.0", "lucide-react": "^0.487.0", + "markdown-it": "^14.1.0", "mermaid": "^11.6.0", "mime": "^4.0.4", "motion": "^12.10.5", + "node-stream-zip": "^1.15.0", "npx-scope-finder": "^1.2.0", + "officeparser": "^4.1.1", "openai": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch", "p-queue": "^8.1.0", "playwright": "^1.52.0", "prettier": "^3.5.3", + "proxy-agent": "^6.5.0", "rc-virtual-list": "^3.18.6", "react": "^19.0.0", "react-dom": "^19.0.0", @@ -206,17 +203,21 @@ "remark-cjk-friendly": "^1.1.0", "remark-gfm": "^4.0.0", "remark-math": "^6.0.0", + "remove-markdown": "^0.6.2", "rollup-plugin-visualizer": "^5.12.0", "sass": "^1.88.0", "shiki": "^3.4.2", "string-width": "^7.2.0", "styled-components": "^6.1.11", + "tar": "^7.4.3", "tiny-pinyin": "^1.3.2", "tokenx": "^0.4.1", "typescript": "^5.6.2", "uuid": "^10.0.0", "vite": "6.2.6", - "vitest": "^3.1.4" + "vitest": "^3.1.4", + "webdav": "^5.8.0", + "zipread": "^1.3.3" }, "resolutions": { "pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch", diff --git a/packages/shared/config/constant.ts b/packages/shared/config/constant.ts index cfba46df70..5a3465f648 100644 --- a/packages/shared/config/constant.ts +++ b/packages/shared/config/constant.ts @@ -408,3 +408,4 @@ export enum FeedUrl { PRODUCTION = 'https://releases.cherry-ai.com', EARLY_ACCESS = 'https://github.com/CherryHQ/cherry-studio/releases/latest/download' } +export const defaultTimeout = 5 * 1000 * 60 diff --git a/src/renderer/src/aiCore/AI_CORE_DESIGN.md b/src/renderer/src/aiCore/AI_CORE_DESIGN.md new file mode 100644 index 0000000000..611c582d83 --- /dev/null +++ b/src/renderer/src/aiCore/AI_CORE_DESIGN.md @@ -0,0 +1,223 @@ +# Cherry Studio AI Provider 技术架构文档 (新方案) + +## 1. 核心设计理念与目标 + +本架构旨在重构 Cherry Studio 的 AI Provider(现称为 `aiCore`)层,以实现以下目标: + +- **职责清晰**:明确划分各组件的职责,降低耦合度。 +- **高度复用**:最大化业务逻辑和通用处理逻辑的复用,减少重复代码。 +- **易于扩展**:方便快捷地接入新的 AI Provider (LLM供应商) 和添加新的 AI 功能 (如翻译、摘要、图像生成等)。 +- **易于维护**:简化单个组件的复杂性,提高代码的可读性和可维护性。 +- **标准化**:统一内部数据流和接口,简化不同 Provider 之间的差异处理。 + +核心思路是将纯粹的 **SDK 适配层 (`XxxApiClient`)**、**通用逻辑处理与智能解析层 (中间件)** 以及 **统一业务功能入口层 (`AiCoreService`)** 清晰地分离开来。 + +## 2. 核心组件详解 + +### 2.1. `aiCore` (原 `AiProvider` 文件夹) + +这是整个 AI 功能的核心模块。 + +#### 2.1.1. `XxxApiClient` (例如 `aiCore/clients/openai/OpenAIApiClient.ts`) + +- **职责**:作为特定 AI Provider SDK 的纯粹适配层。 + - **参数适配**:将应用内部统一的 `CoreRequest` 对象 (见下文) 转换为特定 SDK 所需的请求参数格式。 + - **基础响应转换**:将 SDK 返回的原始数据块 (`RawSdkChunk`,例如 `OpenAI.Chat.Completions.ChatCompletionChunk`) 转换为一组最基础、最直接的应用层 `Chunk` 对象 (定义于 `src/renderer/src/types/chunk.ts`)。 + - 例如:SDK 的 `delta.content` -> `TextDeltaChunk`;SDK 的 `delta.reasoning_content` -> `ThinkingDeltaChunk`;SDK 的 `delta.tool_calls` -> `RawToolCallChunk` (包含原始工具调用数据)。 + - **关键**:`XxxApiClient` **不处理**耦合在文本内容中的复杂结构,如 `` 或 `` 标签。 +- **特点**:极度轻量化,代码量少,易于实现和维护新的 Provider 适配。 + +#### 2.1.2. `ApiClient.ts` (或 `BaseApiClient.ts` 的核心接口) + +- 定义了所有 `XxxApiClient` 必须实现的接口,如: + - `getSdkInstance(): Promise | TSdkInstance` + - `getRequestTransformer(): RequestTransformer` + - `getResponseChunkTransformer(): ResponseChunkTransformer` + - 其他可选的、与特定 Provider 相关的辅助方法 (如工具调用转换)。 + +#### 2.1.3. `ApiClientFactory.ts` + +- 根据 Provider 配置动态创建和返回相应的 `XxxApiClient` 实例。 + +#### 2.1.4. `AiCoreService.ts` (`aiCore/index.ts`) + +- **职责**:作为所有 AI 相关业务功能的统一入口。 + - 提供面向应用的高层接口,例如: + - `executeCompletions(params: CompletionsParams): Promise` + - `translateText(params: TranslateParams): Promise` + - `summarizeText(params: SummarizeParams): Promise` + - 未来可能的 `generateImage(prompt: string): Promise` 等。 + - **返回 `Promise`**:每个服务方法返回一个 `Promise`,该 `Promise` 会在整个(可能是流式的)操作完成后,以包含所有聚合结果(如完整文本、工具调用详情、最终的`usage`/`metrics`等)的对象来 `resolve`。 + - **支持流式回调**:服务方法的参数 (如 `CompletionsParams`) 依然包含 `onChunk` 回调,用于向调用方实时推送处理过程中的 `Chunk` 数据,实现流式UI更新。 + - **封装特定任务的提示工程 (Prompt Engineering)**: + - 例如,`translateText` 方法内部会构建一个包含特定翻译指令的 `CoreRequest`。 + - **编排和调用中间件链**:通过内部的 `MiddlewareBuilder` (参见 `middleware/BUILDER_USAGE.md`) 实例,根据调用的业务方法和参数,动态构建和组织合适的中间件序列,然后通过 `applyCompletionsMiddlewares` 等组合函数执行。 + - 获取 `ApiClient` 实例并将其注入到中间件上游的 `Context` 中。 + - **将 `Promise` 的 `resolve` 和 `reject` 函数传递给中间件链** (通过 `Context`),以便 `FinalChunkConsumerAndNotifierMiddleware` 可以在操作完成或发生错误时结束该 `Promise`。 +- **优势**: + - 业务逻辑(如翻译、摘要的提示构建和流程控制)只需实现一次,即可支持所有通过 `ApiClient` 接入的底层 Provider。 + - **支持外部编排**:调用方可以 `await` 服务方法以获取最终聚合结果,然后将此结果作为后续操作的输入,轻松实现多步骤工作流。 + - **支持内部组合**:服务自身也可以通过 `await` 调用其他原子服务方法来构建更复杂的组合功能。 + +#### 2.1.5. `coreRequestTypes.ts` (或 `types.ts`) + +- 定义核心的、Provider 无关的内部请求结构,例如: + - `CoreCompletionsRequest`: 包含标准化后的消息列表、模型配置、工具列表、最大Token数、是否流式输出等。 + - `CoreTranslateRequest`, `CoreSummarizeRequest` 等 (如果与 `CoreCompletionsRequest` 结构差异较大,否则可复用并添加任务类型标记)。 + +### 2.2. `middleware` + +中间件层负责处理请求和响应流中的通用逻辑和特定特性。其设计和使用遵循 `middleware/BUILDER_USAGE.md` 中定义的规范。 + +**核心组件包括:** + +- **`MiddlewareBuilder`**: 一个通用的、提供流式API的类,用于动态构建中间件链。它支持从基础链开始,根据条件添加、插入、替换或移除中间件。 +- **`applyCompletionsMiddlewares`**: 负责接收 `MiddlewareBuilder` 构建的链并按顺序执行,专门用于 Completions 流程。 +- **`MiddlewareRegistry`**: 集中管理所有可用中间件的注册表,提供统一的中间件访问接口。 +- **各种独立的中间件模块** (存放于 `common/`, `core/`, `feat/` 子目录)。 + +#### 2.2.1. `middlewareTypes.ts` + +- 定义中间件的核心类型,如 `AiProviderMiddlewareContext` (扩展后包含 `_apiClientInstance` 和 `_coreRequest`)、`MiddlewareAPI`、`CompletionsMiddleware` 等。 + +#### 2.2.2. 核心中间件 (`middleware/core/`) + +- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()` 将 `CoreRequest` 转换为特定 SDK 的参数,并存入上下文。 +- **`RequestExecutionMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。 +- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流 (如异步迭代器) 统一适配为 `ReadableStream`。 + - **`RawSdkChunk`**:指特定AI提供商SDK在流式响应中返回的、未经应用层统一处理的原始数据块格式 (例如 OpenAI 的 `ChatCompletionChunk`,Gemini 的 `GenerateContentResponse` 中的部分等)。 +- **`RawSdkChunkToAppChunkMiddleware.ts`**: (新增) 消费 `ReadableStream`,在其内部对每个 `RawSdkChunk` 调用 `ApiClient.getResponseChunkTransformer()`,将其转换为一个或多个基础的应用层 `Chunk` 对象,并输出 `ReadableStream`。 + +#### 2.2.3. 特性中间件 (`middleware/feat/`) + +这些中间件消费由 `ResponseTransformMiddleware` 输出的、相对标准化的 `Chunk` 流,并处理更复杂的逻辑。 + +- **`ThinkingTagExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `...` 文本内嵌标签,生成 `ThinkingDeltaChunk` 和 `ThinkingCompleteChunk`。 +- **`ToolUseExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `...` 文本内嵌标签,生成工具调用相关的 Chunk。如果 `ApiClient` 输出了原生工具调用数据,此中间件也负责将其转换为标准格式。 + +#### 2.2.4. 核心处理中间件 (`middleware/core/`) + +- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()` 将 `CoreRequest` 转换为特定 SDK 的参数,并存入上下文。 +- **`SdkCallMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。 +- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流统一适配为标准流格式。 +- **`ResponseTransformMiddleware.ts`**: 将原始 SDK 响应转换为应用层标准 `Chunk` 对象。 +- **`TextChunkMiddleware.ts`**: 处理文本相关的 Chunk 流。 +- **`ThinkChunkMiddleware.ts`**: 处理思考相关的 Chunk 流。 +- **`McpToolChunkMiddleware.ts`**: 处理工具调用相关的 Chunk 流。 +- **`WebSearchMiddleware.ts`**: 处理 Web 搜索相关逻辑。 + +#### 2.2.5. 通用中间件 (`middleware/common/`) + +- **`LoggingMiddleware.ts`**: 请求和响应日志。 +- **`AbortHandlerMiddleware.ts`**: 处理请求中止。 +- **`FinalChunkConsumerMiddleware.ts`**: 消费最终的 `Chunk` 流,通过 `context.onChunk` 回调通知应用层实时数据。 + - **累积数据**:在流式处理过程中,累积关键数据,如文本片段、工具调用信息、`usage`/`metrics` 等。 + - **结束 `Promise`**:当输入流结束时,使用累积的聚合结果来完成整个处理流程。 + - 在流结束时,发送包含最终累加信息的完成信号。 + +### 2.3. `types/chunk.ts` + +- 定义应用全局统一的 `Chunk` 类型及其所有变体。这包括基础类型 (如 `TextDeltaChunk`, `ThinkingDeltaChunk`)、SDK原生数据传递类型 (如 `RawToolCallChunk`, `RawFinishChunk` - 作为 `ApiClient` 转换的中间产物),以及功能性类型 (如 `McpToolCallRequestChunk`, `WebSearchCompleteChunk`)。 + +## 3. 核心执行流程 (以 `AiCoreService.executeCompletions` 为例) + +```markdown +**应用层 (例如 UI 组件)** +|| +\\/ +**`AiProvider.completions` (`aiCore/index.ts`)** +(1. prepare ApiClient instance. 2. use `CompletionsMiddlewareBuilder.withDefaults()` to build middleware chain. 3. call `applyCompletionsMiddlewares`) +|| +\\/ +**`applyCompletionsMiddlewares` (`middleware/composer.ts`)** +(接收构建好的链、ApiClient实例、原始SDK方法,开始按序执行中间件) +|| +\\/ +**[ 预处理阶段中间件 ]** +(例如: `FinalChunkConsumerMiddleware`, `TransformCoreToSdkParamsMiddleware`, `AbortHandlerMiddleware`) +|| (Context 中准备好 SDK 请求参数) +\\/ +**[ 处理阶段中间件 ]** +(例如: `McpToolChunkMiddleware`, `WebSearchMiddleware`, `TextChunkMiddleware`, `ThinkingTagExtractionMiddleware`) +|| (处理各种特性和Chunk类型) +\\/ +**[ SDK调用阶段中间件 ]** +(例如: `ResponseTransformMiddleware`, `StreamAdapterMiddleware`, `SdkCallMiddleware`) +|| (输出: 标准化的应用层Chunk流) +\\/ +**`FinalChunkConsumerMiddleware` (核心)** +(消费最终的 `Chunk` 流, 通过 `context.onChunk` 回调通知应用层, 并在流结束时完成处理) +|| +\\/ +**`AiProvider.completions` 返回 `Promise`** +``` + +## 4. 建议的文件/目录结构 + +``` +src/renderer/src/ +└── aiCore/ + ├── clients/ + │ ├── openai/ + │ ├── gemini/ + │ ├── anthropic/ + │ ├── BaseApiClient.ts + │ ├── ApiClientFactory.ts + │ ├── AihubmixAPIClient.ts + │ ├── index.ts + │ └── types.ts + ├── middleware/ + │ ├── common/ + │ ├── core/ + │ ├── feat/ + │ ├── builder.ts + │ ├── composer.ts + │ ├── index.ts + │ ├── register.ts + │ ├── schemas.ts + │ ├── types.ts + │ └── utils.ts + ├── types/ + │ ├── chunk.ts + │ └── ... + └── index.ts +``` + +## 5. 迁移和实施建议 + +- **小步快跑,逐步迭代**:优先完成核心流程的重构(例如 `completions`),再逐步迁移其他功能(`translate` 等)和其他 Provider。 +- **优先定义核心类型**:`CoreRequest`, `Chunk`, `ApiClient` 接口是整个架构的基石。 +- **为 `ApiClient` 瘦身**:将现有 `XxxProvider` 中的复杂逻辑剥离到新的中间件或 `AiCoreService` 中。 +- **强化中间件**:让中间件承担起更多解析和特性处理的责任。 +- **编写单元测试和集成测试**:确保每个组件和整体流程的正确性。 + +此架构旨在提供一个更健壮、更灵活、更易于维护的 AI 功能核心,支撑 Cherry Studio 未来的发展。 + +## 6. 迁移策略与实施建议 + +本节内容提炼自早期的 `migrate.md` 文档,并根据最新的架构讨论进行了调整。 + +**目标架构核心组件回顾:** + +与第 2 节描述的核心组件一致,主要包括 `XxxApiClient`, `AiCoreService`, 中间件链, `CoreRequest` 类型, 和标准化的 `Chunk` 类型。 + +**迁移步骤:** + +**Phase 0: 准备工作和类型定义** + +1. **定义核心数据结构 (TypeScript 类型):** + - `CoreCompletionsRequest` (Type):定义应用内部统一的对话请求结构。 + - `Chunk` (Type - 检查并按需扩展现有 `src/renderer/src/types/chunk.ts`):定义所有可能的通用Chunk类型。 + - 为其他API(翻译、总结)定义类似的 `CoreXxxRequest` (Type)。 +2. **定义 `ApiClient` 接口:** 明确 `getRequestTransformer`, `getResponseChunkTransformer`, `getSdkInstance` 等核心方法。 +3. **调整 `AiProviderMiddlewareContext`:** + - 确保包含 `_apiClientInstance: ApiClient`。 + - 确保包含 `_coreRequest: CoreRequestType`。 + - 考虑添加 `resolvePromise: (value: AggregatedResultType) => void` 和 `rejectPromise: (reason?: any) => void` 用于 `AiCoreService` 的 Promise 返回。 + +**Phase 1: 实现第一个 `ApiClient` (以 `OpenAIApiClient` 为例)** + +1. **创建 `OpenAIApiClient` 类:** 实现 `ApiClient` 接口。 +2. **迁移SDK实例和配置。** +3. **实现 `getRequestTransformer()`:** 将 `CoreCompletionsRequest` 转换为 OpenAI SDK 参数。 +4. **实现 `getResponseChunkTransformer()`:** 将 `OpenAI.Chat.Completions.ChatCompletionChunk` 转换为基础的 ` diff --git a/src/renderer/src/aiCore/clients/AihubmixAPIClient.ts b/src/renderer/src/aiCore/clients/AihubmixAPIClient.ts new file mode 100644 index 0000000000..3aa0bc4263 --- /dev/null +++ b/src/renderer/src/aiCore/clients/AihubmixAPIClient.ts @@ -0,0 +1,207 @@ +import { isOpenAILLMModel } from '@renderer/config/models' +import { + GenerateImageParams, + MCPCallToolResponse, + MCPTool, + MCPToolResponse, + Model, + Provider, + ToolCallResponse +} from '@renderer/types' +import { + RequestOptions, + SdkInstance, + SdkMessageParam, + SdkModel, + SdkParams, + SdkRawChunk, + SdkRawOutput, + SdkTool, + SdkToolCall +} from '@renderer/types/sdk' + +import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient' +import { BaseApiClient } from './BaseApiClient' +import { GeminiAPIClient } from './gemini/GeminiAPIClient' +import { OpenAIAPIClient } from './openai/OpenAIApiClient' +import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient' +import { RequestTransformer, ResponseChunkTransformer } from './types' + +/** + * AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient + * 使用装饰器模式实现,在ApiClient层面进行模型路由 + */ +export class AihubmixAPIClient extends BaseApiClient { + // 使用联合类型而不是any,保持类型安全 + private clients: Map = + new Map() + private defaultClient: OpenAIAPIClient + private currentClient: BaseApiClient + + constructor(provider: Provider) { + super(provider) + + // 初始化各个client - 现在有类型安全 + const claudeClient = new AnthropicAPIClient(provider) + const geminiClient = new GeminiAPIClient({ ...provider, apiHost: 'https://aihubmix.com/gemini' }) + const openaiClient = new OpenAIResponseAPIClient(provider) + const defaultClient = new OpenAIAPIClient(provider) + + this.clients.set('claude', claudeClient) + this.clients.set('gemini', geminiClient) + this.clients.set('openai', openaiClient) + this.clients.set('default', defaultClient) + + // 设置默认client + this.defaultClient = defaultClient + this.currentClient = this.defaultClient as BaseApiClient + } + + /** + * 类型守卫:确保client是BaseApiClient的实例 + */ + private isValidClient(client: unknown): client is BaseApiClient { + return ( + client !== null && + client !== undefined && + typeof client === 'object' && + 'createCompletions' in client && + 'getRequestTransformer' in client && + 'getResponseChunkTransformer' in client + ) + } + + /** + * 根据模型获取合适的client + */ + private getClient(model: Model): BaseApiClient { + const id = model.id.toLowerCase() + + // claude开头 + if (id.startsWith('claude')) { + const client = this.clients.get('claude') + if (!client || !this.isValidClient(client)) { + throw new Error('Claude client not properly initialized') + } + return client + } + + // gemini开头 且不以-nothink、-search结尾 + if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) { + const client = this.clients.get('gemini') + if (!client || !this.isValidClient(client)) { + throw new Error('Gemini client not properly initialized') + } + return client + } + + // OpenAI系列模型 + if (isOpenAILLMModel(model)) { + const client = this.clients.get('openai') + if (!client || !this.isValidClient(client)) { + throw new Error('OpenAI client not properly initialized') + } + return client + } + + return this.defaultClient as BaseApiClient + } + + /** + * 根据模型选择合适的client并委托调用 + */ + public getClientForModel(model: Model): BaseApiClient { + this.currentClient = this.getClient(model) + return this.currentClient + } + + // ============ BaseApiClient 抽象方法实现 ============ + + async createCompletions(payload: SdkParams, options?: RequestOptions): Promise { + // 尝试从payload中提取模型信息来选择client + const modelId = this.extractModelFromPayload(payload) + if (modelId) { + const modelObj = { id: modelId } as Model + const targetClient = this.getClient(modelObj) + return targetClient.createCompletions(payload, options) + } + + // 如果无法从payload中提取模型,使用当前设置的client + return this.currentClient.createCompletions(payload, options) + } + + /** + * 从SDK payload中提取模型ID + */ + private extractModelFromPayload(payload: SdkParams): string | null { + // 不同的SDK可能有不同的字段名 + if ('model' in payload && typeof payload.model === 'string') { + return payload.model + } + return null + } + + async generateImage(params: GenerateImageParams): Promise { + return this.currentClient.generateImage(params) + } + + async getEmbeddingDimensions(model?: Model): Promise { + const client = model ? this.getClient(model) : this.currentClient + return client.getEmbeddingDimensions(model) + } + + async listModels(): Promise { + // 可以聚合所有client的模型,或者使用默认client + return this.defaultClient.listModels() + } + + async getSdkInstance(): Promise { + return this.currentClient.getSdkInstance() + } + + getRequestTransformer(): RequestTransformer { + return this.currentClient.getRequestTransformer() + } + + getResponseChunkTransformer(): ResponseChunkTransformer { + return this.currentClient.getResponseChunkTransformer() + } + + convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] { + return this.currentClient.convertMcpToolsToSdkTools(mcpTools) + } + + convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined { + return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools) + } + + convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse { + return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool) + } + + buildSdkMessages( + currentReqMessages: SdkMessageParam[], + output: SdkRawOutput | string, + toolResults: SdkMessageParam[], + toolCalls?: SdkToolCall[] + ): SdkMessageParam[] { + return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls) + } + + convertMcpToolResponseToSdkMessageParam( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ): SdkMessageParam | undefined { + const client = this.getClient(model) + return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model) + } + + extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] { + return this.currentClient.extractMessagesFromSdkPayload(sdkPayload) + } + + estimateMessageTokens(message: SdkMessageParam): number { + return this.currentClient.estimateMessageTokens(message) + } +} diff --git a/src/renderer/src/aiCore/clients/ApiClientFactory.ts b/src/renderer/src/aiCore/clients/ApiClientFactory.ts new file mode 100644 index 0000000000..0d8a97ed51 --- /dev/null +++ b/src/renderer/src/aiCore/clients/ApiClientFactory.ts @@ -0,0 +1,62 @@ +import { Provider } from '@renderer/types' + +import { AihubmixAPIClient } from './AihubmixAPIClient' +import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient' +import { BaseApiClient } from './BaseApiClient' +import { GeminiAPIClient } from './gemini/GeminiAPIClient' +import { OpenAIAPIClient } from './openai/OpenAIApiClient' +import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient' + +/** + * Factory for creating ApiClient instances based on provider configuration + * 根据提供者配置创建ApiClient实例的工厂 + */ +export class ApiClientFactory { + /** + * Create an ApiClient instance for the given provider + * 为给定的提供者创建ApiClient实例 + */ + static create(provider: Provider): BaseApiClient { + console.log(`[ApiClientFactory] Creating ApiClient for provider:`, { + id: provider.id, + type: provider.type + }) + + let instance: BaseApiClient + + // 首先检查特殊的provider id + if (provider.id === 'aihubmix') { + console.log(`[ApiClientFactory] Creating AihubmixAPIClient for provider: ${provider.id}`) + instance = new AihubmixAPIClient(provider) as BaseApiClient + return instance + } + + // 然后检查标准的provider type + switch (provider.type) { + case 'openai': + case 'azure-openai': + console.log(`[ApiClientFactory] Creating OpenAIApiClient for provider: ${provider.id}`) + instance = new OpenAIAPIClient(provider) as BaseApiClient + break + case 'openai-response': + instance = new OpenAIResponseAPIClient(provider) as BaseApiClient + break + case 'gemini': + instance = new GeminiAPIClient(provider) as BaseApiClient + break + case 'anthropic': + instance = new AnthropicAPIClient(provider) as BaseApiClient + break + default: + console.log(`[ApiClientFactory] Using default OpenAIApiClient for provider: ${provider.id}`) + instance = new OpenAIAPIClient(provider) as BaseApiClient + break + } + + return instance + } +} + +export function isOpenAIProvider(provider: Provider) { + return !['anthropic', 'gemini'].includes(provider.type) +} diff --git a/src/renderer/src/providers/AiProvider/BaseProvider.ts b/src/renderer/src/aiCore/clients/BaseApiClient.ts similarity index 53% rename from src/renderer/src/providers/AiProvider/BaseProvider.ts rename to src/renderer/src/aiCore/clients/BaseApiClient.ts index f1beba64d4..19e455026d 100644 --- a/src/renderer/src/providers/AiProvider/BaseProvider.ts +++ b/src/renderer/src/aiCore/clients/BaseApiClient.ts @@ -1,40 +1,69 @@ -import Logger from '@renderer/config/logger' -import { isFunctionCallingModel, isNotSupportTemperatureAndTopP } from '@renderer/config/models' +import { + isFunctionCallingModel, + isNotSupportTemperatureAndTopP, + isOpenAIModel, + isSupportedFlexServiceTier +} from '@renderer/config/models' import { REFERENCE_PROMPT } from '@renderer/config/prompts' import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio' -import type { +import { getStoreSetting } from '@renderer/hooks/useSettings' +import { SettingsState } from '@renderer/store/settings' +import { Assistant, + FileTypes, GenerateImageParams, KnowledgeReference, MCPCallToolResponse, MCPTool, MCPToolResponse, Model, + OpenAIServiceTier, Provider, - Suggestion, + ToolCallResponse, WebSearchProviderResponse, WebSearchResponse } from '@renderer/types' -import { ChunkType } from '@renderer/types/chunk' -import type { Message } from '@renderer/types/newMessage' -import { delay, isJSON, parseJSON } from '@renderer/utils' +import { Message } from '@renderer/types/newMessage' +import { + RequestOptions, + SdkInstance, + SdkMessageParam, + SdkModel, + SdkParams, + SdkRawChunk, + SdkRawOutput, + SdkTool, + SdkToolCall +} from '@renderer/types/sdk' +import { isJSON, parseJSON } from '@renderer/utils' import { addAbortController, removeAbortController } from '@renderer/utils/abortController' -import { formatApiHost } from '@renderer/utils/api' -import { getMainTextContent } from '@renderer/utils/messageUtils/find' +import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' +import { defaultTimeout } from '@shared/config/constant' +import Logger from 'electron-log/renderer' import { isEmpty } from 'lodash' -import type OpenAI from 'openai' -import type { CompletionsParams } from '.' +import { ApiClient, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from './types' -export default abstract class BaseProvider { - // Threshold for determining whether to use system prompt for tools +/** + * Abstract base class for API clients. + * Provides common functionality and structure for specific client implementations. + */ +export abstract class BaseApiClient< + TSdkInstance extends SdkInstance = SdkInstance, + TSdkParams extends SdkParams = SdkParams, + TRawOutput extends SdkRawOutput = SdkRawOutput, + TRawChunk extends SdkRawChunk = SdkRawChunk, + TMessageParam extends SdkMessageParam = SdkMessageParam, + TToolCall extends SdkToolCall = SdkToolCall, + TSdkSpecificTool extends SdkTool = SdkTool +> implements ApiClient +{ private static readonly SYSTEM_PROMPT_THRESHOLD: number = 128 - - protected provider: Provider + public provider: Provider protected host: string protected apiKey: string - - protected useSystemPromptForTools: boolean = true + protected sdkInstance?: TSdkInstance + public useSystemPromptForTools: boolean = true constructor(provider: Provider) { this.provider = provider @@ -42,32 +71,81 @@ export default abstract class BaseProvider { this.apiKey = this.getApiKey() } - abstract completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise - abstract translate( - content: string, - assistant: Assistant, - onResponse?: (text: string, isComplete: boolean) => void - ): Promise - abstract summaries(messages: Message[], assistant: Assistant): Promise - abstract summaryForSearch(messages: Message[], assistant: Assistant): Promise - abstract suggestions(messages: Message[], assistant: Assistant): Promise - abstract generateText({ prompt, content }: { prompt: string; content: string }): Promise - abstract check(model: Model, stream: boolean): Promise<{ valid: boolean; error: Error | null }> - abstract models(): Promise - abstract generateImage(params: GenerateImageParams): Promise - abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise - // 由于现在出现了一些能够选择嵌入维度的嵌入模型,这个不考虑dimensions参数的方法将只能应用于那些不支持dimensions的模型 - abstract getEmbeddingDimensions(model: Model): Promise - public abstract convertMcpTools(mcpTools: MCPTool[]): T[] - public abstract mcpToolCallResponseToMessage( + // // 核心的completions方法 - 在中间件架构中,这通常只是一个占位符 + // abstract completions(params: CompletionsParams, internal?: ProcessingState): Promise + + /** + * 核心API Endpoint + **/ + + abstract createCompletions(payload: TSdkParams, options?: RequestOptions): Promise + + abstract generateImage(generateImageParams: GenerateImageParams): Promise + + abstract getEmbeddingDimensions(model?: Model): Promise + + abstract listModels(): Promise + + abstract getSdkInstance(): Promise | TSdkInstance + + /** + * 中间件 + **/ + + // 在 CoreRequestToSdkParamsMiddleware中使用 + abstract getRequestTransformer(): RequestTransformer + // 在RawSdkChunkToGenericChunkMiddleware中使用 + abstract getResponseChunkTransformer(): ResponseChunkTransformer + + /** + * 工具转换 + **/ + + // Optional tool conversion methods - implement if needed by the specific provider + abstract convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[] + + abstract convertSdkToolCallToMcp(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined + + abstract convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse + + abstract buildSdkMessages( + currentReqMessages: TMessageParam[], + output: TRawOutput | string, + toolResults: TMessageParam[], + toolCalls?: TToolCall[] + ): TMessageParam[] + + abstract estimateMessageTokens(message: TMessageParam): number + + abstract convertMcpToolResponseToSdkMessageParam( mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model - ): any + ): TMessageParam | undefined + + /** + * 从SDK载荷中提取消息数组(用于中间件中的类型安全访问) + * 不同的提供商可能使用不同的字段名(如messages、history等) + */ + abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[] + + /** + * 附加原始流监听器 + */ + public attachRawStreamListener>( + rawOutput: TRawOutput, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + _listener: TListener + ): TRawOutput { + return rawOutput + } + + /** + * 通用函数 + **/ public getBaseURL(): string { - const host = this.provider.apiHost - return formatApiHost(host) + return this.provider.apiHost } public getApiKey() { @@ -112,14 +190,32 @@ export default abstract class BaseProvider { return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.topP } - public async fakeCompletions({ onChunk }: CompletionsParams) { - for (let i = 0; i < 100; i++) { - await delay(0.01) - onChunk({ - response: { text: i + '\n', usage: { completion_tokens: 0, prompt_tokens: 0, total_tokens: 0 } }, - type: ChunkType.BLOCK_COMPLETE - }) + protected getServiceTier(model: Model) { + if (!isOpenAIModel(model) || model.provider === 'github' || model.provider === 'copilot') { + return undefined } + + const openAI = getStoreSetting('openAI') as SettingsState['openAI'] + let serviceTier = 'auto' as OpenAIServiceTier + + if (openAI && openAI?.serviceTier === 'flex') { + if (isSupportedFlexServiceTier(model)) { + serviceTier = 'flex' + } else { + serviceTier = 'auto' + } + } else { + serviceTier = openAI.serviceTier + } + + return serviceTier + } + + protected getTimeout(model: Model) { + if (isSupportedFlexServiceTier(model)) { + return 15 * 1000 * 60 + } + return defaultTimeout } public async getMessageContent(message: Message): Promise { @@ -149,6 +245,36 @@ export default abstract class BaseProvider { return content } + /** + * Extract the file content from the message + * @param message - The message + * @returns The file content + */ + protected async extractFileContent(message: Message) { + const fileBlocks = findFileBlocks(message) + if (fileBlocks.length > 0) { + const textFileBlocks = fileBlocks.filter( + (fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type) + ) + + if (textFileBlocks.length > 0) { + let text = '' + const divider = '\n\n---\n\n' + + for (const fileBlock of textFileBlocks) { + const file = fileBlock.file + const fileContent = (await window.api.file.read(file.id + file.ext)).trim() + const fileNameRow = 'file: ' + file.origin_name + '\n\n' + text = text + fileNameRow + fileContent + divider + } + + return text + } + } + + return '' + } + private async getWebSearchReferencesFromCache(message: Message) { const content = getMainTextContent(message) if (isEmpty(content)) { @@ -210,7 +336,7 @@ export default abstract class BaseProvider { ) } - protected createAbortController(messageId?: string, isAddEventListener?: boolean) { + public createAbortController(messageId?: string, isAddEventListener?: boolean) { const abortController = new AbortController() const abortFn = () => abortController.abort() @@ -256,11 +382,11 @@ export default abstract class BaseProvider { } // Setup tools configuration based on provided parameters - protected setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): { - tools: T[] + public setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): { + tools: TSdkSpecificTool[] } { const { mcpTools, model, enableToolUse } = params - let tools: T[] = [] + let tools: TSdkSpecificTool[] = [] // If there are no tools, return an empty array if (!mcpTools?.length) { @@ -268,14 +394,14 @@ export default abstract class BaseProvider { } // If the number of tools exceeds the threshold, use the system prompt - if (mcpTools.length > BaseProvider.SYSTEM_PROMPT_THRESHOLD) { + if (mcpTools.length > BaseApiClient.SYSTEM_PROMPT_THRESHOLD) { this.useSystemPromptForTools = true return { tools } } // If the model supports function calling and tool usage is enabled if (isFunctionCallingModel(model) && enableToolUse) { - tools = this.convertMcpTools(mcpTools) + tools = this.convertMcpToolsToSdkTools(mcpTools) this.useSystemPromptForTools = false } diff --git a/src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts b/src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts new file mode 100644 index 0000000000..ffbda737c6 --- /dev/null +++ b/src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts @@ -0,0 +1,714 @@ +import Anthropic from '@anthropic-ai/sdk' +import { + Base64ImageSource, + ImageBlockParam, + MessageParam, + TextBlockParam, + ToolResultBlockParam, + ToolUseBlock, + WebSearchTool20250305 +} from '@anthropic-ai/sdk/resources' +import { + ContentBlock, + ContentBlockParam, + MessageCreateParams, + MessageCreateParamsBase, + RedactedThinkingBlockParam, + ServerToolUseBlockParam, + ThinkingBlockParam, + ThinkingConfigParam, + ToolUnion, + ToolUseBlockParam, + WebSearchResultBlock, + WebSearchToolResultBlockParam, + WebSearchToolResultError +} from '@anthropic-ai/sdk/resources/messages' +import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages' +import { GenericChunk } from '@renderer/aiCore/middleware/schemas' +import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' +import Logger from '@renderer/config/logger' +import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models' +import { getAssistantSettings } from '@renderer/services/AssistantService' +import FileManager from '@renderer/services/FileManager' +import { estimateTextTokens } from '@renderer/services/TokenService' +import { + Assistant, + EFFORT_RATIO, + FileTypes, + MCPCallToolResponse, + MCPTool, + MCPToolResponse, + Model, + Provider, + ToolCallResponse, + WebSearchSource +} from '@renderer/types' +import { + ChunkType, + ErrorChunk, + LLMWebSearchCompleteChunk, + LLMWebSearchInProgressChunk, + MCPToolCreatedChunk, + TextDeltaChunk, + ThinkingDeltaChunk +} from '@renderer/types/chunk' +import type { Message } from '@renderer/types/newMessage' +import { + AnthropicSdkMessageParam, + AnthropicSdkParams, + AnthropicSdkRawChunk, + AnthropicSdkRawOutput +} from '@renderer/types/sdk' +import { addImageFileToContents } from '@renderer/utils/formats' +import { + anthropicToolUseToMcpTool, + isEnabledToolUse, + mcpToolCallResponseToAnthropicMessage, + mcpToolsToAnthropicTools +} from '@renderer/utils/mcp-tools' +import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' +import { buildSystemPrompt } from '@renderer/utils/prompt' + +import { BaseApiClient } from '../BaseApiClient' +import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types' + +export class AnthropicAPIClient extends BaseApiClient< + Anthropic, + AnthropicSdkParams, + AnthropicSdkRawOutput, + AnthropicSdkRawChunk, + AnthropicSdkMessageParam, + ToolUseBlock, + ToolUnion +> { + constructor(provider: Provider) { + super(provider) + } + + async getSdkInstance(): Promise { + if (this.sdkInstance) { + return this.sdkInstance + } + this.sdkInstance = new Anthropic({ + apiKey: this.getApiKey(), + baseURL: this.getBaseURL(), + dangerouslyAllowBrowser: true, + defaultHeaders: { + 'anthropic-beta': 'output-128k-2025-02-19' + } + }) + return this.sdkInstance + } + + override async createCompletions( + payload: AnthropicSdkParams, + options?: Anthropic.RequestOptions + ): Promise { + const sdk = await this.getSdkInstance() + if (payload.stream) { + return sdk.messages.stream(payload, options) + } + return await sdk.messages.create(payload, options) + } + + // @ts-ignore sdk未提供 + // eslint-disable-next-line @typescript-eslint/no-unused-vars + override async generateImage(generateImageParams: GenerateImageParams): Promise { + return [] + } + + override async listModels(): Promise { + const sdk = await this.getSdkInstance() + const response = await sdk.models.list() + return response.data + } + + // @ts-ignore sdk未提供 + override async getEmbeddingDimensions(): Promise { + return 0 + } + + override getTemperature(assistant: Assistant, model: Model): number | undefined { + if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { + return undefined + } + return assistant.settings?.temperature + } + + override getTopP(assistant: Assistant, model: Model): number | undefined { + if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { + return undefined + } + return assistant.settings?.topP + } + + /** + * Get the reasoning effort + * @param assistant - The assistant + * @param model - The model + * @returns The reasoning effort + */ + private getBudgetToken(assistant: Assistant, model: Model): ThinkingConfigParam | undefined { + if (!isReasoningModel(model)) { + return undefined + } + const { maxTokens } = getAssistantSettings(assistant) + + const reasoningEffort = assistant?.settings?.reasoning_effort + + if (reasoningEffort === undefined) { + return { + type: 'disabled' + } + } + + const effortRatio = EFFORT_RATIO[reasoningEffort] + + const budgetTokens = Math.max( + 1024, + Math.floor( + Math.min( + (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + + findTokenLimit(model.id)?.min!, + (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio + ) + ) + ) + + return { + type: 'enabled', + budget_tokens: budgetTokens + } + } + + /** + * Get the message parameter + * @param message - The message + * @param model - The model + * @returns The message parameter + */ + public async convertMessageToSdkParam(message: Message): Promise { + const parts: MessageParam['content'] = [ + { + type: 'text', + text: getMainTextContent(message) + } + ] + + // Get and process image blocks + const imageBlocks = findImageBlocks(message) + for (const imageBlock of imageBlocks) { + if (imageBlock.file) { + // Handle uploaded file + const file = imageBlock.file + const base64Data = await window.api.file.base64Image(file.id + file.ext) + parts.push({ + type: 'image', + source: { + data: base64Data.base64, + media_type: base64Data.mime.replace('jpg', 'jpeg') as any, + type: 'base64' + } + }) + } + } + // Get and process file blocks + const fileBlocks = findFileBlocks(message) + for (const fileBlock of fileBlocks) { + const { file } = fileBlock + if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { + if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) { + const base64Data = await FileManager.readBase64File(file) + parts.push({ + type: 'document', + source: { + type: 'base64', + media_type: 'application/pdf', + data: base64Data + } + }) + } else { + const fileContent = await (await window.api.file.read(file.id + file.ext)).trim() + parts.push({ + type: 'text', + text: file.origin_name + '\n' + fileContent + }) + } + } + } + + return { + role: message.role === 'system' ? 'user' : message.role, + content: parts + } + } + + public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ToolUnion[] { + return mcpToolsToAnthropicTools(mcpTools) + } + + public convertMcpToolResponseToSdkMessageParam( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ): AnthropicSdkMessageParam | undefined { + if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { + return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model) + } else if ('toolCallId' in mcpToolResponse) { + return { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: mcpToolResponse.toolCallId!, + content: resp.content + .map((item) => { + if (item.type === 'text') { + return { + type: 'text', + text: item.text || '' + } satisfies TextBlockParam + } + if (item.type === 'image') { + return { + type: 'image', + source: { + data: item.data || '', + media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'], + type: 'base64' + } + } satisfies ImageBlockParam + } + return + }) + .filter((n) => typeof n !== 'undefined'), + is_error: resp.isError + } satisfies ToolResultBlockParam + ] + } + } + return + } + + // Implementing abstract methods from BaseApiClient + convertSdkToolCallToMcp(toolCall: ToolUseBlock, mcpTools: MCPTool[]): MCPTool | undefined { + // Based on anthropicToolUseToMcpTool logic in AnthropicProvider + // This might need adjustment based on how tool calls are specifically handled in the new structure + const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall) + return mcpTool + } + + convertSdkToolCallToMcpToolResponse(toolCall: ToolUseBlock, mcpTool: MCPTool): ToolCallResponse { + return { + id: toolCall.id, + toolCallId: toolCall.id, + tool: mcpTool, + arguments: toolCall.input as Record, + status: 'pending' + } as ToolCallResponse + } + + override buildSdkMessages( + currentReqMessages: AnthropicSdkMessageParam[], + output: Anthropic.Message, + toolResults: AnthropicSdkMessageParam[] + ): AnthropicSdkMessageParam[] { + const assistantMessage: AnthropicSdkMessageParam = { + role: output.role, + content: convertContentBlocksToParams(output.content) + } + + const newMessages: AnthropicSdkMessageParam[] = [...currentReqMessages, assistantMessage] + if (toolResults && toolResults.length > 0) { + newMessages.push(...toolResults) + } + return newMessages + } + + override estimateMessageTokens(message: AnthropicSdkMessageParam): number { + if (typeof message.content === 'string') { + return estimateTextTokens(message.content) + } + return message.content + .map((content) => { + switch (content.type) { + case 'text': + return estimateTextTokens(content.text) + case 'image': + if (content.source.type === 'base64') { + return estimateTextTokens(content.source.data) + } else { + return estimateTextTokens(content.source.url) + } + case 'tool_use': + return estimateTextTokens(JSON.stringify(content.input)) + case 'tool_result': + return estimateTextTokens(JSON.stringify(content.content)) + default: + return 0 + } + }) + .reduce((acc, curr) => acc + curr, 0) + } + + public buildAssistantMessage(message: Anthropic.Message): AnthropicSdkMessageParam { + const messageParam: AnthropicSdkMessageParam = { + role: message.role, + content: convertContentBlocksToParams(message.content) + } + return messageParam + } + + public extractMessagesFromSdkPayload(sdkPayload: AnthropicSdkParams): AnthropicSdkMessageParam[] { + return sdkPayload.messages || [] + } + + /** + * Anthropic专用的原始流监听器 + * 处理MessageStream对象的特定事件 + */ + override attachRawStreamListener( + rawOutput: AnthropicSdkRawOutput, + listener: RawStreamListener + ): AnthropicSdkRawOutput { + console.log(`[AnthropicApiClient] 附加流监听器到原始输出`) + + // 检查是否为MessageStream + if (rawOutput instanceof MessageStream) { + console.log(`[AnthropicApiClient] 检测到 Anthropic MessageStream,附加专用监听器`) + + if (listener.onStart) { + listener.onStart() + } + + if (listener.onChunk) { + rawOutput.on('streamEvent', (event: AnthropicSdkRawChunk) => { + listener.onChunk!(event) + }) + } + + // 专用的Anthropic事件处理 + const anthropicListener = listener as AnthropicStreamListener + + if (anthropicListener.onContentBlock) { + rawOutput.on('contentBlock', anthropicListener.onContentBlock) + } + + if (anthropicListener.onMessage) { + rawOutput.on('finalMessage', anthropicListener.onMessage) + } + + if (listener.onEnd) { + rawOutput.on('end', () => { + listener.onEnd!() + }) + } + + if (listener.onError) { + rawOutput.on('error', (error: Error) => { + listener.onError!(error) + }) + } + + return rawOutput + } + + // 对于非MessageStream响应 + return rawOutput + } + + private async getWebSearchParams(model: Model): Promise { + if (!isWebSearchModel(model)) { + return undefined + } + return { + type: 'web_search_20250305', + name: 'web_search', + max_uses: 5 + } as WebSearchTool20250305 + } + + getRequestTransformer(): RequestTransformer { + return { + transform: async ( + coreRequest, + assistant, + model, + isRecursiveCall, + recursiveSdkMessages + ): Promise<{ + payload: AnthropicSdkParams + messages: AnthropicSdkMessageParam[] + metadata: Record + }> => { + const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest + // 1. 处理系统消息 + let systemPrompt = assistant.prompt + + // 2. 设置工具 + const { tools } = this.setupToolsConfig({ + mcpTools: mcpTools, + model, + enableToolUse: isEnabledToolUse(assistant) + }) + + if (this.useSystemPromptForTools) { + systemPrompt = await buildSystemPrompt(systemPrompt, mcpTools) + } + + const systemMessage: TextBlockParam | undefined = systemPrompt + ? { type: 'text', text: systemPrompt } + : undefined + + // 3. 处理用户消息 + const sdkMessages: AnthropicSdkMessageParam[] = [] + if (typeof messages === 'string') { + sdkMessages.push({ role: 'user', content: messages }) + } else { + const processedMessages = addImageFileToContents(messages) + for (const message of processedMessages) { + sdkMessages.push(await this.convertMessageToSdkParam(message)) + } + } + + if (enableWebSearch) { + const webSearchTool = await this.getWebSearchParams(model) + if (webSearchTool) { + tools.push(webSearchTool) + } + } + + const commonParams: MessageCreateParamsBase = { + model: model.id, + messages: + isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 + ? recursiveSdkMessages + : sdkMessages, + max_tokens: maxTokens || DEFAULT_MAX_TOKENS, + temperature: this.getTemperature(assistant, model), + top_p: this.getTopP(assistant, model), + system: systemMessage ? [systemMessage] : undefined, + thinking: this.getBudgetToken(assistant, model), + tools: tools.length > 0 ? tools : undefined, + ...this.getCustomParameters(assistant) + } + + const finalParams: MessageCreateParams = streamOutput + ? { + ...commonParams, + stream: true + } + : { + ...commonParams, + stream: false + } + + const timeout = this.getTimeout(model) + return { payload: finalParams, messages: sdkMessages, metadata: { timeout } } + } + } + } + + getResponseChunkTransformer(): ResponseChunkTransformer { + return () => { + let accumulatedJson = '' + const toolCalls: Record = {} + + return { + async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController) { + switch (rawChunk.type) { + case 'message': { + for (const content of rawChunk.content) { + switch (content.type) { + case 'text': { + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: content.text + } as TextDeltaChunk) + break + } + case 'tool_use': { + toolCalls[0] = content + break + } + case 'thinking': { + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: content.thinking + } as ThinkingDeltaChunk) + break + } + case 'web_search_tool_result': { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + results: content.content, + source: WebSearchSource.ANTHROPIC + } + } as LLMWebSearchCompleteChunk) + break + } + } + } + break + } + case 'content_block_start': { + const contentBlock = rawChunk.content_block + switch (contentBlock.type) { + case 'server_tool_use': { + if (contentBlock.name === 'web_search') { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS + } as LLMWebSearchInProgressChunk) + } + break + } + case 'web_search_tool_result': { + if ( + contentBlock.content && + (contentBlock.content as WebSearchToolResultError).type === 'web_search_tool_result_error' + ) { + controller.enqueue({ + type: ChunkType.ERROR, + error: { + code: (contentBlock.content as WebSearchToolResultError).error_code, + message: (contentBlock.content as WebSearchToolResultError).error_code + } + } as ErrorChunk) + } else { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + results: contentBlock.content as Array, + source: WebSearchSource.ANTHROPIC + } + } as LLMWebSearchCompleteChunk) + } + break + } + case 'tool_use': { + toolCalls[rawChunk.index] = contentBlock + break + } + } + break + } + case 'content_block_delta': { + const messageDelta = rawChunk.delta + switch (messageDelta.type) { + case 'text_delta': { + if (messageDelta.text) { + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: messageDelta.text + } as TextDeltaChunk) + } + break + } + case 'thinking_delta': { + if (messageDelta.thinking) { + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: messageDelta.thinking + } as ThinkingDeltaChunk) + } + break + } + case 'input_json_delta': { + if (messageDelta.partial_json) { + accumulatedJson += messageDelta.partial_json + } + break + } + } + break + } + case 'content_block_stop': { + const toolCall = toolCalls[rawChunk.index] + if (toolCall) { + try { + toolCall.input = JSON.parse(accumulatedJson) + Logger.debug(`Tool call id: ${toolCall.id}, accumulated json: ${accumulatedJson}`) + controller.enqueue({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: [toolCall] + } as MCPToolCreatedChunk) + } catch (error) { + Logger.error(`Error parsing tool call input: ${error}`) + } + } + break + } + case 'message_delta': { + controller.enqueue({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + prompt_tokens: rawChunk.usage.input_tokens || 0, + completion_tokens: rawChunk.usage.output_tokens || 0, + total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0) + } + } + }) + } + } + } + } + } + } +} + +/** + * 将 ContentBlock 数组转换为 ContentBlockParam 数组 + * 去除服务器生成的额外字段,只保留发送给API所需的字段 + */ +function convertContentBlocksToParams(contentBlocks: ContentBlock[]): ContentBlockParam[] { + return contentBlocks.map((block): ContentBlockParam => { + switch (block.type) { + case 'text': + // TextBlock -> TextBlockParam,去除 citations 等服务器字段 + return { + type: 'text', + text: block.text + } satisfies TextBlockParam + case 'tool_use': + // ToolUseBlock -> ToolUseBlockParam + return { + type: 'tool_use', + id: block.id, + name: block.name, + input: block.input + } satisfies ToolUseBlockParam + case 'thinking': + // ThinkingBlock -> ThinkingBlockParam + return { + type: 'thinking', + thinking: block.thinking, + signature: block.signature + } satisfies ThinkingBlockParam + case 'redacted_thinking': + // RedactedThinkingBlock -> RedactedThinkingBlockParam + return { + type: 'redacted_thinking', + data: block.data + } satisfies RedactedThinkingBlockParam + case 'server_tool_use': + // ServerToolUseBlock -> ServerToolUseBlockParam + return { + type: 'server_tool_use', + id: block.id, + name: block.name, + input: block.input + } satisfies ServerToolUseBlockParam + case 'web_search_tool_result': + // WebSearchToolResultBlock -> WebSearchToolResultBlockParam + return { + type: 'web_search_tool_result', + tool_use_id: block.tool_use_id, + content: block.content + } satisfies WebSearchToolResultBlockParam + default: + return block as ContentBlockParam + } + }) +} diff --git a/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts b/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts new file mode 100644 index 0000000000..bc848df7f9 --- /dev/null +++ b/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts @@ -0,0 +1,786 @@ +import { + Content, + File, + FileState, + FunctionCall, + GenerateContentConfig, + GenerateImagesConfig, + GoogleGenAI, + HarmBlockThreshold, + HarmCategory, + Modality, + Model as GeminiModel, + Pager, + Part, + SafetySetting, + SendMessageParameters, + ThinkingConfig, + Tool +} from '@google/genai' +import { nanoid } from '@reduxjs/toolkit' +import { GenericChunk } from '@renderer/aiCore/middleware/schemas' +import { + findTokenLimit, + GEMINI_FLASH_MODEL_REGEX, + isGeminiReasoningModel, + isGemmaModel, + isVisionModel +} from '@renderer/config/models' +import { CacheService } from '@renderer/services/CacheService' +import { estimateTextTokens } from '@renderer/services/TokenService' +import { + Assistant, + EFFORT_RATIO, + FileType, + FileTypes, + GenerateImageParams, + MCPCallToolResponse, + MCPTool, + MCPToolResponse, + Model, + Provider, + ToolCallResponse, + WebSearchSource +} from '@renderer/types' +import { ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk' +import { Message } from '@renderer/types/newMessage' +import { + GeminiOptions, + GeminiSdkMessageParam, + GeminiSdkParams, + GeminiSdkRawChunk, + GeminiSdkRawOutput, + GeminiSdkToolCall +} from '@renderer/types/sdk' +import { + geminiFunctionCallToMcpTool, + isEnabledToolUse, + mcpToolCallResponseToGeminiMessage, + mcpToolsToGeminiTools +} from '@renderer/utils/mcp-tools' +import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' +import { buildSystemPrompt } from '@renderer/utils/prompt' +import { MB } from '@shared/config/constant' + +import { BaseApiClient } from '../BaseApiClient' +import { RequestTransformer, ResponseChunkTransformer } from '../types' + +export class GeminiAPIClient extends BaseApiClient< + GoogleGenAI, + GeminiSdkParams, + GeminiSdkRawOutput, + GeminiSdkRawChunk, + GeminiSdkMessageParam, + GeminiSdkToolCall, + Tool +> { + constructor(provider: Provider) { + super(provider) + } + + override async createCompletions(payload: GeminiSdkParams, options?: GeminiOptions): Promise { + const sdk = await this.getSdkInstance() + const { model, history, ...rest } = payload + const realPayload: Omit = { + ...rest, + config: { + ...rest.config, + abortSignal: options?.abortSignal, + httpOptions: { + ...rest.config?.httpOptions, + timeout: options?.timeout + } + } + } satisfies SendMessageParameters + + const streamOutput = options?.streamOutput + + const chat = sdk.chats.create({ + model: model, + history: history + }) + + if (streamOutput) { + const stream = chat.sendMessageStream(realPayload) + return stream + } else { + const response = await chat.sendMessage(realPayload) + return response + } + } + + override async generateImage(generateImageParams: GenerateImageParams): Promise { + const sdk = await this.getSdkInstance() + try { + const { model, prompt, imageSize, batchSize, signal } = generateImageParams + const config: GenerateImagesConfig = { + numberOfImages: batchSize, + aspectRatio: imageSize, + abortSignal: signal, + httpOptions: { + timeout: 5 * 60 * 1000 + } + } + const response = await sdk.models.generateImages({ + model: model, + prompt, + config + }) + + if (!response.generatedImages || response.generatedImages.length === 0) { + return [] + } + + const images = response.generatedImages + .filter((image) => image.image?.imageBytes) + .map((image) => { + const dataPrefix = `data:${image.image?.mimeType || 'image/png'};base64,` + return dataPrefix + image.image?.imageBytes + }) + // console.log(response?.generatedImages?.[0]?.image?.imageBytes); + return images + } catch (error) { + console.error('[generateImage] error:', error) + throw error + } + } + + override async getEmbeddingDimensions(model: Model): Promise { + const sdk = await this.getSdkInstance() + try { + const data = await sdk.models.embedContent({ + model: model.id, + contents: [{ role: 'user', parts: [{ text: 'hi' }] }] + }) + return data.embeddings?.[0]?.values?.length || 0 + } catch (e) { + return 0 + } + } + + override async listModels(): Promise { + const sdk = await this.getSdkInstance() + const response = await sdk.models.list() + const models: GeminiModel[] = [] + for await (const model of response) { + models.push(model) + } + return models + } + + override async getSdkInstance() { + if (this.sdkInstance) { + return this.sdkInstance + } + + this.sdkInstance = new GoogleGenAI({ + vertexai: false, + apiKey: this.apiKey, + httpOptions: { baseUrl: this.getBaseURL() } + }) + + return this.sdkInstance + } + + /** + * Handle a PDF file + * @param file - The file + * @returns The part + */ + private async handlePdfFile(file: FileType): Promise { + const smallFileSize = 20 * MB + const isSmallFile = file.size < smallFileSize + + if (isSmallFile) { + const { data, mimeType } = await this.base64File(file) + return { + inlineData: { + data, + mimeType + } as Part['inlineData'] + } + } + + // Retrieve file from Gemini uploaded files + const fileMetadata: File | undefined = await this.retrieveFile(file) + + if (fileMetadata) { + return { + fileData: { + fileUri: fileMetadata.uri, + mimeType: fileMetadata.mimeType + } as Part['fileData'] + } + } + + // If file is not found, upload it to Gemini + const result = await this.uploadFile(file) + + return { + fileData: { + fileUri: result.uri, + mimeType: result.mimeType + } as Part['fileData'] + } + } + + /** + * Get the message contents + * @param message - The message + * @returns The message contents + */ + private async convertMessageToSdkParam(message: Message): Promise { + const role = message.role === 'user' ? 'user' : 'model' + const parts: Part[] = [{ text: await this.getMessageContent(message) }] + // Add any generated images from previous responses + const imageBlocks = findImageBlocks(message) + for (const imageBlock of imageBlocks) { + if ( + imageBlock.metadata?.generateImageResponse?.images && + imageBlock.metadata.generateImageResponse.images.length > 0 + ) { + for (const imageUrl of imageBlock.metadata.generateImageResponse.images) { + if (imageUrl && imageUrl.startsWith('data:')) { + // Extract base64 data and mime type from the data URL + const matches = imageUrl.match(/^data:(.+);base64,(.*)$/) + if (matches && matches.length === 3) { + const mimeType = matches[1] + const base64Data = matches[2] + parts.push({ + inlineData: { + data: base64Data, + mimeType: mimeType + } as Part['inlineData'] + }) + } + } + } + } + const file = imageBlock.file + if (file) { + const base64Data = await window.api.file.base64Image(file.id + file.ext) + parts.push({ + inlineData: { + data: base64Data.base64, + mimeType: base64Data.mime + } as Part['inlineData'] + }) + } + } + + const fileBlocks = findFileBlocks(message) + for (const fileBlock of fileBlocks) { + const file = fileBlock.file + if (file.type === FileTypes.IMAGE) { + const base64Data = await window.api.file.base64Image(file.id + file.ext) + parts.push({ + inlineData: { + data: base64Data.base64, + mimeType: base64Data.mime + } as Part['inlineData'] + }) + } + + if (file.ext === '.pdf') { + parts.push(await this.handlePdfFile(file)) + continue + } + if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { + const fileContent = await (await window.api.file.read(file.id + file.ext)).trim() + parts.push({ + text: file.origin_name + '\n' + fileContent + }) + } + } + + return { + role, + parts: parts + } + } + + // @ts-ignore unused + private async getImageFileContents(message: Message): Promise { + const role = message.role === 'user' ? 'user' : 'model' + const content = getMainTextContent(message) + const parts: Part[] = [{ text: content }] + const imageBlocks = findImageBlocks(message) + for (const imageBlock of imageBlocks) { + if ( + imageBlock.metadata?.generateImageResponse?.images && + imageBlock.metadata.generateImageResponse.images.length > 0 + ) { + for (const imageUrl of imageBlock.metadata.generateImageResponse.images) { + if (imageUrl && imageUrl.startsWith('data:')) { + // Extract base64 data and mime type from the data URL + const matches = imageUrl.match(/^data:(.+);base64,(.*)$/) + if (matches && matches.length === 3) { + const mimeType = matches[1] + const base64Data = matches[2] + parts.push({ + inlineData: { + data: base64Data, + mimeType: mimeType + } as Part['inlineData'] + }) + } + } + } + } + const file = imageBlock.file + if (file) { + const base64Data = await window.api.file.base64Image(file.id + file.ext) + parts.push({ + inlineData: { + data: base64Data.base64, + mimeType: base64Data.mime + } as Part['inlineData'] + }) + } + } + return { + role, + parts: parts + } + } + + /** + * Get the safety settings + * @returns The safety settings + */ + private getSafetySettings(): SafetySetting[] { + const safetyThreshold = 'OFF' as HarmBlockThreshold + + return [ + { + category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold: safetyThreshold + }, + { + category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold: safetyThreshold + }, + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: safetyThreshold + }, + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: safetyThreshold + }, + { + category: HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, + threshold: HarmBlockThreshold.BLOCK_NONE + } + ] + } + + /** + * Get the reasoning effort for the assistant + * @param assistant - The assistant + * @param model - The model + * @returns The reasoning effort + */ + private getBudgetToken(assistant: Assistant, model: Model) { + if (isGeminiReasoningModel(model)) { + const reasoningEffort = assistant?.settings?.reasoning_effort + + // 如果thinking_budget是undefined,不思考 + if (reasoningEffort === undefined) { + return { + thinkingConfig: { + includeThoughts: false, + ...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {}) + } as ThinkingConfig + } + } + + const effortRatio = EFFORT_RATIO[reasoningEffort] + + if (effortRatio > 1) { + return { + thinkingConfig: { + includeThoughts: true + } + } + } + + const { max } = findTokenLimit(model.id) || { max: 0 } + const budget = Math.floor(max * effortRatio) + + return { + thinkingConfig: { + ...(budget > 0 ? { thinkingBudget: budget } : {}), + includeThoughts: true + } as ThinkingConfig + } + } + + return {} + } + + private getGenerateImageParameter(): Partial { + return { + systemInstruction: undefined, + responseModalities: [Modality.TEXT, Modality.IMAGE], + responseMimeType: 'text/plain' + } + } + + getRequestTransformer(): RequestTransformer { + return { + transform: async ( + coreRequest, + assistant, + model, + isRecursiveCall, + recursiveSdkMessages + ): Promise<{ + payload: GeminiSdkParams + messages: GeminiSdkMessageParam[] + metadata: Record + }> => { + const { messages, mcpTools, maxTokens, enableWebSearch, enableGenerateImage } = coreRequest + // 1. 处理系统消息 + let systemInstruction = assistant.prompt + + // 2. 设置工具 + const { tools } = this.setupToolsConfig({ + mcpTools, + model, + enableToolUse: isEnabledToolUse(assistant) + }) + + if (this.useSystemPromptForTools) { + systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools) + } + + let messageContents: Content + const history: Content[] = [] + // 3. 处理用户消息 + if (typeof messages === 'string') { + messageContents = { + role: 'user', + parts: [{ text: messages }] + } + } else { + const userLastMessage = messages.pop()! + messageContents = await this.convertMessageToSdkParam(userLastMessage) + for (const message of messages) { + history.push(await this.convertMessageToSdkParam(message)) + } + } + + if (enableWebSearch) { + tools.push({ + googleSearch: {} + }) + } + + if (isGemmaModel(model) && assistant.prompt) { + const isFirstMessage = history.length === 0 + if (isFirstMessage && messageContents) { + const systemMessage = [ + { + text: + 'user\n' + + systemInstruction + + '\n' + + 'user\n' + + (messageContents?.parts?.[0] as Part).text + + '' + } + ] as Part[] + if (messageContents && messageContents.parts) { + messageContents.parts[0] = systemMessage[0] + } + } + } + + const newHistory = + isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 + ? recursiveSdkMessages.slice(0, recursiveSdkMessages.length - 1) + : history + + const newMessageContents = + isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 + ? { + ...messageContents, + parts: [ + ...(messageContents.parts || []), + ...(recursiveSdkMessages[recursiveSdkMessages.length - 1].parts || []) + ] + } + : messageContents + + const generateContentConfig: GenerateContentConfig = { + safetySettings: this.getSafetySettings(), + systemInstruction: isGemmaModel(model) ? undefined : systemInstruction, + temperature: this.getTemperature(assistant, model), + topP: this.getTopP(assistant, model), + maxOutputTokens: maxTokens, + tools: tools, + ...(enableGenerateImage ? this.getGenerateImageParameter() : {}), + ...this.getBudgetToken(assistant, model), + ...this.getCustomParameters(assistant) + } + + const param: GeminiSdkParams = { + model: model.id, + config: generateContentConfig, + history: newHistory, + message: newMessageContents.parts! + } + + return { + payload: param, + messages: [messageContents], + metadata: {} + } + } + } + } + + getResponseChunkTransformer(): ResponseChunkTransformer { + return () => ({ + async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController) { + let toolCalls: FunctionCall[] = [] + if (chunk.candidates && chunk.candidates.length > 0) { + for (const candidate of chunk.candidates) { + if (candidate.content) { + candidate.content.parts?.forEach((part) => { + const text = part.text || '' + if (part.thought) { + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: text + }) + } else if (part.text) { + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: text + }) + } else if (part.inlineData) { + controller.enqueue({ + type: ChunkType.IMAGE_COMPLETE, + image: { + type: 'base64', + images: [ + part.inlineData?.data?.startsWith('data:') + ? part.inlineData?.data + : `data:${part.inlineData?.mimeType || 'image/png'};base64,${part.inlineData?.data}` + ] + } + }) + } + }) + } + + if (candidate.finishReason) { + if (candidate.groundingMetadata) { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + results: candidate.groundingMetadata, + source: WebSearchSource.GEMINI + } + } as LLMWebSearchCompleteChunk) + } + if (chunk.functionCalls) { + toolCalls = toolCalls.concat(chunk.functionCalls) + } + controller.enqueue({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0, + completion_tokens: + (chunk.usageMetadata?.totalTokenCount || 0) - (chunk.usageMetadata?.promptTokenCount || 0), + total_tokens: chunk.usageMetadata?.totalTokenCount || 0 + } + } + }) + } + } + } + + if (toolCalls.length > 0) { + controller.enqueue({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: toolCalls + }) + } + } + }) + } + + public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Tool[] { + return mcpToolsToGeminiTools(mcpTools) + } + + public convertSdkToolCallToMcp(toolCall: GeminiSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined { + return geminiFunctionCallToMcpTool(mcpTools, toolCall) + } + + public convertSdkToolCallToMcpToolResponse(toolCall: GeminiSdkToolCall, mcpTool: MCPTool): ToolCallResponse { + const parsedArgs = (() => { + try { + return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args + } catch { + return toolCall.args + } + })() + + return { + id: toolCall.id || nanoid(), + toolCallId: toolCall.id, + tool: mcpTool, + arguments: parsedArgs, + status: 'pending' + } as ToolCallResponse + } + + public convertMcpToolResponseToSdkMessageParam( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ): GeminiSdkMessageParam | undefined { + if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { + return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model)) + } else if ('toolCallId' in mcpToolResponse) { + return { + role: 'user', + parts: [ + { + functionResponse: { + id: mcpToolResponse.toolCallId, + name: mcpToolResponse.tool.id, + response: { + output: !resp.isError ? resp.content : undefined, + error: resp.isError ? resp.content : undefined + } + } + } + ] + } satisfies Content + } + return + } + + public buildSdkMessages( + currentReqMessages: Content[], + output: string, + toolResults: Content[], + toolCalls: FunctionCall[] + ): Content[] { + const parts: Part[] = [] + if (output) { + parts.push({ + text: output + }) + } + toolCalls.forEach((toolCall) => { + parts.push({ + functionCall: toolCall + }) + }) + parts.push( + ...toolResults + .map((ts) => ts.parts) + .flat() + .filter((p) => p !== undefined) + ) + + const userMessage: Content = { + role: 'user', + parts: parts + } + + return [...currentReqMessages, userMessage] + } + + override estimateMessageTokens(message: GeminiSdkMessageParam): number { + return ( + message.parts?.reduce((acc, part) => { + if (part.text) { + return acc + estimateTextTokens(part.text) + } + if (part.functionCall) { + return acc + estimateTextTokens(JSON.stringify(part.functionCall)) + } + if (part.functionResponse) { + return acc + estimateTextTokens(JSON.stringify(part.functionResponse.response)) + } + if (part.inlineData) { + return acc + estimateTextTokens(part.inlineData.data || '') + } + if (part.fileData) { + return acc + estimateTextTokens(part.fileData.fileUri || '') + } + return acc + }, 0) || 0 + ) + } + + public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] { + return sdkPayload.history || [] + } + + private async uploadFile(file: FileType): Promise { + return await this.sdkInstance!.files.upload({ + file: file.path, + config: { + mimeType: 'application/pdf', + name: file.id, + displayName: file.origin_name + } + }) + } + + private async base64File(file: FileType) { + const { data } = await window.api.file.base64File(file.id + file.ext) + return { + data, + mimeType: 'application/pdf' + } + } + + private async retrieveFile(file: FileType): Promise { + const cachedResponse = CacheService.get('gemini_file_list') + + if (cachedResponse) { + return this.processResponse(cachedResponse, file) + } + + const response = await this.sdkInstance!.files.list() + CacheService.set('gemini_file_list', response, 3000) + + return this.processResponse(response, file) + } + + private async processResponse(response: Pager, file: FileType) { + for await (const f of response) { + if (f.state === FileState.ACTIVE) { + if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) { + return f + } + } + } + + return undefined + } + + // @ts-ignore unused + private async listFiles(): Promise { + const files: File[] = [] + for await (const f of await this.sdkInstance!.files.list()) { + files.push(f) + } + return files + } + + // @ts-ignore unused + private async deleteFile(fileId: string) { + await this.sdkInstance!.files.delete({ name: fileId }) + } +} diff --git a/src/renderer/src/aiCore/clients/index.ts b/src/renderer/src/aiCore/clients/index.ts new file mode 100644 index 0000000000..ec7f9d9d7e --- /dev/null +++ b/src/renderer/src/aiCore/clients/index.ts @@ -0,0 +1,6 @@ +export * from './ApiClientFactory' +export * from './BaseApiClient' +export * from './types' + +// Export specific clients from subdirectories +export * from './openai/OpenAIApiClient' diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts b/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts new file mode 100644 index 0000000000..07359c837f --- /dev/null +++ b/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts @@ -0,0 +1,682 @@ +import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' +import Logger from '@renderer/config/logger' +import { + findTokenLimit, + GEMINI_FLASH_MODEL_REGEX, + getOpenAIWebSearchParams, + isDoubaoThinkingAutoModel, + isReasoningModel, + isSupportedReasoningEffortGrokModel, + isSupportedReasoningEffortModel, + isSupportedReasoningEffortOpenAIModel, + isSupportedThinkingTokenClaudeModel, + isSupportedThinkingTokenDoubaoModel, + isSupportedThinkingTokenGeminiModel, + isSupportedThinkingTokenModel, + isSupportedThinkingTokenQwenModel, + isVisionModel +} from '@renderer/config/models' +import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService' +import { estimateTextTokens } from '@renderer/services/TokenService' +// For Copilot token +import { + Assistant, + EFFORT_RATIO, + FileTypes, + MCPCallToolResponse, + MCPTool, + MCPToolResponse, + Model, + Provider, + ToolCallResponse, + WebSearchSource +} from '@renderer/types' +import { ChunkType } from '@renderer/types/chunk' +import { Message } from '@renderer/types/newMessage' +import { + OpenAISdkMessageParam, + OpenAISdkParams, + OpenAISdkRawChunk, + OpenAISdkRawContentSource, + OpenAISdkRawOutput, + ReasoningEffortOptionalParams +} from '@renderer/types/sdk' +import { addImageFileToContents } from '@renderer/utils/formats' +import { + isEnabledToolUse, + mcpToolCallResponseToOpenAICompatibleMessage, + mcpToolsToOpenAIChatTools, + openAIToolsToMcpTool +} from '@renderer/utils/mcp-tools' +import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' +import { buildSystemPrompt } from '@renderer/utils/prompt' +import OpenAI, { AzureOpenAI } from 'openai' +import { ChatCompletionContentPart, ChatCompletionContentPartRefusal, ChatCompletionTool } from 'openai/resources' + +import { GenericChunk } from '../../middleware/schemas' +import { RequestTransformer, ResponseChunkTransformer, ResponseChunkTransformerContext } from '../types' +import { OpenAIBaseClient } from './OpenAIBaseClient' + +export class OpenAIAPIClient extends OpenAIBaseClient< + OpenAI | AzureOpenAI, + OpenAISdkParams, + OpenAISdkRawOutput, + OpenAISdkRawChunk, + OpenAISdkMessageParam, + OpenAI.Chat.Completions.ChatCompletionMessageToolCall, + ChatCompletionTool +> { + constructor(provider: Provider) { + super(provider) + } + + override async createCompletions( + payload: OpenAISdkParams, + options?: OpenAI.RequestOptions + ): Promise { + const sdk = await this.getSdkInstance() + // @ts-ignore - SDK参数可能有额外的字段 + return await sdk.chat.completions.create(payload, options) + } + + /** + * Get the reasoning effort for the assistant + * @param assistant - The assistant + * @param model - The model + * @returns The reasoning effort + */ + // Method for reasoning effort, moved from OpenAIProvider + override getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams { + if (this.provider.id === 'groq') { + return {} + } + + if (!isReasoningModel(model)) { + return {} + } + const reasoningEffort = assistant?.settings?.reasoning_effort + + // Doubao 思考模式支持 + if (isSupportedThinkingTokenDoubaoModel(model)) { + // reasoningEffort 为空,默认开启 enabled + if (!reasoningEffort) { + return { thinking: { type: 'disabled' } } + } + if (reasoningEffort === 'high') { + return { thinking: { type: 'enabled' } } + } + if (reasoningEffort === 'auto' && isDoubaoThinkingAutoModel(model)) { + return { thinking: { type: 'auto' } } + } + // 其他情况不带 thinking 字段 + return {} + } + + if (!reasoningEffort) { + if (isSupportedThinkingTokenQwenModel(model)) { + return { enable_thinking: false } + } + + if (isSupportedThinkingTokenClaudeModel(model)) { + return {} + } + + if (isSupportedThinkingTokenGeminiModel(model)) { + // openrouter没有提供一个不推理的选项,先隐藏 + if (this.provider.id === 'openrouter') { + return { reasoning: { max_tokens: 0, exclude: true } } + } + if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) { + return { reasoning_effort: 'none' } + } + return {} + } + + if (isSupportedThinkingTokenDoubaoModel(model)) { + return { thinking: { type: 'disabled' } } + } + + return {} + } + const effortRatio = EFFORT_RATIO[reasoningEffort] + const budgetTokens = Math.floor( + (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min! + ) + + // OpenRouter models + if (model.provider === 'openrouter') { + if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) { + return { + reasoning: { + effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort + } + } + } + } + + // Qwen models + if (isSupportedThinkingTokenQwenModel(model)) { + return { + enable_thinking: true, + thinking_budget: budgetTokens + } + } + + // Grok models + if (isSupportedReasoningEffortGrokModel(model)) { + return { + reasoning_effort: reasoningEffort + } + } + + // OpenAI models + if (isSupportedReasoningEffortOpenAIModel(model) || isSupportedThinkingTokenGeminiModel(model)) { + return { + reasoning_effort: reasoningEffort + } + } + + // Claude models + if (isSupportedThinkingTokenClaudeModel(model)) { + const maxTokens = assistant.settings?.maxTokens + return { + thinking: { + type: 'enabled', + budget_tokens: Math.floor( + Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio)) + ) + } + } + } + + // Doubao models + if (isSupportedThinkingTokenDoubaoModel(model)) { + if (assistant.settings?.reasoning_effort === 'high') { + return { + thinking: { + type: 'enabled' + } + } + } + } + + // Default case: no special thinking settings + return {} + } + + /** + * Check if the provider does not support files + * @returns True if the provider does not support files, false otherwise + */ + private get isNotSupportFiles() { + if (this.provider?.isNotSupportArrayContent) { + return true + } + + const providers = ['deepseek', 'baichuan', 'minimax', 'xirang'] + + return providers.includes(this.provider.id) + } + + /** + * Get the message parameter + * @param message - The message + * @param model - The model + * @returns The message parameter + */ + public async convertMessageToSdkParam(message: Message, model: Model): Promise { + const isVision = isVisionModel(model) + const content = await this.getMessageContent(message) + const fileBlocks = findFileBlocks(message) + const imageBlocks = findImageBlocks(message) + + if (fileBlocks.length === 0 && imageBlocks.length === 0) { + return { + role: message.role === 'system' ? 'user' : message.role, + content + } as OpenAISdkMessageParam + } + + // If the model does not support files, extract the file content + if (this.isNotSupportFiles) { + const fileContent = await this.extractFileContent(message) + + return { + role: message.role === 'system' ? 'user' : message.role, + content: content + '\n\n---\n\n' + fileContent + } as OpenAISdkMessageParam + } + + // If the model supports files, add the file content to the message + const parts: ChatCompletionContentPart[] = [] + + if (content) { + parts.push({ type: 'text', text: content }) + } + + for (const imageBlock of imageBlocks) { + if (isVision) { + if (imageBlock.file) { + const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) + parts.push({ type: 'image_url', image_url: { url: image.data } }) + } else if (imageBlock.url && imageBlock.url.startsWith('data:')) { + parts.push({ type: 'image_url', image_url: { url: imageBlock.url } }) + } + } + } + + for (const fileBlock of fileBlocks) { + const file = fileBlock.file + if (!file) { + continue + } + + if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { + const fileContent = await (await window.api.file.read(file.id + file.ext)).trim() + parts.push({ + type: 'text', + text: file.origin_name + '\n' + fileContent + }) + } + } + + return { + role: message.role === 'system' ? 'user' : message.role, + content: parts + } as OpenAISdkMessageParam + } + + public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ChatCompletionTool[] { + return mcpToolsToOpenAIChatTools(mcpTools) + } + + public convertSdkToolCallToMcp( + toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall, + mcpTools: MCPTool[] + ): MCPTool | undefined { + return openAIToolsToMcpTool(mcpTools, toolCall) + } + + public convertSdkToolCallToMcpToolResponse( + toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall, + mcpTool: MCPTool + ): ToolCallResponse { + let parsedArgs: any + try { + parsedArgs = JSON.parse(toolCall.function.arguments) + } catch { + parsedArgs = toolCall.function.arguments + } + return { + id: toolCall.id, + toolCallId: toolCall.id, + tool: mcpTool, + arguments: parsedArgs, + status: 'pending' + } as ToolCallResponse + } + + public convertMcpToolResponseToSdkMessageParam( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ): OpenAISdkMessageParam | undefined { + if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { + // This case is for Anthropic/Claude like tool usage, OpenAI uses tool_call_id + // For OpenAI, we primarily expect toolCallId. This might need adjustment if mixing provider concepts. + return mcpToolCallResponseToOpenAICompatibleMessage(mcpToolResponse, resp, isVisionModel(model)) + } else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) { + return { + role: 'tool', + tool_call_id: mcpToolResponse.toolCallId, + content: JSON.stringify(resp.content) + } as OpenAI.Chat.Completions.ChatCompletionToolMessageParam + } + return undefined + } + + public buildSdkMessages( + currentReqMessages: OpenAISdkMessageParam[], + output: string, + toolResults: OpenAISdkMessageParam[], + toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] + ): OpenAISdkMessageParam[] { + const assistantMessage: OpenAISdkMessageParam = { + role: 'assistant', + content: output, + tool_calls: toolCalls.length > 0 ? toolCalls : undefined + } + const newReqMessages = [...currentReqMessages, assistantMessage, ...toolResults] + return newReqMessages + } + + override estimateMessageTokens(message: OpenAISdkMessageParam): number { + let sum = 0 + if (typeof message.content === 'string') { + sum += estimateTextTokens(message.content) + } else if (Array.isArray(message.content)) { + sum += (message.content || []) + .map((part: ChatCompletionContentPart | ChatCompletionContentPartRefusal) => { + switch (part.type) { + case 'text': + return estimateTextTokens(part.text) + case 'image_url': + return estimateTextTokens(part.image_url.url) + case 'input_audio': + return estimateTextTokens(part.input_audio.data) + case 'file': + return estimateTextTokens(part.file.file_data || '') + default: + return 0 + } + }) + .reduce((acc, curr) => acc + curr, 0) + } + if ('tool_calls' in message && message.tool_calls) { + sum += message.tool_calls.reduce((acc, toolCall) => { + return acc + estimateTextTokens(JSON.stringify(toolCall.function.arguments)) + }, 0) + } + return sum + } + + public extractMessagesFromSdkPayload(sdkPayload: OpenAISdkParams): OpenAISdkMessageParam[] { + return sdkPayload.messages || [] + } + + getRequestTransformer(): RequestTransformer { + return { + transform: async ( + coreRequest, + assistant, + model, + isRecursiveCall, + recursiveSdkMessages + ): Promise<{ + payload: OpenAISdkParams + messages: OpenAISdkMessageParam[] + metadata: Record + }> => { + const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest + // 1. 处理系统消息 + let systemMessage = { role: 'system', content: assistant.prompt || '' } + + if (isSupportedReasoningEffortOpenAIModel(model)) { + systemMessage = { + role: 'developer', + content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}` + } + } + + if (model.id.includes('o1-mini') || model.id.includes('o1-preview')) { + systemMessage.role = 'assistant' + } + + // 2. 设置工具(必须在this.usesystemPromptForTools前面) + const { tools } = this.setupToolsConfig({ + mcpTools: mcpTools, + model, + enableToolUse: isEnabledToolUse(assistant) + }) + + if (this.useSystemPromptForTools) { + systemMessage.content = await buildSystemPrompt(systemMessage.content || '', mcpTools) + } + + // 3. 处理用户消息 + const userMessages: OpenAISdkMessageParam[] = [] + if (typeof messages === 'string') { + userMessages.push({ role: 'user', content: messages }) + } else { + const processedMessages = addImageFileToContents(messages) + for (const message of processedMessages) { + userMessages.push(await this.convertMessageToSdkParam(message, model)) + } + } + + const lastUserMsg = userMessages.findLast((m) => m.role === 'user') + if (lastUserMsg && isSupportedThinkingTokenQwenModel(model)) { + const postsuffix = '/no_think' + const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true + const currentContent = lastUserMsg.content + + lastUserMsg.content = processPostsuffixQwen3Model(currentContent, postsuffix, qwenThinkModeEnabled) as any + } + + // 4. 最终请求消息 + let reqMessages: OpenAISdkMessageParam[] + if (!systemMessage.content) { + reqMessages = [...userMessages] + } else { + reqMessages = [systemMessage, ...userMessages].filter(Boolean) as OpenAISdkMessageParam[] + } + + reqMessages = processReqMessages(model, reqMessages) + + // 5. 创建通用参数 + const commonParams = { + model: model.id, + messages: + isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 + ? recursiveSdkMessages + : reqMessages, + temperature: this.getTemperature(assistant, model), + top_p: this.getTopP(assistant, model), + max_tokens: maxTokens, + tools: tools.length > 0 ? tools : undefined, + service_tier: this.getServiceTier(model), + ...this.getProviderSpecificParameters(assistant, model), + ...this.getReasoningEffort(assistant, model), + ...getOpenAIWebSearchParams(model, enableWebSearch), + ...this.getCustomParameters(assistant) + } + + // Create the appropriate parameters object based on whether streaming is enabled + const sdkParams: OpenAISdkParams = streamOutput + ? { + ...commonParams, + stream: true + } + : { + ...commonParams, + stream: false + } + + const timeout = this.getTimeout(model) + + return { payload: sdkParams, messages: reqMessages, metadata: { timeout } } + } + } + } + + // 在RawSdkChunkToGenericChunkMiddleware中使用 + getResponseChunkTransformer = (): ResponseChunkTransformer => { + let hasBeenCollectedWebSearch = false + const collectWebSearchData = ( + chunk: OpenAISdkRawChunk, + contentSource: OpenAISdkRawContentSource, + context: ResponseChunkTransformerContext + ) => { + if (hasBeenCollectedWebSearch) { + return + } + // OpenAI annotations + // @ts-ignore - annotations may not be in standard type definitions + const annotations = contentSource.annotations || chunk.annotations + if (annotations && annotations.length > 0 && annotations[0].type === 'url_citation') { + hasBeenCollectedWebSearch = true + return { + results: annotations, + source: WebSearchSource.OPENAI + } + } + + // Grok citations + // @ts-ignore - citations may not be in standard type definitions + if (context.provider?.id === 'grok' && chunk.citations) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - citations may not be in standard type definitions + results: chunk.citations, + source: WebSearchSource.GROK + } + } + + // Perplexity citations + // @ts-ignore - citations may not be in standard type definitions + if (context.provider?.id === 'perplexity' && chunk.citations && chunk.citations.length > 0) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - citations may not be in standard type definitions + results: chunk.citations, + source: WebSearchSource.PERPLEXITY + } + } + + // OpenRouter citations + // @ts-ignore - citations may not be in standard type definitions + if (context.provider?.id === 'openrouter' && chunk.citations && chunk.citations.length > 0) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - citations may not be in standard type definitions + results: chunk.citations, + source: WebSearchSource.OPENROUTER + } + } + + // Zhipu web search + // @ts-ignore - web_search may not be in standard type definitions + if (context.provider?.id === 'zhipu' && chunk.web_search) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - web_search may not be in standard type definitions + results: chunk.web_search, + source: WebSearchSource.ZHIPU + } + } + + // Hunyuan web search + // @ts-ignore - search_info may not be in standard type definitions + if (context.provider?.id === 'hunyuan' && chunk.search_info?.search_results) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - search_info may not be in standard type definitions + results: chunk.search_info.search_results, + source: WebSearchSource.HUNYUAN + } + } + + // TODO: 放到AnthropicApiClient中 + // // Other providers... + // // @ts-ignore - web_search may not be in standard type definitions + // if (chunk.web_search) { + // const sourceMap: Record = { + // openai: 'openai', + // anthropic: 'anthropic', + // qwenlm: 'qwen' + // } + // const source = sourceMap[context.provider?.id] || 'openai_response' + // return { + // results: chunk.web_search, + // source: source as const + // } + // } + + return null + } + const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = [] + return (context: ResponseChunkTransformerContext) => ({ + async transform(chunk: OpenAISdkRawChunk, controller: TransformStreamDefaultController) { + // 处理chunk + if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) { + const choice = chunk.choices[0] + + if (!choice) return + + // 对于流式响应,使用delta;对于非流式响应,使用message + const contentSource: OpenAISdkRawContentSource | null = + 'delta' in choice ? choice.delta : 'message' in choice ? choice.message : null + + if (!contentSource) return + + const webSearchData = collectWebSearchData(chunk, contentSource, context) + if (webSearchData) { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: webSearchData + }) + } + + // 处理推理内容 (e.g. from OpenRouter DeepSeek-R1) + // @ts-ignore - reasoning_content is not in standard OpenAI types but some providers use it + const reasoningText = contentSource.reasoning_content || contentSource.reasoning + if (reasoningText) { + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: reasoningText + }) + } + + // 处理文本内容 + if (contentSource.content) { + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: contentSource.content + }) + } + + // 处理工具调用 + if (contentSource.tool_calls) { + for (const toolCall of contentSource.tool_calls) { + if ('index' in toolCall) { + const { id, index, function: fun } = toolCall + if (fun?.name) { + toolCalls[index] = { + id: id || '', + function: { + name: fun.name, + arguments: fun.arguments || '' + }, + type: 'function' + } + } else if (fun?.arguments) { + toolCalls[index].function.arguments += fun.arguments + } + } else { + toolCalls.push(toolCall) + } + } + } + + // 处理finish_reason,发送流结束信号 + if ('finish_reason' in choice && choice.finish_reason) { + Logger.debug(`[OpenAIApiClient] Stream finished with reason: ${choice.finish_reason}`) + if (toolCalls.length > 0) { + controller.enqueue({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: toolCalls + }) + } + const webSearchData = collectWebSearchData(chunk, contentSource, context) + if (webSearchData) { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: webSearchData + }) + } + controller.enqueue({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + prompt_tokens: chunk.usage?.prompt_tokens || 0, + completion_tokens: chunk.usage?.completion_tokens || 0, + total_tokens: (chunk.usage?.prompt_tokens || 0) + (chunk.usage?.completion_tokens || 0) + } + } + }) + } + } + } + }) + } +} diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts b/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts new file mode 100644 index 0000000000..a44f25def8 --- /dev/null +++ b/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts @@ -0,0 +1,258 @@ +import { + isClaudeReasoningModel, + isNotSupportTemperatureAndTopP, + isOpenAIReasoningModel, + isSupportedModel, + isSupportedReasoningEffortOpenAIModel +} from '@renderer/config/models' +import { getStoreSetting } from '@renderer/hooks/useSettings' +import { getAssistantSettings } from '@renderer/services/AssistantService' +import store from '@renderer/store' +import { SettingsState } from '@renderer/store/settings' +import { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types' +import { + OpenAIResponseSdkMessageParam, + OpenAIResponseSdkParams, + OpenAIResponseSdkRawChunk, + OpenAIResponseSdkRawOutput, + OpenAIResponseSdkTool, + OpenAIResponseSdkToolCall, + OpenAISdkMessageParam, + OpenAISdkParams, + OpenAISdkRawChunk, + OpenAISdkRawOutput, + ReasoningEffortOptionalParams +} from '@renderer/types/sdk' +import { formatApiHost } from '@renderer/utils/api' +import OpenAI, { AzureOpenAI } from 'openai' + +import { BaseApiClient } from '../BaseApiClient' + +/** + * 抽象的OpenAI基础客户端类,包含两个OpenAI客户端之间的共享功能 + */ +export abstract class OpenAIBaseClient< + TSdkInstance extends OpenAI | AzureOpenAI, + TSdkParams extends OpenAISdkParams | OpenAIResponseSdkParams, + TRawOutput extends OpenAISdkRawOutput | OpenAIResponseSdkRawOutput, + TRawChunk extends OpenAISdkRawChunk | OpenAIResponseSdkRawChunk, + TMessageParam extends OpenAISdkMessageParam | OpenAIResponseSdkMessageParam, + TToolCall extends OpenAI.Chat.Completions.ChatCompletionMessageToolCall | OpenAIResponseSdkToolCall, + TSdkSpecificTool extends OpenAI.Chat.Completions.ChatCompletionTool | OpenAIResponseSdkTool +> extends BaseApiClient { + constructor(provider: Provider) { + super(provider) + } + + // 仅适用于openai + override getBaseURL(): string { + const host = this.provider.apiHost + return formatApiHost(host) + } + + override async generateImage({ + model, + prompt, + negativePrompt, + imageSize, + batchSize, + seed, + numInferenceSteps, + guidanceScale, + signal, + promptEnhancement + }: GenerateImageParams): Promise { + const sdk = await this.getSdkInstance() + const response = (await sdk.request({ + method: 'post', + path: '/images/generations', + signal, + body: { + model, + prompt, + negative_prompt: negativePrompt, + image_size: imageSize, + batch_size: batchSize, + seed: seed ? parseInt(seed) : undefined, + num_inference_steps: numInferenceSteps, + guidance_scale: guidanceScale, + prompt_enhancement: promptEnhancement + } + })) as { data: Array<{ url: string }> } + + return response.data.map((item) => item.url) + } + + override async getEmbeddingDimensions(model: Model): Promise { + const sdk = await this.getSdkInstance() + try { + const data = await sdk.embeddings.create({ + model: model.id, + input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi', + encoding_format: 'float' + }) + return data.data[0].embedding.length + } catch (e) { + return 0 + } + } + + override async listModels(): Promise { + try { + const sdk = await this.getSdkInstance() + const response = await sdk.models.list() + if (this.provider.id === 'github') { + // @ts-ignore key is not typed + return response?.body + .map((model) => ({ + id: model.name, + description: model.summary, + object: 'model', + owned_by: model.publisher + })) + .filter(isSupportedModel) + } + if (this.provider.id === 'together') { + // @ts-ignore key is not typed + return response?.body.map((model) => ({ + id: model.id, + description: model.display_name, + object: 'model', + owned_by: model.organization + })) + } + const models = response.data || [] + models.forEach((model) => { + model.id = model.id.trim() + }) + + return models.filter(isSupportedModel) + } catch (error) { + console.error('Error listing models:', error) + return [] + } + } + + override async getSdkInstance() { + if (this.sdkInstance) { + return this.sdkInstance + } + + let apiKeyForSdkInstance = this.provider.apiKey + + if (this.provider.id === 'copilot') { + const defaultHeaders = store.getState().copilot.defaultHeaders + const { token } = await window.api.copilot.getToken(defaultHeaders) + // this.provider.apiKey不允许修改 + // this.provider.apiKey = token + apiKeyForSdkInstance = token + } + + if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') { + this.sdkInstance = new AzureOpenAI({ + dangerouslyAllowBrowser: true, + apiKey: apiKeyForSdkInstance, + apiVersion: this.provider.apiVersion, + endpoint: this.provider.apiHost + }) as TSdkInstance + } else { + this.sdkInstance = new OpenAI({ + dangerouslyAllowBrowser: true, + apiKey: apiKeyForSdkInstance, + baseURL: this.getBaseURL(), + defaultHeaders: { + ...this.defaultHeaders(), + ...(this.provider.id === 'copilot' ? { 'editor-version': 'vscode/1.97.2' } : {}), + ...(this.provider.id === 'copilot' ? { 'copilot-vision-request': 'true' } : {}) + } + }) as TSdkInstance + } + return this.sdkInstance + } + + override getTemperature(assistant: Assistant, model: Model): number | undefined { + if ( + isNotSupportTemperatureAndTopP(model) || + (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) + ) { + return undefined + } + return assistant.settings?.temperature + } + + override getTopP(assistant: Assistant, model: Model): number | undefined { + if ( + isNotSupportTemperatureAndTopP(model) || + (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) + ) { + return undefined + } + return assistant.settings?.topP + } + + /** + * Get the provider specific parameters for the assistant + * @param assistant - The assistant + * @param model - The model + * @returns The provider specific parameters + */ + protected getProviderSpecificParameters(assistant: Assistant, model: Model) { + const { maxTokens } = getAssistantSettings(assistant) + + if (this.provider.id === 'openrouter') { + if (model.id.includes('deepseek-r1')) { + return { + include_reasoning: true + } + } + } + + if (isOpenAIReasoningModel(model)) { + return { + max_tokens: undefined, + max_completion_tokens: maxTokens + } + } + + return {} + } + + /** + * Get the reasoning effort for the assistant + * @param assistant - The assistant + * @param model - The model + * @returns The reasoning effort + */ + protected getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams { + if (!isSupportedReasoningEffortOpenAIModel(model)) { + return {} + } + + const openAI = getStoreSetting('openAI') as SettingsState['openAI'] + const summaryText = openAI?.summaryText || 'off' + + let summary: string | undefined = undefined + + if (summaryText === 'off' || model.id.includes('o1-pro')) { + summary = undefined + } else { + summary = summaryText + } + + const reasoningEffort = assistant?.settings?.reasoning_effort + if (!reasoningEffort) { + return {} + } + + if (isSupportedReasoningEffortOpenAIModel(model)) { + return { + reasoning: { + effort: reasoningEffort as OpenAI.ReasoningEffort, + summary: summary + } as OpenAI.Reasoning + } + } + + return {} + } +} diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts b/src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts new file mode 100644 index 0000000000..0fdd65f709 --- /dev/null +++ b/src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts @@ -0,0 +1,532 @@ +import { GenericChunk } from '@renderer/aiCore/middleware/schemas' +import { + isOpenAIChatCompletionOnlyModel, + isSupportedReasoningEffortOpenAIModel, + isVisionModel +} from '@renderer/config/models' +import { estimateTextTokens } from '@renderer/services/TokenService' +import { + FileTypes, + MCPCallToolResponse, + MCPTool, + MCPToolResponse, + Model, + Provider, + ToolCallResponse, + WebSearchSource +} from '@renderer/types' +import { ChunkType } from '@renderer/types/chunk' +import { Message } from '@renderer/types/newMessage' +import { + OpenAIResponseSdkMessageParam, + OpenAIResponseSdkParams, + OpenAIResponseSdkRawChunk, + OpenAIResponseSdkRawOutput, + OpenAIResponseSdkTool, + OpenAIResponseSdkToolCall +} from '@renderer/types/sdk' +import { addImageFileToContents } from '@renderer/utils/formats' +import { + isEnabledToolUse, + mcpToolCallResponseToOpenAIMessage, + mcpToolsToOpenAIResponseTools, + openAIToolsToMcpTool +} from '@renderer/utils/mcp-tools' +import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' +import { buildSystemPrompt } from '@renderer/utils/prompt' +import { isEmpty } from 'lodash' +import OpenAI from 'openai' + +import { RequestTransformer, ResponseChunkTransformer } from '../types' +import { OpenAIAPIClient } from './OpenAIApiClient' +import { OpenAIBaseClient } from './OpenAIBaseClient' + +export class OpenAIResponseAPIClient extends OpenAIBaseClient< + OpenAI, + OpenAIResponseSdkParams, + OpenAIResponseSdkRawOutput, + OpenAIResponseSdkRawChunk, + OpenAIResponseSdkMessageParam, + OpenAIResponseSdkToolCall, + OpenAIResponseSdkTool +> { + private client: OpenAIAPIClient + constructor(provider: Provider) { + super(provider) + this.client = new OpenAIAPIClient(provider) + } + + /** + * 根据模型特征选择合适的客户端 + */ + public getClient(model: Model) { + if (isOpenAIChatCompletionOnlyModel(model)) { + return this.client + } else { + return this + } + } + + override async getSdkInstance() { + if (this.sdkInstance) { + return this.sdkInstance + } + + return new OpenAI({ + dangerouslyAllowBrowser: true, + apiKey: this.provider.apiKey, + baseURL: this.getBaseURL(), + defaultHeaders: { + ...this.defaultHeaders() + } + }) + } + + override async createCompletions( + payload: OpenAIResponseSdkParams, + options?: OpenAI.RequestOptions + ): Promise { + const sdk = await this.getSdkInstance() + return await sdk.responses.create(payload, options) + } + + public async convertMessageToSdkParam(message: Message, model: Model): Promise { + const isVision = isVisionModel(model) + const content = await this.getMessageContent(message) + const fileBlocks = findFileBlocks(message) + const imageBlocks = findImageBlocks(message) + + if (fileBlocks.length === 0 && imageBlocks.length === 0) { + if (message.role === 'assistant') { + return { + role: 'assistant', + content: content + } + } else { + return { + role: message.role === 'system' ? 'user' : message.role, + content: content ? [{ type: 'input_text', text: content }] : [] + } as OpenAI.Responses.EasyInputMessage + } + } + + const parts: OpenAI.Responses.ResponseInputContent[] = [] + if (content) { + parts.push({ + type: 'input_text', + text: content + }) + } + + for (const imageBlock of imageBlocks) { + if (isVision) { + if (imageBlock.file) { + const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) + parts.push({ + detail: 'auto', + type: 'input_image', + image_url: image.data as string + }) + } else if (imageBlock.url && imageBlock.url.startsWith('data:')) { + parts.push({ + detail: 'auto', + type: 'input_image', + image_url: imageBlock.url + }) + } + } + } + + for (const fileBlock of fileBlocks) { + const file = fileBlock.file + if (!file) continue + + if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { + const fileContent = (await window.api.file.read(file.id + file.ext)).trim() + parts.push({ + type: 'input_text', + text: file.origin_name + '\n' + fileContent + }) + } + } + + return { + role: message.role === 'system' ? 'user' : message.role, + content: parts + } + } + + public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): OpenAI.Responses.Tool[] { + return mcpToolsToOpenAIResponseTools(mcpTools) + } + + public convertSdkToolCallToMcp(toolCall: OpenAIResponseSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined { + return openAIToolsToMcpTool(mcpTools, toolCall) + } + public convertSdkToolCallToMcpToolResponse(toolCall: OpenAIResponseSdkToolCall, mcpTool: MCPTool): ToolCallResponse { + const parsedArgs = (() => { + try { + return JSON.parse(toolCall.arguments) + } catch { + return toolCall.arguments + } + })() + + return { + id: toolCall.call_id, + toolCallId: toolCall.call_id, + tool: mcpTool, + arguments: parsedArgs, + status: 'pending' + } + } + + public convertMcpToolResponseToSdkMessageParam( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ): OpenAIResponseSdkMessageParam | undefined { + if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { + return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model)) + } else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) { + return { + type: 'function_call_output', + call_id: mcpToolResponse.toolCallId, + output: JSON.stringify(resp.content) + } + } + return + } + + public buildSdkMessages( + currentReqMessages: OpenAIResponseSdkMessageParam[], + output: string, + toolResults: OpenAIResponseSdkMessageParam[], + toolCalls: OpenAIResponseSdkToolCall[] + ): OpenAIResponseSdkMessageParam[] { + const assistantMessage: OpenAIResponseSdkMessageParam = { + role: 'assistant', + content: [{ type: 'input_text', text: output }] + } + const newReqMessages = [...currentReqMessages, assistantMessage, ...(toolCalls || []), ...(toolResults || [])] + return newReqMessages + } + + override estimateMessageTokens(message: OpenAIResponseSdkMessageParam): number { + let sum = 0 + if ('content' in message) { + if (typeof message.content === 'string') { + sum += estimateTextTokens(message.content) + } else if (Array.isArray(message.content)) { + for (const part of message.content) { + switch (part.type) { + case 'input_text': + sum += estimateTextTokens(part.text) + break + case 'input_image': + sum += estimateTextTokens(part.image_url || '') + break + default: + break + } + } + } + } + switch (message.type) { + case 'function_call_output': + sum += estimateTextTokens(message.output) + break + case 'function_call': + sum += estimateTextTokens(message.arguments) + break + default: + break + } + return sum + } + + public extractMessagesFromSdkPayload(sdkPayload: OpenAIResponseSdkParams): OpenAIResponseSdkMessageParam[] { + if (typeof sdkPayload.input === 'string') { + return [{ role: 'user', content: sdkPayload.input }] + } + return sdkPayload.input + } + + getRequestTransformer(): RequestTransformer { + return { + transform: async ( + coreRequest, + assistant, + model, + isRecursiveCall, + recursiveSdkMessages + ): Promise<{ + payload: OpenAIResponseSdkParams + messages: OpenAIResponseSdkMessageParam[] + metadata: Record + }> => { + const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch, enableGenerateImage } = coreRequest + // 1. 处理系统消息 + const systemMessage: OpenAI.Responses.EasyInputMessage = { + role: 'system', + content: [] + } + + const systemMessageContent: OpenAI.Responses.ResponseInputMessageContentList = [] + const systemMessageInput: OpenAI.Responses.ResponseInputText = { + text: assistant.prompt || '', + type: 'input_text' + } + if (isSupportedReasoningEffortOpenAIModel(model)) { + systemMessage.role = 'developer' + } + + // 2. 设置工具 + let tools: OpenAI.Responses.Tool[] = [] + const { tools: extraTools } = this.setupToolsConfig({ + mcpTools: mcpTools, + model, + enableToolUse: isEnabledToolUse(assistant) + }) + + if (this.useSystemPromptForTools) { + systemMessageInput.text = await buildSystemPrompt(systemMessageInput.text || '', mcpTools) + } + systemMessageContent.push(systemMessageInput) + systemMessage.content = systemMessageContent + + // 3. 处理用户消息 + let userMessage: OpenAI.Responses.ResponseInputItem[] = [] + if (typeof messages === 'string') { + userMessage.push({ role: 'user', content: messages }) + } else { + const processedMessages = addImageFileToContents(messages) + for (const message of processedMessages) { + userMessage.push(await this.convertMessageToSdkParam(message, model)) + } + } + // FIXME: 最好还是直接使用previous_response_id来处理(或者在数据库中存储image_generation_call的id) + if (enableGenerateImage) { + const finalAssistantMessage = userMessage.findLast( + (m) => (m as OpenAI.Responses.EasyInputMessage).role === 'assistant' + ) as OpenAI.Responses.EasyInputMessage + const finalUserMessage = userMessage.pop() as OpenAI.Responses.EasyInputMessage + if ( + finalAssistantMessage && + Array.isArray(finalAssistantMessage.content) && + finalUserMessage && + Array.isArray(finalUserMessage.content) + ) { + finalAssistantMessage.content = [...finalAssistantMessage.content, ...finalUserMessage.content] + } + // 这里是故意将上条助手消息的内容(包含图片和文件)作为用户消息发送 + userMessage = [{ ...finalAssistantMessage, role: 'user' } as OpenAI.Responses.EasyInputMessage] + } + + // 4. 最终请求消息 + let reqMessages: OpenAI.Responses.ResponseInput + if (!systemMessage.content) { + reqMessages = [...userMessage] + } else { + reqMessages = [systemMessage, ...userMessage].filter(Boolean) as OpenAI.Responses.EasyInputMessage[] + } + + if (enableWebSearch) { + tools.push({ + type: 'web_search_preview' + }) + } + + if (enableGenerateImage) { + tools.push({ + type: 'image_generation', + partial_images: streamOutput ? 2 : undefined + }) + } + + const toolChoices: OpenAI.Responses.ToolChoiceTypes = { + type: 'web_search_preview' + } + + tools = tools.concat(extraTools) + const commonParams = { + model: model.id, + input: + isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 + ? recursiveSdkMessages + : reqMessages, + temperature: this.getTemperature(assistant, model), + top_p: this.getTopP(assistant, model), + max_output_tokens: maxTokens, + stream: streamOutput, + tools: !isEmpty(tools) ? tools : undefined, + tool_choice: enableWebSearch ? toolChoices : undefined, + service_tier: this.getServiceTier(model), + ...(this.getReasoningEffort(assistant, model) as OpenAI.Reasoning), + ...this.getCustomParameters(assistant) + } + const sdkParams: OpenAIResponseSdkParams = streamOutput + ? { + ...commonParams, + stream: true + } + : { + ...commonParams, + stream: false + } + const timeout = this.getTimeout(model) + return { payload: sdkParams, messages: reqMessages, metadata: { timeout } } + } + } + } + + getResponseChunkTransformer(): ResponseChunkTransformer { + const toolCalls: OpenAIResponseSdkToolCall[] = [] + const outputItems: OpenAI.Responses.ResponseOutputItem[] = [] + return () => ({ + async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController) { + // 处理chunk + if ('output' in chunk) { + for (const output of chunk.output) { + switch (output.type) { + case 'message': + if (output.content[0].type === 'output_text') { + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: output.content[0].text + }) + if (output.content[0].annotations && output.content[0].annotations.length > 0) { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + source: WebSearchSource.OPENAI_RESPONSE, + results: output.content[0].annotations + } + }) + } + } + break + case 'reasoning': + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: output.summary.map((s) => s.text).join('\n') + }) + break + case 'function_call': + toolCalls.push(output) + break + case 'image_generation_call': + controller.enqueue({ + type: ChunkType.IMAGE_CREATED + }) + controller.enqueue({ + type: ChunkType.IMAGE_COMPLETE, + image: { + type: 'base64', + images: [`data:image/png;base64,${output.result}`] + } + }) + } + } + } else { + switch (chunk.type) { + case 'response.output_item.added': + if (chunk.item.type === 'function_call') { + outputItems.push(chunk.item) + } + break + case 'response.reasoning_summary_text.delta': + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: chunk.delta + }) + break + case 'response.image_generation_call.generating': + controller.enqueue({ + type: ChunkType.IMAGE_CREATED + }) + break + case 'response.image_generation_call.partial_image': + controller.enqueue({ + type: ChunkType.IMAGE_DELTA, + image: { + type: 'base64', + images: [`data:image/png;base64,${chunk.partial_image_b64}`] + } + }) + break + case 'response.image_generation_call.completed': + controller.enqueue({ + type: ChunkType.IMAGE_COMPLETE + }) + break + case 'response.output_text.delta': { + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: chunk.delta + }) + break + } + case 'response.function_call_arguments.done': { + const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find( + (item) => item.id === chunk.item_id + ) + if (outputItem) { + if (outputItem.type === 'function_call') { + toolCalls.push({ + ...outputItem, + arguments: chunk.arguments + }) + } + } + break + } + case 'response.content_part.done': { + if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + source: WebSearchSource.OPENAI_RESPONSE, + results: chunk.part.annotations + } + }) + } + if (toolCalls.length > 0) { + controller.enqueue({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: toolCalls + }) + } + break + } + case 'response.completed': { + const completion_tokens = chunk.response.usage?.output_tokens || 0 + const total_tokens = chunk.response.usage?.total_tokens || 0 + controller.enqueue({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + prompt_tokens: chunk.response.usage?.input_tokens || 0, + completion_tokens: completion_tokens, + total_tokens: total_tokens + } + } + }) + break + } + case 'error': { + controller.enqueue({ + type: ChunkType.ERROR, + error: { + message: chunk.message, + code: chunk.code + } + }) + break + } + } + } + } + }) + } +} diff --git a/src/renderer/src/aiCore/clients/types.ts b/src/renderer/src/aiCore/clients/types.ts new file mode 100644 index 0000000000..84562a13e9 --- /dev/null +++ b/src/renderer/src/aiCore/clients/types.ts @@ -0,0 +1,129 @@ +import Anthropic from '@anthropic-ai/sdk' +import { Assistant, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types' +import { Provider } from '@renderer/types' +import { + AnthropicSdkRawChunk, + OpenAISdkRawChunk, + SdkMessageParam, + SdkParams, + SdkRawChunk, + SdkRawOutput, + SdkTool, + SdkToolCall +} from '@renderer/types/sdk' +import OpenAI from 'openai' + +import { CompletionsParams, GenericChunk } from '../middleware/schemas' + +/** + * 原始流监听器接口 + */ +export interface RawStreamListener { + onChunk?: (chunk: TRawChunk) => void + onStart?: () => void + onEnd?: () => void + onError?: (error: Error) => void +} + +/** + * OpenAI 专用的流监听器 + */ +export interface OpenAIStreamListener extends RawStreamListener { + onChoice?: (choice: OpenAI.Chat.Completions.ChatCompletionChunk.Choice) => void + onFinishReason?: (reason: string) => void +} + +/** + * Anthropic 专用的流监听器 + */ +export interface AnthropicStreamListener + extends RawStreamListener { + onContentBlock?: (contentBlock: Anthropic.Messages.ContentBlock) => void + onMessage?: (message: Anthropic.Messages.Message) => void +} + +/** + * 请求转换器接口 + */ +export interface RequestTransformer< + TSdkParams extends SdkParams = SdkParams, + TMessageParam extends SdkMessageParam = SdkMessageParam +> { + transform( + completionsParams: CompletionsParams, + assistant: Assistant, + model: Model, + isRecursiveCall?: boolean, + recursiveSdkMessages?: TMessageParam[] + ): Promise<{ + payload: TSdkParams + messages: TMessageParam[] + metadata?: Record + }> +} + +/** + * 响应块转换器接口 + */ +export type ResponseChunkTransformer = ( + context?: TContext +) => Transformer + +export interface ResponseChunkTransformerContext { + isStreaming: boolean + isEnabledToolCalling: boolean + isEnabledWebSearch: boolean + isEnabledReasoning: boolean + mcpTools: MCPTool[] + provider: Provider +} + +/** + * API客户端接口 + */ +export interface ApiClient< + TSdkInstance = any, + TSdkParams extends SdkParams = SdkParams, + TRawOutput extends SdkRawOutput = SdkRawOutput, + TRawChunk extends SdkRawChunk = SdkRawChunk, + TMessageParam extends SdkMessageParam = SdkMessageParam, + TToolCall extends SdkToolCall = SdkToolCall, + TSdkSpecificTool extends SdkTool = SdkTool +> { + provider: Provider + + // 核心方法 - 在中间件架构中,这个方法可能只是一个占位符 + // 实际的SDK调用由SdkCallMiddleware处理 + // completions(params: CompletionsParams): Promise + + createCompletions(payload: TSdkParams): Promise + + // SDK相关方法 + getSdkInstance(): Promise | TSdkInstance + getRequestTransformer(): RequestTransformer + getResponseChunkTransformer(): ResponseChunkTransformer + + // 原始流监听方法 + attachRawStreamListener?(rawOutput: TRawOutput, listener: RawStreamListener): TRawOutput + + // 工具转换相关方法 (保持可选,因为不是所有Provider都支持工具) + convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[] + convertMcpToolResponseToSdkMessageParam?( + mcpToolResponse: MCPToolResponse, + resp: any, + model: Model + ): TMessageParam | undefined + convertSdkToolCallToMcp?(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined + convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse + + // 构建SDK特定的消息列表,用于工具调用后的递归调用 + buildSdkMessages( + currentReqMessages: TMessageParam[], + output: TRawOutput | string, + toolResults: TMessageParam[], + toolCalls?: TToolCall[] + ): TMessageParam[] + + // 从SDK载荷中提取消息数组(用于中间件中的类型安全访问) + extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[] +} diff --git a/src/renderer/src/aiCore/index.ts b/src/renderer/src/aiCore/index.ts new file mode 100644 index 0000000000..c7bffa17de --- /dev/null +++ b/src/renderer/src/aiCore/index.ts @@ -0,0 +1,130 @@ +import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory' +import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient' +import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models' +import type { GenerateImageParams, Model, Provider } from '@renderer/types' +import { RequestOptions, SdkModel } from '@renderer/types/sdk' +import { isEnabledToolUse } from '@renderer/utils/mcp-tools' + +import { OpenAIAPIClient } from './clients' +import { AihubmixAPIClient } from './clients/AihubmixAPIClient' +import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient' +import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient' +import { CompletionsMiddlewareBuilder } from './middleware/builder' +import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware' +import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware' +import { applyCompletionsMiddlewares } from './middleware/composer' +import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware' +import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware' +import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware' +import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware' +import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware' +import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware' +import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware' +import { MiddlewareRegistry } from './middleware/register' +import { CompletionsParams, CompletionsResult } from './middleware/schemas' + +export default class AiProvider { + private apiClient: BaseApiClient + + constructor(provider: Provider) { + // Use the new ApiClientFactory to get a BaseApiClient instance + this.apiClient = ApiClientFactory.create(provider) + } + + public async completions(params: CompletionsParams, options?: RequestOptions): Promise { + // 1. 根据模型识别正确的客户端 + const model = params.assistant.model + if (!model) { + return Promise.reject(new Error('Model is required')) + } + + // 根据client类型选择合适的处理方式 + let client: BaseApiClient + + if (this.apiClient instanceof AihubmixAPIClient) { + // AihubmixAPIClient: 根据模型选择合适的子client + client = this.apiClient.getClientForModel(model) + if (client instanceof OpenAIResponseAPIClient) { + client = client.getClient(model) as BaseApiClient + } + } else if (this.apiClient instanceof OpenAIResponseAPIClient) { + // OpenAIResponseAPIClient: 根据模型特征选择API类型 + client = this.apiClient.getClient(model) as BaseApiClient + } else { + // 其他client直接使用 + client = this.apiClient + } + + // 2. 构建中间件链 + const builder = CompletionsMiddlewareBuilder.withDefaults() + // images api + if (isDedicatedImageGenerationModel(model)) { + builder.clear() + builder + .add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName]) + .add(MiddlewareRegistry[AbortHandlerMiddlewareName]) + .add(MiddlewareRegistry[ImageGenerationMiddlewareName]) + } else { + // Existing logic for other models + if (!params.enableReasoning) { + builder.remove(ThinkingTagExtractionMiddlewareName) + builder.remove(ThinkChunkMiddlewareName) + } + // 注意:用client判断会导致typescript类型收窄 + if (!(this.apiClient instanceof OpenAIAPIClient)) { + builder.remove(ThinkingTagExtractionMiddlewareName) + } + if (!(this.apiClient instanceof AnthropicAPIClient)) { + builder.remove(RawStreamListenerMiddlewareName) + } + if (!params.enableWebSearch) { + builder.remove(WebSearchMiddlewareName) + } + if (!params.mcpTools?.length) { + builder.remove(ToolUseExtractionMiddlewareName) + builder.remove(McpToolChunkMiddlewareName) + } + if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) { + builder.remove(ToolUseExtractionMiddlewareName) + } + if (params.callType !== 'chat') { + builder.remove(AbortHandlerMiddlewareName) + } + } + + const middlewares = builder.build() + + // 3. Create the wrapped SDK method with middlewares + const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares) + + // 4. Execute the wrapped method with the original params + return wrappedCompletionMethod(params, options) + } + + public async models(): Promise { + return this.apiClient.listModels() + } + + public async getEmbeddingDimensions(model: Model): Promise { + try { + // Use the SDK instance to test embedding capabilities + const dimensions = await this.apiClient.getEmbeddingDimensions(model) + return dimensions + } catch (error) { + console.error('Error getting embedding dimensions:', error) + return 0 + } + } + + public async generateImage(params: GenerateImageParams): Promise { + return this.apiClient.generateImage(params) + } + + public getBaseURL(): string { + return this.apiClient.getBaseURL() + } + + public getApiKey(): string { + return this.apiClient.getApiKey() + } +} diff --git a/src/renderer/src/aiCore/middleware/BUILDER_USAGE.md b/src/renderer/src/aiCore/middleware/BUILDER_USAGE.md new file mode 100644 index 0000000000..27d9e32136 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/BUILDER_USAGE.md @@ -0,0 +1,182 @@ +# MiddlewareBuilder 使用指南 + +`MiddlewareBuilder` 是一个用于动态构建和管理中间件链的工具,提供灵活的中间件组织和配置能力。 + +## 主要特性 + +### 1. 统一的中间件命名 + +所有中间件都通过导出的 `MIDDLEWARE_NAME` 常量标识: + +```typescript +// 中间件文件示例 +export const MIDDLEWARE_NAME = 'SdkCallMiddleware' +export const SdkCallMiddleware: CompletionsMiddleware = ... +``` + +### 2. NamedMiddleware 接口 + +中间件使用统一的 `NamedMiddleware` 接口格式: + +```typescript +interface NamedMiddleware { + name: string + middleware: TMiddleware +} +``` + +### 3. 中间件注册表 + +通过 `MiddlewareRegistry` 集中管理所有可用中间件: + +```typescript +import { MiddlewareRegistry } from './register' + +// 通过名称获取中间件 +const sdkCallMiddleware = MiddlewareRegistry['SdkCallMiddleware'] +``` + +## 基本用法 + +### 1. 使用默认中间件链 + +```typescript +import { CompletionsMiddlewareBuilder } from './builder' + +const builder = CompletionsMiddlewareBuilder.withDefaults() +const middlewares = builder.build() +``` + +### 2. 自定义中间件链 + +```typescript +import { createCompletionsBuilder, MiddlewareRegistry } from './builder' + +const builder = createCompletionsBuilder([ + MiddlewareRegistry['AbortHandlerMiddleware'], + MiddlewareRegistry['TextChunkMiddleware'] +]) + +const middlewares = builder.build() +``` + +### 3. 动态调整中间件链 + +```typescript +const builder = CompletionsMiddlewareBuilder.withDefaults() + +// 根据条件添加、移除、替换中间件 +if (needsLogging) { + builder.prepend(MiddlewareRegistry['GenericLoggingMiddleware']) +} + +if (disableTools) { + builder.remove('McpToolChunkMiddleware') +} + +if (customThinking) { + builder.replace('ThinkingTagExtractionMiddleware', customThinkingMiddleware) +} + +const middlewares = builder.build() +``` + +### 4. 链式操作 + +```typescript +const middlewares = CompletionsMiddlewareBuilder.withDefaults() + .add(MiddlewareRegistry['CustomMiddleware']) + .insertBefore('SdkCallMiddleware', MiddlewareRegistry['SecurityCheckMiddleware']) + .remove('WebSearchMiddleware') + .build() +``` + +## API 参考 + +### CompletionsMiddlewareBuilder + +**静态方法:** + +- `static withDefaults()`: 创建带有默认中间件链的构建器 + +**实例方法:** + +- `add(middleware: NamedMiddleware)`: 在链末尾添加中间件 +- `prepend(middleware: NamedMiddleware)`: 在链开头添加中间件 +- `insertAfter(targetName: string, middleware: NamedMiddleware)`: 在指定中间件后插入 +- `insertBefore(targetName: string, middleware: NamedMiddleware)`: 在指定中间件前插入 +- `replace(targetName: string, middleware: NamedMiddleware)`: 替换指定中间件 +- `remove(targetName: string)`: 移除指定中间件 +- `has(name: string)`: 检查是否包含指定中间件 +- `build()`: 构建最终的中间件数组 +- `getChain()`: 获取当前链(包含名称信息) +- `clear()`: 清空中间件链 +- `execute(context, params, middlewareExecutor)`: 直接执行构建好的中间件链 + +### 工厂函数 + +- `createCompletionsBuilder(baseChain?)`: 创建 Completions 中间件构建器 +- `createMethodBuilder(baseChain?)`: 创建通用方法中间件构建器 +- `addMiddlewareName(middleware, name)`: 为中间件添加名称属性的辅助函数 + +### 中间件注册表 + +- `MiddlewareRegistry`: 所有注册中间件的集中访问点 +- `getMiddleware(name)`: 根据名称获取中间件 +- `getRegisteredMiddlewareNames()`: 获取所有注册的中间件名称 +- `DefaultCompletionsNamedMiddlewares`: 默认的 Completions 中间件链(NamedMiddleware 格式) + +## 类型安全 + +构建器提供完整的 TypeScript 类型支持: + +- `CompletionsMiddlewareBuilder` 专门用于 `CompletionsMiddleware` 类型 +- `MethodMiddlewareBuilder` 用于通用的 `MethodMiddleware` 类型 +- 所有中间件操作都基于 `NamedMiddleware` 接口 + +## 默认中间件链 + +默认的 Completions 中间件执行顺序: + +1. `FinalChunkConsumerMiddleware` - 最终消费者 +2. `TransformCoreToSdkParamsMiddleware` - 参数转换 +3. `AbortHandlerMiddleware` - 中止处理 +4. `McpToolChunkMiddleware` - 工具处理 +5. `WebSearchMiddleware` - Web搜索处理 +6. `TextChunkMiddleware` - 文本处理 +7. `ThinkingTagExtractionMiddleware` - 思考标签提取处理 +8. `ThinkChunkMiddleware` - 思考处理 +9. `ResponseTransformMiddleware` - 响应转换 +10. `StreamAdapterMiddleware` - 流适配器 +11. `SdkCallMiddleware` - SDK调用 + +## 在 AiProvider 中的使用 + +```typescript +export default class AiProvider { + public async completions(params: CompletionsParams): Promise { + // 1. 构建中间件链 + const builder = CompletionsMiddlewareBuilder.withDefaults() + + // 2. 根据参数动态调整 + if (params.enableCustomFeature) { + builder.insertAfter('StreamAdapterMiddleware', customFeatureMiddleware) + } + + // 3. 应用中间件 + const middlewares = builder.build() + const wrappedMethod = applyCompletionsMiddlewares(this.apiClient, this.apiClient.createCompletions, middlewares) + + return wrappedMethod(params) + } +} +``` + +## 注意事项 + +1. **类型兼容性**:`MethodMiddleware` 和 `CompletionsMiddleware` 不兼容,需要使用对应的构建器 +2. **中间件名称**:所有中间件必须导出 `MIDDLEWARE_NAME` 常量用于标识 +3. **注册表管理**:新增中间件需要在 `register.ts` 中注册 +4. **默认链**:默认链通过 `DefaultCompletionsNamedMiddlewares` 提供,支持延迟加载避免循环依赖 + +这种设计使得中间件链的构建既灵活又类型安全,同时保持了简洁的 API 接口。 diff --git a/src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md b/src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md new file mode 100644 index 0000000000..6437282ff2 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md @@ -0,0 +1,175 @@ +# Cherry Studio 中间件规范 + +本文档定义了 Cherry Studio `aiCore` 模块中中间件的设计、实现和使用规范。目标是建立一个灵活、可维护且易于扩展的中间件系统。 + +## 1. 核心概念 + +### 1.1. 中间件 (Middleware) + +中间件是一个函数或对象,它在 AI 请求的处理流程中的特定阶段执行,可以访问和修改请求上下文 (`AiProviderMiddlewareContext`)、请求参数 (`Params`),并控制是否将请求传递给下一个中间件或终止流程。 + +每个中间件应该专注于一个单一的横切关注点,例如日志记录、错误处理、流适配、特性解析等。 + +### 1.2. `AiProviderMiddlewareContext` (上下文对象) + +这是一个在整个中间件链执行过程中传递的对象,包含以下核心信息: + +- `_apiClientInstance: ApiClient`: 当前选定的、已实例化的 AI Provider 客户端。 +- `_coreRequest: CoreRequestType`: 标准化的内部核心请求对象。 +- `resolvePromise: (value: AggregatedResultType) => void`: 用于在整个操作成功完成时解析 `AiCoreService` 返回的 Promise。 +- `rejectPromise: (reason?: any) => void`: 用于在发生错误时拒绝 `AiCoreService` 返回的 Promise。 +- `onChunk?: (chunk: Chunk) => void`: 应用层提供的流式数据块回调。 +- `abortController?: AbortController`: 用于中止请求的控制器。 +- 其他中间件可能读写的、与当前请求相关的动态数据。 + +### 1.3. `MiddlewareName` (中间件名称) + +为了方便动态操作(如插入、替换、移除)中间件,每个重要的、可能被其他逻辑引用的中间件都应该有一个唯一的、可识别的名称。推荐使用 TypeScript 的 `enum` 来定义: + +```typescript +// example +export enum MiddlewareName { + LOGGING_START = 'LoggingStartMiddleware', + LOGGING_END = 'LoggingEndMiddleware', + ERROR_HANDLING = 'ErrorHandlingMiddleware', + ABORT_HANDLER = 'AbortHandlerMiddleware', + // Core Flow + TRANSFORM_CORE_TO_SDK_PARAMS = 'TransformCoreToSdkParamsMiddleware', + REQUEST_EXECUTION = 'RequestExecutionMiddleware', + STREAM_ADAPTER = 'StreamAdapterMiddleware', + RAW_SDK_CHUNK_TO_APP_CHUNK = 'RawSdkChunkToAppChunkMiddleware', + // Features + THINKING_TAG_EXTRACTION = 'ThinkingTagExtractionMiddleware', + TOOL_USE_TAG_EXTRACTION = 'ToolUseTagExtractionMiddleware', + MCP_TOOL_HANDLER = 'McpToolHandlerMiddleware', + // Finalization + FINAL_CHUNK_CONSUMER = 'FinalChunkConsumerAndNotifierMiddleware' + // Add more as needed +} +``` + +中间件实例需要某种方式暴露其 `MiddlewareName`,例如通过一个 `name` 属性。 + +### 1.4. 中间件执行结构 + +我们采用一种灵活的中间件执行结构。一个中间件通常是一个函数,它接收 `Context`、`Params`,以及一个 `next` 函数(用于调用链中的下一个中间件)。 + +```typescript +// 简化形式的中间件函数签名 +type MiddlewareFunction = ( + context: AiProviderMiddlewareContext, + params: any, // e.g., CompletionsParams + next: () => Promise // next 通常返回 Promise 以支持异步操作 +) => Promise // 中间件自身也可能返回 Promise + +// 或者更经典的 Koa/Express 风格 (三段式) +// type MiddlewareFactory = (api?: MiddlewareApi) => +// (nextMiddleware: (ctx: AiProviderMiddlewareContext, params: any) => Promise) => +// (context: AiProviderMiddlewareContext, params: any) => Promise; +// 当前设计更倾向于上述简化的 MiddlewareFunction,由 MiddlewareExecutor 负责 next 的编排。 +``` + +`MiddlewareExecutor` (或 `applyMiddlewares`) 会负责管理 `next` 的调用。 + +## 2. `MiddlewareBuilder` (通用中间件构建器) + +为了动态构建和管理中间件链,我们引入一个通用的 `MiddlewareBuilder` 类。 + +### 2.1. 设计理念 + +`MiddlewareBuilder` 提供了一个流式 API,用于以声明式的方式构建中间件链。它允许从一个基础链开始,然后根据特定条件添加、插入、替换或移除中间件。 + +### 2.2. API 概览 + +```typescript +class MiddlewareBuilder { + constructor(baseChain?: Middleware[]) + + add(middleware: Middleware): this + prepend(middleware: Middleware): this + insertAfter(targetName: MiddlewareName, middlewareToInsert: Middleware): this + insertBefore(targetName: MiddlewareName, middlewareToInsert: Middleware): this + replace(targetName: MiddlewareName, newMiddleware: Middleware): this + remove(targetName: MiddlewareName): this + + build(): Middleware[] // 返回构建好的中间件数组 + + // 可选:直接执行链 + execute( + context: AiProviderMiddlewareContext, + params: any, + middlewareExecutor: (chain: Middleware[], context: AiProviderMiddlewareContext, params: any) => void + ): void +} +``` + +### 2.3. 使用示例 + +```typescript +// 1. 定义一些中间件实例 (假设它们有 .name 属性) +const loggingStart = { name: MiddlewareName.LOGGING_START, fn: loggingStartFn } +const requestExec = { name: MiddlewareName.REQUEST_EXECUTION, fn: requestExecFn } +const streamAdapter = { name: MiddlewareName.STREAM_ADAPTER, fn: streamAdapterFn } +const customFeature = { name: MiddlewareName.CUSTOM_FEATURE, fn: customFeatureFn } // 假设自定义 + +// 2. 定义一个基础链 (可选) +const BASE_CHAIN: Middleware[] = [loggingStart, requestExec, streamAdapter] + +// 3. 使用 MiddlewareBuilder +const builder = new MiddlewareBuilder(BASE_CHAIN) + +if (params.needsCustomFeature) { + builder.insertAfter(MiddlewareName.STREAM_ADAPTER, customFeature) +} + +if (params.isHighSecurityContext) { + builder.insertBefore(MiddlewareName.REQUEST_EXECUTION, высокоSecurityCheckMiddleware) +} + +if (params.overrideLogging) { + builder.replace(MiddlewareName.LOGGING_START, newSpecialLoggingMiddleware) +} + +// 4. 获取最终链 +const finalChain = builder.build() + +// 5. 执行 (通过外部执行器) +// middlewareExecutor(finalChain, context, params); +// 或者 builder.execute(context, params, middlewareExecutor); +``` + +## 3. `MiddlewareExecutor` / `applyMiddlewares` (中间件执行器) + +这是负责接收 `MiddlewareBuilder` 构建的中间件链并实际执行它们的组件。 + +### 3.1. 职责 + +- 接收 `Middleware[]`, `AiProviderMiddlewareContext`, `Params`。 +- 按顺序迭代中间件。 +- 为每个中间件提供正确的 `next` 函数,该函数在被调用时会执行链中的下一个中间件。 +- 处理中间件执行过程中的Promise(如果中间件是异步的)。 +- 基础的错误捕获(具体错误处理应由链内的 `ErrorHandlingMiddleware` 负责)。 + +## 4. 在 `AiCoreService` 中使用 + +`AiCoreService` 中的每个核心业务方法 (如 `executeCompletions`) 将负责: + +1. 准备基础数据:实例化 `ApiClient`,转换 `Params` 为 `CoreRequest`。 +2. 实例化 `MiddlewareBuilder`,可能会传入一个特定于该业务方法的基础中间件链。 +3. 根据 `Params` 和 `CoreRequest` 中的条件,调用 `MiddlewareBuilder` 的方法来动态调整中间件链。 +4. 调用 `MiddlewareBuilder.build()` 获取最终的中间件链。 +5. 创建完整的 `AiProviderMiddlewareContext` (包含 `resolvePromise`, `rejectPromise` 等)。 +6. 调用 `MiddlewareExecutor` (或 `applyMiddlewares`) 来执行构建好的链。 + +## 5. 组合功能 + +对于组合功能(例如 "Completions then Translate"): + +- 不推荐创建一个单一、庞大的 `MiddlewareBuilder` 来处理整个组合流程。 +- 推荐在 `AiCoreService` 中创建一个新的方法,该方法按顺序 `await` 调用底层的原子 `AiCoreService` 方法(例如,先 `await this.executeCompletions(...)`,然后用其结果 `await this.translateText(...)`)。 +- 每个被调用的原子方法内部会使用其自身的 `MiddlewareBuilder` 实例来构建和执行其特定阶段的中间件链。 +- 这种方式最大化了复用,并保持了各部分职责的清晰。 + +## 6. 中间件命名和发现 + +为中间件赋予唯一的 `MiddlewareName` 对于 `MiddlewareBuilder` 的 `insertAfter`, `insertBefore`, `replace`, `remove` 等操作至关重要。确保中间件实例能够以某种方式暴露其名称(例如,一个 `name` 属性)。 diff --git a/src/renderer/src/aiCore/middleware/builder.ts b/src/renderer/src/aiCore/middleware/builder.ts new file mode 100644 index 0000000000..e76b59c2bd --- /dev/null +++ b/src/renderer/src/aiCore/middleware/builder.ts @@ -0,0 +1,241 @@ +import { DefaultCompletionsNamedMiddlewares } from './register' +import { BaseContext, CompletionsMiddleware, MethodMiddleware } from './types' + +/** + * 带有名称标识的中间件接口 + */ +export interface NamedMiddleware { + name: string + middleware: TMiddleware +} + +/** + * 中间件执行器函数类型 + */ +export type MiddlewareExecutor = ( + chain: any[], + context: TContext, + params: any +) => Promise + +/** + * 通用中间件构建器类 + * 提供流式 API 用于动态构建和管理中间件链 + * + * 注意:所有中间件都通过 MiddlewareRegistry 管理,使用 NamedMiddleware 格式 + */ +export class MiddlewareBuilder { + private middlewares: NamedMiddleware[] + + /** + * 构造函数 + * @param baseChain - 可选的基础中间件链(NamedMiddleware 格式) + */ + constructor(baseChain?: NamedMiddleware[]) { + this.middlewares = baseChain ? [...baseChain] : [] + } + + /** + * 在链的末尾添加中间件 + * @param middleware - 要添加的具名中间件 + * @returns this,支持链式调用 + */ + add(middleware: NamedMiddleware): this { + this.middlewares.push(middleware) + return this + } + + /** + * 在链的开头添加中间件 + * @param middleware - 要添加的具名中间件 + * @returns this,支持链式调用 + */ + prepend(middleware: NamedMiddleware): this { + this.middlewares.unshift(middleware) + return this + } + + /** + * 在指定中间件之后插入新中间件 + * @param targetName - 目标中间件名称 + * @param middlewareToInsert - 要插入的具名中间件 + * @returns this,支持链式调用 + */ + insertAfter(targetName: string, middlewareToInsert: NamedMiddleware): this { + const index = this.findMiddlewareIndex(targetName) + if (index !== -1) { + this.middlewares.splice(index + 1, 0, middlewareToInsert) + } else { + console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`) + } + return this + } + + /** + * 在指定中间件之前插入新中间件 + * @param targetName - 目标中间件名称 + * @param middlewareToInsert - 要插入的具名中间件 + * @returns this,支持链式调用 + */ + insertBefore(targetName: string, middlewareToInsert: NamedMiddleware): this { + const index = this.findMiddlewareIndex(targetName) + if (index !== -1) { + this.middlewares.splice(index, 0, middlewareToInsert) + } else { + console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`) + } + return this + } + + /** + * 替换指定的中间件 + * @param targetName - 要替换的中间件名称 + * @param newMiddleware - 新的具名中间件 + * @returns this,支持链式调用 + */ + replace(targetName: string, newMiddleware: NamedMiddleware): this { + const index = this.findMiddlewareIndex(targetName) + if (index !== -1) { + this.middlewares[index] = newMiddleware + } else { + console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法替换`) + } + return this + } + + /** + * 移除指定的中间件 + * @param targetName - 要移除的中间件名称 + * @returns this,支持链式调用 + */ + remove(targetName: string): this { + const index = this.findMiddlewareIndex(targetName) + if (index !== -1) { + this.middlewares.splice(index, 1) + } + return this + } + + /** + * 构建最终的中间件数组 + * @returns 构建好的中间件数组 + */ + build(): TMiddleware[] { + return this.middlewares.map((item) => item.middleware) + } + + /** + * 获取当前中间件链的副本(包含名称信息) + * @returns 当前中间件链的副本 + */ + getChain(): NamedMiddleware[] { + return [...this.middlewares] + } + + /** + * 检查是否包含指定名称的中间件 + * @param name - 中间件名称 + * @returns 是否包含该中间件 + */ + has(name: string): boolean { + return this.findMiddlewareIndex(name) !== -1 + } + + /** + * 获取中间件链的长度 + * @returns 中间件数量 + */ + get length(): number { + return this.middlewares.length + } + + /** + * 清空中间件链 + * @returns this,支持链式调用 + */ + clear(): this { + this.middlewares = [] + return this + } + + /** + * 直接执行构建好的中间件链 + * @param context - 中间件上下文 + * @param params - 参数 + * @param middlewareExecutor - 中间件执行器 + * @returns 执行结果 + */ + execute( + context: TContext, + params: any, + middlewareExecutor: MiddlewareExecutor + ): Promise { + const chain = this.build() + return middlewareExecutor(chain, context, params) + } + + /** + * 查找中间件在链中的索引 + * @param name - 中间件名称 + * @returns 索引,如果未找到返回 -1 + */ + private findMiddlewareIndex(name: string): number { + return this.middlewares.findIndex((item) => item.name === name) + } +} + +/** + * Completions 中间件构建器 + */ +export class CompletionsMiddlewareBuilder extends MiddlewareBuilder { + constructor(baseChain?: NamedMiddleware[]) { + super(baseChain) + } + + /** + * 使用默认的 Completions 中间件链 + * @returns CompletionsMiddlewareBuilder 实例 + */ + static withDefaults(): CompletionsMiddlewareBuilder { + return new CompletionsMiddlewareBuilder(DefaultCompletionsNamedMiddlewares) + } +} + +/** + * 通用方法中间件构建器 + */ +export class MethodMiddlewareBuilder extends MiddlewareBuilder { + constructor(baseChain?: NamedMiddleware[]) { + super(baseChain) + } +} + +// 便捷的工厂函数 + +/** + * 创建 Completions 中间件构建器 + * @param baseChain - 可选的基础链 + * @returns Completions 中间件构建器实例 + */ +export function createCompletionsBuilder( + baseChain?: NamedMiddleware[] +): CompletionsMiddlewareBuilder { + return new CompletionsMiddlewareBuilder(baseChain) +} + +/** + * 创建通用方法中间件构建器 + * @param baseChain - 可选的基础链 + * @returns 通用方法中间件构建器实例 + */ +export function createMethodBuilder(baseChain?: NamedMiddleware[]): MethodMiddlewareBuilder { + return new MethodMiddlewareBuilder(baseChain) +} + +/** + * 为中间件添加名称属性的辅助函数 + * 可以用于给现有的中间件添加名称属性 + */ +export function addMiddlewareName(middleware: T, name: string): T & { MIDDLEWARE_NAME: string } { + return Object.assign(middleware, { MIDDLEWARE_NAME: name }) +} diff --git a/src/renderer/src/aiCore/middleware/common/AbortHandlerMiddleware.ts b/src/renderer/src/aiCore/middleware/common/AbortHandlerMiddleware.ts new file mode 100644 index 0000000000..7186cec12f --- /dev/null +++ b/src/renderer/src/aiCore/middleware/common/AbortHandlerMiddleware.ts @@ -0,0 +1,106 @@ +import { Chunk, ChunkType, ErrorChunk } from '@renderer/types/chunk' +import { addAbortController, removeAbortController } from '@renderer/utils/abortController' + +import { CompletionsParams, CompletionsResult } from '../schemas' +import type { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'AbortHandlerMiddleware' + +export const AbortHandlerMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + const isRecursiveCall = ctx._internal?.toolProcessingState?.isRecursiveCall || false + + // 在递归调用中,跳过 AbortController 的创建,直接使用已有的 + if (isRecursiveCall) { + const result = await next(ctx, params) + return result + } + + // 获取当前消息的ID用于abort管理 + // 优先使用处理过的消息,如果没有则使用原始消息 + let messageId: string | undefined + + if (typeof params.messages === 'string') { + messageId = `message-${Date.now()}-${Math.random().toString(36).substring(2, 9)}` + } else { + const processedMessages = params.messages + const lastUserMessage = processedMessages.findLast((m) => m.role === 'user') + messageId = lastUserMessage?.id + } + + if (!messageId) { + console.warn(`[${MIDDLEWARE_NAME}] No messageId found, abort functionality will not be available.`) + return next(ctx, params) + } + + const abortController = new AbortController() + const abortFn = (): void => abortController.abort() + + addAbortController(messageId, abortFn) + + let abortSignal: AbortSignal | null = abortController.signal + + const cleanup = (): void => { + removeAbortController(messageId as string, abortFn) + if (ctx._internal?.flowControl) { + ctx._internal.flowControl.abortController = undefined + ctx._internal.flowControl.abortSignal = undefined + ctx._internal.flowControl.cleanup = undefined + } + abortSignal = null + } + + // 将controller添加到_internal中的flowControl状态 + if (!ctx._internal.flowControl) { + ctx._internal.flowControl = {} + } + ctx._internal.flowControl.abortController = abortController + ctx._internal.flowControl.abortSignal = abortSignal + ctx._internal.flowControl.cleanup = cleanup + + const result = await next(ctx, params) + + const error = new DOMException('Request was aborted', 'AbortError') + + const streamWithAbortHandler = (result.stream as ReadableStream).pipeThrough( + new TransformStream({ + transform(chunk, controller) { + // 检查 abort 状态 + if (abortSignal?.aborted) { + // 转换为 ErrorChunk + const errorChunk: ErrorChunk = { + type: ChunkType.ERROR, + error + } + + controller.enqueue(errorChunk) + cleanup() + return + } + + // 正常传递 chunk + controller.enqueue(chunk) + }, + + flush(controller) { + // 在流结束时再次检查 abort 状态 + if (abortSignal?.aborted) { + const errorChunk: ErrorChunk = { + type: ChunkType.ERROR, + error + } + controller.enqueue(errorChunk) + } + // 在流完全处理完成后清理 AbortController + cleanup() + } + }) + ) + + return { + ...result, + stream: streamWithAbortHandler + } + } diff --git a/src/renderer/src/aiCore/middleware/common/ErrorHandlerMiddleware.ts b/src/renderer/src/aiCore/middleware/common/ErrorHandlerMiddleware.ts new file mode 100644 index 0000000000..2dd5aa9833 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/common/ErrorHandlerMiddleware.ts @@ -0,0 +1,60 @@ +import { Chunk } from '@renderer/types/chunk' +import { isAbortError } from '@renderer/utils/error' + +import { CompletionsResult } from '../schemas' +import { CompletionsContext } from '../types' +import { createErrorChunk } from '../utils' + +export const MIDDLEWARE_NAME = 'ErrorHandlerMiddleware' + +/** + * 创建一个错误处理中间件。 + * + * 这是一个高阶函数,它接收配置并返回一个标准的中间件。 + * 它的主要职责是捕获下游中间件或API调用中发生的任何错误。 + * + * @param config - 中间件的配置。 + * @returns 一个配置好的CompletionsMiddleware。 + */ +export const ErrorHandlerMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params): Promise => { + const { shouldThrow } = params + + try { + // 尝试执行下一个中间件 + return await next(ctx, params) + } catch (error: any) { + let errorStream: ReadableStream | undefined + // 有些sdk的abort error 是直接抛出的 + if (!isAbortError(error)) { + // 1. 使用通用的工具函数将错误解析为标准格式 + const errorChunk = createErrorChunk(error) + // 2. 调用从外部传入的 onError 回调 + if (params.onError) { + params.onError(error) + } + + // 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递 + if (shouldThrow) { + throw error + } + + // 如果不抛出,则创建一个只包含该错误块的流并向下传递 + errorStream = new ReadableStream({ + start(controller) { + controller.enqueue(errorChunk) + controller.close() + } + }) + } + + return { + rawOutput: undefined, + stream: errorStream, // 将包含错误的流传递下去 + controller: undefined, + getText: () => '' // 错误情况下没有文本结果 + } + } + } diff --git a/src/renderer/src/aiCore/middleware/common/FinalChunkConsumerMiddleware.ts b/src/renderer/src/aiCore/middleware/common/FinalChunkConsumerMiddleware.ts new file mode 100644 index 0000000000..b0b9bd7ce6 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/common/FinalChunkConsumerMiddleware.ts @@ -0,0 +1,183 @@ +import Logger from '@renderer/config/logger' +import { Usage } from '@renderer/types' +import type { Chunk } from '@renderer/types/chunk' +import { ChunkType } from '@renderer/types/chunk' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'FinalChunkConsumerAndNotifierMiddleware' + +/** + * 最终Chunk消费和通知中间件 + * + * 职责: + * 1. 消费所有GenericChunk流中的chunks并转发给onChunk回调 + * 2. 累加usage/metrics数据(从原始SDK chunks或GenericChunk中提取) + * 3. 在检测到LLM_RESPONSE_COMPLETE时发送包含累计数据的BLOCK_COMPLETE + * 4. 处理MCP工具调用的多轮请求中的数据累加 + */ +const FinalChunkConsumerMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + const isRecursiveCall = + params._internal?.toolProcessingState?.isRecursiveCall || + ctx._internal?.toolProcessingState?.isRecursiveCall || + false + + // 初始化累计数据(只在顶层调用时初始化) + if (!isRecursiveCall) { + if (!ctx._internal.customState) { + ctx._internal.customState = {} + } + ctx._internal.observer = { + usage: { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0 + }, + metrics: { + completion_tokens: 0, + time_completion_millsec: 0, + time_first_token_millsec: 0, + time_thinking_millsec: 0 + } + } + // 初始化文本累积器 + ctx._internal.customState.accumulatedText = '' + ctx._internal.customState.startTimestamp = Date.now() + } + + // 调用下游中间件 + const result = await next(ctx, params) + + // 响应后处理:处理GenericChunk流式响应 + if (result.stream) { + const resultFromUpstream = result.stream + + if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { + const reader = resultFromUpstream.getReader() + + try { + while (true) { + const { done, value: chunk } = await reader.read() + if (done) { + Logger.debug(`[${MIDDLEWARE_NAME}] Input stream finished.`) + break + } + + if (chunk) { + const genericChunk = chunk as GenericChunk + // 提取并累加usage/metrics数据 + extractAndAccumulateUsageMetrics(ctx, genericChunk) + + const shouldSkipChunk = + isRecursiveCall && + (genericChunk.type === ChunkType.BLOCK_COMPLETE || + genericChunk.type === ChunkType.LLM_RESPONSE_COMPLETE) + + if (!shouldSkipChunk) params.onChunk?.(genericChunk) + } else { + Logger.warn(`[${MIDDLEWARE_NAME}] Received undefined chunk before stream was done.`) + } + } + } catch (error) { + Logger.error(`[${MIDDLEWARE_NAME}] Error consuming stream:`, error) + throw error + } finally { + if (params.onChunk && !isRecursiveCall) { + params.onChunk({ + type: ChunkType.BLOCK_COMPLETE, + response: { + usage: ctx._internal.observer?.usage ? { ...ctx._internal.observer.usage } : undefined, + metrics: ctx._internal.observer?.metrics ? { ...ctx._internal.observer.metrics } : undefined + } + } as Chunk) + if (ctx._internal.toolProcessingState) { + ctx._internal.toolProcessingState = {} + } + } + } + + // 为流式输出添加getText方法 + const modifiedResult = { + ...result, + stream: new ReadableStream({ + start(controller) { + controller.close() + } + }), + getText: () => { + return ctx._internal.customState?.accumulatedText || '' + } + } + + return modifiedResult + } else { + Logger.debug(`[${MIDDLEWARE_NAME}] No GenericChunk stream to process.`) + } + } + + return result + } + +/** + * 从GenericChunk或原始SDK chunks中提取usage/metrics数据并累加 + */ +function extractAndAccumulateUsageMetrics(ctx: CompletionsContext, chunk: GenericChunk): void { + if (!ctx._internal.observer?.usage || !ctx._internal.observer?.metrics) { + return + } + + try { + if (ctx._internal.customState && !ctx._internal.customState?.firstTokenTimestamp) { + ctx._internal.customState.firstTokenTimestamp = Date.now() + Logger.debug(`[${MIDDLEWARE_NAME}] First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`) + } + if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) { + Logger.debug(`[${MIDDLEWARE_NAME}] LLM_RESPONSE_COMPLETE chunk received:`, ctx._internal) + // 从LLM_RESPONSE_COMPLETE chunk中提取usage数据 + if (chunk.response?.usage) { + accumulateUsage(ctx._internal.observer.usage, chunk.response.usage) + } + + if (ctx._internal.customState && ctx._internal.customState?.firstTokenTimestamp) { + ctx._internal.observer.metrics.time_first_token_millsec = + ctx._internal.customState.firstTokenTimestamp - ctx._internal.customState.startTimestamp + ctx._internal.observer.metrics.time_completion_millsec += + Date.now() - ctx._internal.customState.firstTokenTimestamp + } + } + + // 也可以从其他chunk类型中提取metrics数据 + if (chunk.type === ChunkType.THINKING_COMPLETE && chunk.thinking_millsec && ctx._internal.observer?.metrics) { + ctx._internal.observer.metrics.time_thinking_millsec = Math.max( + ctx._internal.observer.metrics.time_thinking_millsec || 0, + chunk.thinking_millsec + ) + } + } catch (error) { + console.error(`[${MIDDLEWARE_NAME}] Error extracting usage/metrics from chunk:`, error) + } +} + +/** + * 累加usage数据 + */ +function accumulateUsage(accumulated: Usage, newUsage: Usage): void { + if (newUsage.prompt_tokens !== undefined) { + accumulated.prompt_tokens += newUsage.prompt_tokens + } + if (newUsage.completion_tokens !== undefined) { + accumulated.completion_tokens += newUsage.completion_tokens + } + if (newUsage.total_tokens !== undefined) { + accumulated.total_tokens += newUsage.total_tokens + } + if (newUsage.thoughts_tokens !== undefined) { + accumulated.thoughts_tokens = (accumulated.thoughts_tokens || 0) + newUsage.thoughts_tokens + } +} + +export default FinalChunkConsumerMiddleware diff --git a/src/renderer/src/aiCore/middleware/common/LoggingMiddleware.ts b/src/renderer/src/aiCore/middleware/common/LoggingMiddleware.ts new file mode 100644 index 0000000000..361eea3119 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/common/LoggingMiddleware.ts @@ -0,0 +1,64 @@ +import { BaseContext, MethodMiddleware, MiddlewareAPI } from '../types' + +export const MIDDLEWARE_NAME = 'GenericLoggingMiddlewares' + +/** + * Helper function to safely stringify arguments for logging, handling circular references and large objects. + * 安全地字符串化日志参数的辅助函数,处理循环引用和大型对象。 + * @param args - The arguments array to stringify. 要字符串化的参数数组。 + * @returns A string representation of the arguments. 参数的字符串表示形式。 + */ +const stringifyArgsForLogging = (args: any[]): string => { + try { + return args + .map((arg) => { + if (typeof arg === 'function') return '[Function]' + if (typeof arg === 'object' && arg !== null && arg.constructor === Object && Object.keys(arg).length > 20) { + return '[Object with >20 keys]' + } + // Truncate long strings to avoid flooding logs 截断长字符串以避免日志泛滥 + const stringifiedArg = JSON.stringify(arg, null, 2) + return stringifiedArg && stringifiedArg.length > 200 ? stringifiedArg.substring(0, 200) + '...' : stringifiedArg + }) + .join(', ') + } catch (e) { + return '[Error serializing arguments]' // Handle potential errors during stringification 处理字符串化期间的潜在错误 + } +} + +/** + * Generic logging middleware for provider methods. + * 为提供者方法创建一个通用的日志中间件。 + * This middleware logs the initiation, success/failure, and duration of a method call. + * 此中间件记录方法调用的启动、成功/失败以及持续时间。 + */ + +/** + * Creates a generic logging middleware for provider methods. + * 为提供者方法创建一个通用的日志中间件。 + * @returns A `MethodMiddleware` instance. 一个 `MethodMiddleware` 实例。 + */ +export const createGenericLoggingMiddleware: () => MethodMiddleware = () => { + const middlewareName = 'GenericLoggingMiddleware' + // eslint-disable-next-line @typescript-eslint/no-unused-vars + return (_: MiddlewareAPI) => (next) => async (ctx, args) => { + const methodName = ctx.methodName + const logPrefix = `[${middlewareName} (${methodName})]` + console.log(`${logPrefix} Initiating. Args:`, stringifyArgsForLogging(args)) + const startTime = Date.now() + try { + const result = await next(ctx, args) + const duration = Date.now() - startTime + // Log successful completion of the method call with duration. / + // 记录方法调用成功完成及其持续时间。 + console.log(`${logPrefix} Successful. Duration: ${duration}ms`) + return result + } catch (error) { + const duration = Date.now() - startTime + // Log failure of the method call with duration and error information. / + // 记录方法调用失败及其持续时间和错误信息。 + console.error(`${logPrefix} Failed. Duration: ${duration}ms`, error) + throw error // Re-throw the error to be handled by subsequent layers or the caller / 重新抛出错误,由后续层或调用者处理 + } + } +} diff --git a/src/renderer/src/aiCore/middleware/composer.ts b/src/renderer/src/aiCore/middleware/composer.ts new file mode 100644 index 0000000000..8b93b8015a --- /dev/null +++ b/src/renderer/src/aiCore/middleware/composer.ts @@ -0,0 +1,285 @@ +import { + RequestOptions, + SdkInstance, + SdkMessageParam, + SdkParams, + SdkRawChunk, + SdkRawOutput, + SdkTool, + SdkToolCall +} from '@renderer/types/sdk' + +import { BaseApiClient } from '../clients' +import { CompletionsParams, CompletionsResult } from './schemas' +import { + BaseContext, + CompletionsContext, + CompletionsMiddleware, + MethodMiddleware, + MIDDLEWARE_CONTEXT_SYMBOL, + MiddlewareAPI +} from './types' + +/** + * Creates the initial context for a method call, populating method-specific fields. / + * 为方法调用创建初始上下文,并填充特定于该方法的字段。 + * @param methodName - The name of the method being called. / 被调用的方法名。 + * @param originalCallArgs - The actual arguments array from the proxy/method call. / 代理/方法调用的实际参数数组。 + * @param providerId - The ID of the provider, if available. / 提供者的ID(如果可用)。 + * @param providerInstance - The instance of the provider. / 提供者实例。 + * @param specificContextFactory - An optional factory function to create a specific context type from the base context and original call arguments. / 一个可选的工厂函数,用于从基础上下文和原始调用参数创建特定的上下文类型。 + * @returns The created context object. / 创建的上下文对象。 + */ +function createInitialCallContext( + methodName: string, + originalCallArgs: TCallArgs, // Renamed from originalArgs to avoid confusion with context.originalArgs + // Factory to create specific context from base and the *original call arguments array* + specificContextFactory?: (base: BaseContext, callArgs: TCallArgs) => TContext +): TContext { + const baseContext: BaseContext = { + [MIDDLEWARE_CONTEXT_SYMBOL]: true, + methodName, + originalArgs: originalCallArgs // Store the full original arguments array in the context + } + + if (specificContextFactory) { + return specificContextFactory(baseContext, originalCallArgs) + } + return baseContext as TContext // Fallback to base context if no specific factory +} + +/** + * Composes an array of functions from right to left. / + * 从右到左组合一个函数数组。 + * `compose(f, g, h)` is `(...args) => f(g(h(...args)))`. / + * `compose(f, g, h)` 等同于 `(...args) => f(g(h(...args)))`。 + * Each function in funcs is expected to take the result of the next function + * (or the initial value for the rightmost function) as its argument. / + * `funcs` 中的每个函数都期望接收下一个函数的结果(或最右侧函数的初始值)作为其参数。 + * @param funcs - Array of functions to compose. / 要组合的函数数组。 + * @returns The composed function. / 组合后的函数。 + */ +function compose(...funcs: Array<(...args: any[]) => any>): (...args: any[]) => any { + if (funcs.length === 0) { + // If no functions to compose, return a function that returns its first argument, or undefined if no args. / + // 如果没有要组合的函数,则返回一个函数,该函数返回其第一个参数,如果没有参数则返回undefined。 + return (...args: any[]) => (args.length > 0 ? args[0] : undefined) + } + if (funcs.length === 1) { + return funcs[0] + } + return funcs.reduce( + (a, b) => + (...args: any[]) => + a(b(...args)) + ) +} + +/** + * Applies an array of Redux-style middlewares to a generic provider method. / + * 将一组Redux风格的中间件应用于一个通用的提供者方法。 + * This version keeps arguments as an array throughout the middleware chain. / + * 此版本在整个中间件链中将参数保持为数组形式。 + * @param originalProviderInstance - The original provider instance. / 原始提供者实例。 + * @param methodName - The name of the method to be enhanced. / 需要增强的方法名。 + * @param originalMethod - The original method to be wrapped. / 需要包装的原始方法。 + * @param middlewares - An array of `ProviderMethodMiddleware` to apply. / 要应用的 `ProviderMethodMiddleware` 数组。 + * @param specificContextFactory - An optional factory to create a specific context for this method. / 可选的工厂函数,用于为此方法创建特定的上下文。 + * @returns An enhanced method with the middlewares applied. / 应用了中间件的增强方法。 + */ +export function applyMethodMiddlewares< + TArgs extends unknown[] = unknown[], // Original method's arguments array type / 原始方法的参数数组类型 + TResult = unknown, + TContext extends BaseContext = BaseContext +>( + methodName: string, + originalMethod: (...args: TArgs) => Promise, + middlewares: MethodMiddleware[], // Expects generic middlewares / 期望通用中间件 + specificContextFactory?: (base: BaseContext, callArgs: TArgs) => TContext +): (...args: TArgs) => Promise { + // Returns a function matching the original method signature. / + // 返回一个与原始方法签名匹配的函数。 + return async function enhancedMethod(...methodCallArgs: TArgs): Promise { + const ctx = createInitialCallContext( + methodName, + methodCallArgs, // Pass the actual call arguments array / 传递实际的调用参数数组 + specificContextFactory + ) + + const api: MiddlewareAPI = { + getContext: () => ctx, + getOriginalArgs: () => methodCallArgs // API provides the original arguments array / API提供原始参数数组 + } + + // `finalDispatch` is the function that will ultimately call the original provider method. / + // `finalDispatch` 是最终将调用原始提供者方法的函数。 + // It receives the current context and arguments, which may have been transformed by middlewares. / + // 它接收当前的上下文和参数,这些参数可能已被中间件转换。 + const finalDispatch = async ( + _: TContext, + currentArgs: TArgs // Generic final dispatch expects args array / 通用finalDispatch期望参数数组 + ): Promise => { + return originalMethod.apply(currentArgs) + } + + const chain = middlewares.map((middleware) => middleware(api)) // Cast API if TContext/TArgs mismatch general ProviderMethodMiddleware / 如果TContext/TArgs与通用的ProviderMethodMiddleware不匹配,则转换API + const composedMiddlewareLogic = compose(...chain) + const enhancedDispatch = composedMiddlewareLogic(finalDispatch) + + return enhancedDispatch(ctx, methodCallArgs) // Pass context and original args array / 传递上下文和原始参数数组 + } +} + +/** + * Applies an array of `CompletionsMiddleware` to the `completions` method. / + * 将一组 `CompletionsMiddleware` 应用于 `completions` 方法。 + * This version adapts for `CompletionsMiddleware` expecting a single `params` object. / + * 此版本适配了期望单个 `params` 对象的 `CompletionsMiddleware`。 + * @param originalProviderInstance - The original provider instance. / 原始提供者实例。 + * @param originalCompletionsMethod - The original SDK `createCompletions` method. / 原始的 SDK `createCompletions` 方法。 + * @param middlewares - An array of `CompletionsMiddleware` to apply. / 要应用的 `CompletionsMiddleware` 数组。 + * @returns An enhanced `completions` method with the middlewares applied. / 应用了中间件的增强版 `completions` 方法。 + */ +export function applyCompletionsMiddlewares< + TSdkInstance extends SdkInstance = SdkInstance, + TSdkParams extends SdkParams = SdkParams, + TRawOutput extends SdkRawOutput = SdkRawOutput, + TRawChunk extends SdkRawChunk = SdkRawChunk, + TMessageParam extends SdkMessageParam = SdkMessageParam, + TToolCall extends SdkToolCall = SdkToolCall, + TSdkSpecificTool extends SdkTool = SdkTool +>( + originalApiClientInstance: BaseApiClient< + TSdkInstance, + TSdkParams, + TRawOutput, + TRawChunk, + TMessageParam, + TToolCall, + TSdkSpecificTool + >, + originalCompletionsMethod: (payload: TSdkParams, options?: RequestOptions) => Promise, + middlewares: CompletionsMiddleware< + TSdkParams, + TMessageParam, + TToolCall, + TSdkInstance, + TRawOutput, + TRawChunk, + TSdkSpecificTool + >[] +): (params: CompletionsParams, options?: RequestOptions) => Promise { + // Returns a function matching the original method signature. / + // 返回一个与原始方法签名匹配的函数。 + + const methodName = 'completions' + + // Factory to create AiProviderMiddlewareCompletionsContext. / + // 用于创建 AiProviderMiddlewareCompletionsContext 的工厂函数。 + const completionsContextFactory = ( + base: BaseContext, + callArgs: [CompletionsParams] + ): CompletionsContext< + TSdkParams, + TMessageParam, + TToolCall, + TSdkInstance, + TRawOutput, + TRawChunk, + TSdkSpecificTool + > => { + return { + ...base, + methodName, + apiClientInstance: originalApiClientInstance, + originalArgs: callArgs, + _internal: { + toolProcessingState: { + recursionDepth: 0, + isRecursiveCall: false + }, + observer: {} + } + } + } + + return async function enhancedCompletionsMethod( + params: CompletionsParams, + options?: RequestOptions + ): Promise { + // `originalCallArgs` for context creation is `[params]`. / + // 用于上下文创建的 `originalCallArgs` 是 `[params]`。 + const originalCallArgs: [CompletionsParams] = [params] + const baseContext: BaseContext = { + [MIDDLEWARE_CONTEXT_SYMBOL]: true, + methodName, + originalArgs: originalCallArgs + } + const ctx = completionsContextFactory(baseContext, originalCallArgs) + + const api: MiddlewareAPI< + CompletionsContext, + [CompletionsParams] + > = { + getContext: () => ctx, + getOriginalArgs: () => originalCallArgs // API provides [CompletionsParams] / API提供 `[CompletionsParams]` + } + + // `finalDispatch` for CompletionsMiddleware: expects (context, params) not (context, args_array). / + // `CompletionsMiddleware` 的 `finalDispatch`:期望 (context, params) 而不是 (context, args_array)。 + const finalDispatch = async ( + context: CompletionsContext< + TSdkParams, + TMessageParam, + TToolCall, + TSdkInstance, + TRawOutput, + TRawChunk, + TSdkSpecificTool + > // Context passed through / 上下文透传 + // _currentParams: CompletionsParams // Directly takes params / 直接接收参数 (unused but required for middleware signature) + ): Promise => { + // At this point, middleware should have transformed CompletionsParams to SDK params + // and stored them in context. If no transformation happened, we need to handle it. + // 此时,中间件应该已经将 CompletionsParams 转换为 SDK 参数并存储在上下文中。 + // 如果没有进行转换,我们需要处理它。 + + const sdkPayload = context._internal?.sdkPayload + if (!sdkPayload) { + throw new Error('SDK payload not found in context. Middleware chain should have transformed parameters.') + } + + const abortSignal = context._internal.flowControl?.abortSignal + const timeout = context._internal.customState?.sdkMetadata?.timeout + + // Call the original SDK method with transformed parameters + // 使用转换后的参数调用原始 SDK 方法 + const rawOutput = await originalCompletionsMethod.call(originalApiClientInstance, sdkPayload, { + ...options, + signal: abortSignal, + timeout + }) + + // Return result wrapped in CompletionsResult format + // 以 CompletionsResult 格式返回包装的结果 + return { + rawOutput + } as CompletionsResult + } + + const chain = middlewares.map((middleware) => middleware(api)) + const composedMiddlewareLogic = compose(...chain) + + // `enhancedDispatch` has the signature `(context, params) => Promise`. / + // `enhancedDispatch` 的签名为 `(context, params) => Promise`。 + const enhancedDispatch = composedMiddlewareLogic(finalDispatch) + + // 将 enhancedDispatch 保存到 context 中,供中间件进行递归调用 + // 这样可以避免重复执行整个中间件链 + ctx._internal.enhancedDispatch = enhancedDispatch + + // Execute with context and the single params object. / + // 使用上下文和单个参数对象执行。 + return enhancedDispatch(ctx, params) + } +} diff --git a/src/renderer/src/aiCore/middleware/core/McpToolChunkMiddleware.ts b/src/renderer/src/aiCore/middleware/core/McpToolChunkMiddleware.ts new file mode 100644 index 0000000000..3dd046c12e --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/McpToolChunkMiddleware.ts @@ -0,0 +1,306 @@ +import Logger from '@renderer/config/logger' +import { MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types' +import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk' +import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk' +import { parseAndCallTools } from '@renderer/utils/mcp-tools' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'McpToolChunkMiddleware' +const MAX_TOOL_RECURSION_DEPTH = 20 // 防止无限递归 + +/** + * MCP工具处理中间件 + * + * 职责: + * 1. 检测并拦截MCP工具进展chunk(Function Call方式和Tool Use方式) + * 2. 执行工具调用 + * 3. 递归处理工具结果 + * 4. 管理工具调用状态和递归深度 + */ +export const McpToolChunkMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + const mcpTools = params.mcpTools || [] + + // 如果没有工具,直接调用下一个中间件 + if (!mcpTools || mcpTools.length === 0) { + return next(ctx, params) + } + + const executeWithToolHandling = async (currentParams: CompletionsParams, depth = 0): Promise => { + if (depth >= MAX_TOOL_RECURSION_DEPTH) { + Logger.error(`🔧 [${MIDDLEWARE_NAME}] Maximum recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`) + throw new Error(`Maximum tool recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`) + } + + let result: CompletionsResult + + if (depth === 0) { + result = await next(ctx, currentParams) + } else { + const enhancedCompletions = ctx._internal.enhancedDispatch + if (!enhancedCompletions) { + Logger.error(`🔧 [${MIDDLEWARE_NAME}] Enhanced completions method not found, cannot perform recursive call`) + throw new Error('Enhanced completions method not found') + } + + ctx._internal.toolProcessingState!.isRecursiveCall = true + ctx._internal.toolProcessingState!.recursionDepth = depth + + result = await enhancedCompletions(ctx, currentParams) + } + + if (!result.stream) { + Logger.error(`🔧 [${MIDDLEWARE_NAME}] No stream returned from enhanced completions`) + throw new Error('No stream returned from enhanced completions') + } + + const resultFromUpstream = result.stream as ReadableStream + const toolHandlingStream = resultFromUpstream.pipeThrough( + createToolHandlingTransform(ctx, currentParams, mcpTools, depth, executeWithToolHandling) + ) + + return { + ...result, + stream: toolHandlingStream + } + } + + return executeWithToolHandling(params, 0) + } + +/** + * 创建工具处理的 TransformStream + */ +function createToolHandlingTransform( + ctx: CompletionsContext, + currentParams: CompletionsParams, + mcpTools: MCPTool[], + depth: number, + executeWithToolHandling: (params: CompletionsParams, depth: number) => Promise +): TransformStream { + const toolCalls: SdkToolCall[] = [] + const toolUseResponses: MCPToolResponse[] = [] + const allToolResponses: MCPToolResponse[] = [] // 统一的工具响应状态管理数组 + let hasToolCalls = false + let hasToolUseResponses = false + let streamEnded = false + + return new TransformStream({ + async transform(chunk: GenericChunk, controller) { + try { + // 处理MCP工具进展chunk + if (chunk.type === ChunkType.MCP_TOOL_CREATED) { + const createdChunk = chunk as MCPToolCreatedChunk + + // 1. 处理Function Call方式的工具调用 + if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) { + toolCalls.push(...createdChunk.tool_calls) + hasToolCalls = true + } + + // 2. 处理Tool Use方式的工具调用 + if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) { + toolUseResponses.push(...createdChunk.tool_use_responses) + hasToolUseResponses = true + } + + // 不转发MCP工具进展chunks,避免重复处理 + return + } + + // 转发其他所有chunk + controller.enqueue(chunk) + } catch (error) { + console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error) + controller.error(error) + } + }, + + async flush(controller) { + const shouldExecuteToolCalls = hasToolCalls && toolCalls.length > 0 + const shouldExecuteToolUseResponses = hasToolUseResponses && toolUseResponses.length > 0 + + if (!streamEnded && (shouldExecuteToolCalls || shouldExecuteToolUseResponses)) { + streamEnded = true + + try { + let toolResult: SdkMessageParam[] = [] + + if (shouldExecuteToolCalls) { + toolResult = await executeToolCalls( + ctx, + toolCalls, + mcpTools, + allToolResponses, + currentParams.onChunk, + currentParams.assistant.model! + ) + } else if (shouldExecuteToolUseResponses) { + toolResult = await executeToolUseResponses( + ctx, + toolUseResponses, + mcpTools, + allToolResponses, + currentParams.onChunk, + currentParams.assistant.model! + ) + } + + if (toolResult.length > 0) { + const output = ctx._internal.toolProcessingState?.output + + const newParams = buildParamsWithToolResults(ctx, currentParams, output!, toolResult, toolCalls) + await executeWithToolHandling(newParams, depth + 1) + } + } catch (error) { + console.error(`🔧 [${MIDDLEWARE_NAME}] Error in tool processing:`, error) + controller.error(error) + } finally { + hasToolCalls = false + hasToolUseResponses = false + } + } + } + }) +} + +/** + * 执行工具调用(Function Call 方式) + */ +async function executeToolCalls( + ctx: CompletionsContext, + toolCalls: SdkToolCall[], + mcpTools: MCPTool[], + allToolResponses: MCPToolResponse[], + onChunk: CompletionsParams['onChunk'], + model: Model +): Promise { + // 转换为MCPToolResponse格式 + const mcpToolResponses: ToolCallResponse[] = toolCalls + .map((toolCall) => { + const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools) + if (!mcpTool) { + return undefined + } + return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool) + }) + .filter((t): t is ToolCallResponse => typeof t !== 'undefined') + + if (mcpToolResponses.length === 0) { + console.warn(`🔧 [${MIDDLEWARE_NAME}] No valid MCP tool responses to execute`) + return [] + } + + // 使用现有的parseAndCallTools函数执行工具 + const toolResults = await parseAndCallTools( + mcpToolResponses, + allToolResponses, + onChunk, + (mcpToolResponse, resp, model) => { + return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model) + }, + model, + mcpTools + ) + + return toolResults +} + +/** + * 执行工具使用响应(Tool Use Response 方式) + * 处理已经解析好的 ToolUseResponse[],不需要重新解析字符串 + */ +async function executeToolUseResponses( + ctx: CompletionsContext, + toolUseResponses: MCPToolResponse[], + mcpTools: MCPTool[], + allToolResponses: MCPToolResponse[], + onChunk: CompletionsParams['onChunk'], + model: Model +): Promise { + // 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse + const toolResults = await parseAndCallTools( + toolUseResponses, + allToolResponses, + onChunk, + (mcpToolResponse, resp, model) => { + return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model) + }, + model, + mcpTools + ) + + return toolResults +} + +/** + * 构建包含工具结果的新参数 + */ +function buildParamsWithToolResults( + ctx: CompletionsContext, + currentParams: CompletionsParams, + output: SdkRawOutput | string, + toolResults: SdkMessageParam[], + toolCalls: SdkToolCall[] +): CompletionsParams { + // 获取当前已经转换好的reqMessages,如果没有则使用原始messages + const currentReqMessages = getCurrentReqMessages(ctx) + + const apiClient = ctx.apiClientInstance + + // 从回复中构建助手消息 + const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls) + + // 估算新增消息的 token 消耗并累加到 usage 中 + if (ctx._internal.observer?.usage && newReqMessages.length > currentReqMessages.length) { + try { + const newMessages = newReqMessages.slice(currentReqMessages.length) + const additionalTokens = newMessages.reduce((acc, message) => { + return acc + ctx.apiClientInstance.estimateMessageTokens(message) + }, 0) + + if (additionalTokens > 0) { + ctx._internal.observer.usage.prompt_tokens += additionalTokens + ctx._internal.observer.usage.total_tokens += additionalTokens + } + } catch (error) { + Logger.error(`🔧 [${MIDDLEWARE_NAME}] Error estimating token usage for new messages:`, error) + } + } + + // 更新递归状态 + if (!ctx._internal.toolProcessingState) { + ctx._internal.toolProcessingState = {} + } + ctx._internal.toolProcessingState.isRecursiveCall = true + ctx._internal.toolProcessingState.recursionDepth = (ctx._internal.toolProcessingState?.recursionDepth || 0) + 1 + + return { + ...currentParams, + _internal: { + ...ctx._internal, + sdkPayload: ctx._internal.sdkPayload, + newReqMessages: newReqMessages + } + } +} + +/** + * 类型安全地获取当前请求消息 + * 使用API客户端提供的抽象方法,保持中间件的provider无关性 + */ +function getCurrentReqMessages(ctx: CompletionsContext): SdkMessageParam[] { + const sdkPayload = ctx._internal.sdkPayload + if (!sdkPayload) { + return [] + } + + // 使用API客户端的抽象方法来提取消息,保持provider无关性 + return ctx.apiClientInstance.extractMessagesFromSdkPayload(sdkPayload) +} + +export default McpToolChunkMiddleware diff --git a/src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts b/src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts new file mode 100644 index 0000000000..36c1693b3a --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts @@ -0,0 +1,48 @@ +import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient' +import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk' + +import { AnthropicStreamListener } from '../../clients/types' +import { CompletionsParams, CompletionsResult } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'RawStreamListenerMiddleware' + +export const RawStreamListenerMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + const result = await next(ctx, params) + + // 在这里可以监听到从SDK返回的最原始流 + if (result.rawOutput) { + console.log(`[${MIDDLEWARE_NAME}] 检测到原始SDK输出,准备附加监听器`) + + const providerType = ctx.apiClientInstance.provider.type + // TODO: 后面下放到AnthropicAPIClient + if (providerType === 'anthropic') { + const anthropicListener: AnthropicStreamListener = { + onMessage: (message) => { + if (ctx._internal?.toolProcessingState) { + ctx._internal.toolProcessingState.output = message + } + } + // onContentBlock: (contentBlock) => { + // console.log(`[${MIDDLEWARE_NAME}] 📝 Anthropic content block:`, contentBlock.type) + // } + } + + const specificApiClient = ctx.apiClientInstance as AnthropicAPIClient + + const monitoredOutput = specificApiClient.attachRawStreamListener( + result.rawOutput as AnthropicSdkRawOutput, + anthropicListener + ) + return { + ...result, + rawOutput: monitoredOutput + } + } + } + + return result + } diff --git a/src/renderer/src/aiCore/middleware/core/ResponseTransformMiddleware.ts b/src/renderer/src/aiCore/middleware/core/ResponseTransformMiddleware.ts new file mode 100644 index 0000000000..fbb3bac198 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/ResponseTransformMiddleware.ts @@ -0,0 +1,85 @@ +import Logger from '@renderer/config/logger' +import { SdkRawChunk } from '@renderer/types/sdk' + +import { ResponseChunkTransformerContext } from '../../clients/types' +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'ResponseTransformMiddleware' + +/** + * 响应转换中间件 + * + * 职责: + * 1. 检测ReadableStream类型的响应流 + * 2. 使用ApiClient的getResponseChunkTransformer()将原始SDK响应块转换为通用格式 + * 3. 将转换后的ReadableStream保存到ctx._internal.apiCall.genericChunkStream,供下游中间件使用 + * + * 注意:此中间件应该在StreamAdapterMiddleware之后执行 + */ +export const ResponseTransformMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + // 调用下游中间件 + const result = await next(ctx, params) + + // 响应后处理:转换原始SDK响应块 + if (result.stream) { + const adaptedStream = result.stream + + // 处理ReadableStream类型的流 + if (adaptedStream instanceof ReadableStream) { + const apiClient = ctx.apiClientInstance + if (!apiClient) { + console.error(`[${MIDDLEWARE_NAME}] ApiClient instance not found in context`) + throw new Error('ApiClient instance not found in context') + } + + // 获取响应转换器 + const responseChunkTransformer = apiClient.getResponseChunkTransformer?.() + if (!responseChunkTransformer) { + Logger.warn(`[${MIDDLEWARE_NAME}] No ResponseChunkTransformer available, skipping transformation`) + return result + } + + const assistant = params.assistant + const model = assistant?.model + + if (!assistant || !model) { + console.error(`[${MIDDLEWARE_NAME}] Assistant or Model not found for transformation`) + throw new Error('Assistant or Model not found for transformation') + } + + const transformerContext: ResponseChunkTransformerContext = { + isStreaming: params.streamOutput || false, + isEnabledToolCalling: (params.mcpTools && params.mcpTools.length > 0) || false, + isEnabledWebSearch: params.enableWebSearch || false, + isEnabledReasoning: params.enableReasoning || false, + mcpTools: params.mcpTools || [], + provider: ctx.apiClientInstance?.provider + } + + console.log(`[${MIDDLEWARE_NAME}] Transforming raw SDK chunks with context:`, transformerContext) + + try { + // 创建转换后的流 + const genericChunkTransformStream = (adaptedStream as ReadableStream).pipeThrough( + new TransformStream(responseChunkTransformer(transformerContext)) + ) + + // 将转换后的ReadableStream保存到result,供下游中间件使用 + return { + ...result, + stream: genericChunkTransformStream + } + } catch (error) { + Logger.error(`[${MIDDLEWARE_NAME}] Error during chunk transformation:`, error) + throw error + } + } + } + + // 如果没有流或不是ReadableStream,返回原始结果 + return result + } diff --git a/src/renderer/src/aiCore/middleware/core/StreamAdapterMiddleware.ts b/src/renderer/src/aiCore/middleware/core/StreamAdapterMiddleware.ts new file mode 100644 index 0000000000..118d96e035 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/StreamAdapterMiddleware.ts @@ -0,0 +1,57 @@ +import { SdkRawChunk } from '@renderer/types/sdk' +import { asyncGeneratorToReadableStream, createSingleChunkReadableStream } from '@renderer/utils/stream' + +import { CompletionsParams, CompletionsResult } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' +import { isAsyncIterable } from '../utils' + +export const MIDDLEWARE_NAME = 'StreamAdapterMiddleware' + +/** + * 流适配器中间件 + * + * 职责: + * 1. 检测ctx._internal.apiCall.rawSdkOutput(优先)或原始AsyncIterable流 + * 2. 将AsyncIterable转换为WHATWG ReadableStream + * 3. 更新响应结果中的stream + * + * 注意:如果ResponseTransformMiddleware已处理过,会优先使用transformedStream + */ +export const StreamAdapterMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + // TODO:调用开始,因为这个是最靠近接口请求的地方,next执行代表着开始接口请求了 + // 但是这个中间件的职责是流适配,是否在这调用优待商榷 + // 调用下游中间件 + const result = await next(ctx, params) + + if ( + result.rawOutput && + !(result.rawOutput instanceof ReadableStream) && + isAsyncIterable(result.rawOutput) + ) { + const whatwgReadableStream: ReadableStream = asyncGeneratorToReadableStream( + result.rawOutput + ) + return { + ...result, + stream: whatwgReadableStream + } + } else if (result.rawOutput && result.rawOutput instanceof ReadableStream) { + return { + ...result, + stream: result.rawOutput + } + } else if (result.rawOutput) { + // 非流式输出,强行变为可读流 + const whatwgReadableStream: ReadableStream = createSingleChunkReadableStream( + result.rawOutput + ) + return { + ...result, + stream: whatwgReadableStream + } + } + return result + } diff --git a/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts b/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts new file mode 100644 index 0000000000..2a3255356f --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts @@ -0,0 +1,99 @@ +import Logger from '@renderer/config/logger' +import { ChunkType, TextDeltaChunk } from '@renderer/types/chunk' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'TextChunkMiddleware' + +/** + * 文本块处理中间件 + * + * 职责: + * 1. 累积文本内容(TEXT_DELTA) + * 2. 对文本内容进行智能链接转换 + * 3. 生成TEXT_COMPLETE事件 + * 4. 暂存Web搜索结果,用于最终链接完善 + * 5. 处理 onResponse 回调,实时发送文本更新和最终完整文本 + */ +export const TextChunkMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + // 调用下游中间件 + const result = await next(ctx, params) + + // 响应后处理:转换流式响应中的文本内容 + if (result.stream) { + const resultFromUpstream = result.stream as ReadableStream + + if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { + const assistant = params.assistant + const model = params.assistant?.model + + if (!assistant || !model) { + Logger.warn(`[${MIDDLEWARE_NAME}] Missing assistant or model information, skipping text processing`) + return result + } + + // 用于跨chunk的状态管理 + let accumulatedTextContent = '' + let hasEnqueue = false + const enhancedTextStream = resultFromUpstream.pipeThrough( + new TransformStream({ + transform(chunk: GenericChunk, controller) { + if (chunk.type === ChunkType.TEXT_DELTA) { + const textChunk = chunk as TextDeltaChunk + accumulatedTextContent += textChunk.text + + // 处理 onResponse 回调 - 发送增量文本更新 + if (params.onResponse) { + params.onResponse(accumulatedTextContent, false) + } + + // 创建新的chunk,包含处理后的文本 + controller.enqueue(chunk) + } else if (accumulatedTextContent) { + if (chunk.type !== ChunkType.LLM_RESPONSE_COMPLETE) { + controller.enqueue(chunk) + hasEnqueue = true + } + const finalText = accumulatedTextContent + ctx._internal.customState!.accumulatedText = finalText + if (ctx._internal.toolProcessingState && !ctx._internal.toolProcessingState?.output) { + ctx._internal.toolProcessingState.output = finalText + } + + // 处理 onResponse 回调 - 发送最终完整文本 + if (params.onResponse) { + params.onResponse(finalText, true) + } + + controller.enqueue({ + type: ChunkType.TEXT_COMPLETE, + text: finalText + }) + accumulatedTextContent = '' + if (!hasEnqueue) { + controller.enqueue(chunk) + } + } else { + // 其他类型的chunk直接传递 + controller.enqueue(chunk) + } + } + }) + ) + + // 更新响应结果 + return { + ...result, + stream: enhancedTextStream + } + } else { + Logger.warn(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream. Returning original result.`) + } + } + + return result + } diff --git a/src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts b/src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts new file mode 100644 index 0000000000..b0df8313a5 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts @@ -0,0 +1,101 @@ +import Logger from '@renderer/config/logger' +import { ChunkType, ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'ThinkChunkMiddleware' + +/** + * 处理思考内容的中间件 + * + * 注意:从 v2 版本开始,流结束语义的判断已移至 ApiClient 层处理 + * 此中间件现在主要负责: + * 1. 处理原始SDK chunk中的reasoning字段 + * 2. 计算准确的思考时间 + * 3. 在思考内容结束时生成THINKING_COMPLETE事件 + * + * 职责: + * 1. 累积思考内容(THINKING_DELTA) + * 2. 监听流结束信号,生成THINKING_COMPLETE事件 + * 3. 计算准确的思考时间 + * + */ +export const ThinkChunkMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + // 调用下游中间件 + const result = await next(ctx, params) + + // 响应后处理:处理思考内容 + if (result.stream) { + const resultFromUpstream = result.stream as ReadableStream + + // 检查是否启用reasoning + const enableReasoning = params.enableReasoning || false + if (!enableReasoning) { + return result + } + + // 检查是否有流需要处理 + if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { + // thinking 处理状态 + let accumulatedThinkingContent = '' + let hasThinkingContent = false + let thinkingStartTime = 0 + + const processedStream = resultFromUpstream.pipeThrough( + new TransformStream({ + transform(chunk: GenericChunk, controller) { + if (chunk.type === ChunkType.THINKING_DELTA) { + const thinkingChunk = chunk as ThinkingDeltaChunk + + // 第一次接收到思考内容时记录开始时间 + if (!hasThinkingContent) { + hasThinkingContent = true + thinkingStartTime = Date.now() + } + + accumulatedThinkingContent += thinkingChunk.text + + // 更新思考时间并传递 + const enhancedChunk: ThinkingDeltaChunk = { + ...thinkingChunk, + thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 + } + controller.enqueue(enhancedChunk) + } else if (hasThinkingContent && thinkingStartTime > 0) { + // 收到任何非THINKING_DELTA的chunk时,如果有累积的思考内容,生成THINKING_COMPLETE + const thinkingCompleteChunk: ThinkingCompleteChunk = { + type: ChunkType.THINKING_COMPLETE, + text: accumulatedThinkingContent, + thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 + } + controller.enqueue(thinkingCompleteChunk) + hasThinkingContent = false + accumulatedThinkingContent = '' + thinkingStartTime = 0 + + // 继续传递当前chunk + controller.enqueue(chunk) + } else { + // 其他情况直接传递 + controller.enqueue(chunk) + } + } + }) + ) + + // 更新响应结果 + return { + ...result, + stream: processedStream + } + } else { + Logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`) + } + } + + return result + } diff --git a/src/renderer/src/aiCore/middleware/core/TransformCoreToSdkParamsMiddleware.ts b/src/renderer/src/aiCore/middleware/core/TransformCoreToSdkParamsMiddleware.ts new file mode 100644 index 0000000000..12c7a27acd --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/TransformCoreToSdkParamsMiddleware.ts @@ -0,0 +1,83 @@ +import Logger from '@renderer/config/logger' +import { ChunkType } from '@renderer/types/chunk' + +import { CompletionsParams, CompletionsResult } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'TransformCoreToSdkParamsMiddleware' + +/** + * 中间件:将CoreCompletionsRequest转换为SDK特定的参数 + * 使用上下文中ApiClient实例的requestTransformer进行转换 + */ +export const TransformCoreToSdkParamsMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + Logger.debug(`🔄 [${MIDDLEWARE_NAME}] Starting core to SDK params transformation:`, ctx) + + const internal = ctx._internal + + // 🔧 检测递归调用:检查 params 中是否携带了预处理的 SDK 消息 + const isRecursiveCall = internal?.toolProcessingState?.isRecursiveCall || false + const newSdkMessages = params._internal?.newReqMessages + + const apiClient = ctx.apiClientInstance + + if (!apiClient) { + Logger.error(`🔄 [${MIDDLEWARE_NAME}] ApiClient instance not found in context.`) + throw new Error('ApiClient instance not found in context') + } + + // 检查是否有requestTransformer方法 + const requestTransformer = apiClient.getRequestTransformer() + if (!requestTransformer) { + Logger.warn( + `🔄 [${MIDDLEWARE_NAME}] ApiClient does not have getRequestTransformer method, skipping transformation` + ) + const result = await next(ctx, params) + return result + } + + // 确保assistant和model可用,它们是transformer所需的 + const assistant = params.assistant + const model = params.assistant.model + + if (!assistant || !model) { + console.error(`🔄 [${MIDDLEWARE_NAME}] Assistant or Model not found for transformation.`) + throw new Error('Assistant or Model not found for transformation') + } + + try { + const transformResult = await requestTransformer.transform( + params, + assistant, + model, + isRecursiveCall, + newSdkMessages + ) + + const { payload: sdkPayload, metadata } = transformResult + + // 将SDK特定的payload和metadata存储在状态中,供下游中间件使用 + ctx._internal.sdkPayload = sdkPayload + + if (metadata) { + ctx._internal.customState = { + ...ctx._internal.customState, + sdkMetadata: metadata + } + } + + if (params.enableGenerateImage) { + params.onChunk?.({ + type: ChunkType.IMAGE_CREATED + }) + } + return next(ctx, params) + } catch (error) { + Logger.error(`🔄 [${MIDDLEWARE_NAME}] Error during request transformation:`, error) + // 让错误向上传播,或者可以在这里进行特定的错误处理 + throw error + } + } diff --git a/src/renderer/src/aiCore/middleware/core/WebSearchMiddleware.ts b/src/renderer/src/aiCore/middleware/core/WebSearchMiddleware.ts new file mode 100644 index 0000000000..97261e3d52 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/core/WebSearchMiddleware.ts @@ -0,0 +1,76 @@ +import { ChunkType } from '@renderer/types/chunk' +import { smartLinkConverter } from '@renderer/utils/linkConverter' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'WebSearchMiddleware' + +/** + * Web搜索处理中间件 - 基于GenericChunk流处理 + * + * 职责: + * 1. 监听和记录Web搜索事件 + * 2. 可以在此处添加Web搜索结果的后处理逻辑 + * 3. 维护Web搜索相关的状态 + * + * 注意:Web搜索结果的识别和生成已在ApiClient的响应转换器中处理 + */ +export const WebSearchMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + ctx._internal.webSearchState = { + results: undefined + } + // 调用下游中间件 + const result = await next(ctx, params) + + const model = params.assistant?.model! + let isFirstChunk = true + + // 响应后处理:记录Web搜索事件 + if (result.stream) { + const resultFromUpstream = result.stream + + if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { + // Web搜索状态跟踪 + const enhancedStream = (resultFromUpstream as ReadableStream).pipeThrough( + new TransformStream({ + transform(chunk: GenericChunk, controller) { + if (chunk.type === ChunkType.TEXT_DELTA) { + const providerType = model.provider || 'openai' + // 使用当前可用的Web搜索结果进行链接转换 + const text = chunk.text + const processedText = smartLinkConverter(text, providerType, isFirstChunk) + if (isFirstChunk) { + isFirstChunk = false + } + controller.enqueue({ + ...chunk, + text: processedText + }) + } else if (chunk.type === ChunkType.LLM_WEB_SEARCH_COMPLETE) { + // 暂存Web搜索结果用于链接完善 + ctx._internal.webSearchState!.results = chunk.llm_web_search + + // 将Web搜索完成事件继续传递下去 + controller.enqueue(chunk) + } else { + controller.enqueue(chunk) + } + } + }) + ) + + return { + ...result, + stream: enhancedStream + } + } else { + console.log(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream.`) + } + } + + return result + } diff --git a/src/renderer/src/aiCore/middleware/feat/ImageGenerationMiddleware.ts b/src/renderer/src/aiCore/middleware/feat/ImageGenerationMiddleware.ts new file mode 100644 index 0000000000..560a9e0aac --- /dev/null +++ b/src/renderer/src/aiCore/middleware/feat/ImageGenerationMiddleware.ts @@ -0,0 +1,132 @@ +import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient' +import { isDedicatedImageGenerationModel } from '@renderer/config/models' +import { ChunkType } from '@renderer/types/chunk' +import { findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' +import OpenAI from 'openai' +import { toFile } from 'openai/uploads' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'ImageGenerationMiddleware' + +export const ImageGenerationMiddleware: CompletionsMiddleware = + () => + (next) => + async (context: CompletionsContext, params: CompletionsParams): Promise => { + const { assistant, messages } = params + const client = context.apiClientInstance as BaseApiClient + const signal = context._internal?.flowControl?.abortSignal + + if (!assistant.model || !isDedicatedImageGenerationModel(assistant.model) || typeof messages === 'string') { + return next(context, params) + } + + const stream = new ReadableStream({ + async start(controller) { + const enqueue = (chunk: GenericChunk) => controller.enqueue(chunk) + + try { + if (!assistant.model) { + throw new Error('Assistant model is not defined.') + } + + const sdk = await client.getSdkInstance() + const lastUserMessage = messages.findLast((m) => m.role === 'user') + const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant') + + if (!lastUserMessage) { + throw new Error('No user message found for image generation.') + } + + const prompt = getMainTextContent(lastUserMessage) + let imageFiles: Blob[] = [] + + // Collect images from user message + const userImageBlocks = findImageBlocks(lastUserMessage) + const userImages = await Promise.all( + userImageBlocks.map(async (block) => { + if (!block.file) return null + const binaryData: Uint8Array = await window.api.file.binaryImage(block.file.id) + const mimeType = `${block.file.type}/${block.file.ext.slice(1)}` + return await toFile(new Blob([binaryData]), block.file.origin_name || 'image.png', { type: mimeType }) + }) + ) + imageFiles = imageFiles.concat(userImages.filter(Boolean) as Blob[]) + + // Collect images from last assistant message + if (lastAssistantMessage) { + const assistantImageBlocks = findImageBlocks(lastAssistantMessage) + const assistantImages = await Promise.all( + assistantImageBlocks.map(async (block) => { + const b64 = block.url?.replace(/^data:image\/\w+;base64,/, '') + if (!b64) return null + const binary = atob(b64) + const bytes = new Uint8Array(binary.length) + for (let i = 0; i < binary.length; i++) bytes[i] = binary.charCodeAt(i) + return await toFile(new Blob([bytes]), 'assistant_image.png', { type: 'image/png' }) + }) + ) + imageFiles = imageFiles.concat(assistantImages.filter(Boolean) as Blob[]) + } + + enqueue({ type: ChunkType.IMAGE_CREATED }) + + const startTime = Date.now() + let response: OpenAI.Images.ImagesResponse + + const options = { signal, timeout: 300_000 } + + if (imageFiles.length > 0) { + response = await sdk.images.edit( + { + model: assistant.model.id, + image: imageFiles, + prompt: prompt || '' + }, + options + ) + } else { + response = await sdk.images.generate( + { + model: assistant.model.id, + prompt: prompt || '', + response_format: assistant.model.id.includes('gpt-image-1') ? undefined : 'b64_json' + }, + options + ) + } + + const b64_json_array = response.data?.map((item) => `data:image/png;base64,${item.b64_json}`) || [] + + enqueue({ + type: ChunkType.IMAGE_COMPLETE, + image: { type: 'base64', images: b64_json_array } + }) + + const usage = (response as any).usage || { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 } + + enqueue({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage, + metrics: { + completion_tokens: usage.completion_tokens, + time_first_token_millsec: 0, + time_completion_millsec: Date.now() - startTime + } + } + }) + } catch (error: any) { + enqueue({ type: ChunkType.ERROR, error }) + } finally { + controller.close() + } + } + }) + + return { + stream, + getText: () => '' + } + } diff --git a/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts b/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts new file mode 100644 index 0000000000..440de40045 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts @@ -0,0 +1,136 @@ +import { Model } from '@renderer/types' +import { ChunkType, TextDeltaChunk, ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk' +import { TagConfig, TagExtractor } from '@renderer/utils/tagExtraction' +import Logger from 'electron-log/renderer' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'ThinkingTagExtractionMiddleware' + +// 不同模型的思考标签配置 +const reasoningTags: TagConfig[] = [ + { openingTag: '', closingTag: '', separator: '\n' }, + { openingTag: '###Thinking', closingTag: '###Response', separator: '\n' } +] + +const getAppropriateTag = (model?: Model): TagConfig => { + if (model?.id?.includes('qwen3')) return reasoningTags[0] + // 可以在这里添加更多模型特定的标签配置 + return reasoningTags[0] // 默认使用 标签 +} + +/** + * 处理文本流中思考标签提取的中间件 + * + * 该中间件专门处理文本流中的思考标签内容(如 ...) + * 主要用于 OpenAI 等支持思考标签的 provider + * + * 职责: + * 1. 从文本流中提取思考标签内容 + * 2. 将标签内的内容转换为 THINKING_DELTA chunk + * 3. 将标签外的内容作为正常文本输出 + * 4. 处理不同模型的思考标签格式 + * 5. 在思考内容结束时生成 THINKING_COMPLETE 事件 + */ +export const ThinkingTagExtractionMiddleware: CompletionsMiddleware = + () => + (next) => + async (context: CompletionsContext, params: CompletionsParams): Promise => { + // 调用下游中间件 + const result = await next(context, params) + + // 响应后处理:处理思考标签提取 + if (result.stream) { + const resultFromUpstream = result.stream as ReadableStream + + // 检查是否有流需要处理 + if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { + // 获取当前模型的思考标签配置 + const model = params.assistant?.model + const reasoningTag = getAppropriateTag(model) + + // 创建标签提取器 + const tagExtractor = new TagExtractor(reasoningTag) + + // thinking 处理状态 + let hasThinkingContent = false + let thinkingStartTime = 0 + + const processedStream = resultFromUpstream.pipeThrough( + new TransformStream({ + transform(chunk: GenericChunk, controller) { + if (chunk.type === ChunkType.TEXT_DELTA) { + const textChunk = chunk as TextDeltaChunk + + // 使用 TagExtractor 处理文本 + const extractionResults = tagExtractor.processText(textChunk.text) + + for (const extractionResult of extractionResults) { + if (extractionResult.complete && extractionResult.tagContentExtracted) { + // 生成 THINKING_COMPLETE 事件 + const thinkingCompleteChunk: ThinkingCompleteChunk = { + type: ChunkType.THINKING_COMPLETE, + text: extractionResult.tagContentExtracted, + thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 + } + controller.enqueue(thinkingCompleteChunk) + + // 重置思考状态 + hasThinkingContent = false + thinkingStartTime = 0 + } else if (extractionResult.content.length > 0) { + if (extractionResult.isTagContent) { + // 第一次接收到思考内容时记录开始时间 + if (!hasThinkingContent) { + hasThinkingContent = true + thinkingStartTime = Date.now() + } + + const thinkingDeltaChunk: ThinkingDeltaChunk = { + type: ChunkType.THINKING_DELTA, + text: extractionResult.content, + thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 + } + controller.enqueue(thinkingDeltaChunk) + } else { + // 发送清理后的文本内容 + const cleanTextChunk: TextDeltaChunk = { + ...textChunk, + text: extractionResult.content + } + controller.enqueue(cleanTextChunk) + } + } + } + } else { + // 其他类型的chunk直接传递(包括 THINKING_DELTA, THINKING_COMPLETE 等) + controller.enqueue(chunk) + } + }, + flush(controller) { + // 处理可能剩余的思考内容 + const finalResult = tagExtractor.finalize() + if (finalResult?.tagContentExtracted) { + const thinkingCompleteChunk: ThinkingCompleteChunk = { + type: ChunkType.THINKING_COMPLETE, + text: finalResult.tagContentExtracted, + thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 + } + controller.enqueue(thinkingCompleteChunk) + } + } + }) + ) + + // 更新响应结果 + return { + ...result, + stream: processedStream + } + } else { + Logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`) + } + } + return result + } diff --git a/src/renderer/src/aiCore/middleware/feat/ToolUseExtractionMiddleware.ts b/src/renderer/src/aiCore/middleware/feat/ToolUseExtractionMiddleware.ts new file mode 100644 index 0000000000..5f444953a9 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/feat/ToolUseExtractionMiddleware.ts @@ -0,0 +1,124 @@ +import { MCPTool } from '@renderer/types' +import { ChunkType, MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk' +import { parseToolUse } from '@renderer/utils/mcp-tools' +import { TagConfig, TagExtractor } from '@renderer/utils/tagExtraction' + +import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' +import { CompletionsContext, CompletionsMiddleware } from '../types' + +export const MIDDLEWARE_NAME = 'ToolUseExtractionMiddleware' + +// 工具使用标签配置 +const TOOL_USE_TAG_CONFIG: TagConfig = { + openingTag: '', + closingTag: '', + separator: '\n' +} + +/** + * 工具使用提取中间件 + * + * 职责: + * 1. 从文本流中检测并提取 标签 + * 2. 解析工具调用信息并转换为 ToolUseResponse 格式 + * 3. 生成 MCP_TOOL_CREATED chunk 供 McpToolChunkMiddleware 处理 + * 4. 清理文本流,移除工具使用标签但保留正常文本 + * + * 注意:此中间件只负责提取和转换,实际工具调用由 McpToolChunkMiddleware 处理 + */ +export const ToolUseExtractionMiddleware: CompletionsMiddleware = + () => + (next) => + async (ctx: CompletionsContext, params: CompletionsParams): Promise => { + const mcpTools = params.mcpTools || [] + + // 如果没有工具,直接调用下一个中间件 + if (!mcpTools || mcpTools.length === 0) return next(ctx, params) + + // 调用下游中间件 + const result = await next(ctx, params) + + // 响应后处理:处理工具使用标签提取 + if (result.stream) { + const resultFromUpstream = result.stream as ReadableStream + + const processedStream = resultFromUpstream.pipeThrough(createToolUseExtractionTransform(ctx, mcpTools)) + + return { + ...result, + stream: processedStream + } + } + + return result + } + +/** + * 创建工具使用提取的 TransformStream + */ +function createToolUseExtractionTransform( + _ctx: CompletionsContext, + mcpTools: MCPTool[] +): TransformStream { + const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG) + + return new TransformStream({ + async transform(chunk: GenericChunk, controller) { + try { + // 处理文本内容,检测工具使用标签 + if (chunk.type === ChunkType.TEXT_DELTA) { + const textChunk = chunk as TextDeltaChunk + const extractionResults = tagExtractor.processText(textChunk.text) + + for (const result of extractionResults) { + if (result.complete && result.tagContentExtracted) { + // 提取到完整的工具使用内容,解析并转换为 SDK ToolCall 格式 + const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools) + + if (toolUseResponses.length > 0) { + // 生成 MCP_TOOL_CREATED chunk,复用现有的处理流程 + const mcpToolCreatedChunk: MCPToolCreatedChunk = { + type: ChunkType.MCP_TOOL_CREATED, + tool_use_responses: toolUseResponses + } + controller.enqueue(mcpToolCreatedChunk) + } + } else if (!result.isTagContent && result.content) { + // 发送标签外的正常文本内容 + const cleanTextChunk: TextDeltaChunk = { + ...textChunk, + text: result.content + } + controller.enqueue(cleanTextChunk) + } + // 注意:标签内的内容不会作为TEXT_DELTA转发,避免重复显示 + } + return + } + + // 转发其他所有chunk + controller.enqueue(chunk) + } catch (error) { + console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error) + controller.error(error) + } + }, + + async flush(controller) { + // 检查是否有未完成的标签内容 + const finalResult = tagExtractor.finalize() + if (finalResult && finalResult.tagContentExtracted) { + const toolUseResponses = parseToolUse(finalResult.tagContentExtracted, mcpTools) + if (toolUseResponses.length > 0) { + const mcpToolCreatedChunk: MCPToolCreatedChunk = { + type: ChunkType.MCP_TOOL_CREATED, + tool_use_responses: toolUseResponses + } + controller.enqueue(mcpToolCreatedChunk) + } + } + } + }) +} + +export default ToolUseExtractionMiddleware diff --git a/src/renderer/src/aiCore/middleware/index.ts b/src/renderer/src/aiCore/middleware/index.ts new file mode 100644 index 0000000000..64be4edd44 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/index.ts @@ -0,0 +1,88 @@ +import { CompletionsMiddleware, MethodMiddleware } from './types' + +// /** +// * Wraps a provider instance with middlewares. +// */ +// export function wrapProviderWithMiddleware( +// apiClientInstance: BaseApiClient, +// middlewareConfig: MiddlewareConfig +// ): BaseApiClient { +// console.log(`[wrapProviderWithMiddleware] Wrapping provider: ${apiClientInstance.provider?.id}`) +// console.log(`[wrapProviderWithMiddleware] Middleware config:`, { +// completions: middlewareConfig.completions?.length || 0, +// methods: Object.keys(middlewareConfig.methods || {}).length +// }) + +// // Cache for already wrapped methods to avoid re-wrapping on every access. +// const wrappedMethodsCache = new Map Promise>() + +// const proxy = new Proxy(apiClientInstance, { +// get(target, propKey, receiver) { +// const methodName = typeof propKey === 'string' ? propKey : undefined + +// if (!methodName) { +// return Reflect.get(target, propKey, receiver) +// } + +// if (wrappedMethodsCache.has(methodName)) { +// console.log(`[wrapProviderWithMiddleware] Using cached wrapped method: ${methodName}`) +// return wrappedMethodsCache.get(methodName) +// } + +// const originalMethod = Reflect.get(target, propKey, receiver) + +// // If the property is not a function, return it directly. +// if (typeof originalMethod !== 'function') { +// return originalMethod +// } + +// let wrappedMethod: ((...args: any[]) => Promise) | undefined + +// // Handle completions method +// if (methodName === 'completions' && middlewareConfig.completions?.length) { +// console.log( +// `[wrapProviderWithMiddleware] Wrapping completions method with ${middlewareConfig.completions.length} middlewares` +// ) +// const completionsOriginalMethod = originalMethod as (params: CompletionsParams) => Promise +// wrappedMethod = applyCompletionsMiddlewares(target, completionsOriginalMethod, middlewareConfig.completions) +// } +// // Handle other methods +// else { +// const methodMiddlewares = middlewareConfig.methods?.[methodName] +// if (methodMiddlewares?.length) { +// console.log( +// `[wrapProviderWithMiddleware] Wrapping method ${methodName} with ${methodMiddlewares.length} middlewares` +// ) +// const genericOriginalMethod = originalMethod as (...args: any[]) => Promise +// wrappedMethod = applyMethodMiddlewares(target, methodName, genericOriginalMethod, methodMiddlewares) +// } +// } + +// if (wrappedMethod) { +// console.log(`[wrapProviderWithMiddleware] Successfully wrapped method: ${methodName}`) +// wrappedMethodsCache.set(methodName, wrappedMethod) +// return wrappedMethod +// } + +// // If no middlewares are configured for this method, return the original method bound to the target. / +// // 如果没有为此方法配置中间件,则返回绑定到目标的原始方法。 +// console.log(`[wrapProviderWithMiddleware] No middlewares for method ${methodName}, returning original`) +// return originalMethod.bind(target) +// } +// }) +// return proxy as BaseApiClient +// } + +// Export types for external use +export type { CompletionsMiddleware, MethodMiddleware } + +// Export MiddlewareBuilder related types and classes +export { + CompletionsMiddlewareBuilder, + createCompletionsBuilder, + createMethodBuilder, + MethodMiddlewareBuilder, + MiddlewareBuilder, + type MiddlewareExecutor, + type NamedMiddleware +} from './builder' diff --git a/src/renderer/src/aiCore/middleware/register.ts b/src/renderer/src/aiCore/middleware/register.ts new file mode 100644 index 0000000000..003ce7e93a --- /dev/null +++ b/src/renderer/src/aiCore/middleware/register.ts @@ -0,0 +1,149 @@ +import * as AbortHandlerModule from './common/AbortHandlerMiddleware' +import * as ErrorHandlerModule from './common/ErrorHandlerMiddleware' +import * as FinalChunkConsumerModule from './common/FinalChunkConsumerMiddleware' +import * as LoggingModule from './common/LoggingMiddleware' +import * as McpToolChunkModule from './core/McpToolChunkMiddleware' +import * as RawStreamListenerModule from './core/RawStreamListenerMiddleware' +import * as ResponseTransformModule from './core/ResponseTransformMiddleware' +// import * as SdkCallModule from './core/SdkCallMiddleware' +import * as StreamAdapterModule from './core/StreamAdapterMiddleware' +import * as TextChunkModule from './core/TextChunkMiddleware' +import * as ThinkChunkModule from './core/ThinkChunkMiddleware' +import * as TransformCoreToSdkParamsModule from './core/TransformCoreToSdkParamsMiddleware' +import * as WebSearchModule from './core/WebSearchMiddleware' +import * as ImageGenerationModule from './feat/ImageGenerationMiddleware' +import * as ThinkingTagExtractionModule from './feat/ThinkingTagExtractionMiddleware' +import * as ToolUseExtractionMiddleware from './feat/ToolUseExtractionMiddleware' + +/** + * 中间件注册表 - 提供所有可用中间件的集中访问 + * 注意:目前中间件文件还未导出 MIDDLEWARE_NAME,会有 linter 错误,这是正常的 + */ +export const MiddlewareRegistry = { + [ErrorHandlerModule.MIDDLEWARE_NAME]: { + name: ErrorHandlerModule.MIDDLEWARE_NAME, + middleware: ErrorHandlerModule.ErrorHandlerMiddleware + }, + // 通用中间件 + [AbortHandlerModule.MIDDLEWARE_NAME]: { + name: AbortHandlerModule.MIDDLEWARE_NAME, + middleware: AbortHandlerModule.AbortHandlerMiddleware + }, + [FinalChunkConsumerModule.MIDDLEWARE_NAME]: { + name: FinalChunkConsumerModule.MIDDLEWARE_NAME, + middleware: FinalChunkConsumerModule.default + }, + + // 核心流程中间件 + [TransformCoreToSdkParamsModule.MIDDLEWARE_NAME]: { + name: TransformCoreToSdkParamsModule.MIDDLEWARE_NAME, + middleware: TransformCoreToSdkParamsModule.TransformCoreToSdkParamsMiddleware + }, + // [SdkCallModule.MIDDLEWARE_NAME]: { + // name: SdkCallModule.MIDDLEWARE_NAME, + // middleware: SdkCallModule.SdkCallMiddleware + // }, + [StreamAdapterModule.MIDDLEWARE_NAME]: { + name: StreamAdapterModule.MIDDLEWARE_NAME, + middleware: StreamAdapterModule.StreamAdapterMiddleware + }, + [RawStreamListenerModule.MIDDLEWARE_NAME]: { + name: RawStreamListenerModule.MIDDLEWARE_NAME, + middleware: RawStreamListenerModule.RawStreamListenerMiddleware + }, + [ResponseTransformModule.MIDDLEWARE_NAME]: { + name: ResponseTransformModule.MIDDLEWARE_NAME, + middleware: ResponseTransformModule.ResponseTransformMiddleware + }, + + // 特性处理中间件 + [ThinkingTagExtractionModule.MIDDLEWARE_NAME]: { + name: ThinkingTagExtractionModule.MIDDLEWARE_NAME, + middleware: ThinkingTagExtractionModule.ThinkingTagExtractionMiddleware + }, + [ToolUseExtractionMiddleware.MIDDLEWARE_NAME]: { + name: ToolUseExtractionMiddleware.MIDDLEWARE_NAME, + middleware: ToolUseExtractionMiddleware.ToolUseExtractionMiddleware + }, + [ThinkChunkModule.MIDDLEWARE_NAME]: { + name: ThinkChunkModule.MIDDLEWARE_NAME, + middleware: ThinkChunkModule.ThinkChunkMiddleware + }, + [McpToolChunkModule.MIDDLEWARE_NAME]: { + name: McpToolChunkModule.MIDDLEWARE_NAME, + middleware: McpToolChunkModule.McpToolChunkMiddleware + }, + [WebSearchModule.MIDDLEWARE_NAME]: { + name: WebSearchModule.MIDDLEWARE_NAME, + middleware: WebSearchModule.WebSearchMiddleware + }, + [TextChunkModule.MIDDLEWARE_NAME]: { + name: TextChunkModule.MIDDLEWARE_NAME, + middleware: TextChunkModule.TextChunkMiddleware + }, + [ImageGenerationModule.MIDDLEWARE_NAME]: { + name: ImageGenerationModule.MIDDLEWARE_NAME, + middleware: ImageGenerationModule.ImageGenerationMiddleware + } +} as const + +/** + * 根据名称获取中间件 + * @param name - 中间件名称 + * @returns 对应的中间件信息 + */ +export function getMiddleware(name: string) { + return MiddlewareRegistry[name] +} + +/** + * 获取所有注册的中间件名称 + * @returns 中间件名称列表 + */ +export function getRegisteredMiddlewareNames(): string[] { + return Object.keys(MiddlewareRegistry) +} + +/** + * 默认的 Completions 中间件配置 - NamedMiddleware 格式,用于 MiddlewareBuilder + */ +export const DefaultCompletionsNamedMiddlewares = [ + MiddlewareRegistry[FinalChunkConsumerModule.MIDDLEWARE_NAME], // 最终消费者 + MiddlewareRegistry[ErrorHandlerModule.MIDDLEWARE_NAME], // 错误处理 + MiddlewareRegistry[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME], // 参数转换 + MiddlewareRegistry[AbortHandlerModule.MIDDLEWARE_NAME], // 中止处理 + MiddlewareRegistry[McpToolChunkModule.MIDDLEWARE_NAME], // 工具处理 + MiddlewareRegistry[TextChunkModule.MIDDLEWARE_NAME], // 文本处理 + MiddlewareRegistry[WebSearchModule.MIDDLEWARE_NAME], // Web搜索处理 + MiddlewareRegistry[ToolUseExtractionMiddleware.MIDDLEWARE_NAME], // 工具使用提取处理 + MiddlewareRegistry[ThinkingTagExtractionModule.MIDDLEWARE_NAME], // 思考标签提取处理(特定provider) + MiddlewareRegistry[ThinkChunkModule.MIDDLEWARE_NAME], // 思考处理(通用SDK) + MiddlewareRegistry[ResponseTransformModule.MIDDLEWARE_NAME], // 响应转换 + MiddlewareRegistry[StreamAdapterModule.MIDDLEWARE_NAME], // 流适配器 + MiddlewareRegistry[RawStreamListenerModule.MIDDLEWARE_NAME] // 原始流监听器 +] + +/** + * 默认的通用方法中间件 - 例如翻译、摘要等 + */ +export const DefaultMethodMiddlewares = { + translate: [LoggingModule.createGenericLoggingMiddleware()], + summaries: [LoggingModule.createGenericLoggingMiddleware()] +} + +/** + * 导出所有中间件模块,方便外部使用 + */ +export { + AbortHandlerModule, + FinalChunkConsumerModule, + LoggingModule, + McpToolChunkModule, + ResponseTransformModule, + StreamAdapterModule, + TextChunkModule, + ThinkChunkModule, + ThinkingTagExtractionModule, + TransformCoreToSdkParamsModule, + WebSearchModule +} diff --git a/src/renderer/src/aiCore/middleware/schemas.ts b/src/renderer/src/aiCore/middleware/schemas.ts new file mode 100644 index 0000000000..33d9816b4f --- /dev/null +++ b/src/renderer/src/aiCore/middleware/schemas.ts @@ -0,0 +1,77 @@ +import { Assistant, MCPTool } from '@renderer/types' +import { Chunk } from '@renderer/types/chunk' +import { Message } from '@renderer/types/newMessage' +import { SdkRawChunk, SdkRawOutput } from '@renderer/types/sdk' + +import { ProcessingState } from './types' + +// ============================================================================ +// Core Request Types - 核心请求结构 +// ============================================================================ + +/** + * 标准化的内部核心请求结构,用于所有AI Provider的统一处理 + * 这是应用层参数转换后的标准格式,不包含回调函数和控制逻辑 + */ +export interface CompletionsParams { + /** + * 调用的业务场景类型,用于中间件判断是否执行 + * 'chat': 主要对话流程 + * 'translate': 翻译 + * 'summary': 摘要 + * 'search': 搜索摘要 + * 'generate': 生成 + * 'check': API检查 + */ + callType?: 'chat' | 'translate' | 'summary' | 'search' | 'generate' | 'check' + + // 基础对话数据 + messages: Message[] | string // 联合类型方便判断是否为空 + + assistant: Assistant // 助手为基本单位 + // model: Model + + onChunk?: (chunk: Chunk) => void + onResponse?: (text: string, isComplete: boolean) => void + + // 错误相关 + onError?: (error: Error) => void + shouldThrow?: boolean + + // 工具相关 + mcpTools?: MCPTool[] + + // 生成参数 + temperature?: number + topP?: number + maxTokens?: number + + // 功能开关 + streamOutput: boolean + enableWebSearch?: boolean + enableReasoning?: boolean + enableGenerateImage?: boolean + + // 上下文控制 + contextCount?: number + + _internal?: ProcessingState +} + +export interface CompletionsResult { + rawOutput?: SdkRawOutput + stream?: ReadableStream | ReadableStream | AsyncIterable + controller?: AbortController + + getText: () => string +} + +// ============================================================================ +// Generic Chunk Types - 通用数据块结构 +// ============================================================================ + +/** + * 通用数据块类型 + * 复用现有的 Chunk 类型,这是所有AI Provider都应该输出的标准化数据块格式 + */ +export type GenericChunk = Chunk diff --git a/src/renderer/src/aiCore/middleware/types.ts b/src/renderer/src/aiCore/middleware/types.ts new file mode 100644 index 0000000000..0a7dbe390b --- /dev/null +++ b/src/renderer/src/aiCore/middleware/types.ts @@ -0,0 +1,166 @@ +import { MCPToolResponse, Metrics, Usage, WebSearchResponse } from '@renderer/types' +import { Chunk, ErrorChunk } from '@renderer/types/chunk' +import { + SdkInstance, + SdkMessageParam, + SdkParams, + SdkRawChunk, + SdkRawOutput, + SdkTool, + SdkToolCall +} from '@renderer/types/sdk' + +import { BaseApiClient } from '../clients' +import { CompletionsParams, CompletionsResult } from './schemas' + +/** + * Symbol to uniquely identify middleware context objects. + */ +export const MIDDLEWARE_CONTEXT_SYMBOL = Symbol.for('AiProviderMiddlewareContext') + +/** + * Defines the structure for the onChunk callback function. + */ +export type OnChunkFunction = (chunk: Chunk | ErrorChunk) => void + +/** + * Base context that carries information about the current method call. + */ +export interface BaseContext { + [MIDDLEWARE_CONTEXT_SYMBOL]: true + methodName: string + originalArgs: Readonly +} + +/** + * Processing state shared between middlewares. + */ +export interface ProcessingState< + TParams extends SdkParams = SdkParams, + TMessageParam extends SdkMessageParam = SdkMessageParam, + TToolCall extends SdkToolCall = SdkToolCall +> { + sdkPayload?: TParams + newReqMessages?: TMessageParam[] + observer?: { + usage?: Usage + metrics?: Metrics + } + toolProcessingState?: { + pendingToolCalls?: Array + executingToolCalls?: Array<{ + sdkToolCall: TToolCall + mcpToolResponse: MCPToolResponse + }> + output?: SdkRawOutput | string + isRecursiveCall?: boolean + recursionDepth?: number + } + webSearchState?: { + results?: WebSearchResponse + } + flowControl?: { + abortController?: AbortController + abortSignal?: AbortSignal + cleanup?: () => void + } + enhancedDispatch?: (context: CompletionsContext, params: CompletionsParams) => Promise + customState?: Record +} + +/** + * Extended context for completions method. + */ +export interface CompletionsContext< + TSdkParams extends SdkParams = SdkParams, + TSdkMessageParam extends SdkMessageParam = SdkMessageParam, + TSdkToolCall extends SdkToolCall = SdkToolCall, + TSdkInstance extends SdkInstance = SdkInstance, + TRawOutput extends SdkRawOutput = SdkRawOutput, + TRawChunk extends SdkRawChunk = SdkRawChunk, + TSdkSpecificTool extends SdkTool = SdkTool +> extends BaseContext { + readonly methodName: 'completions' // 强制方法名为 'completions' + + apiClientInstance: BaseApiClient< + TSdkInstance, + TSdkParams, + TRawOutput, + TRawChunk, + TSdkMessageParam, + TSdkToolCall, + TSdkSpecificTool + > + + // --- Mutable internal state for the duration of the middleware chain --- + _internal: ProcessingState // 包含所有可变的处理状态 +} + +export interface MiddlewareAPI { + getContext: () => Ctx // Function to get the current context / 获取当前上下文的函数 + getOriginalArgs: () => Args // Function to get the original arguments of the method call / 获取方法调用原始参数的函数 +} + +/** + * Base middleware type. + */ +export type Middleware = ( + api: MiddlewareAPI +) => ( + next: (context: TContext, args: any[]) => Promise +) => (context: TContext, args: any[]) => Promise + +export type MethodMiddleware = Middleware + +/** + * Completions middleware type. + */ +export type CompletionsMiddleware< + TSdkParams extends SdkParams = SdkParams, + TSdkMessageParam extends SdkMessageParam = SdkMessageParam, + TSdkToolCall extends SdkToolCall = SdkToolCall, + TSdkInstance extends SdkInstance = SdkInstance, + TRawOutput extends SdkRawOutput = SdkRawOutput, + TRawChunk extends SdkRawChunk = SdkRawChunk, + TSdkSpecificTool extends SdkTool = SdkTool +> = ( + api: MiddlewareAPI< + CompletionsContext< + TSdkParams, + TSdkMessageParam, + TSdkToolCall, + TSdkInstance, + TRawOutput, + TRawChunk, + TSdkSpecificTool + >, + [CompletionsParams] + > +) => ( + next: ( + context: CompletionsContext< + TSdkParams, + TSdkMessageParam, + TSdkToolCall, + TSdkInstance, + TRawOutput, + TRawChunk, + TSdkSpecificTool + >, + params: CompletionsParams + ) => Promise +) => ( + context: CompletionsContext< + TSdkParams, + TSdkMessageParam, + TSdkToolCall, + TSdkInstance, + TRawOutput, + TRawChunk, + TSdkSpecificTool + >, + params: CompletionsParams +) => Promise + +// Re-export for convenience +export type { Chunk as OnChunkArg } from '@renderer/types/chunk' diff --git a/src/renderer/src/aiCore/middleware/utils.ts b/src/renderer/src/aiCore/middleware/utils.ts new file mode 100644 index 0000000000..12a2fe651d --- /dev/null +++ b/src/renderer/src/aiCore/middleware/utils.ts @@ -0,0 +1,57 @@ +import { ChunkType, ErrorChunk } from '@renderer/types/chunk' + +/** + * Creates an ErrorChunk object with a standardized structure. + * @param error The error object or message. + * @param chunkType The type of chunk, defaults to ChunkType.ERROR. + * @returns An ErrorChunk object. + */ +export function createErrorChunk(error: any, chunkType: ChunkType = ChunkType.ERROR): ErrorChunk { + let errorDetails: Record = {} + + if (error instanceof Error) { + errorDetails = { + message: error.message, + name: error.name, + stack: error.stack + } + } else if (typeof error === 'string') { + errorDetails = { message: error } + } else if (typeof error === 'object' && error !== null) { + errorDetails = Object.getOwnPropertyNames(error).reduce( + (acc, key) => { + acc[key] = error[key] + return acc + }, + {} as Record + ) + if (!errorDetails.message && error.toString && typeof error.toString === 'function') { + const errMsg = error.toString() + if (errMsg !== '[object Object]') { + errorDetails.message = errMsg + } + } + } + + return { + type: chunkType, + error: errorDetails + } as ErrorChunk +} + +// Helper to capitalize method names for hook construction +export function capitalize(str: string): string { + if (!str) return '' + return str.charAt(0).toUpperCase() + str.slice(1) +} + +/** + * 检查对象是否实现了AsyncIterable接口 + */ +export function isAsyncIterable(obj: unknown): obj is AsyncIterable { + return ( + obj !== null && + typeof obj === 'object' && + typeof (obj as Record)[Symbol.asyncIterator] === 'function' + ) +} diff --git a/src/renderer/src/assets/images/models/gpt_image_1.png b/src/renderer/src/assets/images/models/gpt_image_1.png new file mode 100644 index 0000000000..30f2f2708f Binary files /dev/null and b/src/renderer/src/assets/images/models/gpt_image_1.png differ diff --git a/src/renderer/src/assets/styles/ant.scss b/src/renderer/src/assets/styles/ant.scss index d44bdca70c..0553c65a9c 100644 --- a/src/renderer/src/assets/styles/ant.scss +++ b/src/renderer/src/assets/styles/ant.scss @@ -202,19 +202,14 @@ overflow-y: auto; overflow-x: hidden; border: 0.5px solid var(--color-border); - border-radius: 12px; + border-radius: 10px; } .ant-dropdown { .ant-dropdown-menu { max-height: 50vh; overflow-y: auto; border: 0.5px solid var(--color-border); - border-radius: 12px; - - .ant-dropdown-menu-item, - .ant-dropdown-menu-submenu-title { - padding: 5px; - } + border-radius: 10px; } .ant-dropdown-arrow + .ant-dropdown-menu { border: none; @@ -224,7 +219,7 @@ .ant-popover { .ant-popover-inner { border: 0.5px solid var(--color-border); - border-radius: 12px; + border-radius: 10px; .ant-popover-inner-content { max-height: 70vh; overflow-y: auto; @@ -242,12 +237,12 @@ padding: 16px 0 0 0; } .ant-modal-content { - border-radius: 12px; + border-radius: 10px; border: 0.5px solid var(--color-border); padding: 0 0 8px 0; .ant-modal-header { padding: 16px 16px 0 16px; - border-radius: 12px; + border-radius: 10px; } .ant-modal-body { max-height: 80vh; diff --git a/src/renderer/src/components/DragableList/index.tsx b/src/renderer/src/components/DragableList/index.tsx index 4240b8452a..2be4dcf402 100644 --- a/src/renderer/src/components/DragableList/index.tsx +++ b/src/renderer/src/components/DragableList/index.tsx @@ -60,9 +60,9 @@ const DragableList: FC> = ({ {...provided.draggableProps} {...provided.dragHandleProps} style={{ + marginBottom: 8, ...listStyle, - ...provided.draggableProps.style, - marginBottom: 8 + ...provided.draggableProps.style }}> {children(item, index)} diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index acebf6171c..78f4ff3d0b 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -55,6 +55,7 @@ import { default as ChatGptModelLogoDakr, default as ChatGPTo1ModelLogoDark } from '@renderer/assets/images/models/gpt_dark.png' +import ChatGPTImageModelLogo from '@renderer/assets/images/models/gpt_image_1.png' import ChatGPTo1ModelLogo from '@renderer/assets/images/models/gpt_o1.png' import GrokModelLogo from '@renderer/assets/images/models/grok.png' import GrokModelLogoDark from '@renderer/assets/images/models/grok_dark.png' @@ -143,7 +144,7 @@ import YiModelLogoDark from '@renderer/assets/images/models/yi_dark.png' import YoudaoLogo from '@renderer/assets/images/providers/netease-youdao.svg' import NomicLogo from '@renderer/assets/images/providers/nomic.png' import { getProviderByModel } from '@renderer/services/AssistantService' -import { Assistant, Model } from '@renderer/types' +import { Model } from '@renderer/types' import OpenAI from 'openai' import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from './prompts' @@ -181,7 +182,8 @@ const visionAllowedModels = [ 'o4(?:-[\\w-]+)?', 'deepseek-vl(?:[\\w-]+)?', 'kimi-latest', - 'gemma-3(?:-[\\w-]+)' + 'gemma-3(?:-[\\w-]+)', + 'doubao-1.6-seed(?:-[\\w-]+)' ] const visionExcludedModels = [ @@ -199,6 +201,11 @@ export const VISION_REGEX = new RegExp( 'i' ) +// For middleware to identify models that must use the dedicated Image API +export const DEDICATED_IMAGE_MODELS = ['grok-2-image', 'dall-e-3', 'dall-e-2', 'gpt-image-1'] +export const isDedicatedImageGenerationModel = (model: Model): boolean => + DEDICATED_IMAGE_MODELS.filter((m) => model.id.includes(m)).length > 0 + // Text to image models export const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus/i @@ -286,6 +293,7 @@ export function getModelLogo(modelId: string) { o1: isLight ? ChatGPTo1ModelLogo : ChatGPTo1ModelLogoDark, o3: isLight ? ChatGPTo1ModelLogo : ChatGPTo1ModelLogoDark, o4: isLight ? ChatGPTo1ModelLogo : ChatGPTo1ModelLogoDark, + 'gpt-image': ChatGPTImageModelLogo, 'gpt-3': isLight ? ChatGPT35ModelLogo : ChatGPT35ModelLogoDark, 'gpt-4': isLight ? ChatGPT4ModelLogo : ChatGPT4ModelLogoDark, gpts: isLight ? ChatGPT4ModelLogo : ChatGPT4ModelLogoDark, @@ -307,6 +315,7 @@ export function getModelLogo(modelId: string) { mistral: isLight ? MistralModelLogo : MistralModelLogoDark, codestral: CodestralModelLogo, ministral: isLight ? MistralModelLogo : MistralModelLogoDark, + magistral: isLight ? MistralModelLogo : MistralModelLogoDark, moonshot: isLight ? MoonshotModelLogo : MoonshotModelLogoDark, kimi: isLight ? MoonshotModelLogo : MoonshotModelLogoDark, phi: isLight ? MicrosoftModelLogo : MicrosoftModelLogoDark, @@ -2246,14 +2255,24 @@ export const TEXT_TO_IMAGES_MODELS_SUPPORT_IMAGE_ENHANCEMENT = [ 'stabilityai/stable-diffusion-xl-base-1.0' ] +export const SUPPORTED_DISABLE_GENERATION_MODELS = [ + 'gemini-2.0-flash-exp', + 'gpt-4o', + 'gpt-4o-mini', + 'gpt-4.1', + 'gpt-4.1-mini', + 'gpt-4.1-nano', + 'o3' +] + export const GENERATE_IMAGE_MODELS = [ 'gemini-2.0-flash-exp-image-generation', 'gemini-2.0-flash-preview-image-generation', - 'gemini-2.0-flash-exp', 'grok-2-image-1212', 'grok-2-image', 'grok-2-image-latest', - 'gpt-image-1' + 'gpt-image-1', + ...SUPPORTED_DISABLE_GENERATION_MODELS ] export const GEMINI_SEARCH_MODELS = [ @@ -2362,10 +2381,32 @@ export function isSupportedReasoningEffortOpenAIModel(model: Model): boolean { ) } -export function isOpenAIWebSearch(model: Model): boolean { +export function isOpenAIChatCompletionOnlyModel(model: Model): boolean { + if (!model) { + return false + } + + return ( + model.id.includes('gpt-4o-search-preview') || + model.id.includes('gpt-4o-mini-search-preview') || + model.id.includes('o1-mini') || + model.id.includes('o1-preview') + ) +} + +export function isOpenAIWebSearchChatCompletionOnlyModel(model: Model): boolean { return model.id.includes('gpt-4o-search-preview') || model.id.includes('gpt-4o-mini-search-preview') } +export function isOpenAIWebSearchModel(model: Model): boolean { + return ( + model.id.includes('gpt-4o-search-preview') || + model.id.includes('gpt-4o-mini-search-preview') || + (model.id.includes('gpt-4.1') && !model.id.includes('gpt-4.1-nano')) || + (model.id.includes('gpt-4o') && !model.id.includes('gpt-4o-image')) + ) +} + export function isSupportedThinkingTokenModel(model?: Model): boolean { if (!model) { return false @@ -2374,7 +2415,8 @@ export function isSupportedThinkingTokenModel(model?: Model): boolean { return ( isSupportedThinkingTokenGeminiModel(model) || isSupportedThinkingTokenQwenModel(model) || - isSupportedThinkingTokenClaudeModel(model) + isSupportedThinkingTokenClaudeModel(model) || + isSupportedThinkingTokenDoubaoModel(model) ) } @@ -2456,6 +2498,14 @@ export function isSupportedThinkingTokenQwenModel(model?: Model): boolean { ) } +export function isSupportedThinkingTokenDoubaoModel(model?: Model): boolean { + if (!model) { + return false + } + + return DOUBAO_THINKING_MODEL_REGEX.test(model.id) +} + export function isClaudeReasoningModel(model?: Model): boolean { if (!model) { return false @@ -2476,7 +2526,12 @@ export function isReasoningModel(model?: Model): boolean { } if (model.provider === 'doubao') { - return REASONING_REGEX.test(model.name) || model.type?.includes('reasoning') || false + return ( + REASONING_REGEX.test(model.name) || + model.type?.includes('reasoning') || + isSupportedThinkingTokenDoubaoModel(model) || + false + ) } if ( @@ -2485,7 +2540,8 @@ export function isReasoningModel(model?: Model): boolean { isGeminiReasoningModel(model) || isQwenReasoningModel(model) || isGrokReasoningModel(model) || - model.id.includes('glm-z1') + model.id.includes('glm-z1') || + model.id.includes('magistral') ) { return true } @@ -2506,7 +2562,7 @@ export function isNotSupportTemperatureAndTopP(model: Model): boolean { return true } - if (isOpenAIReasoningModel(model) || isOpenAIWebSearch(model)) { + if (isOpenAIReasoningModel(model) || isOpenAIChatCompletionOnlyModel(model)) { return true } @@ -2536,17 +2592,13 @@ export function isWebSearchModel(model: Model): boolean { return false } + // 不管哪个供应商都判断了 if (model.id.includes('claude')) { return CLAUDE_SUPPORTED_WEBSEARCH_REGEX.test(model.id) } if (provider.type === 'openai-response') { - if ( - isOpenAILLMModel(model) && - !isTextToImageModel(model) && - !isOpenAIReasoningModel(model) && - !GENERATE_IMAGE_MODELS.includes(model.id) - ) { + if (isOpenAIWebSearchModel(model)) { return true } @@ -2558,12 +2610,7 @@ export function isWebSearchModel(model: Model): boolean { } if (provider.id === 'aihubmix') { - if ( - isOpenAILLMModel(model) && - !isTextToImageModel(model) && - !isOpenAIReasoningModel(model) && - !GENERATE_IMAGE_MODELS.includes(model.id) - ) { + if (isOpenAIWebSearchModel(model)) { return true } @@ -2572,7 +2619,7 @@ export function isWebSearchModel(model: Model): boolean { } if (provider?.type === 'openai') { - if (GEMINI_SEARCH_MODELS.includes(model?.id) || isOpenAIWebSearch(model)) { + if (GEMINI_SEARCH_MODELS.includes(model?.id) || isOpenAIWebSearchModel(model)) { return true } } @@ -2606,6 +2653,20 @@ export function isWebSearchModel(model: Model): boolean { return false } +export function isOpenRouterBuiltInWebSearchModel(model: Model): boolean { + if (!model) { + return false + } + + const provider = getProviderByModel(model) + + if (provider.id !== 'openrouter') { + return false + } + + return isOpenAIWebSearchModel(model) || model.id.includes('sonar') +} + export function isGenerateImageModel(model: Model): boolean { if (!model) { return false @@ -2628,56 +2689,60 @@ export function isGenerateImageModel(model: Model): boolean { return false } -export function getOpenAIWebSearchParams(assistant: Assistant, model: Model): Record { - if (isWebSearchModel(model)) { - if (assistant.enableWebSearch) { - const webSearchTools = getWebSearchTools(model) +export function isSupportedDisableGenerationModel(model: Model): boolean { + if (!model) { + return false + } - if (model.provider === 'grok') { - return { - search_parameters: { - mode: 'auto', - return_citations: true, - sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }] - } - } - } + return SUPPORTED_DISABLE_GENERATION_MODELS.includes(model.id) +} - if (model.provider === 'hunyuan') { - return { enable_enhancement: true, citation: true, search_info: true } - } +export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boolean): Record { + if (!isEnableWebSearch) { + return {} + } - if (model.provider === 'dashscope') { - return { - enable_search: true, - search_options: { - forced_search: true - } - } - } + const webSearchTools = getWebSearchTools(model) - if (model.provider === 'openrouter') { - return { - plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }] - } - } - - if (isOpenAIWebSearch(model)) { - return { - web_search_options: {} - } - } - - return { - tools: webSearchTools - } - } else { - if (model.provider === 'hunyuan') { - return { enable_enhancement: false } + if (model.provider === 'grok') { + return { + search_parameters: { + mode: 'auto', + return_citations: true, + sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }] } } } + if (model.provider === 'hunyuan') { + return { enable_enhancement: true, citation: true, search_info: true } + } + + if (model.provider === 'dashscope') { + return { + enable_search: true, + search_options: { + forced_search: true + } + } + } + + if (isOpenAIWebSearchChatCompletionOnlyModel(model)) { + return { + web_search_options: {} + } + } + + if (model.provider === 'openrouter') { + return { + plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }] + } + } + + return { + tools: webSearchTools + } + return {} } @@ -2758,3 +2823,16 @@ export const findTokenLimit = (modelId: string): { min: number; max: number } | } return undefined } + +// Doubao 支持思考模式的模型正则 +export const DOUBAO_THINKING_MODEL_REGEX = + /doubao-(?:1(\.|-5)-thinking-vision-pro|1(\.|-)5-thinking-pro-m|seed-1\.6|seed-1\.6-flash)(?:-[\\w-]+)?/i + +// 支持 auto 的 Doubao 模型 +export const DOUBAO_THINKING_AUTO_MODEL_REGEX = /doubao-(?:1-5-thinking-pro-m|seed-1.6)(?:-[\\w-]+)?/i + +export function isDoubaoThinkingAutoModel(model: Model): boolean { + return DOUBAO_THINKING_AUTO_MODEL_REGEX.test(model.id) +} + +export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini-.*-flash.*$') diff --git a/src/renderer/src/hooks/useTopic.ts b/src/renderer/src/hooks/useTopic.ts index bd2f0cb7c0..0b5c9fe8d3 100644 --- a/src/renderer/src/hooks/useTopic.ts +++ b/src/renderer/src/hooks/useTopic.ts @@ -2,6 +2,8 @@ import db from '@renderer/databases' import i18n from '@renderer/i18n' import { deleteMessageFiles } from '@renderer/services/MessagesService' import store from '@renderer/store' +import { setNewlyRenamedTopics, setRenamingTopics } from '@renderer/store/runtime' +import { loadTopicMessagesThunk } from '@renderer/store/thunk/messageThunk' import { selectTopicById, topicsActions } from '@renderer/store/topics' import { Assistant, Topic } from '@renderer/types' import { findMainTextBlocks } from '@renderer/utils/messageUtils/find' @@ -9,13 +11,6 @@ import { isEmpty } from 'lodash' import { getStoreSetting } from './useSettings' -const renamingTopics = new Set() - -export function useTopic(topicId?: string) { - if (!topicId) return undefined - return selectTopicById(store.getState(), topicId) -} - export function getTopic(topicId: string) { return selectTopicById(store.getState(), topicId) } @@ -26,13 +21,46 @@ export async function getTopicById(topicId: string) { return { ...topic, messages } as Topic } +/** + * 开始重命名指定话题 + */ +export const startTopicRenaming = (topicId: string) => { + const currentIds = store.getState().runtime.chat.renamingTopics + if (!currentIds.includes(topicId)) { + store.dispatch(setRenamingTopics([...currentIds, topicId])) + } +} + +/** + * 完成重命名指定话题 + */ +export const finishTopicRenaming = (topicId: string) => { + const state = store.getState() + + // 1. 立即从 renamingTopics 移除 + const currentRenaming = state.runtime.chat.renamingTopics + store.dispatch(setRenamingTopics(currentRenaming.filter((id) => id !== topicId))) + + // 2. 立即添加到 newlyRenamedTopics + const currentNewlyRenamed = state.runtime.chat.newlyRenamedTopics + store.dispatch(setNewlyRenamedTopics([...currentNewlyRenamed, topicId])) + + // 3. 延迟从 newlyRenamedTopics 移除 + setTimeout(() => { + const current = store.getState().runtime.chat.newlyRenamedTopics + store.dispatch(setNewlyRenamedTopics(current.filter((id) => id !== topicId))) + }, 700) +} + +const topicRenamingLocks = new Set() + export const autoRenameTopic = async (assistant: Assistant, topicId: string) => { - if (renamingTopics.has(topicId)) { + if (topicRenamingLocks.has(topicId)) { return } try { - renamingTopics.add(topicId) + topicRenamingLocks.add(topicId) const topic = await getTopicById(topicId) const enableTopicNaming = getStoreSetting('enableTopicNaming') @@ -53,22 +81,34 @@ export const autoRenameTopic = async (assistant: Assistant, topicId: string) => .join('\n\n') .substring(0, 50) if (topicName) { - const data = { ...topic, name: topicName } as Topic - store.dispatch(topicsActions.updateTopic({ assistantId: assistant.id, topic: data })) + try { + startTopicRenaming(topicId) + + const data = { ...topic, name: topicName } as Topic + store.dispatch(topicsActions.updateTopic({ assistantId: assistant.id, topic: data })) + } finally { + finishTopicRenaming(topicId) + } } return } if (topic && topic.name === i18n.t('chat.default.topic.name') && topic.messages.length >= 2) { - const { fetchMessagesSummary } = await import('@renderer/services/ApiService') - const summaryText = await fetchMessagesSummary({ messages: topic.messages, assistant }) - if (summaryText) { - const data = { ...topic, name: summaryText } - store.dispatch(topicsActions.updateTopic({ assistantId: assistant.id, topic: data })) + try { + startTopicRenaming(topicId) + + const { fetchMessagesSummary } = await import('@renderer/services/ApiService') + const summaryText = await fetchMessagesSummary({ messages: topic.messages, assistant }) + if (summaryText) { + const data = { ...topic, name: summaryText } + store.dispatch(topicsActions.updateTopic({ assistantId: assistant.id, topic: data })) + } + } finally { + finishTopicRenaming(topicId) } } } finally { - renamingTopics.delete(topicId) + topicRenamingLocks.delete(topicId) } } @@ -83,9 +123,18 @@ export const TopicManager = { return await db.topics.toArray() }, + /** + * 加载并返回指定话题的消息 + */ async getTopicMessages(id: string) { const topic = await TopicManager.getTopic(id) - return topic ? topic.messages : [] + if (!topic) return [] + + await store.dispatch(loadTopicMessagesThunk(id)) + + // 获取更新后的话题 + const updatedTopic = await TopicManager.getTopic(id) + return updatedTopic?.messages || [] }, async removeTopic(id: string) { diff --git a/src/renderer/src/middlewares/extractReasoningMiddleware.ts b/src/renderer/src/middlewares/extractReasoningMiddleware.ts deleted file mode 100644 index a466822d8c..0000000000 --- a/src/renderer/src/middlewares/extractReasoningMiddleware.ts +++ /dev/null @@ -1,118 +0,0 @@ -// Modified from https://github.com/vercel/ai/blob/845080d80b8538bb9c7e527d2213acb5f33ac9c2/packages/ai/core/middleware/extract-reasoning-middleware.ts - -import { getPotentialStartIndex } from '../utils/getPotentialIndex' - -export interface ExtractReasoningMiddlewareOptions { - openingTag: string - closingTag: string - separator?: string - enableReasoning?: boolean -} - -function escapeRegExp(str: string) { - return str.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&') -} - -// 支持泛型 T,默认 T = { type: string; textDelta: string } -export function extractReasoningMiddleware< - T extends { type: string } & ( - | { type: 'text-delta' | 'reasoning'; textDelta: string } - | { type: string } // 其他类型 - ) = { type: string; textDelta: string } ->({ openingTag, closingTag, separator = '\n', enableReasoning }: ExtractReasoningMiddlewareOptions) { - const openingTagEscaped = escapeRegExp(openingTag) - const closingTagEscaped = escapeRegExp(closingTag) - - return { - wrapGenerate: async ({ doGenerate }: { doGenerate: () => Promise<{ text: string } & Record> }) => { - const { text: rawText, ...rest } = await doGenerate() - if (rawText == null) { - return { text: rawText, ...rest } - } - const text = rawText - const regexp = new RegExp(`${openingTagEscaped}(.*?)${closingTagEscaped}`, 'gs') - const matches = Array.from(text.matchAll(regexp)) - if (!matches.length) { - return { text, ...rest } - } - const reasoning = matches.map((match: RegExpMatchArray) => match[1]).join(separator) - let textWithoutReasoning = text - for (let i = matches.length - 1; i >= 0; i--) { - const match = matches[i] as RegExpMatchArray - const beforeMatch = textWithoutReasoning.slice(0, match.index as number) - const afterMatch = textWithoutReasoning.slice((match.index as number) + match[0].length) - textWithoutReasoning = - beforeMatch + (beforeMatch.length > 0 && afterMatch.length > 0 ? separator : '') + afterMatch - } - return { ...rest, text: textWithoutReasoning, reasoning } - }, - wrapStream: async ({ - doStream - }: { - doStream: () => Promise<{ stream: ReadableStream } & Record> - }) => { - const { stream, ...rest } = await doStream() - if (!enableReasoning) { - return { - stream, - ...rest - } - } - let isFirstReasoning = true - let isFirstText = true - let afterSwitch = false - let isReasoning = false - let buffer = '' - return { - stream: stream.pipeThrough( - new TransformStream({ - transform: (chunk, controller) => { - if (chunk.type !== 'text-delta') { - controller.enqueue(chunk) - return - } - // textDelta 只在 text-delta/reasoning chunk 上 - buffer += (chunk as { textDelta: string }).textDelta - function publish(text: string) { - if (text.length > 0) { - const prefix = afterSwitch && (isReasoning ? !isFirstReasoning : !isFirstText) ? separator : '' - controller.enqueue({ - ...chunk, - type: isReasoning ? 'reasoning' : 'text-delta', - textDelta: prefix + text - } as T) - afterSwitch = false - if (isReasoning) { - isFirstReasoning = false - } else { - isFirstText = false - } - } - } - while (true) { - const nextTag = isReasoning ? closingTag : openingTag - const startIndex = getPotentialStartIndex(buffer, nextTag) - if (startIndex == null) { - publish(buffer) - buffer = '' - break - } - publish(buffer.slice(0, startIndex)) - const foundFullMatch = startIndex + nextTag.length <= buffer.length - if (foundFullMatch) { - buffer = buffer.slice(startIndex + nextTag.length) - isReasoning = !isReasoning - afterSwitch = true - } else { - buffer = buffer.slice(startIndex) - break - } - } - } - }) - ), - ...rest - } - } - } -} diff --git a/src/renderer/src/pages/history/HistoryPage.tsx b/src/renderer/src/pages/history/HistoryPage.tsx index a3f60e2cc5..d20accfd87 100644 --- a/src/renderer/src/pages/history/HistoryPage.tsx +++ b/src/renderer/src/pages/history/HistoryPage.tsx @@ -86,7 +86,7 @@ const TopicsPage: FC = () => { ) } - suffix={search.length >= 2 ? : } + suffix={search.length >= 2 ? : null} ref={inputRef} placeholder={t('history.search.placeholder')} value={search} diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index 2f2ba0569f..c8ebcac5c5 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -4,6 +4,7 @@ import TranslateButton from '@renderer/components/TranslateButton' import Logger from '@renderer/config/logger' import { isGenerateImageModel, + isSupportedDisableGenerationModel, isSupportedReasoningEffortModel, isSupportedThinkingTokenModel, isVisionModel, @@ -718,7 +719,7 @@ const Inputbar: FC = () => { if (!isGenerateImageModel(model) && assistant.enableGenerateImage) { updateAssistant({ ...assistant, enableGenerateImage: false }) } - if (isGenerateImageModel(model) && !assistant.enableGenerateImage && model.id !== 'gemini-2.0-flash-exp') { + if (isGenerateImageModel(model) && !assistant.enableGenerateImage && !isSupportedDisableGenerationModel(model)) { updateAssistant({ ...assistant, enableGenerateImage: true }) } }, [assistant, model, updateAssistant]) diff --git a/src/renderer/src/pages/home/Inputbar/ThinkingButton.tsx b/src/renderer/src/pages/home/Inputbar/ThinkingButton.tsx index 2caef6c158..21db131cef 100644 --- a/src/renderer/src/pages/home/Inputbar/ThinkingButton.tsx +++ b/src/renderer/src/pages/home/Inputbar/ThinkingButton.tsx @@ -7,7 +7,9 @@ import { } from '@renderer/components/Icons/SVGIcon' import { useQuickPanel } from '@renderer/components/QuickPanel' import { + isDoubaoThinkingAutoModel, isSupportedReasoningEffortGrokModel, + isSupportedThinkingTokenDoubaoModel, isSupportedThinkingTokenGeminiModel, isSupportedThinkingTokenQwenModel } from '@renderer/config/models' @@ -35,13 +37,14 @@ const MODEL_SUPPORTED_OPTIONS: Record = { default: ['off', 'low', 'medium', 'high'], grok: ['off', 'low', 'high'], gemini: ['off', 'low', 'medium', 'high', 'auto'], - qwen: ['off', 'low', 'medium', 'high'] + qwen: ['off', 'low', 'medium', 'high'], + doubao: ['off', 'auto', 'high'] } // 选项转换映射表:当选项不支持时使用的替代选项 const OPTION_FALLBACK: Record = { off: 'off', - low: 'low', + low: 'high', medium: 'high', // medium -> high (for Grok models) high: 'high', auto: 'high' // auto -> high (for non-Gemini models) @@ -55,6 +58,7 @@ const ThinkingButton: FC = ({ ref, model, assistant, ToolbarButton }): Re const isGrokModel = isSupportedReasoningEffortGrokModel(model) const isGeminiModel = isSupportedThinkingTokenGeminiModel(model) const isQwenModel = isSupportedThinkingTokenQwenModel(model) + const isDoubaoModel = isSupportedThinkingTokenDoubaoModel(model) const currentReasoningEffort = useMemo(() => { return assistant.settings?.reasoning_effort || 'off' @@ -65,13 +69,20 @@ const ThinkingButton: FC = ({ ref, model, assistant, ToolbarButton }): Re if (isGeminiModel) return 'gemini' if (isGrokModel) return 'grok' if (isQwenModel) return 'qwen' + if (isDoubaoModel) return 'doubao' return 'default' - }, [isGeminiModel, isGrokModel, isQwenModel]) + }, [isGeminiModel, isGrokModel, isQwenModel, isDoubaoModel]) // 获取当前模型支持的选项 const supportedOptions = useMemo(() => { + if (modelType === 'doubao') { + if (isDoubaoThinkingAutoModel(model)) { + return ['off', 'auto', 'high'] as ThinkingOption[] + } + return ['off', 'high'] as ThinkingOption[] + } return MODEL_SUPPORTED_OPTIONS[modelType] - }, [modelType]) + }, [model, modelType]) // 检查当前设置是否与当前模型兼容 useEffect(() => { diff --git a/src/renderer/src/pages/home/Markdown/CitationTooltip.tsx b/src/renderer/src/pages/home/Markdown/CitationTooltip.tsx index 45b804c851..6041b562af 100644 --- a/src/renderer/src/pages/home/Markdown/CitationTooltip.tsx +++ b/src/renderer/src/pages/home/Markdown/CitationTooltip.tsx @@ -54,9 +54,10 @@ const CitationTooltip: React.FC = ({ children, citation }) return ( = ({ block }) => { code: (props: any) => ( ), + table: (props: any) => , img: (props: any) => , pre: (props: any) =>
,
       p: (props) => {
@@ -91,7 +93,7 @@ const Markdown: FC = ({ block }) => {
         return 

} } as Partial - }, [onSaveCodeBlock]) + }, [onSaveCodeBlock, block.id]) const urlTransform = useCallback((value: string) => { if (value.startsWith('data:image/png') || value.startsWith('data:image/jpeg')) return value diff --git a/src/renderer/src/pages/home/Markdown/Table.tsx b/src/renderer/src/pages/home/Markdown/Table.tsx new file mode 100644 index 0000000000..06074d55ac --- /dev/null +++ b/src/renderer/src/pages/home/Markdown/Table.tsx @@ -0,0 +1,120 @@ +import store from '@renderer/store' +import { messageBlocksSelectors } from '@renderer/store/messageBlock' +import { Tooltip } from 'antd' +import { Check, Copy } from 'lucide-react' +import React, { memo, useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import styled from 'styled-components' + +interface Props { + children: React.ReactNode + node?: any + blockId?: string +} + +/** + * 自定义 Markdown 表格组件,提供 copy 功能。 + */ +const Table: React.FC = ({ children, node, blockId }) => { + const { t } = useTranslation() + const [copied, setCopied] = useState(false) + + const handleCopyTable = useCallback(() => { + const tableMarkdown = extractTableMarkdown(blockId ?? '', node?.position) + if (!tableMarkdown) return + + navigator.clipboard + .writeText(tableMarkdown) + .then(() => { + setCopied(true) + setTimeout(() => setCopied(false), 2000) + }) + .catch((error) => { + window.message?.error({ content: `${t('message.copy.failed')}: ${error}`, key: 'copy-table-error' }) + }) + }, [node, blockId, t]) + + return ( + +

{children}
+ + + + {copied ? ( + + ) : ( + + )} + + + + + ) +} + +/** + * 从原始 Markdown 内容中提取表格源代码 + * @param blockId 消息块 ID + * @param position 表格节点的位置信息 + * @returns 源代码 + */ +export function extractTableMarkdown(blockId: string, position: any): string { + if (!position || !blockId) return '' + + const block = messageBlocksSelectors.selectById(store.getState(), blockId) + + if (!block || !('content' in block) || typeof block.content !== 'string') return '' + + const { start, end } = position + const lines = block.content.split('\n') + + // 提取表格对应的行(行号从1开始,数组索引从0开始) + const tableLines = lines.slice(start.line - 1, end.line) + return tableLines.join('\n').trim() +} + +const TableWrapper = styled.div` + position: relative; + + .table-toolbar { + border-radius: 4px; + opacity: 0; + transition: opacity 0.2s ease; + transform: translateZ(0); + will-change: opacity; + } + &:hover { + .table-toolbar { + opacity: 1; + } + } +` + +const ToolbarWrapper = styled.div` + position: absolute; + top: 8px; + right: 8px; + z-index: 10; +` + +const ToolButton = styled.div` + display: flex; + align-items: center; + justify-content: center; + width: 24px; + height: 24px; + border-radius: 4px; + cursor: pointer; + user-select: none; + transition: all 0.2s ease; + opacity: 1; + color: var(--color-text-3); + background-color: var(--color-background-mute); + will-change: background-color, opacity; + + &:hover { + background-color: var(--color-background-soft); + } +` + +export default memo(Table) diff --git a/src/renderer/src/pages/home/Markdown/__tests__/Markdown.test.tsx b/src/renderer/src/pages/home/Markdown/__tests__/Markdown.test.tsx index f5769eb4f8..c72f30de98 100644 --- a/src/renderer/src/pages/home/Markdown/__tests__/Markdown.test.tsx +++ b/src/renderer/src/pages/home/Markdown/__tests__/Markdown.test.tsx @@ -78,6 +78,18 @@ vi.mock('../Link', () => ({ ) })) +vi.mock('../Table', () => ({ + __esModule: true, + default: ({ children, blockId }: any) => ( +
+ {children}
+ +
+ ) +})) + vi.mock('@renderer/components/MarkdownShadowDOMRenderer', () => ({ __esModule: true, default: ({ children }: any) =>
{children}
@@ -104,6 +116,11 @@ vi.mock('react-markdown', () => ({ {components.code({ children: 'test code', node: { position: { start: { line: 1 } } } })} )} + {components?.table && ( +
+ {components.table({ children: 'test table', node: { position: { start: { line: 1 } } } })} +
+ )} {components?.img && img} {components?.style && style} @@ -300,6 +317,16 @@ describe('Markdown', () => { }) }) + it('should integrate Table component with copy functionality', () => { + const block = createMainTextBlock({ id: 'test-block-456' }) + render() + + expect(screen.getByTestId('has-table-component')).toBeInTheDocument() + + const tableComponent = screen.getByTestId('table-component') + expect(tableComponent).toHaveAttribute('data-block-id', 'test-block-456') + }) + it('should integrate ImagePreview component', () => { render() diff --git a/src/renderer/src/pages/home/Markdown/__tests__/Table.test.tsx b/src/renderer/src/pages/home/Markdown/__tests__/Table.test.tsx new file mode 100644 index 0000000000..5a5bbeb90d --- /dev/null +++ b/src/renderer/src/pages/home/Markdown/__tests__/Table.test.tsx @@ -0,0 +1,316 @@ +import { act, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from 'vitest' + +import Table, { extractTableMarkdown } from '../Table' + +const mocks = vi.hoisted(() => { + return { + store: { + getState: vi.fn() + }, + messageBlocksSelectors: { + selectById: vi.fn() + }, + windowMessage: { + error: vi.fn() + } + } +}) + +// Mock dependencies +vi.mock('@renderer/store', () => ({ + __esModule: true, + default: mocks.store +})) + +vi.mock('@renderer/store/messageBlock', () => ({ + messageBlocksSelectors: mocks.messageBlocksSelectors +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key + }) +})) + +vi.mock('antd', () => ({ + Tooltip: ({ children, title }: any) => ( +
+ {children} +
+ ) +})) + +Object.assign(window, { + message: mocks.windowMessage +}) + +describe('Table', () => { + beforeAll(() => { + vi.stubGlobal('jest', { + advanceTimersByTime: vi.advanceTimersByTime.bind(vi) + }) + }) + + beforeEach(() => { + vi.clearAllMocks() + vi.useFakeTimers() + }) + + afterEach(() => { + vi.restoreAllMocks() + vi.runOnlyPendingTimers() + vi.useRealTimers() + }) + + afterAll(() => { + vi.unstubAllGlobals() + }) + + // https://testing-library.com/docs/user-event/clipboard/ + const user = userEvent.setup({ + advanceTimers: vi.advanceTimersByTime.bind(vi), + writeToClipboard: true + }) + + // Test data factories + const createMockBlock = (content: string = defaultTableContent) => ({ + id: 'test-block-1', + content + }) + + const createTablePosition = (startLine = 1, endLine = 3) => ({ + start: { line: startLine }, + end: { line: endLine } + }) + + const defaultTableContent = `| Header 1 | Header 2 | +|----------|----------| +| Cell 1 | Cell 2 |` + + const defaultProps = { + children: ( + + + Cell 1 + Cell 2 + + + ), + blockId: 'test-block-1', + node: { position: createTablePosition() } + } + + const getCopyButton = () => screen.getByRole('button', { name: /common\.copy/i }) + const getCopyIcon = () => screen.getByTestId('copy-icon') + const getCheckIcon = () => screen.getByTestId('check-icon') + const queryCheckIcon = () => screen.queryByTestId('check-icon') + const queryCopyIcon = () => screen.queryByTestId('copy-icon') + + describe('rendering', () => { + it('should render table with children and toolbar', () => { + render() + + expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByText('Cell 1')).toBeInTheDocument() + expect(screen.getByText('Cell 2')).toBeInTheDocument() + expect(screen.getByTestId('tooltip')).toBeInTheDocument() + }) + + it('should render with table-wrapper and table-toolbar classes', () => { + const { container } = render(
) + + expect(container.querySelector('.table-wrapper')).toBeInTheDocument() + expect(container.querySelector('.table-toolbar')).toBeInTheDocument() + }) + + it('should render copy button with correct tooltip', () => { + render(
) + + const tooltip = screen.getByTestId('tooltip') + expect(tooltip).toHaveAttribute('title', 'common.copy') + }) + + it('should match snapshot', () => { + const { container } = render(
) + expect(container.firstChild).toMatchSnapshot() + }) + }) + + describe('extractTableMarkdown', () => { + beforeEach(() => { + mocks.store.getState.mockReturnValue({}) + }) + + it('should extract table content from specified line range', () => { + const block = createMockBlock() + const position = createTablePosition(1, 3) + mocks.messageBlocksSelectors.selectById.mockReturnValue(block) + + const result = extractTableMarkdown('test-block-1', position) + + expect(result).toBe(defaultTableContent) + expect(mocks.messageBlocksSelectors.selectById).toHaveBeenCalledWith({}, 'test-block-1') + }) + + it('should handle line range extraction correctly', () => { + const multiLineContent = `Line 0 +| Header 1 | Header 2 | +|----------|----------| +| Cell 1 | Cell 2 | +Line 4` + const block = createMockBlock(multiLineContent) + const position = createTablePosition(2, 4) // Extract lines 2-4 (table part) + mocks.messageBlocksSelectors.selectById.mockReturnValue(block) + + const result = extractTableMarkdown('test-block-1', position) + + expect(result).toBe(`| Header 1 | Header 2 | +|----------|----------| +| Cell 1 | Cell 2 |`) + }) + + it('should return empty string when blockId is empty', () => { + const result = extractTableMarkdown('', createTablePosition()) + expect(result).toBe('') + expect(mocks.messageBlocksSelectors.selectById).not.toHaveBeenCalled() + }) + + it('should return empty string when position is null', () => { + const result = extractTableMarkdown('test-block-1', null) + expect(result).toBe('') + expect(mocks.messageBlocksSelectors.selectById).not.toHaveBeenCalled() + }) + + it('should return empty string when position is undefined', () => { + const result = extractTableMarkdown('test-block-1', undefined) + expect(result).toBe('') + expect(mocks.messageBlocksSelectors.selectById).not.toHaveBeenCalled() + }) + + it('should return empty string when block does not exist', () => { + mocks.messageBlocksSelectors.selectById.mockReturnValue(null) + + const result = extractTableMarkdown('non-existent-block', createTablePosition()) + + expect(result).toBe('') + }) + + it('should return empty string when block has no content property', () => { + const blockWithoutContent = { id: 'test-block-1' } + mocks.messageBlocksSelectors.selectById.mockReturnValue(blockWithoutContent) + + const result = extractTableMarkdown('test-block-1', createTablePosition()) + + expect(result).toBe('') + }) + + it('should return empty string when block content is not a string', () => { + const blockWithInvalidContent = { id: 'test-block-1', content: 123 } + mocks.messageBlocksSelectors.selectById.mockReturnValue(blockWithInvalidContent) + + const result = extractTableMarkdown('test-block-1', createTablePosition()) + + expect(result).toBe('') + }) + + it('should handle boundary line numbers correctly', () => { + const block = createMockBlock('Line 1\nLine 2\nLine 3') + const position = createTablePosition(1, 3) + mocks.messageBlocksSelectors.selectById.mockReturnValue(block) + + const result = extractTableMarkdown('test-block-1', position) + + expect(result).toBe('Line 1\nLine 2\nLine 3') + }) + }) + + describe('copy functionality', () => { + beforeEach(() => { + mocks.messageBlocksSelectors.selectById.mockReturnValue(createMockBlock()) + }) + + it('should copy table content to clipboard on button click', async () => { + render(
) + + const copyButton = getCopyButton() + await user.click(copyButton) + + await waitFor(() => { + expect(getCheckIcon()).toBeInTheDocument() + expect(queryCopyIcon()).not.toBeInTheDocument() + }) + }) + + it('should show check icon after successful copy', async () => { + render(
) + + // Initially shows copy icon + expect(getCopyIcon()).toBeInTheDocument() + + const copyButton = getCopyButton() + await user.click(copyButton) + + await waitFor(() => { + expect(getCheckIcon()).toBeInTheDocument() + expect(queryCopyIcon()).not.toBeInTheDocument() + }) + }) + + it('should reset to copy icon after 2 seconds', async () => { + render(
) + + const copyButton = getCopyButton() + await user.click(copyButton) + + await waitFor(() => { + expect(getCheckIcon()).toBeInTheDocument() + }) + + // Fast forward 2 seconds + act(() => { + vi.advanceTimersByTime(2000) + }) + + await waitFor(() => { + expect(getCopyIcon()).toBeInTheDocument() + expect(queryCheckIcon()).not.toBeInTheDocument() + }) + }) + + it('should not copy when extractTableMarkdown returns empty string', async () => { + mocks.messageBlocksSelectors.selectById.mockReturnValue(null) + + render(
) + + const copyButton = getCopyButton() + await user.click(copyButton) + + await waitFor(() => { + expect(getCopyIcon()).toBeInTheDocument() + expect(queryCheckIcon()).not.toBeInTheDocument() + }) + }) + }) + + describe('edge cases', () => { + it('should work without blockId', () => { + const propsWithoutBlockId = { ...defaultProps, blockId: undefined } + + expect(() => render(
)).not.toThrow() + + const copyButton = getCopyButton() + expect(copyButton).toBeInTheDocument() + }) + + it('should work without node position', () => { + const propsWithoutPosition = { ...defaultProps, node: undefined } + + expect(() => render(
)).not.toThrow() + + const copyButton = getCopyButton() + expect(copyButton).toBeInTheDocument() + }) + }) +}) diff --git a/src/renderer/src/pages/home/Markdown/__tests__/__snapshots__/Markdown.test.tsx.snap b/src/renderer/src/pages/home/Markdown/__tests__/__snapshots__/Markdown.test.tsx.snap index e055c83f52..29aae68dc0 100644 --- a/src/renderer/src/pages/home/Markdown/__tests__/__snapshots__/Markdown.test.tsx.snap +++ b/src/renderer/src/pages/home/Markdown/__tests__/__snapshots__/Markdown.test.tsx.snap @@ -30,6 +30,24 @@ This is **bold** text. +
+
+
+ test table +
+ + + diff --git a/src/renderer/src/pages/home/Markdown/__tests__/__snapshots__/Table.test.tsx.snap b/src/renderer/src/pages/home/Markdown/__tests__/__snapshots__/Table.test.tsx.snap new file mode 100644 index 0000000000..da85d514be --- /dev/null +++ b/src/renderer/src/pages/home/Markdown/__tests__/__snapshots__/Table.test.tsx.snap @@ -0,0 +1,103 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`Table > rendering > should match snapshot 1`] = ` +.c0 { + position: relative; +} + +.c0 .table-toolbar { + border-radius: 4px; + opacity: 0; + transition: opacity 0.2s ease; + transform: translateZ(0); + will-change: opacity; +} + +.c0:hover .table-toolbar { + opacity: 1; +} + +.c1 { + position: absolute; + top: 8px; + right: 8px; + z-index: 10; +} + +.c2 { + display: flex; + align-items: center; + justify-content: center; + width: 24px; + height: 24px; + border-radius: 4px; + cursor: pointer; + user-select: none; + transition: all 0.2s ease; + opacity: 1; + color: var(--color-text-3); + background-color: var(--color-background-mute); + will-change: background-color,opacity; +} + +.c2:hover { + background-color: var(--color-background-soft); +} + +
+ + + + + + + +
+ Cell 1 + + Cell 2 +
+
+
+
+ + + + +
+
+
+
+`; diff --git a/src/renderer/src/pages/home/Messages/Blocks/CitationBlock.tsx b/src/renderer/src/pages/home/Messages/Blocks/CitationBlock.tsx index 09a16c3496..fdd4640d00 100644 --- a/src/renderer/src/pages/home/Messages/Blocks/CitationBlock.tsx +++ b/src/renderer/src/pages/home/Messages/Blocks/CitationBlock.tsx @@ -40,7 +40,18 @@ function CitationBlock({ block }: { block: CitationMessageBlock }) { __html: (block.response?.results as GroundingMetadata)?.searchEntryPoint?.renderedContent ?.replace(/@media \(prefers-color-scheme: light\)/g, 'body[theme-mode="light"]') - .replace(/@media \(prefers-color-scheme: dark\)/g, 'body[theme-mode="dark"]') || '' + .replace(/@media \(prefers-color-scheme: dark\)/g, 'body[theme-mode="dark"]') + .replace( + /background-color\s*:\s*#[0-9a-fA-F]{3,6}\b|\bbackground-color\s*:\s*[a-zA-Z-]+\b/g, + 'background-color: var(--color-background-soft)' + ) + .replace(/\.gradient\s*{[^}]*background\s*:\s*[^};]+[;}]/g, (match) => { + // Remove the background property while preserving the rest + return match.replace(/background\s*:\s*[^};]+;?\s*/g, '') + }) + .replace(/\.chip {\n/g, '.chip {\n background-color: var(--color-background)!important;\n') + .replace(/border-color\s*:\s*[^};]+;?\s*/g, '') + .replace(/border\s*:\s*[^};]+;?\s*/g, '') || '' }} /> diff --git a/src/renderer/src/pages/home/Messages/Blocks/ImageBlock.tsx b/src/renderer/src/pages/home/Messages/Blocks/ImageBlock.tsx index 7f352fbe2a..046c8395ad 100644 --- a/src/renderer/src/pages/home/Messages/Blocks/ImageBlock.tsx +++ b/src/renderer/src/pages/home/Messages/Blocks/ImageBlock.tsx @@ -1,6 +1,6 @@ import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring' import ImageViewer from '@renderer/components/ImageViewer' -import type { ImageMessageBlock } from '@renderer/types/newMessage' +import { type ImageMessageBlock } from '@renderer/types/newMessage' import React from 'react' import styled from 'styled-components' diff --git a/src/renderer/src/pages/home/Messages/ChatNavigation.tsx b/src/renderer/src/pages/home/Messages/ChatNavigation.tsx index b1cbc340ce..955aae645f 100644 --- a/src/renderer/src/pages/home/Messages/ChatNavigation.tsx +++ b/src/renderer/src/pages/home/Messages/ChatNavigation.tsx @@ -23,7 +23,8 @@ const EXCLUDED_SELECTORS = [ '.ant-collapse-header', '.group-menu-bar', '.code-block', - '.message-editor' + '.message-editor', + '.table-wrapper' ] // Gap between the navigation bar and the right element diff --git a/src/renderer/src/pages/home/Messages/CitationsList.tsx b/src/renderer/src/pages/home/Messages/CitationsList.tsx index a33d03343d..0c32f6151a 100644 --- a/src/renderer/src/pages/home/Messages/CitationsList.tsx +++ b/src/renderer/src/pages/home/Messages/CitationsList.tsx @@ -56,13 +56,13 @@ const CitationsList: React.FC = ({ citations }) => { const popoverContent = ( {citations.map((citation) => ( -
+ {citation.type === 'websearch' ? ( ) : ( )} -
+ ))}
) @@ -72,7 +72,17 @@ const CitationsList: React.FC = ({ citations }) => { {t('message.citations')}} + title={ +
+ {t('message.citations')} +
+ } placement="right" trigger="hover" styles={{ @@ -142,13 +152,14 @@ const WebSearchCitation: React.FC<{ citation: Citation }> = ({ citation }) => { - {citation.number} {citation.showFavicon && citation.url && ( )} handleLinkClick(citation.url, e)}> {citation.title || {citation.hostname}} + + {citation.number} {fetchedContent && } {isLoading ? ( @@ -216,10 +227,19 @@ const PreviewIcon = styled.div` ` const CitationIndex = styled.div` - font-size: 14px; + width: 14px; + height: 14px; + display: flex; + align-items: center; + justify-content: center; + border-radius: 50%; + background-color: var(--color-reference); + font-size: 10px; line-height: 1.6; - color: var(--color-text-2); - margin-right: 8px; + color: var(--color-reference-text); + flex-shrink: 0; + opacity: 1; + transition: opacity 0.3s ease; ` const CitationLink = styled.a` @@ -227,7 +247,7 @@ const CitationLink = styled.a` line-height: 1.6; color: var(--color-text-1); text-decoration: none; - + flex: 1; .hostname { color: var(--color-link); } @@ -239,10 +259,14 @@ const CopyIconWrapper = styled.div` align-items: center; justify-content: center; color: var(--color-text-2); - opacity: 0.6; - margin-left: auto; + opacity: 0; padding: 4px; border-radius: 4px; + position: absolute; + right: 0; + top: 50%; + transform: translateY(-50%); + transition: opacity 0.3s ease; &:hover { opacity: 1; @@ -254,10 +278,17 @@ const WebSearchCard = styled.div` display: flex; flex-direction: column; width: 100%; - padding: 12px; - border-radius: var(--list-item-border-radius); + padding: 12px 0; transition: all 0.3s ease; position: relative; + &:hover { + ${CopyIconWrapper} { + opacity: 1; + } + ${CitationIndex} { + opacity: 0; + } + } ` const WebSearchCardHeader = styled.div` @@ -267,6 +298,7 @@ const WebSearchCardHeader = styled.div` gap: 8px; margin-bottom: 6px; width: 100%; + position: relative; ` const WebSearchCardContent = styled.div` @@ -275,6 +307,7 @@ const WebSearchCardContent = styled.div` color: var(--color-text-2); user-select: text; cursor: text; + word-break: break-all; &.selectable-text { -webkit-user-select: text; @@ -285,8 +318,15 @@ const WebSearchCardContent = styled.div` ` const PopoverContent = styled.div` - max-width: 300px; - max-height: 50vh; + max-width: min(300px, 60vw); + max-height: 60vh; + padding: 0 12px; +` +const PopoverContentItem = styled.div` + border-bottom: 0.5px solid var(--color-border); + &:last-child { + border-bottom: none; + } ` export default CitationsList diff --git a/src/renderer/src/pages/home/Messages/MessageGroup.tsx b/src/renderer/src/pages/home/Messages/MessageGroup.tsx index e0244ea4f5..e6c831318f 100644 --- a/src/renderer/src/pages/home/Messages/MessageGroup.tsx +++ b/src/renderer/src/pages/home/Messages/MessageGroup.tsx @@ -293,6 +293,7 @@ const GridContainer = styled.div<{ $count: number; $layout: MultiModelMessageSty $layout === 'horizontal' && css` margin-top: 15px; + padding-bottom: 4px; `} ${({ $gridColumns, $layout, $count }) => $layout === 'grid' && diff --git a/src/renderer/src/pages/home/Messages/MessageMenubar.tsx b/src/renderer/src/pages/home/Messages/MessageMenubar.tsx index f3382312ad..4512a7bb38 100644 --- a/src/renderer/src/pages/home/Messages/MessageMenubar.tsx +++ b/src/renderer/src/pages/home/Messages/MessageMenubar.tsx @@ -492,10 +492,10 @@ const MessageMenubar: FC = (props) => { {!isUserMessage && ( e.domEvent.stopPropagation() }} trigger={['click']} - placement="topRight" - arrow> + placement="topRight"> e.stopPropagation()}> diff --git a/src/renderer/src/pages/home/Tabs/AssistantsTab.tsx b/src/renderer/src/pages/home/Tabs/AssistantsTab.tsx index 1e486e4cb7..812acfab6a 100644 --- a/src/renderer/src/pages/home/Tabs/AssistantsTab.tsx +++ b/src/renderer/src/pages/home/Tabs/AssistantsTab.tsx @@ -159,67 +159,63 @@ const Assistants: FC = ({ if (assistantsTabSortType === 'tags') { return ( -
- ({ - ..._, - disabled: _.tag === t('assistants.tags.untagged') - }))} - onUpdate={() => {}} - onDragEnd={handleGroupDragEnd} - style={{ paddingBottom: 0 }}> - {(group) => ( - - {(provided) => ( - - {group.tag !== t('assistants.tags.untagged') && ( - toggleTagCollapse(group.tag)}> - - - {collapsedTags[group.tag] ? ( - - ) : ( - - )} - {group.tag} - - - - - )} - {!collapsedTags[group.tag] && ( -
- {group.assistants.map((assistant, index) => ( - - {(provided) => ( -
- {}} - /> -
- )} -
- ))} -
- )} - {provided.placeholder} -
- )} -
- )} -
-
+ ({ ..._, disabled: _.tag === t('assistants.tags.untagged') }))} + onUpdate={() => {}} + onDragEnd={handleGroupDragEnd}> + {(group) => ( + + {(provided) => ( + + {group.tag !== t('assistants.tags.untagged') && ( + toggleTagCollapse(group.tag)}> + + + {collapsedTags[group.tag] ? ( + + ) : ( + + )} + {group.tag} + + + + + )} + {!collapsedTags[group.tag] && ( + <> + {group.assistants.map((assistant, index) => ( + + {(provided) => ( +
+ {}} + style={{ margin: '4px 0' }} + /> +
+ )} +
+ ))} + + )} + {provided.placeholder} +
+ )} +
+ )} +
@@ -269,13 +265,12 @@ const Assistants: FC = ({ const Container = styled(Scrollbar)` display: flex; flex-direction: column; - padding: 4px 10px; + padding: 0 10px; ` const TagsContainer = styled.div` display: flex; flex-direction: column; - gap: 8px; ` const AssistantAddItem = styled.div` @@ -304,6 +299,7 @@ const GroupTitle = styled.div` justify-content: space-between; align-items: center; height: 24px; + margin: 5px 0; ` const GroupTitleName = styled.div` diff --git a/src/renderer/src/pages/home/Tabs/TopicsTab.tsx b/src/renderer/src/pages/home/Tabs/TopicsTab.tsx index 1929eba045..f39b834b0b 100644 --- a/src/renderer/src/pages/home/Tabs/TopicsTab.tsx +++ b/src/renderer/src/pages/home/Tabs/TopicsTab.tsx @@ -17,7 +17,7 @@ import { isMac } from '@renderer/config/constant' import { useAssistant, useAssistants, useTopicsForAssistant } from '@renderer/hooks/useAssistant' import { modelGenerating } from '@renderer/hooks/useRuntime' import { useSettings } from '@renderer/hooks/useSettings' -import { TopicManager } from '@renderer/hooks/useTopic' +import { finishTopicRenaming, startTopicRenaming, TopicManager } from '@renderer/hooks/useTopic' import { fetchMessagesSummary } from '@renderer/services/ApiService' import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService' import store from '@renderer/store' @@ -58,6 +58,9 @@ const Topics: FC = ({ assistant: _assistant, activeTopic, setActiveTopic const topics = useTopicsForAssistant(_assistant.id) + const renamingTopics = useSelector((state: RootState) => state.runtime.chat.renamingTopics) + const newlyRenamedTopics = useSelector((state: RootState) => state.runtime.chat.newlyRenamedTopics) + const borderRadius = showTopicTime ? 12 : 'var(--list-item-border-radius)' const [deletingTopicId, setDeletingTopicId] = useState(null) @@ -85,6 +88,20 @@ const Topics: FC = ({ assistant: _assistant, activeTopic, setActiveTopic [activeTopic.id, pendingTopics] ) + const isRenaming = useCallback( + (topicId: string) => { + return renamingTopics.includes(topicId) + }, + [renamingTopics] + ) + + const isNewlyRenamed = useCallback( + (topicId: string) => { + return newlyRenamedTopics.includes(topicId) + }, + [newlyRenamedTopics] + ) + const handleDeleteClick = useCallback((topicId: string, e: React.MouseEvent) => { e.stopPropagation() @@ -171,16 +188,22 @@ const Topics: FC = ({ assistant: _assistant, activeTopic, setActiveTopic label: t('chat.topics.auto_rename'), key: 'auto-rename', icon: , + disabled: isRenaming(topic.id), async onClick() { const messages = await TopicManager.getTopicMessages(topic.id) if (messages.length >= 2) { - const summaryText = await fetchMessagesSummary({ messages, assistant }) - if (summaryText) { - const updatedTopic = { ...topic, name: summaryText, isNameManuallyEdited: false } - updateTopic(updatedTopic) - topic.id === activeTopic.id && setActiveTopic(updatedTopic) - } else { - window.message?.error(t('message.error.fetchTopicName')) + startTopicRenaming(topic.id) + try { + const summaryText = await fetchMessagesSummary({ messages, assistant }) + if (summaryText) { + const updatedTopic = { ...topic, name: summaryText, isNameManuallyEdited: false } + updateTopic(updatedTopic) + topic.id === activeTopic.id && setActiveTopic(updatedTopic) + } else { + window.message?.error(t('message.error.fetchTopicName')) + } + } finally { + finishTopicRenaming(topic.id) } } } @@ -189,6 +212,7 @@ const Topics: FC = ({ assistant: _assistant, activeTopic, setActiveTopic label: t('chat.topics.edit.title'), key: 'rename', icon: , + disabled: isRenaming(topic.id), async onClick() { const name = await PromptPopup.show({ title: t('chat.topics.edit.title'), @@ -372,6 +396,7 @@ const Topics: FC = ({ assistant: _assistant, activeTopic, setActiveTopic }, [ targetTopic, t, + isRenaming, exportMenuOptions.image, exportMenuOptions.markdown, exportMenuOptions.markdown_reason, @@ -414,6 +439,13 @@ const Topics: FC = ({ assistant: _assistant, activeTopic, setActiveTopic const topicName = topic.name.replace('`', '') const topicPrompt = topic.prompt const fullTopicPrompt = t('common.prompt') + ': ' + topicPrompt + + const getTopicNameClassName = () => { + if (isRenaming(topic.id)) return 'shimmer' + if (isNewlyRenamed(topic.id)) return 'typing' + return '' + } + return ( setTargetTopic(topic)} @@ -422,7 +454,7 @@ const Topics: FC = ({ assistant: _assistant, activeTopic, setActiveTopic style={{ borderRadius }}> {isPending(topic.id) && !isActive && } - + {topicName} {isActive && !topic.pinned && ( @@ -526,6 +558,46 @@ const TopicName = styled.div` -webkit-box-orient: vertical; overflow: hidden; font-size: 13px; + position: relative; + will-change: background-position, width; + + --color-shimmer-mid: var(--color-text-1); + --color-shimmer-end: color-mix(in srgb, var(--color-text-1) 25%, transparent); + + &.shimmer { + background: linear-gradient(to left, var(--color-shimmer-end), var(--color-shimmer-mid), var(--color-shimmer-end)); + background-size: 200% 100%; + background-clip: text; + color: transparent; + animation: shimmer 3s linear infinite; + } + + &.typing { + display: block; + -webkit-line-clamp: unset; + -webkit-box-orient: unset; + white-space: nowrap; + overflow: hidden; + animation: typewriter 0.5s steps(40, end); + } + + @keyframes shimmer { + 0% { + background-position: 200% 0; + } + 100% { + background-position: -200% 0; + } + } + + @keyframes typewriter { + from { + width: 0; + } + to { + width: 100%; + } + } ` const PendingIndicator = styled.div.attrs({ diff --git a/src/renderer/src/pages/home/Tabs/components/AssistantItem.tsx b/src/renderer/src/pages/home/Tabs/components/AssistantItem.tsx index b23416e7e2..512f8e0924 100644 --- a/src/renderer/src/pages/home/Tabs/components/AssistantItem.tsx +++ b/src/renderer/src/pages/home/Tabs/components/AssistantItem.tsx @@ -43,6 +43,7 @@ interface AssistantItemProps { onTagClick?: (tag: string) => void handleSortByChange?: (sortType: AssistantsSortType) => void singleLine?: boolean + style?: React.CSSProperties } const AssistantItem: FC = ({ @@ -53,7 +54,8 @@ const AssistantItem: FC = ({ onDelete, addAssistant, handleSortByChange, - singleLine = false + singleLine = false, + style }) => { const { t } = useTranslation() const { allTags } = useTags() @@ -160,7 +162,8 @@ const AssistantItem: FC = ({ return ( + className={classNames({ active: isActive, 'is-menu-open': isMenuOpen, singleLine })} + style={style}> {assistantNave}