mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-26 03:31:24 +08:00
Merge branch 'main' into feat/sidebar-ui
# Conflicts: # package.json # src/renderer/src/hooks/useTopic.ts # src/renderer/src/pages/home/Messages/Blocks/ImageBlock.tsx # src/renderer/src/pages/home/Messages/MessageTokens.tsx # src/renderer/src/store/index.ts # src/renderer/src/store/migrate.ts # src/renderer/src/store/runtime.ts
This commit is contained in:
commit
7a44910847
1
.vscode/launch.json
vendored
1
.vscode/launch.json
vendored
@ -7,7 +7,6 @@
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}",
|
||||
"runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite",
|
||||
"runtimeVersion": "20",
|
||||
"windows": {
|
||||
"runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite.cmd"
|
||||
},
|
||||
|
||||
214
docs/technical/how-to-write-middlewares.md
Normal file
214
docs/technical/how-to-write-middlewares.md
Normal file
@ -0,0 +1,214 @@
|
||||
# 如何为 AI Provider 编写中间件
|
||||
|
||||
本文档旨在指导开发者如何为我们的 AI Provider 框架创建和集成自定义中间件。中间件提供了一种强大而灵活的方式来增强、修改或观察 Provider 方法的调用过程,例如日志记录、缓存、请求/响应转换、错误处理等。
|
||||
|
||||
## 架构概览
|
||||
|
||||
我们的中间件架构借鉴了 Redux 的三段式设计,并结合了 JavaScript Proxy 来动态地将中间件应用于 Provider 的方法。
|
||||
|
||||
- **Proxy**: 拦截对 Provider 方法的调用,并将调用引导至中间件链。
|
||||
- **中间件链**: 一系列按顺序执行的中间件函数。每个中间件都可以处理请求/响应,然后将控制权传递给链中的下一个中间件,或者在某些情况下提前终止链。
|
||||
- **上下文 (Context)**: 一个在中间件之间传递的对象,携带了关于当前调用的信息(如方法名、原始参数、Provider 实例、以及中间件自定义的数据)。
|
||||
|
||||
## 中间件的类型
|
||||
|
||||
目前主要支持两种类型的中间件,它们共享相似的结构但针对不同的场景:
|
||||
|
||||
1. **`CompletionsMiddleware`**: 专门为 `completions` 方法设计。这是最常用的中间件类型,因为它允许对 AI 模型的核心聊天/文本生成功能进行精细控制。
|
||||
2. **`ProviderMethodMiddleware`**: 通用中间件,可以应用于 Provider 上的任何其他方法(例如,`translate`, `summarize` 等,如果这些方法也通过中间件系统包装)。
|
||||
|
||||
## 编写一个 `CompletionsMiddleware`
|
||||
|
||||
`CompletionsMiddleware` 的基本签名(TypeScript 类型)如下:
|
||||
|
||||
```typescript
|
||||
import { AiProviderMiddlewareCompletionsContext, CompletionsParams, MiddlewareAPI } from './AiProviderMiddlewareTypes' // 假设类型定义文件路径
|
||||
|
||||
export type CompletionsMiddleware = (
|
||||
api: MiddlewareAPI<AiProviderMiddlewareCompletionsContext, [CompletionsParams]>
|
||||
) => (
|
||||
next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<any> // next 返回 Promise<any> 代表原始SDK响应或下游中间件的结果
|
||||
) => (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<void> // 最内层函数通常返回 Promise<void>,因为结果通过 onChunk 或 context 副作用传递
|
||||
```
|
||||
|
||||
让我们分解这个三段式结构:
|
||||
|
||||
1. **第一层函数 `(api) => { ... }`**:
|
||||
|
||||
- 接收一个 `api` 对象。
|
||||
- `api` 对象提供了以下方法:
|
||||
- `api.getContext()`: 获取当前调用的上下文对象 (`AiProviderMiddlewareCompletionsContext`)。
|
||||
- `api.getOriginalArgs()`: 获取传递给 `completions` 方法的原始参数数组 (即 `[CompletionsParams]`)。
|
||||
- `api.getProviderId()`: 获取当前 Provider 的 ID。
|
||||
- `api.getProviderInstance()`: 获取原始的 Provider 实例。
|
||||
- 此函数通常用于进行一次性的设置或获取所需的服务/配置。它返回第二层函数。
|
||||
|
||||
2. **第二层函数 `(next) => { ... }`**:
|
||||
|
||||
- 接收一个 `next` 函数。
|
||||
- `next` 函数代表了中间件链中的下一个环节。调用 `next(context, params)` 会将控制权传递给下一个中间件,或者如果当前中间件是链中的最后一个,则会调用核心的 Provider 方法逻辑 (例如,实际的 SDK 调用)。
|
||||
- `next` 函数接收当前的 `context` 和 `params` (这些可能已被上游中间件修改)。
|
||||
- **重要的是**:`next` 的返回类型通常是 `Promise<any>`。对于 `completions` 方法,如果 `next` 调用了实际的 SDK,它将返回原始的 SDK 响应(例如,OpenAI 的流对象或 JSON 对象)。你需要处理这个响应。
|
||||
- 此函数返回第三层(也是最核心的)函数。
|
||||
|
||||
3. **第三层函数 `(context, params) => { ... }`**:
|
||||
- 这是执行中间件主要逻辑的地方。
|
||||
- 它接收当前的 `context` (`AiProviderMiddlewareCompletionsContext`) 和 `params` (`CompletionsParams`)。
|
||||
- 在此函数中,你可以:
|
||||
- **在调用 `next` 之前**:
|
||||
- 读取或修改 `params`。例如,添加默认参数、转换消息格式。
|
||||
- 读取或修改 `context`。例如,设置一个时间戳用于后续计算延迟。
|
||||
- 执行某些检查,如果不满足条件,可以不调用 `next` 而直接返回或抛出错误(例如,参数校验失败)。
|
||||
- **调用 `await next(context, params)`**:
|
||||
- 这是将控制权传递给下游的关键步骤。
|
||||
- `next` 的返回值是原始的 SDK 响应或下游中间件的结果,你需要根据情况处理它(例如,如果是流,则开始消费流)。
|
||||
- **在调用 `next` 之后**:
|
||||
- 处理 `next` 的返回结果。例如,如果 `next` 返回了一个流,你可以在这里开始迭代处理这个流,并通过 `context.onChunk` 发送数据块。
|
||||
- 基于 `context` 的变化或 `next` 的结果执行进一步操作。例如,计算总耗时、记录日志。
|
||||
- 修改最终结果(尽管对于 `completions`,结果通常通过 `onChunk` 副作用发出)。
|
||||
|
||||
### 示例:一个简单的日志中间件
|
||||
|
||||
```typescript
|
||||
import {
|
||||
AiProviderMiddlewareCompletionsContext,
|
||||
CompletionsParams,
|
||||
MiddlewareAPI,
|
||||
OnChunkFunction // 假设 OnChunkFunction 类型被导出
|
||||
} from './AiProviderMiddlewareTypes' // 调整路径
|
||||
import { ChunkType } from '@renderer/types' // 调整路径
|
||||
|
||||
export const createSimpleLoggingMiddleware = (): CompletionsMiddleware => {
|
||||
return (api: MiddlewareAPI<AiProviderMiddlewareCompletionsContext, [CompletionsParams]>) => {
|
||||
// console.log(`[LoggingMiddleware] Initialized for provider: ${api.getProviderId()}`);
|
||||
|
||||
return (next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<any>) => {
|
||||
return async (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams): Promise<void> => {
|
||||
const startTime = Date.now()
|
||||
// 从 context 中获取 onChunk (它最初来自 params.onChunk)
|
||||
const onChunk = context.onChunk
|
||||
|
||||
console.log(
|
||||
`[LoggingMiddleware] Request for ${context.methodName} with params:`,
|
||||
params.messages?.[params.messages.length - 1]?.content
|
||||
)
|
||||
|
||||
try {
|
||||
// 调用下一个中间件或核心逻辑
|
||||
// `rawSdkResponse` 是来自下游的原始响应 (例如 OpenAIStream 或 ChatCompletion 对象)
|
||||
const rawSdkResponse = await next(context, params)
|
||||
|
||||
// 此处简单示例不处理 rawSdkResponse,假设下游中间件 (如 StreamingResponseHandler)
|
||||
// 会处理它并通过 onChunk 发送数据。
|
||||
// 如果这个日志中间件在 StreamingResponseHandler 之后,那么流已经被处理。
|
||||
// 如果在之前,那么它需要自己处理 rawSdkResponse 或确保下游会处理。
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
console.log(`[LoggingMiddleware] Request for ${context.methodName} completed in ${duration}ms.`)
|
||||
|
||||
// 假设下游已经通过 onChunk 发送了所有数据。
|
||||
// 如果这个中间件是链的末端,并且需要确保 BLOCK_COMPLETE 被发送,
|
||||
// 它可能需要更复杂的逻辑来跟踪何时所有数据都已发送。
|
||||
} catch (error) {
|
||||
const duration = Date.now() - startTime
|
||||
console.error(`[LoggingMiddleware] Request for ${context.methodName} failed after ${duration}ms:`, error)
|
||||
|
||||
// 如果 onChunk 可用,可以尝试发送一个错误块
|
||||
if (onChunk) {
|
||||
onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error: { message: (error as Error).message, name: (error as Error).name, stack: (error as Error).stack }
|
||||
})
|
||||
// 考虑是否还需要发送 BLOCK_COMPLETE 来结束流
|
||||
onChunk({ type: ChunkType.BLOCK_COMPLETE, response: {} })
|
||||
}
|
||||
throw error // 重新抛出错误,以便上层或全局错误处理器可以捕获
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### `AiProviderMiddlewareCompletionsContext` 的重要性
|
||||
|
||||
`AiProviderMiddlewareCompletionsContext` 是在中间件之间传递状态和数据的核心。它通常包含:
|
||||
|
||||
- `methodName`: 当前调用的方法名 (总是 `'completions'`)。
|
||||
- `originalArgs`: 传递给 `completions` 的原始参数数组。
|
||||
- `providerId`: Provider 的 ID。
|
||||
- `_providerInstance`: Provider 实例。
|
||||
- `onChunk`: 从原始 `CompletionsParams` 传入的回调函数,用于流式发送数据块。**所有中间件都应该通过 `context.onChunk` 来发送数据。**
|
||||
- `messages`, `model`, `assistant`, `mcpTools`: 从原始 `CompletionsParams` 中提取的常用字段,方便访问。
|
||||
- **自定义字段**: 中间件可以向上下文中添加自定义字段,以供后续中间件使用。例如,一个缓存中间件可能会添加 `context.cacheHit = true`。
|
||||
|
||||
**关键**: 当你在中间件中修改 `params` 或 `context` 时,这些修改会向下游中间件传播(如果它们在 `next` 调用之前修改)。
|
||||
|
||||
### 中间件的顺序
|
||||
|
||||
中间件的执行顺序非常重要。它们在 `AiProviderMiddlewareConfig` 的数组中定义的顺序就是它们的执行顺序。
|
||||
|
||||
- 请求首先通过第一个中间件,然后是第二个,依此类推。
|
||||
- 响应(或 `next` 的调用结果)则以相反的顺序"冒泡"回来。
|
||||
|
||||
例如,如果链是 `[AuthMiddleware, CacheMiddleware, LoggingMiddleware]`:
|
||||
|
||||
1. `AuthMiddleware` 先执行其 "调用 `next` 之前" 的逻辑。
|
||||
2. 然后 `CacheMiddleware` 执行其 "调用 `next` 之前" 的逻辑。
|
||||
3. 然后 `LoggingMiddleware` 执行其 "调用 `next` 之前" 的逻辑。
|
||||
4. 核心SDK调用(或链的末端)。
|
||||
5. `LoggingMiddleware` 先接收到结果,执行其 "调用 `next` 之后" 的逻辑。
|
||||
6. 然后 `CacheMiddleware` 接收到结果(可能已被 LoggingMiddleware 修改的上下文),执行其 "调用 `next` 之后" 的逻辑(例如,存储结果)。
|
||||
7. 最后 `AuthMiddleware` 接收到结果,执行其 "调用 `next` 之后" 的逻辑。
|
||||
|
||||
### 注册中间件
|
||||
|
||||
中间件在 `src/renderer/src/providers/middleware/register.ts` (或其他类似的配置文件) 中进行注册。
|
||||
|
||||
```typescript
|
||||
// register.ts
|
||||
import { AiProviderMiddlewareConfig } from './AiProviderMiddlewareTypes'
|
||||
import { createSimpleLoggingMiddleware } from './common/SimpleLoggingMiddleware' // 假设你创建了这个文件
|
||||
import { createCompletionsLoggingMiddleware } from './common/CompletionsLoggingMiddleware' // 已有的
|
||||
|
||||
const middlewareConfig: AiProviderMiddlewareConfig = {
|
||||
completions: [
|
||||
createSimpleLoggingMiddleware(), // 你新加的中间件
|
||||
createCompletionsLoggingMiddleware() // 已有的日志中间件
|
||||
// ... 其他 completions 中间件
|
||||
],
|
||||
methods: {
|
||||
// translate: [createGenericLoggingMiddleware()],
|
||||
// ... 其他方法的中间件
|
||||
}
|
||||
}
|
||||
|
||||
export default middlewareConfig
|
||||
```
|
||||
|
||||
### 最佳实践
|
||||
|
||||
1. **单一职责**: 每个中间件应专注于一个特定的功能(例如,日志、缓存、转换特定数据)。
|
||||
2. **无副作用 (尽可能)**: 除了通过 `context` 或 `onChunk` 明确的副作用外,尽量避免修改全局状态或产生其他隐蔽的副作用。
|
||||
3. **错误处理**:
|
||||
- 在中间件内部使用 `try...catch` 来处理可能发生的错误。
|
||||
- 决定是自行处理错误(例如,通过 `onChunk` 发送错误块)还是将错误重新抛出给上游。
|
||||
- 如果重新抛出,确保错误对象包含足够的信息。
|
||||
4. **性能考虑**: 中间件会增加请求处理的开销。避免在中间件中执行非常耗时的同步操作。对于IO密集型操作,确保它们是异步的。
|
||||
5. **可配置性**: 使中间件的行为可通过参数或配置进行调整。例如,日志中间件可以接受一个日志级别参数。
|
||||
6. **上下文管理**:
|
||||
- 谨慎地向 `context` 添加数据。避免污染 `context` 或添加过大的对象。
|
||||
- 明确你添加到 `context` 的字段的用途和生命周期。
|
||||
7. **`next` 的调用**:
|
||||
- 除非你有充分的理由提前终止请求(例如,缓存命中、授权失败),否则**总是确保调用 `await next(context, params)`**。否则,下游的中间件和核心逻辑将不会执行。
|
||||
- 理解 `next` 的返回值并正确处理它,特别是当它是一个流时。你需要负责消费这个流或将其传递给另一个能够消费它的组件/中间件。
|
||||
8. **命名清晰**: 给你的中间件和它们创建的函数起描述性的名字。
|
||||
9. **文档和注释**: 对复杂的中间件逻辑添加注释,解释其工作原理和目的。
|
||||
|
||||
### 调试技巧
|
||||
|
||||
- 在中间件的关键点使用 `console.log` 或调试器来检查 `params`、`context` 的状态以及 `next` 的返回值。
|
||||
- 暂时简化中间件链,只保留你正在调试的中间件和最简单的核心逻辑,以隔离问题。
|
||||
- 编写单元测试来独立验证每个中间件的行为。
|
||||
|
||||
通过遵循这些指南,你应该能够有效地为我们的系统创建强大且可维护的中间件。如果你有任何疑问或需要进一步的帮助,请咨询团队。
|
||||
81
package.json
81
package.json
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "CherryStudio",
|
||||
"version": "1.4.2-ui-preview",
|
||||
"version": "1.4.2",
|
||||
"private": true,
|
||||
"description": "A powerful AI assistant for producer.",
|
||||
"main": "./out/main/index.js",
|
||||
@ -58,6 +58,20 @@
|
||||
"prepare": "husky"
|
||||
},
|
||||
"dependencies": {
|
||||
"@libsql/client": "0.14.0",
|
||||
"@libsql/win32-x64-msvc": "^0.4.7",
|
||||
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
||||
"jsdom": "26.1.0",
|
||||
"os-proxy-config": "^1.1.2",
|
||||
"selection-hook": "^0.9.23",
|
||||
"turndown": "7.2.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@agentic/exa": "^7.3.3",
|
||||
"@agentic/searxng": "^7.3.3",
|
||||
"@agentic/tavily": "^7.3.3",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
"@cherrystudio/embedjs": "^0.1.31",
|
||||
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
||||
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
||||
@ -70,48 +84,11 @@
|
||||
"@cherrystudio/embedjs-loader-xml": "^0.1.31",
|
||||
"@cherrystudio/embedjs-ollama": "^0.1.31",
|
||||
"@cherrystudio/embedjs-openai": "^0.1.31",
|
||||
"@electron-toolkit/utils": "^3.0.0",
|
||||
"@langchain/community": "^0.3.36",
|
||||
"@langchain/ollama": "^0.2.1",
|
||||
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
||||
"@tanstack/react-query": "^5.27.0",
|
||||
"@types/react-infinite-scroll-component": "^5.0.0",
|
||||
"archiver": "^7.0.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
"diff": "^7.0.0",
|
||||
"docx": "^9.0.2",
|
||||
"electron-log": "^5.1.5",
|
||||
"electron-store": "^8.2.0",
|
||||
"electron-updater": "6.6.4",
|
||||
"electron-window-state": "^5.0.3",
|
||||
"epub": "patch:epub@npm%3A1.3.0#~/.yarn/patches/epub-npm-1.3.0-8325494ffe.patch",
|
||||
"fast-xml-parser": "^5.2.0",
|
||||
"framer-motion": "^12.17.0",
|
||||
"franc-min": "^6.2.0",
|
||||
"fs-extra": "^11.2.0",
|
||||
"jsdom": "^26.0.0",
|
||||
"markdown-it": "^14.1.0",
|
||||
"node-stream-zip": "^1.15.0",
|
||||
"officeparser": "^4.1.1",
|
||||
"os-proxy-config": "^1.1.2",
|
||||
"proxy-agent": "^6.5.0",
|
||||
"remove-markdown": "^0.6.2",
|
||||
"selection-hook": "^0.9.23",
|
||||
"tar": "^7.4.3",
|
||||
"turndown": "^7.2.0",
|
||||
"webdav": "^5.8.0",
|
||||
"zipread": "^1.3.3"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@agentic/exa": "^7.3.3",
|
||||
"@agentic/searxng": "^7.3.3",
|
||||
"@agentic/tavily": "^7.3.3",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
"@electron-toolkit/eslint-config-prettier": "^3.0.0",
|
||||
"@electron-toolkit/eslint-config-ts": "^3.0.0",
|
||||
"@electron-toolkit/preload": "^3.0.0",
|
||||
"@electron-toolkit/tsconfig": "^1.0.1",
|
||||
"@electron-toolkit/utils": "^3.0.0",
|
||||
"@electron/notarize": "^2.5.0",
|
||||
"@emotion/is-prop-valid": "^1.3.1",
|
||||
"@eslint-react/eslint-plugin": "^1.36.1",
|
||||
@ -119,6 +96,8 @@
|
||||
"@google/genai": "^1.0.1",
|
||||
"@hello-pangea/dnd": "^16.6.0",
|
||||
"@kangfenmao/keyv-storage": "^0.1.0",
|
||||
"@langchain/community": "^0.3.36",
|
||||
"@langchain/ollama": "^0.2.1",
|
||||
"@modelcontextprotocol/sdk": "^1.11.4",
|
||||
"@mozilla/readability": "^0.6.0",
|
||||
"@notionhq/client": "^2.2.15",
|
||||
@ -126,6 +105,7 @@
|
||||
"@reduxjs/toolkit": "^2.2.5",
|
||||
"@shikijs/markdown-it": "^3.4.2",
|
||||
"@swc/plugin-styled-components": "^7.1.5",
|
||||
"@tanstack/react-query": "^5.27.0",
|
||||
"@testing-library/dom": "^10.4.0",
|
||||
"@testing-library/jest-dom": "^6.6.3",
|
||||
"@testing-library/react": "^16.3.0",
|
||||
@ -152,24 +132,37 @@
|
||||
"@vitest/web-worker": "^3.1.4",
|
||||
"@xyflow/react": "^12.4.4",
|
||||
"antd": "^5.22.5",
|
||||
"archiver": "^7.0.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
"axios": "^1.7.3",
|
||||
"browser-image-compression": "^2.0.2",
|
||||
"color": "^5.0.0",
|
||||
"dayjs": "^1.11.11",
|
||||
"dexie": "^4.0.8",
|
||||
"dexie-react-hooks": "^1.1.7",
|
||||
"diff": "^7.0.0",
|
||||
"docx": "^9.0.2",
|
||||
"dotenv-cli": "^7.4.2",
|
||||
"electron": "35.4.0",
|
||||
"electron-builder": "26.0.15",
|
||||
"electron-devtools-installer": "^3.2.0",
|
||||
"electron-log": "^5.1.5",
|
||||
"electron-store": "^8.2.0",
|
||||
"electron-updater": "6.6.4",
|
||||
"electron-vite": "^3.1.0",
|
||||
"electron-window-state": "^5.0.3",
|
||||
"emittery": "^1.0.3",
|
||||
"emoji-picker-element": "^1.22.1",
|
||||
"epub": "patch:epub@npm%3A1.3.0#~/.yarn/patches/epub-npm-1.3.0-8325494ffe.patch",
|
||||
"eslint": "^9.22.0",
|
||||
"eslint-plugin-react-hooks": "^5.2.0",
|
||||
"eslint-plugin-simple-import-sort": "^12.1.1",
|
||||
"eslint-plugin-unused-imports": "^4.1.4",
|
||||
"fast-diff": "^1.3.0",
|
||||
"fast-xml-parser": "^5.2.0",
|
||||
"framer-motion": "^12.17.3",
|
||||
"franc-min": "^6.2.0",
|
||||
"fs-extra": "^11.2.0",
|
||||
"html-to-image": "^1.11.13",
|
||||
"husky": "^9.1.7",
|
||||
"i18next": "^23.11.5",
|
||||
@ -178,14 +171,18 @@
|
||||
"lodash": "^4.17.21",
|
||||
"lru-cache": "^11.1.0",
|
||||
"lucide-react": "^0.487.0",
|
||||
"markdown-it": "^14.1.0",
|
||||
"mermaid": "^11.6.0",
|
||||
"mime": "^4.0.4",
|
||||
"motion": "^12.10.5",
|
||||
"node-stream-zip": "^1.15.0",
|
||||
"npx-scope-finder": "^1.2.0",
|
||||
"officeparser": "^4.1.1",
|
||||
"openai": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch",
|
||||
"p-queue": "^8.1.0",
|
||||
"playwright": "^1.52.0",
|
||||
"prettier": "^3.5.3",
|
||||
"proxy-agent": "^6.5.0",
|
||||
"rc-virtual-list": "^3.18.6",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
@ -206,17 +203,21 @@
|
||||
"remark-cjk-friendly": "^1.1.0",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"remark-math": "^6.0.0",
|
||||
"remove-markdown": "^0.6.2",
|
||||
"rollup-plugin-visualizer": "^5.12.0",
|
||||
"sass": "^1.88.0",
|
||||
"shiki": "^3.4.2",
|
||||
"string-width": "^7.2.0",
|
||||
"styled-components": "^6.1.11",
|
||||
"tar": "^7.4.3",
|
||||
"tiny-pinyin": "^1.3.2",
|
||||
"tokenx": "^0.4.1",
|
||||
"typescript": "^5.6.2",
|
||||
"uuid": "^10.0.0",
|
||||
"vite": "6.2.6",
|
||||
"vitest": "^3.1.4"
|
||||
"vitest": "^3.1.4",
|
||||
"webdav": "^5.8.0",
|
||||
"zipread": "^1.3.3"
|
||||
},
|
||||
"resolutions": {
|
||||
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
|
||||
|
||||
@ -408,3 +408,4 @@ export enum FeedUrl {
|
||||
PRODUCTION = 'https://releases.cherry-ai.com',
|
||||
EARLY_ACCESS = 'https://github.com/CherryHQ/cherry-studio/releases/latest/download'
|
||||
}
|
||||
export const defaultTimeout = 5 * 1000 * 60
|
||||
|
||||
223
src/renderer/src/aiCore/AI_CORE_DESIGN.md
Normal file
223
src/renderer/src/aiCore/AI_CORE_DESIGN.md
Normal file
@ -0,0 +1,223 @@
|
||||
# Cherry Studio AI Provider 技术架构文档 (新方案)
|
||||
|
||||
## 1. 核心设计理念与目标
|
||||
|
||||
本架构旨在重构 Cherry Studio 的 AI Provider(现称为 `aiCore`)层,以实现以下目标:
|
||||
|
||||
- **职责清晰**:明确划分各组件的职责,降低耦合度。
|
||||
- **高度复用**:最大化业务逻辑和通用处理逻辑的复用,减少重复代码。
|
||||
- **易于扩展**:方便快捷地接入新的 AI Provider (LLM供应商) 和添加新的 AI 功能 (如翻译、摘要、图像生成等)。
|
||||
- **易于维护**:简化单个组件的复杂性,提高代码的可读性和可维护性。
|
||||
- **标准化**:统一内部数据流和接口,简化不同 Provider 之间的差异处理。
|
||||
|
||||
核心思路是将纯粹的 **SDK 适配层 (`XxxApiClient`)**、**通用逻辑处理与智能解析层 (中间件)** 以及 **统一业务功能入口层 (`AiCoreService`)** 清晰地分离开来。
|
||||
|
||||
## 2. 核心组件详解
|
||||
|
||||
### 2.1. `aiCore` (原 `AiProvider` 文件夹)
|
||||
|
||||
这是整个 AI 功能的核心模块。
|
||||
|
||||
#### 2.1.1. `XxxApiClient` (例如 `aiCore/clients/openai/OpenAIApiClient.ts`)
|
||||
|
||||
- **职责**:作为特定 AI Provider SDK 的纯粹适配层。
|
||||
- **参数适配**:将应用内部统一的 `CoreRequest` 对象 (见下文) 转换为特定 SDK 所需的请求参数格式。
|
||||
- **基础响应转换**:将 SDK 返回的原始数据块 (`RawSdkChunk`,例如 `OpenAI.Chat.Completions.ChatCompletionChunk`) 转换为一组最基础、最直接的应用层 `Chunk` 对象 (定义于 `src/renderer/src/types/chunk.ts`)。
|
||||
- 例如:SDK 的 `delta.content` -> `TextDeltaChunk`;SDK 的 `delta.reasoning_content` -> `ThinkingDeltaChunk`;SDK 的 `delta.tool_calls` -> `RawToolCallChunk` (包含原始工具调用数据)。
|
||||
- **关键**:`XxxApiClient` **不处理**耦合在文本内容中的复杂结构,如 `<think>` 或 `<tool_use>` 标签。
|
||||
- **特点**:极度轻量化,代码量少,易于实现和维护新的 Provider 适配。
|
||||
|
||||
#### 2.1.2. `ApiClient.ts` (或 `BaseApiClient.ts` 的核心接口)
|
||||
|
||||
- 定义了所有 `XxxApiClient` 必须实现的接口,如:
|
||||
- `getSdkInstance(): Promise<TSdkInstance> | TSdkInstance`
|
||||
- `getRequestTransformer(): RequestTransformer<TSdkParams>`
|
||||
- `getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk, TResponseContext>`
|
||||
- 其他可选的、与特定 Provider 相关的辅助方法 (如工具调用转换)。
|
||||
|
||||
#### 2.1.3. `ApiClientFactory.ts`
|
||||
|
||||
- 根据 Provider 配置动态创建和返回相应的 `XxxApiClient` 实例。
|
||||
|
||||
#### 2.1.4. `AiCoreService.ts` (`aiCore/index.ts`)
|
||||
|
||||
- **职责**:作为所有 AI 相关业务功能的统一入口。
|
||||
- 提供面向应用的高层接口,例如:
|
||||
- `executeCompletions(params: CompletionsParams): Promise<AggregatedCompletionsResult>`
|
||||
- `translateText(params: TranslateParams): Promise<AggregatedTranslateResult>`
|
||||
- `summarizeText(params: SummarizeParams): Promise<AggregatedSummarizeResult>`
|
||||
- 未来可能的 `generateImage(prompt: string): Promise<ImageResult>` 等。
|
||||
- **返回 `Promise`**:每个服务方法返回一个 `Promise`,该 `Promise` 会在整个(可能是流式的)操作完成后,以包含所有聚合结果(如完整文本、工具调用详情、最终的`usage`/`metrics`等)的对象来 `resolve`。
|
||||
- **支持流式回调**:服务方法的参数 (如 `CompletionsParams`) 依然包含 `onChunk` 回调,用于向调用方实时推送处理过程中的 `Chunk` 数据,实现流式UI更新。
|
||||
- **封装特定任务的提示工程 (Prompt Engineering)**:
|
||||
- 例如,`translateText` 方法内部会构建一个包含特定翻译指令的 `CoreRequest`。
|
||||
- **编排和调用中间件链**:通过内部的 `MiddlewareBuilder` (参见 `middleware/BUILDER_USAGE.md`) 实例,根据调用的业务方法和参数,动态构建和组织合适的中间件序列,然后通过 `applyCompletionsMiddlewares` 等组合函数执行。
|
||||
- 获取 `ApiClient` 实例并将其注入到中间件上游的 `Context` 中。
|
||||
- **将 `Promise` 的 `resolve` 和 `reject` 函数传递给中间件链** (通过 `Context`),以便 `FinalChunkConsumerAndNotifierMiddleware` 可以在操作完成或发生错误时结束该 `Promise`。
|
||||
- **优势**:
|
||||
- 业务逻辑(如翻译、摘要的提示构建和流程控制)只需实现一次,即可支持所有通过 `ApiClient` 接入的底层 Provider。
|
||||
- **支持外部编排**:调用方可以 `await` 服务方法以获取最终聚合结果,然后将此结果作为后续操作的输入,轻松实现多步骤工作流。
|
||||
- **支持内部组合**:服务自身也可以通过 `await` 调用其他原子服务方法来构建更复杂的组合功能。
|
||||
|
||||
#### 2.1.5. `coreRequestTypes.ts` (或 `types.ts`)
|
||||
|
||||
- 定义核心的、Provider 无关的内部请求结构,例如:
|
||||
- `CoreCompletionsRequest`: 包含标准化后的消息列表、模型配置、工具列表、最大Token数、是否流式输出等。
|
||||
- `CoreTranslateRequest`, `CoreSummarizeRequest` 等 (如果与 `CoreCompletionsRequest` 结构差异较大,否则可复用并添加任务类型标记)。
|
||||
|
||||
### 2.2. `middleware`
|
||||
|
||||
中间件层负责处理请求和响应流中的通用逻辑和特定特性。其设计和使用遵循 `middleware/BUILDER_USAGE.md` 中定义的规范。
|
||||
|
||||
**核心组件包括:**
|
||||
|
||||
- **`MiddlewareBuilder`**: 一个通用的、提供流式API的类,用于动态构建中间件链。它支持从基础链开始,根据条件添加、插入、替换或移除中间件。
|
||||
- **`applyCompletionsMiddlewares`**: 负责接收 `MiddlewareBuilder` 构建的链并按顺序执行,专门用于 Completions 流程。
|
||||
- **`MiddlewareRegistry`**: 集中管理所有可用中间件的注册表,提供统一的中间件访问接口。
|
||||
- **各种独立的中间件模块** (存放于 `common/`, `core/`, `feat/` 子目录)。
|
||||
|
||||
#### 2.2.1. `middlewareTypes.ts`
|
||||
|
||||
- 定义中间件的核心类型,如 `AiProviderMiddlewareContext` (扩展后包含 `_apiClientInstance` 和 `_coreRequest`)、`MiddlewareAPI`、`CompletionsMiddleware` 等。
|
||||
|
||||
#### 2.2.2. 核心中间件 (`middleware/core/`)
|
||||
|
||||
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()` 将 `CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
|
||||
- **`RequestExecutionMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
|
||||
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流 (如异步迭代器) 统一适配为 `ReadableStream<RawSdkChunk>`。
|
||||
- **`RawSdkChunk`**:指特定AI提供商SDK在流式响应中返回的、未经应用层统一处理的原始数据块格式 (例如 OpenAI 的 `ChatCompletionChunk`,Gemini 的 `GenerateContentResponse` 中的部分等)。
|
||||
- **`RawSdkChunkToAppChunkMiddleware.ts`**: (新增) 消费 `ReadableStream<RawSdkChunk>`,在其内部对每个 `RawSdkChunk` 调用 `ApiClient.getResponseChunkTransformer()`,将其转换为一个或多个基础的应用层 `Chunk` 对象,并输出 `ReadableStream<Chunk>`。
|
||||
|
||||
#### 2.2.3. 特性中间件 (`middleware/feat/`)
|
||||
|
||||
这些中间件消费由 `ResponseTransformMiddleware` 输出的、相对标准化的 `Chunk` 流,并处理更复杂的逻辑。
|
||||
|
||||
- **`ThinkingTagExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<think>...</think>` 文本内嵌标签,生成 `ThinkingDeltaChunk` 和 `ThinkingCompleteChunk`。
|
||||
- **`ToolUseExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<tool_use>...</tool_use>` 文本内嵌标签,生成工具调用相关的 Chunk。如果 `ApiClient` 输出了原生工具调用数据,此中间件也负责将其转换为标准格式。
|
||||
|
||||
#### 2.2.4. 核心处理中间件 (`middleware/core/`)
|
||||
|
||||
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()` 将 `CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
|
||||
- **`SdkCallMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
|
||||
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流统一适配为标准流格式。
|
||||
- **`ResponseTransformMiddleware.ts`**: 将原始 SDK 响应转换为应用层标准 `Chunk` 对象。
|
||||
- **`TextChunkMiddleware.ts`**: 处理文本相关的 Chunk 流。
|
||||
- **`ThinkChunkMiddleware.ts`**: 处理思考相关的 Chunk 流。
|
||||
- **`McpToolChunkMiddleware.ts`**: 处理工具调用相关的 Chunk 流。
|
||||
- **`WebSearchMiddleware.ts`**: 处理 Web 搜索相关逻辑。
|
||||
|
||||
#### 2.2.5. 通用中间件 (`middleware/common/`)
|
||||
|
||||
- **`LoggingMiddleware.ts`**: 请求和响应日志。
|
||||
- **`AbortHandlerMiddleware.ts`**: 处理请求中止。
|
||||
- **`FinalChunkConsumerMiddleware.ts`**: 消费最终的 `Chunk` 流,通过 `context.onChunk` 回调通知应用层实时数据。
|
||||
- **累积数据**:在流式处理过程中,累积关键数据,如文本片段、工具调用信息、`usage`/`metrics` 等。
|
||||
- **结束 `Promise`**:当输入流结束时,使用累积的聚合结果来完成整个处理流程。
|
||||
- 在流结束时,发送包含最终累加信息的完成信号。
|
||||
|
||||
### 2.3. `types/chunk.ts`
|
||||
|
||||
- 定义应用全局统一的 `Chunk` 类型及其所有变体。这包括基础类型 (如 `TextDeltaChunk`, `ThinkingDeltaChunk`)、SDK原生数据传递类型 (如 `RawToolCallChunk`, `RawFinishChunk` - 作为 `ApiClient` 转换的中间产物),以及功能性类型 (如 `McpToolCallRequestChunk`, `WebSearchCompleteChunk`)。
|
||||
|
||||
## 3. 核心执行流程 (以 `AiCoreService.executeCompletions` 为例)
|
||||
|
||||
```markdown
|
||||
**应用层 (例如 UI 组件)**
|
||||
||
|
||||
\\/
|
||||
**`AiProvider.completions` (`aiCore/index.ts`)**
|
||||
(1. prepare ApiClient instance. 2. use `CompletionsMiddlewareBuilder.withDefaults()` to build middleware chain. 3. call `applyCompletionsMiddlewares`)
|
||||
||
|
||||
\\/
|
||||
**`applyCompletionsMiddlewares` (`middleware/composer.ts`)**
|
||||
(接收构建好的链、ApiClient实例、原始SDK方法,开始按序执行中间件)
|
||||
||
|
||||
\\/
|
||||
**[ 预处理阶段中间件 ]**
|
||||
(例如: `FinalChunkConsumerMiddleware`, `TransformCoreToSdkParamsMiddleware`, `AbortHandlerMiddleware`)
|
||||
|| (Context 中准备好 SDK 请求参数)
|
||||
\\/
|
||||
**[ 处理阶段中间件 ]**
|
||||
(例如: `McpToolChunkMiddleware`, `WebSearchMiddleware`, `TextChunkMiddleware`, `ThinkingTagExtractionMiddleware`)
|
||||
|| (处理各种特性和Chunk类型)
|
||||
\\/
|
||||
**[ SDK调用阶段中间件 ]**
|
||||
(例如: `ResponseTransformMiddleware`, `StreamAdapterMiddleware`, `SdkCallMiddleware`)
|
||||
|| (输出: 标准化的应用层Chunk流)
|
||||
\\/
|
||||
**`FinalChunkConsumerMiddleware` (核心)**
|
||||
(消费最终的 `Chunk` 流, 通过 `context.onChunk` 回调通知应用层, 并在流结束时完成处理)
|
||||
||
|
||||
\\/
|
||||
**`AiProvider.completions` 返回 `Promise<CompletionsResult>`**
|
||||
```
|
||||
|
||||
## 4. 建议的文件/目录结构
|
||||
|
||||
```
|
||||
src/renderer/src/
|
||||
└── aiCore/
|
||||
├── clients/
|
||||
│ ├── openai/
|
||||
│ ├── gemini/
|
||||
│ ├── anthropic/
|
||||
│ ├── BaseApiClient.ts
|
||||
│ ├── ApiClientFactory.ts
|
||||
│ ├── AihubmixAPIClient.ts
|
||||
│ ├── index.ts
|
||||
│ └── types.ts
|
||||
├── middleware/
|
||||
│ ├── common/
|
||||
│ ├── core/
|
||||
│ ├── feat/
|
||||
│ ├── builder.ts
|
||||
│ ├── composer.ts
|
||||
│ ├── index.ts
|
||||
│ ├── register.ts
|
||||
│ ├── schemas.ts
|
||||
│ ├── types.ts
|
||||
│ └── utils.ts
|
||||
├── types/
|
||||
│ ├── chunk.ts
|
||||
│ └── ...
|
||||
└── index.ts
|
||||
```
|
||||
|
||||
## 5. 迁移和实施建议
|
||||
|
||||
- **小步快跑,逐步迭代**:优先完成核心流程的重构(例如 `completions`),再逐步迁移其他功能(`translate` 等)和其他 Provider。
|
||||
- **优先定义核心类型**:`CoreRequest`, `Chunk`, `ApiClient` 接口是整个架构的基石。
|
||||
- **为 `ApiClient` 瘦身**:将现有 `XxxProvider` 中的复杂逻辑剥离到新的中间件或 `AiCoreService` 中。
|
||||
- **强化中间件**:让中间件承担起更多解析和特性处理的责任。
|
||||
- **编写单元测试和集成测试**:确保每个组件和整体流程的正确性。
|
||||
|
||||
此架构旨在提供一个更健壮、更灵活、更易于维护的 AI 功能核心,支撑 Cherry Studio 未来的发展。
|
||||
|
||||
## 6. 迁移策略与实施建议
|
||||
|
||||
本节内容提炼自早期的 `migrate.md` 文档,并根据最新的架构讨论进行了调整。
|
||||
|
||||
**目标架构核心组件回顾:**
|
||||
|
||||
与第 2 节描述的核心组件一致,主要包括 `XxxApiClient`, `AiCoreService`, 中间件链, `CoreRequest` 类型, 和标准化的 `Chunk` 类型。
|
||||
|
||||
**迁移步骤:**
|
||||
|
||||
**Phase 0: 准备工作和类型定义**
|
||||
|
||||
1. **定义核心数据结构 (TypeScript 类型):**
|
||||
- `CoreCompletionsRequest` (Type):定义应用内部统一的对话请求结构。
|
||||
- `Chunk` (Type - 检查并按需扩展现有 `src/renderer/src/types/chunk.ts`):定义所有可能的通用Chunk类型。
|
||||
- 为其他API(翻译、总结)定义类似的 `CoreXxxRequest` (Type)。
|
||||
2. **定义 `ApiClient` 接口:** 明确 `getRequestTransformer`, `getResponseChunkTransformer`, `getSdkInstance` 等核心方法。
|
||||
3. **调整 `AiProviderMiddlewareContext`:**
|
||||
- 确保包含 `_apiClientInstance: ApiClient<any,any,any>`。
|
||||
- 确保包含 `_coreRequest: CoreRequestType`。
|
||||
- 考虑添加 `resolvePromise: (value: AggregatedResultType) => void` 和 `rejectPromise: (reason?: any) => void` 用于 `AiCoreService` 的 Promise 返回。
|
||||
|
||||
**Phase 1: 实现第一个 `ApiClient` (以 `OpenAIApiClient` 为例)**
|
||||
|
||||
1. **创建 `OpenAIApiClient` 类:** 实现 `ApiClient` 接口。
|
||||
2. **迁移SDK实例和配置。**
|
||||
3. **实现 `getRequestTransformer()`:** 将 `CoreCompletionsRequest` 转换为 OpenAI SDK 参数。
|
||||
4. **实现 `getResponseChunkTransformer()`:** 将 `OpenAI.Chat.Completions.ChatCompletionChunk` 转换为基础的 `
|
||||
207
src/renderer/src/aiCore/clients/AihubmixAPIClient.ts
Normal file
207
src/renderer/src/aiCore/clients/AihubmixAPIClient.ts
Normal file
@ -0,0 +1,207 @@
|
||||
import { isOpenAILLMModel } from '@renderer/config/models'
|
||||
import {
|
||||
GenerateImageParams,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse
|
||||
} from '@renderer/types'
|
||||
import {
|
||||
RequestOptions,
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkModel,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
|
||||
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||
import { BaseApiClient } from './BaseApiClient'
|
||||
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||
import { RequestTransformer, ResponseChunkTransformer } from './types'
|
||||
|
||||
/**
|
||||
* AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient
|
||||
* 使用装饰器模式实现,在ApiClient层面进行模型路由
|
||||
*/
|
||||
export class AihubmixAPIClient extends BaseApiClient {
|
||||
// 使用联合类型而不是any,保持类型安全
|
||||
private clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
|
||||
new Map()
|
||||
private defaultClient: OpenAIAPIClient
|
||||
private currentClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
|
||||
// 初始化各个client - 现在有类型安全
|
||||
const claudeClient = new AnthropicAPIClient(provider)
|
||||
const geminiClient = new GeminiAPIClient({ ...provider, apiHost: 'https://aihubmix.com/gemini' })
|
||||
const openaiClient = new OpenAIResponseAPIClient(provider)
|
||||
const defaultClient = new OpenAIAPIClient(provider)
|
||||
|
||||
this.clients.set('claude', claudeClient)
|
||||
this.clients.set('gemini', geminiClient)
|
||||
this.clients.set('openai', openaiClient)
|
||||
this.clients.set('default', defaultClient)
|
||||
|
||||
// 设置默认client
|
||||
this.defaultClient = defaultClient
|
||||
this.currentClient = this.defaultClient as BaseApiClient
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫:确保client是BaseApiClient的实例
|
||||
*/
|
||||
private isValidClient(client: unknown): client is BaseApiClient {
|
||||
return (
|
||||
client !== null &&
|
||||
client !== undefined &&
|
||||
typeof client === 'object' &&
|
||||
'createCompletions' in client &&
|
||||
'getRequestTransformer' in client &&
|
||||
'getResponseChunkTransformer' in client
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型获取合适的client
|
||||
*/
|
||||
private getClient(model: Model): BaseApiClient {
|
||||
const id = model.id.toLowerCase()
|
||||
|
||||
// claude开头
|
||||
if (id.startsWith('claude')) {
|
||||
const client = this.clients.get('claude')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('Claude client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// gemini开头 且不以-nothink、-search结尾
|
||||
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
const client = this.clients.get('gemini')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('Gemini client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// OpenAI系列模型
|
||||
if (isOpenAILLMModel(model)) {
|
||||
const client = this.clients.get('openai')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('OpenAI client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
return this.defaultClient as BaseApiClient
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型选择合适的client并委托调用
|
||||
*/
|
||||
public getClientForModel(model: Model): BaseApiClient {
|
||||
this.currentClient = this.getClient(model)
|
||||
return this.currentClient
|
||||
}
|
||||
|
||||
// ============ BaseApiClient 抽象方法实现 ============
|
||||
|
||||
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
|
||||
// 尝试从payload中提取模型信息来选择client
|
||||
const modelId = this.extractModelFromPayload(payload)
|
||||
if (modelId) {
|
||||
const modelObj = { id: modelId } as Model
|
||||
const targetClient = this.getClient(modelObj)
|
||||
return targetClient.createCompletions(payload, options)
|
||||
}
|
||||
|
||||
// 如果无法从payload中提取模型,使用当前设置的client
|
||||
return this.currentClient.createCompletions(payload, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 从SDK payload中提取模型ID
|
||||
*/
|
||||
private extractModelFromPayload(payload: SdkParams): string | null {
|
||||
// 不同的SDK可能有不同的字段名
|
||||
if ('model' in payload && typeof payload.model === 'string') {
|
||||
return payload.model
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
return this.currentClient.generateImage(params)
|
||||
}
|
||||
|
||||
async getEmbeddingDimensions(model?: Model): Promise<number> {
|
||||
const client = model ? this.getClient(model) : this.currentClient
|
||||
return client.getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
async listModels(): Promise<SdkModel[]> {
|
||||
// 可以聚合所有client的模型,或者使用默认client
|
||||
return this.defaultClient.listModels()
|
||||
}
|
||||
|
||||
async getSdkInstance(): Promise<SdkInstance> {
|
||||
return this.currentClient.getSdkInstance()
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<SdkParams, SdkMessageParam> {
|
||||
return this.currentClient.getRequestTransformer()
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<SdkRawChunk> {
|
||||
return this.currentClient.getResponseChunkTransformer()
|
||||
}
|
||||
|
||||
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {
|
||||
return this.currentClient.convertMcpToolsToSdkTools(mcpTools)
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||
return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
||||
}
|
||||
|
||||
buildSdkMessages(
|
||||
currentReqMessages: SdkMessageParam[],
|
||||
output: SdkRawOutput | string,
|
||||
toolResults: SdkMessageParam[],
|
||||
toolCalls?: SdkToolCall[]
|
||||
): SdkMessageParam[] {
|
||||
return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
||||
}
|
||||
|
||||
convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): SdkMessageParam | undefined {
|
||||
const client = this.getClient(model)
|
||||
return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
}
|
||||
|
||||
extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] {
|
||||
return this.currentClient.extractMessagesFromSdkPayload(sdkPayload)
|
||||
}
|
||||
|
||||
estimateMessageTokens(message: SdkMessageParam): number {
|
||||
return this.currentClient.estimateMessageTokens(message)
|
||||
}
|
||||
}
|
||||
62
src/renderer/src/aiCore/clients/ApiClientFactory.ts
Normal file
62
src/renderer/src/aiCore/clients/ApiClientFactory.ts
Normal file
@ -0,0 +1,62 @@
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import { AihubmixAPIClient } from './AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||
import { BaseApiClient } from './BaseApiClient'
|
||||
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||
|
||||
/**
|
||||
* Factory for creating ApiClient instances based on provider configuration
|
||||
* 根据提供者配置创建ApiClient实例的工厂
|
||||
*/
|
||||
export class ApiClientFactory {
|
||||
/**
|
||||
* Create an ApiClient instance for the given provider
|
||||
* 为给定的提供者创建ApiClient实例
|
||||
*/
|
||||
static create(provider: Provider): BaseApiClient {
|
||||
console.log(`[ApiClientFactory] Creating ApiClient for provider:`, {
|
||||
id: provider.id,
|
||||
type: provider.type
|
||||
})
|
||||
|
||||
let instance: BaseApiClient
|
||||
|
||||
// 首先检查特殊的provider id
|
||||
if (provider.id === 'aihubmix') {
|
||||
console.log(`[ApiClientFactory] Creating AihubmixAPIClient for provider: ${provider.id}`)
|
||||
instance = new AihubmixAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
// 然后检查标准的provider type
|
||||
switch (provider.type) {
|
||||
case 'openai':
|
||||
case 'azure-openai':
|
||||
console.log(`[ApiClientFactory] Creating OpenAIApiClient for provider: ${provider.id}`)
|
||||
instance = new OpenAIAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'openai-response':
|
||||
instance = new OpenAIResponseAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'gemini':
|
||||
instance = new GeminiAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'anthropic':
|
||||
instance = new AnthropicAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
default:
|
||||
console.log(`[ApiClientFactory] Using default OpenAIApiClient for provider: ${provider.id}`)
|
||||
instance = new OpenAIAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
}
|
||||
|
||||
return instance
|
||||
}
|
||||
}
|
||||
|
||||
export function isOpenAIProvider(provider: Provider) {
|
||||
return !['anthropic', 'gemini'].includes(provider.type)
|
||||
}
|
||||
@ -1,40 +1,69 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { isFunctionCallingModel, isNotSupportTemperatureAndTopP } from '@renderer/config/models'
|
||||
import {
|
||||
isFunctionCallingModel,
|
||||
isNotSupportTemperatureAndTopP,
|
||||
isOpenAIModel,
|
||||
isSupportedFlexServiceTier
|
||||
} from '@renderer/config/models'
|
||||
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
|
||||
import type {
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { SettingsState } from '@renderer/store/settings'
|
||||
import {
|
||||
Assistant,
|
||||
FileTypes,
|
||||
GenerateImageParams,
|
||||
KnowledgeReference,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
OpenAIServiceTier,
|
||||
Provider,
|
||||
Suggestion,
|
||||
ToolCallResponse,
|
||||
WebSearchProviderResponse,
|
||||
WebSearchResponse
|
||||
} from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import { delay, isJSON, parseJSON } from '@renderer/utils'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
RequestOptions,
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkModel,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import { isJSON, parseJSON } from '@renderer/utils'
|
||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import { getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
import Logger from 'electron-log/renderer'
|
||||
import { isEmpty } from 'lodash'
|
||||
import type OpenAI from 'openai'
|
||||
|
||||
import type { CompletionsParams } from '.'
|
||||
import { ApiClient, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from './types'
|
||||
|
||||
export default abstract class BaseProvider {
|
||||
// Threshold for determining whether to use system prompt for tools
|
||||
/**
|
||||
* Abstract base class for API clients.
|
||||
* Provides common functionality and structure for specific client implementations.
|
||||
*/
|
||||
export abstract class BaseApiClient<
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> implements ApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool>
|
||||
{
|
||||
private static readonly SYSTEM_PROMPT_THRESHOLD: number = 128
|
||||
|
||||
protected provider: Provider
|
||||
public provider: Provider
|
||||
protected host: string
|
||||
protected apiKey: string
|
||||
|
||||
protected useSystemPromptForTools: boolean = true
|
||||
protected sdkInstance?: TSdkInstance
|
||||
public useSystemPromptForTools: boolean = true
|
||||
|
||||
constructor(provider: Provider) {
|
||||
this.provider = provider
|
||||
@ -42,32 +71,81 @@ export default abstract class BaseProvider {
|
||||
this.apiKey = this.getApiKey()
|
||||
}
|
||||
|
||||
abstract completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
|
||||
abstract translate(
|
||||
content: string,
|
||||
assistant: Assistant,
|
||||
onResponse?: (text: string, isComplete: boolean) => void
|
||||
): Promise<string>
|
||||
abstract summaries(messages: Message[], assistant: Assistant): Promise<string>
|
||||
abstract summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null>
|
||||
abstract suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]>
|
||||
abstract generateText({ prompt, content }: { prompt: string; content: string }): Promise<string>
|
||||
abstract check(model: Model, stream: boolean): Promise<{ valid: boolean; error: Error | null }>
|
||||
abstract models(): Promise<OpenAI.Models.Model[]>
|
||||
abstract generateImage(params: GenerateImageParams): Promise<string[]>
|
||||
abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
|
||||
// 由于现在出现了一些能够选择嵌入维度的嵌入模型,这个不考虑dimensions参数的方法将只能应用于那些不支持dimensions的模型
|
||||
abstract getEmbeddingDimensions(model: Model): Promise<number>
|
||||
public abstract convertMcpTools<T>(mcpTools: MCPTool[]): T[]
|
||||
public abstract mcpToolCallResponseToMessage(
|
||||
// // 核心的completions方法 - 在中间件架构中,这通常只是一个占位符
|
||||
// abstract completions(params: CompletionsParams, internal?: ProcessingState): Promise<CompletionsResult>
|
||||
|
||||
/**
|
||||
* 核心API Endpoint
|
||||
**/
|
||||
|
||||
abstract createCompletions(payload: TSdkParams, options?: RequestOptions): Promise<TRawOutput>
|
||||
|
||||
abstract generateImage(generateImageParams: GenerateImageParams): Promise<string[]>
|
||||
|
||||
abstract getEmbeddingDimensions(model?: Model): Promise<number>
|
||||
|
||||
abstract listModels(): Promise<SdkModel[]>
|
||||
|
||||
abstract getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
|
||||
|
||||
/**
|
||||
* 中间件
|
||||
**/
|
||||
|
||||
// 在 CoreRequestToSdkParamsMiddleware中使用
|
||||
abstract getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
|
||||
// 在RawSdkChunkToGenericChunkMiddleware中使用
|
||||
abstract getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk>
|
||||
|
||||
/**
|
||||
* 工具转换
|
||||
**/
|
||||
|
||||
// Optional tool conversion methods - implement if needed by the specific provider
|
||||
abstract convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
|
||||
|
||||
abstract convertSdkToolCallToMcp(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
|
||||
|
||||
abstract convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
|
||||
|
||||
abstract buildSdkMessages(
|
||||
currentReqMessages: TMessageParam[],
|
||||
output: TRawOutput | string,
|
||||
toolResults: TMessageParam[],
|
||||
toolCalls?: TToolCall[]
|
||||
): TMessageParam[]
|
||||
|
||||
abstract estimateMessageTokens(message: TMessageParam): number
|
||||
|
||||
abstract convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): any
|
||||
): TMessageParam | undefined
|
||||
|
||||
/**
|
||||
* 从SDK载荷中提取消息数组(用于中间件中的类型安全访问)
|
||||
* 不同的提供商可能使用不同的字段名(如messages、history等)
|
||||
*/
|
||||
abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
|
||||
|
||||
/**
|
||||
* 附加原始流监听器
|
||||
*/
|
||||
public attachRawStreamListener<TListener extends RawStreamListener<TRawChunk>>(
|
||||
rawOutput: TRawOutput,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
_listener: TListener
|
||||
): TRawOutput {
|
||||
return rawOutput
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用函数
|
||||
**/
|
||||
|
||||
public getBaseURL(): string {
|
||||
const host = this.provider.apiHost
|
||||
return formatApiHost(host)
|
||||
return this.provider.apiHost
|
||||
}
|
||||
|
||||
public getApiKey() {
|
||||
@ -112,14 +190,32 @@ export default abstract class BaseProvider {
|
||||
return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.topP
|
||||
}
|
||||
|
||||
public async fakeCompletions({ onChunk }: CompletionsParams) {
|
||||
for (let i = 0; i < 100; i++) {
|
||||
await delay(0.01)
|
||||
onChunk({
|
||||
response: { text: i + '\n', usage: { completion_tokens: 0, prompt_tokens: 0, total_tokens: 0 } },
|
||||
type: ChunkType.BLOCK_COMPLETE
|
||||
})
|
||||
protected getServiceTier(model: Model) {
|
||||
if (!isOpenAIModel(model) || model.provider === 'github' || model.provider === 'copilot') {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
|
||||
let serviceTier = 'auto' as OpenAIServiceTier
|
||||
|
||||
if (openAI && openAI?.serviceTier === 'flex') {
|
||||
if (isSupportedFlexServiceTier(model)) {
|
||||
serviceTier = 'flex'
|
||||
} else {
|
||||
serviceTier = 'auto'
|
||||
}
|
||||
} else {
|
||||
serviceTier = openAI.serviceTier
|
||||
}
|
||||
|
||||
return serviceTier
|
||||
}
|
||||
|
||||
protected getTimeout(model: Model) {
|
||||
if (isSupportedFlexServiceTier(model)) {
|
||||
return 15 * 1000 * 60
|
||||
}
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
public async getMessageContent(message: Message): Promise<string> {
|
||||
@ -149,6 +245,36 @@ export default abstract class BaseProvider {
|
||||
return content
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the file content from the message
|
||||
* @param message - The message
|
||||
* @returns The file content
|
||||
*/
|
||||
protected async extractFileContent(message: Message) {
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
if (fileBlocks.length > 0) {
|
||||
const textFileBlocks = fileBlocks.filter(
|
||||
(fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type)
|
||||
)
|
||||
|
||||
if (textFileBlocks.length > 0) {
|
||||
let text = ''
|
||||
const divider = '\n\n---\n\n'
|
||||
|
||||
for (const fileBlock of textFileBlocks) {
|
||||
const file = fileBlock.file
|
||||
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
|
||||
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
|
||||
text = text + fileNameRow + fileContent + divider
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
}
|
||||
|
||||
return ''
|
||||
}
|
||||
|
||||
private async getWebSearchReferencesFromCache(message: Message) {
|
||||
const content = getMainTextContent(message)
|
||||
if (isEmpty(content)) {
|
||||
@ -210,7 +336,7 @@ export default abstract class BaseProvider {
|
||||
)
|
||||
}
|
||||
|
||||
protected createAbortController(messageId?: string, isAddEventListener?: boolean) {
|
||||
public createAbortController(messageId?: string, isAddEventListener?: boolean) {
|
||||
const abortController = new AbortController()
|
||||
const abortFn = () => abortController.abort()
|
||||
|
||||
@ -256,11 +382,11 @@ export default abstract class BaseProvider {
|
||||
}
|
||||
|
||||
// Setup tools configuration based on provided parameters
|
||||
protected setupToolsConfig<T>(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||
tools: T[]
|
||||
public setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||
tools: TSdkSpecificTool[]
|
||||
} {
|
||||
const { mcpTools, model, enableToolUse } = params
|
||||
let tools: T[] = []
|
||||
let tools: TSdkSpecificTool[] = []
|
||||
|
||||
// If there are no tools, return an empty array
|
||||
if (!mcpTools?.length) {
|
||||
@ -268,14 +394,14 @@ export default abstract class BaseProvider {
|
||||
}
|
||||
|
||||
// If the number of tools exceeds the threshold, use the system prompt
|
||||
if (mcpTools.length > BaseProvider.SYSTEM_PROMPT_THRESHOLD) {
|
||||
if (mcpTools.length > BaseApiClient.SYSTEM_PROMPT_THRESHOLD) {
|
||||
this.useSystemPromptForTools = true
|
||||
return { tools }
|
||||
}
|
||||
|
||||
// If the model supports function calling and tool usage is enabled
|
||||
if (isFunctionCallingModel(model) && enableToolUse) {
|
||||
tools = this.convertMcpTools<T>(mcpTools)
|
||||
tools = this.convertMcpToolsToSdkTools(mcpTools)
|
||||
this.useSystemPromptForTools = false
|
||||
}
|
||||
|
||||
714
src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts
Normal file
714
src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts
Normal file
@ -0,0 +1,714 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import {
|
||||
Base64ImageSource,
|
||||
ImageBlockParam,
|
||||
MessageParam,
|
||||
TextBlockParam,
|
||||
ToolResultBlockParam,
|
||||
ToolUseBlock,
|
||||
WebSearchTool20250305
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import {
|
||||
ContentBlock,
|
||||
ContentBlockParam,
|
||||
MessageCreateParams,
|
||||
MessageCreateParamsBase,
|
||||
RedactedThinkingBlockParam,
|
||||
ServerToolUseBlockParam,
|
||||
ThinkingBlockParam,
|
||||
ThinkingConfigParam,
|
||||
ToolUnion,
|
||||
ToolUseBlockParam,
|
||||
WebSearchResultBlock,
|
||||
WebSearchToolResultBlockParam,
|
||||
WebSearchToolResultError
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import {
|
||||
Assistant,
|
||||
EFFORT_RATIO,
|
||||
FileTypes,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import {
|
||||
ChunkType,
|
||||
ErrorChunk,
|
||||
LLMWebSearchCompleteChunk,
|
||||
LLMWebSearchInProgressChunk,
|
||||
MCPToolCreatedChunk,
|
||||
TextDeltaChunk,
|
||||
ThinkingDeltaChunk
|
||||
} from '@renderer/types/chunk'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
AnthropicSdkMessageParam,
|
||||
AnthropicSdkParams,
|
||||
AnthropicSdkRawChunk,
|
||||
AnthropicSdkRawOutput
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
anthropicToolUseToMcpTool,
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToAnthropicMessage,
|
||||
mcpToolsToAnthropicTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
export class AnthropicAPIClient extends BaseApiClient<
|
||||
Anthropic,
|
||||
AnthropicSdkParams,
|
||||
AnthropicSdkRawOutput,
|
||||
AnthropicSdkRawChunk,
|
||||
AnthropicSdkMessageParam,
|
||||
ToolUseBlock,
|
||||
ToolUnion
|
||||
> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
async getSdkInstance(): Promise<Anthropic> {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
this.sdkInstance = new Anthropic({
|
||||
apiKey: this.getApiKey(),
|
||||
baseURL: this.getBaseURL(),
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: {
|
||||
'anthropic-beta': 'output-128k-2025-02-19'
|
||||
}
|
||||
})
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: AnthropicSdkParams,
|
||||
options?: Anthropic.RequestOptions
|
||||
): Promise<AnthropicSdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
if (payload.stream) {
|
||||
return sdk.messages.stream(payload, options)
|
||||
}
|
||||
return await sdk.messages.create(payload, options)
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
override async listModels(): Promise<Anthropic.ModelInfo[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const response = await sdk.models.list()
|
||||
return response.data
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
override async getEmbeddingDimensions(): Promise<number> {
|
||||
return 0
|
||||
}
|
||||
|
||||
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.temperature
|
||||
}
|
||||
|
||||
override getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.topP
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
private getBudgetToken(assistant: Assistant, model: Model): ThinkingConfigParam | undefined {
|
||||
if (!isReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
if (reasoningEffort === undefined) {
|
||||
return {
|
||||
type: 'disabled'
|
||||
}
|
||||
}
|
||||
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
const budgetTokens = Math.max(
|
||||
1024,
|
||||
Math.floor(
|
||||
Math.min(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||
findTokenLimit(model.id)?.min!,
|
||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
type: 'enabled',
|
||||
budget_tokens: budgetTokens
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message parameter
|
||||
* @param message - The message
|
||||
* @param model - The model
|
||||
* @returns The message parameter
|
||||
*/
|
||||
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
|
||||
const parts: MessageParam['content'] = [
|
||||
{
|
||||
type: 'text',
|
||||
text: getMainTextContent(message)
|
||||
}
|
||||
]
|
||||
|
||||
// Get and process image blocks
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (imageBlock.file) {
|
||||
// Handle uploaded file
|
||||
const file = imageBlock.file
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
type: 'image',
|
||||
source: {
|
||||
data: base64Data.base64,
|
||||
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
|
||||
type: 'base64'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
// Get and process file blocks
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const { file } = fileBlock
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
|
||||
const base64Data = await FileManager.readBase64File(file)
|
||||
parts.push({
|
||||
type: 'document',
|
||||
source: {
|
||||
type: 'base64',
|
||||
media_type: 'application/pdf',
|
||||
data: base64Data
|
||||
}
|
||||
})
|
||||
} else {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
type: 'text',
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ToolUnion[] {
|
||||
return mcpToolsToAnthropicTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): AnthropicSdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
|
||||
} else if ('toolCallId' in mcpToolResponse) {
|
||||
return {
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: mcpToolResponse.toolCallId!,
|
||||
content: resp.content
|
||||
.map((item) => {
|
||||
if (item.type === 'text') {
|
||||
return {
|
||||
type: 'text',
|
||||
text: item.text || ''
|
||||
} satisfies TextBlockParam
|
||||
}
|
||||
if (item.type === 'image') {
|
||||
return {
|
||||
type: 'image',
|
||||
source: {
|
||||
data: item.data || '',
|
||||
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
|
||||
type: 'base64'
|
||||
}
|
||||
} satisfies ImageBlockParam
|
||||
}
|
||||
return
|
||||
})
|
||||
.filter((n) => typeof n !== 'undefined'),
|
||||
is_error: resp.isError
|
||||
} satisfies ToolResultBlockParam
|
||||
]
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Implementing abstract methods from BaseApiClient
|
||||
convertSdkToolCallToMcp(toolCall: ToolUseBlock, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
// Based on anthropicToolUseToMcpTool logic in AnthropicProvider
|
||||
// This might need adjustment based on how tool calls are specifically handled in the new structure
|
||||
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
|
||||
return mcpTool
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: ToolUseBlock, mcpTool: MCPTool): ToolCallResponse {
|
||||
return {
|
||||
id: toolCall.id,
|
||||
toolCallId: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: toolCall.input as Record<string, unknown>,
|
||||
status: 'pending'
|
||||
} as ToolCallResponse
|
||||
}
|
||||
|
||||
override buildSdkMessages(
|
||||
currentReqMessages: AnthropicSdkMessageParam[],
|
||||
output: Anthropic.Message,
|
||||
toolResults: AnthropicSdkMessageParam[]
|
||||
): AnthropicSdkMessageParam[] {
|
||||
const assistantMessage: AnthropicSdkMessageParam = {
|
||||
role: output.role,
|
||||
content: convertContentBlocksToParams(output.content)
|
||||
}
|
||||
|
||||
const newMessages: AnthropicSdkMessageParam[] = [...currentReqMessages, assistantMessage]
|
||||
if (toolResults && toolResults.length > 0) {
|
||||
newMessages.push(...toolResults)
|
||||
}
|
||||
return newMessages
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: AnthropicSdkMessageParam): number {
|
||||
if (typeof message.content === 'string') {
|
||||
return estimateTextTokens(message.content)
|
||||
}
|
||||
return message.content
|
||||
.map((content) => {
|
||||
switch (content.type) {
|
||||
case 'text':
|
||||
return estimateTextTokens(content.text)
|
||||
case 'image':
|
||||
if (content.source.type === 'base64') {
|
||||
return estimateTextTokens(content.source.data)
|
||||
} else {
|
||||
return estimateTextTokens(content.source.url)
|
||||
}
|
||||
case 'tool_use':
|
||||
return estimateTextTokens(JSON.stringify(content.input))
|
||||
case 'tool_result':
|
||||
return estimateTextTokens(JSON.stringify(content.content))
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
.reduce((acc, curr) => acc + curr, 0)
|
||||
}
|
||||
|
||||
public buildAssistantMessage(message: Anthropic.Message): AnthropicSdkMessageParam {
|
||||
const messageParam: AnthropicSdkMessageParam = {
|
||||
role: message.role,
|
||||
content: convertContentBlocksToParams(message.content)
|
||||
}
|
||||
return messageParam
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: AnthropicSdkParams): AnthropicSdkMessageParam[] {
|
||||
return sdkPayload.messages || []
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic专用的原始流监听器
|
||||
* 处理MessageStream对象的特定事件
|
||||
*/
|
||||
override attachRawStreamListener(
|
||||
rawOutput: AnthropicSdkRawOutput,
|
||||
listener: RawStreamListener<AnthropicSdkRawChunk>
|
||||
): AnthropicSdkRawOutput {
|
||||
console.log(`[AnthropicApiClient] 附加流监听器到原始输出`)
|
||||
|
||||
// 检查是否为MessageStream
|
||||
if (rawOutput instanceof MessageStream) {
|
||||
console.log(`[AnthropicApiClient] 检测到 Anthropic MessageStream,附加专用监听器`)
|
||||
|
||||
if (listener.onStart) {
|
||||
listener.onStart()
|
||||
}
|
||||
|
||||
if (listener.onChunk) {
|
||||
rawOutput.on('streamEvent', (event: AnthropicSdkRawChunk) => {
|
||||
listener.onChunk!(event)
|
||||
})
|
||||
}
|
||||
|
||||
// 专用的Anthropic事件处理
|
||||
const anthropicListener = listener as AnthropicStreamListener
|
||||
|
||||
if (anthropicListener.onContentBlock) {
|
||||
rawOutput.on('contentBlock', anthropicListener.onContentBlock)
|
||||
}
|
||||
|
||||
if (anthropicListener.onMessage) {
|
||||
rawOutput.on('finalMessage', anthropicListener.onMessage)
|
||||
}
|
||||
|
||||
if (listener.onEnd) {
|
||||
rawOutput.on('end', () => {
|
||||
listener.onEnd!()
|
||||
})
|
||||
}
|
||||
|
||||
if (listener.onError) {
|
||||
rawOutput.on('error', (error: Error) => {
|
||||
listener.onError!(error)
|
||||
})
|
||||
}
|
||||
|
||||
return rawOutput
|
||||
}
|
||||
|
||||
// 对于非MessageStream响应
|
||||
return rawOutput
|
||||
}
|
||||
|
||||
private async getWebSearchParams(model: Model): Promise<WebSearchTool20250305 | undefined> {
|
||||
if (!isWebSearchModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return {
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
max_uses: 5
|
||||
} as WebSearchTool20250305
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<AnthropicSdkParams, AnthropicSdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: AnthropicSdkParams
|
||||
messages: AnthropicSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
let systemPrompt = assistant.prompt
|
||||
|
||||
// 2. 设置工具
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
systemPrompt = await buildSystemPrompt(systemPrompt, mcpTools)
|
||||
}
|
||||
|
||||
const systemMessage: TextBlockParam | undefined = systemPrompt
|
||||
? { type: 'text', text: systemPrompt }
|
||||
: undefined
|
||||
|
||||
// 3. 处理用户消息
|
||||
const sdkMessages: AnthropicSdkMessageParam[] = []
|
||||
if (typeof messages === 'string') {
|
||||
sdkMessages.push({ role: 'user', content: messages })
|
||||
} else {
|
||||
const processedMessages = addImageFileToContents(messages)
|
||||
for (const message of processedMessages) {
|
||||
sdkMessages.push(await this.convertMessageToSdkParam(message))
|
||||
}
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
const webSearchTool = await this.getWebSearchParams(model)
|
||||
if (webSearchTool) {
|
||||
tools.push(webSearchTool)
|
||||
}
|
||||
}
|
||||
|
||||
const commonParams: MessageCreateParamsBase = {
|
||||
model: model.id,
|
||||
messages:
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages
|
||||
: sdkMessages,
|
||||
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
system: systemMessage ? [systemMessage] : undefined,
|
||||
thinking: this.getBudgetToken(assistant, model),
|
||||
tools: tools.length > 0 ? tools : undefined,
|
||||
...this.getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
const finalParams: MessageCreateParams = streamOutput
|
||||
? {
|
||||
...commonParams,
|
||||
stream: true
|
||||
}
|
||||
: {
|
||||
...commonParams,
|
||||
stream: false
|
||||
}
|
||||
|
||||
const timeout = this.getTimeout(model)
|
||||
return { payload: finalParams, messages: sdkMessages, metadata: { timeout } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<AnthropicSdkRawChunk> {
|
||||
return () => {
|
||||
let accumulatedJson = ''
|
||||
const toolCalls: Record<number, ToolUseBlock> = {}
|
||||
|
||||
return {
|
||||
async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
switch (rawChunk.type) {
|
||||
case 'message': {
|
||||
for (const content of rawChunk.content) {
|
||||
switch (content.type) {
|
||||
case 'text': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: content.text
|
||||
} as TextDeltaChunk)
|
||||
break
|
||||
}
|
||||
case 'tool_use': {
|
||||
toolCalls[0] = content
|
||||
break
|
||||
}
|
||||
case 'thinking': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: content.thinking
|
||||
} as ThinkingDeltaChunk)
|
||||
break
|
||||
}
|
||||
case 'web_search_tool_result': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: content.content,
|
||||
source: WebSearchSource.ANTHROPIC
|
||||
}
|
||||
} as LLMWebSearchCompleteChunk)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'content_block_start': {
|
||||
const contentBlock = rawChunk.content_block
|
||||
switch (contentBlock.type) {
|
||||
case 'server_tool_use': {
|
||||
if (contentBlock.name === 'web_search') {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
|
||||
} as LLMWebSearchInProgressChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'web_search_tool_result': {
|
||||
if (
|
||||
contentBlock.content &&
|
||||
(contentBlock.content as WebSearchToolResultError).type === 'web_search_tool_result_error'
|
||||
) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.ERROR,
|
||||
error: {
|
||||
code: (contentBlock.content as WebSearchToolResultError).error_code,
|
||||
message: (contentBlock.content as WebSearchToolResultError).error_code
|
||||
}
|
||||
} as ErrorChunk)
|
||||
} else {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: contentBlock.content as Array<WebSearchResultBlock>,
|
||||
source: WebSearchSource.ANTHROPIC
|
||||
}
|
||||
} as LLMWebSearchCompleteChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'tool_use': {
|
||||
toolCalls[rawChunk.index] = contentBlock
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'content_block_delta': {
|
||||
const messageDelta = rawChunk.delta
|
||||
switch (messageDelta.type) {
|
||||
case 'text_delta': {
|
||||
if (messageDelta.text) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: messageDelta.text
|
||||
} as TextDeltaChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'thinking_delta': {
|
||||
if (messageDelta.thinking) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: messageDelta.thinking
|
||||
} as ThinkingDeltaChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'input_json_delta': {
|
||||
if (messageDelta.partial_json) {
|
||||
accumulatedJson += messageDelta.partial_json
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'content_block_stop': {
|
||||
const toolCall = toolCalls[rawChunk.index]
|
||||
if (toolCall) {
|
||||
try {
|
||||
toolCall.input = JSON.parse(accumulatedJson)
|
||||
Logger.debug(`Tool call id: ${toolCall.id}, accumulated json: ${accumulatedJson}`)
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: [toolCall]
|
||||
} as MCPToolCreatedChunk)
|
||||
} catch (error) {
|
||||
Logger.error(`Error parsing tool call input: ${error}`)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'message_delta': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: rawChunk.usage.input_tokens || 0,
|
||||
completion_tokens: rawChunk.usage.output_tokens || 0,
|
||||
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 ContentBlock 数组转换为 ContentBlockParam 数组
|
||||
* 去除服务器生成的额外字段,只保留发送给API所需的字段
|
||||
*/
|
||||
function convertContentBlocksToParams(contentBlocks: ContentBlock[]): ContentBlockParam[] {
|
||||
return contentBlocks.map((block): ContentBlockParam => {
|
||||
switch (block.type) {
|
||||
case 'text':
|
||||
// TextBlock -> TextBlockParam,去除 citations 等服务器字段
|
||||
return {
|
||||
type: 'text',
|
||||
text: block.text
|
||||
} satisfies TextBlockParam
|
||||
case 'tool_use':
|
||||
// ToolUseBlock -> ToolUseBlockParam
|
||||
return {
|
||||
type: 'tool_use',
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.input
|
||||
} satisfies ToolUseBlockParam
|
||||
case 'thinking':
|
||||
// ThinkingBlock -> ThinkingBlockParam
|
||||
return {
|
||||
type: 'thinking',
|
||||
thinking: block.thinking,
|
||||
signature: block.signature
|
||||
} satisfies ThinkingBlockParam
|
||||
case 'redacted_thinking':
|
||||
// RedactedThinkingBlock -> RedactedThinkingBlockParam
|
||||
return {
|
||||
type: 'redacted_thinking',
|
||||
data: block.data
|
||||
} satisfies RedactedThinkingBlockParam
|
||||
case 'server_tool_use':
|
||||
// ServerToolUseBlock -> ServerToolUseBlockParam
|
||||
return {
|
||||
type: 'server_tool_use',
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.input
|
||||
} satisfies ServerToolUseBlockParam
|
||||
case 'web_search_tool_result':
|
||||
// WebSearchToolResultBlock -> WebSearchToolResultBlockParam
|
||||
return {
|
||||
type: 'web_search_tool_result',
|
||||
tool_use_id: block.tool_use_id,
|
||||
content: block.content
|
||||
} satisfies WebSearchToolResultBlockParam
|
||||
default:
|
||||
return block as ContentBlockParam
|
||||
}
|
||||
})
|
||||
}
|
||||
786
src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts
Normal file
786
src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts
Normal file
@ -0,0 +1,786 @@
|
||||
import {
|
||||
Content,
|
||||
File,
|
||||
FileState,
|
||||
FunctionCall,
|
||||
GenerateContentConfig,
|
||||
GenerateImagesConfig,
|
||||
GoogleGenAI,
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
Modality,
|
||||
Model as GeminiModel,
|
||||
Pager,
|
||||
Part,
|
||||
SafetySetting,
|
||||
SendMessageParameters,
|
||||
ThinkingConfig,
|
||||
Tool
|
||||
} from '@google/genai'
|
||||
import { nanoid } from '@reduxjs/toolkit'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import {
|
||||
findTokenLimit,
|
||||
GEMINI_FLASH_MODEL_REGEX,
|
||||
isGeminiReasoningModel,
|
||||
isGemmaModel,
|
||||
isVisionModel
|
||||
} from '@renderer/config/models'
|
||||
import { CacheService } from '@renderer/services/CacheService'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import {
|
||||
Assistant,
|
||||
EFFORT_RATIO,
|
||||
FileType,
|
||||
FileTypes,
|
||||
GenerateImageParams,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import { ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
GeminiOptions,
|
||||
GeminiSdkMessageParam,
|
||||
GeminiSdkParams,
|
||||
GeminiSdkRawChunk,
|
||||
GeminiSdkRawOutput,
|
||||
GeminiSdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import {
|
||||
geminiFunctionCallToMcpTool,
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToGeminiMessage,
|
||||
mcpToolsToGeminiTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { MB } from '@shared/config/constant'
|
||||
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
export class GeminiAPIClient extends BaseApiClient<
|
||||
GoogleGenAI,
|
||||
GeminiSdkParams,
|
||||
GeminiSdkRawOutput,
|
||||
GeminiSdkRawChunk,
|
||||
GeminiSdkMessageParam,
|
||||
GeminiSdkToolCall,
|
||||
Tool
|
||||
> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override async createCompletions(payload: GeminiSdkParams, options?: GeminiOptions): Promise<GeminiSdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const { model, history, ...rest } = payload
|
||||
const realPayload: Omit<GeminiSdkParams, 'model'> = {
|
||||
...rest,
|
||||
config: {
|
||||
...rest.config,
|
||||
abortSignal: options?.abortSignal,
|
||||
httpOptions: {
|
||||
...rest.config?.httpOptions,
|
||||
timeout: options?.timeout
|
||||
}
|
||||
}
|
||||
} satisfies SendMessageParameters
|
||||
|
||||
const streamOutput = options?.streamOutput
|
||||
|
||||
const chat = sdk.chats.create({
|
||||
model: model,
|
||||
history: history
|
||||
})
|
||||
|
||||
if (streamOutput) {
|
||||
const stream = chat.sendMessageStream(realPayload)
|
||||
return stream
|
||||
} else {
|
||||
const response = await chat.sendMessage(realPayload)
|
||||
return response
|
||||
}
|
||||
}
|
||||
|
||||
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
try {
|
||||
const { model, prompt, imageSize, batchSize, signal } = generateImageParams
|
||||
const config: GenerateImagesConfig = {
|
||||
numberOfImages: batchSize,
|
||||
aspectRatio: imageSize,
|
||||
abortSignal: signal,
|
||||
httpOptions: {
|
||||
timeout: 5 * 60 * 1000
|
||||
}
|
||||
}
|
||||
const response = await sdk.models.generateImages({
|
||||
model: model,
|
||||
prompt,
|
||||
config
|
||||
})
|
||||
|
||||
if (!response.generatedImages || response.generatedImages.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
const images = response.generatedImages
|
||||
.filter((image) => image.image?.imageBytes)
|
||||
.map((image) => {
|
||||
const dataPrefix = `data:${image.image?.mimeType || 'image/png'};base64,`
|
||||
return dataPrefix + image.image?.imageBytes
|
||||
})
|
||||
// console.log(response?.generatedImages?.[0]?.image?.imageBytes);
|
||||
return images
|
||||
} catch (error) {
|
||||
console.error('[generateImage] error:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
try {
|
||||
const data = await sdk.models.embedContent({
|
||||
model: model.id,
|
||||
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
|
||||
})
|
||||
return data.embeddings?.[0]?.values?.length || 0
|
||||
} catch (e) {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
override async listModels(): Promise<GeminiModel[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const response = await sdk.models.list()
|
||||
const models: GeminiModel[] = []
|
||||
for await (const model of response) {
|
||||
models.push(model)
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
this.sdkInstance = new GoogleGenAI({
|
||||
vertexai: false,
|
||||
apiKey: this.apiKey,
|
||||
httpOptions: { baseUrl: this.getBaseURL() }
|
||||
})
|
||||
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle a PDF file
|
||||
* @param file - The file
|
||||
* @returns The part
|
||||
*/
|
||||
private async handlePdfFile(file: FileType): Promise<Part> {
|
||||
const smallFileSize = 20 * MB
|
||||
const isSmallFile = file.size < smallFileSize
|
||||
|
||||
if (isSmallFile) {
|
||||
const { data, mimeType } = await this.base64File(file)
|
||||
return {
|
||||
inlineData: {
|
||||
data,
|
||||
mimeType
|
||||
} as Part['inlineData']
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve file from Gemini uploaded files
|
||||
const fileMetadata: File | undefined = await this.retrieveFile(file)
|
||||
|
||||
if (fileMetadata) {
|
||||
return {
|
||||
fileData: {
|
||||
fileUri: fileMetadata.uri,
|
||||
mimeType: fileMetadata.mimeType
|
||||
} as Part['fileData']
|
||||
}
|
||||
}
|
||||
|
||||
// If file is not found, upload it to Gemini
|
||||
const result = await this.uploadFile(file)
|
||||
|
||||
return {
|
||||
fileData: {
|
||||
fileUri: result.uri,
|
||||
mimeType: result.mimeType
|
||||
} as Part['fileData']
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message contents
|
||||
* @param message - The message
|
||||
* @returns The message contents
|
||||
*/
|
||||
private async convertMessageToSdkParam(message: Message): Promise<Content> {
|
||||
const role = message.role === 'user' ? 'user' : 'model'
|
||||
const parts: Part[] = [{ text: await this.getMessageContent(message) }]
|
||||
// Add any generated images from previous responses
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (
|
||||
imageBlock.metadata?.generateImageResponse?.images &&
|
||||
imageBlock.metadata.generateImageResponse.images.length > 0
|
||||
) {
|
||||
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
|
||||
if (imageUrl && imageUrl.startsWith('data:')) {
|
||||
// Extract base64 data and mime type from the data URL
|
||||
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
|
||||
if (matches && matches.length === 3) {
|
||||
const mimeType = matches[1]
|
||||
const base64Data = matches[2]
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data,
|
||||
mimeType: mimeType
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const file = imageBlock.file
|
||||
if (file) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const file = fileBlock.file
|
||||
if (file.type === FileTypes.IMAGE) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
|
||||
if (file.ext === '.pdf') {
|
||||
parts.push(await this.handlePdfFile(file))
|
||||
continue
|
||||
}
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role,
|
||||
parts: parts
|
||||
}
|
||||
}
|
||||
|
||||
// @ts-ignore unused
|
||||
private async getImageFileContents(message: Message): Promise<Content> {
|
||||
const role = message.role === 'user' ? 'user' : 'model'
|
||||
const content = getMainTextContent(message)
|
||||
const parts: Part[] = [{ text: content }]
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (
|
||||
imageBlock.metadata?.generateImageResponse?.images &&
|
||||
imageBlock.metadata.generateImageResponse.images.length > 0
|
||||
) {
|
||||
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
|
||||
if (imageUrl && imageUrl.startsWith('data:')) {
|
||||
// Extract base64 data and mime type from the data URL
|
||||
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
|
||||
if (matches && matches.length === 3) {
|
||||
const mimeType = matches[1]
|
||||
const base64Data = matches[2]
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data,
|
||||
mimeType: mimeType
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const file = imageBlock.file
|
||||
if (file) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
return {
|
||||
role,
|
||||
parts: parts
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the safety settings
|
||||
* @returns The safety settings
|
||||
*/
|
||||
private getSafetySettings(): SafetySetting[] {
|
||||
const safetyThreshold = 'OFF' as HarmBlockThreshold
|
||||
|
||||
return [
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
|
||||
threshold: HarmBlockThreshold.BLOCK_NONE
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
private getBudgetToken(assistant: Assistant, model: Model) {
|
||||
if (isGeminiReasoningModel(model)) {
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
// 如果thinking_budget是undefined,不思考
|
||||
if (reasoningEffort === undefined) {
|
||||
return {
|
||||
thinkingConfig: {
|
||||
includeThoughts: false,
|
||||
...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {})
|
||||
} as ThinkingConfig
|
||||
}
|
||||
}
|
||||
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
if (effortRatio > 1) {
|
||||
return {
|
||||
thinkingConfig: {
|
||||
includeThoughts: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const { max } = findTokenLimit(model.id) || { max: 0 }
|
||||
const budget = Math.floor(max * effortRatio)
|
||||
|
||||
return {
|
||||
thinkingConfig: {
|
||||
...(budget > 0 ? { thinkingBudget: budget } : {}),
|
||||
includeThoughts: true
|
||||
} as ThinkingConfig
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
private getGenerateImageParameter(): Partial<GenerateContentConfig> {
|
||||
return {
|
||||
systemInstruction: undefined,
|
||||
responseModalities: [Modality.TEXT, Modality.IMAGE],
|
||||
responseMimeType: 'text/plain'
|
||||
}
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<GeminiSdkParams, GeminiSdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: GeminiSdkParams
|
||||
messages: GeminiSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, enableWebSearch, enableGenerateImage } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
let systemInstruction = assistant.prompt
|
||||
|
||||
// 2. 设置工具
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools)
|
||||
}
|
||||
|
||||
let messageContents: Content
|
||||
const history: Content[] = []
|
||||
// 3. 处理用户消息
|
||||
if (typeof messages === 'string') {
|
||||
messageContents = {
|
||||
role: 'user',
|
||||
parts: [{ text: messages }]
|
||||
}
|
||||
} else {
|
||||
const userLastMessage = messages.pop()!
|
||||
messageContents = await this.convertMessageToSdkParam(userLastMessage)
|
||||
for (const message of messages) {
|
||||
history.push(await this.convertMessageToSdkParam(message))
|
||||
}
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
tools.push({
|
||||
googleSearch: {}
|
||||
})
|
||||
}
|
||||
|
||||
if (isGemmaModel(model) && assistant.prompt) {
|
||||
const isFirstMessage = history.length === 0
|
||||
if (isFirstMessage && messageContents) {
|
||||
const systemMessage = [
|
||||
{
|
||||
text:
|
||||
'<start_of_turn>user\n' +
|
||||
systemInstruction +
|
||||
'<end_of_turn>\n' +
|
||||
'<start_of_turn>user\n' +
|
||||
(messageContents?.parts?.[0] as Part).text +
|
||||
'<end_of_turn>'
|
||||
}
|
||||
] as Part[]
|
||||
if (messageContents && messageContents.parts) {
|
||||
messageContents.parts[0] = systemMessage[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const newHistory =
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages.slice(0, recursiveSdkMessages.length - 1)
|
||||
: history
|
||||
|
||||
const newMessageContents =
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? {
|
||||
...messageContents,
|
||||
parts: [
|
||||
...(messageContents.parts || []),
|
||||
...(recursiveSdkMessages[recursiveSdkMessages.length - 1].parts || [])
|
||||
]
|
||||
}
|
||||
: messageContents
|
||||
|
||||
const generateContentConfig: GenerateContentConfig = {
|
||||
safetySettings: this.getSafetySettings(),
|
||||
systemInstruction: isGemmaModel(model) ? undefined : systemInstruction,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
topP: this.getTopP(assistant, model),
|
||||
maxOutputTokens: maxTokens,
|
||||
tools: tools,
|
||||
...(enableGenerateImage ? this.getGenerateImageParameter() : {}),
|
||||
...this.getBudgetToken(assistant, model),
|
||||
...this.getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
const param: GeminiSdkParams = {
|
||||
model: model.id,
|
||||
config: generateContentConfig,
|
||||
history: newHistory,
|
||||
message: newMessageContents.parts!
|
||||
}
|
||||
|
||||
return {
|
||||
payload: param,
|
||||
messages: [messageContents],
|
||||
metadata: {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<GeminiSdkRawChunk> {
|
||||
return () => ({
|
||||
async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
let toolCalls: FunctionCall[] = []
|
||||
if (chunk.candidates && chunk.candidates.length > 0) {
|
||||
for (const candidate of chunk.candidates) {
|
||||
if (candidate.content) {
|
||||
candidate.content.parts?.forEach((part) => {
|
||||
const text = part.text || ''
|
||||
if (part.thought) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: text
|
||||
})
|
||||
} else if (part.text) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: text
|
||||
})
|
||||
} else if (part.inlineData) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [
|
||||
part.inlineData?.data?.startsWith('data:')
|
||||
? part.inlineData?.data
|
||||
: `data:${part.inlineData?.mimeType || 'image/png'};base64,${part.inlineData?.data}`
|
||||
]
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if (candidate.finishReason) {
|
||||
if (candidate.groundingMetadata) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: candidate.groundingMetadata,
|
||||
source: WebSearchSource.GEMINI
|
||||
}
|
||||
} as LLMWebSearchCompleteChunk)
|
||||
}
|
||||
if (chunk.functionCalls) {
|
||||
toolCalls = toolCalls.concat(chunk.functionCalls)
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
|
||||
completion_tokens:
|
||||
(chunk.usageMetadata?.totalTokenCount || 0) - (chunk.usageMetadata?.promptTokenCount || 0),
|
||||
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (toolCalls.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Tool[] {
|
||||
return mcpToolsToGeminiTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcp(toolCall: GeminiSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
return geminiFunctionCallToMcpTool(mcpTools, toolCall)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcpToolResponse(toolCall: GeminiSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||
const parsedArgs = (() => {
|
||||
try {
|
||||
return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args
|
||||
} catch {
|
||||
return toolCall.args
|
||||
}
|
||||
})()
|
||||
|
||||
return {
|
||||
id: toolCall.id || nanoid(),
|
||||
toolCallId: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
} as ToolCallResponse
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): GeminiSdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||
} else if ('toolCallId' in mcpToolResponse) {
|
||||
return {
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
id: mcpToolResponse.toolCallId,
|
||||
name: mcpToolResponse.tool.id,
|
||||
response: {
|
||||
output: !resp.isError ? resp.content : undefined,
|
||||
error: resp.isError ? resp.content : undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
} satisfies Content
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
public buildSdkMessages(
|
||||
currentReqMessages: Content[],
|
||||
output: string,
|
||||
toolResults: Content[],
|
||||
toolCalls: FunctionCall[]
|
||||
): Content[] {
|
||||
const parts: Part[] = []
|
||||
if (output) {
|
||||
parts.push({
|
||||
text: output
|
||||
})
|
||||
}
|
||||
toolCalls.forEach((toolCall) => {
|
||||
parts.push({
|
||||
functionCall: toolCall
|
||||
})
|
||||
})
|
||||
parts.push(
|
||||
...toolResults
|
||||
.map((ts) => ts.parts)
|
||||
.flat()
|
||||
.filter((p) => p !== undefined)
|
||||
)
|
||||
|
||||
const userMessage: Content = {
|
||||
role: 'user',
|
||||
parts: parts
|
||||
}
|
||||
|
||||
return [...currentReqMessages, userMessage]
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: GeminiSdkMessageParam): number {
|
||||
return (
|
||||
message.parts?.reduce((acc, part) => {
|
||||
if (part.text) {
|
||||
return acc + estimateTextTokens(part.text)
|
||||
}
|
||||
if (part.functionCall) {
|
||||
return acc + estimateTextTokens(JSON.stringify(part.functionCall))
|
||||
}
|
||||
if (part.functionResponse) {
|
||||
return acc + estimateTextTokens(JSON.stringify(part.functionResponse.response))
|
||||
}
|
||||
if (part.inlineData) {
|
||||
return acc + estimateTextTokens(part.inlineData.data || '')
|
||||
}
|
||||
if (part.fileData) {
|
||||
return acc + estimateTextTokens(part.fileData.fileUri || '')
|
||||
}
|
||||
return acc
|
||||
}, 0) || 0
|
||||
)
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] {
|
||||
return sdkPayload.history || []
|
||||
}
|
||||
|
||||
private async uploadFile(file: FileType): Promise<File> {
|
||||
return await this.sdkInstance!.files.upload({
|
||||
file: file.path,
|
||||
config: {
|
||||
mimeType: 'application/pdf',
|
||||
name: file.id,
|
||||
displayName: file.origin_name
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private async base64File(file: FileType) {
|
||||
const { data } = await window.api.file.base64File(file.id + file.ext)
|
||||
return {
|
||||
data,
|
||||
mimeType: 'application/pdf'
|
||||
}
|
||||
}
|
||||
|
||||
private async retrieveFile(file: FileType): Promise<File | undefined> {
|
||||
const cachedResponse = CacheService.get<any>('gemini_file_list')
|
||||
|
||||
if (cachedResponse) {
|
||||
return this.processResponse(cachedResponse, file)
|
||||
}
|
||||
|
||||
const response = await this.sdkInstance!.files.list()
|
||||
CacheService.set('gemini_file_list', response, 3000)
|
||||
|
||||
return this.processResponse(response, file)
|
||||
}
|
||||
|
||||
private async processResponse(response: Pager<File>, file: FileType) {
|
||||
for await (const f of response) {
|
||||
if (f.state === FileState.ACTIVE) {
|
||||
if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) {
|
||||
return f
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
// @ts-ignore unused
|
||||
private async listFiles(): Promise<File[]> {
|
||||
const files: File[] = []
|
||||
for await (const f of await this.sdkInstance!.files.list()) {
|
||||
files.push(f)
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
// @ts-ignore unused
|
||||
private async deleteFile(fileId: string) {
|
||||
await this.sdkInstance!.files.delete({ name: fileId })
|
||||
}
|
||||
}
|
||||
6
src/renderer/src/aiCore/clients/index.ts
Normal file
6
src/renderer/src/aiCore/clients/index.ts
Normal file
@ -0,0 +1,6 @@
|
||||
export * from './ApiClientFactory'
|
||||
export * from './BaseApiClient'
|
||||
export * from './types'
|
||||
|
||||
// Export specific clients from subdirectories
|
||||
export * from './openai/OpenAIApiClient'
|
||||
682
src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts
Normal file
682
src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts
Normal file
@ -0,0 +1,682 @@
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import {
|
||||
findTokenLimit,
|
||||
GEMINI_FLASH_MODEL_REGEX,
|
||||
getOpenAIWebSearchParams,
|
||||
isDoubaoThinkingAutoModel,
|
||||
isReasoningModel,
|
||||
isSupportedReasoningEffortGrokModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedReasoningEffortOpenAIModel,
|
||||
isSupportedThinkingTokenClaudeModel,
|
||||
isSupportedThinkingTokenDoubaoModel,
|
||||
isSupportedThinkingTokenGeminiModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isSupportedThinkingTokenQwenModel,
|
||||
isVisionModel
|
||||
} from '@renderer/config/models'
|
||||
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
// For Copilot token
|
||||
import {
|
||||
Assistant,
|
||||
EFFORT_RATIO,
|
||||
FileTypes,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
OpenAISdkMessageParam,
|
||||
OpenAISdkParams,
|
||||
OpenAISdkRawChunk,
|
||||
OpenAISdkRawContentSource,
|
||||
OpenAISdkRawOutput,
|
||||
ReasoningEffortOptionalParams
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToOpenAICompatibleMessage,
|
||||
mcpToolsToOpenAIChatTools,
|
||||
openAIToolsToMcpTool
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
import { ChatCompletionContentPart, ChatCompletionContentPartRefusal, ChatCompletionTool } from 'openai/resources'
|
||||
|
||||
import { GenericChunk } from '../../middleware/schemas'
|
||||
import { RequestTransformer, ResponseChunkTransformer, ResponseChunkTransformerContext } from '../types'
|
||||
import { OpenAIBaseClient } from './OpenAIBaseClient'
|
||||
|
||||
export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||
OpenAI | AzureOpenAI,
|
||||
OpenAISdkParams,
|
||||
OpenAISdkRawOutput,
|
||||
OpenAISdkRawChunk,
|
||||
OpenAISdkMessageParam,
|
||||
OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
|
||||
ChatCompletionTool
|
||||
> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: OpenAISdkParams,
|
||||
options?: OpenAI.RequestOptions
|
||||
): Promise<OpenAISdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
// @ts-ignore - SDK参数可能有额外的字段
|
||||
return await sdk.chat.completions.create(payload, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
// Method for reasoning effort, moved from OpenAIProvider
|
||||
override getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
|
||||
if (this.provider.id === 'groq') {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
// Doubao 思考模式支持
|
||||
if (isSupportedThinkingTokenDoubaoModel(model)) {
|
||||
// reasoningEffort 为空,默认开启 enabled
|
||||
if (!reasoningEffort) {
|
||||
return { thinking: { type: 'disabled' } }
|
||||
}
|
||||
if (reasoningEffort === 'high') {
|
||||
return { thinking: { type: 'enabled' } }
|
||||
}
|
||||
if (reasoningEffort === 'auto' && isDoubaoThinkingAutoModel(model)) {
|
||||
return { thinking: { type: 'auto' } }
|
||||
}
|
||||
// 其他情况不带 thinking 字段
|
||||
return {}
|
||||
}
|
||||
|
||||
if (!reasoningEffort) {
|
||||
if (isSupportedThinkingTokenQwenModel(model)) {
|
||||
return { enable_thinking: false }
|
||||
}
|
||||
|
||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||
// openrouter没有提供一个不推理的选项,先隐藏
|
||||
if (this.provider.id === 'openrouter') {
|
||||
return { reasoning: { max_tokens: 0, exclude: true } }
|
||||
}
|
||||
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
|
||||
return { reasoning_effort: 'none' }
|
||||
}
|
||||
return {}
|
||||
}
|
||||
|
||||
if (isSupportedThinkingTokenDoubaoModel(model)) {
|
||||
return { thinking: { type: 'disabled' } }
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
const budgetTokens = Math.floor(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min!
|
||||
)
|
||||
|
||||
// OpenRouter models
|
||||
if (model.provider === 'openrouter') {
|
||||
if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) {
|
||||
return {
|
||||
reasoning: {
|
||||
effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Qwen models
|
||||
if (isSupportedThinkingTokenQwenModel(model)) {
|
||||
return {
|
||||
enable_thinking: true,
|
||||
thinking_budget: budgetTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Grok models
|
||||
if (isSupportedReasoningEffortGrokModel(model)) {
|
||||
return {
|
||||
reasoning_effort: reasoningEffort
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI models
|
||||
if (isSupportedReasoningEffortOpenAIModel(model) || isSupportedThinkingTokenGeminiModel(model)) {
|
||||
return {
|
||||
reasoning_effort: reasoningEffort
|
||||
}
|
||||
}
|
||||
|
||||
// Claude models
|
||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||
const maxTokens = assistant.settings?.maxTokens
|
||||
return {
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budget_tokens: Math.floor(
|
||||
Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Doubao models
|
||||
if (isSupportedThinkingTokenDoubaoModel(model)) {
|
||||
if (assistant.settings?.reasoning_effort === 'high') {
|
||||
return {
|
||||
thinking: {
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default case: no special thinking settings
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the provider does not support files
|
||||
* @returns True if the provider does not support files, false otherwise
|
||||
*/
|
||||
private get isNotSupportFiles() {
|
||||
if (this.provider?.isNotSupportArrayContent) {
|
||||
return true
|
||||
}
|
||||
|
||||
const providers = ['deepseek', 'baichuan', 'minimax', 'xirang']
|
||||
|
||||
return providers.includes(this.provider.id)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message parameter
|
||||
* @param message - The message
|
||||
* @param model - The model
|
||||
* @returns The message parameter
|
||||
*/
|
||||
public async convertMessageToSdkParam(message: Message, model: Model): Promise<OpenAISdkMessageParam> {
|
||||
const isVision = isVisionModel(model)
|
||||
const content = await this.getMessageContent(message)
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
|
||||
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content
|
||||
} as OpenAISdkMessageParam
|
||||
}
|
||||
|
||||
// If the model does not support files, extract the file content
|
||||
if (this.isNotSupportFiles) {
|
||||
const fileContent = await this.extractFileContent(message)
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: content + '\n\n---\n\n' + fileContent
|
||||
} as OpenAISdkMessageParam
|
||||
}
|
||||
|
||||
// If the model supports files, add the file content to the message
|
||||
const parts: ChatCompletionContentPart[] = []
|
||||
|
||||
if (content) {
|
||||
parts.push({ type: 'text', text: content })
|
||||
}
|
||||
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (isVision) {
|
||||
if (imageBlock.file) {
|
||||
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
|
||||
parts.push({ type: 'image_url', image_url: { url: image.data } })
|
||||
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
|
||||
parts.push({ type: 'image_url', image_url: { url: imageBlock.url } })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const file = fileBlock.file
|
||||
if (!file) {
|
||||
continue
|
||||
}
|
||||
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
type: 'text',
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
} as OpenAISdkMessageParam
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ChatCompletionTool[] {
|
||||
return mcpToolsToOpenAIChatTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcp(
|
||||
toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
|
||||
mcpTools: MCPTool[]
|
||||
): MCPTool | undefined {
|
||||
return openAIToolsToMcpTool(mcpTools, toolCall)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcpToolResponse(
|
||||
toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
|
||||
mcpTool: MCPTool
|
||||
): ToolCallResponse {
|
||||
let parsedArgs: any
|
||||
try {
|
||||
parsedArgs = JSON.parse(toolCall.function.arguments)
|
||||
} catch {
|
||||
parsedArgs = toolCall.function.arguments
|
||||
}
|
||||
return {
|
||||
id: toolCall.id,
|
||||
toolCallId: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
} as ToolCallResponse
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): OpenAISdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
// This case is for Anthropic/Claude like tool usage, OpenAI uses tool_call_id
|
||||
// For OpenAI, we primarily expect toolCallId. This might need adjustment if mixing provider concepts.
|
||||
return mcpToolCallResponseToOpenAICompatibleMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
|
||||
return {
|
||||
role: 'tool',
|
||||
tool_call_id: mcpToolResponse.toolCallId,
|
||||
content: JSON.stringify(resp.content)
|
||||
} as OpenAI.Chat.Completions.ChatCompletionToolMessageParam
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
public buildSdkMessages(
|
||||
currentReqMessages: OpenAISdkMessageParam[],
|
||||
output: string,
|
||||
toolResults: OpenAISdkMessageParam[],
|
||||
toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[]
|
||||
): OpenAISdkMessageParam[] {
|
||||
const assistantMessage: OpenAISdkMessageParam = {
|
||||
role: 'assistant',
|
||||
content: output,
|
||||
tool_calls: toolCalls.length > 0 ? toolCalls : undefined
|
||||
}
|
||||
const newReqMessages = [...currentReqMessages, assistantMessage, ...toolResults]
|
||||
return newReqMessages
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: OpenAISdkMessageParam): number {
|
||||
let sum = 0
|
||||
if (typeof message.content === 'string') {
|
||||
sum += estimateTextTokens(message.content)
|
||||
} else if (Array.isArray(message.content)) {
|
||||
sum += (message.content || [])
|
||||
.map((part: ChatCompletionContentPart | ChatCompletionContentPartRefusal) => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return estimateTextTokens(part.text)
|
||||
case 'image_url':
|
||||
return estimateTextTokens(part.image_url.url)
|
||||
case 'input_audio':
|
||||
return estimateTextTokens(part.input_audio.data)
|
||||
case 'file':
|
||||
return estimateTextTokens(part.file.file_data || '')
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
.reduce((acc, curr) => acc + curr, 0)
|
||||
}
|
||||
if ('tool_calls' in message && message.tool_calls) {
|
||||
sum += message.tool_calls.reduce((acc, toolCall) => {
|
||||
return acc + estimateTextTokens(JSON.stringify(toolCall.function.arguments))
|
||||
}, 0)
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: OpenAISdkParams): OpenAISdkMessageParam[] {
|
||||
return sdkPayload.messages || []
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<OpenAISdkParams, OpenAISdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: OpenAISdkParams
|
||||
messages: OpenAISdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
let systemMessage = { role: 'system', content: assistant.prompt || '' }
|
||||
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
systemMessage = {
|
||||
role: 'developer',
|
||||
content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}`
|
||||
}
|
||||
}
|
||||
|
||||
if (model.id.includes('o1-mini') || model.id.includes('o1-preview')) {
|
||||
systemMessage.role = 'assistant'
|
||||
}
|
||||
|
||||
// 2. 设置工具(必须在this.usesystemPromptForTools前面)
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
systemMessage.content = await buildSystemPrompt(systemMessage.content || '', mcpTools)
|
||||
}
|
||||
|
||||
// 3. 处理用户消息
|
||||
const userMessages: OpenAISdkMessageParam[] = []
|
||||
if (typeof messages === 'string') {
|
||||
userMessages.push({ role: 'user', content: messages })
|
||||
} else {
|
||||
const processedMessages = addImageFileToContents(messages)
|
||||
for (const message of processedMessages) {
|
||||
userMessages.push(await this.convertMessageToSdkParam(message, model))
|
||||
}
|
||||
}
|
||||
|
||||
const lastUserMsg = userMessages.findLast((m) => m.role === 'user')
|
||||
if (lastUserMsg && isSupportedThinkingTokenQwenModel(model)) {
|
||||
const postsuffix = '/no_think'
|
||||
const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true
|
||||
const currentContent = lastUserMsg.content
|
||||
|
||||
lastUserMsg.content = processPostsuffixQwen3Model(currentContent, postsuffix, qwenThinkModeEnabled) as any
|
||||
}
|
||||
|
||||
// 4. 最终请求消息
|
||||
let reqMessages: OpenAISdkMessageParam[]
|
||||
if (!systemMessage.content) {
|
||||
reqMessages = [...userMessages]
|
||||
} else {
|
||||
reqMessages = [systemMessage, ...userMessages].filter(Boolean) as OpenAISdkMessageParam[]
|
||||
}
|
||||
|
||||
reqMessages = processReqMessages(model, reqMessages)
|
||||
|
||||
// 5. 创建通用参数
|
||||
const commonParams = {
|
||||
model: model.id,
|
||||
messages:
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages
|
||||
: reqMessages,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
max_tokens: maxTokens,
|
||||
tools: tools.length > 0 ? tools : undefined,
|
||||
service_tier: this.getServiceTier(model),
|
||||
...this.getProviderSpecificParameters(assistant, model),
|
||||
...this.getReasoningEffort(assistant, model),
|
||||
...getOpenAIWebSearchParams(model, enableWebSearch),
|
||||
...this.getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
// Create the appropriate parameters object based on whether streaming is enabled
|
||||
const sdkParams: OpenAISdkParams = streamOutput
|
||||
? {
|
||||
...commonParams,
|
||||
stream: true
|
||||
}
|
||||
: {
|
||||
...commonParams,
|
||||
stream: false
|
||||
}
|
||||
|
||||
const timeout = this.getTimeout(model)
|
||||
|
||||
return { payload: sdkParams, messages: reqMessages, metadata: { timeout } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 在RawSdkChunkToGenericChunkMiddleware中使用
|
||||
getResponseChunkTransformer = (): ResponseChunkTransformer<OpenAISdkRawChunk> => {
|
||||
let hasBeenCollectedWebSearch = false
|
||||
const collectWebSearchData = (
|
||||
chunk: OpenAISdkRawChunk,
|
||||
contentSource: OpenAISdkRawContentSource,
|
||||
context: ResponseChunkTransformerContext
|
||||
) => {
|
||||
if (hasBeenCollectedWebSearch) {
|
||||
return
|
||||
}
|
||||
// OpenAI annotations
|
||||
// @ts-ignore - annotations may not be in standard type definitions
|
||||
const annotations = contentSource.annotations || chunk.annotations
|
||||
if (annotations && annotations.length > 0 && annotations[0].type === 'url_citation') {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
results: annotations,
|
||||
source: WebSearchSource.OPENAI
|
||||
}
|
||||
}
|
||||
|
||||
// Grok citations
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
if (context.provider?.id === 'grok' && chunk.citations) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
results: chunk.citations,
|
||||
source: WebSearchSource.GROK
|
||||
}
|
||||
}
|
||||
|
||||
// Perplexity citations
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
if (context.provider?.id === 'perplexity' && chunk.citations && chunk.citations.length > 0) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
results: chunk.citations,
|
||||
source: WebSearchSource.PERPLEXITY
|
||||
}
|
||||
}
|
||||
|
||||
// OpenRouter citations
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
if (context.provider?.id === 'openrouter' && chunk.citations && chunk.citations.length > 0) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
results: chunk.citations,
|
||||
source: WebSearchSource.OPENROUTER
|
||||
}
|
||||
}
|
||||
|
||||
// Zhipu web search
|
||||
// @ts-ignore - web_search may not be in standard type definitions
|
||||
if (context.provider?.id === 'zhipu' && chunk.web_search) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - web_search may not be in standard type definitions
|
||||
results: chunk.web_search,
|
||||
source: WebSearchSource.ZHIPU
|
||||
}
|
||||
}
|
||||
|
||||
// Hunyuan web search
|
||||
// @ts-ignore - search_info may not be in standard type definitions
|
||||
if (context.provider?.id === 'hunyuan' && chunk.search_info?.search_results) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - search_info may not be in standard type definitions
|
||||
results: chunk.search_info.search_results,
|
||||
source: WebSearchSource.HUNYUAN
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 放到AnthropicApiClient中
|
||||
// // Other providers...
|
||||
// // @ts-ignore - web_search may not be in standard type definitions
|
||||
// if (chunk.web_search) {
|
||||
// const sourceMap: Record<string, string> = {
|
||||
// openai: 'openai',
|
||||
// anthropic: 'anthropic',
|
||||
// qwenlm: 'qwen'
|
||||
// }
|
||||
// const source = sourceMap[context.provider?.id] || 'openai_response'
|
||||
// return {
|
||||
// results: chunk.web_search,
|
||||
// source: source as const
|
||||
// }
|
||||
// }
|
||||
|
||||
return null
|
||||
}
|
||||
const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = []
|
||||
return (context: ResponseChunkTransformerContext) => ({
|
||||
async transform(chunk: OpenAISdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
// 处理chunk
|
||||
if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) {
|
||||
const choice = chunk.choices[0]
|
||||
|
||||
if (!choice) return
|
||||
|
||||
// 对于流式响应,使用delta;对于非流式响应,使用message
|
||||
const contentSource: OpenAISdkRawContentSource | null =
|
||||
'delta' in choice ? choice.delta : 'message' in choice ? choice.message : null
|
||||
|
||||
if (!contentSource) return
|
||||
|
||||
const webSearchData = collectWebSearchData(chunk, contentSource, context)
|
||||
if (webSearchData) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: webSearchData
|
||||
})
|
||||
}
|
||||
|
||||
// 处理推理内容 (e.g. from OpenRouter DeepSeek-R1)
|
||||
// @ts-ignore - reasoning_content is not in standard OpenAI types but some providers use it
|
||||
const reasoningText = contentSource.reasoning_content || contentSource.reasoning
|
||||
if (reasoningText) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: reasoningText
|
||||
})
|
||||
}
|
||||
|
||||
// 处理文本内容
|
||||
if (contentSource.content) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: contentSource.content
|
||||
})
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
if (contentSource.tool_calls) {
|
||||
for (const toolCall of contentSource.tool_calls) {
|
||||
if ('index' in toolCall) {
|
||||
const { id, index, function: fun } = toolCall
|
||||
if (fun?.name) {
|
||||
toolCalls[index] = {
|
||||
id: id || '',
|
||||
function: {
|
||||
name: fun.name,
|
||||
arguments: fun.arguments || ''
|
||||
},
|
||||
type: 'function'
|
||||
}
|
||||
} else if (fun?.arguments) {
|
||||
toolCalls[index].function.arguments += fun.arguments
|
||||
}
|
||||
} else {
|
||||
toolCalls.push(toolCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理finish_reason,发送流结束信号
|
||||
if ('finish_reason' in choice && choice.finish_reason) {
|
||||
Logger.debug(`[OpenAIApiClient] Stream finished with reason: ${choice.finish_reason}`)
|
||||
if (toolCalls.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
}
|
||||
const webSearchData = collectWebSearchData(chunk, contentSource, context)
|
||||
if (webSearchData) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: webSearchData
|
||||
})
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: chunk.usage?.prompt_tokens || 0,
|
||||
completion_tokens: chunk.usage?.completion_tokens || 0,
|
||||
total_tokens: (chunk.usage?.prompt_tokens || 0) + (chunk.usage?.completion_tokens || 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
258
src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts
Normal file
258
src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts
Normal file
@ -0,0 +1,258 @@
|
||||
import {
|
||||
isClaudeReasoningModel,
|
||||
isNotSupportTemperatureAndTopP,
|
||||
isOpenAIReasoningModel,
|
||||
isSupportedModel,
|
||||
isSupportedReasoningEffortOpenAIModel
|
||||
} from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import store from '@renderer/store'
|
||||
import { SettingsState } from '@renderer/store/settings'
|
||||
import { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import {
|
||||
OpenAIResponseSdkMessageParam,
|
||||
OpenAIResponseSdkParams,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAIResponseSdkTool,
|
||||
OpenAIResponseSdkToolCall,
|
||||
OpenAISdkMessageParam,
|
||||
OpenAISdkParams,
|
||||
OpenAISdkRawChunk,
|
||||
OpenAISdkRawOutput,
|
||||
ReasoningEffortOptionalParams
|
||||
} from '@renderer/types/sdk'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
|
||||
/**
|
||||
* 抽象的OpenAI基础客户端类,包含两个OpenAI客户端之间的共享功能
|
||||
*/
|
||||
export abstract class OpenAIBaseClient<
|
||||
TSdkInstance extends OpenAI | AzureOpenAI,
|
||||
TSdkParams extends OpenAISdkParams | OpenAIResponseSdkParams,
|
||||
TRawOutput extends OpenAISdkRawOutput | OpenAIResponseSdkRawOutput,
|
||||
TRawChunk extends OpenAISdkRawChunk | OpenAIResponseSdkRawChunk,
|
||||
TMessageParam extends OpenAISdkMessageParam | OpenAIResponseSdkMessageParam,
|
||||
TToolCall extends OpenAI.Chat.Completions.ChatCompletionMessageToolCall | OpenAIResponseSdkToolCall,
|
||||
TSdkSpecificTool extends OpenAI.Chat.Completions.ChatCompletionTool | OpenAIResponseSdkTool
|
||||
> extends BaseApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
// 仅适用于openai
|
||||
override getBaseURL(): string {
|
||||
const host = this.provider.apiHost
|
||||
return formatApiHost(host)
|
||||
}
|
||||
|
||||
override async generateImage({
|
||||
model,
|
||||
prompt,
|
||||
negativePrompt,
|
||||
imageSize,
|
||||
batchSize,
|
||||
seed,
|
||||
numInferenceSteps,
|
||||
guidanceScale,
|
||||
signal,
|
||||
promptEnhancement
|
||||
}: GenerateImageParams): Promise<string[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const response = (await sdk.request({
|
||||
method: 'post',
|
||||
path: '/images/generations',
|
||||
signal,
|
||||
body: {
|
||||
model,
|
||||
prompt,
|
||||
negative_prompt: negativePrompt,
|
||||
image_size: imageSize,
|
||||
batch_size: batchSize,
|
||||
seed: seed ? parseInt(seed) : undefined,
|
||||
num_inference_steps: numInferenceSteps,
|
||||
guidance_scale: guidanceScale,
|
||||
prompt_enhancement: promptEnhancement
|
||||
}
|
||||
})) as { data: Array<{ url: string }> }
|
||||
|
||||
return response.data.map((item) => item.url)
|
||||
}
|
||||
|
||||
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
try {
|
||||
const data = await sdk.embeddings.create({
|
||||
model: model.id,
|
||||
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
|
||||
encoding_format: 'float'
|
||||
})
|
||||
return data.data[0].embedding.length
|
||||
} catch (e) {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
override async listModels(): Promise<OpenAI.Models.Model[]> {
|
||||
try {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const response = await sdk.models.list()
|
||||
if (this.provider.id === 'github') {
|
||||
// @ts-ignore key is not typed
|
||||
return response?.body
|
||||
.map((model) => ({
|
||||
id: model.name,
|
||||
description: model.summary,
|
||||
object: 'model',
|
||||
owned_by: model.publisher
|
||||
}))
|
||||
.filter(isSupportedModel)
|
||||
}
|
||||
if (this.provider.id === 'together') {
|
||||
// @ts-ignore key is not typed
|
||||
return response?.body.map((model) => ({
|
||||
id: model.id,
|
||||
description: model.display_name,
|
||||
object: 'model',
|
||||
owned_by: model.organization
|
||||
}))
|
||||
}
|
||||
const models = response.data || []
|
||||
models.forEach((model) => {
|
||||
model.id = model.id.trim()
|
||||
})
|
||||
|
||||
return models.filter(isSupportedModel)
|
||||
} catch (error) {
|
||||
console.error('Error listing models:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
let apiKeyForSdkInstance = this.provider.apiKey
|
||||
|
||||
if (this.provider.id === 'copilot') {
|
||||
const defaultHeaders = store.getState().copilot.defaultHeaders
|
||||
const { token } = await window.api.copilot.getToken(defaultHeaders)
|
||||
// this.provider.apiKey不允许修改
|
||||
// this.provider.apiKey = token
|
||||
apiKeyForSdkInstance = token
|
||||
}
|
||||
|
||||
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
|
||||
this.sdkInstance = new AzureOpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: apiKeyForSdkInstance,
|
||||
apiVersion: this.provider.apiVersion,
|
||||
endpoint: this.provider.apiHost
|
||||
}) as TSdkInstance
|
||||
} else {
|
||||
this.sdkInstance = new OpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: apiKeyForSdkInstance,
|
||||
baseURL: this.getBaseURL(),
|
||||
defaultHeaders: {
|
||||
...this.defaultHeaders(),
|
||||
...(this.provider.id === 'copilot' ? { 'editor-version': 'vscode/1.97.2' } : {}),
|
||||
...(this.provider.id === 'copilot' ? { 'copilot-vision-request': 'true' } : {})
|
||||
}
|
||||
}) as TSdkInstance
|
||||
}
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
if (
|
||||
isNotSupportTemperatureAndTopP(model) ||
|
||||
(assistant.settings?.reasoning_effort && isClaudeReasoningModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.temperature
|
||||
}
|
||||
|
||||
override getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
if (
|
||||
isNotSupportTemperatureAndTopP(model) ||
|
||||
(assistant.settings?.reasoning_effort && isClaudeReasoningModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.topP
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the provider specific parameters for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The provider specific parameters
|
||||
*/
|
||||
protected getProviderSpecificParameters(assistant: Assistant, model: Model) {
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
if (this.provider.id === 'openrouter') {
|
||||
if (model.id.includes('deepseek-r1')) {
|
||||
return {
|
||||
include_reasoning: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isOpenAIReasoningModel(model)) {
|
||||
return {
|
||||
max_tokens: undefined,
|
||||
max_completion_tokens: maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
protected getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
|
||||
if (!isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
|
||||
const summaryText = openAI?.summaryText || 'off'
|
||||
|
||||
let summary: string | undefined = undefined
|
||||
|
||||
if (summaryText === 'off' || model.id.includes('o1-pro')) {
|
||||
summary = undefined
|
||||
} else {
|
||||
summary = summaryText
|
||||
}
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
if (!reasoningEffort) {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
return {
|
||||
reasoning: {
|
||||
effort: reasoningEffort as OpenAI.ReasoningEffort,
|
||||
summary: summary
|
||||
} as OpenAI.Reasoning
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,532 @@
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import {
|
||||
isOpenAIChatCompletionOnlyModel,
|
||||
isSupportedReasoningEffortOpenAIModel,
|
||||
isVisionModel
|
||||
} from '@renderer/config/models'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import {
|
||||
FileTypes,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
OpenAIResponseSdkMessageParam,
|
||||
OpenAIResponseSdkParams,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAIResponseSdkTool,
|
||||
OpenAIResponseSdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToOpenAIMessage,
|
||||
mcpToolsToOpenAIResponseTools,
|
||||
openAIToolsToMcpTool
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { isEmpty } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
import { OpenAIAPIClient } from './OpenAIApiClient'
|
||||
import { OpenAIBaseClient } from './OpenAIBaseClient'
|
||||
|
||||
export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
OpenAI,
|
||||
OpenAIResponseSdkParams,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkMessageParam,
|
||||
OpenAIResponseSdkToolCall,
|
||||
OpenAIResponseSdkTool
|
||||
> {
|
||||
private client: OpenAIAPIClient
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.client = new OpenAIAPIClient(provider)
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型特征选择合适的客户端
|
||||
*/
|
||||
public getClient(model: Model) {
|
||||
if (isOpenAIChatCompletionOnlyModel(model)) {
|
||||
return this.client
|
||||
} else {
|
||||
return this
|
||||
}
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
return new OpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: this.provider.apiKey,
|
||||
baseURL: this.getBaseURL(),
|
||||
defaultHeaders: {
|
||||
...this.defaultHeaders()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: OpenAIResponseSdkParams,
|
||||
options?: OpenAI.RequestOptions
|
||||
): Promise<OpenAIResponseSdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
return await sdk.responses.create(payload, options)
|
||||
}
|
||||
|
||||
public async convertMessageToSdkParam(message: Message, model: Model): Promise<OpenAIResponseSdkMessageParam> {
|
||||
const isVision = isVisionModel(model)
|
||||
const content = await this.getMessageContent(message)
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
|
||||
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
|
||||
if (message.role === 'assistant') {
|
||||
return {
|
||||
role: 'assistant',
|
||||
content: content
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: content ? [{ type: 'input_text', text: content }] : []
|
||||
} as OpenAI.Responses.EasyInputMessage
|
||||
}
|
||||
}
|
||||
|
||||
const parts: OpenAI.Responses.ResponseInputContent[] = []
|
||||
if (content) {
|
||||
parts.push({
|
||||
type: 'input_text',
|
||||
text: content
|
||||
})
|
||||
}
|
||||
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (isVision) {
|
||||
if (imageBlock.file) {
|
||||
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
|
||||
parts.push({
|
||||
detail: 'auto',
|
||||
type: 'input_image',
|
||||
image_url: image.data as string
|
||||
})
|
||||
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
|
||||
parts.push({
|
||||
detail: 'auto',
|
||||
type: 'input_image',
|
||||
image_url: imageBlock.url
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const file = fileBlock.file
|
||||
if (!file) continue
|
||||
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
type: 'input_text',
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): OpenAI.Responses.Tool[] {
|
||||
return mcpToolsToOpenAIResponseTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcp(toolCall: OpenAIResponseSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
return openAIToolsToMcpTool(mcpTools, toolCall)
|
||||
}
|
||||
public convertSdkToolCallToMcpToolResponse(toolCall: OpenAIResponseSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||
const parsedArgs = (() => {
|
||||
try {
|
||||
return JSON.parse(toolCall.arguments)
|
||||
} catch {
|
||||
return toolCall.arguments
|
||||
}
|
||||
})()
|
||||
|
||||
return {
|
||||
id: toolCall.call_id,
|
||||
toolCallId: toolCall.call_id,
|
||||
tool: mcpTool,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): OpenAIResponseSdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
|
||||
return {
|
||||
type: 'function_call_output',
|
||||
call_id: mcpToolResponse.toolCallId,
|
||||
output: JSON.stringify(resp.content)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
public buildSdkMessages(
|
||||
currentReqMessages: OpenAIResponseSdkMessageParam[],
|
||||
output: string,
|
||||
toolResults: OpenAIResponseSdkMessageParam[],
|
||||
toolCalls: OpenAIResponseSdkToolCall[]
|
||||
): OpenAIResponseSdkMessageParam[] {
|
||||
const assistantMessage: OpenAIResponseSdkMessageParam = {
|
||||
role: 'assistant',
|
||||
content: [{ type: 'input_text', text: output }]
|
||||
}
|
||||
const newReqMessages = [...currentReqMessages, assistantMessage, ...(toolCalls || []), ...(toolResults || [])]
|
||||
return newReqMessages
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: OpenAIResponseSdkMessageParam): number {
|
||||
let sum = 0
|
||||
if ('content' in message) {
|
||||
if (typeof message.content === 'string') {
|
||||
sum += estimateTextTokens(message.content)
|
||||
} else if (Array.isArray(message.content)) {
|
||||
for (const part of message.content) {
|
||||
switch (part.type) {
|
||||
case 'input_text':
|
||||
sum += estimateTextTokens(part.text)
|
||||
break
|
||||
case 'input_image':
|
||||
sum += estimateTextTokens(part.image_url || '')
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
switch (message.type) {
|
||||
case 'function_call_output':
|
||||
sum += estimateTextTokens(message.output)
|
||||
break
|
||||
case 'function_call':
|
||||
sum += estimateTextTokens(message.arguments)
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: OpenAIResponseSdkParams): OpenAIResponseSdkMessageParam[] {
|
||||
if (typeof sdkPayload.input === 'string') {
|
||||
return [{ role: 'user', content: sdkPayload.input }]
|
||||
}
|
||||
return sdkPayload.input
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<OpenAIResponseSdkParams, OpenAIResponseSdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: OpenAIResponseSdkParams
|
||||
messages: OpenAIResponseSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch, enableGenerateImage } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
const systemMessage: OpenAI.Responses.EasyInputMessage = {
|
||||
role: 'system',
|
||||
content: []
|
||||
}
|
||||
|
||||
const systemMessageContent: OpenAI.Responses.ResponseInputMessageContentList = []
|
||||
const systemMessageInput: OpenAI.Responses.ResponseInputText = {
|
||||
text: assistant.prompt || '',
|
||||
type: 'input_text'
|
||||
}
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
systemMessage.role = 'developer'
|
||||
}
|
||||
|
||||
// 2. 设置工具
|
||||
let tools: OpenAI.Responses.Tool[] = []
|
||||
const { tools: extraTools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
systemMessageInput.text = await buildSystemPrompt(systemMessageInput.text || '', mcpTools)
|
||||
}
|
||||
systemMessageContent.push(systemMessageInput)
|
||||
systemMessage.content = systemMessageContent
|
||||
|
||||
// 3. 处理用户消息
|
||||
let userMessage: OpenAI.Responses.ResponseInputItem[] = []
|
||||
if (typeof messages === 'string') {
|
||||
userMessage.push({ role: 'user', content: messages })
|
||||
} else {
|
||||
const processedMessages = addImageFileToContents(messages)
|
||||
for (const message of processedMessages) {
|
||||
userMessage.push(await this.convertMessageToSdkParam(message, model))
|
||||
}
|
||||
}
|
||||
// FIXME: 最好还是直接使用previous_response_id来处理(或者在数据库中存储image_generation_call的id)
|
||||
if (enableGenerateImage) {
|
||||
const finalAssistantMessage = userMessage.findLast(
|
||||
(m) => (m as OpenAI.Responses.EasyInputMessage).role === 'assistant'
|
||||
) as OpenAI.Responses.EasyInputMessage
|
||||
const finalUserMessage = userMessage.pop() as OpenAI.Responses.EasyInputMessage
|
||||
if (
|
||||
finalAssistantMessage &&
|
||||
Array.isArray(finalAssistantMessage.content) &&
|
||||
finalUserMessage &&
|
||||
Array.isArray(finalUserMessage.content)
|
||||
) {
|
||||
finalAssistantMessage.content = [...finalAssistantMessage.content, ...finalUserMessage.content]
|
||||
}
|
||||
// 这里是故意将上条助手消息的内容(包含图片和文件)作为用户消息发送
|
||||
userMessage = [{ ...finalAssistantMessage, role: 'user' } as OpenAI.Responses.EasyInputMessage]
|
||||
}
|
||||
|
||||
// 4. 最终请求消息
|
||||
let reqMessages: OpenAI.Responses.ResponseInput
|
||||
if (!systemMessage.content) {
|
||||
reqMessages = [...userMessage]
|
||||
} else {
|
||||
reqMessages = [systemMessage, ...userMessage].filter(Boolean) as OpenAI.Responses.EasyInputMessage[]
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
tools.push({
|
||||
type: 'web_search_preview'
|
||||
})
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
tools.push({
|
||||
type: 'image_generation',
|
||||
partial_images: streamOutput ? 2 : undefined
|
||||
})
|
||||
}
|
||||
|
||||
const toolChoices: OpenAI.Responses.ToolChoiceTypes = {
|
||||
type: 'web_search_preview'
|
||||
}
|
||||
|
||||
tools = tools.concat(extraTools)
|
||||
const commonParams = {
|
||||
model: model.id,
|
||||
input:
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages
|
||||
: reqMessages,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
max_output_tokens: maxTokens,
|
||||
stream: streamOutput,
|
||||
tools: !isEmpty(tools) ? tools : undefined,
|
||||
tool_choice: enableWebSearch ? toolChoices : undefined,
|
||||
service_tier: this.getServiceTier(model),
|
||||
...(this.getReasoningEffort(assistant, model) as OpenAI.Reasoning),
|
||||
...this.getCustomParameters(assistant)
|
||||
}
|
||||
const sdkParams: OpenAIResponseSdkParams = streamOutput
|
||||
? {
|
||||
...commonParams,
|
||||
stream: true
|
||||
}
|
||||
: {
|
||||
...commonParams,
|
||||
stream: false
|
||||
}
|
||||
const timeout = this.getTimeout(model)
|
||||
return { payload: sdkParams, messages: reqMessages, metadata: { timeout } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<OpenAIResponseSdkRawChunk> {
|
||||
const toolCalls: OpenAIResponseSdkToolCall[] = []
|
||||
const outputItems: OpenAI.Responses.ResponseOutputItem[] = []
|
||||
return () => ({
|
||||
async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
// 处理chunk
|
||||
if ('output' in chunk) {
|
||||
for (const output of chunk.output) {
|
||||
switch (output.type) {
|
||||
case 'message':
|
||||
if (output.content[0].type === 'output_text') {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: output.content[0].text
|
||||
})
|
||||
if (output.content[0].annotations && output.content[0].annotations.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
source: WebSearchSource.OPENAI_RESPONSE,
|
||||
results: output.content[0].annotations
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
break
|
||||
case 'reasoning':
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: output.summary.map((s) => s.text).join('\n')
|
||||
})
|
||||
break
|
||||
case 'function_call':
|
||||
toolCalls.push(output)
|
||||
break
|
||||
case 'image_generation_call':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [`data:image/png;base64,${output.result}`]
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch (chunk.type) {
|
||||
case 'response.output_item.added':
|
||||
if (chunk.item.type === 'function_call') {
|
||||
outputItems.push(chunk.item)
|
||||
}
|
||||
break
|
||||
case 'response.reasoning_summary_text.delta':
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: chunk.delta
|
||||
})
|
||||
break
|
||||
case 'response.image_generation_call.generating':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
break
|
||||
case 'response.image_generation_call.partial_image':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_DELTA,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [`data:image/png;base64,${chunk.partial_image_b64}`]
|
||||
}
|
||||
})
|
||||
break
|
||||
case 'response.image_generation_call.completed':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE
|
||||
})
|
||||
break
|
||||
case 'response.output_text.delta': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: chunk.delta
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'response.function_call_arguments.done': {
|
||||
const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find(
|
||||
(item) => item.id === chunk.item_id
|
||||
)
|
||||
if (outputItem) {
|
||||
if (outputItem.type === 'function_call') {
|
||||
toolCalls.push({
|
||||
...outputItem,
|
||||
arguments: chunk.arguments
|
||||
})
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'response.content_part.done': {
|
||||
if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
source: WebSearchSource.OPENAI_RESPONSE,
|
||||
results: chunk.part.annotations
|
||||
}
|
||||
})
|
||||
}
|
||||
if (toolCalls.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'response.completed': {
|
||||
const completion_tokens = chunk.response.usage?.output_tokens || 0
|
||||
const total_tokens = chunk.response.usage?.total_tokens || 0
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: chunk.response.usage?.input_tokens || 0,
|
||||
completion_tokens: completion_tokens,
|
||||
total_tokens: total_tokens
|
||||
}
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'error': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.ERROR,
|
||||
error: {
|
||||
message: chunk.message,
|
||||
code: chunk.code
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
129
src/renderer/src/aiCore/clients/types.ts
Normal file
129
src/renderer/src/aiCore/clients/types.ts
Normal file
@ -0,0 +1,129 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import { Assistant, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
|
||||
import { Provider } from '@renderer/types'
|
||||
import {
|
||||
AnthropicSdkRawChunk,
|
||||
OpenAISdkRawChunk,
|
||||
SdkMessageParam,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CompletionsParams, GenericChunk } from '../middleware/schemas'
|
||||
|
||||
/**
|
||||
* 原始流监听器接口
|
||||
*/
|
||||
export interface RawStreamListener<TRawChunk = SdkRawChunk> {
|
||||
onChunk?: (chunk: TRawChunk) => void
|
||||
onStart?: () => void
|
||||
onEnd?: () => void
|
||||
onError?: (error: Error) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI 专用的流监听器
|
||||
*/
|
||||
export interface OpenAIStreamListener extends RawStreamListener<OpenAISdkRawChunk> {
|
||||
onChoice?: (choice: OpenAI.Chat.Completions.ChatCompletionChunk.Choice) => void
|
||||
onFinishReason?: (reason: string) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic 专用的流监听器
|
||||
*/
|
||||
export interface AnthropicStreamListener<TChunk extends AnthropicSdkRawChunk = AnthropicSdkRawChunk>
|
||||
extends RawStreamListener<TChunk> {
|
||||
onContentBlock?: (contentBlock: Anthropic.Messages.ContentBlock) => void
|
||||
onMessage?: (message: Anthropic.Messages.Message) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求转换器接口
|
||||
*/
|
||||
export interface RequestTransformer<
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam
|
||||
> {
|
||||
transform(
|
||||
completionsParams: CompletionsParams,
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
isRecursiveCall?: boolean,
|
||||
recursiveSdkMessages?: TMessageParam[]
|
||||
): Promise<{
|
||||
payload: TSdkParams
|
||||
messages: TMessageParam[]
|
||||
metadata?: Record<string, any>
|
||||
}>
|
||||
}
|
||||
|
||||
/**
|
||||
* 响应块转换器接口
|
||||
*/
|
||||
export type ResponseChunkTransformer<TRawChunk extends SdkRawChunk = SdkRawChunk, TContext = any> = (
|
||||
context?: TContext
|
||||
) => Transformer<TRawChunk, GenericChunk>
|
||||
|
||||
export interface ResponseChunkTransformerContext {
|
||||
isStreaming: boolean
|
||||
isEnabledToolCalling: boolean
|
||||
isEnabledWebSearch: boolean
|
||||
isEnabledReasoning: boolean
|
||||
mcpTools: MCPTool[]
|
||||
provider: Provider
|
||||
}
|
||||
|
||||
/**
|
||||
* API客户端接口
|
||||
*/
|
||||
export interface ApiClient<
|
||||
TSdkInstance = any,
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> {
|
||||
provider: Provider
|
||||
|
||||
// 核心方法 - 在中间件架构中,这个方法可能只是一个占位符
|
||||
// 实际的SDK调用由SdkCallMiddleware处理
|
||||
// completions(params: CompletionsParams): Promise<CompletionsResult>
|
||||
|
||||
createCompletions(payload: TSdkParams): Promise<TRawOutput>
|
||||
|
||||
// SDK相关方法
|
||||
getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
|
||||
getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk>
|
||||
|
||||
// 原始流监听方法
|
||||
attachRawStreamListener?(rawOutput: TRawOutput, listener: RawStreamListener<TRawChunk>): TRawOutput
|
||||
|
||||
// 工具转换相关方法 (保持可选,因为不是所有Provider都支持工具)
|
||||
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
|
||||
convertMcpToolResponseToSdkMessageParam?(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: any,
|
||||
model: Model
|
||||
): TMessageParam | undefined
|
||||
convertSdkToolCallToMcp?(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
|
||||
|
||||
// 构建SDK特定的消息列表,用于工具调用后的递归调用
|
||||
buildSdkMessages(
|
||||
currentReqMessages: TMessageParam[],
|
||||
output: TRawOutput | string,
|
||||
toolResults: TMessageParam[],
|
||||
toolCalls?: TToolCall[]
|
||||
): TMessageParam[]
|
||||
|
||||
// 从SDK载荷中提取消息数组(用于中间件中的类型安全访问)
|
||||
extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
|
||||
}
|
||||
130
src/renderer/src/aiCore/index.ts
Normal file
130
src/renderer/src/aiCore/index.ts
Normal file
@ -0,0 +1,130 @@
|
||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
||||
|
||||
import { OpenAIAPIClient } from './clients'
|
||||
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient'
|
||||
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
||||
import { applyCompletionsMiddlewares } from './middleware/composer'
|
||||
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
|
||||
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
|
||||
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
|
||||
import { MiddlewareRegistry } from './middleware/register'
|
||||
import { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
||||
|
||||
export default class AiProvider {
|
||||
private apiClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
// Use the new ApiClientFactory to get a BaseApiClient instance
|
||||
this.apiClient = ApiClientFactory.create(provider)
|
||||
}
|
||||
|
||||
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
// 1. 根据模型识别正确的客户端
|
||||
const model = params.assistant.model
|
||||
if (!model) {
|
||||
return Promise.reject(new Error('Model is required'))
|
||||
}
|
||||
|
||||
// 根据client类型选择合适的处理方式
|
||||
let client: BaseApiClient
|
||||
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
// AihubmixAPIClient: 根据模型选择合适的子client
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
|
||||
// OpenAIResponseAPIClient: 根据模型特征选择API类型
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else {
|
||||
// 其他client直接使用
|
||||
client = this.apiClient
|
||||
}
|
||||
|
||||
// 2. 构建中间件链
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
// images api
|
||||
if (isDedicatedImageGenerationModel(model)) {
|
||||
builder.clear()
|
||||
builder
|
||||
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
||||
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
||||
} else {
|
||||
// Existing logic for other models
|
||||
if (!params.enableReasoning) {
|
||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||
builder.remove(ThinkChunkMiddlewareName)
|
||||
}
|
||||
// 注意:用client判断会导致typescript类型收窄
|
||||
if (!(this.apiClient instanceof OpenAIAPIClient)) {
|
||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||
}
|
||||
if (!(this.apiClient instanceof AnthropicAPIClient)) {
|
||||
builder.remove(RawStreamListenerMiddlewareName)
|
||||
}
|
||||
if (!params.enableWebSearch) {
|
||||
builder.remove(WebSearchMiddlewareName)
|
||||
}
|
||||
if (!params.mcpTools?.length) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
builder.remove(McpToolChunkMiddlewareName)
|
||||
}
|
||||
if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
}
|
||||
if (params.callType !== 'chat') {
|
||||
builder.remove(AbortHandlerMiddlewareName)
|
||||
}
|
||||
}
|
||||
|
||||
const middlewares = builder.build()
|
||||
|
||||
// 3. Create the wrapped SDK method with middlewares
|
||||
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
||||
|
||||
// 4. Execute the wrapped method with the original params
|
||||
return wrappedCompletionMethod(params, options)
|
||||
}
|
||||
|
||||
public async models(): Promise<SdkModel[]> {
|
||||
return this.apiClient.listModels()
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
try {
|
||||
// Use the SDK instance to test embedding capabilities
|
||||
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
||||
return dimensions
|
||||
} catch (error) {
|
||||
console.error('Error getting embedding dimensions:', error)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
return this.apiClient.generateImage(params)
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.apiClient.getBaseURL()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.apiClient.getApiKey()
|
||||
}
|
||||
}
|
||||
182
src/renderer/src/aiCore/middleware/BUILDER_USAGE.md
Normal file
182
src/renderer/src/aiCore/middleware/BUILDER_USAGE.md
Normal file
@ -0,0 +1,182 @@
|
||||
# MiddlewareBuilder 使用指南
|
||||
|
||||
`MiddlewareBuilder` 是一个用于动态构建和管理中间件链的工具,提供灵活的中间件组织和配置能力。
|
||||
|
||||
## 主要特性
|
||||
|
||||
### 1. 统一的中间件命名
|
||||
|
||||
所有中间件都通过导出的 `MIDDLEWARE_NAME` 常量标识:
|
||||
|
||||
```typescript
|
||||
// 中间件文件示例
|
||||
export const MIDDLEWARE_NAME = 'SdkCallMiddleware'
|
||||
export const SdkCallMiddleware: CompletionsMiddleware = ...
|
||||
```
|
||||
|
||||
### 2. NamedMiddleware 接口
|
||||
|
||||
中间件使用统一的 `NamedMiddleware` 接口格式:
|
||||
|
||||
```typescript
|
||||
interface NamedMiddleware<TMiddleware = any> {
|
||||
name: string
|
||||
middleware: TMiddleware
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 中间件注册表
|
||||
|
||||
通过 `MiddlewareRegistry` 集中管理所有可用中间件:
|
||||
|
||||
```typescript
|
||||
import { MiddlewareRegistry } from './register'
|
||||
|
||||
// 通过名称获取中间件
|
||||
const sdkCallMiddleware = MiddlewareRegistry['SdkCallMiddleware']
|
||||
```
|
||||
|
||||
## 基本用法
|
||||
|
||||
### 1. 使用默认中间件链
|
||||
|
||||
```typescript
|
||||
import { CompletionsMiddlewareBuilder } from './builder'
|
||||
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
const middlewares = builder.build()
|
||||
```
|
||||
|
||||
### 2. 自定义中间件链
|
||||
|
||||
```typescript
|
||||
import { createCompletionsBuilder, MiddlewareRegistry } from './builder'
|
||||
|
||||
const builder = createCompletionsBuilder([
|
||||
MiddlewareRegistry['AbortHandlerMiddleware'],
|
||||
MiddlewareRegistry['TextChunkMiddleware']
|
||||
])
|
||||
|
||||
const middlewares = builder.build()
|
||||
```
|
||||
|
||||
### 3. 动态调整中间件链
|
||||
|
||||
```typescript
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
|
||||
// 根据条件添加、移除、替换中间件
|
||||
if (needsLogging) {
|
||||
builder.prepend(MiddlewareRegistry['GenericLoggingMiddleware'])
|
||||
}
|
||||
|
||||
if (disableTools) {
|
||||
builder.remove('McpToolChunkMiddleware')
|
||||
}
|
||||
|
||||
if (customThinking) {
|
||||
builder.replace('ThinkingTagExtractionMiddleware', customThinkingMiddleware)
|
||||
}
|
||||
|
||||
const middlewares = builder.build()
|
||||
```
|
||||
|
||||
### 4. 链式操作
|
||||
|
||||
```typescript
|
||||
const middlewares = CompletionsMiddlewareBuilder.withDefaults()
|
||||
.add(MiddlewareRegistry['CustomMiddleware'])
|
||||
.insertBefore('SdkCallMiddleware', MiddlewareRegistry['SecurityCheckMiddleware'])
|
||||
.remove('WebSearchMiddleware')
|
||||
.build()
|
||||
```
|
||||
|
||||
## API 参考
|
||||
|
||||
### CompletionsMiddlewareBuilder
|
||||
|
||||
**静态方法:**
|
||||
|
||||
- `static withDefaults()`: 创建带有默认中间件链的构建器
|
||||
|
||||
**实例方法:**
|
||||
|
||||
- `add(middleware: NamedMiddleware)`: 在链末尾添加中间件
|
||||
- `prepend(middleware: NamedMiddleware)`: 在链开头添加中间件
|
||||
- `insertAfter(targetName: string, middleware: NamedMiddleware)`: 在指定中间件后插入
|
||||
- `insertBefore(targetName: string, middleware: NamedMiddleware)`: 在指定中间件前插入
|
||||
- `replace(targetName: string, middleware: NamedMiddleware)`: 替换指定中间件
|
||||
- `remove(targetName: string)`: 移除指定中间件
|
||||
- `has(name: string)`: 检查是否包含指定中间件
|
||||
- `build()`: 构建最终的中间件数组
|
||||
- `getChain()`: 获取当前链(包含名称信息)
|
||||
- `clear()`: 清空中间件链
|
||||
- `execute(context, params, middlewareExecutor)`: 直接执行构建好的中间件链
|
||||
|
||||
### 工厂函数
|
||||
|
||||
- `createCompletionsBuilder(baseChain?)`: 创建 Completions 中间件构建器
|
||||
- `createMethodBuilder(baseChain?)`: 创建通用方法中间件构建器
|
||||
- `addMiddlewareName(middleware, name)`: 为中间件添加名称属性的辅助函数
|
||||
|
||||
### 中间件注册表
|
||||
|
||||
- `MiddlewareRegistry`: 所有注册中间件的集中访问点
|
||||
- `getMiddleware(name)`: 根据名称获取中间件
|
||||
- `getRegisteredMiddlewareNames()`: 获取所有注册的中间件名称
|
||||
- `DefaultCompletionsNamedMiddlewares`: 默认的 Completions 中间件链(NamedMiddleware 格式)
|
||||
|
||||
## 类型安全
|
||||
|
||||
构建器提供完整的 TypeScript 类型支持:
|
||||
|
||||
- `CompletionsMiddlewareBuilder` 专门用于 `CompletionsMiddleware` 类型
|
||||
- `MethodMiddlewareBuilder` 用于通用的 `MethodMiddleware` 类型
|
||||
- 所有中间件操作都基于 `NamedMiddleware<TMiddleware>` 接口
|
||||
|
||||
## 默认中间件链
|
||||
|
||||
默认的 Completions 中间件执行顺序:
|
||||
|
||||
1. `FinalChunkConsumerMiddleware` - 最终消费者
|
||||
2. `TransformCoreToSdkParamsMiddleware` - 参数转换
|
||||
3. `AbortHandlerMiddleware` - 中止处理
|
||||
4. `McpToolChunkMiddleware` - 工具处理
|
||||
5. `WebSearchMiddleware` - Web搜索处理
|
||||
6. `TextChunkMiddleware` - 文本处理
|
||||
7. `ThinkingTagExtractionMiddleware` - 思考标签提取处理
|
||||
8. `ThinkChunkMiddleware` - 思考处理
|
||||
9. `ResponseTransformMiddleware` - 响应转换
|
||||
10. `StreamAdapterMiddleware` - 流适配器
|
||||
11. `SdkCallMiddleware` - SDK调用
|
||||
|
||||
## 在 AiProvider 中的使用
|
||||
|
||||
```typescript
|
||||
export default class AiProvider {
|
||||
public async completions(params: CompletionsParams): Promise<CompletionsResult> {
|
||||
// 1. 构建中间件链
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
|
||||
// 2. 根据参数动态调整
|
||||
if (params.enableCustomFeature) {
|
||||
builder.insertAfter('StreamAdapterMiddleware', customFeatureMiddleware)
|
||||
}
|
||||
|
||||
// 3. 应用中间件
|
||||
const middlewares = builder.build()
|
||||
const wrappedMethod = applyCompletionsMiddlewares(this.apiClient, this.apiClient.createCompletions, middlewares)
|
||||
|
||||
return wrappedMethod(params)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **类型兼容性**:`MethodMiddleware` 和 `CompletionsMiddleware` 不兼容,需要使用对应的构建器
|
||||
2. **中间件名称**:所有中间件必须导出 `MIDDLEWARE_NAME` 常量用于标识
|
||||
3. **注册表管理**:新增中间件需要在 `register.ts` 中注册
|
||||
4. **默认链**:默认链通过 `DefaultCompletionsNamedMiddlewares` 提供,支持延迟加载避免循环依赖
|
||||
|
||||
这种设计使得中间件链的构建既灵活又类型安全,同时保持了简洁的 API 接口。
|
||||
175
src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md
Normal file
175
src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md
Normal file
@ -0,0 +1,175 @@
|
||||
# Cherry Studio 中间件规范
|
||||
|
||||
本文档定义了 Cherry Studio `aiCore` 模块中中间件的设计、实现和使用规范。目标是建立一个灵活、可维护且易于扩展的中间件系统。
|
||||
|
||||
## 1. 核心概念
|
||||
|
||||
### 1.1. 中间件 (Middleware)
|
||||
|
||||
中间件是一个函数或对象,它在 AI 请求的处理流程中的特定阶段执行,可以访问和修改请求上下文 (`AiProviderMiddlewareContext`)、请求参数 (`Params`),并控制是否将请求传递给下一个中间件或终止流程。
|
||||
|
||||
每个中间件应该专注于一个单一的横切关注点,例如日志记录、错误处理、流适配、特性解析等。
|
||||
|
||||
### 1.2. `AiProviderMiddlewareContext` (上下文对象)
|
||||
|
||||
这是一个在整个中间件链执行过程中传递的对象,包含以下核心信息:
|
||||
|
||||
- `_apiClientInstance: ApiClient<any,any,any>`: 当前选定的、已实例化的 AI Provider 客户端。
|
||||
- `_coreRequest: CoreRequestType`: 标准化的内部核心请求对象。
|
||||
- `resolvePromise: (value: AggregatedResultType) => void`: 用于在整个操作成功完成时解析 `AiCoreService` 返回的 Promise。
|
||||
- `rejectPromise: (reason?: any) => void`: 用于在发生错误时拒绝 `AiCoreService` 返回的 Promise。
|
||||
- `onChunk?: (chunk: Chunk) => void`: 应用层提供的流式数据块回调。
|
||||
- `abortController?: AbortController`: 用于中止请求的控制器。
|
||||
- 其他中间件可能读写的、与当前请求相关的动态数据。
|
||||
|
||||
### 1.3. `MiddlewareName` (中间件名称)
|
||||
|
||||
为了方便动态操作(如插入、替换、移除)中间件,每个重要的、可能被其他逻辑引用的中间件都应该有一个唯一的、可识别的名称。推荐使用 TypeScript 的 `enum` 来定义:
|
||||
|
||||
```typescript
|
||||
// example
|
||||
export enum MiddlewareName {
|
||||
LOGGING_START = 'LoggingStartMiddleware',
|
||||
LOGGING_END = 'LoggingEndMiddleware',
|
||||
ERROR_HANDLING = 'ErrorHandlingMiddleware',
|
||||
ABORT_HANDLER = 'AbortHandlerMiddleware',
|
||||
// Core Flow
|
||||
TRANSFORM_CORE_TO_SDK_PARAMS = 'TransformCoreToSdkParamsMiddleware',
|
||||
REQUEST_EXECUTION = 'RequestExecutionMiddleware',
|
||||
STREAM_ADAPTER = 'StreamAdapterMiddleware',
|
||||
RAW_SDK_CHUNK_TO_APP_CHUNK = 'RawSdkChunkToAppChunkMiddleware',
|
||||
// Features
|
||||
THINKING_TAG_EXTRACTION = 'ThinkingTagExtractionMiddleware',
|
||||
TOOL_USE_TAG_EXTRACTION = 'ToolUseTagExtractionMiddleware',
|
||||
MCP_TOOL_HANDLER = 'McpToolHandlerMiddleware',
|
||||
// Finalization
|
||||
FINAL_CHUNK_CONSUMER = 'FinalChunkConsumerAndNotifierMiddleware'
|
||||
// Add more as needed
|
||||
}
|
||||
```
|
||||
|
||||
中间件实例需要某种方式暴露其 `MiddlewareName`,例如通过一个 `name` 属性。
|
||||
|
||||
### 1.4. 中间件执行结构
|
||||
|
||||
我们采用一种灵活的中间件执行结构。一个中间件通常是一个函数,它接收 `Context`、`Params`,以及一个 `next` 函数(用于调用链中的下一个中间件)。
|
||||
|
||||
```typescript
|
||||
// 简化形式的中间件函数签名
|
||||
type MiddlewareFunction = (
|
||||
context: AiProviderMiddlewareContext,
|
||||
params: any, // e.g., CompletionsParams
|
||||
next: () => Promise<void> // next 通常返回 Promise 以支持异步操作
|
||||
) => Promise<void> // 中间件自身也可能返回 Promise
|
||||
|
||||
// 或者更经典的 Koa/Express 风格 (三段式)
|
||||
// type MiddlewareFactory = (api?: MiddlewareApi) =>
|
||||
// (nextMiddleware: (ctx: AiProviderMiddlewareContext, params: any) => Promise<void>) =>
|
||||
// (context: AiProviderMiddlewareContext, params: any) => Promise<void>;
|
||||
// 当前设计更倾向于上述简化的 MiddlewareFunction,由 MiddlewareExecutor 负责 next 的编排。
|
||||
```
|
||||
|
||||
`MiddlewareExecutor` (或 `applyMiddlewares`) 会负责管理 `next` 的调用。
|
||||
|
||||
## 2. `MiddlewareBuilder` (通用中间件构建器)
|
||||
|
||||
为了动态构建和管理中间件链,我们引入一个通用的 `MiddlewareBuilder` 类。
|
||||
|
||||
### 2.1. 设计理念
|
||||
|
||||
`MiddlewareBuilder` 提供了一个流式 API,用于以声明式的方式构建中间件链。它允许从一个基础链开始,然后根据特定条件添加、插入、替换或移除中间件。
|
||||
|
||||
### 2.2. API 概览
|
||||
|
||||
```typescript
|
||||
class MiddlewareBuilder {
|
||||
constructor(baseChain?: Middleware[])
|
||||
|
||||
add(middleware: Middleware): this
|
||||
prepend(middleware: Middleware): this
|
||||
insertAfter(targetName: MiddlewareName, middlewareToInsert: Middleware): this
|
||||
insertBefore(targetName: MiddlewareName, middlewareToInsert: Middleware): this
|
||||
replace(targetName: MiddlewareName, newMiddleware: Middleware): this
|
||||
remove(targetName: MiddlewareName): this
|
||||
|
||||
build(): Middleware[] // 返回构建好的中间件数组
|
||||
|
||||
// 可选:直接执行链
|
||||
execute(
|
||||
context: AiProviderMiddlewareContext,
|
||||
params: any,
|
||||
middlewareExecutor: (chain: Middleware[], context: AiProviderMiddlewareContext, params: any) => void
|
||||
): void
|
||||
}
|
||||
```
|
||||
|
||||
### 2.3. 使用示例
|
||||
|
||||
```typescript
|
||||
// 1. 定义一些中间件实例 (假设它们有 .name 属性)
|
||||
const loggingStart = { name: MiddlewareName.LOGGING_START, fn: loggingStartFn }
|
||||
const requestExec = { name: MiddlewareName.REQUEST_EXECUTION, fn: requestExecFn }
|
||||
const streamAdapter = { name: MiddlewareName.STREAM_ADAPTER, fn: streamAdapterFn }
|
||||
const customFeature = { name: MiddlewareName.CUSTOM_FEATURE, fn: customFeatureFn } // 假设自定义
|
||||
|
||||
// 2. 定义一个基础链 (可选)
|
||||
const BASE_CHAIN: Middleware[] = [loggingStart, requestExec, streamAdapter]
|
||||
|
||||
// 3. 使用 MiddlewareBuilder
|
||||
const builder = new MiddlewareBuilder(BASE_CHAIN)
|
||||
|
||||
if (params.needsCustomFeature) {
|
||||
builder.insertAfter(MiddlewareName.STREAM_ADAPTER, customFeature)
|
||||
}
|
||||
|
||||
if (params.isHighSecurityContext) {
|
||||
builder.insertBefore(MiddlewareName.REQUEST_EXECUTION, высокоSecurityCheckMiddleware)
|
||||
}
|
||||
|
||||
if (params.overrideLogging) {
|
||||
builder.replace(MiddlewareName.LOGGING_START, newSpecialLoggingMiddleware)
|
||||
}
|
||||
|
||||
// 4. 获取最终链
|
||||
const finalChain = builder.build()
|
||||
|
||||
// 5. 执行 (通过外部执行器)
|
||||
// middlewareExecutor(finalChain, context, params);
|
||||
// 或者 builder.execute(context, params, middlewareExecutor);
|
||||
```
|
||||
|
||||
## 3. `MiddlewareExecutor` / `applyMiddlewares` (中间件执行器)
|
||||
|
||||
这是负责接收 `MiddlewareBuilder` 构建的中间件链并实际执行它们的组件。
|
||||
|
||||
### 3.1. 职责
|
||||
|
||||
- 接收 `Middleware[]`, `AiProviderMiddlewareContext`, `Params`。
|
||||
- 按顺序迭代中间件。
|
||||
- 为每个中间件提供正确的 `next` 函数,该函数在被调用时会执行链中的下一个中间件。
|
||||
- 处理中间件执行过程中的Promise(如果中间件是异步的)。
|
||||
- 基础的错误捕获(具体错误处理应由链内的 `ErrorHandlingMiddleware` 负责)。
|
||||
|
||||
## 4. 在 `AiCoreService` 中使用
|
||||
|
||||
`AiCoreService` 中的每个核心业务方法 (如 `executeCompletions`) 将负责:
|
||||
|
||||
1. 准备基础数据:实例化 `ApiClient`,转换 `Params` 为 `CoreRequest`。
|
||||
2. 实例化 `MiddlewareBuilder`,可能会传入一个特定于该业务方法的基础中间件链。
|
||||
3. 根据 `Params` 和 `CoreRequest` 中的条件,调用 `MiddlewareBuilder` 的方法来动态调整中间件链。
|
||||
4. 调用 `MiddlewareBuilder.build()` 获取最终的中间件链。
|
||||
5. 创建完整的 `AiProviderMiddlewareContext` (包含 `resolvePromise`, `rejectPromise` 等)。
|
||||
6. 调用 `MiddlewareExecutor` (或 `applyMiddlewares`) 来执行构建好的链。
|
||||
|
||||
## 5. 组合功能
|
||||
|
||||
对于组合功能(例如 "Completions then Translate"):
|
||||
|
||||
- 不推荐创建一个单一、庞大的 `MiddlewareBuilder` 来处理整个组合流程。
|
||||
- 推荐在 `AiCoreService` 中创建一个新的方法,该方法按顺序 `await` 调用底层的原子 `AiCoreService` 方法(例如,先 `await this.executeCompletions(...)`,然后用其结果 `await this.translateText(...)`)。
|
||||
- 每个被调用的原子方法内部会使用其自身的 `MiddlewareBuilder` 实例来构建和执行其特定阶段的中间件链。
|
||||
- 这种方式最大化了复用,并保持了各部分职责的清晰。
|
||||
|
||||
## 6. 中间件命名和发现
|
||||
|
||||
为中间件赋予唯一的 `MiddlewareName` 对于 `MiddlewareBuilder` 的 `insertAfter`, `insertBefore`, `replace`, `remove` 等操作至关重要。确保中间件实例能够以某种方式暴露其名称(例如,一个 `name` 属性)。
|
||||
241
src/renderer/src/aiCore/middleware/builder.ts
Normal file
241
src/renderer/src/aiCore/middleware/builder.ts
Normal file
@ -0,0 +1,241 @@
|
||||
import { DefaultCompletionsNamedMiddlewares } from './register'
|
||||
import { BaseContext, CompletionsMiddleware, MethodMiddleware } from './types'
|
||||
|
||||
/**
|
||||
* 带有名称标识的中间件接口
|
||||
*/
|
||||
export interface NamedMiddleware<TMiddleware = any> {
|
||||
name: string
|
||||
middleware: TMiddleware
|
||||
}
|
||||
|
||||
/**
|
||||
* 中间件执行器函数类型
|
||||
*/
|
||||
export type MiddlewareExecutor<TContext extends BaseContext = BaseContext> = (
|
||||
chain: any[],
|
||||
context: TContext,
|
||||
params: any
|
||||
) => Promise<any>
|
||||
|
||||
/**
|
||||
* 通用中间件构建器类
|
||||
* 提供流式 API 用于动态构建和管理中间件链
|
||||
*
|
||||
* 注意:所有中间件都通过 MiddlewareRegistry 管理,使用 NamedMiddleware 格式
|
||||
*/
|
||||
export class MiddlewareBuilder<TMiddleware = any> {
|
||||
private middlewares: NamedMiddleware<TMiddleware>[]
|
||||
|
||||
/**
|
||||
* 构造函数
|
||||
* @param baseChain - 可选的基础中间件链(NamedMiddleware 格式)
|
||||
*/
|
||||
constructor(baseChain?: NamedMiddleware<TMiddleware>[]) {
|
||||
this.middlewares = baseChain ? [...baseChain] : []
|
||||
}
|
||||
|
||||
/**
|
||||
* 在链的末尾添加中间件
|
||||
* @param middleware - 要添加的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
add(middleware: NamedMiddleware<TMiddleware>): this {
|
||||
this.middlewares.push(middleware)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 在链的开头添加中间件
|
||||
* @param middleware - 要添加的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
prepend(middleware: NamedMiddleware<TMiddleware>): this {
|
||||
this.middlewares.unshift(middleware)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 在指定中间件之后插入新中间件
|
||||
* @param targetName - 目标中间件名称
|
||||
* @param middlewareToInsert - 要插入的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
insertAfter(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares.splice(index + 1, 0, middlewareToInsert)
|
||||
} else {
|
||||
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 在指定中间件之前插入新中间件
|
||||
* @param targetName - 目标中间件名称
|
||||
* @param middlewareToInsert - 要插入的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
insertBefore(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares.splice(index, 0, middlewareToInsert)
|
||||
} else {
|
||||
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 替换指定的中间件
|
||||
* @param targetName - 要替换的中间件名称
|
||||
* @param newMiddleware - 新的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
replace(targetName: string, newMiddleware: NamedMiddleware<TMiddleware>): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares[index] = newMiddleware
|
||||
} else {
|
||||
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法替换`)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除指定的中间件
|
||||
* @param targetName - 要移除的中间件名称
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
remove(targetName: string): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares.splice(index, 1)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建最终的中间件数组
|
||||
* @returns 构建好的中间件数组
|
||||
*/
|
||||
build(): TMiddleware[] {
|
||||
return this.middlewares.map((item) => item.middleware)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前中间件链的副本(包含名称信息)
|
||||
* @returns 当前中间件链的副本
|
||||
*/
|
||||
getChain(): NamedMiddleware<TMiddleware>[] {
|
||||
return [...this.middlewares]
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否包含指定名称的中间件
|
||||
* @param name - 中间件名称
|
||||
* @returns 是否包含该中间件
|
||||
*/
|
||||
has(name: string): boolean {
|
||||
return this.findMiddlewareIndex(name) !== -1
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取中间件链的长度
|
||||
* @returns 中间件数量
|
||||
*/
|
||||
get length(): number {
|
||||
return this.middlewares.length
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空中间件链
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
clear(): this {
|
||||
this.middlewares = []
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接执行构建好的中间件链
|
||||
* @param context - 中间件上下文
|
||||
* @param params - 参数
|
||||
* @param middlewareExecutor - 中间件执行器
|
||||
* @returns 执行结果
|
||||
*/
|
||||
execute<TContext extends BaseContext>(
|
||||
context: TContext,
|
||||
params: any,
|
||||
middlewareExecutor: MiddlewareExecutor<TContext>
|
||||
): Promise<any> {
|
||||
const chain = this.build()
|
||||
return middlewareExecutor(chain, context, params)
|
||||
}
|
||||
|
||||
/**
|
||||
* 查找中间件在链中的索引
|
||||
* @param name - 中间件名称
|
||||
* @returns 索引,如果未找到返回 -1
|
||||
*/
|
||||
private findMiddlewareIndex(name: string): number {
|
||||
return this.middlewares.findIndex((item) => item.name === name)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Completions 中间件构建器
|
||||
*/
|
||||
export class CompletionsMiddlewareBuilder extends MiddlewareBuilder<CompletionsMiddleware> {
|
||||
constructor(baseChain?: NamedMiddleware<CompletionsMiddleware>[]) {
|
||||
super(baseChain)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用默认的 Completions 中间件链
|
||||
* @returns CompletionsMiddlewareBuilder 实例
|
||||
*/
|
||||
static withDefaults(): CompletionsMiddlewareBuilder {
|
||||
return new CompletionsMiddlewareBuilder(DefaultCompletionsNamedMiddlewares)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用方法中间件构建器
|
||||
*/
|
||||
export class MethodMiddlewareBuilder extends MiddlewareBuilder<MethodMiddleware> {
|
||||
constructor(baseChain?: NamedMiddleware<MethodMiddleware>[]) {
|
||||
super(baseChain)
|
||||
}
|
||||
}
|
||||
|
||||
// 便捷的工厂函数
|
||||
|
||||
/**
|
||||
* 创建 Completions 中间件构建器
|
||||
* @param baseChain - 可选的基础链
|
||||
* @returns Completions 中间件构建器实例
|
||||
*/
|
||||
export function createCompletionsBuilder(
|
||||
baseChain?: NamedMiddleware<CompletionsMiddleware>[]
|
||||
): CompletionsMiddlewareBuilder {
|
||||
return new CompletionsMiddlewareBuilder(baseChain)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建通用方法中间件构建器
|
||||
* @param baseChain - 可选的基础链
|
||||
* @returns 通用方法中间件构建器实例
|
||||
*/
|
||||
export function createMethodBuilder(baseChain?: NamedMiddleware<MethodMiddleware>[]): MethodMiddlewareBuilder {
|
||||
return new MethodMiddlewareBuilder(baseChain)
|
||||
}
|
||||
|
||||
/**
|
||||
* 为中间件添加名称属性的辅助函数
|
||||
* 可以用于给现有的中间件添加名称属性
|
||||
*/
|
||||
export function addMiddlewareName<T extends object>(middleware: T, name: string): T & { MIDDLEWARE_NAME: string } {
|
||||
return Object.assign(middleware, { MIDDLEWARE_NAME: name })
|
||||
}
|
||||
@ -0,0 +1,106 @@
|
||||
import { Chunk, ChunkType, ErrorChunk } from '@renderer/types/chunk'
|
||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||
|
||||
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'AbortHandlerMiddleware'
|
||||
|
||||
export const AbortHandlerMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const isRecursiveCall = ctx._internal?.toolProcessingState?.isRecursiveCall || false
|
||||
|
||||
// 在递归调用中,跳过 AbortController 的创建,直接使用已有的
|
||||
if (isRecursiveCall) {
|
||||
const result = await next(ctx, params)
|
||||
return result
|
||||
}
|
||||
|
||||
// 获取当前消息的ID用于abort管理
|
||||
// 优先使用处理过的消息,如果没有则使用原始消息
|
||||
let messageId: string | undefined
|
||||
|
||||
if (typeof params.messages === 'string') {
|
||||
messageId = `message-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`
|
||||
} else {
|
||||
const processedMessages = params.messages
|
||||
const lastUserMessage = processedMessages.findLast((m) => m.role === 'user')
|
||||
messageId = lastUserMessage?.id
|
||||
}
|
||||
|
||||
if (!messageId) {
|
||||
console.warn(`[${MIDDLEWARE_NAME}] No messageId found, abort functionality will not be available.`)
|
||||
return next(ctx, params)
|
||||
}
|
||||
|
||||
const abortController = new AbortController()
|
||||
const abortFn = (): void => abortController.abort()
|
||||
|
||||
addAbortController(messageId, abortFn)
|
||||
|
||||
let abortSignal: AbortSignal | null = abortController.signal
|
||||
|
||||
const cleanup = (): void => {
|
||||
removeAbortController(messageId as string, abortFn)
|
||||
if (ctx._internal?.flowControl) {
|
||||
ctx._internal.flowControl.abortController = undefined
|
||||
ctx._internal.flowControl.abortSignal = undefined
|
||||
ctx._internal.flowControl.cleanup = undefined
|
||||
}
|
||||
abortSignal = null
|
||||
}
|
||||
|
||||
// 将controller添加到_internal中的flowControl状态
|
||||
if (!ctx._internal.flowControl) {
|
||||
ctx._internal.flowControl = {}
|
||||
}
|
||||
ctx._internal.flowControl.abortController = abortController
|
||||
ctx._internal.flowControl.abortSignal = abortSignal
|
||||
ctx._internal.flowControl.cleanup = cleanup
|
||||
|
||||
const result = await next(ctx, params)
|
||||
|
||||
const error = new DOMException('Request was aborted', 'AbortError')
|
||||
|
||||
const streamWithAbortHandler = (result.stream as ReadableStream<Chunk>).pipeThrough(
|
||||
new TransformStream<Chunk, Chunk | ErrorChunk>({
|
||||
transform(chunk, controller) {
|
||||
// 检查 abort 状态
|
||||
if (abortSignal?.aborted) {
|
||||
// 转换为 ErrorChunk
|
||||
const errorChunk: ErrorChunk = {
|
||||
type: ChunkType.ERROR,
|
||||
error
|
||||
}
|
||||
|
||||
controller.enqueue(errorChunk)
|
||||
cleanup()
|
||||
return
|
||||
}
|
||||
|
||||
// 正常传递 chunk
|
||||
controller.enqueue(chunk)
|
||||
},
|
||||
|
||||
flush(controller) {
|
||||
// 在流结束时再次检查 abort 状态
|
||||
if (abortSignal?.aborted) {
|
||||
const errorChunk: ErrorChunk = {
|
||||
type: ChunkType.ERROR,
|
||||
error
|
||||
}
|
||||
controller.enqueue(errorChunk)
|
||||
}
|
||||
// 在流完全处理完成后清理 AbortController
|
||||
cleanup()
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return {
|
||||
...result,
|
||||
stream: streamWithAbortHandler
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,60 @@
|
||||
import { Chunk } from '@renderer/types/chunk'
|
||||
import { isAbortError } from '@renderer/utils/error'
|
||||
|
||||
import { CompletionsResult } from '../schemas'
|
||||
import { CompletionsContext } from '../types'
|
||||
import { createErrorChunk } from '../utils'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ErrorHandlerMiddleware'
|
||||
|
||||
/**
|
||||
* 创建一个错误处理中间件。
|
||||
*
|
||||
* 这是一个高阶函数,它接收配置并返回一个标准的中间件。
|
||||
* 它的主要职责是捕获下游中间件或API调用中发生的任何错误。
|
||||
*
|
||||
* @param config - 中间件的配置。
|
||||
* @returns 一个配置好的CompletionsMiddleware。
|
||||
*/
|
||||
export const ErrorHandlerMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params): Promise<CompletionsResult> => {
|
||||
const { shouldThrow } = params
|
||||
|
||||
try {
|
||||
// 尝试执行下一个中间件
|
||||
return await next(ctx, params)
|
||||
} catch (error: any) {
|
||||
let errorStream: ReadableStream<Chunk> | undefined
|
||||
// 有些sdk的abort error 是直接抛出的
|
||||
if (!isAbortError(error)) {
|
||||
// 1. 使用通用的工具函数将错误解析为标准格式
|
||||
const errorChunk = createErrorChunk(error)
|
||||
// 2. 调用从外部传入的 onError 回调
|
||||
if (params.onError) {
|
||||
params.onError(error)
|
||||
}
|
||||
|
||||
// 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递
|
||||
if (shouldThrow) {
|
||||
throw error
|
||||
}
|
||||
|
||||
// 如果不抛出,则创建一个只包含该错误块的流并向下传递
|
||||
errorStream = new ReadableStream<Chunk>({
|
||||
start(controller) {
|
||||
controller.enqueue(errorChunk)
|
||||
controller.close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
rawOutput: undefined,
|
||||
stream: errorStream, // 将包含错误的流传递下去
|
||||
controller: undefined,
|
||||
getText: () => '' // 错误情况下没有文本结果
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,183 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { Usage } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'FinalChunkConsumerAndNotifierMiddleware'
|
||||
|
||||
/**
|
||||
* 最终Chunk消费和通知中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 消费所有GenericChunk流中的chunks并转发给onChunk回调
|
||||
* 2. 累加usage/metrics数据(从原始SDK chunks或GenericChunk中提取)
|
||||
* 3. 在检测到LLM_RESPONSE_COMPLETE时发送包含累计数据的BLOCK_COMPLETE
|
||||
* 4. 处理MCP工具调用的多轮请求中的数据累加
|
||||
*/
|
||||
const FinalChunkConsumerMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const isRecursiveCall =
|
||||
params._internal?.toolProcessingState?.isRecursiveCall ||
|
||||
ctx._internal?.toolProcessingState?.isRecursiveCall ||
|
||||
false
|
||||
|
||||
// 初始化累计数据(只在顶层调用时初始化)
|
||||
if (!isRecursiveCall) {
|
||||
if (!ctx._internal.customState) {
|
||||
ctx._internal.customState = {}
|
||||
}
|
||||
ctx._internal.observer = {
|
||||
usage: {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0
|
||||
},
|
||||
metrics: {
|
||||
completion_tokens: 0,
|
||||
time_completion_millsec: 0,
|
||||
time_first_token_millsec: 0,
|
||||
time_thinking_millsec: 0
|
||||
}
|
||||
}
|
||||
// 初始化文本累积器
|
||||
ctx._internal.customState.accumulatedText = ''
|
||||
ctx._internal.customState.startTimestamp = Date.now()
|
||||
}
|
||||
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:处理GenericChunk流式响应
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream
|
||||
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
const reader = resultFromUpstream.getReader()
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value: chunk } = await reader.read()
|
||||
if (done) {
|
||||
Logger.debug(`[${MIDDLEWARE_NAME}] Input stream finished.`)
|
||||
break
|
||||
}
|
||||
|
||||
if (chunk) {
|
||||
const genericChunk = chunk as GenericChunk
|
||||
// 提取并累加usage/metrics数据
|
||||
extractAndAccumulateUsageMetrics(ctx, genericChunk)
|
||||
|
||||
const shouldSkipChunk =
|
||||
isRecursiveCall &&
|
||||
(genericChunk.type === ChunkType.BLOCK_COMPLETE ||
|
||||
genericChunk.type === ChunkType.LLM_RESPONSE_COMPLETE)
|
||||
|
||||
if (!shouldSkipChunk) params.onChunk?.(genericChunk)
|
||||
} else {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] Received undefined chunk before stream was done.`)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error(`[${MIDDLEWARE_NAME}] Error consuming stream:`, error)
|
||||
throw error
|
||||
} finally {
|
||||
if (params.onChunk && !isRecursiveCall) {
|
||||
params.onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
usage: ctx._internal.observer?.usage ? { ...ctx._internal.observer.usage } : undefined,
|
||||
metrics: ctx._internal.observer?.metrics ? { ...ctx._internal.observer.metrics } : undefined
|
||||
}
|
||||
} as Chunk)
|
||||
if (ctx._internal.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState = {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 为流式输出添加getText方法
|
||||
const modifiedResult = {
|
||||
...result,
|
||||
stream: new ReadableStream<GenericChunk>({
|
||||
start(controller) {
|
||||
controller.close()
|
||||
}
|
||||
}),
|
||||
getText: () => {
|
||||
return ctx._internal.customState?.accumulatedText || ''
|
||||
}
|
||||
}
|
||||
|
||||
return modifiedResult
|
||||
} else {
|
||||
Logger.debug(`[${MIDDLEWARE_NAME}] No GenericChunk stream to process.`)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 从GenericChunk或原始SDK chunks中提取usage/metrics数据并累加
|
||||
*/
|
||||
function extractAndAccumulateUsageMetrics(ctx: CompletionsContext, chunk: GenericChunk): void {
|
||||
if (!ctx._internal.observer?.usage || !ctx._internal.observer?.metrics) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
if (ctx._internal.customState && !ctx._internal.customState?.firstTokenTimestamp) {
|
||||
ctx._internal.customState.firstTokenTimestamp = Date.now()
|
||||
Logger.debug(`[${MIDDLEWARE_NAME}] First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`)
|
||||
}
|
||||
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||
Logger.debug(`[${MIDDLEWARE_NAME}] LLM_RESPONSE_COMPLETE chunk received:`, ctx._internal)
|
||||
// 从LLM_RESPONSE_COMPLETE chunk中提取usage数据
|
||||
if (chunk.response?.usage) {
|
||||
accumulateUsage(ctx._internal.observer.usage, chunk.response.usage)
|
||||
}
|
||||
|
||||
if (ctx._internal.customState && ctx._internal.customState?.firstTokenTimestamp) {
|
||||
ctx._internal.observer.metrics.time_first_token_millsec =
|
||||
ctx._internal.customState.firstTokenTimestamp - ctx._internal.customState.startTimestamp
|
||||
ctx._internal.observer.metrics.time_completion_millsec +=
|
||||
Date.now() - ctx._internal.customState.firstTokenTimestamp
|
||||
}
|
||||
}
|
||||
|
||||
// 也可以从其他chunk类型中提取metrics数据
|
||||
if (chunk.type === ChunkType.THINKING_COMPLETE && chunk.thinking_millsec && ctx._internal.observer?.metrics) {
|
||||
ctx._internal.observer.metrics.time_thinking_millsec = Math.max(
|
||||
ctx._internal.observer.metrics.time_thinking_millsec || 0,
|
||||
chunk.thinking_millsec
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`[${MIDDLEWARE_NAME}] Error extracting usage/metrics from chunk:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 累加usage数据
|
||||
*/
|
||||
function accumulateUsage(accumulated: Usage, newUsage: Usage): void {
|
||||
if (newUsage.prompt_tokens !== undefined) {
|
||||
accumulated.prompt_tokens += newUsage.prompt_tokens
|
||||
}
|
||||
if (newUsage.completion_tokens !== undefined) {
|
||||
accumulated.completion_tokens += newUsage.completion_tokens
|
||||
}
|
||||
if (newUsage.total_tokens !== undefined) {
|
||||
accumulated.total_tokens += newUsage.total_tokens
|
||||
}
|
||||
if (newUsage.thoughts_tokens !== undefined) {
|
||||
accumulated.thoughts_tokens = (accumulated.thoughts_tokens || 0) + newUsage.thoughts_tokens
|
||||
}
|
||||
}
|
||||
|
||||
export default FinalChunkConsumerMiddleware
|
||||
@ -0,0 +1,64 @@
|
||||
import { BaseContext, MethodMiddleware, MiddlewareAPI } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'GenericLoggingMiddlewares'
|
||||
|
||||
/**
|
||||
* Helper function to safely stringify arguments for logging, handling circular references and large objects.
|
||||
* 安全地字符串化日志参数的辅助函数,处理循环引用和大型对象。
|
||||
* @param args - The arguments array to stringify. 要字符串化的参数数组。
|
||||
* @returns A string representation of the arguments. 参数的字符串表示形式。
|
||||
*/
|
||||
const stringifyArgsForLogging = (args: any[]): string => {
|
||||
try {
|
||||
return args
|
||||
.map((arg) => {
|
||||
if (typeof arg === 'function') return '[Function]'
|
||||
if (typeof arg === 'object' && arg !== null && arg.constructor === Object && Object.keys(arg).length > 20) {
|
||||
return '[Object with >20 keys]'
|
||||
}
|
||||
// Truncate long strings to avoid flooding logs 截断长字符串以避免日志泛滥
|
||||
const stringifiedArg = JSON.stringify(arg, null, 2)
|
||||
return stringifiedArg && stringifiedArg.length > 200 ? stringifiedArg.substring(0, 200) + '...' : stringifiedArg
|
||||
})
|
||||
.join(', ')
|
||||
} catch (e) {
|
||||
return '[Error serializing arguments]' // Handle potential errors during stringification 处理字符串化期间的潜在错误
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generic logging middleware for provider methods.
|
||||
* 为提供者方法创建一个通用的日志中间件。
|
||||
* This middleware logs the initiation, success/failure, and duration of a method call.
|
||||
* 此中间件记录方法调用的启动、成功/失败以及持续时间。
|
||||
*/
|
||||
|
||||
/**
|
||||
* Creates a generic logging middleware for provider methods.
|
||||
* 为提供者方法创建一个通用的日志中间件。
|
||||
* @returns A `MethodMiddleware` instance. 一个 `MethodMiddleware` 实例。
|
||||
*/
|
||||
export const createGenericLoggingMiddleware: () => MethodMiddleware = () => {
|
||||
const middlewareName = 'GenericLoggingMiddleware'
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
return (_: MiddlewareAPI<BaseContext, any[]>) => (next) => async (ctx, args) => {
|
||||
const methodName = ctx.methodName
|
||||
const logPrefix = `[${middlewareName} (${methodName})]`
|
||||
console.log(`${logPrefix} Initiating. Args:`, stringifyArgsForLogging(args))
|
||||
const startTime = Date.now()
|
||||
try {
|
||||
const result = await next(ctx, args)
|
||||
const duration = Date.now() - startTime
|
||||
// Log successful completion of the method call with duration. /
|
||||
// 记录方法调用成功完成及其持续时间。
|
||||
console.log(`${logPrefix} Successful. Duration: ${duration}ms`)
|
||||
return result
|
||||
} catch (error) {
|
||||
const duration = Date.now() - startTime
|
||||
// Log failure of the method call with duration and error information. /
|
||||
// 记录方法调用失败及其持续时间和错误信息。
|
||||
console.error(`${logPrefix} Failed. Duration: ${duration}ms`, error)
|
||||
throw error // Re-throw the error to be handled by subsequent layers or the caller / 重新抛出错误,由后续层或调用者处理
|
||||
}
|
||||
}
|
||||
}
|
||||
285
src/renderer/src/aiCore/middleware/composer.ts
Normal file
285
src/renderer/src/aiCore/middleware/composer.ts
Normal file
@ -0,0 +1,285 @@
|
||||
import {
|
||||
RequestOptions,
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
|
||||
import { BaseApiClient } from '../clients'
|
||||
import { CompletionsParams, CompletionsResult } from './schemas'
|
||||
import {
|
||||
BaseContext,
|
||||
CompletionsContext,
|
||||
CompletionsMiddleware,
|
||||
MethodMiddleware,
|
||||
MIDDLEWARE_CONTEXT_SYMBOL,
|
||||
MiddlewareAPI
|
||||
} from './types'
|
||||
|
||||
/**
|
||||
* Creates the initial context for a method call, populating method-specific fields. /
|
||||
* 为方法调用创建初始上下文,并填充特定于该方法的字段。
|
||||
* @param methodName - The name of the method being called. / 被调用的方法名。
|
||||
* @param originalCallArgs - The actual arguments array from the proxy/method call. / 代理/方法调用的实际参数数组。
|
||||
* @param providerId - The ID of the provider, if available. / 提供者的ID(如果可用)。
|
||||
* @param providerInstance - The instance of the provider. / 提供者实例。
|
||||
* @param specificContextFactory - An optional factory function to create a specific context type from the base context and original call arguments. / 一个可选的工厂函数,用于从基础上下文和原始调用参数创建特定的上下文类型。
|
||||
* @returns The created context object. / 创建的上下文对象。
|
||||
*/
|
||||
function createInitialCallContext<TContext extends BaseContext, TCallArgs extends unknown[]>(
|
||||
methodName: string,
|
||||
originalCallArgs: TCallArgs, // Renamed from originalArgs to avoid confusion with context.originalArgs
|
||||
// Factory to create specific context from base and the *original call arguments array*
|
||||
specificContextFactory?: (base: BaseContext, callArgs: TCallArgs) => TContext
|
||||
): TContext {
|
||||
const baseContext: BaseContext = {
|
||||
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
|
||||
methodName,
|
||||
originalArgs: originalCallArgs // Store the full original arguments array in the context
|
||||
}
|
||||
|
||||
if (specificContextFactory) {
|
||||
return specificContextFactory(baseContext, originalCallArgs)
|
||||
}
|
||||
return baseContext as TContext // Fallback to base context if no specific factory
|
||||
}
|
||||
|
||||
/**
|
||||
* Composes an array of functions from right to left. /
|
||||
* 从右到左组合一个函数数组。
|
||||
* `compose(f, g, h)` is `(...args) => f(g(h(...args)))`. /
|
||||
* `compose(f, g, h)` 等同于 `(...args) => f(g(h(...args)))`。
|
||||
* Each function in funcs is expected to take the result of the next function
|
||||
* (or the initial value for the rightmost function) as its argument. /
|
||||
* `funcs` 中的每个函数都期望接收下一个函数的结果(或最右侧函数的初始值)作为其参数。
|
||||
* @param funcs - Array of functions to compose. / 要组合的函数数组。
|
||||
* @returns The composed function. / 组合后的函数。
|
||||
*/
|
||||
function compose(...funcs: Array<(...args: any[]) => any>): (...args: any[]) => any {
|
||||
if (funcs.length === 0) {
|
||||
// If no functions to compose, return a function that returns its first argument, or undefined if no args. /
|
||||
// 如果没有要组合的函数,则返回一个函数,该函数返回其第一个参数,如果没有参数则返回undefined。
|
||||
return (...args: any[]) => (args.length > 0 ? args[0] : undefined)
|
||||
}
|
||||
if (funcs.length === 1) {
|
||||
return funcs[0]
|
||||
}
|
||||
return funcs.reduce(
|
||||
(a, b) =>
|
||||
(...args: any[]) =>
|
||||
a(b(...args))
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies an array of Redux-style middlewares to a generic provider method. /
|
||||
* 将一组Redux风格的中间件应用于一个通用的提供者方法。
|
||||
* This version keeps arguments as an array throughout the middleware chain. /
|
||||
* 此版本在整个中间件链中将参数保持为数组形式。
|
||||
* @param originalProviderInstance - The original provider instance. / 原始提供者实例。
|
||||
* @param methodName - The name of the method to be enhanced. / 需要增强的方法名。
|
||||
* @param originalMethod - The original method to be wrapped. / 需要包装的原始方法。
|
||||
* @param middlewares - An array of `ProviderMethodMiddleware` to apply. / 要应用的 `ProviderMethodMiddleware` 数组。
|
||||
* @param specificContextFactory - An optional factory to create a specific context for this method. / 可选的工厂函数,用于为此方法创建特定的上下文。
|
||||
* @returns An enhanced method with the middlewares applied. / 应用了中间件的增强方法。
|
||||
*/
|
||||
export function applyMethodMiddlewares<
|
||||
TArgs extends unknown[] = unknown[], // Original method's arguments array type / 原始方法的参数数组类型
|
||||
TResult = unknown,
|
||||
TContext extends BaseContext = BaseContext
|
||||
>(
|
||||
methodName: string,
|
||||
originalMethod: (...args: TArgs) => Promise<TResult>,
|
||||
middlewares: MethodMiddleware[], // Expects generic middlewares / 期望通用中间件
|
||||
specificContextFactory?: (base: BaseContext, callArgs: TArgs) => TContext
|
||||
): (...args: TArgs) => Promise<TResult> {
|
||||
// Returns a function matching the original method signature. /
|
||||
// 返回一个与原始方法签名匹配的函数。
|
||||
return async function enhancedMethod(...methodCallArgs: TArgs): Promise<TResult> {
|
||||
const ctx = createInitialCallContext<TContext, TArgs>(
|
||||
methodName,
|
||||
methodCallArgs, // Pass the actual call arguments array / 传递实际的调用参数数组
|
||||
specificContextFactory
|
||||
)
|
||||
|
||||
const api: MiddlewareAPI<TContext, TArgs> = {
|
||||
getContext: () => ctx,
|
||||
getOriginalArgs: () => methodCallArgs // API provides the original arguments array / API提供原始参数数组
|
||||
}
|
||||
|
||||
// `finalDispatch` is the function that will ultimately call the original provider method. /
|
||||
// `finalDispatch` 是最终将调用原始提供者方法的函数。
|
||||
// It receives the current context and arguments, which may have been transformed by middlewares. /
|
||||
// 它接收当前的上下文和参数,这些参数可能已被中间件转换。
|
||||
const finalDispatch = async (
|
||||
_: TContext,
|
||||
currentArgs: TArgs // Generic final dispatch expects args array / 通用finalDispatch期望参数数组
|
||||
): Promise<TResult> => {
|
||||
return originalMethod.apply(currentArgs)
|
||||
}
|
||||
|
||||
const chain = middlewares.map((middleware) => middleware(api)) // Cast API if TContext/TArgs mismatch general ProviderMethodMiddleware / 如果TContext/TArgs与通用的ProviderMethodMiddleware不匹配,则转换API
|
||||
const composedMiddlewareLogic = compose(...chain)
|
||||
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
|
||||
|
||||
return enhancedDispatch(ctx, methodCallArgs) // Pass context and original args array / 传递上下文和原始参数数组
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies an array of `CompletionsMiddleware` to the `completions` method. /
|
||||
* 将一组 `CompletionsMiddleware` 应用于 `completions` 方法。
|
||||
* This version adapts for `CompletionsMiddleware` expecting a single `params` object. /
|
||||
* 此版本适配了期望单个 `params` 对象的 `CompletionsMiddleware`。
|
||||
* @param originalProviderInstance - The original provider instance. / 原始提供者实例。
|
||||
* @param originalCompletionsMethod - The original SDK `createCompletions` method. / 原始的 SDK `createCompletions` 方法。
|
||||
* @param middlewares - An array of `CompletionsMiddleware` to apply. / 要应用的 `CompletionsMiddleware` 数组。
|
||||
* @returns An enhanced `completions` method with the middlewares applied. / 应用了中间件的增强版 `completions` 方法。
|
||||
*/
|
||||
export function applyCompletionsMiddlewares<
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
>(
|
||||
originalApiClientInstance: BaseApiClient<
|
||||
TSdkInstance,
|
||||
TSdkParams,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkSpecificTool
|
||||
>,
|
||||
originalCompletionsMethod: (payload: TSdkParams, options?: RequestOptions) => Promise<TRawOutput>,
|
||||
middlewares: CompletionsMiddleware<
|
||||
TSdkParams,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
>[]
|
||||
): (params: CompletionsParams, options?: RequestOptions) => Promise<CompletionsResult> {
|
||||
// Returns a function matching the original method signature. /
|
||||
// 返回一个与原始方法签名匹配的函数。
|
||||
|
||||
const methodName = 'completions'
|
||||
|
||||
// Factory to create AiProviderMiddlewareCompletionsContext. /
|
||||
// 用于创建 AiProviderMiddlewareCompletionsContext 的工厂函数。
|
||||
const completionsContextFactory = (
|
||||
base: BaseContext,
|
||||
callArgs: [CompletionsParams]
|
||||
): CompletionsContext<
|
||||
TSdkParams,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
> => {
|
||||
return {
|
||||
...base,
|
||||
methodName,
|
||||
apiClientInstance: originalApiClientInstance,
|
||||
originalArgs: callArgs,
|
||||
_internal: {
|
||||
toolProcessingState: {
|
||||
recursionDepth: 0,
|
||||
isRecursiveCall: false
|
||||
},
|
||||
observer: {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return async function enhancedCompletionsMethod(
|
||||
params: CompletionsParams,
|
||||
options?: RequestOptions
|
||||
): Promise<CompletionsResult> {
|
||||
// `originalCallArgs` for context creation is `[params]`. /
|
||||
// 用于上下文创建的 `originalCallArgs` 是 `[params]`。
|
||||
const originalCallArgs: [CompletionsParams] = [params]
|
||||
const baseContext: BaseContext = {
|
||||
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
|
||||
methodName,
|
||||
originalArgs: originalCallArgs
|
||||
}
|
||||
const ctx = completionsContextFactory(baseContext, originalCallArgs)
|
||||
|
||||
const api: MiddlewareAPI<
|
||||
CompletionsContext<TSdkParams, TMessageParam, TToolCall, TSdkInstance, TRawOutput, TRawChunk, TSdkSpecificTool>,
|
||||
[CompletionsParams]
|
||||
> = {
|
||||
getContext: () => ctx,
|
||||
getOriginalArgs: () => originalCallArgs // API provides [CompletionsParams] / API提供 `[CompletionsParams]`
|
||||
}
|
||||
|
||||
// `finalDispatch` for CompletionsMiddleware: expects (context, params) not (context, args_array). /
|
||||
// `CompletionsMiddleware` 的 `finalDispatch`:期望 (context, params) 而不是 (context, args_array)。
|
||||
const finalDispatch = async (
|
||||
context: CompletionsContext<
|
||||
TSdkParams,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
> // Context passed through / 上下文透传
|
||||
// _currentParams: CompletionsParams // Directly takes params / 直接接收参数 (unused but required for middleware signature)
|
||||
): Promise<CompletionsResult> => {
|
||||
// At this point, middleware should have transformed CompletionsParams to SDK params
|
||||
// and stored them in context. If no transformation happened, we need to handle it.
|
||||
// 此时,中间件应该已经将 CompletionsParams 转换为 SDK 参数并存储在上下文中。
|
||||
// 如果没有进行转换,我们需要处理它。
|
||||
|
||||
const sdkPayload = context._internal?.sdkPayload
|
||||
if (!sdkPayload) {
|
||||
throw new Error('SDK payload not found in context. Middleware chain should have transformed parameters.')
|
||||
}
|
||||
|
||||
const abortSignal = context._internal.flowControl?.abortSignal
|
||||
const timeout = context._internal.customState?.sdkMetadata?.timeout
|
||||
|
||||
// Call the original SDK method with transformed parameters
|
||||
// 使用转换后的参数调用原始 SDK 方法
|
||||
const rawOutput = await originalCompletionsMethod.call(originalApiClientInstance, sdkPayload, {
|
||||
...options,
|
||||
signal: abortSignal,
|
||||
timeout
|
||||
})
|
||||
|
||||
// Return result wrapped in CompletionsResult format
|
||||
// 以 CompletionsResult 格式返回包装的结果
|
||||
return {
|
||||
rawOutput
|
||||
} as CompletionsResult
|
||||
}
|
||||
|
||||
const chain = middlewares.map((middleware) => middleware(api))
|
||||
const composedMiddlewareLogic = compose(...chain)
|
||||
|
||||
// `enhancedDispatch` has the signature `(context, params) => Promise<CompletionsResult>`. /
|
||||
// `enhancedDispatch` 的签名为 `(context, params) => Promise<CompletionsResult>`。
|
||||
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
|
||||
|
||||
// 将 enhancedDispatch 保存到 context 中,供中间件进行递归调用
|
||||
// 这样可以避免重复执行整个中间件链
|
||||
ctx._internal.enhancedDispatch = enhancedDispatch
|
||||
|
||||
// Execute with context and the single params object. /
|
||||
// 使用上下文和单个参数对象执行。
|
||||
return enhancedDispatch(ctx, params)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,306 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
|
||||
import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk'
|
||||
import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk'
|
||||
import { parseAndCallTools } from '@renderer/utils/mcp-tools'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'McpToolChunkMiddleware'
|
||||
const MAX_TOOL_RECURSION_DEPTH = 20 // 防止无限递归
|
||||
|
||||
/**
|
||||
* MCP工具处理中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 检测并拦截MCP工具进展chunk(Function Call方式和Tool Use方式)
|
||||
* 2. 执行工具调用
|
||||
* 3. 递归处理工具结果
|
||||
* 4. 管理工具调用状态和递归深度
|
||||
*/
|
||||
export const McpToolChunkMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const mcpTools = params.mcpTools || []
|
||||
|
||||
// 如果没有工具,直接调用下一个中间件
|
||||
if (!mcpTools || mcpTools.length === 0) {
|
||||
return next(ctx, params)
|
||||
}
|
||||
|
||||
const executeWithToolHandling = async (currentParams: CompletionsParams, depth = 0): Promise<CompletionsResult> => {
|
||||
if (depth >= MAX_TOOL_RECURSION_DEPTH) {
|
||||
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Maximum recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
|
||||
throw new Error(`Maximum tool recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
|
||||
}
|
||||
|
||||
let result: CompletionsResult
|
||||
|
||||
if (depth === 0) {
|
||||
result = await next(ctx, currentParams)
|
||||
} else {
|
||||
const enhancedCompletions = ctx._internal.enhancedDispatch
|
||||
if (!enhancedCompletions) {
|
||||
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Enhanced completions method not found, cannot perform recursive call`)
|
||||
throw new Error('Enhanced completions method not found')
|
||||
}
|
||||
|
||||
ctx._internal.toolProcessingState!.isRecursiveCall = true
|
||||
ctx._internal.toolProcessingState!.recursionDepth = depth
|
||||
|
||||
result = await enhancedCompletions(ctx, currentParams)
|
||||
}
|
||||
|
||||
if (!result.stream) {
|
||||
Logger.error(`🔧 [${MIDDLEWARE_NAME}] No stream returned from enhanced completions`)
|
||||
throw new Error('No stream returned from enhanced completions')
|
||||
}
|
||||
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
const toolHandlingStream = resultFromUpstream.pipeThrough(
|
||||
createToolHandlingTransform(ctx, currentParams, mcpTools, depth, executeWithToolHandling)
|
||||
)
|
||||
|
||||
return {
|
||||
...result,
|
||||
stream: toolHandlingStream
|
||||
}
|
||||
}
|
||||
|
||||
return executeWithToolHandling(params, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建工具处理的 TransformStream
|
||||
*/
|
||||
function createToolHandlingTransform(
|
||||
ctx: CompletionsContext,
|
||||
currentParams: CompletionsParams,
|
||||
mcpTools: MCPTool[],
|
||||
depth: number,
|
||||
executeWithToolHandling: (params: CompletionsParams, depth: number) => Promise<CompletionsResult>
|
||||
): TransformStream<GenericChunk, GenericChunk> {
|
||||
const toolCalls: SdkToolCall[] = []
|
||||
const toolUseResponses: MCPToolResponse[] = []
|
||||
const allToolResponses: MCPToolResponse[] = [] // 统一的工具响应状态管理数组
|
||||
let hasToolCalls = false
|
||||
let hasToolUseResponses = false
|
||||
let streamEnded = false
|
||||
|
||||
return new TransformStream({
|
||||
async transform(chunk: GenericChunk, controller) {
|
||||
try {
|
||||
// 处理MCP工具进展chunk
|
||||
if (chunk.type === ChunkType.MCP_TOOL_CREATED) {
|
||||
const createdChunk = chunk as MCPToolCreatedChunk
|
||||
|
||||
// 1. 处理Function Call方式的工具调用
|
||||
if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) {
|
||||
toolCalls.push(...createdChunk.tool_calls)
|
||||
hasToolCalls = true
|
||||
}
|
||||
|
||||
// 2. 处理Tool Use方式的工具调用
|
||||
if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) {
|
||||
toolUseResponses.push(...createdChunk.tool_use_responses)
|
||||
hasToolUseResponses = true
|
||||
}
|
||||
|
||||
// 不转发MCP工具进展chunks,避免重复处理
|
||||
return
|
||||
}
|
||||
|
||||
// 转发其他所有chunk
|
||||
controller.enqueue(chunk)
|
||||
} catch (error) {
|
||||
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
|
||||
async flush(controller) {
|
||||
const shouldExecuteToolCalls = hasToolCalls && toolCalls.length > 0
|
||||
const shouldExecuteToolUseResponses = hasToolUseResponses && toolUseResponses.length > 0
|
||||
|
||||
if (!streamEnded && (shouldExecuteToolCalls || shouldExecuteToolUseResponses)) {
|
||||
streamEnded = true
|
||||
|
||||
try {
|
||||
let toolResult: SdkMessageParam[] = []
|
||||
|
||||
if (shouldExecuteToolCalls) {
|
||||
toolResult = await executeToolCalls(
|
||||
ctx,
|
||||
toolCalls,
|
||||
mcpTools,
|
||||
allToolResponses,
|
||||
currentParams.onChunk,
|
||||
currentParams.assistant.model!
|
||||
)
|
||||
} else if (shouldExecuteToolUseResponses) {
|
||||
toolResult = await executeToolUseResponses(
|
||||
ctx,
|
||||
toolUseResponses,
|
||||
mcpTools,
|
||||
allToolResponses,
|
||||
currentParams.onChunk,
|
||||
currentParams.assistant.model!
|
||||
)
|
||||
}
|
||||
|
||||
if (toolResult.length > 0) {
|
||||
const output = ctx._internal.toolProcessingState?.output
|
||||
|
||||
const newParams = buildParamsWithToolResults(ctx, currentParams, output!, toolResult, toolCalls)
|
||||
await executeWithToolHandling(newParams, depth + 1)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`🔧 [${MIDDLEWARE_NAME}] Error in tool processing:`, error)
|
||||
controller.error(error)
|
||||
} finally {
|
||||
hasToolCalls = false
|
||||
hasToolUseResponses = false
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行工具调用(Function Call 方式)
|
||||
*/
|
||||
async function executeToolCalls(
|
||||
ctx: CompletionsContext,
|
||||
toolCalls: SdkToolCall[],
|
||||
mcpTools: MCPTool[],
|
||||
allToolResponses: MCPToolResponse[],
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
model: Model
|
||||
): Promise<SdkMessageParam[]> {
|
||||
// 转换为MCPToolResponse格式
|
||||
const mcpToolResponses: ToolCallResponse[] = toolCalls
|
||||
.map((toolCall) => {
|
||||
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||
if (!mcpTool) {
|
||||
return undefined
|
||||
}
|
||||
return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
||||
})
|
||||
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
|
||||
|
||||
if (mcpToolResponses.length === 0) {
|
||||
console.warn(`🔧 [${MIDDLEWARE_NAME}] No valid MCP tool responses to execute`)
|
||||
return []
|
||||
}
|
||||
|
||||
// 使用现有的parseAndCallTools函数执行工具
|
||||
const toolResults = await parseAndCallTools(
|
||||
mcpToolResponses,
|
||||
allToolResponses,
|
||||
onChunk,
|
||||
(mcpToolResponse, resp, model) => {
|
||||
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
},
|
||||
model,
|
||||
mcpTools
|
||||
)
|
||||
|
||||
return toolResults
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行工具使用响应(Tool Use Response 方式)
|
||||
* 处理已经解析好的 ToolUseResponse[],不需要重新解析字符串
|
||||
*/
|
||||
async function executeToolUseResponses(
|
||||
ctx: CompletionsContext,
|
||||
toolUseResponses: MCPToolResponse[],
|
||||
mcpTools: MCPTool[],
|
||||
allToolResponses: MCPToolResponse[],
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
model: Model
|
||||
): Promise<SdkMessageParam[]> {
|
||||
// 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse
|
||||
const toolResults = await parseAndCallTools(
|
||||
toolUseResponses,
|
||||
allToolResponses,
|
||||
onChunk,
|
||||
(mcpToolResponse, resp, model) => {
|
||||
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
},
|
||||
model,
|
||||
mcpTools
|
||||
)
|
||||
|
||||
return toolResults
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建包含工具结果的新参数
|
||||
*/
|
||||
function buildParamsWithToolResults(
|
||||
ctx: CompletionsContext,
|
||||
currentParams: CompletionsParams,
|
||||
output: SdkRawOutput | string,
|
||||
toolResults: SdkMessageParam[],
|
||||
toolCalls: SdkToolCall[]
|
||||
): CompletionsParams {
|
||||
// 获取当前已经转换好的reqMessages,如果没有则使用原始messages
|
||||
const currentReqMessages = getCurrentReqMessages(ctx)
|
||||
|
||||
const apiClient = ctx.apiClientInstance
|
||||
|
||||
// 从回复中构建助手消息
|
||||
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
||||
|
||||
// 估算新增消息的 token 消耗并累加到 usage 中
|
||||
if (ctx._internal.observer?.usage && newReqMessages.length > currentReqMessages.length) {
|
||||
try {
|
||||
const newMessages = newReqMessages.slice(currentReqMessages.length)
|
||||
const additionalTokens = newMessages.reduce((acc, message) => {
|
||||
return acc + ctx.apiClientInstance.estimateMessageTokens(message)
|
||||
}, 0)
|
||||
|
||||
if (additionalTokens > 0) {
|
||||
ctx._internal.observer.usage.prompt_tokens += additionalTokens
|
||||
ctx._internal.observer.usage.total_tokens += additionalTokens
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Error estimating token usage for new messages:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新递归状态
|
||||
if (!ctx._internal.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState = {}
|
||||
}
|
||||
ctx._internal.toolProcessingState.isRecursiveCall = true
|
||||
ctx._internal.toolProcessingState.recursionDepth = (ctx._internal.toolProcessingState?.recursionDepth || 0) + 1
|
||||
|
||||
return {
|
||||
...currentParams,
|
||||
_internal: {
|
||||
...ctx._internal,
|
||||
sdkPayload: ctx._internal.sdkPayload,
|
||||
newReqMessages: newReqMessages
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型安全地获取当前请求消息
|
||||
* 使用API客户端提供的抽象方法,保持中间件的provider无关性
|
||||
*/
|
||||
function getCurrentReqMessages(ctx: CompletionsContext): SdkMessageParam[] {
|
||||
const sdkPayload = ctx._internal.sdkPayload
|
||||
if (!sdkPayload) {
|
||||
return []
|
||||
}
|
||||
|
||||
// 使用API客户端的抽象方法来提取消息,保持provider无关性
|
||||
return ctx.apiClientInstance.extractMessagesFromSdkPayload(sdkPayload)
|
||||
}
|
||||
|
||||
export default McpToolChunkMiddleware
|
||||
@ -0,0 +1,48 @@
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
|
||||
import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
|
||||
|
||||
import { AnthropicStreamListener } from '../../clients/types'
|
||||
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'RawStreamListenerMiddleware'
|
||||
|
||||
export const RawStreamListenerMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 在这里可以监听到从SDK返回的最原始流
|
||||
if (result.rawOutput) {
|
||||
console.log(`[${MIDDLEWARE_NAME}] 检测到原始SDK输出,准备附加监听器`)
|
||||
|
||||
const providerType = ctx.apiClientInstance.provider.type
|
||||
// TODO: 后面下放到AnthropicAPIClient
|
||||
if (providerType === 'anthropic') {
|
||||
const anthropicListener: AnthropicStreamListener<AnthropicSdkRawChunk> = {
|
||||
onMessage: (message) => {
|
||||
if (ctx._internal?.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState.output = message
|
||||
}
|
||||
}
|
||||
// onContentBlock: (contentBlock) => {
|
||||
// console.log(`[${MIDDLEWARE_NAME}] 📝 Anthropic content block:`, contentBlock.type)
|
||||
// }
|
||||
}
|
||||
|
||||
const specificApiClient = ctx.apiClientInstance as AnthropicAPIClient
|
||||
|
||||
const monitoredOutput = specificApiClient.attachRawStreamListener(
|
||||
result.rawOutput as AnthropicSdkRawOutput,
|
||||
anthropicListener
|
||||
)
|
||||
return {
|
||||
...result,
|
||||
rawOutput: monitoredOutput
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,85 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { SdkRawChunk } from '@renderer/types/sdk'
|
||||
|
||||
import { ResponseChunkTransformerContext } from '../../clients/types'
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ResponseTransformMiddleware'
|
||||
|
||||
/**
|
||||
* 响应转换中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 检测ReadableStream类型的响应流
|
||||
* 2. 使用ApiClient的getResponseChunkTransformer()将原始SDK响应块转换为通用格式
|
||||
* 3. 将转换后的ReadableStream保存到ctx._internal.apiCall.genericChunkStream,供下游中间件使用
|
||||
*
|
||||
* 注意:此中间件应该在StreamAdapterMiddleware之后执行
|
||||
*/
|
||||
export const ResponseTransformMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:转换原始SDK响应块
|
||||
if (result.stream) {
|
||||
const adaptedStream = result.stream
|
||||
|
||||
// 处理ReadableStream类型的流
|
||||
if (adaptedStream instanceof ReadableStream) {
|
||||
const apiClient = ctx.apiClientInstance
|
||||
if (!apiClient) {
|
||||
console.error(`[${MIDDLEWARE_NAME}] ApiClient instance not found in context`)
|
||||
throw new Error('ApiClient instance not found in context')
|
||||
}
|
||||
|
||||
// 获取响应转换器
|
||||
const responseChunkTransformer = apiClient.getResponseChunkTransformer?.()
|
||||
if (!responseChunkTransformer) {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] No ResponseChunkTransformer available, skipping transformation`)
|
||||
return result
|
||||
}
|
||||
|
||||
const assistant = params.assistant
|
||||
const model = assistant?.model
|
||||
|
||||
if (!assistant || !model) {
|
||||
console.error(`[${MIDDLEWARE_NAME}] Assistant or Model not found for transformation`)
|
||||
throw new Error('Assistant or Model not found for transformation')
|
||||
}
|
||||
|
||||
const transformerContext: ResponseChunkTransformerContext = {
|
||||
isStreaming: params.streamOutput || false,
|
||||
isEnabledToolCalling: (params.mcpTools && params.mcpTools.length > 0) || false,
|
||||
isEnabledWebSearch: params.enableWebSearch || false,
|
||||
isEnabledReasoning: params.enableReasoning || false,
|
||||
mcpTools: params.mcpTools || [],
|
||||
provider: ctx.apiClientInstance?.provider
|
||||
}
|
||||
|
||||
console.log(`[${MIDDLEWARE_NAME}] Transforming raw SDK chunks with context:`, transformerContext)
|
||||
|
||||
try {
|
||||
// 创建转换后的流
|
||||
const genericChunkTransformStream = (adaptedStream as ReadableStream<SdkRawChunk>).pipeThrough<GenericChunk>(
|
||||
new TransformStream<SdkRawChunk, GenericChunk>(responseChunkTransformer(transformerContext))
|
||||
)
|
||||
|
||||
// 将转换后的ReadableStream保存到result,供下游中间件使用
|
||||
return {
|
||||
...result,
|
||||
stream: genericChunkTransformStream
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error(`[${MIDDLEWARE_NAME}] Error during chunk transformation:`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有流或不是ReadableStream,返回原始结果
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,57 @@
|
||||
import { SdkRawChunk } from '@renderer/types/sdk'
|
||||
import { asyncGeneratorToReadableStream, createSingleChunkReadableStream } from '@renderer/utils/stream'
|
||||
|
||||
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
import { isAsyncIterable } from '../utils'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'StreamAdapterMiddleware'
|
||||
|
||||
/**
|
||||
* 流适配器中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 检测ctx._internal.apiCall.rawSdkOutput(优先)或原始AsyncIterable流
|
||||
* 2. 将AsyncIterable转换为WHATWG ReadableStream
|
||||
* 3. 更新响应结果中的stream
|
||||
*
|
||||
* 注意:如果ResponseTransformMiddleware已处理过,会优先使用transformedStream
|
||||
*/
|
||||
export const StreamAdapterMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// TODO:调用开始,因为这个是最靠近接口请求的地方,next执行代表着开始接口请求了
|
||||
// 但是这个中间件的职责是流适配,是否在这调用优待商榷
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
if (
|
||||
result.rawOutput &&
|
||||
!(result.rawOutput instanceof ReadableStream) &&
|
||||
isAsyncIterable<SdkRawChunk>(result.rawOutput)
|
||||
) {
|
||||
const whatwgReadableStream: ReadableStream<SdkRawChunk> = asyncGeneratorToReadableStream<SdkRawChunk>(
|
||||
result.rawOutput
|
||||
)
|
||||
return {
|
||||
...result,
|
||||
stream: whatwgReadableStream
|
||||
}
|
||||
} else if (result.rawOutput && result.rawOutput instanceof ReadableStream) {
|
||||
return {
|
||||
...result,
|
||||
stream: result.rawOutput
|
||||
}
|
||||
} else if (result.rawOutput) {
|
||||
// 非流式输出,强行变为可读流
|
||||
const whatwgReadableStream: ReadableStream<SdkRawChunk> = createSingleChunkReadableStream<SdkRawChunk>(
|
||||
result.rawOutput
|
||||
)
|
||||
return {
|
||||
...result,
|
||||
stream: whatwgReadableStream
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,99 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { ChunkType, TextDeltaChunk } from '@renderer/types/chunk'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'TextChunkMiddleware'
|
||||
|
||||
/**
|
||||
* 文本块处理中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 累积文本内容(TEXT_DELTA)
|
||||
* 2. 对文本内容进行智能链接转换
|
||||
* 3. 生成TEXT_COMPLETE事件
|
||||
* 4. 暂存Web搜索结果,用于最终链接完善
|
||||
* 5. 处理 onResponse 回调,实时发送文本更新和最终完整文本
|
||||
*/
|
||||
export const TextChunkMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:转换流式响应中的文本内容
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
const assistant = params.assistant
|
||||
const model = params.assistant?.model
|
||||
|
||||
if (!assistant || !model) {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] Missing assistant or model information, skipping text processing`)
|
||||
return result
|
||||
}
|
||||
|
||||
// 用于跨chunk的状态管理
|
||||
let accumulatedTextContent = ''
|
||||
let hasEnqueue = false
|
||||
const enhancedTextStream = resultFromUpstream.pipeThrough(
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
const textChunk = chunk as TextDeltaChunk
|
||||
accumulatedTextContent += textChunk.text
|
||||
|
||||
// 处理 onResponse 回调 - 发送增量文本更新
|
||||
if (params.onResponse) {
|
||||
params.onResponse(accumulatedTextContent, false)
|
||||
}
|
||||
|
||||
// 创建新的chunk,包含处理后的文本
|
||||
controller.enqueue(chunk)
|
||||
} else if (accumulatedTextContent) {
|
||||
if (chunk.type !== ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||
controller.enqueue(chunk)
|
||||
hasEnqueue = true
|
||||
}
|
||||
const finalText = accumulatedTextContent
|
||||
ctx._internal.customState!.accumulatedText = finalText
|
||||
if (ctx._internal.toolProcessingState && !ctx._internal.toolProcessingState?.output) {
|
||||
ctx._internal.toolProcessingState.output = finalText
|
||||
}
|
||||
|
||||
// 处理 onResponse 回调 - 发送最终完整文本
|
||||
if (params.onResponse) {
|
||||
params.onResponse(finalText, true)
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: finalText
|
||||
})
|
||||
accumulatedTextContent = ''
|
||||
if (!hasEnqueue) {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
} else {
|
||||
// 其他类型的chunk直接传递
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
// 更新响应结果
|
||||
return {
|
||||
...result,
|
||||
stream: enhancedTextStream
|
||||
}
|
||||
} else {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream. Returning original result.`)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
101
src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts
Normal file
101
src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts
Normal file
@ -0,0 +1,101 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { ChunkType, ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ThinkChunkMiddleware'
|
||||
|
||||
/**
|
||||
* 处理思考内容的中间件
|
||||
*
|
||||
* 注意:从 v2 版本开始,流结束语义的判断已移至 ApiClient 层处理
|
||||
* 此中间件现在主要负责:
|
||||
* 1. 处理原始SDK chunk中的reasoning字段
|
||||
* 2. 计算准确的思考时间
|
||||
* 3. 在思考内容结束时生成THINKING_COMPLETE事件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 累积思考内容(THINKING_DELTA)
|
||||
* 2. 监听流结束信号,生成THINKING_COMPLETE事件
|
||||
* 3. 计算准确的思考时间
|
||||
*
|
||||
*/
|
||||
export const ThinkChunkMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:处理思考内容
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
|
||||
// 检查是否启用reasoning
|
||||
const enableReasoning = params.enableReasoning || false
|
||||
if (!enableReasoning) {
|
||||
return result
|
||||
}
|
||||
|
||||
// 检查是否有流需要处理
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
// thinking 处理状态
|
||||
let accumulatedThinkingContent = ''
|
||||
let hasThinkingContent = false
|
||||
let thinkingStartTime = 0
|
||||
|
||||
const processedStream = resultFromUpstream.pipeThrough(
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
if (chunk.type === ChunkType.THINKING_DELTA) {
|
||||
const thinkingChunk = chunk as ThinkingDeltaChunk
|
||||
|
||||
// 第一次接收到思考内容时记录开始时间
|
||||
if (!hasThinkingContent) {
|
||||
hasThinkingContent = true
|
||||
thinkingStartTime = Date.now()
|
||||
}
|
||||
|
||||
accumulatedThinkingContent += thinkingChunk.text
|
||||
|
||||
// 更新思考时间并传递
|
||||
const enhancedChunk: ThinkingDeltaChunk = {
|
||||
...thinkingChunk,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(enhancedChunk)
|
||||
} else if (hasThinkingContent && thinkingStartTime > 0) {
|
||||
// 收到任何非THINKING_DELTA的chunk时,如果有累积的思考内容,生成THINKING_COMPLETE
|
||||
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(thinkingCompleteChunk)
|
||||
hasThinkingContent = false
|
||||
accumulatedThinkingContent = ''
|
||||
thinkingStartTime = 0
|
||||
|
||||
// 继续传递当前chunk
|
||||
controller.enqueue(chunk)
|
||||
} else {
|
||||
// 其他情况直接传递
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
// 更新响应结果
|
||||
return {
|
||||
...result,
|
||||
stream: processedStream
|
||||
}
|
||||
} else {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,83 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'TransformCoreToSdkParamsMiddleware'
|
||||
|
||||
/**
|
||||
* 中间件:将CoreCompletionsRequest转换为SDK特定的参数
|
||||
* 使用上下文中ApiClient实例的requestTransformer进行转换
|
||||
*/
|
||||
export const TransformCoreToSdkParamsMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
Logger.debug(`🔄 [${MIDDLEWARE_NAME}] Starting core to SDK params transformation:`, ctx)
|
||||
|
||||
const internal = ctx._internal
|
||||
|
||||
// 🔧 检测递归调用:检查 params 中是否携带了预处理的 SDK 消息
|
||||
const isRecursiveCall = internal?.toolProcessingState?.isRecursiveCall || false
|
||||
const newSdkMessages = params._internal?.newReqMessages
|
||||
|
||||
const apiClient = ctx.apiClientInstance
|
||||
|
||||
if (!apiClient) {
|
||||
Logger.error(`🔄 [${MIDDLEWARE_NAME}] ApiClient instance not found in context.`)
|
||||
throw new Error('ApiClient instance not found in context')
|
||||
}
|
||||
|
||||
// 检查是否有requestTransformer方法
|
||||
const requestTransformer = apiClient.getRequestTransformer()
|
||||
if (!requestTransformer) {
|
||||
Logger.warn(
|
||||
`🔄 [${MIDDLEWARE_NAME}] ApiClient does not have getRequestTransformer method, skipping transformation`
|
||||
)
|
||||
const result = await next(ctx, params)
|
||||
return result
|
||||
}
|
||||
|
||||
// 确保assistant和model可用,它们是transformer所需的
|
||||
const assistant = params.assistant
|
||||
const model = params.assistant.model
|
||||
|
||||
if (!assistant || !model) {
|
||||
console.error(`🔄 [${MIDDLEWARE_NAME}] Assistant or Model not found for transformation.`)
|
||||
throw new Error('Assistant or Model not found for transformation')
|
||||
}
|
||||
|
||||
try {
|
||||
const transformResult = await requestTransformer.transform(
|
||||
params,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
newSdkMessages
|
||||
)
|
||||
|
||||
const { payload: sdkPayload, metadata } = transformResult
|
||||
|
||||
// 将SDK特定的payload和metadata存储在状态中,供下游中间件使用
|
||||
ctx._internal.sdkPayload = sdkPayload
|
||||
|
||||
if (metadata) {
|
||||
ctx._internal.customState = {
|
||||
...ctx._internal.customState,
|
||||
sdkMetadata: metadata
|
||||
}
|
||||
}
|
||||
|
||||
if (params.enableGenerateImage) {
|
||||
params.onChunk?.({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
}
|
||||
return next(ctx, params)
|
||||
} catch (error) {
|
||||
Logger.error(`🔄 [${MIDDLEWARE_NAME}] Error during request transformation:`, error)
|
||||
// 让错误向上传播,或者可以在这里进行特定的错误处理
|
||||
throw error
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,76 @@
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { smartLinkConverter } from '@renderer/utils/linkConverter'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'WebSearchMiddleware'
|
||||
|
||||
/**
|
||||
* Web搜索处理中间件 - 基于GenericChunk流处理
|
||||
*
|
||||
* 职责:
|
||||
* 1. 监听和记录Web搜索事件
|
||||
* 2. 可以在此处添加Web搜索结果的后处理逻辑
|
||||
* 3. 维护Web搜索相关的状态
|
||||
*
|
||||
* 注意:Web搜索结果的识别和生成已在ApiClient的响应转换器中处理
|
||||
*/
|
||||
export const WebSearchMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
ctx._internal.webSearchState = {
|
||||
results: undefined
|
||||
}
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
const model = params.assistant?.model!
|
||||
let isFirstChunk = true
|
||||
|
||||
// 响应后处理:记录Web搜索事件
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream
|
||||
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
// Web搜索状态跟踪
|
||||
const enhancedStream = (resultFromUpstream as ReadableStream<GenericChunk>).pipeThrough(
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
const providerType = model.provider || 'openai'
|
||||
// 使用当前可用的Web搜索结果进行链接转换
|
||||
const text = chunk.text
|
||||
const processedText = smartLinkConverter(text, providerType, isFirstChunk)
|
||||
if (isFirstChunk) {
|
||||
isFirstChunk = false
|
||||
}
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
text: processedText
|
||||
})
|
||||
} else if (chunk.type === ChunkType.LLM_WEB_SEARCH_COMPLETE) {
|
||||
// 暂存Web搜索结果用于链接完善
|
||||
ctx._internal.webSearchState!.results = chunk.llm_web_search
|
||||
|
||||
// 将Web搜索完成事件继续传递下去
|
||||
controller.enqueue(chunk)
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return {
|
||||
...result,
|
||||
stream: enhancedStream
|
||||
}
|
||||
} else {
|
||||
console.log(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream.`)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,132 @@
|
||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import OpenAI from 'openai'
|
||||
import { toFile } from 'openai/uploads'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ImageGenerationMiddleware'
|
||||
|
||||
export const ImageGenerationMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (context: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const { assistant, messages } = params
|
||||
const client = context.apiClientInstance as BaseApiClient<OpenAI>
|
||||
const signal = context._internal?.flowControl?.abortSignal
|
||||
|
||||
if (!assistant.model || !isDedicatedImageGenerationModel(assistant.model) || typeof messages === 'string') {
|
||||
return next(context, params)
|
||||
}
|
||||
|
||||
const stream = new ReadableStream<GenericChunk>({
|
||||
async start(controller) {
|
||||
const enqueue = (chunk: GenericChunk) => controller.enqueue(chunk)
|
||||
|
||||
try {
|
||||
if (!assistant.model) {
|
||||
throw new Error('Assistant model is not defined.')
|
||||
}
|
||||
|
||||
const sdk = await client.getSdkInstance()
|
||||
const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
||||
const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant')
|
||||
|
||||
if (!lastUserMessage) {
|
||||
throw new Error('No user message found for image generation.')
|
||||
}
|
||||
|
||||
const prompt = getMainTextContent(lastUserMessage)
|
||||
let imageFiles: Blob[] = []
|
||||
|
||||
// Collect images from user message
|
||||
const userImageBlocks = findImageBlocks(lastUserMessage)
|
||||
const userImages = await Promise.all(
|
||||
userImageBlocks.map(async (block) => {
|
||||
if (!block.file) return null
|
||||
const binaryData: Uint8Array = await window.api.file.binaryImage(block.file.id)
|
||||
const mimeType = `${block.file.type}/${block.file.ext.slice(1)}`
|
||||
return await toFile(new Blob([binaryData]), block.file.origin_name || 'image.png', { type: mimeType })
|
||||
})
|
||||
)
|
||||
imageFiles = imageFiles.concat(userImages.filter(Boolean) as Blob[])
|
||||
|
||||
// Collect images from last assistant message
|
||||
if (lastAssistantMessage) {
|
||||
const assistantImageBlocks = findImageBlocks(lastAssistantMessage)
|
||||
const assistantImages = await Promise.all(
|
||||
assistantImageBlocks.map(async (block) => {
|
||||
const b64 = block.url?.replace(/^data:image\/\w+;base64,/, '')
|
||||
if (!b64) return null
|
||||
const binary = atob(b64)
|
||||
const bytes = new Uint8Array(binary.length)
|
||||
for (let i = 0; i < binary.length; i++) bytes[i] = binary.charCodeAt(i)
|
||||
return await toFile(new Blob([bytes]), 'assistant_image.png', { type: 'image/png' })
|
||||
})
|
||||
)
|
||||
imageFiles = imageFiles.concat(assistantImages.filter(Boolean) as Blob[])
|
||||
}
|
||||
|
||||
enqueue({ type: ChunkType.IMAGE_CREATED })
|
||||
|
||||
const startTime = Date.now()
|
||||
let response: OpenAI.Images.ImagesResponse
|
||||
|
||||
const options = { signal, timeout: 300_000 }
|
||||
|
||||
if (imageFiles.length > 0) {
|
||||
response = await sdk.images.edit(
|
||||
{
|
||||
model: assistant.model.id,
|
||||
image: imageFiles,
|
||||
prompt: prompt || ''
|
||||
},
|
||||
options
|
||||
)
|
||||
} else {
|
||||
response = await sdk.images.generate(
|
||||
{
|
||||
model: assistant.model.id,
|
||||
prompt: prompt || '',
|
||||
response_format: assistant.model.id.includes('gpt-image-1') ? undefined : 'b64_json'
|
||||
},
|
||||
options
|
||||
)
|
||||
}
|
||||
|
||||
const b64_json_array = response.data?.map((item) => `data:image/png;base64,${item.b64_json}`) || []
|
||||
|
||||
enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: { type: 'base64', images: b64_json_array }
|
||||
})
|
||||
|
||||
const usage = (response as any).usage || { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }
|
||||
|
||||
enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage,
|
||||
metrics: {
|
||||
completion_tokens: usage.completion_tokens,
|
||||
time_first_token_millsec: 0,
|
||||
time_completion_millsec: Date.now() - startTime
|
||||
}
|
||||
}
|
||||
})
|
||||
} catch (error: any) {
|
||||
enqueue({ type: ChunkType.ERROR, error })
|
||||
} finally {
|
||||
controller.close()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
stream,
|
||||
getText: () => ''
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,136 @@
|
||||
import { Model } from '@renderer/types'
|
||||
import { ChunkType, TextDeltaChunk, ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk'
|
||||
import { TagConfig, TagExtractor } from '@renderer/utils/tagExtraction'
|
||||
import Logger from 'electron-log/renderer'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ThinkingTagExtractionMiddleware'
|
||||
|
||||
// 不同模型的思考标签配置
|
||||
const reasoningTags: TagConfig[] = [
|
||||
{ openingTag: '<think>', closingTag: '</think>', separator: '\n' },
|
||||
{ openingTag: '###Thinking', closingTag: '###Response', separator: '\n' }
|
||||
]
|
||||
|
||||
const getAppropriateTag = (model?: Model): TagConfig => {
|
||||
if (model?.id?.includes('qwen3')) return reasoningTags[0]
|
||||
// 可以在这里添加更多模型特定的标签配置
|
||||
return reasoningTags[0] // 默认使用 <think> 标签
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理文本流中思考标签提取的中间件
|
||||
*
|
||||
* 该中间件专门处理文本流中的思考标签内容(如 <think>...</think>)
|
||||
* 主要用于 OpenAI 等支持思考标签的 provider
|
||||
*
|
||||
* 职责:
|
||||
* 1. 从文本流中提取思考标签内容
|
||||
* 2. 将标签内的内容转换为 THINKING_DELTA chunk
|
||||
* 3. 将标签外的内容作为正常文本输出
|
||||
* 4. 处理不同模型的思考标签格式
|
||||
* 5. 在思考内容结束时生成 THINKING_COMPLETE 事件
|
||||
*/
|
||||
export const ThinkingTagExtractionMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (context: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// 调用下游中间件
|
||||
const result = await next(context, params)
|
||||
|
||||
// 响应后处理:处理思考标签提取
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
|
||||
// 检查是否有流需要处理
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
// 获取当前模型的思考标签配置
|
||||
const model = params.assistant?.model
|
||||
const reasoningTag = getAppropriateTag(model)
|
||||
|
||||
// 创建标签提取器
|
||||
const tagExtractor = new TagExtractor(reasoningTag)
|
||||
|
||||
// thinking 处理状态
|
||||
let hasThinkingContent = false
|
||||
let thinkingStartTime = 0
|
||||
|
||||
const processedStream = resultFromUpstream.pipeThrough(
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
const textChunk = chunk as TextDeltaChunk
|
||||
|
||||
// 使用 TagExtractor 处理文本
|
||||
const extractionResults = tagExtractor.processText(textChunk.text)
|
||||
|
||||
for (const extractionResult of extractionResults) {
|
||||
if (extractionResult.complete && extractionResult.tagContentExtracted) {
|
||||
// 生成 THINKING_COMPLETE 事件
|
||||
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: extractionResult.tagContentExtracted,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(thinkingCompleteChunk)
|
||||
|
||||
// 重置思考状态
|
||||
hasThinkingContent = false
|
||||
thinkingStartTime = 0
|
||||
} else if (extractionResult.content.length > 0) {
|
||||
if (extractionResult.isTagContent) {
|
||||
// 第一次接收到思考内容时记录开始时间
|
||||
if (!hasThinkingContent) {
|
||||
hasThinkingContent = true
|
||||
thinkingStartTime = Date.now()
|
||||
}
|
||||
|
||||
const thinkingDeltaChunk: ThinkingDeltaChunk = {
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: extractionResult.content,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(thinkingDeltaChunk)
|
||||
} else {
|
||||
// 发送清理后的文本内容
|
||||
const cleanTextChunk: TextDeltaChunk = {
|
||||
...textChunk,
|
||||
text: extractionResult.content
|
||||
}
|
||||
controller.enqueue(cleanTextChunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 其他类型的chunk直接传递(包括 THINKING_DELTA, THINKING_COMPLETE 等)
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
},
|
||||
flush(controller) {
|
||||
// 处理可能剩余的思考内容
|
||||
const finalResult = tagExtractor.finalize()
|
||||
if (finalResult?.tagContentExtracted) {
|
||||
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: finalResult.tagContentExtracted,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(thinkingCompleteChunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
// 更新响应结果
|
||||
return {
|
||||
...result,
|
||||
stream: processedStream
|
||||
}
|
||||
} else {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,124 @@
|
||||
import { MCPTool } from '@renderer/types'
|
||||
import { ChunkType, MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk'
|
||||
import { parseToolUse } from '@renderer/utils/mcp-tools'
|
||||
import { TagConfig, TagExtractor } from '@renderer/utils/tagExtraction'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ToolUseExtractionMiddleware'
|
||||
|
||||
// 工具使用标签配置
|
||||
const TOOL_USE_TAG_CONFIG: TagConfig = {
|
||||
openingTag: '<tool_use>',
|
||||
closingTag: '</tool_use>',
|
||||
separator: '\n'
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具使用提取中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 从文本流中检测并提取 <tool_use></tool_use> 标签
|
||||
* 2. 解析工具调用信息并转换为 ToolUseResponse 格式
|
||||
* 3. 生成 MCP_TOOL_CREATED chunk 供 McpToolChunkMiddleware 处理
|
||||
* 4. 清理文本流,移除工具使用标签但保留正常文本
|
||||
*
|
||||
* 注意:此中间件只负责提取和转换,实际工具调用由 McpToolChunkMiddleware 处理
|
||||
*/
|
||||
export const ToolUseExtractionMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const mcpTools = params.mcpTools || []
|
||||
|
||||
// 如果没有工具,直接调用下一个中间件
|
||||
if (!mcpTools || mcpTools.length === 0) return next(ctx, params)
|
||||
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:处理工具使用标签提取
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
|
||||
const processedStream = resultFromUpstream.pipeThrough(createToolUseExtractionTransform(ctx, mcpTools))
|
||||
|
||||
return {
|
||||
...result,
|
||||
stream: processedStream
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建工具使用提取的 TransformStream
|
||||
*/
|
||||
function createToolUseExtractionTransform(
|
||||
_ctx: CompletionsContext,
|
||||
mcpTools: MCPTool[]
|
||||
): TransformStream<GenericChunk, GenericChunk> {
|
||||
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
|
||||
|
||||
return new TransformStream({
|
||||
async transform(chunk: GenericChunk, controller) {
|
||||
try {
|
||||
// 处理文本内容,检测工具使用标签
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
const textChunk = chunk as TextDeltaChunk
|
||||
const extractionResults = tagExtractor.processText(textChunk.text)
|
||||
|
||||
for (const result of extractionResults) {
|
||||
if (result.complete && result.tagContentExtracted) {
|
||||
// 提取到完整的工具使用内容,解析并转换为 SDK ToolCall 格式
|
||||
const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools)
|
||||
|
||||
if (toolUseResponses.length > 0) {
|
||||
// 生成 MCP_TOOL_CREATED chunk,复用现有的处理流程
|
||||
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_use_responses: toolUseResponses
|
||||
}
|
||||
controller.enqueue(mcpToolCreatedChunk)
|
||||
}
|
||||
} else if (!result.isTagContent && result.content) {
|
||||
// 发送标签外的正常文本内容
|
||||
const cleanTextChunk: TextDeltaChunk = {
|
||||
...textChunk,
|
||||
text: result.content
|
||||
}
|
||||
controller.enqueue(cleanTextChunk)
|
||||
}
|
||||
// 注意:标签内的内容不会作为TEXT_DELTA转发,避免重复显示
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 转发其他所有chunk
|
||||
controller.enqueue(chunk)
|
||||
} catch (error) {
|
||||
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
|
||||
async flush(controller) {
|
||||
// 检查是否有未完成的标签内容
|
||||
const finalResult = tagExtractor.finalize()
|
||||
if (finalResult && finalResult.tagContentExtracted) {
|
||||
const toolUseResponses = parseToolUse(finalResult.tagContentExtracted, mcpTools)
|
||||
if (toolUseResponses.length > 0) {
|
||||
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_use_responses: toolUseResponses
|
||||
}
|
||||
controller.enqueue(mcpToolCreatedChunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
export default ToolUseExtractionMiddleware
|
||||
88
src/renderer/src/aiCore/middleware/index.ts
Normal file
88
src/renderer/src/aiCore/middleware/index.ts
Normal file
@ -0,0 +1,88 @@
|
||||
import { CompletionsMiddleware, MethodMiddleware } from './types'
|
||||
|
||||
// /**
|
||||
// * Wraps a provider instance with middlewares.
|
||||
// */
|
||||
// export function wrapProviderWithMiddleware(
|
||||
// apiClientInstance: BaseApiClient,
|
||||
// middlewareConfig: MiddlewareConfig
|
||||
// ): BaseApiClient {
|
||||
// console.log(`[wrapProviderWithMiddleware] Wrapping provider: ${apiClientInstance.provider?.id}`)
|
||||
// console.log(`[wrapProviderWithMiddleware] Middleware config:`, {
|
||||
// completions: middlewareConfig.completions?.length || 0,
|
||||
// methods: Object.keys(middlewareConfig.methods || {}).length
|
||||
// })
|
||||
|
||||
// // Cache for already wrapped methods to avoid re-wrapping on every access.
|
||||
// const wrappedMethodsCache = new Map<string, (...args: any[]) => Promise<any>>()
|
||||
|
||||
// const proxy = new Proxy(apiClientInstance, {
|
||||
// get(target, propKey, receiver) {
|
||||
// const methodName = typeof propKey === 'string' ? propKey : undefined
|
||||
|
||||
// if (!methodName) {
|
||||
// return Reflect.get(target, propKey, receiver)
|
||||
// }
|
||||
|
||||
// if (wrappedMethodsCache.has(methodName)) {
|
||||
// console.log(`[wrapProviderWithMiddleware] Using cached wrapped method: ${methodName}`)
|
||||
// return wrappedMethodsCache.get(methodName)
|
||||
// }
|
||||
|
||||
// const originalMethod = Reflect.get(target, propKey, receiver)
|
||||
|
||||
// // If the property is not a function, return it directly.
|
||||
// if (typeof originalMethod !== 'function') {
|
||||
// return originalMethod
|
||||
// }
|
||||
|
||||
// let wrappedMethod: ((...args: any[]) => Promise<any>) | undefined
|
||||
|
||||
// // Handle completions method
|
||||
// if (methodName === 'completions' && middlewareConfig.completions?.length) {
|
||||
// console.log(
|
||||
// `[wrapProviderWithMiddleware] Wrapping completions method with ${middlewareConfig.completions.length} middlewares`
|
||||
// )
|
||||
// const completionsOriginalMethod = originalMethod as (params: CompletionsParams) => Promise<any>
|
||||
// wrappedMethod = applyCompletionsMiddlewares(target, completionsOriginalMethod, middlewareConfig.completions)
|
||||
// }
|
||||
// // Handle other methods
|
||||
// else {
|
||||
// const methodMiddlewares = middlewareConfig.methods?.[methodName]
|
||||
// if (methodMiddlewares?.length) {
|
||||
// console.log(
|
||||
// `[wrapProviderWithMiddleware] Wrapping method ${methodName} with ${methodMiddlewares.length} middlewares`
|
||||
// )
|
||||
// const genericOriginalMethod = originalMethod as (...args: any[]) => Promise<any>
|
||||
// wrappedMethod = applyMethodMiddlewares(target, methodName, genericOriginalMethod, methodMiddlewares)
|
||||
// }
|
||||
// }
|
||||
|
||||
// if (wrappedMethod) {
|
||||
// console.log(`[wrapProviderWithMiddleware] Successfully wrapped method: ${methodName}`)
|
||||
// wrappedMethodsCache.set(methodName, wrappedMethod)
|
||||
// return wrappedMethod
|
||||
// }
|
||||
|
||||
// // If no middlewares are configured for this method, return the original method bound to the target. /
|
||||
// // 如果没有为此方法配置中间件,则返回绑定到目标的原始方法。
|
||||
// console.log(`[wrapProviderWithMiddleware] No middlewares for method ${methodName}, returning original`)
|
||||
// return originalMethod.bind(target)
|
||||
// }
|
||||
// })
|
||||
// return proxy as BaseApiClient
|
||||
// }
|
||||
|
||||
// Export types for external use
|
||||
export type { CompletionsMiddleware, MethodMiddleware }
|
||||
|
||||
// Export MiddlewareBuilder related types and classes
|
||||
export {
|
||||
CompletionsMiddlewareBuilder,
|
||||
createCompletionsBuilder,
|
||||
createMethodBuilder,
|
||||
MethodMiddlewareBuilder,
|
||||
MiddlewareBuilder,
|
||||
type MiddlewareExecutor,
|
||||
type NamedMiddleware
|
||||
} from './builder'
|
||||
149
src/renderer/src/aiCore/middleware/register.ts
Normal file
149
src/renderer/src/aiCore/middleware/register.ts
Normal file
@ -0,0 +1,149 @@
|
||||
import * as AbortHandlerModule from './common/AbortHandlerMiddleware'
|
||||
import * as ErrorHandlerModule from './common/ErrorHandlerMiddleware'
|
||||
import * as FinalChunkConsumerModule from './common/FinalChunkConsumerMiddleware'
|
||||
import * as LoggingModule from './common/LoggingMiddleware'
|
||||
import * as McpToolChunkModule from './core/McpToolChunkMiddleware'
|
||||
import * as RawStreamListenerModule from './core/RawStreamListenerMiddleware'
|
||||
import * as ResponseTransformModule from './core/ResponseTransformMiddleware'
|
||||
// import * as SdkCallModule from './core/SdkCallMiddleware'
|
||||
import * as StreamAdapterModule from './core/StreamAdapterMiddleware'
|
||||
import * as TextChunkModule from './core/TextChunkMiddleware'
|
||||
import * as ThinkChunkModule from './core/ThinkChunkMiddleware'
|
||||
import * as TransformCoreToSdkParamsModule from './core/TransformCoreToSdkParamsMiddleware'
|
||||
import * as WebSearchModule from './core/WebSearchMiddleware'
|
||||
import * as ImageGenerationModule from './feat/ImageGenerationMiddleware'
|
||||
import * as ThinkingTagExtractionModule from './feat/ThinkingTagExtractionMiddleware'
|
||||
import * as ToolUseExtractionMiddleware from './feat/ToolUseExtractionMiddleware'
|
||||
|
||||
/**
|
||||
* 中间件注册表 - 提供所有可用中间件的集中访问
|
||||
* 注意:目前中间件文件还未导出 MIDDLEWARE_NAME,会有 linter 错误,这是正常的
|
||||
*/
|
||||
export const MiddlewareRegistry = {
|
||||
[ErrorHandlerModule.MIDDLEWARE_NAME]: {
|
||||
name: ErrorHandlerModule.MIDDLEWARE_NAME,
|
||||
middleware: ErrorHandlerModule.ErrorHandlerMiddleware
|
||||
},
|
||||
// 通用中间件
|
||||
[AbortHandlerModule.MIDDLEWARE_NAME]: {
|
||||
name: AbortHandlerModule.MIDDLEWARE_NAME,
|
||||
middleware: AbortHandlerModule.AbortHandlerMiddleware
|
||||
},
|
||||
[FinalChunkConsumerModule.MIDDLEWARE_NAME]: {
|
||||
name: FinalChunkConsumerModule.MIDDLEWARE_NAME,
|
||||
middleware: FinalChunkConsumerModule.default
|
||||
},
|
||||
|
||||
// 核心流程中间件
|
||||
[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME]: {
|
||||
name: TransformCoreToSdkParamsModule.MIDDLEWARE_NAME,
|
||||
middleware: TransformCoreToSdkParamsModule.TransformCoreToSdkParamsMiddleware
|
||||
},
|
||||
// [SdkCallModule.MIDDLEWARE_NAME]: {
|
||||
// name: SdkCallModule.MIDDLEWARE_NAME,
|
||||
// middleware: SdkCallModule.SdkCallMiddleware
|
||||
// },
|
||||
[StreamAdapterModule.MIDDLEWARE_NAME]: {
|
||||
name: StreamAdapterModule.MIDDLEWARE_NAME,
|
||||
middleware: StreamAdapterModule.StreamAdapterMiddleware
|
||||
},
|
||||
[RawStreamListenerModule.MIDDLEWARE_NAME]: {
|
||||
name: RawStreamListenerModule.MIDDLEWARE_NAME,
|
||||
middleware: RawStreamListenerModule.RawStreamListenerMiddleware
|
||||
},
|
||||
[ResponseTransformModule.MIDDLEWARE_NAME]: {
|
||||
name: ResponseTransformModule.MIDDLEWARE_NAME,
|
||||
middleware: ResponseTransformModule.ResponseTransformMiddleware
|
||||
},
|
||||
|
||||
// 特性处理中间件
|
||||
[ThinkingTagExtractionModule.MIDDLEWARE_NAME]: {
|
||||
name: ThinkingTagExtractionModule.MIDDLEWARE_NAME,
|
||||
middleware: ThinkingTagExtractionModule.ThinkingTagExtractionMiddleware
|
||||
},
|
||||
[ToolUseExtractionMiddleware.MIDDLEWARE_NAME]: {
|
||||
name: ToolUseExtractionMiddleware.MIDDLEWARE_NAME,
|
||||
middleware: ToolUseExtractionMiddleware.ToolUseExtractionMiddleware
|
||||
},
|
||||
[ThinkChunkModule.MIDDLEWARE_NAME]: {
|
||||
name: ThinkChunkModule.MIDDLEWARE_NAME,
|
||||
middleware: ThinkChunkModule.ThinkChunkMiddleware
|
||||
},
|
||||
[McpToolChunkModule.MIDDLEWARE_NAME]: {
|
||||
name: McpToolChunkModule.MIDDLEWARE_NAME,
|
||||
middleware: McpToolChunkModule.McpToolChunkMiddleware
|
||||
},
|
||||
[WebSearchModule.MIDDLEWARE_NAME]: {
|
||||
name: WebSearchModule.MIDDLEWARE_NAME,
|
||||
middleware: WebSearchModule.WebSearchMiddleware
|
||||
},
|
||||
[TextChunkModule.MIDDLEWARE_NAME]: {
|
||||
name: TextChunkModule.MIDDLEWARE_NAME,
|
||||
middleware: TextChunkModule.TextChunkMiddleware
|
||||
},
|
||||
[ImageGenerationModule.MIDDLEWARE_NAME]: {
|
||||
name: ImageGenerationModule.MIDDLEWARE_NAME,
|
||||
middleware: ImageGenerationModule.ImageGenerationMiddleware
|
||||
}
|
||||
} as const
|
||||
|
||||
/**
|
||||
* 根据名称获取中间件
|
||||
* @param name - 中间件名称
|
||||
* @returns 对应的中间件信息
|
||||
*/
|
||||
export function getMiddleware(name: string) {
|
||||
return MiddlewareRegistry[name]
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有注册的中间件名称
|
||||
* @returns 中间件名称列表
|
||||
*/
|
||||
export function getRegisteredMiddlewareNames(): string[] {
|
||||
return Object.keys(MiddlewareRegistry)
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认的 Completions 中间件配置 - NamedMiddleware 格式,用于 MiddlewareBuilder
|
||||
*/
|
||||
export const DefaultCompletionsNamedMiddlewares = [
|
||||
MiddlewareRegistry[FinalChunkConsumerModule.MIDDLEWARE_NAME], // 最终消费者
|
||||
MiddlewareRegistry[ErrorHandlerModule.MIDDLEWARE_NAME], // 错误处理
|
||||
MiddlewareRegistry[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME], // 参数转换
|
||||
MiddlewareRegistry[AbortHandlerModule.MIDDLEWARE_NAME], // 中止处理
|
||||
MiddlewareRegistry[McpToolChunkModule.MIDDLEWARE_NAME], // 工具处理
|
||||
MiddlewareRegistry[TextChunkModule.MIDDLEWARE_NAME], // 文本处理
|
||||
MiddlewareRegistry[WebSearchModule.MIDDLEWARE_NAME], // Web搜索处理
|
||||
MiddlewareRegistry[ToolUseExtractionMiddleware.MIDDLEWARE_NAME], // 工具使用提取处理
|
||||
MiddlewareRegistry[ThinkingTagExtractionModule.MIDDLEWARE_NAME], // 思考标签提取处理(特定provider)
|
||||
MiddlewareRegistry[ThinkChunkModule.MIDDLEWARE_NAME], // 思考处理(通用SDK)
|
||||
MiddlewareRegistry[ResponseTransformModule.MIDDLEWARE_NAME], // 响应转换
|
||||
MiddlewareRegistry[StreamAdapterModule.MIDDLEWARE_NAME], // 流适配器
|
||||
MiddlewareRegistry[RawStreamListenerModule.MIDDLEWARE_NAME] // 原始流监听器
|
||||
]
|
||||
|
||||
/**
|
||||
* 默认的通用方法中间件 - 例如翻译、摘要等
|
||||
*/
|
||||
export const DefaultMethodMiddlewares = {
|
||||
translate: [LoggingModule.createGenericLoggingMiddleware()],
|
||||
summaries: [LoggingModule.createGenericLoggingMiddleware()]
|
||||
}
|
||||
|
||||
/**
|
||||
* 导出所有中间件模块,方便外部使用
|
||||
*/
|
||||
export {
|
||||
AbortHandlerModule,
|
||||
FinalChunkConsumerModule,
|
||||
LoggingModule,
|
||||
McpToolChunkModule,
|
||||
ResponseTransformModule,
|
||||
StreamAdapterModule,
|
||||
TextChunkModule,
|
||||
ThinkChunkModule,
|
||||
ThinkingTagExtractionModule,
|
||||
TransformCoreToSdkParamsModule,
|
||||
WebSearchModule
|
||||
}
|
||||
77
src/renderer/src/aiCore/middleware/schemas.ts
Normal file
77
src/renderer/src/aiCore/middleware/schemas.ts
Normal file
@ -0,0 +1,77 @@
|
||||
import { Assistant, MCPTool } from '@renderer/types'
|
||||
import { Chunk } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import { SdkRawChunk, SdkRawOutput } from '@renderer/types/sdk'
|
||||
|
||||
import { ProcessingState } from './types'
|
||||
|
||||
// ============================================================================
|
||||
// Core Request Types - 核心请求结构
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* 标准化的内部核心请求结构,用于所有AI Provider的统一处理
|
||||
* 这是应用层参数转换后的标准格式,不包含回调函数和控制逻辑
|
||||
*/
|
||||
export interface CompletionsParams {
|
||||
/**
|
||||
* 调用的业务场景类型,用于中间件判断是否执行
|
||||
* 'chat': 主要对话流程
|
||||
* 'translate': 翻译
|
||||
* 'summary': 摘要
|
||||
* 'search': 搜索摘要
|
||||
* 'generate': 生成
|
||||
* 'check': API检查
|
||||
*/
|
||||
callType?: 'chat' | 'translate' | 'summary' | 'search' | 'generate' | 'check'
|
||||
|
||||
// 基础对话数据
|
||||
messages: Message[] | string // 联合类型方便判断是否为空
|
||||
|
||||
assistant: Assistant // 助手为基本单位
|
||||
// model: Model
|
||||
|
||||
onChunk?: (chunk: Chunk) => void
|
||||
onResponse?: (text: string, isComplete: boolean) => void
|
||||
|
||||
// 错误相关
|
||||
onError?: (error: Error) => void
|
||||
shouldThrow?: boolean
|
||||
|
||||
// 工具相关
|
||||
mcpTools?: MCPTool[]
|
||||
|
||||
// 生成参数
|
||||
temperature?: number
|
||||
topP?: number
|
||||
maxTokens?: number
|
||||
|
||||
// 功能开关
|
||||
streamOutput: boolean
|
||||
enableWebSearch?: boolean
|
||||
enableReasoning?: boolean
|
||||
enableGenerateImage?: boolean
|
||||
|
||||
// 上下文控制
|
||||
contextCount?: number
|
||||
|
||||
_internal?: ProcessingState
|
||||
}
|
||||
|
||||
export interface CompletionsResult {
|
||||
rawOutput?: SdkRawOutput
|
||||
stream?: ReadableStream<SdkRawChunk> | ReadableStream<Chunk> | AsyncIterable<Chunk>
|
||||
controller?: AbortController
|
||||
|
||||
getText: () => string
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Generic Chunk Types - 通用数据块结构
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* 通用数据块类型
|
||||
* 复用现有的 Chunk 类型,这是所有AI Provider都应该输出的标准化数据块格式
|
||||
*/
|
||||
export type GenericChunk = Chunk
|
||||
166
src/renderer/src/aiCore/middleware/types.ts
Normal file
166
src/renderer/src/aiCore/middleware/types.ts
Normal file
@ -0,0 +1,166 @@
|
||||
import { MCPToolResponse, Metrics, Usage, WebSearchResponse } from '@renderer/types'
|
||||
import { Chunk, ErrorChunk } from '@renderer/types/chunk'
|
||||
import {
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
|
||||
import { BaseApiClient } from '../clients'
|
||||
import { CompletionsParams, CompletionsResult } from './schemas'
|
||||
|
||||
/**
|
||||
* Symbol to uniquely identify middleware context objects.
|
||||
*/
|
||||
export const MIDDLEWARE_CONTEXT_SYMBOL = Symbol.for('AiProviderMiddlewareContext')
|
||||
|
||||
/**
|
||||
* Defines the structure for the onChunk callback function.
|
||||
*/
|
||||
export type OnChunkFunction = (chunk: Chunk | ErrorChunk) => void
|
||||
|
||||
/**
|
||||
* Base context that carries information about the current method call.
|
||||
*/
|
||||
export interface BaseContext {
|
||||
[MIDDLEWARE_CONTEXT_SYMBOL]: true
|
||||
methodName: string
|
||||
originalArgs: Readonly<any[]>
|
||||
}
|
||||
|
||||
/**
|
||||
* Processing state shared between middlewares.
|
||||
*/
|
||||
export interface ProcessingState<
|
||||
TParams extends SdkParams = SdkParams,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall
|
||||
> {
|
||||
sdkPayload?: TParams
|
||||
newReqMessages?: TMessageParam[]
|
||||
observer?: {
|
||||
usage?: Usage
|
||||
metrics?: Metrics
|
||||
}
|
||||
toolProcessingState?: {
|
||||
pendingToolCalls?: Array<TToolCall>
|
||||
executingToolCalls?: Array<{
|
||||
sdkToolCall: TToolCall
|
||||
mcpToolResponse: MCPToolResponse
|
||||
}>
|
||||
output?: SdkRawOutput | string
|
||||
isRecursiveCall?: boolean
|
||||
recursionDepth?: number
|
||||
}
|
||||
webSearchState?: {
|
||||
results?: WebSearchResponse
|
||||
}
|
||||
flowControl?: {
|
||||
abortController?: AbortController
|
||||
abortSignal?: AbortSignal
|
||||
cleanup?: () => void
|
||||
}
|
||||
enhancedDispatch?: (context: CompletionsContext, params: CompletionsParams) => Promise<CompletionsResult>
|
||||
customState?: Record<string, any>
|
||||
}
|
||||
|
||||
/**
|
||||
* Extended context for completions method.
|
||||
*/
|
||||
export interface CompletionsContext<
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TSdkMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TSdkToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> extends BaseContext {
|
||||
readonly methodName: 'completions' // 强制方法名为 'completions'
|
||||
|
||||
apiClientInstance: BaseApiClient<
|
||||
TSdkInstance,
|
||||
TSdkParams,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkMessageParam,
|
||||
TSdkToolCall,
|
||||
TSdkSpecificTool
|
||||
>
|
||||
|
||||
// --- Mutable internal state for the duration of the middleware chain ---
|
||||
_internal: ProcessingState<TSdkParams, TSdkMessageParam, TSdkToolCall> // 包含所有可变的处理状态
|
||||
}
|
||||
|
||||
export interface MiddlewareAPI<Ctx extends BaseContext = BaseContext, Args extends any[] = any[]> {
|
||||
getContext: () => Ctx // Function to get the current context / 获取当前上下文的函数
|
||||
getOriginalArgs: () => Args // Function to get the original arguments of the method call / 获取方法调用原始参数的函数
|
||||
}
|
||||
|
||||
/**
|
||||
* Base middleware type.
|
||||
*/
|
||||
export type Middleware<TContext extends BaseContext> = (
|
||||
api: MiddlewareAPI<TContext>
|
||||
) => (
|
||||
next: (context: TContext, args: any[]) => Promise<unknown>
|
||||
) => (context: TContext, args: any[]) => Promise<unknown>
|
||||
|
||||
export type MethodMiddleware = Middleware<BaseContext>
|
||||
|
||||
/**
|
||||
* Completions middleware type.
|
||||
*/
|
||||
export type CompletionsMiddleware<
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TSdkMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TSdkToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> = (
|
||||
api: MiddlewareAPI<
|
||||
CompletionsContext<
|
||||
TSdkParams,
|
||||
TSdkMessageParam,
|
||||
TSdkToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
>,
|
||||
[CompletionsParams]
|
||||
>
|
||||
) => (
|
||||
next: (
|
||||
context: CompletionsContext<
|
||||
TSdkParams,
|
||||
TSdkMessageParam,
|
||||
TSdkToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
>,
|
||||
params: CompletionsParams
|
||||
) => Promise<CompletionsResult>
|
||||
) => (
|
||||
context: CompletionsContext<
|
||||
TSdkParams,
|
||||
TSdkMessageParam,
|
||||
TSdkToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
>,
|
||||
params: CompletionsParams
|
||||
) => Promise<CompletionsResult>
|
||||
|
||||
// Re-export for convenience
|
||||
export type { Chunk as OnChunkArg } from '@renderer/types/chunk'
|
||||
57
src/renderer/src/aiCore/middleware/utils.ts
Normal file
57
src/renderer/src/aiCore/middleware/utils.ts
Normal file
@ -0,0 +1,57 @@
|
||||
import { ChunkType, ErrorChunk } from '@renderer/types/chunk'
|
||||
|
||||
/**
|
||||
* Creates an ErrorChunk object with a standardized structure.
|
||||
* @param error The error object or message.
|
||||
* @param chunkType The type of chunk, defaults to ChunkType.ERROR.
|
||||
* @returns An ErrorChunk object.
|
||||
*/
|
||||
export function createErrorChunk(error: any, chunkType: ChunkType = ChunkType.ERROR): ErrorChunk {
|
||||
let errorDetails: Record<string, any> = {}
|
||||
|
||||
if (error instanceof Error) {
|
||||
errorDetails = {
|
||||
message: error.message,
|
||||
name: error.name,
|
||||
stack: error.stack
|
||||
}
|
||||
} else if (typeof error === 'string') {
|
||||
errorDetails = { message: error }
|
||||
} else if (typeof error === 'object' && error !== null) {
|
||||
errorDetails = Object.getOwnPropertyNames(error).reduce(
|
||||
(acc, key) => {
|
||||
acc[key] = error[key]
|
||||
return acc
|
||||
},
|
||||
{} as Record<string, any>
|
||||
)
|
||||
if (!errorDetails.message && error.toString && typeof error.toString === 'function') {
|
||||
const errMsg = error.toString()
|
||||
if (errMsg !== '[object Object]') {
|
||||
errorDetails.message = errMsg
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
type: chunkType,
|
||||
error: errorDetails
|
||||
} as ErrorChunk
|
||||
}
|
||||
|
||||
// Helper to capitalize method names for hook construction
|
||||
export function capitalize(str: string): string {
|
||||
if (!str) return ''
|
||||
return str.charAt(0).toUpperCase() + str.slice(1)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查对象是否实现了AsyncIterable接口
|
||||
*/
|
||||
export function isAsyncIterable<T = unknown>(obj: unknown): obj is AsyncIterable<T> {
|
||||
return (
|
||||
obj !== null &&
|
||||
typeof obj === 'object' &&
|
||||
typeof (obj as Record<symbol, unknown>)[Symbol.asyncIterator] === 'function'
|
||||
)
|
||||
}
|
||||
BIN
src/renderer/src/assets/images/models/gpt_image_1.png
Normal file
BIN
src/renderer/src/assets/images/models/gpt_image_1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 20 KiB |
@ -55,6 +55,7 @@ import {
|
||||
default as ChatGptModelLogoDakr,
|
||||
default as ChatGPTo1ModelLogoDark
|
||||
} from '@renderer/assets/images/models/gpt_dark.png'
|
||||
import ChatGPTImageModelLogo from '@renderer/assets/images/models/gpt_image_1.png'
|
||||
import ChatGPTo1ModelLogo from '@renderer/assets/images/models/gpt_o1.png'
|
||||
import GrokModelLogo from '@renderer/assets/images/models/grok.png'
|
||||
import GrokModelLogoDark from '@renderer/assets/images/models/grok_dark.png'
|
||||
@ -143,7 +144,7 @@ import YiModelLogoDark from '@renderer/assets/images/models/yi_dark.png'
|
||||
import YoudaoLogo from '@renderer/assets/images/providers/netease-youdao.svg'
|
||||
import NomicLogo from '@renderer/assets/images/providers/nomic.png'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { Assistant, Model } from '@renderer/types'
|
||||
import { Model } from '@renderer/types'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from './prompts'
|
||||
@ -181,7 +182,8 @@ const visionAllowedModels = [
|
||||
'o4(?:-[\\w-]+)?',
|
||||
'deepseek-vl(?:[\\w-]+)?',
|
||||
'kimi-latest',
|
||||
'gemma-3(?:-[\\w-]+)'
|
||||
'gemma-3(?:-[\\w-]+)',
|
||||
'doubao-1.6-seed(?:-[\\w-]+)'
|
||||
]
|
||||
|
||||
const visionExcludedModels = [
|
||||
@ -199,6 +201,11 @@ export const VISION_REGEX = new RegExp(
|
||||
'i'
|
||||
)
|
||||
|
||||
// For middleware to identify models that must use the dedicated Image API
|
||||
export const DEDICATED_IMAGE_MODELS = ['grok-2-image', 'dall-e-3', 'dall-e-2', 'gpt-image-1']
|
||||
export const isDedicatedImageGenerationModel = (model: Model): boolean =>
|
||||
DEDICATED_IMAGE_MODELS.filter((m) => model.id.includes(m)).length > 0
|
||||
|
||||
// Text to image models
|
||||
export const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus/i
|
||||
|
||||
@ -286,6 +293,7 @@ export function getModelLogo(modelId: string) {
|
||||
o1: isLight ? ChatGPTo1ModelLogo : ChatGPTo1ModelLogoDark,
|
||||
o3: isLight ? ChatGPTo1ModelLogo : ChatGPTo1ModelLogoDark,
|
||||
o4: isLight ? ChatGPTo1ModelLogo : ChatGPTo1ModelLogoDark,
|
||||
'gpt-image': ChatGPTImageModelLogo,
|
||||
'gpt-3': isLight ? ChatGPT35ModelLogo : ChatGPT35ModelLogoDark,
|
||||
'gpt-4': isLight ? ChatGPT4ModelLogo : ChatGPT4ModelLogoDark,
|
||||
gpts: isLight ? ChatGPT4ModelLogo : ChatGPT4ModelLogoDark,
|
||||
@ -307,6 +315,7 @@ export function getModelLogo(modelId: string) {
|
||||
mistral: isLight ? MistralModelLogo : MistralModelLogoDark,
|
||||
codestral: CodestralModelLogo,
|
||||
ministral: isLight ? MistralModelLogo : MistralModelLogoDark,
|
||||
magistral: isLight ? MistralModelLogo : MistralModelLogoDark,
|
||||
moonshot: isLight ? MoonshotModelLogo : MoonshotModelLogoDark,
|
||||
kimi: isLight ? MoonshotModelLogo : MoonshotModelLogoDark,
|
||||
phi: isLight ? MicrosoftModelLogo : MicrosoftModelLogoDark,
|
||||
@ -2246,14 +2255,24 @@ export const TEXT_TO_IMAGES_MODELS_SUPPORT_IMAGE_ENHANCEMENT = [
|
||||
'stabilityai/stable-diffusion-xl-base-1.0'
|
||||
]
|
||||
|
||||
export const SUPPORTED_DISABLE_GENERATION_MODELS = [
|
||||
'gemini-2.0-flash-exp',
|
||||
'gpt-4o',
|
||||
'gpt-4o-mini',
|
||||
'gpt-4.1',
|
||||
'gpt-4.1-mini',
|
||||
'gpt-4.1-nano',
|
||||
'o3'
|
||||
]
|
||||
|
||||
export const GENERATE_IMAGE_MODELS = [
|
||||
'gemini-2.0-flash-exp-image-generation',
|
||||
'gemini-2.0-flash-preview-image-generation',
|
||||
'gemini-2.0-flash-exp',
|
||||
'grok-2-image-1212',
|
||||
'grok-2-image',
|
||||
'grok-2-image-latest',
|
||||
'gpt-image-1'
|
||||
'gpt-image-1',
|
||||
...SUPPORTED_DISABLE_GENERATION_MODELS
|
||||
]
|
||||
|
||||
export const GEMINI_SEARCH_MODELS = [
|
||||
@ -2362,10 +2381,32 @@ export function isSupportedReasoningEffortOpenAIModel(model: Model): boolean {
|
||||
)
|
||||
}
|
||||
|
||||
export function isOpenAIWebSearch(model: Model): boolean {
|
||||
export function isOpenAIChatCompletionOnlyModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
return (
|
||||
model.id.includes('gpt-4o-search-preview') ||
|
||||
model.id.includes('gpt-4o-mini-search-preview') ||
|
||||
model.id.includes('o1-mini') ||
|
||||
model.id.includes('o1-preview')
|
||||
)
|
||||
}
|
||||
|
||||
export function isOpenAIWebSearchChatCompletionOnlyModel(model: Model): boolean {
|
||||
return model.id.includes('gpt-4o-search-preview') || model.id.includes('gpt-4o-mini-search-preview')
|
||||
}
|
||||
|
||||
export function isOpenAIWebSearchModel(model: Model): boolean {
|
||||
return (
|
||||
model.id.includes('gpt-4o-search-preview') ||
|
||||
model.id.includes('gpt-4o-mini-search-preview') ||
|
||||
(model.id.includes('gpt-4.1') && !model.id.includes('gpt-4.1-nano')) ||
|
||||
(model.id.includes('gpt-4o') && !model.id.includes('gpt-4o-image'))
|
||||
)
|
||||
}
|
||||
|
||||
export function isSupportedThinkingTokenModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
@ -2374,7 +2415,8 @@ export function isSupportedThinkingTokenModel(model?: Model): boolean {
|
||||
return (
|
||||
isSupportedThinkingTokenGeminiModel(model) ||
|
||||
isSupportedThinkingTokenQwenModel(model) ||
|
||||
isSupportedThinkingTokenClaudeModel(model)
|
||||
isSupportedThinkingTokenClaudeModel(model) ||
|
||||
isSupportedThinkingTokenDoubaoModel(model)
|
||||
)
|
||||
}
|
||||
|
||||
@ -2456,6 +2498,14 @@ export function isSupportedThinkingTokenQwenModel(model?: Model): boolean {
|
||||
)
|
||||
}
|
||||
|
||||
export function isSupportedThinkingTokenDoubaoModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
return DOUBAO_THINKING_MODEL_REGEX.test(model.id)
|
||||
}
|
||||
|
||||
export function isClaudeReasoningModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
@ -2476,7 +2526,12 @@ export function isReasoningModel(model?: Model): boolean {
|
||||
}
|
||||
|
||||
if (model.provider === 'doubao') {
|
||||
return REASONING_REGEX.test(model.name) || model.type?.includes('reasoning') || false
|
||||
return (
|
||||
REASONING_REGEX.test(model.name) ||
|
||||
model.type?.includes('reasoning') ||
|
||||
isSupportedThinkingTokenDoubaoModel(model) ||
|
||||
false
|
||||
)
|
||||
}
|
||||
|
||||
if (
|
||||
@ -2485,7 +2540,8 @@ export function isReasoningModel(model?: Model): boolean {
|
||||
isGeminiReasoningModel(model) ||
|
||||
isQwenReasoningModel(model) ||
|
||||
isGrokReasoningModel(model) ||
|
||||
model.id.includes('glm-z1')
|
||||
model.id.includes('glm-z1') ||
|
||||
model.id.includes('magistral')
|
||||
) {
|
||||
return true
|
||||
}
|
||||
@ -2506,7 +2562,7 @@ export function isNotSupportTemperatureAndTopP(model: Model): boolean {
|
||||
return true
|
||||
}
|
||||
|
||||
if (isOpenAIReasoningModel(model) || isOpenAIWebSearch(model)) {
|
||||
if (isOpenAIReasoningModel(model) || isOpenAIChatCompletionOnlyModel(model)) {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -2536,17 +2592,13 @@ export function isWebSearchModel(model: Model): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
// 不管哪个供应商都判断了
|
||||
if (model.id.includes('claude')) {
|
||||
return CLAUDE_SUPPORTED_WEBSEARCH_REGEX.test(model.id)
|
||||
}
|
||||
|
||||
if (provider.type === 'openai-response') {
|
||||
if (
|
||||
isOpenAILLMModel(model) &&
|
||||
!isTextToImageModel(model) &&
|
||||
!isOpenAIReasoningModel(model) &&
|
||||
!GENERATE_IMAGE_MODELS.includes(model.id)
|
||||
) {
|
||||
if (isOpenAIWebSearchModel(model)) {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -2558,12 +2610,7 @@ export function isWebSearchModel(model: Model): boolean {
|
||||
}
|
||||
|
||||
if (provider.id === 'aihubmix') {
|
||||
if (
|
||||
isOpenAILLMModel(model) &&
|
||||
!isTextToImageModel(model) &&
|
||||
!isOpenAIReasoningModel(model) &&
|
||||
!GENERATE_IMAGE_MODELS.includes(model.id)
|
||||
) {
|
||||
if (isOpenAIWebSearchModel(model)) {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -2572,7 +2619,7 @@ export function isWebSearchModel(model: Model): boolean {
|
||||
}
|
||||
|
||||
if (provider?.type === 'openai') {
|
||||
if (GEMINI_SEARCH_MODELS.includes(model?.id) || isOpenAIWebSearch(model)) {
|
||||
if (GEMINI_SEARCH_MODELS.includes(model?.id) || isOpenAIWebSearchModel(model)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -2606,6 +2653,20 @@ export function isWebSearchModel(model: Model): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
export function isOpenRouterBuiltInWebSearchModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
const provider = getProviderByModel(model)
|
||||
|
||||
if (provider.id !== 'openrouter') {
|
||||
return false
|
||||
}
|
||||
|
||||
return isOpenAIWebSearchModel(model) || model.id.includes('sonar')
|
||||
}
|
||||
|
||||
export function isGenerateImageModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
@ -2628,56 +2689,60 @@ export function isGenerateImageModel(model: Model): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
export function getOpenAIWebSearchParams(assistant: Assistant, model: Model): Record<string, any> {
|
||||
if (isWebSearchModel(model)) {
|
||||
if (assistant.enableWebSearch) {
|
||||
const webSearchTools = getWebSearchTools(model)
|
||||
export function isSupportedDisableGenerationModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (model.provider === 'grok') {
|
||||
return {
|
||||
search_parameters: {
|
||||
mode: 'auto',
|
||||
return_citations: true,
|
||||
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
|
||||
}
|
||||
}
|
||||
}
|
||||
return SUPPORTED_DISABLE_GENERATION_MODELS.includes(model.id)
|
||||
}
|
||||
|
||||
if (model.provider === 'hunyuan') {
|
||||
return { enable_enhancement: true, citation: true, search_info: true }
|
||||
}
|
||||
export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boolean): Record<string, any> {
|
||||
if (!isEnableWebSearch) {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (model.provider === 'dashscope') {
|
||||
return {
|
||||
enable_search: true,
|
||||
search_options: {
|
||||
forced_search: true
|
||||
}
|
||||
}
|
||||
}
|
||||
const webSearchTools = getWebSearchTools(model)
|
||||
|
||||
if (model.provider === 'openrouter') {
|
||||
return {
|
||||
plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }]
|
||||
}
|
||||
}
|
||||
|
||||
if (isOpenAIWebSearch(model)) {
|
||||
return {
|
||||
web_search_options: {}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
tools: webSearchTools
|
||||
}
|
||||
} else {
|
||||
if (model.provider === 'hunyuan') {
|
||||
return { enable_enhancement: false }
|
||||
if (model.provider === 'grok') {
|
||||
return {
|
||||
search_parameters: {
|
||||
mode: 'auto',
|
||||
return_citations: true,
|
||||
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (model.provider === 'hunyuan') {
|
||||
return { enable_enhancement: true, citation: true, search_info: true }
|
||||
}
|
||||
|
||||
if (model.provider === 'dashscope') {
|
||||
return {
|
||||
enable_search: true,
|
||||
search_options: {
|
||||
forced_search: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isOpenAIWebSearchChatCompletionOnlyModel(model)) {
|
||||
return {
|
||||
web_search_options: {}
|
||||
}
|
||||
}
|
||||
|
||||
if (model.provider === 'openrouter') {
|
||||
return {
|
||||
plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }]
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
tools: webSearchTools
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
@ -2758,3 +2823,16 @@ export const findTokenLimit = (modelId: string): { min: number; max: number } |
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
// Doubao 支持思考模式的模型正则
|
||||
export const DOUBAO_THINKING_MODEL_REGEX =
|
||||
/doubao-(?:1(\.|-5)-thinking-vision-pro|1(\.|-)5-thinking-pro-m|seed-1\.6|seed-1\.6-flash)(?:-[\\w-]+)?/i
|
||||
|
||||
// 支持 auto 的 Doubao 模型
|
||||
export const DOUBAO_THINKING_AUTO_MODEL_REGEX = /doubao-(?:1-5-thinking-pro-m|seed-1.6)(?:-[\\w-]+)?/i
|
||||
|
||||
export function isDoubaoThinkingAutoModel(model: Model): boolean {
|
||||
return DOUBAO_THINKING_AUTO_MODEL_REGEX.test(model.id)
|
||||
}
|
||||
|
||||
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini-.*-flash.*$')
|
||||
|
||||
@ -2,6 +2,8 @@ import db from '@renderer/databases'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { deleteMessageFiles } from '@renderer/services/MessagesService'
|
||||
import store from '@renderer/store'
|
||||
import { setNewlyRenamedTopics, setRenamingTopics } from '@renderer/store/runtime'
|
||||
import { loadTopicMessagesThunk } from '@renderer/store/thunk/messageThunk'
|
||||
import { selectTopicById, topicsActions } from '@renderer/store/topics'
|
||||
import { Assistant, Topic } from '@renderer/types'
|
||||
import { findMainTextBlocks } from '@renderer/utils/messageUtils/find'
|
||||
@ -9,13 +11,6 @@ import { isEmpty } from 'lodash'
|
||||
|
||||
import { getStoreSetting } from './useSettings'
|
||||
|
||||
const renamingTopics = new Set<string>()
|
||||
|
||||
export function useTopic(topicId?: string) {
|
||||
if (!topicId) return undefined
|
||||
return selectTopicById(store.getState(), topicId)
|
||||
}
|
||||
|
||||
export function getTopic(topicId: string) {
|
||||
return selectTopicById(store.getState(), topicId)
|
||||
}
|
||||
@ -26,13 +21,46 @@ export async function getTopicById(topicId: string) {
|
||||
return { ...topic, messages } as Topic
|
||||
}
|
||||
|
||||
/**
|
||||
* 开始重命名指定话题
|
||||
*/
|
||||
export const startTopicRenaming = (topicId: string) => {
|
||||
const currentIds = store.getState().runtime.chat.renamingTopics
|
||||
if (!currentIds.includes(topicId)) {
|
||||
store.dispatch(setRenamingTopics([...currentIds, topicId]))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 完成重命名指定话题
|
||||
*/
|
||||
export const finishTopicRenaming = (topicId: string) => {
|
||||
const state = store.getState()
|
||||
|
||||
// 1. 立即从 renamingTopics 移除
|
||||
const currentRenaming = state.runtime.chat.renamingTopics
|
||||
store.dispatch(setRenamingTopics(currentRenaming.filter((id) => id !== topicId)))
|
||||
|
||||
// 2. 立即添加到 newlyRenamedTopics
|
||||
const currentNewlyRenamed = state.runtime.chat.newlyRenamedTopics
|
||||
store.dispatch(setNewlyRenamedTopics([...currentNewlyRenamed, topicId]))
|
||||
|
||||
// 3. 延迟从 newlyRenamedTopics 移除
|
||||
setTimeout(() => {
|
||||
const current = store.getState().runtime.chat.newlyRenamedTopics
|
||||
store.dispatch(setNewlyRenamedTopics(current.filter((id) => id !== topicId)))
|
||||
}, 700)
|
||||
}
|
||||
|
||||
const topicRenamingLocks = new Set<string>()
|
||||
|
||||
export const autoRenameTopic = async (assistant: Assistant, topicId: string) => {
|
||||
if (renamingTopics.has(topicId)) {
|
||||
if (topicRenamingLocks.has(topicId)) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
renamingTopics.add(topicId)
|
||||
topicRenamingLocks.add(topicId)
|
||||
|
||||
const topic = await getTopicById(topicId)
|
||||
const enableTopicNaming = getStoreSetting('enableTopicNaming')
|
||||
@ -53,22 +81,34 @@ export const autoRenameTopic = async (assistant: Assistant, topicId: string) =>
|
||||
.join('\n\n')
|
||||
.substring(0, 50)
|
||||
if (topicName) {
|
||||
const data = { ...topic, name: topicName } as Topic
|
||||
store.dispatch(topicsActions.updateTopic({ assistantId: assistant.id, topic: data }))
|
||||
try {
|
||||
startTopicRenaming(topicId)
|
||||
|
||||
const data = { ...topic, name: topicName } as Topic
|
||||
store.dispatch(topicsActions.updateTopic({ assistantId: assistant.id, topic: data }))
|
||||
} finally {
|
||||
finishTopicRenaming(topicId)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (topic && topic.name === i18n.t('chat.default.topic.name') && topic.messages.length >= 2) {
|
||||
const { fetchMessagesSummary } = await import('@renderer/services/ApiService')
|
||||
const summaryText = await fetchMessagesSummary({ messages: topic.messages, assistant })
|
||||
if (summaryText) {
|
||||
const data = { ...topic, name: summaryText }
|
||||
store.dispatch(topicsActions.updateTopic({ assistantId: assistant.id, topic: data }))
|
||||
try {
|
||||
startTopicRenaming(topicId)
|
||||
|
||||
const { fetchMessagesSummary } = await import('@renderer/services/ApiService')
|
||||
const summaryText = await fetchMessagesSummary({ messages: topic.messages, assistant })
|
||||
if (summaryText) {
|
||||
const data = { ...topic, name: summaryText }
|
||||
store.dispatch(topicsActions.updateTopic({ assistantId: assistant.id, topic: data }))
|
||||
}
|
||||
} finally {
|
||||
finishTopicRenaming(topicId)
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
renamingTopics.delete(topicId)
|
||||
topicRenamingLocks.delete(topicId)
|
||||
}
|
||||
}
|
||||
|
||||
@ -83,9 +123,18 @@ export const TopicManager = {
|
||||
return await db.topics.toArray()
|
||||
},
|
||||
|
||||
/**
|
||||
* 加载并返回指定话题的消息
|
||||
*/
|
||||
async getTopicMessages(id: string) {
|
||||
const topic = await TopicManager.getTopic(id)
|
||||
return topic ? topic.messages : []
|
||||
if (!topic) return []
|
||||
|
||||
await store.dispatch(loadTopicMessagesThunk(id))
|
||||
|
||||
// 获取更新后的话题
|
||||
const updatedTopic = await TopicManager.getTopic(id)
|
||||
return updatedTopic?.messages || []
|
||||
},
|
||||
|
||||
async removeTopic(id: string) {
|
||||
|
||||
@ -1,118 +0,0 @@
|
||||
// Modified from https://github.com/vercel/ai/blob/845080d80b8538bb9c7e527d2213acb5f33ac9c2/packages/ai/core/middleware/extract-reasoning-middleware.ts
|
||||
|
||||
import { getPotentialStartIndex } from '../utils/getPotentialIndex'
|
||||
|
||||
export interface ExtractReasoningMiddlewareOptions {
|
||||
openingTag: string
|
||||
closingTag: string
|
||||
separator?: string
|
||||
enableReasoning?: boolean
|
||||
}
|
||||
|
||||
function escapeRegExp(str: string) {
|
||||
return str.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&')
|
||||
}
|
||||
|
||||
// 支持泛型 T,默认 T = { type: string; textDelta: string }
|
||||
export function extractReasoningMiddleware<
|
||||
T extends { type: string } & (
|
||||
| { type: 'text-delta' | 'reasoning'; textDelta: string }
|
||||
| { type: string } // 其他类型
|
||||
) = { type: string; textDelta: string }
|
||||
>({ openingTag, closingTag, separator = '\n', enableReasoning }: ExtractReasoningMiddlewareOptions) {
|
||||
const openingTagEscaped = escapeRegExp(openingTag)
|
||||
const closingTagEscaped = escapeRegExp(closingTag)
|
||||
|
||||
return {
|
||||
wrapGenerate: async ({ doGenerate }: { doGenerate: () => Promise<{ text: string } & Record<string, any>> }) => {
|
||||
const { text: rawText, ...rest } = await doGenerate()
|
||||
if (rawText == null) {
|
||||
return { text: rawText, ...rest }
|
||||
}
|
||||
const text = rawText
|
||||
const regexp = new RegExp(`${openingTagEscaped}(.*?)${closingTagEscaped}`, 'gs')
|
||||
const matches = Array.from(text.matchAll(regexp))
|
||||
if (!matches.length) {
|
||||
return { text, ...rest }
|
||||
}
|
||||
const reasoning = matches.map((match: RegExpMatchArray) => match[1]).join(separator)
|
||||
let textWithoutReasoning = text
|
||||
for (let i = matches.length - 1; i >= 0; i--) {
|
||||
const match = matches[i] as RegExpMatchArray
|
||||
const beforeMatch = textWithoutReasoning.slice(0, match.index as number)
|
||||
const afterMatch = textWithoutReasoning.slice((match.index as number) + match[0].length)
|
||||
textWithoutReasoning =
|
||||
beforeMatch + (beforeMatch.length > 0 && afterMatch.length > 0 ? separator : '') + afterMatch
|
||||
}
|
||||
return { ...rest, text: textWithoutReasoning, reasoning }
|
||||
},
|
||||
wrapStream: async ({
|
||||
doStream
|
||||
}: {
|
||||
doStream: () => Promise<{ stream: ReadableStream<T> } & Record<string, any>>
|
||||
}) => {
|
||||
const { stream, ...rest } = await doStream()
|
||||
if (!enableReasoning) {
|
||||
return {
|
||||
stream,
|
||||
...rest
|
||||
}
|
||||
}
|
||||
let isFirstReasoning = true
|
||||
let isFirstText = true
|
||||
let afterSwitch = false
|
||||
let isReasoning = false
|
||||
let buffer = ''
|
||||
return {
|
||||
stream: stream.pipeThrough(
|
||||
new TransformStream<T, T>({
|
||||
transform: (chunk, controller) => {
|
||||
if (chunk.type !== 'text-delta') {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
// textDelta 只在 text-delta/reasoning chunk 上
|
||||
buffer += (chunk as { textDelta: string }).textDelta
|
||||
function publish(text: string) {
|
||||
if (text.length > 0) {
|
||||
const prefix = afterSwitch && (isReasoning ? !isFirstReasoning : !isFirstText) ? separator : ''
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
type: isReasoning ? 'reasoning' : 'text-delta',
|
||||
textDelta: prefix + text
|
||||
} as T)
|
||||
afterSwitch = false
|
||||
if (isReasoning) {
|
||||
isFirstReasoning = false
|
||||
} else {
|
||||
isFirstText = false
|
||||
}
|
||||
}
|
||||
}
|
||||
while (true) {
|
||||
const nextTag = isReasoning ? closingTag : openingTag
|
||||
const startIndex = getPotentialStartIndex(buffer, nextTag)
|
||||
if (startIndex == null) {
|
||||
publish(buffer)
|
||||
buffer = ''
|
||||
break
|
||||
}
|
||||
publish(buffer.slice(0, startIndex))
|
||||
const foundFullMatch = startIndex + nextTag.length <= buffer.length
|
||||
if (foundFullMatch) {
|
||||
buffer = buffer.slice(startIndex + nextTag.length)
|
||||
isReasoning = !isReasoning
|
||||
afterSwitch = true
|
||||
} else {
|
||||
buffer = buffer.slice(startIndex)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
),
|
||||
...rest
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -4,6 +4,7 @@ import TranslateButton from '@renderer/components/TranslateButton'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import {
|
||||
isGenerateImageModel,
|
||||
isSupportedDisableGenerationModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isVisionModel,
|
||||
@ -718,7 +719,7 @@ const Inputbar: FC = () => {
|
||||
if (!isGenerateImageModel(model) && assistant.enableGenerateImage) {
|
||||
updateAssistant({ ...assistant, enableGenerateImage: false })
|
||||
}
|
||||
if (isGenerateImageModel(model) && !assistant.enableGenerateImage && model.id !== 'gemini-2.0-flash-exp') {
|
||||
if (isGenerateImageModel(model) && !assistant.enableGenerateImage && !isSupportedDisableGenerationModel(model)) {
|
||||
updateAssistant({ ...assistant, enableGenerateImage: true })
|
||||
}
|
||||
}, [assistant, model, updateAssistant])
|
||||
|
||||
@ -7,7 +7,9 @@ import {
|
||||
} from '@renderer/components/Icons/SVGIcon'
|
||||
import { useQuickPanel } from '@renderer/components/QuickPanel'
|
||||
import {
|
||||
isDoubaoThinkingAutoModel,
|
||||
isSupportedReasoningEffortGrokModel,
|
||||
isSupportedThinkingTokenDoubaoModel,
|
||||
isSupportedThinkingTokenGeminiModel,
|
||||
isSupportedThinkingTokenQwenModel
|
||||
} from '@renderer/config/models'
|
||||
@ -35,13 +37,14 @@ const MODEL_SUPPORTED_OPTIONS: Record<string, ThinkingOption[]> = {
|
||||
default: ['off', 'low', 'medium', 'high'],
|
||||
grok: ['off', 'low', 'high'],
|
||||
gemini: ['off', 'low', 'medium', 'high', 'auto'],
|
||||
qwen: ['off', 'low', 'medium', 'high']
|
||||
qwen: ['off', 'low', 'medium', 'high'],
|
||||
doubao: ['off', 'auto', 'high']
|
||||
}
|
||||
|
||||
// 选项转换映射表:当选项不支持时使用的替代选项
|
||||
const OPTION_FALLBACK: Record<ThinkingOption, ThinkingOption> = {
|
||||
off: 'off',
|
||||
low: 'low',
|
||||
low: 'high',
|
||||
medium: 'high', // medium -> high (for Grok models)
|
||||
high: 'high',
|
||||
auto: 'high' // auto -> high (for non-Gemini models)
|
||||
@ -55,6 +58,7 @@ const ThinkingButton: FC<Props> = ({ ref, model, assistant, ToolbarButton }): Re
|
||||
const isGrokModel = isSupportedReasoningEffortGrokModel(model)
|
||||
const isGeminiModel = isSupportedThinkingTokenGeminiModel(model)
|
||||
const isQwenModel = isSupportedThinkingTokenQwenModel(model)
|
||||
const isDoubaoModel = isSupportedThinkingTokenDoubaoModel(model)
|
||||
|
||||
const currentReasoningEffort = useMemo(() => {
|
||||
return assistant.settings?.reasoning_effort || 'off'
|
||||
@ -65,13 +69,20 @@ const ThinkingButton: FC<Props> = ({ ref, model, assistant, ToolbarButton }): Re
|
||||
if (isGeminiModel) return 'gemini'
|
||||
if (isGrokModel) return 'grok'
|
||||
if (isQwenModel) return 'qwen'
|
||||
if (isDoubaoModel) return 'doubao'
|
||||
return 'default'
|
||||
}, [isGeminiModel, isGrokModel, isQwenModel])
|
||||
}, [isGeminiModel, isGrokModel, isQwenModel, isDoubaoModel])
|
||||
|
||||
// 获取当前模型支持的选项
|
||||
const supportedOptions = useMemo(() => {
|
||||
if (modelType === 'doubao') {
|
||||
if (isDoubaoThinkingAutoModel(model)) {
|
||||
return ['off', 'auto', 'high'] as ThinkingOption[]
|
||||
}
|
||||
return ['off', 'high'] as ThinkingOption[]
|
||||
}
|
||||
return MODEL_SUPPORTED_OPTIONS[modelType]
|
||||
}, [modelType])
|
||||
}, [model, modelType])
|
||||
|
||||
// 检查当前设置是否与当前模型兼容
|
||||
useEffect(() => {
|
||||
|
||||
@ -24,6 +24,7 @@ import remarkMath from 'remark-math'
|
||||
|
||||
import CodeBlock from './CodeBlock'
|
||||
import Link from './Link'
|
||||
import Table from './Table'
|
||||
|
||||
const ALLOWED_ELEMENTS =
|
||||
/<(style|p|div|span|b|i|strong|em|ul|ol|li|table|tr|td|th|thead|tbody|h[1-6]|blockquote|pre|code|br|hr|svg|path|circle|rect|line|polyline|polygon|text|g|defs|title|desc|tspan|sub|sup)/i
|
||||
@ -83,6 +84,7 @@ const Markdown: FC<Props> = ({ block }) => {
|
||||
code: (props: any) => (
|
||||
<CodeBlock {...props} id={getCodeBlockId(props?.node?.position?.start)} onSave={onSaveCodeBlock} />
|
||||
),
|
||||
table: (props: any) => <Table {...props} blockId={block.id} />,
|
||||
img: (props: any) => <ImageViewer style={{ maxWidth: 500, maxHeight: 500 }} {...props} />,
|
||||
pre: (props: any) => <pre style={{ overflow: 'visible' }} {...props} />,
|
||||
p: (props) => {
|
||||
@ -91,7 +93,7 @@ const Markdown: FC<Props> = ({ block }) => {
|
||||
return <p {...props} />
|
||||
}
|
||||
} as Partial<Components>
|
||||
}, [onSaveCodeBlock])
|
||||
}, [onSaveCodeBlock, block.id])
|
||||
|
||||
const urlTransform = useCallback((value: string) => {
|
||||
if (value.startsWith('data:image/png') || value.startsWith('data:image/jpeg')) return value
|
||||
|
||||
120
src/renderer/src/pages/home/Markdown/Table.tsx
Normal file
120
src/renderer/src/pages/home/Markdown/Table.tsx
Normal file
@ -0,0 +1,120 @@
|
||||
import store from '@renderer/store'
|
||||
import { messageBlocksSelectors } from '@renderer/store/messageBlock'
|
||||
import { Tooltip } from 'antd'
|
||||
import { Check, Copy } from 'lucide-react'
|
||||
import React, { memo, useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
|
||||
interface Props {
|
||||
children: React.ReactNode
|
||||
node?: any
|
||||
blockId?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 自定义 Markdown 表格组件,提供 copy 功能。
|
||||
*/
|
||||
const Table: React.FC<Props> = ({ children, node, blockId }) => {
|
||||
const { t } = useTranslation()
|
||||
const [copied, setCopied] = useState(false)
|
||||
|
||||
const handleCopyTable = useCallback(() => {
|
||||
const tableMarkdown = extractTableMarkdown(blockId ?? '', node?.position)
|
||||
if (!tableMarkdown) return
|
||||
|
||||
navigator.clipboard
|
||||
.writeText(tableMarkdown)
|
||||
.then(() => {
|
||||
setCopied(true)
|
||||
setTimeout(() => setCopied(false), 2000)
|
||||
})
|
||||
.catch((error) => {
|
||||
window.message?.error({ content: `${t('message.copy.failed')}: ${error}`, key: 'copy-table-error' })
|
||||
})
|
||||
}, [node, blockId, t])
|
||||
|
||||
return (
|
||||
<TableWrapper className="table-wrapper">
|
||||
<table>{children}</table>
|
||||
<ToolbarWrapper className="table-toolbar">
|
||||
<Tooltip title={t('common.copy')} mouseEnterDelay={0.8}>
|
||||
<ToolButton role="button" aria-label={t('common.copy')} onClick={handleCopyTable}>
|
||||
{copied ? (
|
||||
<Check size={14} style={{ color: 'var(--color-primary)' }} data-testid="check-icon" />
|
||||
) : (
|
||||
<Copy size={14} data-testid="copy-icon" />
|
||||
)}
|
||||
</ToolButton>
|
||||
</Tooltip>
|
||||
</ToolbarWrapper>
|
||||
</TableWrapper>
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 从原始 Markdown 内容中提取表格源代码
|
||||
* @param blockId 消息块 ID
|
||||
* @param position 表格节点的位置信息
|
||||
* @returns 源代码
|
||||
*/
|
||||
export function extractTableMarkdown(blockId: string, position: any): string {
|
||||
if (!position || !blockId) return ''
|
||||
|
||||
const block = messageBlocksSelectors.selectById(store.getState(), blockId)
|
||||
|
||||
if (!block || !('content' in block) || typeof block.content !== 'string') return ''
|
||||
|
||||
const { start, end } = position
|
||||
const lines = block.content.split('\n')
|
||||
|
||||
// 提取表格对应的行(行号从1开始,数组索引从0开始)
|
||||
const tableLines = lines.slice(start.line - 1, end.line)
|
||||
return tableLines.join('\n').trim()
|
||||
}
|
||||
|
||||
const TableWrapper = styled.div`
|
||||
position: relative;
|
||||
|
||||
.table-toolbar {
|
||||
border-radius: 4px;
|
||||
opacity: 0;
|
||||
transition: opacity 0.2s ease;
|
||||
transform: translateZ(0);
|
||||
will-change: opacity;
|
||||
}
|
||||
&:hover {
|
||||
.table-toolbar {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
const ToolbarWrapper = styled.div`
|
||||
position: absolute;
|
||||
top: 8px;
|
||||
right: 8px;
|
||||
z-index: 10;
|
||||
`
|
||||
|
||||
const ToolButton = styled.div`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: all 0.2s ease;
|
||||
opacity: 1;
|
||||
color: var(--color-text-3);
|
||||
background-color: var(--color-background-mute);
|
||||
will-change: background-color, opacity;
|
||||
|
||||
&:hover {
|
||||
background-color: var(--color-background-soft);
|
||||
}
|
||||
`
|
||||
|
||||
export default memo(Table)
|
||||
@ -78,6 +78,18 @@ vi.mock('../Link', () => ({
|
||||
)
|
||||
}))
|
||||
|
||||
vi.mock('../Table', () => ({
|
||||
__esModule: true,
|
||||
default: ({ children, blockId }: any) => (
|
||||
<div data-testid="table-component" data-block-id={blockId}>
|
||||
<table>{children}</table>
|
||||
<button type="button" data-testid="copy-table-button">
|
||||
Copy Table
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/components/MarkdownShadowDOMRenderer', () => ({
|
||||
__esModule: true,
|
||||
default: ({ children }: any) => <div data-testid="shadow-dom">{children}</div>
|
||||
@ -104,6 +116,11 @@ vi.mock('react-markdown', () => ({
|
||||
{components.code({ children: 'test code', node: { position: { start: { line: 1 } } } })}
|
||||
</div>
|
||||
)}
|
||||
{components?.table && (
|
||||
<div data-testid="has-table-component">
|
||||
{components.table({ children: 'test table', node: { position: { start: { line: 1 } } } })}
|
||||
</div>
|
||||
)}
|
||||
{components?.img && <span data-testid="has-img-component">img</span>}
|
||||
{components?.style && <span data-testid="has-style-component">style</span>}
|
||||
</div>
|
||||
@ -300,6 +317,16 @@ describe('Markdown', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should integrate Table component with copy functionality', () => {
|
||||
const block = createMainTextBlock({ id: 'test-block-456' })
|
||||
render(<Markdown block={block} />)
|
||||
|
||||
expect(screen.getByTestId('has-table-component')).toBeInTheDocument()
|
||||
|
||||
const tableComponent = screen.getByTestId('table-component')
|
||||
expect(tableComponent).toHaveAttribute('data-block-id', 'test-block-456')
|
||||
})
|
||||
|
||||
it('should integrate ImagePreview component', () => {
|
||||
render(<Markdown block={createMainTextBlock()} />)
|
||||
|
||||
|
||||
316
src/renderer/src/pages/home/Markdown/__tests__/Table.test.tsx
Normal file
316
src/renderer/src/pages/home/Markdown/__tests__/Table.test.tsx
Normal file
@ -0,0 +1,316 @@
|
||||
import { act, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import Table, { extractTableMarkdown } from '../Table'
|
||||
|
||||
const mocks = vi.hoisted(() => {
|
||||
return {
|
||||
store: {
|
||||
getState: vi.fn()
|
||||
},
|
||||
messageBlocksSelectors: {
|
||||
selectById: vi.fn()
|
||||
},
|
||||
windowMessage: {
|
||||
error: vi.fn()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@renderer/store', () => ({
|
||||
__esModule: true,
|
||||
default: mocks.store
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/messageBlock', () => ({
|
||||
messageBlocksSelectors: mocks.messageBlocksSelectors
|
||||
}))
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('antd', () => ({
|
||||
Tooltip: ({ children, title }: any) => (
|
||||
<div data-testid="tooltip" title={title}>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}))
|
||||
|
||||
Object.assign(window, {
|
||||
message: mocks.windowMessage
|
||||
})
|
||||
|
||||
describe('Table', () => {
|
||||
beforeAll(() => {
|
||||
vi.stubGlobal('jest', {
|
||||
advanceTimersByTime: vi.advanceTimersByTime.bind(vi)
|
||||
})
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.useFakeTimers()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.runOnlyPendingTimers()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
afterAll(() => {
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
// https://testing-library.com/docs/user-event/clipboard/
|
||||
const user = userEvent.setup({
|
||||
advanceTimers: vi.advanceTimersByTime.bind(vi),
|
||||
writeToClipboard: true
|
||||
})
|
||||
|
||||
// Test data factories
|
||||
const createMockBlock = (content: string = defaultTableContent) => ({
|
||||
id: 'test-block-1',
|
||||
content
|
||||
})
|
||||
|
||||
const createTablePosition = (startLine = 1, endLine = 3) => ({
|
||||
start: { line: startLine },
|
||||
end: { line: endLine }
|
||||
})
|
||||
|
||||
const defaultTableContent = `| Header 1 | Header 2 |
|
||||
|----------|----------|
|
||||
| Cell 1 | Cell 2 |`
|
||||
|
||||
const defaultProps = {
|
||||
children: (
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>Cell 1</td>
|
||||
<td>Cell 2</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
),
|
||||
blockId: 'test-block-1',
|
||||
node: { position: createTablePosition() }
|
||||
}
|
||||
|
||||
const getCopyButton = () => screen.getByRole('button', { name: /common\.copy/i })
|
||||
const getCopyIcon = () => screen.getByTestId('copy-icon')
|
||||
const getCheckIcon = () => screen.getByTestId('check-icon')
|
||||
const queryCheckIcon = () => screen.queryByTestId('check-icon')
|
||||
const queryCopyIcon = () => screen.queryByTestId('copy-icon')
|
||||
|
||||
describe('rendering', () => {
|
||||
it('should render table with children and toolbar', () => {
|
||||
render(<Table {...defaultProps} />)
|
||||
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
expect(screen.getByText('Cell 1')).toBeInTheDocument()
|
||||
expect(screen.getByText('Cell 2')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('tooltip')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with table-wrapper and table-toolbar classes', () => {
|
||||
const { container } = render(<Table {...defaultProps} />)
|
||||
|
||||
expect(container.querySelector('.table-wrapper')).toBeInTheDocument()
|
||||
expect(container.querySelector('.table-toolbar')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render copy button with correct tooltip', () => {
|
||||
render(<Table {...defaultProps} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
expect(tooltip).toHaveAttribute('title', 'common.copy')
|
||||
})
|
||||
|
||||
it('should match snapshot', () => {
|
||||
const { container } = render(<Table {...defaultProps} />)
|
||||
expect(container.firstChild).toMatchSnapshot()
|
||||
})
|
||||
})
|
||||
|
||||
describe('extractTableMarkdown', () => {
|
||||
beforeEach(() => {
|
||||
mocks.store.getState.mockReturnValue({})
|
||||
})
|
||||
|
||||
it('should extract table content from specified line range', () => {
|
||||
const block = createMockBlock()
|
||||
const position = createTablePosition(1, 3)
|
||||
mocks.messageBlocksSelectors.selectById.mockReturnValue(block)
|
||||
|
||||
const result = extractTableMarkdown('test-block-1', position)
|
||||
|
||||
expect(result).toBe(defaultTableContent)
|
||||
expect(mocks.messageBlocksSelectors.selectById).toHaveBeenCalledWith({}, 'test-block-1')
|
||||
})
|
||||
|
||||
it('should handle line range extraction correctly', () => {
|
||||
const multiLineContent = `Line 0
|
||||
| Header 1 | Header 2 |
|
||||
|----------|----------|
|
||||
| Cell 1 | Cell 2 |
|
||||
Line 4`
|
||||
const block = createMockBlock(multiLineContent)
|
||||
const position = createTablePosition(2, 4) // Extract lines 2-4 (table part)
|
||||
mocks.messageBlocksSelectors.selectById.mockReturnValue(block)
|
||||
|
||||
const result = extractTableMarkdown('test-block-1', position)
|
||||
|
||||
expect(result).toBe(`| Header 1 | Header 2 |
|
||||
|----------|----------|
|
||||
| Cell 1 | Cell 2 |`)
|
||||
})
|
||||
|
||||
it('should return empty string when blockId is empty', () => {
|
||||
const result = extractTableMarkdown('', createTablePosition())
|
||||
expect(result).toBe('')
|
||||
expect(mocks.messageBlocksSelectors.selectById).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return empty string when position is null', () => {
|
||||
const result = extractTableMarkdown('test-block-1', null)
|
||||
expect(result).toBe('')
|
||||
expect(mocks.messageBlocksSelectors.selectById).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return empty string when position is undefined', () => {
|
||||
const result = extractTableMarkdown('test-block-1', undefined)
|
||||
expect(result).toBe('')
|
||||
expect(mocks.messageBlocksSelectors.selectById).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return empty string when block does not exist', () => {
|
||||
mocks.messageBlocksSelectors.selectById.mockReturnValue(null)
|
||||
|
||||
const result = extractTableMarkdown('non-existent-block', createTablePosition())
|
||||
|
||||
expect(result).toBe('')
|
||||
})
|
||||
|
||||
it('should return empty string when block has no content property', () => {
|
||||
const blockWithoutContent = { id: 'test-block-1' }
|
||||
mocks.messageBlocksSelectors.selectById.mockReturnValue(blockWithoutContent)
|
||||
|
||||
const result = extractTableMarkdown('test-block-1', createTablePosition())
|
||||
|
||||
expect(result).toBe('')
|
||||
})
|
||||
|
||||
it('should return empty string when block content is not a string', () => {
|
||||
const blockWithInvalidContent = { id: 'test-block-1', content: 123 }
|
||||
mocks.messageBlocksSelectors.selectById.mockReturnValue(blockWithInvalidContent)
|
||||
|
||||
const result = extractTableMarkdown('test-block-1', createTablePosition())
|
||||
|
||||
expect(result).toBe('')
|
||||
})
|
||||
|
||||
it('should handle boundary line numbers correctly', () => {
|
||||
const block = createMockBlock('Line 1\nLine 2\nLine 3')
|
||||
const position = createTablePosition(1, 3)
|
||||
mocks.messageBlocksSelectors.selectById.mockReturnValue(block)
|
||||
|
||||
const result = extractTableMarkdown('test-block-1', position)
|
||||
|
||||
expect(result).toBe('Line 1\nLine 2\nLine 3')
|
||||
})
|
||||
})
|
||||
|
||||
describe('copy functionality', () => {
|
||||
beforeEach(() => {
|
||||
mocks.messageBlocksSelectors.selectById.mockReturnValue(createMockBlock())
|
||||
})
|
||||
|
||||
it('should copy table content to clipboard on button click', async () => {
|
||||
render(<Table {...defaultProps} />)
|
||||
|
||||
const copyButton = getCopyButton()
|
||||
await user.click(copyButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(getCheckIcon()).toBeInTheDocument()
|
||||
expect(queryCopyIcon()).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should show check icon after successful copy', async () => {
|
||||
render(<Table {...defaultProps} />)
|
||||
|
||||
// Initially shows copy icon
|
||||
expect(getCopyIcon()).toBeInTheDocument()
|
||||
|
||||
const copyButton = getCopyButton()
|
||||
await user.click(copyButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(getCheckIcon()).toBeInTheDocument()
|
||||
expect(queryCopyIcon()).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should reset to copy icon after 2 seconds', async () => {
|
||||
render(<Table {...defaultProps} />)
|
||||
|
||||
const copyButton = getCopyButton()
|
||||
await user.click(copyButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(getCheckIcon()).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Fast forward 2 seconds
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(2000)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(getCopyIcon()).toBeInTheDocument()
|
||||
expect(queryCheckIcon()).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not copy when extractTableMarkdown returns empty string', async () => {
|
||||
mocks.messageBlocksSelectors.selectById.mockReturnValue(null)
|
||||
|
||||
render(<Table {...defaultProps} />)
|
||||
|
||||
const copyButton = getCopyButton()
|
||||
await user.click(copyButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(getCopyIcon()).toBeInTheDocument()
|
||||
expect(queryCheckIcon()).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should work without blockId', () => {
|
||||
const propsWithoutBlockId = { ...defaultProps, blockId: undefined }
|
||||
|
||||
expect(() => render(<Table {...propsWithoutBlockId} />)).not.toThrow()
|
||||
|
||||
const copyButton = getCopyButton()
|
||||
expect(copyButton).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should work without node position', () => {
|
||||
const propsWithoutPosition = { ...defaultProps, node: undefined }
|
||||
|
||||
expect(() => render(<Table {...propsWithoutPosition} />)).not.toThrow()
|
||||
|
||||
const copyButton = getCopyButton()
|
||||
expect(copyButton).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -30,6 +30,24 @@ This is **bold** text.
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
data-testid="has-table-component"
|
||||
>
|
||||
<div
|
||||
data-block-id="test-block-1"
|
||||
data-testid="table-component"
|
||||
>
|
||||
<table>
|
||||
test table
|
||||
</table>
|
||||
<button
|
||||
data-testid="copy-table-button"
|
||||
type="button"
|
||||
>
|
||||
Copy Table
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<span
|
||||
data-testid="has-img-component"
|
||||
>
|
||||
|
||||
@ -0,0 +1,103 @@
|
||||
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
|
||||
|
||||
exports[`Table > rendering > should match snapshot 1`] = `
|
||||
.c0 {
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.c0 .table-toolbar {
|
||||
border-radius: 4px;
|
||||
opacity: 0;
|
||||
transition: opacity 0.2s ease;
|
||||
transform: translateZ(0);
|
||||
will-change: opacity;
|
||||
}
|
||||
|
||||
.c0:hover .table-toolbar {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.c1 {
|
||||
position: absolute;
|
||||
top: 8px;
|
||||
right: 8px;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
.c2 {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: all 0.2s ease;
|
||||
opacity: 1;
|
||||
color: var(--color-text-3);
|
||||
background-color: var(--color-background-mute);
|
||||
will-change: background-color,opacity;
|
||||
}
|
||||
|
||||
.c2:hover {
|
||||
background-color: var(--color-background-soft);
|
||||
}
|
||||
|
||||
<div
|
||||
class="c0 table-wrapper"
|
||||
>
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>
|
||||
Cell 1
|
||||
</td>
|
||||
<td>
|
||||
Cell 2
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
<div
|
||||
class="c1 table-toolbar"
|
||||
>
|
||||
<div
|
||||
data-testid="tooltip"
|
||||
title="common.copy"
|
||||
>
|
||||
<div
|
||||
aria-label="common.copy"
|
||||
class="c2"
|
||||
role="button"
|
||||
>
|
||||
<svg
|
||||
class="lucide lucide-copy"
|
||||
data-testid="copy-icon"
|
||||
fill="none"
|
||||
height="14"
|
||||
stroke="currentColor"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
viewBox="0 0 24 24"
|
||||
width="14"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<rect
|
||||
height="14"
|
||||
rx="2"
|
||||
ry="2"
|
||||
width="14"
|
||||
x="8"
|
||||
y="8"
|
||||
/>
|
||||
<path
|
||||
d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
@ -40,7 +40,18 @@ function CitationBlock({ block }: { block: CitationMessageBlock }) {
|
||||
__html:
|
||||
(block.response?.results as GroundingMetadata)?.searchEntryPoint?.renderedContent
|
||||
?.replace(/@media \(prefers-color-scheme: light\)/g, 'body[theme-mode="light"]')
|
||||
.replace(/@media \(prefers-color-scheme: dark\)/g, 'body[theme-mode="dark"]') || ''
|
||||
.replace(/@media \(prefers-color-scheme: dark\)/g, 'body[theme-mode="dark"]')
|
||||
.replace(
|
||||
/background-color\s*:\s*#[0-9a-fA-F]{3,6}\b|\bbackground-color\s*:\s*[a-zA-Z-]+\b/g,
|
||||
'background-color: var(--color-background-soft)'
|
||||
)
|
||||
.replace(/\.gradient\s*{[^}]*background\s*:\s*[^};]+[;}]/g, (match) => {
|
||||
// Remove the background property while preserving the rest
|
||||
return match.replace(/background\s*:\s*[^};]+;?\s*/g, '')
|
||||
})
|
||||
.replace(/\.chip {\n/g, '.chip {\n background-color: var(--color-background)!important;\n')
|
||||
.replace(/border-color\s*:\s*[^};]+;?\s*/g, '')
|
||||
.replace(/border\s*:\s*[^};]+;?\s*/g, '') || ''
|
||||
}}
|
||||
/>
|
||||
</>
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
|
||||
import ImageViewer from '@renderer/components/ImageViewer'
|
||||
import type { ImageMessageBlock } from '@renderer/types/newMessage'
|
||||
import { type ImageMessageBlock } from '@renderer/types/newMessage'
|
||||
import React from 'react'
|
||||
import styled from 'styled-components'
|
||||
|
||||
|
||||
@ -23,7 +23,8 @@ const EXCLUDED_SELECTORS = [
|
||||
'.ant-collapse-header',
|
||||
'.group-menu-bar',
|
||||
'.code-block',
|
||||
'.message-editor'
|
||||
'.message-editor',
|
||||
'.table-wrapper'
|
||||
]
|
||||
|
||||
// Gap between the navigation bar and the right element
|
||||
|
||||
@ -17,7 +17,7 @@ import { isMac } from '@renderer/config/constant'
|
||||
import { useAssistant, useAssistants, useTopicsForAssistant } from '@renderer/hooks/useAssistant'
|
||||
import { modelGenerating } from '@renderer/hooks/useRuntime'
|
||||
import { useSettings } from '@renderer/hooks/useSettings'
|
||||
import { TopicManager } from '@renderer/hooks/useTopic'
|
||||
import { finishTopicRenaming, startTopicRenaming, TopicManager } from '@renderer/hooks/useTopic'
|
||||
import { fetchMessagesSummary } from '@renderer/services/ApiService'
|
||||
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
|
||||
import store from '@renderer/store'
|
||||
@ -58,6 +58,9 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
|
||||
|
||||
const topics = useTopicsForAssistant(_assistant.id)
|
||||
|
||||
const renamingTopics = useSelector((state: RootState) => state.runtime.chat.renamingTopics)
|
||||
const newlyRenamedTopics = useSelector((state: RootState) => state.runtime.chat.newlyRenamedTopics)
|
||||
|
||||
const borderRadius = showTopicTime ? 12 : 'var(--list-item-border-radius)'
|
||||
|
||||
const [deletingTopicId, setDeletingTopicId] = useState<string | null>(null)
|
||||
@ -85,6 +88,20 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
|
||||
[activeTopic.id, pendingTopics]
|
||||
)
|
||||
|
||||
const isRenaming = useCallback(
|
||||
(topicId: string) => {
|
||||
return renamingTopics.includes(topicId)
|
||||
},
|
||||
[renamingTopics]
|
||||
)
|
||||
|
||||
const isNewlyRenamed = useCallback(
|
||||
(topicId: string) => {
|
||||
return newlyRenamedTopics.includes(topicId)
|
||||
},
|
||||
[newlyRenamedTopics]
|
||||
)
|
||||
|
||||
const handleDeleteClick = useCallback((topicId: string, e: React.MouseEvent) => {
|
||||
e.stopPropagation()
|
||||
|
||||
@ -171,16 +188,22 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
|
||||
label: t('chat.topics.auto_rename'),
|
||||
key: 'auto-rename',
|
||||
icon: <i className="iconfont icon-business-smart-assistant" style={{ fontSize: '14px' }} />,
|
||||
disabled: isRenaming(topic.id),
|
||||
async onClick() {
|
||||
const messages = await TopicManager.getTopicMessages(topic.id)
|
||||
if (messages.length >= 2) {
|
||||
const summaryText = await fetchMessagesSummary({ messages, assistant })
|
||||
if (summaryText) {
|
||||
const updatedTopic = { ...topic, name: summaryText, isNameManuallyEdited: false }
|
||||
updateTopic(updatedTopic)
|
||||
topic.id === activeTopic.id && setActiveTopic(updatedTopic)
|
||||
} else {
|
||||
window.message?.error(t('message.error.fetchTopicName'))
|
||||
startTopicRenaming(topic.id)
|
||||
try {
|
||||
const summaryText = await fetchMessagesSummary({ messages, assistant })
|
||||
if (summaryText) {
|
||||
const updatedTopic = { ...topic, name: summaryText, isNameManuallyEdited: false }
|
||||
updateTopic(updatedTopic)
|
||||
topic.id === activeTopic.id && setActiveTopic(updatedTopic)
|
||||
} else {
|
||||
window.message?.error(t('message.error.fetchTopicName'))
|
||||
}
|
||||
} finally {
|
||||
finishTopicRenaming(topic.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -189,6 +212,7 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
|
||||
label: t('chat.topics.edit.title'),
|
||||
key: 'rename',
|
||||
icon: <EditOutlined />,
|
||||
disabled: isRenaming(topic.id),
|
||||
async onClick() {
|
||||
const name = await PromptPopup.show({
|
||||
title: t('chat.topics.edit.title'),
|
||||
@ -372,6 +396,7 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
|
||||
}, [
|
||||
targetTopic,
|
||||
t,
|
||||
isRenaming,
|
||||
exportMenuOptions.image,
|
||||
exportMenuOptions.markdown,
|
||||
exportMenuOptions.markdown_reason,
|
||||
@ -414,6 +439,13 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
|
||||
const topicName = topic.name.replace('`', '')
|
||||
const topicPrompt = topic.prompt
|
||||
const fullTopicPrompt = t('common.prompt') + ': ' + topicPrompt
|
||||
|
||||
const getTopicNameClassName = () => {
|
||||
if (isRenaming(topic.id)) return 'shimmer'
|
||||
if (isNewlyRenamed(topic.id)) return 'typing'
|
||||
return ''
|
||||
}
|
||||
|
||||
return (
|
||||
<TopicListItem
|
||||
onContextMenu={() => setTargetTopic(topic)}
|
||||
@ -422,7 +454,7 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
|
||||
style={{ borderRadius }}>
|
||||
{isPending(topic.id) && !isActive && <PendingIndicator />}
|
||||
<TopicNameContainer>
|
||||
<TopicName className="name" title={topicName}>
|
||||
<TopicName className={getTopicNameClassName()} title={topicName}>
|
||||
{topicName}
|
||||
</TopicName>
|
||||
{isActive && !topic.pinned && (
|
||||
@ -526,6 +558,46 @@ const TopicName = styled.div`
|
||||
-webkit-box-orient: vertical;
|
||||
overflow: hidden;
|
||||
font-size: 13px;
|
||||
position: relative;
|
||||
will-change: background-position, width;
|
||||
|
||||
--color-shimmer-mid: var(--color-text-1);
|
||||
--color-shimmer-end: color-mix(in srgb, var(--color-text-1) 25%, transparent);
|
||||
|
||||
&.shimmer {
|
||||
background: linear-gradient(to left, var(--color-shimmer-end), var(--color-shimmer-mid), var(--color-shimmer-end));
|
||||
background-size: 200% 100%;
|
||||
background-clip: text;
|
||||
color: transparent;
|
||||
animation: shimmer 3s linear infinite;
|
||||
}
|
||||
|
||||
&.typing {
|
||||
display: block;
|
||||
-webkit-line-clamp: unset;
|
||||
-webkit-box-orient: unset;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
animation: typewriter 0.5s steps(40, end);
|
||||
}
|
||||
|
||||
@keyframes shimmer {
|
||||
0% {
|
||||
background-position: 200% 0;
|
||||
}
|
||||
100% {
|
||||
background-position: -200% 0;
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes typewriter {
|
||||
from {
|
||||
width: 0;
|
||||
}
|
||||
to {
|
||||
width: 100%;
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
const PendingIndicator = styled.div.attrs({
|
||||
|
||||
@ -1,124 +0,0 @@
|
||||
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
|
||||
import { fetchSuggestions } from '@renderer/services/ApiService'
|
||||
import { getUserMessage } from '@renderer/services/MessagesService'
|
||||
import { useAppDispatch } from '@renderer/store'
|
||||
import { sendMessage } from '@renderer/store/thunk/messageThunk'
|
||||
import { Assistant, Suggestion } from '@renderer/types'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import { last } from 'lodash'
|
||||
import { FC, memo, useEffect, useState } from 'react'
|
||||
import styled from 'styled-components'
|
||||
|
||||
interface Props {
|
||||
assistant: Assistant
|
||||
messages: Message[]
|
||||
}
|
||||
|
||||
const suggestionsMap = new Map<string, Suggestion[]>()
|
||||
|
||||
const Suggestions: FC<Props> = ({ assistant, messages }) => {
|
||||
const dispatch = useAppDispatch()
|
||||
|
||||
const [suggestions, setSuggestions] = useState<Suggestion[]>(
|
||||
suggestionsMap.get(messages[messages.length - 1]?.id) || []
|
||||
)
|
||||
const [loadingSuggestions, setLoadingSuggestions] = useState(false)
|
||||
|
||||
const handleSuggestionClick = async (content: string) => {
|
||||
const { message: userMessage, blocks } = getUserMessage({
|
||||
assistant,
|
||||
topic: assistant.topics[0],
|
||||
content
|
||||
})
|
||||
|
||||
await dispatch(sendMessage(userMessage, blocks, assistant, assistant.topics[0].id))
|
||||
}
|
||||
|
||||
const suggestionsHandle = async () => {
|
||||
if (loadingSuggestions) return
|
||||
try {
|
||||
setLoadingSuggestions(true)
|
||||
const _suggestions = await fetchSuggestions({
|
||||
assistant,
|
||||
messages
|
||||
})
|
||||
if (_suggestions.length) {
|
||||
setSuggestions(_suggestions)
|
||||
suggestionsMap.set(messages[messages.length - 1].id, _suggestions)
|
||||
}
|
||||
} finally {
|
||||
setLoadingSuggestions(false)
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
suggestionsHandle()
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
setSuggestions(suggestionsMap.get(messages[messages.length - 1]?.id) || [])
|
||||
}, [messages])
|
||||
|
||||
if (last(messages)?.status !== 'success') {
|
||||
return null
|
||||
}
|
||||
if (loadingSuggestions) {
|
||||
return (
|
||||
<Container>
|
||||
<SvgSpinners180Ring color="var(--color-text-2)" />
|
||||
</Container>
|
||||
)
|
||||
}
|
||||
|
||||
if (suggestions.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<Container>
|
||||
<SuggestionsContainer>
|
||||
{suggestions.map((s, i) => (
|
||||
<SuggestionItem key={i} onClick={() => handleSuggestionClick(s.content)}>
|
||||
{s.content} →
|
||||
</SuggestionItem>
|
||||
))}
|
||||
</SuggestionsContainer>
|
||||
</Container>
|
||||
)
|
||||
}
|
||||
|
||||
const Container = styled.div`
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
padding: 10px 10px 20px 65px;
|
||||
display: flex;
|
||||
width: 100%;
|
||||
flex-direction: row;
|
||||
flex-wrap: wrap;
|
||||
gap: 15px;
|
||||
`
|
||||
|
||||
const SuggestionsContainer = styled.div`
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
flex-wrap: wrap;
|
||||
gap: 10px;
|
||||
`
|
||||
|
||||
const SuggestionItem = styled.div`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
width: fit-content;
|
||||
padding: 5px 10px;
|
||||
border-radius: 12px;
|
||||
font-size: 12px;
|
||||
color: var(--color-text);
|
||||
background: var(--color-background-mute);
|
||||
cursor: pointer;
|
||||
&:hover {
|
||||
opacity: 0.9;
|
||||
}
|
||||
`
|
||||
|
||||
export default memo(Suggestions)
|
||||
@ -1,3 +1,4 @@
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
@ -6,7 +7,6 @@ import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { SettingHelpText } from '@renderer/pages/settings'
|
||||
import AiProvider from '@renderer/providers/AiProvider'
|
||||
import { getKnowledgeBaseParams } from '@renderer/services/KnowledgeService'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { KnowledgeBase, Model } from '@renderer/types'
|
||||
|
||||
@ -11,7 +11,7 @@ import { usePaintings } from '@renderer/hooks/usePaintings'
|
||||
import { useAllProviders } from '@renderer/hooks/useProvider'
|
||||
import { useRuntime } from '@renderer/hooks/useRuntime'
|
||||
import { useSettings } from '@renderer/hooks/useSettings'
|
||||
import AiProvider from '@renderer/providers/AiProvider'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { translateText } from '@renderer/services/TranslateService'
|
||||
import { useAppDispatch } from '@renderer/store'
|
||||
@ -182,11 +182,9 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => {
|
||||
const base64s = await AI.generateImage({
|
||||
prompt,
|
||||
model: painting.model,
|
||||
config: {
|
||||
aspectRatio: painting.aspectRatio?.replace('ASPECT_', '').replace('_', ':'),
|
||||
numberOfImages: painting.model.startsWith('imagen-4.0-ultra-generate-exp') ? 1 : painting.numberOfImages,
|
||||
personGeneration: painting.personGeneration
|
||||
}
|
||||
imageSize: painting.aspectRatio?.replace('ASPECT_', '').replace('_', ':') || '1:1',
|
||||
batchSize: painting.model.startsWith('imagen-4.0-ultra-generate-exp') ? 1 : painting.numberOfImages || 1,
|
||||
personGeneration: painting.personGeneration
|
||||
})
|
||||
if (base64s?.length > 0) {
|
||||
const validFiles = await Promise.all(
|
||||
|
||||
@ -16,7 +16,7 @@ import { usePaintings } from '@renderer/hooks/usePaintings'
|
||||
import { useAllProviders } from '@renderer/hooks/useProvider'
|
||||
import { useRuntime } from '@renderer/hooks/useRuntime'
|
||||
import { useSettings } from '@renderer/hooks/useSettings'
|
||||
import AiProvider from '@renderer/providers/AiProvider'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { translateText } from '@renderer/services/TranslateService'
|
||||
|
||||
@ -51,8 +51,8 @@ const PopupContainer: React.FC<Props> = ({ title, provider, model, apiKeys, type
|
||||
try {
|
||||
let valid = false
|
||||
if (type === 'provider' && model) {
|
||||
const result = await checkApi({ ...(provider as Provider), apiKey: status.key }, model)
|
||||
valid = result.valid
|
||||
await checkApi({ ...(provider as Provider), apiKey: status.key }, model)
|
||||
valid = true
|
||||
} else {
|
||||
const result = await WebSearchService.checkSearch({
|
||||
...(provider as WebSearchProvider),
|
||||
@ -65,7 +65,7 @@ const PopupContainer: React.FC<Props> = ({ title, provider, model, apiKeys, type
|
||||
setKeyStatuses((prev) => prev.map((s, idx) => (idx === i ? { ...s, checking: false, isValid: valid } : s)))
|
||||
|
||||
return { index: i, valid }
|
||||
} catch (error) {
|
||||
} catch (error: unknown) {
|
||||
// 处理错误情况
|
||||
setKeyStatuses((prev) => prev.map((s, idx) => (idx === i ? { ...s, checking: false, isValid: false } : s)))
|
||||
return { index: i, valid: false }
|
||||
@ -90,8 +90,8 @@ const PopupContainer: React.FC<Props> = ({ title, provider, model, apiKeys, type
|
||||
try {
|
||||
let valid = false
|
||||
if (type === 'provider' && model) {
|
||||
const result = await checkApi({ ...(provider as Provider), apiKey: keyStatuses[keyIndex].key }, model)
|
||||
valid = result.valid
|
||||
await checkApi({ ...(provider as Provider), apiKey: keyStatuses[keyIndex].key }, model)
|
||||
valid = true
|
||||
} else {
|
||||
const result = await WebSearchService.checkSearch({
|
||||
...(provider as WebSearchProvider),
|
||||
@ -103,7 +103,7 @@ const PopupContainer: React.FC<Props> = ({ title, provider, model, apiKeys, type
|
||||
setKeyStatuses((prev) =>
|
||||
prev.map((status, idx) => (idx === keyIndex ? { ...status, checking: false, isValid: valid } : status))
|
||||
)
|
||||
} catch (error) {
|
||||
} catch (error: unknown) {
|
||||
setKeyStatuses((prev) =>
|
||||
prev.map((status, idx) => (idx === keyIndex ? { ...status, checking: false, isValid: false } : status))
|
||||
)
|
||||
|
||||
@ -145,14 +145,17 @@ const PopupContainer: React.FC<Props> = ({ provider: _provider, resolve }) => {
|
||||
setListModels(
|
||||
models
|
||||
.map((model) => ({
|
||||
id: model.id,
|
||||
// @ts-ignore modelId
|
||||
id: model?.id || model?.name,
|
||||
// @ts-ignore name
|
||||
name: model.name || model.id,
|
||||
name: model?.display_name || model?.displayName || model?.name || model?.id,
|
||||
provider: _provider.id,
|
||||
group: getDefaultGroupName(model.id, _provider.id),
|
||||
// @ts-ignore name
|
||||
description: model?.description,
|
||||
owned_by: model?.owned_by
|
||||
// @ts-ignore group
|
||||
group: getDefaultGroupName(model?.id || model?.name, _provider.id),
|
||||
// @ts-ignore description
|
||||
description: model?.description || '',
|
||||
// @ts-ignore owned_by
|
||||
owned_by: model?.owned_by || ''
|
||||
}))
|
||||
.filter((model) => !isEmpty(model.name))
|
||||
)
|
||||
|
||||
@ -7,7 +7,7 @@ import { PROVIDER_CONFIG } from '@renderer/config/providers'
|
||||
import { useTheme } from '@renderer/context/ThemeProvider'
|
||||
import { useAllProviders, useProvider, useProviders } from '@renderer/hooks/useProvider'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { isOpenAIProvider } from '@renderer/providers/AiProvider/ProviderFactory'
|
||||
import { isOpenAIProvider } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { checkApi, formatApiKeys } from '@renderer/services/ApiService'
|
||||
import { checkModelsHealth, getModelCheckSummary } from '@renderer/services/HealthCheckService'
|
||||
import { isProviderSupportAuth } from '@renderer/services/ProviderService'
|
||||
@ -231,22 +231,32 @@ const ProviderSetting: FC<Props> = ({ provider: _provider }) => {
|
||||
} else {
|
||||
setApiChecking(true)
|
||||
|
||||
const { valid, error } = await checkApi({ ...provider, apiKey, apiHost }, model)
|
||||
try {
|
||||
await checkApi({ ...provider, apiKey, apiHost }, model)
|
||||
|
||||
const errorMessage = error && error?.message ? ' ' + error?.message : ''
|
||||
window.message.success({
|
||||
key: 'api-check',
|
||||
style: { marginTop: '3vh' },
|
||||
duration: 2,
|
||||
content: i18n.t('message.api.connection.success')
|
||||
})
|
||||
|
||||
window.message[valid ? 'success' : 'error']({
|
||||
key: 'api-check',
|
||||
style: { marginTop: '3vh' },
|
||||
duration: valid ? 2 : 8,
|
||||
content: valid
|
||||
? i18n.t('message.api.connection.success')
|
||||
: i18n.t('message.api.connection.failed') + errorMessage
|
||||
})
|
||||
setApiValid(true)
|
||||
setTimeout(() => setApiValid(false), 3000)
|
||||
} catch (error: any) {
|
||||
const errorMessage = error?.message ? ' ' + error.message : ''
|
||||
|
||||
setApiValid(valid)
|
||||
setApiChecking(false)
|
||||
setTimeout(() => setApiValid(false), 3000)
|
||||
window.message.error({
|
||||
key: 'api-check',
|
||||
style: { marginTop: '3vh' },
|
||||
duration: 8,
|
||||
content: i18n.t('message.api.connection.failed') + errorMessage
|
||||
})
|
||||
|
||||
setApiValid(false)
|
||||
} finally {
|
||||
setApiChecking(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,117 +0,0 @@
|
||||
import { isOpenAILLMModel } from '@renderer/config/models'
|
||||
import { getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import { Assistant, MCPCallToolResponse, MCPTool, MCPToolResponse, Model, Provider, Suggestion } from '@renderer/types'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CompletionsParams } from '.'
|
||||
import AnthropicProvider from './AnthropicProvider'
|
||||
import BaseProvider from './BaseProvider'
|
||||
import GeminiProvider from './GeminiProvider'
|
||||
import OpenAIProvider from './OpenAIProvider'
|
||||
import OpenAIResponseProvider from './OpenAIResponseProvider'
|
||||
|
||||
/**
|
||||
* AihubmixProvider - 根据模型类型自动选择合适的提供商
|
||||
* 使用装饰器模式实现
|
||||
*/
|
||||
export default class AihubmixProvider extends BaseProvider {
|
||||
private providers: Map<string, BaseProvider> = new Map()
|
||||
private defaultProvider: BaseProvider
|
||||
private currentProvider: BaseProvider
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
|
||||
// 初始化各个提供商
|
||||
this.providers.set('claude', new AnthropicProvider(provider))
|
||||
this.providers.set('gemini', new GeminiProvider({ ...provider, apiHost: 'https://aihubmix.com/gemini' }))
|
||||
this.providers.set('openai', new OpenAIResponseProvider(provider))
|
||||
this.providers.set('default', new OpenAIProvider(provider))
|
||||
|
||||
// 设置默认提供商
|
||||
this.defaultProvider = this.providers.get('default')!
|
||||
this.currentProvider = this.defaultProvider
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型获取合适的提供商
|
||||
*/
|
||||
private getProvider(model: Model): BaseProvider {
|
||||
const id = model.id.toLowerCase()
|
||||
// claude开头
|
||||
if (id.startsWith('claude')) {
|
||||
return this.providers.get('claude')!
|
||||
}
|
||||
// gemini开头 或 imagen开头 且不以-nothink、-search结尾
|
||||
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
return this.providers.get('gemini')!
|
||||
}
|
||||
if (isOpenAILLMModel(model)) {
|
||||
return this.providers.get('openai')!
|
||||
}
|
||||
|
||||
return this.defaultProvider
|
||||
}
|
||||
|
||||
// 直接使用默认提供商的方法
|
||||
public async models(): Promise<OpenAI.Models.Model[]> {
|
||||
return this.defaultProvider.models()
|
||||
}
|
||||
|
||||
public async generateText(params: { prompt: string; content: string }): Promise<string> {
|
||||
return this.defaultProvider.generateText(params)
|
||||
}
|
||||
|
||||
public async generateImage(params: any): Promise<string[]> {
|
||||
return this.getProvider({
|
||||
id: params.model
|
||||
} as unknown as Model).generateImage(params)
|
||||
}
|
||||
|
||||
public async generateImageByChat(params: any): Promise<void> {
|
||||
return this.defaultProvider.generateImageByChat(params)
|
||||
}
|
||||
|
||||
public async completions(params: CompletionsParams): Promise<void> {
|
||||
const model = params.assistant.model
|
||||
this.currentProvider = this.getProvider(model!)
|
||||
return this.currentProvider.completions(params)
|
||||
}
|
||||
|
||||
public async translate(
|
||||
content: string,
|
||||
assistant: Assistant,
|
||||
onResponse?: (text: string, isComplete: boolean) => void
|
||||
): Promise<string> {
|
||||
return this.getProvider(assistant.model || getDefaultModel()).translate(content, assistant, onResponse)
|
||||
}
|
||||
|
||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
|
||||
return this.getProvider(assistant.model || getDefaultModel()).summaries(messages, assistant)
|
||||
}
|
||||
|
||||
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||
return this.getProvider(assistant.model || getDefaultModel()).summaryForSearch(messages, assistant)
|
||||
}
|
||||
|
||||
public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
|
||||
return this.getProvider(assistant.model || getDefaultModel()).suggestions(messages, assistant)
|
||||
}
|
||||
|
||||
public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> {
|
||||
return this.getProvider(model).check(model, stream)
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
return this.getProvider(model).getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
public convertMcpTools<T>(mcpTools: MCPTool[]) {
|
||||
return this.currentProvider.convertMcpTools(mcpTools) as T[]
|
||||
}
|
||||
|
||||
public mcpToolCallResponseToMessage(mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) {
|
||||
return this.currentProvider.mcpToolCallResponseToMessage(mcpToolResponse, resp, model)
|
||||
}
|
||||
}
|
||||
@ -1,802 +0,0 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import {
|
||||
Base64ImageSource,
|
||||
ImageBlockParam,
|
||||
MessageCreateParamsNonStreaming,
|
||||
MessageParam,
|
||||
TextBlockParam,
|
||||
ToolResultBlockParam,
|
||||
ToolUnion,
|
||||
ToolUseBlock,
|
||||
WebSearchResultBlock,
|
||||
WebSearchTool20250305,
|
||||
WebSearchToolResultError
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import {
|
||||
filterContextMessages,
|
||||
filterEmptyMessages,
|
||||
filterUserRoleStartMessages
|
||||
} from '@renderer/services/MessagesService'
|
||||
import {
|
||||
Assistant,
|
||||
EFFORT_RATIO,
|
||||
FileTypes,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Metrics,
|
||||
Model,
|
||||
Provider,
|
||||
Suggestion,
|
||||
ToolCallResponse,
|
||||
Usage,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||
import {
|
||||
anthropicToolUseToMcpTool,
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToAnthropicMessage,
|
||||
mcpToolsToAnthropicTools,
|
||||
parseAndCallTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { first, flatten, takeRight } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CompletionsParams } from '.'
|
||||
import BaseProvider from './BaseProvider'
|
||||
|
||||
interface ReasoningConfig {
|
||||
type: 'enabled' | 'disabled'
|
||||
budget_tokens?: number
|
||||
}
|
||||
|
||||
export default class AnthropicProvider extends BaseProvider {
|
||||
private sdk: Anthropic
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.sdk = new Anthropic({
|
||||
apiKey: this.apiKey,
|
||||
baseURL: this.getBaseURL(),
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: {
|
||||
'anthropic-beta': 'output-128k-2025-02-19'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.provider.apiHost
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message parameter
|
||||
* @param message - The message
|
||||
* @returns The message parameter
|
||||
*/
|
||||
private async getMessageParam(message: Message): Promise<MessageParam> {
|
||||
const parts: MessageParam['content'] = [
|
||||
{
|
||||
type: 'text',
|
||||
text: getMainTextContent(message)
|
||||
}
|
||||
]
|
||||
|
||||
// Get and process image blocks
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (imageBlock.file) {
|
||||
// Handle uploaded file
|
||||
const file = imageBlock.file
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
type: 'image',
|
||||
source: {
|
||||
data: base64Data.base64,
|
||||
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
|
||||
type: 'base64'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
// Get and process file blocks
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const { file } = fileBlock
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
|
||||
const base64Data = await FileManager.readBase64File(file)
|
||||
parts.push({
|
||||
type: 'document',
|
||||
source: {
|
||||
type: 'base64',
|
||||
media_type: 'application/pdf',
|
||||
data: base64Data
|
||||
}
|
||||
})
|
||||
} else {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
type: 'text',
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
}
|
||||
}
|
||||
|
||||
private async getWebSearchParams(model: Model): Promise<WebSearchTool20250305 | undefined> {
|
||||
if (!isWebSearchModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
max_uses: 5
|
||||
} as WebSearchTool20250305
|
||||
}
|
||||
|
||||
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.temperature
|
||||
}
|
||||
|
||||
override getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.topP
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
private getBudgetToken(assistant: Assistant, model: Model): ReasoningConfig | undefined {
|
||||
if (!isReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
if (reasoningEffort === undefined) {
|
||||
return {
|
||||
type: 'disabled'
|
||||
}
|
||||
}
|
||||
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
const budgetTokens = Math.max(
|
||||
1024,
|
||||
Math.floor(
|
||||
Math.min(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||
findTokenLimit(model.id)?.min!,
|
||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
type: 'enabled',
|
||||
budget_tokens: budgetTokens
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate completions
|
||||
* @param messages - The messages
|
||||
* @param assistant - The assistant
|
||||
* @param mcpTools - The MCP tools
|
||||
* @param onChunk - The onChunk callback
|
||||
* @param onFilterMessages - The onFilterMessages callback
|
||||
*/
|
||||
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||
|
||||
const userMessagesParams: MessageParam[] = []
|
||||
|
||||
const _messages = filterUserRoleStartMessages(
|
||||
filterContextMessages(filterEmptyMessages(takeRight(messages, contextCount + 2)))
|
||||
)
|
||||
|
||||
onFilterMessages(_messages)
|
||||
|
||||
for (const message of _messages) {
|
||||
userMessagesParams.push(await this.getMessageParam(message))
|
||||
}
|
||||
|
||||
const userMessages = flatten(userMessagesParams)
|
||||
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
|
||||
|
||||
let systemPrompt = assistant.prompt
|
||||
|
||||
const { tools } = this.setupToolsConfig<ToolUnion>({
|
||||
model,
|
||||
mcpTools,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools && mcpTools && mcpTools.length) {
|
||||
systemPrompt = await buildSystemPrompt(systemPrompt, mcpTools)
|
||||
}
|
||||
|
||||
let systemMessage: TextBlockParam | undefined = undefined
|
||||
if (systemPrompt) {
|
||||
systemMessage = {
|
||||
type: 'text',
|
||||
text: systemPrompt
|
||||
}
|
||||
}
|
||||
|
||||
const isEnabledBuiltinWebSearch = assistant.enableWebSearch && isWebSearchModel(model)
|
||||
|
||||
if (isEnabledBuiltinWebSearch) {
|
||||
const webSearchTool = await this.getWebSearchParams(model)
|
||||
if (webSearchTool) {
|
||||
tools.push(webSearchTool)
|
||||
}
|
||||
}
|
||||
|
||||
const body: MessageCreateParamsNonStreaming = {
|
||||
model: model.id,
|
||||
messages: userMessages,
|
||||
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
system: systemMessage ? [systemMessage] : undefined,
|
||||
// @ts-ignore thinking
|
||||
thinking: this.getBudgetToken(assistant, model),
|
||||
tools: tools,
|
||||
...this.getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
|
||||
const { signal } = abortController
|
||||
|
||||
const finalUsage: Usage = {
|
||||
completion_tokens: 0,
|
||||
prompt_tokens: 0,
|
||||
total_tokens: 0
|
||||
}
|
||||
|
||||
const finalMetrics: Metrics = {
|
||||
completion_tokens: 0,
|
||||
time_completion_millsec: 0,
|
||||
time_first_token_millsec: 0
|
||||
}
|
||||
const toolResponses: MCPToolResponse[] = []
|
||||
|
||||
const processStream = async (body: MessageCreateParamsNonStreaming, idx: number) => {
|
||||
let time_first_token_millsec = 0
|
||||
|
||||
if (!streamOutput) {
|
||||
const message = await this.sdk.messages.create({ ...body, stream: false })
|
||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||
|
||||
let text = ''
|
||||
let reasoning_content = ''
|
||||
|
||||
if (message.content && message.content.length > 0) {
|
||||
const thinkingBlock = message.content.find((block) => block.type === 'thinking')
|
||||
const textBlock = message.content.find((block) => block.type === 'text')
|
||||
|
||||
if (thinkingBlock && 'thinking' in thinkingBlock) {
|
||||
reasoning_content = thinkingBlock.thinking
|
||||
}
|
||||
|
||||
if (textBlock && 'text' in textBlock) {
|
||||
text = textBlock.text
|
||||
}
|
||||
}
|
||||
|
||||
return onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
text,
|
||||
reasoning_content,
|
||||
usage: message.usage as any,
|
||||
metrics: {
|
||||
completion_tokens: message.usage?.output_tokens || 0,
|
||||
time_completion_millsec,
|
||||
time_first_token_millsec: 0
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
let thinking_content = ''
|
||||
let isFirstChunk = true
|
||||
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
// 等待接口返回流
|
||||
const toolCalls: ToolUseBlock[] = []
|
||||
|
||||
this.sdk.messages
|
||||
.stream({ ...body, stream: true }, { signal, timeout: 5 * 60 * 1000 })
|
||||
.on('text', (text) => {
|
||||
if (isFirstChunk) {
|
||||
isFirstChunk = false
|
||||
if (time_first_token_millsec == 0) {
|
||||
time_first_token_millsec = new Date().getTime()
|
||||
} else {
|
||||
onChunk({
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: thinking_content,
|
||||
thinking_millsec: new Date().getTime() - time_first_token_millsec
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
onChunk({ type: ChunkType.TEXT_DELTA, text })
|
||||
})
|
||||
.on('contentBlock', (block) => {
|
||||
if (block.type === 'server_tool_use' && block.name === 'web_search') {
|
||||
onChunk({
|
||||
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
|
||||
})
|
||||
} else if (block.type === 'web_search_tool_result') {
|
||||
if (
|
||||
block.content &&
|
||||
(block.content as WebSearchToolResultError).type === 'web_search_tool_result_error'
|
||||
) {
|
||||
onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error: {
|
||||
code: (block.content as WebSearchToolResultError).error_code,
|
||||
message: (block.content as WebSearchToolResultError).error_code
|
||||
}
|
||||
})
|
||||
} else {
|
||||
onChunk({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: block.content as Array<WebSearchResultBlock>,
|
||||
source: WebSearchSource.ANTHROPIC
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
if (block.type === 'tool_use') {
|
||||
toolCalls.push(block)
|
||||
}
|
||||
})
|
||||
.on('thinking', (thinking) => {
|
||||
if (time_first_token_millsec == 0) {
|
||||
time_first_token_millsec = new Date().getTime()
|
||||
}
|
||||
|
||||
onChunk({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: thinking,
|
||||
thinking_millsec: new Date().getTime() - time_first_token_millsec
|
||||
})
|
||||
thinking_content += thinking
|
||||
})
|
||||
.on('finalMessage', async (message) => {
|
||||
const toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
|
||||
// tool call
|
||||
if (toolCalls.length > 0) {
|
||||
const mcpToolResponses = toolCalls
|
||||
.map((toolCall) => {
|
||||
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
|
||||
if (!mcpTool) {
|
||||
return undefined
|
||||
}
|
||||
return {
|
||||
id: toolCall.id,
|
||||
toolCallId: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: toolCall.input as Record<string, unknown>,
|
||||
status: 'pending'
|
||||
} as ToolCallResponse
|
||||
})
|
||||
.filter((t) => typeof t !== 'undefined')
|
||||
toolResults.push(
|
||||
...(await parseAndCallTools(
|
||||
mcpToolResponses,
|
||||
toolResponses,
|
||||
onChunk,
|
||||
this.mcpToolCallResponseToMessage,
|
||||
model,
|
||||
mcpTools
|
||||
))
|
||||
)
|
||||
}
|
||||
|
||||
// tool use
|
||||
const content = message.content[0]
|
||||
if (content && content.type === 'text') {
|
||||
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content.text })
|
||||
toolResults.push(
|
||||
...(await parseAndCallTools(
|
||||
content.text,
|
||||
toolResponses,
|
||||
onChunk,
|
||||
this.mcpToolCallResponseToMessage,
|
||||
model,
|
||||
mcpTools
|
||||
))
|
||||
)
|
||||
}
|
||||
|
||||
if (thinking_content) {
|
||||
onChunk({
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: thinking_content,
|
||||
thinking_millsec: new Date().getTime() - time_first_token_millsec
|
||||
})
|
||||
}
|
||||
|
||||
userMessages.push({
|
||||
role: message.role,
|
||||
content: message.content
|
||||
})
|
||||
|
||||
if (toolResults.length > 0) {
|
||||
toolResults.forEach((ts) => userMessages.push(ts as MessageParam))
|
||||
const newBody = body
|
||||
newBody.messages = userMessages
|
||||
|
||||
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
try {
|
||||
await processStream(newBody, idx + 1)
|
||||
} catch (error) {
|
||||
console.error('Error processing stream:', error)
|
||||
reject(error)
|
||||
}
|
||||
}
|
||||
|
||||
// 直接修改finalUsage对象会报错,TypeError: Cannot assign to read only property 'prompt_tokens' of object '#<Object>'
|
||||
// 暂未找到原因
|
||||
const updatedUsage: Usage = {
|
||||
...finalUsage,
|
||||
prompt_tokens: finalUsage.prompt_tokens + (message.usage?.input_tokens || 0),
|
||||
completion_tokens: finalUsage.completion_tokens + (message.usage?.output_tokens || 0)
|
||||
}
|
||||
updatedUsage.total_tokens = updatedUsage.prompt_tokens + updatedUsage.completion_tokens
|
||||
|
||||
const updatedMetrics: Metrics = {
|
||||
...finalMetrics,
|
||||
completion_tokens: updatedUsage.completion_tokens,
|
||||
time_completion_millsec:
|
||||
finalMetrics.time_completion_millsec + (new Date().getTime() - start_time_millsec),
|
||||
time_first_token_millsec: time_first_token_millsec - start_time_millsec
|
||||
}
|
||||
|
||||
Object.assign(finalUsage, updatedUsage)
|
||||
Object.assign(finalMetrics, updatedMetrics)
|
||||
|
||||
onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
usage: updatedUsage,
|
||||
metrics: updatedMetrics
|
||||
}
|
||||
})
|
||||
resolve()
|
||||
})
|
||||
.on('error', (error) => reject(error))
|
||||
.on('abort', () => {
|
||||
reject(new Error('Request was aborted.'))
|
||||
})
|
||||
})
|
||||
}
|
||||
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
const start_time_millsec = new Date().getTime()
|
||||
await processStream(body, 0).finally(() => {
|
||||
cleanup()
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Translate a message
|
||||
* @param content
|
||||
* @param assistant - The assistant
|
||||
* @param onResponse - The onResponse callback
|
||||
* @returns The translated message
|
||||
*/
|
||||
public async translate(
|
||||
content: string,
|
||||
assistant: Assistant,
|
||||
onResponse?: (text: string, isComplete: boolean) => void
|
||||
) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
|
||||
const messagesForApi = [{ role: 'user' as const, content: content }]
|
||||
|
||||
const stream = !!onResponse
|
||||
|
||||
const body: MessageCreateParamsNonStreaming = {
|
||||
model: model.id,
|
||||
messages: messagesForApi,
|
||||
max_tokens: 4096,
|
||||
temperature: assistant?.settings?.temperature,
|
||||
system: assistant.prompt
|
||||
}
|
||||
|
||||
if (!stream) {
|
||||
const response = await this.sdk.messages.create({ ...body, stream: false })
|
||||
return response.content[0].type === 'text' ? response.content[0].text : ''
|
||||
}
|
||||
|
||||
let text = ''
|
||||
|
||||
return new Promise<string>((resolve, reject) => {
|
||||
this.sdk.messages
|
||||
.stream({ ...body, stream: true })
|
||||
.on('text', (_text) => {
|
||||
text += _text
|
||||
onResponse?.(text, false)
|
||||
})
|
||||
.on('finalMessage', () => {
|
||||
onResponse?.(text, true)
|
||||
resolve(text)
|
||||
})
|
||||
.on('error', (error) => reject(error))
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Summarize a message
|
||||
* @param messages - The messages
|
||||
* @param assistant - The assistant
|
||||
* @returns The summary
|
||||
*/
|
||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
|
||||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||||
|
||||
const userMessages = takeRight(messages, 5).map((message) => ({
|
||||
role: message.role,
|
||||
content: getMainTextContent(message)
|
||||
}))
|
||||
|
||||
if (first(userMessages)?.role === 'assistant') {
|
||||
userMessages.shift()
|
||||
}
|
||||
|
||||
const userMessageContent = userMessages.reduce((prev, curr) => {
|
||||
const currentContent = curr.role === 'user' ? `User: ${curr.content}` : `Assistant: ${curr.content}`
|
||||
return prev + (prev ? '\n' : '') + currentContent
|
||||
}, '')
|
||||
|
||||
const systemMessage = {
|
||||
role: 'system',
|
||||
content: (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title')
|
||||
}
|
||||
|
||||
const userMessage = {
|
||||
role: 'user',
|
||||
content: userMessageContent
|
||||
}
|
||||
|
||||
const message = await this.sdk.messages.create({
|
||||
messages: [userMessage] as Anthropic.Messages.MessageParam[],
|
||||
model: model.id,
|
||||
system: systemMessage.content,
|
||||
stream: false,
|
||||
max_tokens: 4096
|
||||
})
|
||||
|
||||
const responseContent = message.content[0].type === 'text' ? message.content[0].text : ''
|
||||
return removeSpecialCharactersForTopicName(responseContent)
|
||||
}
|
||||
|
||||
/**
|
||||
* Summarize a message for search
|
||||
* @param messages - The messages
|
||||
* @param assistant - The assistant
|
||||
* @returns The summary
|
||||
*/
|
||||
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||
const model = assistant.model || getDefaultModel()
|
||||
const systemMessage = { content: assistant.prompt }
|
||||
|
||||
const userMessageContent = messages.map((m) => getMainTextContent(m)).join('\n')
|
||||
|
||||
const userMessage = {
|
||||
role: 'user' as const,
|
||||
content: userMessageContent
|
||||
}
|
||||
const lastUserMessage = messages[messages.length - 1]
|
||||
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
|
||||
const { signal } = abortController
|
||||
|
||||
const response = await this.sdk.messages
|
||||
.create(
|
||||
{
|
||||
messages: [userMessage],
|
||||
model: model.id,
|
||||
system: systemMessage.content,
|
||||
stream: false,
|
||||
max_tokens: 4096
|
||||
},
|
||||
{ timeout: 20 * 1000, signal }
|
||||
)
|
||||
.finally(cleanup)
|
||||
|
||||
return response.content[0].type === 'text' ? response.content[0].text : ''
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate text
|
||||
* @param prompt - The prompt
|
||||
* @param content - The content
|
||||
* @returns The generated text
|
||||
*/
|
||||
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
|
||||
const model = getDefaultModel()
|
||||
|
||||
const message = await this.sdk.messages.create({
|
||||
model: model.id,
|
||||
system: prompt,
|
||||
stream: false,
|
||||
max_tokens: 4096,
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
return message.content[0].type === 'text' ? message.content[0].text : ''
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an image
|
||||
* @returns The generated image
|
||||
*/
|
||||
public async generateImage(): Promise<string[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
public async generateImageByChat(): Promise<void> {
|
||||
throw new Error('Method not implemented.')
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate suggestions
|
||||
* @returns The suggestions
|
||||
*/
|
||||
public async suggestions(): Promise<Suggestion[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model is valid
|
||||
* @param model - The model
|
||||
* @param stream - Whether to use streaming interface
|
||||
* @returns The validity of the model
|
||||
*/
|
||||
public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> {
|
||||
if (!model) {
|
||||
return { valid: false, error: new Error('No model found') }
|
||||
}
|
||||
|
||||
const body = {
|
||||
model: model.id,
|
||||
messages: [{ role: 'user' as const, content: 'hi' }],
|
||||
max_tokens: 2, // api文档写的 x>1
|
||||
stream
|
||||
}
|
||||
|
||||
try {
|
||||
if (!stream) {
|
||||
const message = await this.sdk.messages.create(body as MessageCreateParamsNonStreaming)
|
||||
return {
|
||||
valid: message.content.length > 0,
|
||||
error: null
|
||||
}
|
||||
} else {
|
||||
return await new Promise((resolve, reject) => {
|
||||
let hasContent = false
|
||||
this.sdk.messages
|
||||
.stream(body)
|
||||
.on('text', (text) => {
|
||||
if (!hasContent && text) {
|
||||
hasContent = true
|
||||
resolve({ valid: true, error: null })
|
||||
}
|
||||
})
|
||||
.on('finalMessage', (message) => {
|
||||
if (!hasContent && message.content && message.content.length > 0) {
|
||||
hasContent = true
|
||||
resolve({ valid: true, error: null })
|
||||
}
|
||||
if (!hasContent) {
|
||||
reject(new Error('Empty streaming response'))
|
||||
}
|
||||
})
|
||||
.on('error', (error) => reject(error))
|
||||
})
|
||||
}
|
||||
} catch (error: any) {
|
||||
return {
|
||||
valid: false,
|
||||
error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the models
|
||||
* @returns The models
|
||||
*/
|
||||
public async models(): Promise<OpenAI.Models.Model[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(): Promise<number> {
|
||||
return 0
|
||||
}
|
||||
|
||||
public convertMcpTools<T>(mcpTools: MCPTool[]): T[] {
|
||||
return mcpToolsToAnthropicTools(mcpTools) as T[]
|
||||
}
|
||||
|
||||
public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
|
||||
} else if ('toolCallId' in mcpToolResponse) {
|
||||
return {
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: mcpToolResponse.toolCallId!,
|
||||
content: resp.content
|
||||
.map((item) => {
|
||||
if (item.type === 'text') {
|
||||
return {
|
||||
type: 'text',
|
||||
text: item.text || ''
|
||||
} satisfies TextBlockParam
|
||||
}
|
||||
if (item.type === 'image') {
|
||||
return {
|
||||
type: 'image',
|
||||
source: {
|
||||
data: item.data || '',
|
||||
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
|
||||
type: 'base64'
|
||||
}
|
||||
} satisfies ImageBlockParam
|
||||
}
|
||||
return
|
||||
})
|
||||
.filter((n) => typeof n !== 'undefined'),
|
||||
is_error: resp.isError
|
||||
} satisfies ToolResultBlockParam
|
||||
]
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,33 +0,0 @@
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import AihubmixProvider from './AihubmixProvider'
|
||||
import AnthropicProvider from './AnthropicProvider'
|
||||
import BaseProvider from './BaseProvider'
|
||||
import GeminiProvider from './GeminiProvider'
|
||||
import OpenAIProvider from './OpenAIProvider'
|
||||
import OpenAIResponseProvider from './OpenAIResponseProvider'
|
||||
|
||||
export default class ProviderFactory {
|
||||
static create(provider: Provider): BaseProvider {
|
||||
if (provider.id === 'aihubmix') {
|
||||
return new AihubmixProvider(provider)
|
||||
}
|
||||
|
||||
switch (provider.type) {
|
||||
case 'openai':
|
||||
return new OpenAIProvider(provider)
|
||||
case 'openai-response':
|
||||
return new OpenAIResponseProvider(provider)
|
||||
case 'anthropic':
|
||||
return new AnthropicProvider(provider)
|
||||
case 'gemini':
|
||||
return new GeminiProvider(provider)
|
||||
default:
|
||||
return new OpenAIProvider(provider)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function isOpenAIProvider(provider: Provider) {
|
||||
return !['anthropic', 'gemini'].includes(provider.type)
|
||||
}
|
||||
@ -1,94 +0,0 @@
|
||||
import { GenerateImagesParameters } from '@google/genai'
|
||||
import BaseProvider from '@renderer/providers/AiProvider/BaseProvider'
|
||||
import ProviderFactory from '@renderer/providers/AiProvider/ProviderFactory'
|
||||
import type { Assistant, GenerateImageParams, MCPTool, Model, Provider, Suggestion } from '@renderer/types'
|
||||
import { Chunk } from '@renderer/types/chunk'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
export interface CompletionsParams {
|
||||
messages: Message[]
|
||||
assistant: Assistant
|
||||
onChunk: (chunk: Chunk) => void
|
||||
onFilterMessages: (messages: Message[]) => void
|
||||
mcpTools?: MCPTool[]
|
||||
}
|
||||
|
||||
export default class AiProvider {
|
||||
private sdk: BaseProvider
|
||||
|
||||
constructor(provider: Provider) {
|
||||
this.sdk = ProviderFactory.create(provider)
|
||||
}
|
||||
|
||||
public async fakeCompletions(params: CompletionsParams): Promise<void> {
|
||||
return this.sdk.fakeCompletions(params)
|
||||
}
|
||||
|
||||
public async completions({
|
||||
messages,
|
||||
assistant,
|
||||
mcpTools,
|
||||
onChunk,
|
||||
onFilterMessages
|
||||
}: CompletionsParams): Promise<void> {
|
||||
return this.sdk.completions({ messages, assistant, mcpTools, onChunk, onFilterMessages })
|
||||
}
|
||||
|
||||
public async translate(
|
||||
content: string,
|
||||
assistant: Assistant,
|
||||
onResponse?: (text: string, isComplete: boolean) => void
|
||||
): Promise<string> {
|
||||
return this.sdk.translate(content, assistant, onResponse)
|
||||
}
|
||||
|
||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
|
||||
return this.sdk.summaries(messages, assistant)
|
||||
}
|
||||
|
||||
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||
return this.sdk.summaryForSearch(messages, assistant)
|
||||
}
|
||||
|
||||
public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
|
||||
return this.sdk.suggestions(messages, assistant)
|
||||
}
|
||||
|
||||
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
|
||||
return this.sdk.generateText({ prompt, content })
|
||||
}
|
||||
|
||||
public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> {
|
||||
return this.sdk.check(model, stream)
|
||||
}
|
||||
|
||||
public async models(): Promise<OpenAI.Models.Model[]> {
|
||||
return this.sdk.models()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.sdk.getApiKey()
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams | GenerateImagesParameters): Promise<string[]> {
|
||||
return this.sdk.generateImage(params as GenerateImageParams)
|
||||
}
|
||||
|
||||
public async generateImageByChat({
|
||||
messages,
|
||||
assistant,
|
||||
onChunk,
|
||||
onFilterMessages
|
||||
}: CompletionsParams): Promise<void> {
|
||||
return this.sdk.generateImageByChat({ messages, assistant, onChunk, onFilterMessages })
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
return this.sdk.getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.sdk.getBaseURL()
|
||||
}
|
||||
}
|
||||
@ -1,10 +1,21 @@
|
||||
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { getOpenAIWebSearchParams, isOpenAIWebSearch } from '@renderer/config/models'
|
||||
import {
|
||||
isEmbeddingModel,
|
||||
isGenerateImageModel,
|
||||
isOpenRouterBuiltInWebSearchModel,
|
||||
isReasoningModel,
|
||||
isSupportedDisableGenerationModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import {
|
||||
SEARCH_SUMMARY_PROMPT,
|
||||
SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY,
|
||||
SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
} from '@renderer/config/prompts'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import {
|
||||
Assistant,
|
||||
@ -13,20 +24,22 @@ import {
|
||||
MCPTool,
|
||||
Model,
|
||||
Provider,
|
||||
Suggestion,
|
||||
WebSearchResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import { SdkModel } from '@renderer/types/sdk'
|
||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||
import { isAbortError } from '@renderer/utils/error'
|
||||
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
|
||||
import { getKnowledgeBaseIds, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { findLast, isEmpty } from 'lodash'
|
||||
import { findLast, isEmpty, takeRight } from 'lodash'
|
||||
|
||||
import AiProvider from '../providers/AiProvider'
|
||||
import AiProvider from '../aiCore'
|
||||
import {
|
||||
getAssistantProvider,
|
||||
getAssistantSettings,
|
||||
getDefaultModel,
|
||||
getProviderByModel,
|
||||
getTopNamingModel,
|
||||
@ -34,7 +47,13 @@ import {
|
||||
} from './AssistantService'
|
||||
import { getDefaultAssistant } from './AssistantService'
|
||||
import { processKnowledgeSearch } from './KnowledgeService'
|
||||
import { filterContextMessages, filterMessages, filterUsefulMessages } from './MessagesService'
|
||||
import {
|
||||
filterContextMessages,
|
||||
filterEmptyMessages,
|
||||
filterMessages,
|
||||
filterUsefulMessages,
|
||||
filterUserRoleStartMessages
|
||||
} from './MessagesService'
|
||||
import WebSearchService from './WebSearchService'
|
||||
|
||||
// TODO:考虑拆开
|
||||
@ -50,6 +69,7 @@ async function fetchExternalTool(
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
const webSearchProvider = WebSearchService.getWebSearchProvider(assistant.webSearchProviderId)
|
||||
|
||||
// 使用外部搜索工具
|
||||
const shouldWebSearch = !!assistant.webSearchProviderId && webSearchProvider !== null
|
||||
const shouldKnowledgeSearch = hasKnowledgeBase
|
||||
|
||||
@ -83,14 +103,14 @@ async function fetchExternalTool(
|
||||
summaryAssistant.prompt = prompt
|
||||
|
||||
try {
|
||||
const keywords = await fetchSearchSummary({
|
||||
const result = await fetchSearchSummary({
|
||||
messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage],
|
||||
assistant: summaryAssistant
|
||||
})
|
||||
|
||||
if (!keywords) return getFallbackResult()
|
||||
if (!result) return getFallbackResult()
|
||||
|
||||
const extracted = extractInfoFromXML(keywords)
|
||||
const extracted = extractInfoFromXML(result.getText())
|
||||
// 根据需求过滤结果
|
||||
return {
|
||||
websearch: needWebExtract ? extracted?.websearch : undefined,
|
||||
@ -134,12 +154,6 @@ async function fetchExternalTool(
|
||||
return undefined
|
||||
}
|
||||
|
||||
// Pass the guaranteed model to the check function
|
||||
const webSearchParams = getOpenAIWebSearchParams(assistant, assistant.model)
|
||||
if (!isEmpty(webSearchParams) || isOpenAIWebSearch(assistant.model)) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
// Use the consolidated processWebsearch function
|
||||
WebSearchService.createAbortSignal(lastUserMessage.id)
|
||||
@ -238,7 +252,7 @@ async function fetchExternalTool(
|
||||
|
||||
// Get MCP tools (Fix duplicate declaration)
|
||||
let mcpTools: MCPTool[] = [] // Initialize as empty array
|
||||
const enabledMCPs = lastUserMessage?.enabledMCPs
|
||||
const enabledMCPs = assistant.mcpServers
|
||||
if (enabledMCPs && enabledMCPs.length > 0) {
|
||||
try {
|
||||
const toolPromises = enabledMCPs.map(async (mcpServer) => {
|
||||
@ -301,17 +315,52 @@ export async function fetchChatCompletion({
|
||||
// NOTE: The search results are NOT added to the messages sent to the AI here.
|
||||
// They will be retrieved and used by the messageThunk later to create CitationBlocks.
|
||||
const { mcpTools } = await fetchExternalTool(lastUserMessage, assistant, onChunkReceived, lastAnswer)
|
||||
const model = assistant.model || getDefaultModel()
|
||||
|
||||
const { maxTokens, contextCount } = getAssistantSettings(assistant)
|
||||
|
||||
const filteredMessages = filterUsefulMessages(messages)
|
||||
|
||||
const _messages = filterUserRoleStartMessages(
|
||||
filterEmptyMessages(filterContextMessages(takeRight(filteredMessages, contextCount + 2))) // 取原来几个provider的最大值
|
||||
)
|
||||
|
||||
const enableReasoning =
|
||||
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
|
||||
assistant.settings?.reasoning_effort !== undefined) ||
|
||||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
|
||||
|
||||
const enableWebSearch =
|
||||
(assistant.enableWebSearch && isWebSearchModel(model)) ||
|
||||
isOpenRouterBuiltInWebSearchModel(model) ||
|
||||
model.id.includes('sonar') ||
|
||||
false
|
||||
|
||||
const enableGenerateImage =
|
||||
isGenerateImageModel(model) && (isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage : true)
|
||||
|
||||
// --- Call AI Completions ---
|
||||
await AI.completions({
|
||||
messages: filteredMessages,
|
||||
assistant,
|
||||
onFilterMessages: () => {},
|
||||
onChunk: onChunkReceived,
|
||||
mcpTools: mcpTools
|
||||
})
|
||||
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
if (enableWebSearch) {
|
||||
onChunkReceived({ type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS })
|
||||
}
|
||||
await AI.completions(
|
||||
{
|
||||
callType: 'chat',
|
||||
messages: _messages,
|
||||
assistant,
|
||||
onChunk: onChunkReceived,
|
||||
mcpTools: mcpTools,
|
||||
maxTokens,
|
||||
streamOutput: assistant.settings?.streamOutput || false,
|
||||
enableReasoning,
|
||||
enableWebSearch,
|
||||
enableGenerateImage
|
||||
},
|
||||
{
|
||||
streamOutput: assistant.settings?.streamOutput || false
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
interface FetchTranslateProps {
|
||||
@ -321,7 +370,7 @@ interface FetchTranslateProps {
|
||||
}
|
||||
|
||||
export async function fetchTranslate({ content, assistant, onResponse }: FetchTranslateProps) {
|
||||
const model = getTranslateModel()
|
||||
const model = getTranslateModel() || assistant.model || getDefaultModel()
|
||||
|
||||
if (!model) {
|
||||
throw new Error(i18n.t('error.provider_disabled'))
|
||||
@ -333,17 +382,42 @@ export async function fetchTranslate({ content, assistant, onResponse }: FetchTr
|
||||
throw new Error(i18n.t('error.no_api_key'))
|
||||
}
|
||||
|
||||
const isSupportedStreamOutput = () => {
|
||||
if (!onResponse) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
const stream = isSupportedStreamOutput()
|
||||
const enableReasoning =
|
||||
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
|
||||
assistant.settings?.reasoning_effort !== undefined) ||
|
||||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
|
||||
|
||||
const params: CompletionsParams = {
|
||||
callType: 'translate',
|
||||
messages: content,
|
||||
assistant: { ...assistant, model },
|
||||
streamOutput: stream,
|
||||
enableReasoning,
|
||||
onResponse
|
||||
}
|
||||
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
try {
|
||||
return await AI.translate(content, assistant, onResponse)
|
||||
return (await AI.completions(params)).getText() || ''
|
||||
} catch (error: any) {
|
||||
return ''
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
|
||||
const prompt = (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title')
|
||||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||||
const userMessages = takeRight(messages, 5)
|
||||
|
||||
const provider = getProviderByModel(model)
|
||||
|
||||
if (!hasApiKey(provider)) {
|
||||
@ -352,9 +426,18 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
|
||||
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
const params: CompletionsParams = {
|
||||
callType: 'summary',
|
||||
messages: filterMessages(userMessages),
|
||||
assistant: { ...assistant, prompt, model },
|
||||
maxTokens: 1000,
|
||||
streamOutput: false
|
||||
}
|
||||
|
||||
try {
|
||||
const text = await AI.summaries(filterMessages(messages), assistant)
|
||||
return text?.replace(/["']/g, '') || null
|
||||
const { getText } = await AI.completions(params)
|
||||
const text = getText()
|
||||
return removeSpecialCharactersForTopicName(text) || null
|
||||
} catch (error: any) {
|
||||
return null
|
||||
}
|
||||
@ -370,7 +453,14 @@ export async function fetchSearchSummary({ messages, assistant }: { messages: Me
|
||||
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
return await AI.summaryForSearch(messages, assistant)
|
||||
const params: CompletionsParams = {
|
||||
callType: 'search',
|
||||
messages: messages,
|
||||
assistant,
|
||||
streamOutput: false
|
||||
}
|
||||
|
||||
return await AI.completions(params)
|
||||
}
|
||||
|
||||
export async function fetchGenerate({ prompt, content }: { prompt: string; content: string }): Promise<string> {
|
||||
@ -383,42 +473,32 @@ export async function fetchGenerate({ prompt, content }: { prompt: string; conte
|
||||
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
const assistant = getDefaultAssistant()
|
||||
assistant.model = model
|
||||
assistant.prompt = prompt
|
||||
|
||||
const params: CompletionsParams = {
|
||||
callType: 'generate',
|
||||
messages: content,
|
||||
assistant,
|
||||
streamOutput: false
|
||||
}
|
||||
|
||||
try {
|
||||
return await AI.generateText({ prompt, content })
|
||||
const result = await AI.completions(params)
|
||||
return result.getText() || ''
|
||||
} catch (error: any) {
|
||||
return ''
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchSuggestions({
|
||||
messages,
|
||||
assistant
|
||||
}: {
|
||||
messages: Message[]
|
||||
assistant: Assistant
|
||||
}): Promise<Suggestion[]> {
|
||||
const model = assistant.model
|
||||
if (!model || model.id.endsWith('global')) {
|
||||
return []
|
||||
}
|
||||
|
||||
const provider = getAssistantProvider(assistant)
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
try {
|
||||
return await AI.suggestions(filterMessages(messages), assistant)
|
||||
} catch (error: any) {
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
function hasApiKey(provider: Provider) {
|
||||
if (!provider) return false
|
||||
if (provider.id === 'ollama' || provider.id === 'lmstudio') return true
|
||||
return !isEmpty(provider.apiKey)
|
||||
}
|
||||
|
||||
export async function fetchModels(provider: Provider) {
|
||||
export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
try {
|
||||
@ -432,68 +512,69 @@ export const formatApiKeys = (value: string) => {
|
||||
return value.replaceAll(',', ',').replaceAll(' ', ',').replaceAll(' ', '').replaceAll('\n', ',')
|
||||
}
|
||||
|
||||
export function checkApiProvider(provider: Provider): {
|
||||
valid: boolean
|
||||
error: Error | null
|
||||
} {
|
||||
export function checkApiProvider(provider: Provider): void {
|
||||
const key = 'api-check'
|
||||
const style = { marginTop: '3vh' }
|
||||
|
||||
if (provider.id !== 'ollama' && provider.id !== 'lmstudio') {
|
||||
if (!provider.apiKey) {
|
||||
window.message.error({ content: i18n.t('message.error.enter.api.key'), key, style })
|
||||
return {
|
||||
valid: false,
|
||||
error: new Error(i18n.t('message.error.enter.api.key'))
|
||||
}
|
||||
throw new Error(i18n.t('message.error.enter.api.key'))
|
||||
}
|
||||
}
|
||||
|
||||
if (!provider.apiHost) {
|
||||
window.message.error({ content: i18n.t('message.error.enter.api.host'), key, style })
|
||||
return {
|
||||
valid: false,
|
||||
error: new Error(i18n.t('message.error.enter.api.host'))
|
||||
}
|
||||
throw new Error(i18n.t('message.error.enter.api.host'))
|
||||
}
|
||||
|
||||
if (isEmpty(provider.models)) {
|
||||
window.message.error({ content: i18n.t('message.error.enter.model'), key, style })
|
||||
return {
|
||||
valid: false,
|
||||
error: new Error(i18n.t('message.error.enter.model'))
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
valid: true,
|
||||
error: null
|
||||
throw new Error(i18n.t('message.error.enter.model'))
|
||||
}
|
||||
}
|
||||
|
||||
export async function checkApi(provider: Provider, model: Model): Promise<{ valid: boolean; error: Error | null }> {
|
||||
const validation = checkApiProvider(provider)
|
||||
if (!validation.valid) {
|
||||
return {
|
||||
valid: validation.valid,
|
||||
error: validation.error
|
||||
}
|
||||
}
|
||||
export async function checkApi(provider: Provider, model: Model): Promise<void> {
|
||||
checkApiProvider(provider)
|
||||
|
||||
const ai = new AiProvider(provider)
|
||||
|
||||
// Try streaming check first
|
||||
const result = await ai.check(model, true)
|
||||
const assistant = getDefaultAssistant()
|
||||
assistant.model = model
|
||||
try {
|
||||
if (isEmbeddingModel(model)) {
|
||||
const result = await ai.getEmbeddingDimensions(model)
|
||||
if (result === 0) {
|
||||
throw new Error(i18n.t('message.error.enter.model'))
|
||||
}
|
||||
} else {
|
||||
const params: CompletionsParams = {
|
||||
callType: 'check',
|
||||
messages: 'hi',
|
||||
assistant,
|
||||
streamOutput: true
|
||||
}
|
||||
|
||||
if (result.valid && !result.error) {
|
||||
return result
|
||||
}
|
||||
|
||||
// 不应该假设错误由流式引发。多次发起检测请求可能触发429,掩盖了真正的问题。
|
||||
// 但这里错误类型做的很粗糙,暂时先这样
|
||||
if (result.error && result.error.message.includes('stream')) {
|
||||
return ai.check(model, false)
|
||||
} else {
|
||||
return result
|
||||
// Try streaming check first
|
||||
const result = await ai.completions(params)
|
||||
if (!result.getText()) {
|
||||
throw new Error('No response received')
|
||||
}
|
||||
}
|
||||
} catch (error: any) {
|
||||
if (error.message.includes('stream')) {
|
||||
const params: CompletionsParams = {
|
||||
callType: 'check',
|
||||
messages: 'hi',
|
||||
assistant,
|
||||
streamOutput: false
|
||||
}
|
||||
const result = await ai.completions(params)
|
||||
if (!result.getText()) {
|
||||
throw new Error('No response received')
|
||||
}
|
||||
} else {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -98,14 +98,20 @@ export async function checkModelWithMultipleKeys(
|
||||
if (isParallel) {
|
||||
// Check all API keys in parallel
|
||||
const keyPromises = apiKeys.map(async (key) => {
|
||||
const result = await checkModel({ ...provider, apiKey: key }, model)
|
||||
|
||||
return {
|
||||
key,
|
||||
isValid: result.valid,
|
||||
error: result.error?.message,
|
||||
latency: result.latency
|
||||
} as ApiKeyCheckStatus
|
||||
try {
|
||||
const result = await checkModel({ ...provider, apiKey: key }, model)
|
||||
return {
|
||||
key,
|
||||
isValid: true,
|
||||
latency: result.latency
|
||||
} as ApiKeyCheckStatus
|
||||
} catch (error: unknown) {
|
||||
return {
|
||||
key,
|
||||
isValid: false,
|
||||
error: error instanceof Error ? error.message.slice(0, 20) + '...' : String(error).slice(0, 20) + '...'
|
||||
} as ApiKeyCheckStatus
|
||||
}
|
||||
})
|
||||
|
||||
const results = await Promise.allSettled(keyPromises)
|
||||
@ -125,14 +131,20 @@ export async function checkModelWithMultipleKeys(
|
||||
} else {
|
||||
// Check all API keys serially
|
||||
for (const key of apiKeys) {
|
||||
const result = await checkModel({ ...provider, apiKey: key }, model)
|
||||
|
||||
keyResults.push({
|
||||
key,
|
||||
isValid: result.valid,
|
||||
error: result.error?.message,
|
||||
latency: result.latency
|
||||
})
|
||||
try {
|
||||
const result = await checkModel({ ...provider, apiKey: key }, model)
|
||||
keyResults.push({
|
||||
key,
|
||||
isValid: true,
|
||||
latency: result.latency
|
||||
})
|
||||
} catch (error: unknown) {
|
||||
keyResults.push({
|
||||
key,
|
||||
isValid: false,
|
||||
error: error instanceof Error ? error.message.slice(0, 20) + '...' : String(error).slice(0, 20) + '...'
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant'
|
||||
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import AiProvider from '@renderer/providers/AiProvider'
|
||||
import store from '@renderer/store'
|
||||
import { FileType, KnowledgeBase, KnowledgeBaseParams, KnowledgeReference } from '@renderer/types'
|
||||
import { ExtractResults } from '@renderer/utils/extract'
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import { isEmbeddingModel } from '@renderer/config/models'
|
||||
import AiProvider from '@renderer/providers/AiProvider'
|
||||
import store from '@renderer/store'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
import { t } from 'i18next'
|
||||
import { pick } from 'lodash'
|
||||
|
||||
import { checkApiProvider } from './ApiService'
|
||||
import { checkApi } from './ApiService'
|
||||
|
||||
export const getModelUniqId = (m?: Model) => {
|
||||
return m?.id ? JSON.stringify(pick(m, ['id', 'provider'])) : ''
|
||||
@ -33,64 +31,23 @@ export function getModelName(model?: Model) {
|
||||
return modelName
|
||||
}
|
||||
|
||||
// Generic function to perform model checks
|
||||
// Abstracts provider validation and error handling, allowing different types of check logic
|
||||
// Generic function to perform model checks with exception handling
|
||||
async function performModelCheck<T>(
|
||||
provider: Provider,
|
||||
model: Model,
|
||||
checkFn: (ai: AiProvider, model: Model) => Promise<T>,
|
||||
processResult: (result: T) => { valid: boolean; error: Error | null }
|
||||
): Promise<{ valid: boolean; error: Error | null; latency?: number }> {
|
||||
const validation = checkApiProvider(provider)
|
||||
if (!validation.valid) {
|
||||
return {
|
||||
valid: validation.valid,
|
||||
error: validation.error
|
||||
}
|
||||
}
|
||||
checkFn: (provider: Provider, model: Model) => Promise<T>
|
||||
): Promise<{ latency: number }> {
|
||||
const startTime = performance.now()
|
||||
await checkFn(provider, model)
|
||||
const latency = performance.now() - startTime
|
||||
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
try {
|
||||
const startTime = performance.now()
|
||||
const result = await checkFn(AI, model)
|
||||
const latency = performance.now() - startTime
|
||||
|
||||
return {
|
||||
...processResult(result),
|
||||
latency
|
||||
}
|
||||
} catch (error: any) {
|
||||
return {
|
||||
valid: false,
|
||||
error
|
||||
}
|
||||
}
|
||||
return { latency }
|
||||
}
|
||||
|
||||
// Unified model check function
|
||||
// Automatically selects appropriate check method based on model type
|
||||
export async function checkModel(provider: Provider, model: Model) {
|
||||
if (isEmbeddingModel(model)) {
|
||||
return performModelCheck(
|
||||
provider,
|
||||
model,
|
||||
(ai, model) => ai.getEmbeddingDimensions(model),
|
||||
(dimensions) => ({ valid: dimensions > 0, error: null })
|
||||
)
|
||||
} else {
|
||||
return performModelCheck(
|
||||
provider,
|
||||
model,
|
||||
async (ai, model) => {
|
||||
// Try streaming check first
|
||||
const result = await ai.check(model, true)
|
||||
if (result.valid && !result.error) {
|
||||
return result
|
||||
}
|
||||
return ai.check(model, false)
|
||||
},
|
||||
({ valid, error }) => ({ valid, error: error || null })
|
||||
)
|
||||
}
|
||||
export async function checkModel(provider: Provider, model: Model): Promise<{ latency: number }> {
|
||||
return performModelCheck(provider, model, async (provider, model) => {
|
||||
await checkApi(provider, model)
|
||||
})
|
||||
}
|
||||
|
||||
@ -28,7 +28,9 @@ export interface StreamProcessorCallbacks {
|
||||
onLLMWebSearchComplete?: (llmWebSearchResult: WebSearchResponse) => void
|
||||
// Image generation chunk received
|
||||
onImageCreated?: () => void
|
||||
onImageGenerated?: (imageData: GenerateImageResponse) => void
|
||||
onImageDelta?: (imageData: GenerateImageResponse) => void
|
||||
onImageGenerated?: (imageData?: GenerateImageResponse) => void
|
||||
onLLMResponseComplete?: (response?: Response) => void
|
||||
// Called when an error occurs during chunk processing
|
||||
onError?: (error: any) => void
|
||||
// Called when the entire stream processing is signaled as complete (success or failure)
|
||||
@ -40,59 +42,84 @@ export function createStreamProcessor(callbacks: StreamProcessorCallbacks = {})
|
||||
// The returned function processes a single chunk or a final signal
|
||||
return (chunk: Chunk) => {
|
||||
try {
|
||||
// Logger.log(`[${new Date().toLocaleString()}] createStreamProcessor ${chunk.type}`, chunk)
|
||||
// 1. Handle the manual final signal first
|
||||
if (chunk?.type === ChunkType.BLOCK_COMPLETE) {
|
||||
callbacks.onComplete?.(AssistantMessageStatus.SUCCESS, chunk?.response)
|
||||
return
|
||||
const data = chunk
|
||||
switch (data.type) {
|
||||
case ChunkType.BLOCK_COMPLETE: {
|
||||
if (callbacks.onComplete) callbacks.onComplete(AssistantMessageStatus.SUCCESS, data?.response)
|
||||
break
|
||||
}
|
||||
case ChunkType.LLM_RESPONSE_CREATED: {
|
||||
if (callbacks.onLLMResponseCreated) callbacks.onLLMResponseCreated()
|
||||
break
|
||||
}
|
||||
case ChunkType.TEXT_DELTA: {
|
||||
if (callbacks.onTextChunk) callbacks.onTextChunk(data.text)
|
||||
break
|
||||
}
|
||||
case ChunkType.TEXT_COMPLETE: {
|
||||
if (callbacks.onTextComplete) callbacks.onTextComplete(data.text)
|
||||
break
|
||||
}
|
||||
case ChunkType.THINKING_DELTA: {
|
||||
if (callbacks.onThinkingChunk) callbacks.onThinkingChunk(data.text, data.thinking_millsec)
|
||||
break
|
||||
}
|
||||
case ChunkType.THINKING_COMPLETE: {
|
||||
if (callbacks.onThinkingComplete) callbacks.onThinkingComplete(data.text, data.thinking_millsec)
|
||||
break
|
||||
}
|
||||
case ChunkType.MCP_TOOL_IN_PROGRESS: {
|
||||
if (callbacks.onToolCallInProgress)
|
||||
data.responses.forEach((toolResp) => callbacks.onToolCallInProgress!(toolResp))
|
||||
break
|
||||
}
|
||||
case ChunkType.MCP_TOOL_COMPLETE: {
|
||||
if (callbacks.onToolCallComplete && data.responses.length > 0) {
|
||||
data.responses.forEach((toolResp) => callbacks.onToolCallComplete!(toolResp))
|
||||
}
|
||||
break
|
||||
}
|
||||
case ChunkType.EXTERNEL_TOOL_IN_PROGRESS: {
|
||||
if (callbacks.onExternalToolInProgress) callbacks.onExternalToolInProgress()
|
||||
break
|
||||
}
|
||||
case ChunkType.EXTERNEL_TOOL_COMPLETE: {
|
||||
if (callbacks.onExternalToolComplete) callbacks.onExternalToolComplete(data.external_tool)
|
||||
break
|
||||
}
|
||||
case ChunkType.LLM_WEB_SEARCH_IN_PROGRESS: {
|
||||
if (callbacks.onLLMWebSearchInProgress) callbacks.onLLMWebSearchInProgress()
|
||||
break
|
||||
}
|
||||
case ChunkType.LLM_WEB_SEARCH_COMPLETE: {
|
||||
if (callbacks.onLLMWebSearchComplete) callbacks.onLLMWebSearchComplete(data.llm_web_search)
|
||||
break
|
||||
}
|
||||
case ChunkType.IMAGE_CREATED: {
|
||||
if (callbacks.onImageCreated) callbacks.onImageCreated()
|
||||
break
|
||||
}
|
||||
case ChunkType.IMAGE_DELTA: {
|
||||
if (callbacks.onImageDelta) callbacks.onImageDelta(data.image)
|
||||
break
|
||||
}
|
||||
case ChunkType.IMAGE_COMPLETE: {
|
||||
if (callbacks.onImageGenerated) callbacks.onImageGenerated(data.image)
|
||||
break
|
||||
}
|
||||
case ChunkType.LLM_RESPONSE_COMPLETE: {
|
||||
if (callbacks.onLLMResponseComplete) callbacks.onLLMResponseComplete(data.response)
|
||||
break
|
||||
}
|
||||
case ChunkType.ERROR: {
|
||||
if (callbacks.onError) callbacks.onError(data.error)
|
||||
break
|
||||
}
|
||||
default: {
|
||||
// Handle unknown chunk types or log an error
|
||||
console.warn(`Unknown chunk type: ${data.type}`)
|
||||
}
|
||||
}
|
||||
// 2. Process the actual ChunkCallbackData
|
||||
const data = chunk // Cast after checking for 'final'
|
||||
// Invoke callbacks based on the fields present in the chunk data
|
||||
if (data.type === ChunkType.LLM_RESPONSE_CREATED && callbacks.onLLMResponseCreated) {
|
||||
callbacks.onLLMResponseCreated()
|
||||
}
|
||||
if (data.type === ChunkType.TEXT_DELTA && callbacks.onTextChunk) {
|
||||
callbacks.onTextChunk(data.text)
|
||||
}
|
||||
if (data.type === ChunkType.TEXT_COMPLETE && callbacks.onTextComplete) {
|
||||
callbacks.onTextComplete(data.text)
|
||||
}
|
||||
if (data.type === ChunkType.THINKING_DELTA && callbacks.onThinkingChunk) {
|
||||
callbacks.onThinkingChunk(data.text, data.thinking_millsec)
|
||||
}
|
||||
if (data.type === ChunkType.THINKING_COMPLETE && callbacks.onThinkingComplete) {
|
||||
callbacks.onThinkingComplete(data.text, data.thinking_millsec)
|
||||
}
|
||||
if (data.type === ChunkType.MCP_TOOL_IN_PROGRESS && callbacks.onToolCallInProgress) {
|
||||
data.responses.forEach((toolResp) => callbacks.onToolCallInProgress!(toolResp))
|
||||
}
|
||||
if (data.type === ChunkType.MCP_TOOL_COMPLETE && data.responses.length > 0 && callbacks.onToolCallComplete) {
|
||||
data.responses.forEach((toolResp) => callbacks.onToolCallComplete!(toolResp))
|
||||
}
|
||||
if (data.type === ChunkType.EXTERNEL_TOOL_IN_PROGRESS && callbacks.onExternalToolInProgress) {
|
||||
callbacks.onExternalToolInProgress()
|
||||
}
|
||||
if (data.type === ChunkType.EXTERNEL_TOOL_COMPLETE && callbacks.onExternalToolComplete) {
|
||||
callbacks.onExternalToolComplete(data.external_tool)
|
||||
}
|
||||
if (data.type === ChunkType.LLM_WEB_SEARCH_IN_PROGRESS && callbacks.onLLMWebSearchInProgress) {
|
||||
callbacks.onLLMWebSearchInProgress()
|
||||
}
|
||||
if (data.type === ChunkType.LLM_WEB_SEARCH_COMPLETE && callbacks.onLLMWebSearchComplete) {
|
||||
callbacks.onLLMWebSearchComplete(data.llm_web_search)
|
||||
}
|
||||
if (data.type === ChunkType.IMAGE_CREATED && callbacks.onImageCreated) {
|
||||
callbacks.onImageCreated()
|
||||
}
|
||||
if (data.type === ChunkType.IMAGE_COMPLETE && callbacks.onImageGenerated) {
|
||||
callbacks.onImageGenerated(data.image)
|
||||
}
|
||||
if (data.type === ChunkType.ERROR && callbacks.onError) {
|
||||
callbacks.onError(data.error)
|
||||
}
|
||||
// Note: Usage and Metrics are usually handled at the end or accumulated differently,
|
||||
// so direct callbacks might not be the best fit here. They are often part of the final message state.
|
||||
} catch (error) {
|
||||
console.error('Error processing stream chunk:', error)
|
||||
callbacks.onError?.(error)
|
||||
|
||||
@ -8,6 +8,10 @@ export interface ChatState {
|
||||
selectedMessageIds: string[]
|
||||
activeTopic: Topic | null
|
||||
activeAssistant: Assistant | null
|
||||
/** topic ids that are currently being renamed */
|
||||
renamingTopics: string[]
|
||||
/** topic ids that are newly renamed */
|
||||
newlyRenamedTopics: string[]
|
||||
}
|
||||
|
||||
export interface UpdateState {
|
||||
@ -67,7 +71,9 @@ const initialState: RuntimeState = {
|
||||
isMultiSelectMode: false,
|
||||
selectedMessageIds: [],
|
||||
activeTopic: null,
|
||||
activeAssistant: null
|
||||
activeAssistant: null,
|
||||
renamingTopics: [],
|
||||
newlyRenamedTopics: []
|
||||
}
|
||||
}
|
||||
|
||||
@ -123,6 +129,12 @@ const runtimeSlice = createSlice({
|
||||
},
|
||||
setActiveAssistant: (state, action: PayloadAction<Assistant>) => {
|
||||
state.chat.activeAssistant = action.payload
|
||||
},
|
||||
setRenamingTopics: (state, action: PayloadAction<string[]>) => {
|
||||
state.chat.renamingTopics = action.payload
|
||||
},
|
||||
setNewlyRenamedTopics: (state, action: PayloadAction<string[]>) => {
|
||||
state.chat.newlyRenamedTopics = action.payload
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -143,7 +155,9 @@ export const {
|
||||
toggleMultiSelectMode,
|
||||
setSelectedMessageIds,
|
||||
setActiveTopic,
|
||||
setActiveAssistant
|
||||
setActiveAssistant,
|
||||
setRenamingTopics,
|
||||
setNewlyRenamedTopics
|
||||
} = runtimeSlice.actions
|
||||
|
||||
export default runtimeSlice.reducer
|
||||
|
||||
@ -8,7 +8,6 @@ import { createStreamProcessor, type StreamProcessorCallbacks } from '@renderer/
|
||||
import { estimateMessagesUsage } from '@renderer/services/TokenService'
|
||||
import store from '@renderer/store'
|
||||
import type { Assistant, ExternalToolResult, FileType, MCPToolResponse, Model, Topic } from '@renderer/types'
|
||||
import { WebSearchSource } from '@renderer/types'
|
||||
import type {
|
||||
CitationMessageBlock,
|
||||
FileMessageBlock,
|
||||
@ -22,7 +21,6 @@ import { AssistantMessageStatus, MessageBlockStatus, MessageBlockType } from '@r
|
||||
import { Response } from '@renderer/types/newMessage'
|
||||
import { uuid } from '@renderer/utils'
|
||||
import { formatErrorMessage, isAbortError } from '@renderer/utils/error'
|
||||
import { extractUrlsFromMarkdown } from '@renderer/utils/linkConverter'
|
||||
import {
|
||||
createAssistantMessage,
|
||||
createBaseMessageBlock,
|
||||
@ -35,7 +33,8 @@ import {
|
||||
createTranslationBlock,
|
||||
resetAssistantMessage
|
||||
} from '@renderer/utils/messageUtils/create'
|
||||
import { getTopicQueue, waitForTopicQueue } from '@renderer/utils/queue'
|
||||
import { getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { getTopicQueue } from '@renderer/utils/queue'
|
||||
import { isOnHomePage } from '@renderer/utils/window'
|
||||
import { t } from 'i18next'
|
||||
import { isEmpty, throttle } from 'lodash'
|
||||
@ -45,10 +44,10 @@ import type { AppDispatch, RootState } from '../index'
|
||||
import { removeManyBlocks, updateOneBlock, upsertManyBlocks, upsertOneBlock } from '../messageBlock'
|
||||
import { newMessagesActions, selectMessagesForTopic } from '../newMessage'
|
||||
|
||||
const handleChangeLoadingOfTopic = async (topicId: string) => {
|
||||
await waitForTopicQueue(topicId)
|
||||
store.dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
}
|
||||
// const handleChangeLoadingOfTopic = async (topicId: string) => {
|
||||
// await waitForTopicQueue(topicId)
|
||||
// store.dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
// }
|
||||
// TODO: 后续可以将db操作移到Listener Middleware中
|
||||
export const saveMessageAndBlocksToDB = async (message: Message, blocks: MessageBlock[], messageIndex: number = -1) => {
|
||||
try {
|
||||
@ -337,10 +336,17 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
|
||||
let accumulatedContent = ''
|
||||
let accumulatedThinking = ''
|
||||
// 专注于管理UI焦点和块切换
|
||||
let lastBlockId: string | null = null
|
||||
let lastBlockType: MessageBlockType | null = null
|
||||
// 专注于块内部的生命周期处理
|
||||
let initialPlaceholderBlockId: string | null = null
|
||||
let citationBlockId: string | null = null
|
||||
let mainTextBlockId: string | null = null
|
||||
let thinkingBlockId: string | null = null
|
||||
let imageBlockId: string | null = null
|
||||
let toolBlockId: string | null = null
|
||||
let hasWebSearch = false
|
||||
const toolCallIdToBlockIdMap = new Map<string, string>()
|
||||
const notificationService = NotificationService.getInstance()
|
||||
|
||||
@ -400,129 +406,129 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
}
|
||||
|
||||
callbacks = {
|
||||
onLLMResponseCreated: () => {
|
||||
onLLMResponseCreated: async () => {
|
||||
const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, {
|
||||
status: MessageBlockStatus.PROCESSING
|
||||
})
|
||||
handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
|
||||
initialPlaceholderBlockId = baseBlock.id
|
||||
await handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
|
||||
},
|
||||
onTextChunk: (text) => {
|
||||
onTextChunk: async (text) => {
|
||||
accumulatedContent += text
|
||||
if (lastBlockId) {
|
||||
if (lastBlockType === MessageBlockType.UNKNOWN) {
|
||||
const initialChanges: Partial<MessageBlock> = {
|
||||
type: MessageBlockType.MAIN_TEXT,
|
||||
content: accumulatedContent,
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
citationReferences: citationBlockId ? [{ citationBlockId }] : []
|
||||
}
|
||||
mainTextBlockId = lastBlockId
|
||||
lastBlockType = MessageBlockType.MAIN_TEXT
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
} else if (lastBlockType === MessageBlockType.MAIN_TEXT) {
|
||||
const blockChanges: Partial<MessageBlock> = {
|
||||
content: accumulatedContent,
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
throttledBlockUpdate(lastBlockId, blockChanges)
|
||||
// throttledBlockDbUpdate(lastBlockId, blockChanges)
|
||||
} else {
|
||||
const newBlock = createMainTextBlock(assistantMsgId, accumulatedContent, {
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
citationReferences: citationBlockId ? [{ citationBlockId }] : []
|
||||
})
|
||||
handleBlockTransition(newBlock, MessageBlockType.MAIN_TEXT)
|
||||
mainTextBlockId = newBlock.id
|
||||
if (mainTextBlockId) {
|
||||
const blockChanges: Partial<MessageBlock> = {
|
||||
content: accumulatedContent,
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
throttledBlockUpdate(mainTextBlockId, blockChanges)
|
||||
} else if (initialPlaceholderBlockId) {
|
||||
// 将占位块转换为主文本块
|
||||
const initialChanges: Partial<MessageBlock> = {
|
||||
type: MessageBlockType.MAIN_TEXT,
|
||||
content: accumulatedContent,
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
citationReferences: citationBlockId ? [{ citationBlockId }] : []
|
||||
}
|
||||
mainTextBlockId = initialPlaceholderBlockId
|
||||
// 清理占位块
|
||||
initialPlaceholderBlockId = null
|
||||
lastBlockType = MessageBlockType.MAIN_TEXT
|
||||
dispatch(updateOneBlock({ id: mainTextBlockId, changes: initialChanges }))
|
||||
saveUpdatedBlockToDB(mainTextBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
const newBlock = createMainTextBlock(assistantMsgId, accumulatedContent, {
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
citationReferences: citationBlockId ? [{ citationBlockId }] : []
|
||||
})
|
||||
mainTextBlockId = newBlock.id // 立即设置ID,防止竞态条件
|
||||
await handleBlockTransition(newBlock, MessageBlockType.MAIN_TEXT)
|
||||
}
|
||||
},
|
||||
onTextComplete: async (finalText) => {
|
||||
if (lastBlockType === MessageBlockType.MAIN_TEXT && lastBlockId) {
|
||||
if (mainTextBlockId) {
|
||||
const changes = {
|
||||
content: finalText,
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
cancelThrottledBlockUpdate(lastBlockId)
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
|
||||
if (assistant.enableWebSearch && assistant.model?.provider === 'openrouter') {
|
||||
const extractedUrls = extractUrlsFromMarkdown(finalText)
|
||||
if (extractedUrls.length > 0) {
|
||||
const citationBlock = createCitationBlock(
|
||||
assistantMsgId,
|
||||
{ response: { source: WebSearchSource.OPENROUTER, results: extractedUrls } },
|
||||
{ status: MessageBlockStatus.SUCCESS }
|
||||
)
|
||||
await handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
// saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState)
|
||||
}
|
||||
}
|
||||
cancelThrottledBlockUpdate(mainTextBlockId)
|
||||
dispatch(updateOneBlock({ id: mainTextBlockId, changes }))
|
||||
saveUpdatedBlockToDB(mainTextBlockId, assistantMsgId, topicId, getState)
|
||||
mainTextBlockId = null
|
||||
} else {
|
||||
console.warn(
|
||||
`[onTextComplete] Received text.complete but last block was not MAIN_TEXT (was ${lastBlockType}) or lastBlockId is null.`
|
||||
`[onTextComplete] Received text.complete but last block was not MAIN_TEXT (was ${lastBlockType}) or lastBlockId is null.`
|
||||
)
|
||||
}
|
||||
},
|
||||
onThinkingChunk: (text, thinking_millsec) => {
|
||||
accumulatedThinking += text
|
||||
if (lastBlockId) {
|
||||
if (lastBlockType === MessageBlockType.UNKNOWN) {
|
||||
// First chunk for this block: Update type and status immediately
|
||||
lastBlockType = MessageBlockType.THINKING
|
||||
const initialChanges: Partial<MessageBlock> = {
|
||||
type: MessageBlockType.THINKING,
|
||||
content: accumulatedThinking,
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
} else if (lastBlockType === MessageBlockType.THINKING) {
|
||||
const blockChanges: Partial<MessageBlock> = {
|
||||
content: accumulatedThinking,
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
thinking_millsec: thinking_millsec
|
||||
}
|
||||
throttledBlockUpdate(lastBlockId, blockChanges)
|
||||
// throttledBlockDbUpdate(lastBlockId, blockChanges)
|
||||
} else {
|
||||
const newBlock = createThinkingBlock(assistantMsgId, accumulatedThinking, {
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
thinking_millsec: 0
|
||||
})
|
||||
handleBlockTransition(newBlock, MessageBlockType.THINKING)
|
||||
if (citationBlockId && !hasWebSearch) {
|
||||
const changes: Partial<CitationMessageBlock> = {
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
dispatch(updateOneBlock({ id: citationBlockId, changes }))
|
||||
saveUpdatedBlockToDB(citationBlockId, assistantMsgId, topicId, getState)
|
||||
citationBlockId = null
|
||||
}
|
||||
},
|
||||
onThinkingChunk: async (text, thinking_millsec) => {
|
||||
accumulatedThinking += text
|
||||
if (thinkingBlockId) {
|
||||
const blockChanges: Partial<MessageBlock> = {
|
||||
content: accumulatedThinking,
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
thinking_millsec: thinking_millsec
|
||||
}
|
||||
throttledBlockUpdate(thinkingBlockId, blockChanges)
|
||||
} else if (initialPlaceholderBlockId) {
|
||||
// First chunk for this block: Update type and status immediately
|
||||
lastBlockType = MessageBlockType.THINKING
|
||||
const initialChanges: Partial<MessageBlock> = {
|
||||
type: MessageBlockType.THINKING,
|
||||
content: accumulatedThinking,
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
thinkingBlockId = initialPlaceholderBlockId
|
||||
initialPlaceholderBlockId = null
|
||||
dispatch(updateOneBlock({ id: thinkingBlockId, changes: initialChanges }))
|
||||
saveUpdatedBlockToDB(thinkingBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
const newBlock = createThinkingBlock(assistantMsgId, accumulatedThinking, {
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
thinking_millsec: 0
|
||||
})
|
||||
thinkingBlockId = newBlock.id // 立即设置ID,防止竞态条件
|
||||
await handleBlockTransition(newBlock, MessageBlockType.THINKING)
|
||||
}
|
||||
},
|
||||
onThinkingComplete: (finalText, final_thinking_millsec) => {
|
||||
if (lastBlockType === MessageBlockType.THINKING && lastBlockId) {
|
||||
if (thinkingBlockId) {
|
||||
const changes = {
|
||||
type: MessageBlockType.THINKING,
|
||||
content: finalText,
|
||||
status: MessageBlockStatus.SUCCESS,
|
||||
thinking_millsec: final_thinking_millsec
|
||||
}
|
||||
cancelThrottledBlockUpdate(lastBlockId)
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
cancelThrottledBlockUpdate(thinkingBlockId)
|
||||
dispatch(updateOneBlock({ id: thinkingBlockId, changes }))
|
||||
saveUpdatedBlockToDB(thinkingBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
console.warn(
|
||||
`[onThinkingComplete] Received thinking.complete but last block was not THINKING (was ${lastBlockType}) or lastBlockId is null.`
|
||||
`[onThinkingComplete] Received thinking.complete but last block was not THINKING (was ${lastBlockType}) or lastBlockId is null.`
|
||||
)
|
||||
}
|
||||
thinkingBlockId = null
|
||||
},
|
||||
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
|
||||
if (lastBlockType === MessageBlockType.UNKNOWN && lastBlockId) {
|
||||
if (initialPlaceholderBlockId) {
|
||||
lastBlockType = MessageBlockType.TOOL
|
||||
const changes = {
|
||||
type: MessageBlockType.TOOL,
|
||||
status: MessageBlockStatus.PROCESSING,
|
||||
metadata: { rawMcpToolResponse: toolResponse }
|
||||
}
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
toolCallIdToBlockIdMap.set(toolResponse.id, lastBlockId)
|
||||
toolBlockId = initialPlaceholderBlockId
|
||||
initialPlaceholderBlockId = null
|
||||
dispatch(updateOneBlock({ id: toolBlockId, changes }))
|
||||
saveUpdatedBlockToDB(toolBlockId, assistantMsgId, topicId, getState)
|
||||
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlockId)
|
||||
} else if (toolResponse.status === 'invoking') {
|
||||
const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, {
|
||||
toolName: toolResponse.tool.name,
|
||||
@ -539,6 +545,7 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
},
|
||||
onToolCallComplete: (toolResponse: MCPToolResponse) => {
|
||||
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
||||
toolCallIdToBlockIdMap.delete(toolResponse.id)
|
||||
if (toolResponse.status === 'done' || toolResponse.status === 'error') {
|
||||
if (!existingBlockId) {
|
||||
console.error(
|
||||
@ -564,10 +571,10 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
)
|
||||
}
|
||||
},
|
||||
onExternalToolInProgress: () => {
|
||||
onExternalToolInProgress: async () => {
|
||||
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
|
||||
citationBlockId = citationBlock.id
|
||||
handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
await handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
// saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState)
|
||||
},
|
||||
onExternalToolComplete: (externalToolResult: ExternalToolResult) => {
|
||||
@ -583,35 +590,39 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
console.error('[onExternalToolComplete] citationBlockId is null. Cannot update.')
|
||||
}
|
||||
},
|
||||
onLLMWebSearchInProgress: () => {
|
||||
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
|
||||
citationBlockId = citationBlock.id
|
||||
handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
// saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState)
|
||||
onLLMWebSearchInProgress: async () => {
|
||||
if (initialPlaceholderBlockId) {
|
||||
lastBlockType = MessageBlockType.CITATION
|
||||
citationBlockId = initialPlaceholderBlockId
|
||||
const changes = {
|
||||
type: MessageBlockType.CITATION,
|
||||
status: MessageBlockStatus.PROCESSING
|
||||
}
|
||||
lastBlockType = MessageBlockType.CITATION
|
||||
dispatch(updateOneBlock({ id: initialPlaceholderBlockId, changes }))
|
||||
saveUpdatedBlockToDB(initialPlaceholderBlockId, assistantMsgId, topicId, getState)
|
||||
initialPlaceholderBlockId = null
|
||||
} else {
|
||||
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
|
||||
citationBlockId = citationBlock.id
|
||||
await handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
}
|
||||
},
|
||||
onLLMWebSearchComplete: async (llmWebSearchResult) => {
|
||||
if (citationBlockId) {
|
||||
hasWebSearch = true
|
||||
const changes: Partial<CitationMessageBlock> = {
|
||||
response: llmWebSearchResult,
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
dispatch(updateOneBlock({ id: citationBlockId, changes }))
|
||||
saveUpdatedBlockToDB(citationBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
const citationBlock = createCitationBlock(
|
||||
assistantMsgId,
|
||||
{ response: llmWebSearchResult },
|
||||
{ status: MessageBlockStatus.SUCCESS }
|
||||
)
|
||||
citationBlockId = citationBlock.id
|
||||
handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
}
|
||||
if (mainTextBlockId) {
|
||||
const state = getState()
|
||||
const existingMainTextBlock = state.messageBlocks.entities[mainTextBlockId]
|
||||
if (existingMainTextBlock && existingMainTextBlock.type === MessageBlockType.MAIN_TEXT) {
|
||||
const currentRefs = existingMainTextBlock.citationReferences || []
|
||||
if (!currentRefs.some((ref) => ref.citationBlockId === citationBlockId)) {
|
||||
|
||||
if (mainTextBlockId) {
|
||||
const state = getState()
|
||||
const existingMainTextBlock = state.messageBlocks.entities[mainTextBlockId]
|
||||
if (existingMainTextBlock && existingMainTextBlock.type === MessageBlockType.MAIN_TEXT) {
|
||||
const currentRefs = existingMainTextBlock.citationReferences || []
|
||||
const mainTextChanges = {
|
||||
citationReferences: [
|
||||
...currentRefs,
|
||||
@ -621,40 +632,64 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
dispatch(updateOneBlock({ id: mainTextBlockId, changes: mainTextChanges }))
|
||||
saveUpdatedBlockToDB(mainTextBlockId, assistantMsgId, topicId, getState)
|
||||
}
|
||||
mainTextBlockId = null
|
||||
}
|
||||
}
|
||||
},
|
||||
onImageCreated: () => {
|
||||
if (lastBlockId) {
|
||||
if (lastBlockType === MessageBlockType.UNKNOWN) {
|
||||
const initialChanges: Partial<MessageBlock> = {
|
||||
type: MessageBlockType.IMAGE,
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
lastBlockType = MessageBlockType.IMAGE
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
const imageBlock = createImageBlock(assistantMsgId, {
|
||||
status: MessageBlockStatus.PROCESSING
|
||||
})
|
||||
handleBlockTransition(imageBlock, MessageBlockType.IMAGE)
|
||||
onImageCreated: async () => {
|
||||
if (initialPlaceholderBlockId) {
|
||||
lastBlockType = MessageBlockType.IMAGE
|
||||
const initialChanges: Partial<MessageBlock> = {
|
||||
type: MessageBlockType.IMAGE,
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
lastBlockType = MessageBlockType.IMAGE
|
||||
imageBlockId = initialPlaceholderBlockId
|
||||
initialPlaceholderBlockId = null
|
||||
dispatch(updateOneBlock({ id: imageBlockId, changes: initialChanges }))
|
||||
saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState)
|
||||
} else if (!imageBlockId) {
|
||||
const imageBlock = createImageBlock(assistantMsgId, {
|
||||
status: MessageBlockStatus.STREAMING
|
||||
})
|
||||
imageBlockId = imageBlock.id
|
||||
await handleBlockTransition(imageBlock, MessageBlockType.IMAGE)
|
||||
}
|
||||
},
|
||||
onImageGenerated: (imageData) => {
|
||||
onImageDelta: (imageData) => {
|
||||
const imageUrl = imageData.images?.[0] || 'placeholder_image_url'
|
||||
if (lastBlockId && lastBlockType === MessageBlockType.IMAGE) {
|
||||
if (imageBlockId) {
|
||||
const changes: Partial<ImageMessageBlock> = {
|
||||
url: imageUrl,
|
||||
metadata: { generateImageResponse: imageData },
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
dispatch(updateOneBlock({ id: imageBlockId, changes }))
|
||||
saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState)
|
||||
}
|
||||
},
|
||||
onImageGenerated: (imageData) => {
|
||||
if (imageBlockId) {
|
||||
if (!imageData) {
|
||||
const changes: Partial<ImageMessageBlock> = {
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
dispatch(updateOneBlock({ id: imageBlockId, changes }))
|
||||
saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
const imageUrl = imageData.images?.[0] || 'placeholder_image_url'
|
||||
const changes: Partial<ImageMessageBlock> = {
|
||||
url: imageUrl,
|
||||
metadata: { generateImageResponse: imageData },
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
dispatch(updateOneBlock({ id: imageBlockId, changes }))
|
||||
saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState)
|
||||
}
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
console.error('[onImageGenerated] Last block was not an Image block or ID is missing.')
|
||||
}
|
||||
imageBlockId = null
|
||||
},
|
||||
onError: async (error) => {
|
||||
console.dir(error, { depth: null })
|
||||
@ -683,15 +718,16 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
source: 'assistant'
|
||||
})
|
||||
}
|
||||
|
||||
if (lastBlockId) {
|
||||
const possibleBlockId =
|
||||
mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId
|
||||
if (possibleBlockId) {
|
||||
// 更改上一个block的状态为ERROR
|
||||
const changes: Partial<MessageBlock> = {
|
||||
status: isErrorTypeAbort ? MessageBlockStatus.PAUSED : MessageBlockStatus.ERROR
|
||||
}
|
||||
cancelThrottledBlockUpdate(lastBlockId)
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
cancelThrottledBlockUpdate(possibleBlockId)
|
||||
dispatch(updateOneBlock({ id: possibleBlockId, changes }))
|
||||
saveUpdatedBlockToDB(possibleBlockId, assistantMsgId, topicId, getState)
|
||||
}
|
||||
|
||||
const errorBlock = createErrorBlock(assistantMsgId, serializableError, { status: MessageBlockStatus.SUCCESS })
|
||||
@ -721,35 +757,45 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
const contextForUsage = userMsgIndex !== -1 ? orderedMsgs.slice(0, userMsgIndex + 1) : []
|
||||
const finalContextWithAssistant = [...contextForUsage, finalAssistantMsg]
|
||||
|
||||
if (lastBlockId) {
|
||||
const possibleBlockId =
|
||||
mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId
|
||||
if (possibleBlockId) {
|
||||
const changes: Partial<MessageBlock> = {
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
cancelThrottledBlockUpdate(lastBlockId)
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
cancelThrottledBlockUpdate(possibleBlockId)
|
||||
dispatch(updateOneBlock({ id: possibleBlockId, changes }))
|
||||
saveUpdatedBlockToDB(possibleBlockId, assistantMsgId, topicId, getState)
|
||||
}
|
||||
|
||||
// const content = getMainTextContent(finalAssistantMsg)
|
||||
// if (!isOnHomePage()) {
|
||||
// await notificationService.send({
|
||||
// id: uuid(),
|
||||
// type: 'success',
|
||||
// title: t('notification.assistant'),
|
||||
// message: content.length > 50 ? content.slice(0, 47) + '...' : content,
|
||||
// silent: false,
|
||||
// timestamp: Date.now(),
|
||||
// source: 'assistant'
|
||||
// })
|
||||
// }
|
||||
const endTime = Date.now()
|
||||
const duration = endTime - startTime
|
||||
const content = getMainTextContent(finalAssistantMsg)
|
||||
if (!isOnHomePage() && duration > 60 * 1000) {
|
||||
await notificationService.send({
|
||||
id: uuid(),
|
||||
type: 'success',
|
||||
title: t('notification.assistant'),
|
||||
message: content.length > 50 ? content.slice(0, 47) + '...' : content,
|
||||
silent: false,
|
||||
timestamp: Date.now(),
|
||||
source: 'assistant'
|
||||
})
|
||||
}
|
||||
|
||||
// 更新topic的name
|
||||
autoRenameTopic(assistant, topicId)
|
||||
|
||||
if (response && response.usage?.total_tokens === 0) {
|
||||
if (
|
||||
response &&
|
||||
(response.usage?.total_tokens === 0 ||
|
||||
response?.usage?.prompt_tokens === 0 ||
|
||||
response?.usage?.completion_tokens === 0)
|
||||
) {
|
||||
const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant })
|
||||
response.usage = usage
|
||||
}
|
||||
dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
}
|
||||
if (response && response.metrics) {
|
||||
if (response.metrics.completion_tokens === 0 && response.usage?.completion_tokens) {
|
||||
@ -779,6 +825,7 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
|
||||
const streamProcessorCallbacks = createStreamProcessor(callbacks)
|
||||
|
||||
const startTime = Date.now()
|
||||
await fetchChatCompletion({
|
||||
messages: messagesForContext,
|
||||
assistant: assistant,
|
||||
@ -833,9 +880,10 @@ export const sendMessage =
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in sendMessage thunk:', error)
|
||||
} finally {
|
||||
handleChangeLoadingOfTopic(topicId)
|
||||
}
|
||||
// finally {
|
||||
// handleChangeLoadingOfTopic(topicId)
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1069,9 +1117,10 @@ export const resendMessageThunk =
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`[resendMessageThunk] Error resending user message ${userMessageToResend.id}:`, error)
|
||||
} finally {
|
||||
handleChangeLoadingOfTopic(topicId)
|
||||
}
|
||||
// finally {
|
||||
// handleChangeLoadingOfTopic(topicId)
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1179,10 +1228,11 @@ export const regenerateAssistantResponseThunk =
|
||||
`[regenerateAssistantResponseThunk] Error regenerating response for assistant message ${assistantMessageToRegenerate.id}:`,
|
||||
error
|
||||
)
|
||||
dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
} finally {
|
||||
handleChangeLoadingOfTopic(topicId)
|
||||
// dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
}
|
||||
// finally {
|
||||
// handleChangeLoadingOfTopic(topicId)
|
||||
// }
|
||||
}
|
||||
|
||||
// --- Thunk to initiate translation and create the initial block ---
|
||||
@ -1348,9 +1398,10 @@ export const appendAssistantResponseThunk =
|
||||
console.error(`[appendAssistantResponseThunk] Error appending assistant response:`, error)
|
||||
// Optionally dispatch an error action or notification
|
||||
// Resetting loading state should be handled by the underlying fetchAndProcessAssistantResponseImpl
|
||||
} finally {
|
||||
handleChangeLoadingOfTopic(topicId)
|
||||
}
|
||||
// finally {
|
||||
// handleChangeLoadingOfTopic(topicId)
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import { ExternalToolResult, KnowledgeReference, MCPToolResponse, WebSearchResponse } from '.'
|
||||
import { ExternalToolResult, KnowledgeReference, MCPToolResponse, ToolUseResponse, WebSearchResponse } from '.'
|
||||
import { Response, ResponseError } from './newMessage'
|
||||
import { SdkToolCall } from './sdk'
|
||||
|
||||
// Define Enum for Chunk Types
|
||||
// 目前用到的,并没有列出完整的生命周期
|
||||
@ -11,6 +12,7 @@ export enum ChunkType {
|
||||
WEB_SEARCH_COMPLETE = 'web_search_complete',
|
||||
KNOWLEDGE_SEARCH_IN_PROGRESS = 'knowledge_search_in_progress',
|
||||
KNOWLEDGE_SEARCH_COMPLETE = 'knowledge_search_complete',
|
||||
MCP_TOOL_CREATED = 'mcp_tool_created',
|
||||
MCP_TOOL_IN_PROGRESS = 'mcp_tool_in_progress',
|
||||
MCP_TOOL_COMPLETE = 'mcp_tool_complete',
|
||||
EXTERNEL_TOOL_COMPLETE = 'externel_tool_complete',
|
||||
@ -118,7 +120,7 @@ export interface ImageDeltaChunk {
|
||||
/**
|
||||
* A chunk of Base64 encoded image data
|
||||
*/
|
||||
image: string
|
||||
image: { type: 'base64'; images: string[] }
|
||||
|
||||
/**
|
||||
* The type of the chunk
|
||||
@ -135,7 +137,7 @@ export interface ImageCompleteChunk {
|
||||
/**
|
||||
* The image content of the chunk
|
||||
*/
|
||||
image: { type: 'base64'; images: string[] }
|
||||
image?: { type: 'base64'; images: string[] }
|
||||
}
|
||||
|
||||
export interface ThinkingDeltaChunk {
|
||||
@ -253,6 +255,12 @@ export interface ExternalToolCompleteChunk {
|
||||
type: ChunkType.EXTERNEL_TOOL_COMPLETE
|
||||
}
|
||||
|
||||
export interface MCPToolCreatedChunk {
|
||||
type: ChunkType.MCP_TOOL_CREATED
|
||||
tool_calls?: SdkToolCall[] // 工具调用
|
||||
tool_use_responses?: ToolUseResponse[] // 工具使用响应
|
||||
}
|
||||
|
||||
export interface MCPToolInProgressChunk {
|
||||
/**
|
||||
* The type of the chunk
|
||||
@ -345,6 +353,7 @@ export type Chunk =
|
||||
| WebSearchCompleteChunk // 互联网搜索完成
|
||||
| KnowledgeSearchInProgressChunk // 知识库搜索进行中
|
||||
| KnowledgeSearchCompleteChunk // 知识库搜索完成
|
||||
| MCPToolCreatedChunk // MCP工具被大模型创建
|
||||
| MCPToolInProgressChunk // MCP工具调用中
|
||||
| MCPToolCompleteChunk // MCP工具调用完成
|
||||
| ExternalToolCompleteChunk // 外部工具调用完成,外部工具包含搜索互联网,知识库,MCP服务器
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import type { WebSearchResultBlock } from '@anthropic-ai/sdk/resources'
|
||||
import type { GenerateImagesConfig, GroundingMetadata } from '@google/genai'
|
||||
import type { GenerateImagesConfig, GroundingMetadata, PersonGeneration } from '@google/genai'
|
||||
import type OpenAI from 'openai'
|
||||
import type { CSSProperties } from 'react'
|
||||
|
||||
@ -448,10 +448,11 @@ export type GenerateImageParams = {
|
||||
imageSize: string
|
||||
batchSize: number
|
||||
seed?: string
|
||||
numInferenceSteps: number
|
||||
guidanceScale: number
|
||||
numInferenceSteps?: number
|
||||
guidanceScale?: number
|
||||
signal?: AbortSignal
|
||||
promptEnhancement?: boolean
|
||||
personGeneration?: PersonGeneration
|
||||
}
|
||||
|
||||
export type GenerateImageResponse = {
|
||||
@ -524,7 +525,7 @@ export enum WebSearchSource {
|
||||
}
|
||||
|
||||
export type WebSearchResponse = {
|
||||
results: WebSearchResults
|
||||
results?: WebSearchResults
|
||||
source: WebSearchSource
|
||||
}
|
||||
|
||||
|
||||
107
src/renderer/src/types/sdk.ts
Normal file
107
src/renderer/src/types/sdk.ts
Normal file
@ -0,0 +1,107 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import {
|
||||
Message,
|
||||
MessageCreateParams,
|
||||
MessageParam,
|
||||
RawMessageStreamEvent,
|
||||
ToolUnion,
|
||||
ToolUseBlock
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import {
|
||||
Content,
|
||||
CreateChatParameters,
|
||||
FunctionCall,
|
||||
GenerateContentResponse,
|
||||
GoogleGenAI,
|
||||
Model as GeminiModel,
|
||||
SendMessageParameters,
|
||||
Tool
|
||||
} from '@google/genai'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
import { Stream } from 'openai/streaming'
|
||||
|
||||
export type SdkInstance = OpenAI | AzureOpenAI | Anthropic | GoogleGenAI
|
||||
export type SdkParams = OpenAISdkParams | OpenAIResponseSdkParams | AnthropicSdkParams | GeminiSdkParams
|
||||
export type SdkRawChunk = OpenAISdkRawChunk | OpenAIResponseSdkRawChunk | AnthropicSdkRawChunk | GeminiSdkRawChunk
|
||||
export type SdkRawOutput = OpenAISdkRawOutput | OpenAIResponseSdkRawOutput | AnthropicSdkRawOutput | GeminiSdkRawOutput
|
||||
export type SdkMessageParam =
|
||||
| OpenAISdkMessageParam
|
||||
| OpenAIResponseSdkMessageParam
|
||||
| AnthropicSdkMessageParam
|
||||
| GeminiSdkMessageParam
|
||||
export type SdkToolCall =
|
||||
| OpenAI.Chat.Completions.ChatCompletionMessageToolCall
|
||||
| ToolUseBlock
|
||||
| FunctionCall
|
||||
| OpenAIResponseSdkToolCall
|
||||
export type SdkTool = OpenAI.Chat.Completions.ChatCompletionTool | ToolUnion | Tool | OpenAIResponseSdkTool
|
||||
export type SdkModel = OpenAI.Models.Model | Anthropic.ModelInfo | GeminiModel
|
||||
|
||||
export type RequestOptions = Anthropic.RequestOptions | OpenAI.RequestOptions | GeminiOptions
|
||||
|
||||
/**
|
||||
* OpenAI
|
||||
*/
|
||||
|
||||
type OpenAIParamsWithoutReasoningEffort = Omit<OpenAI.Chat.Completions.ChatCompletionCreateParams, 'reasoning_effort'>
|
||||
|
||||
export type ReasoningEffortOptionalParams = {
|
||||
thinking?: { type: 'disabled' | 'enabled' | 'auto'; budget_tokens?: number }
|
||||
reasoning?: { max_tokens?: number; exclude?: boolean; effort?: string } | OpenAI.Reasoning
|
||||
reasoning_effort?: OpenAI.Chat.Completions.ChatCompletionCreateParams['reasoning_effort'] | 'none' | 'auto'
|
||||
enable_thinking?: boolean
|
||||
thinking_budget?: number
|
||||
enable_reasoning?: boolean
|
||||
// Add any other potential reasoning-related keys here if they exist
|
||||
}
|
||||
|
||||
export type OpenAISdkParams = OpenAIParamsWithoutReasoningEffort & ReasoningEffortOptionalParams
|
||||
export type OpenAISdkRawChunk =
|
||||
| OpenAI.Chat.Completions.ChatCompletionChunk
|
||||
| ({
|
||||
_request_id?: string | null | undefined
|
||||
} & OpenAI.ChatCompletion)
|
||||
|
||||
export type OpenAISdkRawOutput = Stream<OpenAI.Chat.Completions.ChatCompletionChunk> | OpenAI.ChatCompletion
|
||||
export type OpenAISdkRawContentSource =
|
||||
| OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta
|
||||
| OpenAI.Chat.Completions.ChatCompletionMessage
|
||||
|
||||
export type OpenAISdkMessageParam = OpenAI.Chat.Completions.ChatCompletionMessageParam
|
||||
|
||||
/**
|
||||
* OpenAI Response
|
||||
*/
|
||||
|
||||
export type OpenAIResponseSdkParams = OpenAI.Responses.ResponseCreateParams
|
||||
export type OpenAIResponseSdkRawOutput = Stream<OpenAI.Responses.ResponseStreamEvent> | OpenAI.Responses.Response
|
||||
export type OpenAIResponseSdkRawChunk = OpenAI.Responses.ResponseStreamEvent | OpenAI.Responses.Response
|
||||
export type OpenAIResponseSdkMessageParam = OpenAI.Responses.ResponseInputItem
|
||||
export type OpenAIResponseSdkToolCall = OpenAI.Responses.ResponseFunctionToolCall
|
||||
export type OpenAIResponseSdkTool = OpenAI.Responses.Tool
|
||||
|
||||
/**
|
||||
* Anthropic
|
||||
*/
|
||||
|
||||
export type AnthropicSdkParams = MessageCreateParams
|
||||
export type AnthropicSdkRawOutput = MessageStream | Message
|
||||
export type AnthropicSdkRawChunk = RawMessageStreamEvent | Message
|
||||
export type AnthropicSdkMessageParam = MessageParam
|
||||
|
||||
/**
|
||||
* Gemini
|
||||
*/
|
||||
|
||||
export type GeminiSdkParams = SendMessageParameters & CreateChatParameters
|
||||
export type GeminiSdkRawOutput = AsyncGenerator<GenerateContentResponse> | GenerateContentResponse
|
||||
export type GeminiSdkRawChunk = GenerateContentResponse
|
||||
export type GeminiSdkMessageParam = Content
|
||||
export type GeminiSdkToolCall = FunctionCall
|
||||
|
||||
export type GeminiOptions = {
|
||||
streamOutput: boolean
|
||||
abortSignal?: AbortSignal
|
||||
timeout?: number
|
||||
}
|
||||
@ -369,3 +369,99 @@ export function cleanLinkCommas(text: string): string {
|
||||
// 匹配两个 Markdown 链接之间的英文逗号(可能包含空格)
|
||||
return text.replace(/\]\(([^)]+)\)\s*,\s*\[/g, ']($1)[')
|
||||
}
|
||||
|
||||
/**
|
||||
* 从文本中识别各种格式的Web搜索引用占位符
|
||||
* 支持的格式包括:[1], [ref_1], [1](@ref), [1,2,3](@ref) 等
|
||||
* @param {string} text 要分析的文本
|
||||
* @returns {Array} 识别到的引用信息数组
|
||||
*/
|
||||
export function extractWebSearchReferences(text: string): Array<{
|
||||
match: string
|
||||
placeholder: string
|
||||
numbers: number[]
|
||||
startIndex: number
|
||||
endIndex: number
|
||||
}> {
|
||||
const references: Array<{
|
||||
match: string
|
||||
placeholder: string
|
||||
numbers: number[]
|
||||
startIndex: number
|
||||
endIndex: number
|
||||
}> = []
|
||||
|
||||
// 匹配各种引用格式的正则表达式
|
||||
const patterns = [
|
||||
// [1], [2], [3] - 简单数字引用
|
||||
{ regex: /\[(\d+)\]/g, type: 'simple' },
|
||||
// [ref_1], [ref_2] - Zhipu格式
|
||||
{ regex: /\[ref_(\d+)\]/g, type: 'zhipu' },
|
||||
// [1](@ref), [2](@ref) - Hunyuan单个引用格式
|
||||
{ regex: /\[(\d+)\]\(@ref\)/g, type: 'hunyuan_single' },
|
||||
// [1,2,3](@ref) - Hunyuan多个引用格式
|
||||
{ regex: /\[([\d,\s]+)\]\(@ref\)/g, type: 'hunyuan_multiple' }
|
||||
]
|
||||
|
||||
patterns.forEach(({ regex, type }) => {
|
||||
let match
|
||||
while ((match = regex.exec(text)) !== null) {
|
||||
let numbers: number[] = []
|
||||
|
||||
if (type === 'hunyuan_multiple') {
|
||||
// 解析逗号分隔的数字
|
||||
numbers = match[1]
|
||||
.split(',')
|
||||
.map((num) => parseInt(num.trim()))
|
||||
.filter((num) => !isNaN(num))
|
||||
} else {
|
||||
// 单个数字
|
||||
numbers = [parseInt(match[1])]
|
||||
}
|
||||
|
||||
references.push({
|
||||
match: match[0],
|
||||
placeholder: match[0],
|
||||
numbers: numbers,
|
||||
startIndex: match.index!,
|
||||
endIndex: match.index! + match[0].length
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// 按位置排序
|
||||
return references.sort((a, b) => a.startIndex - b.startIndex)
|
||||
}
|
||||
|
||||
/**
|
||||
* 智能链接转换器 - 根据文本中的引用模式和Web搜索结果自动选择合适的转换策略
|
||||
* @param {string} text 当前文本块
|
||||
* @param {any[]} webSearchResults Web搜索结果数组
|
||||
* @param {string} providerType Provider类型 ('openai', 'zhipu', 'hunyuan', 'openrouter', etc.)
|
||||
* @param {boolean} resetCounter 是否重置计数器
|
||||
* @returns {string} 转换后的文本
|
||||
*/
|
||||
export function smartLinkConverter(
|
||||
text: string,
|
||||
providerType: string = 'openai',
|
||||
resetCounter: boolean = false
|
||||
): string {
|
||||
// 检测文本中的引用模式
|
||||
const references = extractWebSearchReferences(text)
|
||||
|
||||
if (references.length === 0) {
|
||||
// 如果没有特定的引用模式,使用通用转换
|
||||
return convertLinks(text, resetCounter)
|
||||
}
|
||||
|
||||
// 根据检测到的引用模式选择合适的转换器
|
||||
const hasZhipuPattern = references.some((ref) => ref.placeholder.includes('ref_'))
|
||||
|
||||
if (hasZhipuPattern) {
|
||||
return convertLinksToZhipu(text, resetCounter)
|
||||
} else if (providerType === 'openrouter') {
|
||||
return convertLinksToOpenRouter(text, resetCounter)
|
||||
} else {
|
||||
return convertLinks(text, resetCounter)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,10 +1,4 @@
|
||||
import {
|
||||
ContentBlockParam,
|
||||
MessageParam,
|
||||
ToolResultBlockParam,
|
||||
ToolUnion,
|
||||
ToolUseBlock
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import { ContentBlockParam, MessageParam, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
|
||||
import { Content, FunctionCall, Part, Tool, Type as GeminiSchemaType } from '@google/genai'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { isFunctionCallingModel, isVisionModel } from '@renderer/config/models'
|
||||
@ -21,6 +15,7 @@ import {
|
||||
} from '@renderer/types'
|
||||
import type { MCPToolCompleteChunk, MCPToolInProgressChunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { SdkMessageParam } from '@renderer/types/sdk'
|
||||
import { isArray, isObject, pull, transform } from 'lodash'
|
||||
import { nanoid } from 'nanoid'
|
||||
import OpenAI from 'openai'
|
||||
@ -31,7 +26,7 @@ import {
|
||||
ChatCompletionTool
|
||||
} from 'openai/resources'
|
||||
|
||||
import { CompletionsParams } from '../providers/AiProvider'
|
||||
import { CompletionsParams } from '../aiCore/middleware/schemas'
|
||||
|
||||
const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install'
|
||||
const EXTRA_SCHEMA_KEYS = ['schema', 'headers']
|
||||
@ -449,13 +444,25 @@ export function parseToolUse(content: string, mcpTools: MCPTool[]): ToolUseRespo
|
||||
if (!content || !mcpTools || mcpTools.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
// 支持两种格式:
|
||||
// 1. 完整的 <tool_use></tool_use> 标签包围的内容
|
||||
// 2. 只有内部内容(从 TagExtractor 提取出来的)
|
||||
|
||||
let contentToProcess = content
|
||||
|
||||
// 如果内容不包含 <tool_use> 标签,说明是从 TagExtractor 提取的内部内容,需要包装
|
||||
if (!content.includes('<tool_use>')) {
|
||||
contentToProcess = `<tool_use>\n${content}\n</tool_use>`
|
||||
}
|
||||
|
||||
const toolUsePattern =
|
||||
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
|
||||
const tools: ToolUseResponse[] = []
|
||||
let match
|
||||
let idx = 0
|
||||
// Find all tool use blocks
|
||||
while ((match = toolUsePattern.exec(content)) !== null) {
|
||||
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
|
||||
// const fullMatch = match[0]
|
||||
const toolName = match[2].trim()
|
||||
const toolArgs = match[4].trim()
|
||||
@ -497,9 +504,7 @@ export async function parseAndCallTools<R>(
|
||||
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||
model: Model,
|
||||
mcpTools?: MCPTool[]
|
||||
): Promise<
|
||||
(ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.ResponseInputItem | ToolResultBlockParam)[]
|
||||
>
|
||||
): Promise<SdkMessageParam[]>
|
||||
|
||||
export async function parseAndCallTools<R>(
|
||||
content: string,
|
||||
@ -508,9 +513,7 @@ export async function parseAndCallTools<R>(
|
||||
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||
model: Model,
|
||||
mcpTools?: MCPTool[]
|
||||
): Promise<
|
||||
(ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.ResponseInputItem | ToolResultBlockParam)[]
|
||||
>
|
||||
): Promise<SdkMessageParam[]>
|
||||
|
||||
export async function parseAndCallTools<R>(
|
||||
content: string | MCPToolResponse[],
|
||||
@ -539,7 +542,7 @@ export async function parseAndCallTools<R>(
|
||||
...toolResponse,
|
||||
status: 'invoking'
|
||||
},
|
||||
onChunk
|
||||
onChunk!
|
||||
)
|
||||
}
|
||||
|
||||
@ -553,7 +556,7 @@ export async function parseAndCallTools<R>(
|
||||
status: 'done',
|
||||
response: toolCallResponse
|
||||
},
|
||||
onChunk
|
||||
onChunk!
|
||||
)
|
||||
|
||||
for (const content of toolCallResponse.content) {
|
||||
@ -563,10 +566,10 @@ export async function parseAndCallTools<R>(
|
||||
}
|
||||
|
||||
if (images.length) {
|
||||
onChunk({
|
||||
onChunk?.({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
onChunk({
|
||||
onChunk?.({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
|
||||
@ -101,7 +101,7 @@ export function isEmoji(str: string): boolean {
|
||||
* @returns {string} 处理后的字符串
|
||||
*/
|
||||
export function removeSpecialCharactersForTopicName(str: string): string {
|
||||
return str.replace(/[\r\n]+/g, ' ').trim()
|
||||
return str.replace(/["'\r\n]+/g, ' ').trim()
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -31,10 +31,12 @@ export function readableStreamAsyncIterable<T>(stream: any): AsyncIterableIterat
|
||||
}
|
||||
}
|
||||
|
||||
export function asyncGeneratorToReadableStream<T>(gen: AsyncGenerator<T>): ReadableStream<T> {
|
||||
export function asyncGeneratorToReadableStream<T>(gen: AsyncIterable<T>): ReadableStream<T> {
|
||||
const iterator = gen[Symbol.asyncIterator]()
|
||||
|
||||
return new ReadableStream<T>({
|
||||
async pull(controller) {
|
||||
const { value, done } = await gen.next()
|
||||
const { value, done } = await iterator.next()
|
||||
if (done) {
|
||||
controller.close()
|
||||
} else {
|
||||
@ -43,3 +45,17 @@ export function asyncGeneratorToReadableStream<T>(gen: AsyncGenerator<T>): Reada
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 将单个数据项转换为可读流
|
||||
* @param data 要转换为流的单个数据项
|
||||
* @returns 包含单个数据项的ReadableStream
|
||||
*/
|
||||
export function createSingleChunkReadableStream<T>(data: T): ReadableStream<T> {
|
||||
return new ReadableStream<T>({
|
||||
start(controller) {
|
||||
controller.enqueue(data)
|
||||
controller.close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
168
src/renderer/src/utils/tagExtraction.ts
Normal file
168
src/renderer/src/utils/tagExtraction.ts
Normal file
@ -0,0 +1,168 @@
|
||||
import { getPotentialStartIndex } from './getPotentialIndex'
|
||||
|
||||
export interface TagConfig {
|
||||
openingTag: string
|
||||
closingTag: string
|
||||
separator?: string
|
||||
}
|
||||
|
||||
export interface TagExtractionState {
|
||||
textBuffer: string
|
||||
isInsideTag: boolean
|
||||
isFirstTag: boolean
|
||||
isFirstText: boolean
|
||||
afterSwitch: boolean
|
||||
accumulatedTagContent: string
|
||||
hasTagContent: boolean
|
||||
}
|
||||
|
||||
export interface TagExtractionResult {
|
||||
content: string
|
||||
isTagContent: boolean
|
||||
complete: boolean
|
||||
tagContentExtracted?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用标签提取处理器
|
||||
* 可以处理各种形式的标签对,如 <think>...</think>, <tool_use>...</tool_use> 等
|
||||
*/
|
||||
export class TagExtractor {
|
||||
private config: TagConfig
|
||||
private state: TagExtractionState
|
||||
|
||||
constructor(config: TagConfig) {
|
||||
this.config = config
|
||||
this.state = {
|
||||
textBuffer: '',
|
||||
isInsideTag: false,
|
||||
isFirstTag: true,
|
||||
isFirstText: true,
|
||||
afterSwitch: false,
|
||||
accumulatedTagContent: '',
|
||||
hasTagContent: false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理文本块,返回处理结果
|
||||
*/
|
||||
processText(newText: string): TagExtractionResult[] {
|
||||
this.state.textBuffer += newText
|
||||
const results: TagExtractionResult[] = []
|
||||
|
||||
// 处理标签提取逻辑
|
||||
while (true) {
|
||||
const nextTag = this.state.isInsideTag ? this.config.closingTag : this.config.openingTag
|
||||
const startIndex = getPotentialStartIndex(this.state.textBuffer, nextTag)
|
||||
|
||||
if (startIndex == null) {
|
||||
const content = this.state.textBuffer
|
||||
if (content.length > 0) {
|
||||
results.push({
|
||||
content: this.addPrefix(content),
|
||||
isTagContent: this.state.isInsideTag,
|
||||
complete: false
|
||||
})
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.accumulatedTagContent += this.addPrefix(content)
|
||||
this.state.hasTagContent = true
|
||||
}
|
||||
}
|
||||
this.state.textBuffer = ''
|
||||
break
|
||||
}
|
||||
|
||||
// 处理标签前的内容
|
||||
const contentBeforeTag = this.state.textBuffer.slice(0, startIndex)
|
||||
if (contentBeforeTag.length > 0) {
|
||||
results.push({
|
||||
content: this.addPrefix(contentBeforeTag),
|
||||
isTagContent: this.state.isInsideTag,
|
||||
complete: false
|
||||
})
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.accumulatedTagContent += this.addPrefix(contentBeforeTag)
|
||||
this.state.hasTagContent = true
|
||||
}
|
||||
}
|
||||
|
||||
const foundFullMatch = startIndex + nextTag.length <= this.state.textBuffer.length
|
||||
|
||||
if (foundFullMatch) {
|
||||
// 如果找到完整的标签
|
||||
this.state.textBuffer = this.state.textBuffer.slice(startIndex + nextTag.length)
|
||||
|
||||
// 如果刚刚结束一个标签内容,生成完整的标签内容结果
|
||||
if (this.state.isInsideTag && this.state.hasTagContent) {
|
||||
results.push({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: this.state.accumulatedTagContent
|
||||
})
|
||||
this.state.accumulatedTagContent = ''
|
||||
this.state.hasTagContent = false
|
||||
}
|
||||
|
||||
this.state.isInsideTag = !this.state.isInsideTag
|
||||
this.state.afterSwitch = true
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.isFirstTag = false
|
||||
} else {
|
||||
this.state.isFirstText = false
|
||||
}
|
||||
} else {
|
||||
this.state.textBuffer = this.state.textBuffer.slice(startIndex)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
/**
|
||||
* 完成处理,返回任何剩余的标签内容
|
||||
*/
|
||||
finalize(): TagExtractionResult | null {
|
||||
if (this.state.hasTagContent && this.state.accumulatedTagContent) {
|
||||
const result = {
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: this.state.accumulatedTagContent
|
||||
}
|
||||
this.state.accumulatedTagContent = ''
|
||||
this.state.hasTagContent = false
|
||||
return result
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
private addPrefix(text: string): string {
|
||||
const needsPrefix =
|
||||
this.state.afterSwitch && (this.state.isInsideTag ? !this.state.isFirstTag : !this.state.isFirstText)
|
||||
|
||||
const prefix = needsPrefix && this.config.separator ? this.config.separator : ''
|
||||
this.state.afterSwitch = false
|
||||
return prefix + text
|
||||
}
|
||||
|
||||
/**
|
||||
* 重置状态
|
||||
*/
|
||||
reset(): void {
|
||||
this.state = {
|
||||
textBuffer: '',
|
||||
isInsideTag: false,
|
||||
isFirstTag: true,
|
||||
isFirstText: true,
|
||||
afterSwitch: false,
|
||||
accumulatedTagContent: '',
|
||||
hasTagContent: false
|
||||
}
|
||||
}
|
||||
}
|
||||
35
yarn.lock
35
yarn.lock
@ -2898,7 +2898,7 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@libsql/client@npm:^0.14.0":
|
||||
"@libsql/client@npm:0.14.0, @libsql/client@npm:^0.14.0":
|
||||
version: 0.14.0
|
||||
resolution: "@libsql/client@npm:0.14.0"
|
||||
dependencies:
|
||||
@ -2991,9 +2991,10 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@libsql/win32-x64-msvc@npm:0.4.7":
|
||||
"@libsql/win32-x64-msvc@npm:0.4.7, @libsql/win32-x64-msvc@npm:^0.4.7":
|
||||
version: 0.4.7
|
||||
resolution: "@libsql/win32-x64-msvc@npm:0.4.7"
|
||||
checksum: 10c0/2fcb8715b6f0571dec145eaaf3fd53c7c5aa5bf408fe1be9d84b10adc8a909bb6ee60b45e0d7052b0c1722c30ac212356a3f1adcdf7f57d5a59b48f36ca5bdf5
|
||||
conditions: os=win32 & cpu=x64
|
||||
languageName: node
|
||||
linkType: hard
|
||||
@ -5593,6 +5594,8 @@ __metadata:
|
||||
"@kangfenmao/keyv-storage": "npm:^0.1.0"
|
||||
"@langchain/community": "npm:^0.3.36"
|
||||
"@langchain/ollama": "npm:^0.2.1"
|
||||
"@libsql/client": "npm:0.14.0"
|
||||
"@libsql/win32-x64-msvc": "npm:^0.4.7"
|
||||
"@modelcontextprotocol/sdk": "npm:^1.11.4"
|
||||
"@mozilla/readability": "npm:^0.6.0"
|
||||
"@notionhq/client": "npm:^2.2.15"
|
||||
@ -5656,14 +5659,14 @@ __metadata:
|
||||
eslint-plugin-unused-imports: "npm:^4.1.4"
|
||||
fast-diff: "npm:^1.3.0"
|
||||
fast-xml-parser: "npm:^5.2.0"
|
||||
framer-motion: "npm:^12.17.0"
|
||||
framer-motion: "npm:^12.17.3"
|
||||
franc-min: "npm:^6.2.0"
|
||||
fs-extra: "npm:^11.2.0"
|
||||
html-to-image: "npm:^1.11.13"
|
||||
husky: "npm:^9.1.7"
|
||||
i18next: "npm:^23.11.5"
|
||||
jest-styled-components: "npm:^7.2.0"
|
||||
jsdom: "npm:^26.0.0"
|
||||
jsdom: "npm:26.1.0"
|
||||
lint-staged: "npm:^15.5.0"
|
||||
lodash: "npm:^4.17.21"
|
||||
lru-cache: "npm:^11.1.0"
|
||||
@ -5711,7 +5714,7 @@ __metadata:
|
||||
tar: "npm:^7.4.3"
|
||||
tiny-pinyin: "npm:^1.3.2"
|
||||
tokenx: "npm:^0.4.1"
|
||||
turndown: "npm:^7.2.0"
|
||||
turndown: "npm:7.2.0"
|
||||
typescript: "npm:^5.6.2"
|
||||
uuid: "npm:^10.0.0"
|
||||
vite: "npm:6.2.6"
|
||||
@ -9864,11 +9867,11 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"framer-motion@npm:^12.17.0":
|
||||
version: 12.17.0
|
||||
resolution: "framer-motion@npm:12.17.0"
|
||||
"framer-motion@npm:^12.17.3":
|
||||
version: 12.17.3
|
||||
resolution: "framer-motion@npm:12.17.3"
|
||||
dependencies:
|
||||
motion-dom: "npm:^12.17.0"
|
||||
motion-dom: "npm:^12.17.3"
|
||||
motion-utils: "npm:^12.12.1"
|
||||
tslib: "npm:^2.4.0"
|
||||
peerDependencies:
|
||||
@ -9882,7 +9885,7 @@ __metadata:
|
||||
optional: true
|
||||
react-dom:
|
||||
optional: true
|
||||
checksum: 10c0/3262ab125650d71cd13eb9f4838da70550ea383d68a2fbd2664b05bac88b7420fe7db25911fbd30cbc237327d98a4567df34e675c8261dde559a9375e580103c
|
||||
checksum: 10c0/2d8ae235f5b61005d47a7f004f7c04d7484686c07023c06ed546789fdaab5c2e24caac5f23a967263b6a51cbd72dcf41c84ad8b0671472c12f5373055cd6eb46
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
@ -11398,7 +11401,7 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"jsdom@npm:^26.0.0":
|
||||
"jsdom@npm:26.1.0":
|
||||
version: 26.1.0
|
||||
resolution: "jsdom@npm:26.1.0"
|
||||
dependencies:
|
||||
@ -13575,12 +13578,12 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"motion-dom@npm:^12.17.0":
|
||||
version: 12.17.0
|
||||
resolution: "motion-dom@npm:12.17.0"
|
||||
"motion-dom@npm:^12.17.3":
|
||||
version: 12.17.3
|
||||
resolution: "motion-dom@npm:12.17.3"
|
||||
dependencies:
|
||||
motion-utils: "npm:^12.12.1"
|
||||
checksum: 10c0/1ec428e113f334193dcd52293c94bca21fcca97f3825521d1dafe41f6b999e8dda5013b48de2c09e2f32204f80d1d7281079ba3a142c71b8d6923a0ddb056513
|
||||
checksum: 10c0/6892f070e07fdd4f6d97c347e479a1f706a6ad678f86818ce36d35a89dc79a0cc45804bd5758b95612893f48c8b6353f0cee2a25340b2cde789b8ad323aa592e
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
@ -17723,7 +17726,7 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"turndown@npm:^7.2.0":
|
||||
"turndown@npm:7.2.0":
|
||||
version: 7.2.0
|
||||
resolution: "turndown@npm:7.2.0"
|
||||
dependencies:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user