diff --git a/.prettierignore b/.prettierignore index 5f6cea6dad..e6e3d34935 100644 --- a/.prettierignore +++ b/.prettierignore @@ -7,4 +7,5 @@ tsconfig.*.json CHANGELOG*.md agents.json src/renderer/src/integration/nutstore/sso/lib +AGENT.md src/main/integration/cherryin/index.js diff --git a/electron-builder.yml b/electron-builder.yml index 6683c22ff2..dbafecb548 100644 --- a/electron-builder.yml +++ b/electron-builder.yml @@ -121,24 +121,12 @@ afterSign: scripts/notarize.js artifactBuildCompleted: scripts/artifact-build-completed.js releaseInfo: releaseNotes: | - ✨ 重要更新: - - 新增笔记模块,支持富文本编辑和管理 - - 内置 GLM-4.5-Flash 免费模型(由智谱开放平台提供) - - 内置 Qwen3-8B 免费模型(由硅基流动提供) - - 新增 Nano Banana(Gemini 2.5 Flash Image)模型支持 - - 新增系统 OCR 功能 (macOS & Windows) - - 新增图片 OCR 识别和翻译功能 - - 模型切换支持通过标签筛选 - - 翻译功能增强:历史搜索和收藏 - 🔧 性能优化: - - 优化历史页面搜索性能 - - 优化拖拽列表组件交互 - - 升级 Electron 到 37.4.0 + - 优化AI服务连接方式,提升响应速度和稳定性 + - 改进模型列表获取功能,减少不必要的网络请求 + - 增强各AI服务商的兼容性和连接可靠性 - 🐛 修复问题: - - 修复知识库加密 PDF 文档处理 - - 修复导航栏在左侧时笔记侧边栏按钮缺失 - - 修复多个模型兼容性问题 - - 修复 MCP 相关问题 - - 其他稳定性改进 + 🐛 问题修复: + - 修复部分AI服务商连接失败的问题 + - 修复模型配置加载时的潜在错误 + - 提升应用整体稳定性和容错能力 diff --git a/electron.vite.config.ts b/electron.vite.config.ts index 7cf20902b9..dff0a94a37 100644 --- a/electron.vite.config.ts +++ b/electron.vite.config.ts @@ -84,6 +84,9 @@ export default defineConfig({ '@logger': resolve('src/renderer/src/services/LoggerService'), '@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'), '@mcp-trace/trace-web': resolve('packages/mcp-trace/trace-web'), + '@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'), + '@cherrystudio/ai-core/built-in/plugins': resolve('packages/aiCore/src/core/plugins/built-in'), + '@cherrystudio/ai-core': resolve('packages/aiCore/src'), '@cherrystudio/extension-table-plus': resolve('packages/extension-table-plus/src') } }, diff --git a/package.json b/package.json index a6380c1ed3..c84bc483d4 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "CherryStudio", - "version": "1.5.9", + "version": "1.6.0-beta.6", "private": true, "description": "A powerful AI assistant for producer.", "main": "./out/main/index.js", @@ -88,12 +88,16 @@ "@agentic/exa": "^7.3.3", "@agentic/searxng": "^7.3.3", "@agentic/tavily": "^7.3.3", + "@ai-sdk/amazon-bedrock": "^3.0.0", + "@ai-sdk/google-vertex": "^3.0.0", + "@ai-sdk/mistral": "^2.0.0", "@ant-design/v5-patch-for-react-19": "^1.0.3", "@anthropic-ai/sdk": "^0.41.0", "@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch", "@aws-sdk/client-bedrock": "^3.840.0", "@aws-sdk/client-bedrock-runtime": "^3.840.0", "@aws-sdk/client-s3": "^3.840.0", + "@cherrystudio/ai-core": "workspace:*", "@cherrystudio/embedjs": "^0.1.31", "@cherrystudio/embedjs-libsql": "^0.1.31", "@cherrystudio/embedjs-loader-csv": "^0.1.31", @@ -129,6 +133,7 @@ "@modelcontextprotocol/sdk": "^1.17.0", "@mozilla/readability": "^0.6.0", "@notionhq/client": "^2.2.15", + "@openrouter/ai-sdk-provider": "^1.1.2", "@opentelemetry/api": "^1.9.0", "@opentelemetry/core": "2.0.0", "@opentelemetry/exporter-trace-otlp-http": "^0.200.0", @@ -138,7 +143,7 @@ "@playwright/test": "^1.52.0", "@reduxjs/toolkit": "^2.2.5", "@shikijs/markdown-it": "^3.12.0", - "@swc/plugin-styled-components": "^7.1.5", + "@swc/plugin-styled-components": "^8.0.4", "@tanstack/react-query": "^5.85.5", "@tanstack/react-virtual": "^3.13.12", "@testing-library/dom": "^10.4.0", @@ -189,6 +194,7 @@ "@viz-js/lang-dot": "^1.0.5", "@viz-js/viz": "^3.14.0", "@xyflow/react": "^12.4.4", + "ai": "^5.0.29", "antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch", "archiver": "^7.0.1", "async-mutex": "^0.5.0", @@ -337,7 +343,7 @@ "prettier --write", "eslint --fix" ], - "*.{json,md,yml,yaml,css,scss,html}": [ + "*.{json,yml,yaml,css,scss,html}": [ "prettier --write" ] } diff --git a/packages/aiCore/AI_SDK_ARCHITECTURE.md b/packages/aiCore/AI_SDK_ARCHITECTURE.md new file mode 100644 index 0000000000..67af20c0a6 --- /dev/null +++ b/packages/aiCore/AI_SDK_ARCHITECTURE.md @@ -0,0 +1,514 @@ +# AI Core 基于 Vercel AI SDK 的技术架构 + +## 1. 架构设计理念 + +### 1.1 设计目标 + +- **简化分层**:`models`(模型层)→ `runtime`(运行时层),清晰的职责分离 +- **统一接口**:使用 Vercel AI SDK 统一不同 AI Provider 的接口差异 +- **动态导入**:通过动态导入实现按需加载,减少打包体积 +- **最小包装**:直接使用 AI SDK 的类型和接口,避免重复定义 +- **插件系统**:基于钩子的通用插件架构,支持请求全生命周期扩展 +- **类型安全**:利用 TypeScript 和 AI SDK 的类型系统确保类型安全 +- **轻量级**:专注核心功能,保持包的轻量和高效 +- **包级独立**:作为独立包管理,便于复用和维护 +- **Agent就绪**:为将来集成 OpenAI Agents SDK 预留扩展空间 + +### 1.2 核心优势 + +- **标准化**:AI SDK 提供统一的模型接口,减少适配工作 +- **简化设计**:函数式API,避免过度抽象 +- **更好的开发体验**:完整的 TypeScript 支持和丰富的生态系统 +- **性能优化**:AI SDK 内置优化和最佳实践 +- **模块化设计**:独立包结构,支持跨项目复用 +- **可扩展插件**:通用的流转换和参数处理插件系统 +- **面向未来**:为 OpenAI Agents SDK 集成做好准备 + +## 2. 整体架构图 + +```mermaid +graph TD + subgraph "用户应用 (如 Cherry Studio)" + UI["用户界面"] + Components["应用组件"] + end + + subgraph "packages/aiCore (AI Core 包)" + subgraph "Runtime Layer (运行时层)" + RuntimeExecutor["RuntimeExecutor (运行时执行器)"] + PluginEngine["PluginEngine (插件引擎)"] + RuntimeAPI["Runtime API (便捷函数)"] + end + + subgraph "Models Layer (模型层)" + ModelFactory["createModel() (模型工厂)"] + ProviderCreator["ProviderCreator (提供商创建器)"] + end + + subgraph "Core Systems (核心系统)" + subgraph "Plugins (插件)" + PluginManager["PluginManager (插件管理)"] + BuiltInPlugins["Built-in Plugins (内置插件)"] + StreamTransforms["Stream Transforms (流转换)"] + end + + subgraph "Middleware (中间件)" + MiddlewareWrapper["wrapModelWithMiddlewares() (中间件包装)"] + end + + subgraph "Providers (提供商)" + Registry["Provider Registry (注册表)"] + Factory["Provider Factory (工厂)"] + end + end + end + + subgraph "Vercel AI SDK" + AICore["ai (核心库)"] + OpenAI["@ai-sdk/openai"] + Anthropic["@ai-sdk/anthropic"] + Google["@ai-sdk/google"] + XAI["@ai-sdk/xai"] + Others["其他 19+ Providers"] + end + + subgraph "Future: OpenAI Agents SDK" + AgentSDK["@openai/agents (未来集成)"] + AgentExtensions["Agent Extensions (预留)"] + end + + UI --> RuntimeAPI + Components --> RuntimeExecutor + RuntimeAPI --> RuntimeExecutor + RuntimeExecutor --> PluginEngine + RuntimeExecutor --> ModelFactory + PluginEngine --> PluginManager + ModelFactory --> ProviderCreator + ModelFactory --> MiddlewareWrapper + ProviderCreator --> Registry + Registry --> Factory + Factory --> OpenAI + Factory --> Anthropic + Factory --> Google + Factory --> XAI + Factory --> Others + + RuntimeExecutor --> AICore + AICore --> streamText + AICore --> generateText + AICore --> streamObject + AICore --> generateObject + + PluginManager --> StreamTransforms + PluginManager --> BuiltInPlugins + + %% 未来集成路径 + RuntimeExecutor -.-> AgentSDK + AgentSDK -.-> AgentExtensions +``` + +## 3. 包结构设计 + +### 3.1 新架构文件结构 + +``` +packages/aiCore/ +├── src/ +│ ├── core/ # 核心层 - 内部实现 +│ │ ├── models/ # 模型层 - 模型创建和配置 +│ │ │ ├── factory.ts # 模型工厂函数 ✅ +│ │ │ ├── ModelCreator.ts # 模型创建器 ✅ +│ │ │ ├── ConfigManager.ts # 配置管理器 ✅ +│ │ │ ├── types.ts # 模型类型定义 ✅ +│ │ │ └── index.ts # 模型层导出 ✅ +│ │ ├── runtime/ # 运行时层 - 执行和用户API +│ │ │ ├── executor.ts # 运行时执行器 ✅ +│ │ │ ├── pluginEngine.ts # 插件引擎 ✅ +│ │ │ ├── types.ts # 运行时类型定义 ✅ +│ │ │ └── index.ts # 运行时导出 ✅ +│ │ ├── middleware/ # 中间件系统 +│ │ │ ├── wrapper.ts # 模型包装器 ✅ +│ │ │ ├── manager.ts # 中间件管理器 ✅ +│ │ │ ├── types.ts # 中间件类型 ✅ +│ │ │ └── index.ts # 中间件导出 ✅ +│ │ ├── plugins/ # 插件系统 +│ │ │ ├── types.ts # 插件类型定义 ✅ +│ │ │ ├── manager.ts # 插件管理器 ✅ +│ │ │ ├── built-in/ # 内置插件 ✅ +│ │ │ │ ├── logging.ts # 日志插件 ✅ +│ │ │ │ ├── webSearchPlugin/ # 网络搜索插件 ✅ +│ │ │ │ ├── toolUsePlugin/ # 工具使用插件 ✅ +│ │ │ │ └── index.ts # 内置插件导出 ✅ +│ │ │ ├── README.md # 插件文档 ✅ +│ │ │ └── index.ts # 插件导出 ✅ +│ │ ├── providers/ # 提供商管理 +│ │ │ ├── registry.ts # 提供商注册表 ✅ +│ │ │ ├── factory.ts # 提供商工厂 ✅ +│ │ │ ├── creator.ts # 提供商创建器 ✅ +│ │ │ ├── types.ts # 提供商类型 ✅ +│ │ │ ├── utils.ts # 工具函数 ✅ +│ │ │ └── index.ts # 提供商导出 ✅ +│ │ ├── options/ # 配置选项 +│ │ │ ├── factory.ts # 选项工厂 ✅ +│ │ │ ├── types.ts # 选项类型 ✅ +│ │ │ ├── xai.ts # xAI 选项 ✅ +│ │ │ ├── openrouter.ts # OpenRouter 选项 ✅ +│ │ │ ├── examples.ts # 示例配置 ✅ +│ │ │ └── index.ts # 选项导出 ✅ +│ │ └── index.ts # 核心层导出 ✅ +│ ├── types.ts # 全局类型定义 ✅ +│ └── index.ts # 包主入口文件 ✅ +├── package.json # 包配置文件 ✅ +├── tsconfig.json # TypeScript 配置 ✅ +├── README.md # 包说明文档 ✅ +└── AI_SDK_ARCHITECTURE.md # 本文档 ✅ +``` + +## 4. 架构分层详解 + +### 4.1 Models Layer (模型层) + +**职责**:统一的模型创建和配置管理 + +**核心文件**: + +- `factory.ts`: 模型工厂函数 (`createModel`, `createModels`) +- `ProviderCreator.ts`: 底层提供商创建和模型实例化 +- `types.ts`: 模型配置类型定义 + +**设计特点**: + +- 函数式设计,避免不必要的类抽象 +- 统一的模型配置接口 +- 自动处理中间件应用 +- 支持批量模型创建 + +**核心API**: + +```typescript +// 模型配置接口 +export interface ModelConfig { + providerId: ProviderId + modelId: string + options: ProviderSettingsMap[ProviderId] + middlewares?: LanguageModelV1Middleware[] +} + +// 核心模型创建函数 +export async function createModel(config: ModelConfig): Promise +export async function createModels(configs: ModelConfig[]): Promise +``` + +### 4.2 Runtime Layer (运行时层) + +**职责**:运行时执行器和用户面向的API接口 + +**核心组件**: + +- `executor.ts`: 运行时执行器类 +- `plugin-engine.ts`: 插件引擎(原PluginEnabledAiClient) +- `index.ts`: 便捷函数和工厂方法 + +**设计特点**: + +- 提供三种使用方式:类实例、静态工厂、函数式调用 +- 自动集成模型创建和插件处理 +- 完整的类型安全支持 +- 为 OpenAI Agents SDK 预留扩展接口 + +**核心API**: + +```typescript +// 运行时执行器 +export class RuntimeExecutor { + static create( + providerId: T, + options: ProviderSettingsMap[T], + plugins?: AiPlugin[] + ): RuntimeExecutor + + async streamText(modelId: string, params: StreamTextParams): Promise + async generateText(modelId: string, params: GenerateTextParams): Promise + async streamObject(modelId: string, params: StreamObjectParams): Promise + async generateObject(modelId: string, params: GenerateObjectParams): Promise +} + +// 便捷函数式API +export async function streamText( + providerId: T, + options: ProviderSettingsMap[T], + modelId: string, + params: StreamTextParams, + plugins?: AiPlugin[] +): Promise +``` + +### 4.3 Plugin System (插件系统) + +**职责**:可扩展的插件架构 + +**核心组件**: + +- `PluginManager`: 插件生命周期管理 +- `built-in/`: 内置插件集合 +- 流转换收集和应用 + +**设计特点**: + +- 借鉴 Rollup 的钩子分类设计 +- 支持流转换 (`experimental_transform`) +- 内置常用插件(日志、计数等) +- 完整的生命周期钩子 + +**插件接口**: + +```typescript +export interface AiPlugin { + name: string + enforce?: 'pre' | 'post' + + // 【First】首个钩子 - 只执行第一个返回值的插件 + resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise + loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise + + // 【Sequential】串行钩子 - 链式执行,支持数据转换 + transformParams?: (params: any, context: AiRequestContext) => any | Promise + transformResult?: (result: any, context: AiRequestContext) => any | Promise + + // 【Parallel】并行钩子 - 不依赖顺序,用于副作用 + onRequestStart?: (context: AiRequestContext) => void | Promise + onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise + onError?: (error: Error, context: AiRequestContext) => void | Promise + + // 【Stream】流处理 + transformStream?: () => TransformStream +} +``` + +### 4.4 Middleware System (中间件系统) + +**职责**:AI SDK原生中间件支持 + +**核心组件**: + +- `ModelWrapper.ts`: 模型包装函数 + +**设计哲学**: + +- 直接使用AI SDK的 `wrapLanguageModel` +- 与插件系统分离,职责明确 +- 函数式设计,简化使用 + +```typescript +export function wrapModelWithMiddlewares(model: LanguageModel, middlewares: LanguageModelV1Middleware[]): LanguageModel +``` + +### 4.5 Provider System (提供商系统) + +**职责**:AI Provider注册表和动态导入 + +**核心组件**: + +- `registry.ts`: 19+ Provider配置和类型 +- `factory.ts`: Provider配置工厂 + +**支持的Providers**: + +- OpenAI, Anthropic, Google, XAI +- Azure OpenAI, Amazon Bedrock, Google Vertex +- Groq, Together.ai, Fireworks, DeepSeek +- 等19+ AI SDK官方支持的providers + +## 5. 使用方式 + +### 5.1 函数式调用 (推荐 - 简单场景) + +```typescript +import { streamText, generateText } from '@cherrystudio/ai-core/runtime' + +// 直接函数调用 +const stream = await streamText( + 'anthropic', + { apiKey: 'your-api-key' }, + 'claude-3', + { messages: [{ role: 'user', content: 'Hello!' }] }, + [loggingPlugin] +) +``` + +### 5.2 执行器实例 (推荐 - 复杂场景) + +```typescript +import { createExecutor } from '@cherrystudio/ai-core/runtime' + +// 创建可复用的执行器 +const executor = createExecutor('openai', { apiKey: 'your-api-key' }, [plugin1, plugin2]) + +// 多次使用 +const stream = await executor.streamText('gpt-4', { + messages: [{ role: 'user', content: 'Hello!' }] +}) + +const result = await executor.generateText('gpt-4', { + messages: [{ role: 'user', content: 'How are you?' }] +}) +``` + +### 5.3 静态工厂方法 + +```typescript +import { RuntimeExecutor } from '@cherrystudio/ai-core/runtime' + +// 静态创建 +const executor = RuntimeExecutor.create('anthropic', { apiKey: 'your-api-key' }) +await executor.streamText('claude-3', { messages: [...] }) +``` + +### 5.4 直接模型创建 (高级用法) + +```typescript +import { createModel } from '@cherrystudio/ai-core/models' +import { streamText } from 'ai' + +// 直接创建模型使用 +const model = await createModel({ + providerId: 'openai', + modelId: 'gpt-4', + options: { apiKey: 'your-api-key' }, + middlewares: [middleware1, middleware2] +}) + +// 直接使用 AI SDK +const result = await streamText({ model, messages: [...] }) +``` + +## 6. 为 OpenAI Agents SDK 预留的设计 + +### 6.1 架构兼容性 + +当前架构完全兼容 OpenAI Agents SDK 的集成需求: + +```typescript +// 当前的模型创建 +const model = await createModel({ + providerId: 'anthropic', + modelId: 'claude-3', + options: { apiKey: 'xxx' } +}) + +// 将来可以直接用于 OpenAI Agents SDK +import { Agent, run } from '@openai/agents' + +const agent = new Agent({ + model, // ✅ 直接兼容 LanguageModel 接口 + name: 'Assistant', + instructions: '...', + tools: [tool1, tool2] +}) + +const result = await run(agent, 'user input') +``` + +### 6.2 预留的扩展点 + +1. **runtime/agents/** 目录预留 +2. **AgentExecutor** 类预留 +3. **Agent工具转换插件** 预留 +4. **多Agent编排** 预留 + +### 6.3 未来架构扩展 + +``` +packages/aiCore/src/core/ +├── runtime/ +│ ├── agents/ # 🚀 未来添加 +│ │ ├── AgentExecutor.ts +│ │ ├── WorkflowManager.ts +│ │ └── ConversationManager.ts +│ ├── executor.ts +│ └── index.ts +``` + +## 7. 架构优势 + +### 7.1 简化设计 + +- **移除过度抽象**:删除了orchestration层和creation层的复杂包装 +- **函数式优先**:models层使用函数而非类 +- **直接明了**:runtime层直接提供用户API + +### 7.2 职责清晰 + +- **Models**: 专注模型创建和配置 +- **Runtime**: 专注执行和用户API +- **Plugins**: 专注扩展功能 +- **Providers**: 专注AI Provider管理 + +### 7.3 类型安全 + +- 完整的 TypeScript 支持 +- AI SDK 类型的直接复用 +- 避免类型重复定义 + +### 7.4 灵活使用 + +- 三种使用模式满足不同需求 +- 从简单函数调用到复杂执行器 +- 支持直接AI SDK使用 + +### 7.5 面向未来 + +- 为 OpenAI Agents SDK 集成做好准备 +- 清晰的扩展点和架构边界 +- 模块化设计便于功能添加 + +## 8. 技术决策记录 + +### 8.1 为什么选择简化的两层架构? + +- **职责分离**:models专注创建,runtime专注执行 +- **模块化**:每层都有清晰的边界和职责 +- **扩展性**:为Agent功能预留了清晰的扩展空间 + +### 8.2 为什么选择函数式设计? + +- **简洁性**:避免不必要的类设计 +- **性能**:减少对象创建开销 +- **易用性**:函数调用更直观 + +### 8.3 为什么分离插件和中间件? + +- **职责明确**: 插件处理应用特定需求 +- **原生支持**: 中间件使用AI SDK原生功能 +- **灵活性**: 两套系统可以独立演进 + +## 9. 总结 + +AI Core架构实现了: + +### 9.1 核心特点 + +- ✅ **简化架构**: 2层核心架构,职责清晰 +- ✅ **函数式设计**: models层完全函数化 +- ✅ **类型安全**: 统一的类型定义和AI SDK类型复用 +- ✅ **插件扩展**: 强大的插件系统 +- ✅ **多种使用方式**: 满足不同复杂度需求 +- ✅ **Agent就绪**: 为OpenAI Agents SDK集成做好准备 + +### 9.2 核心价值 + +- **统一接口**: 一套API支持19+ AI providers +- **灵活使用**: 函数式、实例式、静态工厂式 +- **强类型**: 完整的TypeScript支持 +- **可扩展**: 插件和中间件双重扩展能力 +- **高性能**: 最小化包装,直接使用AI SDK +- **面向未来**: Agent SDK集成架构就绪 + +### 9.3 未来发展 + +这个架构提供了: + +- **优秀的开发体验**: 简洁的API和清晰的使用模式 +- **强大的扩展能力**: 为Agent功能预留了完整的架构空间 +- **良好的维护性**: 职责分离明确,代码易于维护 +- **广泛的适用性**: 既适合简单调用也适合复杂应用 diff --git a/packages/aiCore/README.md b/packages/aiCore/README.md new file mode 100644 index 0000000000..4ca5ea6640 --- /dev/null +++ b/packages/aiCore/README.md @@ -0,0 +1,433 @@ +# @cherrystudio/ai-core + +Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口包,为 AI 应用提供强大的抽象层和插件化架构。 + +## ✨ 核心亮点 + +### 🏗️ 优雅的架构设计 + +- **简化分层**:`models`(模型层)→ `runtime`(运行时层),清晰的职责分离 +- **函数式优先**:避免过度抽象,提供简洁直观的 API +- **类型安全**:完整的 TypeScript 支持,直接复用 AI SDK 类型系统 +- **最小包装**:直接使用 AI SDK 的接口,避免重复定义和性能损耗 + +### 🔌 强大的插件系统 + +- **生命周期钩子**:支持请求全生命周期的扩展点 +- **流转换支持**:基于 AI SDK 的 `experimental_transform` 实现流处理 +- **插件分类**:First、Sequential、Parallel 三种钩子类型,满足不同场景 +- **内置插件**:webSearch、logging、toolUse 等开箱即用的功能 + +### 🌐 统一多 Provider 接口 + +- **扩展注册**:支持自定义 Provider 注册,无限扩展能力 +- **配置统一**:统一的配置接口,简化多 Provider 管理 + +### 🚀 多种使用方式 + +- **函数式调用**:适合简单场景的直接函数调用 +- **执行器实例**:适合复杂场景的可复用执行器 +- **静态工厂**:便捷的静态创建方法 +- **原生兼容**:完全兼容 AI SDK 原生 Provider Registry + +### 🔮 面向未来 + +- **Agent 就绪**:为 OpenAI Agents SDK 集成预留架构空间 +- **模块化设计**:独立包结构,支持跨项目复用 +- **渐进式迁移**:可以逐步从现有 AI SDK 代码迁移 + +## 特性 + +- 🚀 统一的 AI Provider 接口 +- 🔄 动态导入支持 +- 🛠️ TypeScript 支持 +- 📦 强大的插件系统 +- 🌍 内置webSearch(Openai,Google,Anthropic,xAI) +- 🎯 多种使用模式(函数式/实例式/静态工厂) +- 🔌 可扩展的 Provider 注册系统 +- 🧩 完整的中间件支持 +- 📊 插件统计和调试功能 + +## 支持的 Providers + +基于 [AI SDK 官方支持的 providers](https://ai-sdk.dev/providers/ai-sdk-providers): + +**核心 Providers(内置支持):** + +- OpenAI +- Anthropic +- Google Generative AI +- OpenAI-Compatible +- xAI (Grok) +- Azure OpenAI +- DeepSeek + +**扩展 Providers(通过注册API支持):** + +- Google Vertex AI +- ... +- 自定义 Provider + +## 安装 + +```bash +npm install @cherrystudio/ai-core ai +``` + +### React Native + +如果你在 React Native 项目中使用此包,需要在 `metro.config.js` 中添加以下配置: + +```javascript +// metro.config.js +const { getDefaultConfig } = require('expo/metro-config') + +const config = getDefaultConfig(__dirname) + +// 添加对 @cherrystudio/ai-core 的支持 +config.resolver.resolverMainFields = ['react-native', 'browser', 'main'] +config.resolver.platforms = ['ios', 'android', 'native', 'web'] + +module.exports = config +``` + +还需要安装你要使用的 AI SDK provider: + +```bash +npm install @ai-sdk/openai @ai-sdk/anthropic @ai-sdk/google +``` + +## 使用示例 + +### 基础用法 + +```typescript +import { AiCore } from '@cherrystudio/ai-core' + +// 创建 OpenAI executor +const executor = AiCore.create('openai', { + apiKey: 'your-api-key' +}) + +// 流式生成 +const result = await executor.streamText('gpt-4', { + messages: [{ role: 'user', content: 'Hello!' }] +}) + +// 非流式生成 +const response = await executor.generateText('gpt-4', { + messages: [{ role: 'user', content: 'Hello!' }] +}) +``` + +### 便捷函数 + +```typescript +import { createOpenAIExecutor } from '@cherrystudio/ai-core' + +// 快速创建 OpenAI executor +const executor = createOpenAIExecutor({ + apiKey: 'your-api-key' +}) + +// 使用 executor +const result = await executor.streamText('gpt-4', { + messages: [{ role: 'user', content: 'Hello!' }] +}) +``` + +### 多 Provider 支持 + +```typescript +import { AiCore } from '@cherrystudio/ai-core' + +// 支持多种 AI providers +const openaiExecutor = AiCore.create('openai', { apiKey: 'openai-key' }) +const anthropicExecutor = AiCore.create('anthropic', { apiKey: 'anthropic-key' }) +const googleExecutor = AiCore.create('google', { apiKey: 'google-key' }) +const xaiExecutor = AiCore.create('xai', { apiKey: 'xai-key' }) +``` + +### 扩展 Provider 注册 + +对于非内置的 providers,可以通过注册 API 扩展支持: + +```typescript +import { registerProvider, AiCore } from '@cherrystudio/ai-core' + +// 方式一:导入并注册第三方 provider +import { createGroq } from '@ai-sdk/groq' + +registerProvider({ + id: 'groq', + name: 'Groq', + creator: createGroq, + supportsImageGeneration: false +}) + +// 现在可以使用 Groq +const groqExecutor = AiCore.create('groq', { apiKey: 'groq-key' }) + +// 方式二:动态导入方式注册 +registerProvider({ + id: 'mistral', + name: 'Mistral AI', + import: () => import('@ai-sdk/mistral'), + creatorFunctionName: 'createMistral' +}) + +const mistralExecutor = AiCore.create('mistral', { apiKey: 'mistral-key' }) +``` + +## 🔌 插件系统 + +AI Core 提供了强大的插件系统,支持请求全生命周期的扩展。 + +### 内置插件 + +#### webSearchPlugin - 网络搜索插件 + +为不同 AI Provider 提供统一的网络搜索能力: + +```typescript +import { webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins' + +const executor = AiCore.create('openai', { apiKey: 'your-key' }, [ + webSearchPlugin({ + openai: { + /* OpenAI 搜索配置 */ + }, + anthropic: { maxUses: 5 }, + google: { + /* Google 搜索配置 */ + }, + xai: { + mode: 'on', + returnCitations: true, + maxSearchResults: 5, + sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }] + } + }) +]) +``` + +#### loggingPlugin - 日志插件 + +提供详细的请求日志记录: + +```typescript +import { createLoggingPlugin } from '@cherrystudio/ai-core/built-in/plugins' + +const executor = AiCore.create('openai', { apiKey: 'your-key' }, [ + createLoggingPlugin({ + logLevel: 'info', + includeParams: true, + includeResult: false + }) +]) +``` + +#### promptToolUsePlugin - 提示工具使用插件 + +为不支持原生 Function Call 的模型提供 prompt 方式的工具调用: + +```typescript +import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/built-in/plugins' + +// 对于不支持 function call 的模型 +const executor = AiCore.create( + 'providerId', + { + apiKey: 'your-key', + baseURL: 'https://your-model-endpoint' + }, + [ + createPromptToolUsePlugin({ + enabled: true, + // 可选:自定义系统提示符构建 + buildSystemPrompt: (userPrompt, tools) => { + return `${userPrompt}\n\nAvailable tools: ${Object.keys(tools).join(', ')}` + } + }) + ] +) +``` + +### 自定义插件 + +创建自定义插件非常简单: + +```typescript +import { definePlugin } from '@cherrystudio/ai-core' + +const customPlugin = definePlugin({ + name: 'custom-plugin', + enforce: 'pre', // 'pre' | 'post' | undefined + + // 在请求开始时记录日志 + onRequestStart: async (context) => { + console.log(`Starting request for model: ${context.modelId}`) + }, + + // 转换请求参数 + transformParams: async (params, context) => { + // 添加自定义系统消息 + if (params.messages) { + params.messages.unshift({ + role: 'system', + content: 'You are a helpful assistant.' + }) + } + return params + }, + + // 处理响应结果 + transformResult: async (result, context) => { + // 添加元数据 + if (result.text) { + result.metadata = { + processedAt: new Date().toISOString(), + modelId: context.modelId + } + } + return result + } +}) + +// 使用自定义插件 +const executor = AiCore.create('openai', { apiKey: 'your-key' }, [customPlugin]) +``` + +### 使用 AI SDK 原生 Provider 注册表 + +> https://ai-sdk.dev/docs/reference/ai-sdk-core/provider-registry + +除了使用内建的 provider 管理,你还可以使用 AI SDK 原生的 `createProviderRegistry` 来构建自己的 provider 注册表。 + +#### 基本用法示例 + +```typescript +import { createClient } from '@cherrystudio/ai-core' +import { createProviderRegistry } from 'ai' +import { createOpenAI } from '@ai-sdk/openai' +import { anthropic } from '@ai-sdk/anthropic' + +// 1. 创建 AI SDK 原生注册表 +export const registry = createProviderRegistry({ + // register provider with prefix and default setup: + anthropic, + + // register provider with prefix and custom setup: + openai: createOpenAI({ + apiKey: process.env.OPENAI_API_KEY + }) +}) + +// 2. 创建client,'openai'可以传空或者传providerId(内建的provider) +const client = PluginEnabledAiClient.create('openai', { + apiKey: process.env.OPENAI_API_KEY +}) + +// 3. 方式1:使用内建逻辑(传统方式) +const result1 = await client.streamText('gpt-4', { + messages: [{ role: 'user', content: 'Hello with built-in logic!' }] +}) + +// 4. 方式2:使用自定义注册表(灵活方式) +const result2 = await client.streamText({ + model: registry.languageModel('openai:gpt-4'), + messages: [{ role: 'user', content: 'Hello with custom registry!' }] +}) + +// 5. 支持的重载方法 +await client.generateObject({ + model: registry.languageModel('openai:gpt-4'), + schema: z.object({ name: z.string() }), + messages: [{ role: 'user', content: 'Generate a user' }] +}) + +await client.streamObject({ + model: registry.languageModel('anthropic:claude-3-opus-20240229'), + schema: z.object({ items: z.array(z.string()) }), + messages: [{ role: 'user', content: 'Generate a list' }] +}) +``` + +#### 与插件系统配合使用 + +更强大的是,你还可以将自定义注册表与 Cherry Studio 的插件系统结合使用: + +```typescript +import { PluginEnabledAiClient } from '@cherrystudio/ai-core' +import { createProviderRegistry } from 'ai' +import { createOpenAI } from '@ai-sdk/openai' +import { anthropic } from '@ai-sdk/anthropic' + +// 1. 创建带插件的客户端 +const client = PluginEnabledAiClient.create( + 'openai', + { + apiKey: process.env.OPENAI_API_KEY + }, + [LoggingPlugin, RetryPlugin] +) + +// 2. 创建自定义注册表 +const registry = createProviderRegistry({ + openai: createOpenAI({ apiKey: process.env.OPENAI_API_KEY }), + anthropic: anthropic({ apiKey: process.env.ANTHROPIC_API_KEY }) +}) + +// 3. 方式1:使用内建逻辑 + 完整插件系统 +await client.streamText('gpt-4', { + messages: [{ role: 'user', content: 'Hello with plugins!' }] +}) + +// 4. 方式2:使用自定义注册表 + 有限插件支持 +await client.streamText({ + model: registry.languageModel('anthropic:claude-3-opus-20240229'), + messages: [{ role: 'user', content: 'Hello from Claude!' }] +}) + +// 5. 支持的方法 +await client.generateObject({ + model: registry.languageModel('openai:gpt-4'), + schema: z.object({ name: z.string() }), + messages: [{ role: 'user', content: 'Generate a user' }] +}) + +await client.streamObject({ + model: registry.languageModel('openai:gpt-4'), + schema: z.object({ items: z.array(z.string()) }), + messages: [{ role: 'user', content: 'Generate a list' }] +}) +``` + +#### 混合使用的优势 + +- **灵活性**:可以根据需要选择使用内建逻辑或自定义注册表 +- **兼容性**:完全兼容 AI SDK 的 `createProviderRegistry` API +- **渐进式**:可以逐步迁移现有代码,无需一次性重构 +- **插件支持**:自定义注册表仍可享受插件系统的部分功能 +- **最佳实践**:结合两种方式的优点,既有动态加载的性能优势,又有统一注册表的便利性 + +## 📚 相关资源 + +- [Vercel AI SDK 文档](https://ai-sdk.dev/) +- [Cherry Studio 项目](https://github.com/CherryHQ/cherry-studio) +- [AI SDK Providers](https://ai-sdk.dev/providers/ai-sdk-providers) + +## 未来版本 + +- 🔮 多 Agent 编排 +- 🔮 可视化插件配置 +- 🔮 实时监控和分析 +- 🔮 云端插件同步 + +## 📄 License + +MIT License - 详见 [LICENSE](https://github.com/CherryHQ/cherry-studio/blob/main/LICENSE) 文件 + +--- + +**Cherry Studio AI Core** - 让 AI 开发更简单、更强大、更灵活 🚀 diff --git a/packages/aiCore/examples/hub-provider-usage.ts b/packages/aiCore/examples/hub-provider-usage.ts new file mode 100644 index 0000000000..559e812bdb --- /dev/null +++ b/packages/aiCore/examples/hub-provider-usage.ts @@ -0,0 +1,103 @@ +/** + * Hub Provider 使用示例 + * + * 演示如何使用简化后的Hub Provider功能来路由到多个底层provider + */ + +import { createHubProvider, initializeProvider, providerRegistry } from '../src/index' + +async function demonstrateHubProvider() { + try { + // 1. 初始化底层providers + console.log('📦 初始化底层providers...') + + initializeProvider('openai', { + apiKey: process.env.OPENAI_API_KEY || 'sk-test-key' + }) + + initializeProvider('anthropic', { + apiKey: process.env.ANTHROPIC_API_KEY || 'sk-ant-test-key' + }) + + // 2. 创建Hub Provider(自动包含所有已初始化的providers) + console.log('🌐 创建Hub Provider...') + + const aihubmixProvider = createHubProvider({ + hubId: 'aihubmix', + debug: true + }) + + // 3. 注册Hub Provider + providerRegistry.registerProvider('aihubmix', aihubmixProvider) + + console.log('✅ Hub Provider "aihubmix" 注册成功') + + // 4. 使用Hub Provider访问不同的模型 + console.log('\n🚀 使用Hub模型...') + + // 通过Hub路由到OpenAI + const openaiModel = providerRegistry.languageModel('aihubmix:openai:gpt-4') + console.log('✓ OpenAI模型已获取:', openaiModel.modelId) + + // 通过Hub路由到Anthropic + const anthropicModel = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet') + console.log('✓ Anthropic模型已获取:', anthropicModel.modelId) + + // 5. 演示错误处理 + console.log('\n❌ 演示错误处理...') + + try { + // 尝试访问未初始化的provider + providerRegistry.languageModel('aihubmix:google:gemini-pro') + } catch (error) { + console.log('预期错误:', error.message) + } + + try { + // 尝试使用错误的模型ID格式 + providerRegistry.languageModel('aihubmix:invalid-format') + } catch (error) { + console.log('预期错误:', error.message) + } + + // 6. 多个Hub Provider示例 + console.log('\n🔄 创建多个Hub Provider...') + + const localHubProvider = createHubProvider({ + hubId: 'local-ai' + }) + + providerRegistry.registerProvider('local-ai', localHubProvider) + console.log('✅ Hub Provider "local-ai" 注册成功') + + console.log('\n🎉 Hub Provider演示完成!') + } catch (error) { + console.error('💥 演示过程中发生错误:', error) + } +} + +// 演示简化的使用方式 +function simplifiedUsageExample() { + console.log('\n📝 简化使用示例:') + console.log(` +// 1. 初始化providers +initializeProvider('openai', { apiKey: 'sk-xxx' }) +initializeProvider('anthropic', { apiKey: 'sk-ant-xxx' }) + +// 2. 创建并注册Hub Provider +const hubProvider = createHubProvider({ hubId: 'aihubmix' }) +providerRegistry.registerProvider('aihubmix', hubProvider) + +// 3. 直接使用 +const model1 = providerRegistry.languageModel('aihubmix:openai:gpt-4') +const model2 = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet') +`) +} + +// 运行演示 +if (require.main === module) { + demonstrateHubProvider() + simplifiedUsageExample() +} + +export { demonstrateHubProvider, simplifiedUsageExample } diff --git a/packages/aiCore/examples/image-generation.ts b/packages/aiCore/examples/image-generation.ts new file mode 100644 index 0000000000..811aa048c8 --- /dev/null +++ b/packages/aiCore/examples/image-generation.ts @@ -0,0 +1,167 @@ +/** + * Image Generation Example + * 演示如何使用 aiCore 的文生图功能 + */ + +import { createExecutor, generateImage } from '../src/index' + +async function main() { + // 方式1: 使用执行器实例 + console.log('📸 创建 OpenAI 图像生成执行器...') + const executor = createExecutor('openai', { + apiKey: process.env.OPENAI_API_KEY! + }) + + try { + console.log('🎨 使用执行器生成图像...') + const result1 = await executor.generateImage('dall-e-3', { + prompt: 'A futuristic cityscape at sunset with flying cars', + size: '1024x1024', + n: 1 + }) + + console.log('✅ 图像生成成功!') + console.log('📊 结果:', { + imagesCount: result1.images.length, + mediaType: result1.image.mediaType, + hasBase64: !!result1.image.base64, + providerMetadata: result1.providerMetadata + }) + } catch (error) { + console.error('❌ 执行器生成失败:', error) + } + + // 方式2: 使用直接调用 API + try { + console.log('🎨 使用直接 API 生成图像...') + const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', { + prompt: 'A magical forest with glowing mushrooms and fairy lights', + aspectRatio: '16:9', + providerOptions: { + openai: { + quality: 'hd', + style: 'vivid' + } + } + }) + + console.log('✅ 直接 API 生成成功!') + console.log('📊 结果:', { + imagesCount: result2.images.length, + mediaType: result2.image.mediaType, + hasBase64: !!result2.image.base64 + }) + } catch (error) { + console.error('❌ 直接 API 生成失败:', error) + } + + // 方式3: 支持其他提供商 (Google Imagen) + if (process.env.GOOGLE_API_KEY) { + try { + console.log('🎨 使用 Google Imagen 生成图像...') + const googleExecutor = createExecutor('google', { + apiKey: process.env.GOOGLE_API_KEY! + }) + + const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', { + prompt: 'A serene mountain lake at dawn with mist rising from the water', + aspectRatio: '1:1' + }) + + console.log('✅ Google Imagen 生成成功!') + console.log('📊 结果:', { + imagesCount: result3.images.length, + mediaType: result3.image.mediaType, + hasBase64: !!result3.image.base64 + }) + } catch (error) { + console.error('❌ Google Imagen 生成失败:', error) + } + } + + // 方式4: 支持插件系统 + const pluginExample = async () => { + console.log('🔌 演示插件系统...') + + // 创建一个示例插件,用于修改提示词 + const promptEnhancerPlugin = { + name: 'prompt-enhancer', + transformParams: async (params: any) => { + console.log('🔧 插件: 增强提示词...') + return { + ...params, + prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution` + } + }, + transformResult: async (result: any) => { + console.log('🔧 插件: 处理结果...') + return { + ...result, + enhanced: true + } + } + } + + const executorWithPlugin = createExecutor( + 'openai', + { + apiKey: process.env.OPENAI_API_KEY! + }, + [promptEnhancerPlugin] + ) + + try { + const result4 = await executorWithPlugin.generateImage('dall-e-3', { + prompt: 'A cute robot playing in a garden' + }) + + console.log('✅ 插件系统生成成功!') + console.log('📊 结果:', { + imagesCount: result4.images.length, + enhanced: (result4 as any).enhanced, + mediaType: result4.image.mediaType + }) + } catch (error) { + console.error('❌ 插件系统生成失败:', error) + } + } + + await pluginExample() +} + +// 错误处理演示 +async function errorHandlingExample() { + console.log('⚠️ 演示错误处理...') + + try { + const executor = createExecutor('openai', { + apiKey: 'invalid-key' + }) + + await executor.generateImage('dall-e-3', { + prompt: 'Test image' + }) + } catch (error: any) { + console.log('✅ 成功捕获错误:', error.constructor.name) + console.log('📋 错误信息:', error.message) + console.log('🏷️ 提供商ID:', error.providerId) + console.log('🏷️ 模型ID:', error.modelId) + } +} + +// 运行示例 +if (require.main === module) { + main() + .then(() => { + console.log('🎉 所有示例完成!') + return errorHandlingExample() + }) + .then(() => { + console.log('🎯 示例程序结束') + process.exit(0) + }) + .catch((error) => { + console.error('💥 程序执行出错:', error) + process.exit(1) + }) +} diff --git a/packages/aiCore/package.json b/packages/aiCore/package.json new file mode 100644 index 0000000000..2e02700563 --- /dev/null +++ b/packages/aiCore/package.json @@ -0,0 +1,85 @@ +{ + "name": "@cherrystudio/ai-core", + "version": "1.0.0-alpha.11", + "description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK", + "main": "dist/index.js", + "module": "dist/index.mjs", + "types": "dist/index.d.ts", + "react-native": "dist/index.js", + "scripts": { + "build": "tsdown", + "dev": "tsc -w", + "clean": "rm -rf dist", + "test": "vitest run", + "test:watch": "vitest" + }, + "keywords": [ + "ai", + "sdk", + "openai", + "anthropic", + "google", + "cherry-studio", + "vercel-ai-sdk" + ], + "author": "Cherry Studio", + "license": "MIT", + "repository": { + "type": "git", + "url": "git+https://github.com/CherryHQ/cherry-studio.git" + }, + "bugs": { + "url": "https://github.com/CherryHQ/cherry-studio/issues" + }, + "homepage": "https://github.com/CherryHQ/cherry-studio#readme", + "peerDependencies": { + "ai": "^5.0.26" + }, + "dependencies": { + "@ai-sdk/anthropic": "^2.0.5", + "@ai-sdk/azure": "^2.0.16", + "@ai-sdk/deepseek": "^1.0.9", + "@ai-sdk/google": "^2.0.7", + "@ai-sdk/openai": "^2.0.19", + "@ai-sdk/openai-compatible": "^1.0.9", + "@ai-sdk/provider": "^2.0.0", + "@ai-sdk/provider-utils": "^3.0.4", + "@ai-sdk/xai": "^2.0.9", + "zod": "^3.25.0" + }, + "devDependencies": { + "tsdown": "^0.12.9", + "typescript": "^5.0.0", + "vitest": "^3.2.4" + }, + "sideEffects": false, + "engines": { + "node": ">=18.0.0" + }, + "files": [ + "dist" + ], + "exports": { + ".": { + "types": "./dist/index.d.ts", + "react-native": "./dist/index.js", + "import": "./dist/index.mjs", + "require": "./dist/index.js", + "default": "./dist/index.js" + }, + "./built-in/plugins": { + "types": "./dist/built-in/plugins/index.d.ts", + "react-native": "./dist/built-in/plugins/index.js", + "import": "./dist/built-in/plugins/index.mjs", + "require": "./dist/built-in/plugins/index.js", + "default": "./dist/built-in/plugins/index.js" + }, + "./provider": { + "types": "./dist/provider/index.d.ts", + "react-native": "./dist/provider/index.js", + "import": "./dist/provider/index.mjs", + "require": "./dist/provider/index.js", + "default": "./dist/provider/index.js" + } + } +} diff --git a/packages/aiCore/setupVitest.ts b/packages/aiCore/setupVitest.ts new file mode 100644 index 0000000000..c878079ca9 --- /dev/null +++ b/packages/aiCore/setupVitest.ts @@ -0,0 +1,2 @@ +// 模拟 Vite SSR helper,避免 Node 环境找不到时报错 +;(globalThis as any).__vite_ssr_exportName__ = (name: string, value: any) => value diff --git a/packages/aiCore/src/core/README.MD b/packages/aiCore/src/core/README.MD new file mode 100644 index 0000000000..fc33fe18d5 --- /dev/null +++ b/packages/aiCore/src/core/README.MD @@ -0,0 +1,3 @@ +# @cherryStudio-aiCore + +Core diff --git a/packages/aiCore/src/core/index.ts b/packages/aiCore/src/core/index.ts new file mode 100644 index 0000000000..2346ea8cd2 --- /dev/null +++ b/packages/aiCore/src/core/index.ts @@ -0,0 +1,17 @@ +/** + * Core 模块导出 + * 内部核心功能,供其他模块使用,不直接面向最终调用者 + */ + +// 中间件系统 +export type { NamedMiddleware } from './middleware' +export { createMiddlewares, wrapModelWithMiddlewares } from './middleware' + +// 创建管理 +export { globalModelResolver, ModelResolver } from './models' +export type { ModelConfig as ModelConfigType } from './models/types' + +// 执行管理 +export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type' +export { createExecutor, createOpenAICompatibleExecutor } from './runtime' +export type { RuntimeConfig } from './runtime/types' diff --git a/packages/aiCore/src/core/middleware/index.ts b/packages/aiCore/src/core/middleware/index.ts new file mode 100644 index 0000000000..535b588098 --- /dev/null +++ b/packages/aiCore/src/core/middleware/index.ts @@ -0,0 +1,8 @@ +/** + * Middleware 模块导出 + * 提供通用的中间件管理能力 + */ + +export { createMiddlewares } from './manager' +export type { NamedMiddleware } from './types' +export { wrapModelWithMiddlewares } from './wrapper' diff --git a/packages/aiCore/src/core/middleware/manager.ts b/packages/aiCore/src/core/middleware/manager.ts new file mode 100644 index 0000000000..bcb044b3a9 --- /dev/null +++ b/packages/aiCore/src/core/middleware/manager.ts @@ -0,0 +1,16 @@ +/** + * 中间件管理器 + * 专注于 AI SDK 中间件的管理,与插件系统分离 + */ +import { LanguageModelV2Middleware } from '@ai-sdk/provider' + +/** + * 创建中间件列表 + * 合并用户提供的中间件 + */ +export function createMiddlewares(userMiddlewares: LanguageModelV2Middleware[] = []): LanguageModelV2Middleware[] { + // 未来可以在这里添加默认的中间件 + const defaultMiddlewares: LanguageModelV2Middleware[] = [] + + return [...defaultMiddlewares, ...userMiddlewares] +} diff --git a/packages/aiCore/src/core/middleware/types.ts b/packages/aiCore/src/core/middleware/types.ts new file mode 100644 index 0000000000..50b5210b53 --- /dev/null +++ b/packages/aiCore/src/core/middleware/types.ts @@ -0,0 +1,12 @@ +/** + * 中间件系统类型定义 + */ +import { LanguageModelV2Middleware } from '@ai-sdk/provider' + +/** + * 具名中间件接口 + */ +export interface NamedMiddleware { + name: string + middleware: LanguageModelV2Middleware +} diff --git a/packages/aiCore/src/core/middleware/wrapper.ts b/packages/aiCore/src/core/middleware/wrapper.ts new file mode 100644 index 0000000000..625eddbab3 --- /dev/null +++ b/packages/aiCore/src/core/middleware/wrapper.ts @@ -0,0 +1,23 @@ +/** + * 模型包装工具函数 + * 用于将中间件应用到LanguageModel上 + */ +import { LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider' +import { wrapLanguageModel } from 'ai' + +/** + * 使用中间件包装模型 + */ +export function wrapModelWithMiddlewares( + model: LanguageModelV2, + middlewares: LanguageModelV2Middleware[] +): LanguageModelV2 { + if (middlewares.length === 0) { + return model + } + + return wrapLanguageModel({ + model, + middleware: middlewares + }) +} diff --git a/packages/aiCore/src/core/models/ModelResolver.ts b/packages/aiCore/src/core/models/ModelResolver.ts new file mode 100644 index 0000000000..0f1bde95c6 --- /dev/null +++ b/packages/aiCore/src/core/models/ModelResolver.ts @@ -0,0 +1,125 @@ +/** + * 模型解析器 - models模块的核心 + * 负责将modelId解析为AI SDK的LanguageModel实例 + * 支持传统格式和命名空间格式 + * 集成了来自 ModelCreator 的特殊处理逻辑 + */ + +import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider' + +import { wrapModelWithMiddlewares } from '../middleware/wrapper' +import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../providers/RegistryManagement' + +export class ModelResolver { + /** + * 核心方法:解析任意格式的modelId为语言模型 + * + * @param modelId 模型ID,支持 'gpt-4' 和 'anthropic>claude-3' 两种格式 + * @param fallbackProviderId 当modelId为传统格式时使用的providerId + * @param providerOptions provider配置选项(用于OpenAI模式选择等) + * @param middlewares 中间件数组,会应用到最终模型上 + */ + async resolveLanguageModel( + modelId: string, + fallbackProviderId: string, + providerOptions?: any, + middlewares?: LanguageModelV2Middleware[] + ): Promise { + let finalProviderId = fallbackProviderId + let model: LanguageModelV2 + // 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移) + if ((fallbackProviderId === 'openai' || fallbackProviderId === 'azure') && providerOptions?.mode === 'chat') { + finalProviderId = `${fallbackProviderId}-chat` + } + + // 检查是否是命名空间格式 + if (modelId.includes(DEFAULT_SEPARATOR)) { + model = this.resolveNamespacedModel(modelId) + } else { + // 传统格式:使用处理后的 providerId + modelId + model = this.resolveTraditionalModel(finalProviderId, modelId) + } + + // 🎯 应用中间件(如果有) + if (middlewares && middlewares.length > 0) { + model = wrapModelWithMiddlewares(model, middlewares) + } + + return model + } + + /** + * 解析文本嵌入模型 + */ + async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise> { + if (modelId.includes(DEFAULT_SEPARATOR)) { + return this.resolveNamespacedEmbeddingModel(modelId) + } + + return this.resolveTraditionalEmbeddingModel(fallbackProviderId, modelId) + } + + /** + * 解析图像模型 + */ + async resolveImageModel(modelId: string, fallbackProviderId: string): Promise { + if (modelId.includes(DEFAULT_SEPARATOR)) { + return this.resolveNamespacedImageModel(modelId) + } + + return this.resolveTraditionalImageModel(fallbackProviderId, modelId) + } + + /** + * 解析命名空间格式的语言模型 + * aihubmix:anthropic:claude-3 -> globalRegistryManagement.languageModel('aihubmix:anthropic:claude-3') + */ + private resolveNamespacedModel(modelId: string): LanguageModelV2 { + return globalRegistryManagement.languageModel(modelId as any) + } + + /** + * 解析传统格式的语言模型 + * providerId: 'openai', modelId: 'gpt-4' -> globalRegistryManagement.languageModel('openai:gpt-4') + */ + private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 { + const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}` + console.log('fullModelId', fullModelId) + return globalRegistryManagement.languageModel(fullModelId as any) + } + + /** + * 解析命名空间格式的嵌入模型 + */ + private resolveNamespacedEmbeddingModel(modelId: string): EmbeddingModelV2 { + return globalRegistryManagement.textEmbeddingModel(modelId as any) + } + + /** + * 解析传统格式的嵌入模型 + */ + private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV2 { + const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}` + return globalRegistryManagement.textEmbeddingModel(fullModelId as any) + } + + /** + * 解析命名空间格式的图像模型 + */ + private resolveNamespacedImageModel(modelId: string): ImageModelV2 { + return globalRegistryManagement.imageModel(modelId as any) + } + + /** + * 解析传统格式的图像模型 + */ + private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV2 { + const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}` + return globalRegistryManagement.imageModel(fullModelId as any) + } +} + +/** + * 全局模型解析器实例 + */ +export const globalModelResolver = new ModelResolver() diff --git a/packages/aiCore/src/core/models/index.ts b/packages/aiCore/src/core/models/index.ts new file mode 100644 index 0000000000..439d3d0f41 --- /dev/null +++ b/packages/aiCore/src/core/models/index.ts @@ -0,0 +1,9 @@ +/** + * Models 模块统一导出 - 简化版 + */ + +// 核心模型解析器 +export { globalModelResolver, ModelResolver } from './ModelResolver' + +// 保留的类型定义(可能被其他地方使用) +export type { ModelConfig as ModelConfigType } from './types' diff --git a/packages/aiCore/src/core/models/types.ts b/packages/aiCore/src/core/models/types.ts new file mode 100644 index 0000000000..57cb72366e --- /dev/null +++ b/packages/aiCore/src/core/models/types.ts @@ -0,0 +1,15 @@ +/** + * Creation 模块类型定义 + */ +import { LanguageModelV2Middleware } from '@ai-sdk/provider' + +import type { ProviderId, ProviderSettingsMap } from '../providers/types' + +export interface ModelConfig { + providerId: T + modelId: string + providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' } + middlewares?: LanguageModelV2Middleware[] + // 额外模型参数 + extraModelConfig?: Record +} diff --git a/packages/aiCore/src/core/options/examples.ts b/packages/aiCore/src/core/options/examples.ts new file mode 100644 index 0000000000..9078437d9c --- /dev/null +++ b/packages/aiCore/src/core/options/examples.ts @@ -0,0 +1,87 @@ +import { streamText } from 'ai' + +import { + createAnthropicOptions, + createGenericProviderOptions, + createGoogleOptions, + createOpenAIOptions, + mergeProviderOptions +} from './factory' + +// 示例1: 使用已知供应商的严格类型约束 +export function exampleOpenAIWithOptions() { + const openaiOptions = createOpenAIOptions({ + reasoningEffort: 'medium' + }) + + // 这里会有类型检查,确保选项符合OpenAI的设置 + return streamText({ + model: {} as any, // 实际使用时替换为真实模型 + prompt: 'Hello', + providerOptions: openaiOptions + }) +} + +// 示例2: 使用Anthropic供应商选项 +export function exampleAnthropicWithOptions() { + const anthropicOptions = createAnthropicOptions({ + thinking: { + type: 'enabled', + budgetTokens: 1000 + } + }) + + return streamText({ + model: {} as any, + prompt: 'Hello', + providerOptions: anthropicOptions + }) +} + +// 示例3: 使用Google供应商选项 +export function exampleGoogleWithOptions() { + const googleOptions = createGoogleOptions({ + thinkingConfig: { + includeThoughts: true, + thinkingBudget: 1000 + } + }) + + return streamText({ + model: {} as any, + prompt: 'Hello', + providerOptions: googleOptions + }) +} + +// 示例4: 使用未知供应商(通用类型) +export function exampleUnknownProviderWithOptions() { + const customProviderOptions = createGenericProviderOptions('custom-provider', { + temperature: 0.7, + customSetting: 'value', + anotherOption: true + }) + + return streamText({ + model: {} as any, + prompt: 'Hello', + providerOptions: customProviderOptions + }) +} + +// 示例5: 合并多个供应商选项 +export function exampleMergedOptions() { + const openaiOptions = createOpenAIOptions({}) + + const customOptions = createGenericProviderOptions('custom', { + customParam: 'value' + }) + + const mergedOptions = mergeProviderOptions(openaiOptions, customOptions) + + return streamText({ + model: {} as any, + prompt: 'Hello', + providerOptions: mergedOptions + }) +} diff --git a/packages/aiCore/src/core/options/factory.ts b/packages/aiCore/src/core/options/factory.ts new file mode 100644 index 0000000000..4350e9241b --- /dev/null +++ b/packages/aiCore/src/core/options/factory.ts @@ -0,0 +1,71 @@ +import { ExtractProviderOptions, ProviderOptionsMap, TypedProviderOptions } from './types' + +/** + * 创建特定供应商的选项 + * @param provider 供应商名称 + * @param options 供应商特定的选项 + * @returns 格式化的provider options + */ +export function createProviderOptions( + provider: T, + options: ExtractProviderOptions +): Record> { + return { [provider]: options } as Record> +} + +/** + * 创建任意供应商的选项(包括未知供应商) + * @param provider 供应商名称 + * @param options 供应商选项 + * @returns 格式化的provider options + */ +export function createGenericProviderOptions( + provider: T, + options: Record +): Record> { + return { [provider]: options } as Record> +} + +/** + * 合并多个供应商的options + * @param optionsMap 包含多个供应商选项的对象 + * @returns 合并后的TypedProviderOptions + */ +export function mergeProviderOptions(...optionsMap: Partial[]): TypedProviderOptions { + return Object.assign({}, ...optionsMap) +} + +/** + * 创建OpenAI供应商选项的便捷函数 + */ +export function createOpenAIOptions(options: ExtractProviderOptions<'openai'>) { + return createProviderOptions('openai', options) +} + +/** + * 创建Anthropic供应商选项的便捷函数 + */ +export function createAnthropicOptions(options: ExtractProviderOptions<'anthropic'>) { + return createProviderOptions('anthropic', options) +} + +/** + * 创建Google供应商选项的便捷函数 + */ +export function createGoogleOptions(options: ExtractProviderOptions<'google'>) { + return createProviderOptions('google', options) +} + +/** + * 创建OpenRouter供应商选项的便捷函数 + */ +export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'>) { + return createProviderOptions('openrouter', options) +} + +/** + * 创建XAI供应商选项的便捷函数 + */ +export function createXaiOptions(options: ExtractProviderOptions<'xai'>) { + return createProviderOptions('xai', options) +} diff --git a/packages/aiCore/src/core/options/index.ts b/packages/aiCore/src/core/options/index.ts new file mode 100644 index 0000000000..97a7b59914 --- /dev/null +++ b/packages/aiCore/src/core/options/index.ts @@ -0,0 +1,2 @@ +export * from './factory' +export * from './types' diff --git a/packages/aiCore/src/core/options/openrouter.ts b/packages/aiCore/src/core/options/openrouter.ts new file mode 100644 index 0000000000..b351f8fda1 --- /dev/null +++ b/packages/aiCore/src/core/options/openrouter.ts @@ -0,0 +1,38 @@ +export type OpenRouterProviderOptions = { + models?: string[] + + /** + * https://openrouter.ai/docs/use-cases/reasoning-tokens + * One of `max_tokens` or `effort` is required. + * If `exclude` is true, reasoning will be removed from the response. Default is false. + */ + reasoning?: { + exclude?: boolean + } & ( + | { + max_tokens: number + } + | { + effort: 'high' | 'medium' | 'low' + } + ) + + /** + * A unique identifier representing your end-user, which can + * help OpenRouter to monitor and detect abuse. + */ + user?: string + + extraBody?: Record + + /** + * Enable usage accounting to get detailed token usage information. + * https://openrouter.ai/docs/use-cases/usage-accounting + */ + usage?: { + /** + * When true, includes token usage information in the response. + */ + include: boolean + } +} diff --git a/packages/aiCore/src/core/options/types.ts b/packages/aiCore/src/core/options/types.ts new file mode 100644 index 0000000000..724dc30698 --- /dev/null +++ b/packages/aiCore/src/core/options/types.ts @@ -0,0 +1,33 @@ +import { type AnthropicProviderOptions } from '@ai-sdk/anthropic' +import { type GoogleGenerativeAIProviderOptions } from '@ai-sdk/google' +import { type OpenAIResponsesProviderOptions } from '@ai-sdk/openai' +import { type SharedV2ProviderMetadata } from '@ai-sdk/provider' + +import { type OpenRouterProviderOptions } from './openrouter' +import { type XaiProviderOptions } from './xai' + +export type ProviderOptions = SharedV2ProviderMetadata[T] + +/** + * 供应商选项类型,如果map中没有,说明没有约束 + */ +export type ProviderOptionsMap = { + openai: OpenAIResponsesProviderOptions + anthropic: AnthropicProviderOptions + google: GoogleGenerativeAIProviderOptions + openrouter: OpenRouterProviderOptions + xai: XaiProviderOptions +} + +// 工具类型,用于从ProviderOptionsMap中提取特定供应商的选项类型 +export type ExtractProviderOptions = ProviderOptionsMap[T] + +/** + * 类型安全的ProviderOptions + * 对于已知供应商使用严格类型,对于未知供应商允许任意Record + */ +export type TypedProviderOptions = { + [K in keyof ProviderOptionsMap]?: ProviderOptionsMap[K] +} & { + [K in string]?: Record +} & SharedV2ProviderMetadata diff --git a/packages/aiCore/src/core/options/xai.ts b/packages/aiCore/src/core/options/xai.ts new file mode 100644 index 0000000000..8d82f587e8 --- /dev/null +++ b/packages/aiCore/src/core/options/xai.ts @@ -0,0 +1,86 @@ +// copy from @ai-sdk/xai/xai-chat-options.ts +// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件 + +import * as z from 'zod/v4' + +const webSourceSchema = z.object({ + type: z.literal('web'), + country: z.string().length(2).optional(), + excludedWebsites: z.array(z.string()).max(5).optional(), + allowedWebsites: z.array(z.string()).max(5).optional(), + safeSearch: z.boolean().optional() +}) + +const xSourceSchema = z.object({ + type: z.literal('x'), + xHandles: z.array(z.string()).optional() +}) + +const newsSourceSchema = z.object({ + type: z.literal('news'), + country: z.string().length(2).optional(), + excludedWebsites: z.array(z.string()).max(5).optional(), + safeSearch: z.boolean().optional() +}) + +const rssSourceSchema = z.object({ + type: z.literal('rss'), + links: z.array(z.url()).max(1) // currently only supports one RSS link +}) + +const searchSourceSchema = z.discriminatedUnion('type', [ + webSourceSchema, + xSourceSchema, + newsSourceSchema, + rssSourceSchema +]) + +export const xaiProviderOptions = z.object({ + /** + * reasoning effort for reasoning models + * only supported by grok-3-mini and grok-3-mini-fast models + */ + reasoningEffort: z.enum(['low', 'high']).optional(), + + searchParameters: z + .object({ + /** + * search mode preference + * - "off": disables search completely + * - "auto": model decides whether to search (default) + * - "on": always enables search + */ + mode: z.enum(['off', 'auto', 'on']), + + /** + * whether to return citations in the response + * defaults to true + */ + returnCitations: z.boolean().optional(), + + /** + * start date for search data (ISO8601 format: YYYY-MM-DD) + */ + fromDate: z.string().optional(), + + /** + * end date for search data (ISO8601 format: YYYY-MM-DD) + */ + toDate: z.string().optional(), + + /** + * maximum number of search results to consider + * defaults to 20 + */ + maxSearchResults: z.number().min(1).max(50).optional(), + + /** + * data sources to search from + * defaults to ["web", "x"] if not specified + */ + sources: z.array(searchSourceSchema).optional() + }) + .optional() +}) + +export type XaiProviderOptions = z.infer diff --git a/packages/aiCore/src/core/plugins/README.md b/packages/aiCore/src/core/plugins/README.md new file mode 100644 index 0000000000..266b10c876 --- /dev/null +++ b/packages/aiCore/src/core/plugins/README.md @@ -0,0 +1,257 @@ +# AI Core 插件系统 + +支持四种钩子类型:**First**、**Sequential**、**Parallel** 和 **Stream**。 + +## 🎯 设计理念 + +- **语义清晰**:不同钩子有不同的执行语义 +- **类型安全**:TypeScript 完整支持 +- **性能优化**:First 短路、Parallel 并发、Sequential 链式 +- **易于扩展**:`enforce` 排序 + 功能分类 + +## 📋 钩子类型 + +### 🥇 First 钩子 - 首个有效结果 + +```typescript +// 只执行第一个返回值的插件,用于解析和查找 +resolveModel?: (modelId: string, context: AiRequestContext) => string | null +loadTemplate?: (templateName: string, context: AiRequestContext) => any | null +``` + +### 🔄 Sequential 钩子 - 链式数据转换 + +```typescript +// 按顺序链式执行,每个插件可以修改数据 +transformParams?: (params: any, context: AiRequestContext) => any +transformResult?: (result: any, context: AiRequestContext) => any +``` + +### ⚡ Parallel 钩子 - 并行副作用 + +```typescript +// 并发执行,用于日志、监控等副作用 +onRequestStart?: (context: AiRequestContext) => void +onRequestEnd?: (context: AiRequestContext, result: any) => void +onError?: (error: Error, context: AiRequestContext) => void +``` + +### 🌊 Stream 钩子 - 流处理 + +```typescript +// 直接使用 AI SDK 的 TransformStream +transformStream?: () => (options) => TransformStream +``` + +## 🚀 快速开始 + +### 基础用法 + +```typescript +import { PluginManager, createContext, definePlugin } from '@cherrystudio/ai-core/middleware' + +// 创建插件管理器 +const pluginManager = new PluginManager() + +// 添加插件 +pluginManager.use({ + name: 'my-plugin', + async transformParams(params, context) { + return { ...params, temperature: 0.7 } + } +}) + +// 使用插件 +const context = createContext('openai', 'gpt-4', { messages: [] }) +const transformedParams = await pluginManager.executeSequential( + 'transformParams', + { messages: [{ role: 'user', content: 'Hello' }] }, + context +) +``` + +### 完整示例 + +```typescript +import { + PluginManager, + ModelAliasPlugin, + LoggingPlugin, + ParamsValidationPlugin, + createContext +} from '@cherrystudio/ai-core/middleware' + +// 创建插件管理器 +const manager = new PluginManager([ + ModelAliasPlugin, // 模型别名解析 + ParamsValidationPlugin, // 参数验证 + LoggingPlugin // 日志记录 +]) + +// AI 请求流程 +async function aiRequest(providerId: string, modelId: string, params: any) { + const context = createContext(providerId, modelId, params) + + try { + // 1. 【并行】触发请求开始事件 + await manager.executeParallel('onRequestStart', context) + + // 2. 【首个】解析模型别名 + const resolvedModel = await manager.executeFirst('resolveModel', modelId, context) + context.modelId = resolvedModel || modelId + + // 3. 【串行】转换请求参数 + const transformedParams = await manager.executeSequential('transformParams', params, context) + + // 4. 【流处理】收集流转换器(AI SDK 原生支持数组) + const streamTransforms = manager.collectStreamTransforms() + + // 5. 调用 AI SDK(这里省略具体实现) + const result = await callAiSdk(transformedParams, streamTransforms) + + // 6. 【串行】转换响应结果 + const transformedResult = await manager.executeSequential('transformResult', result, context) + + // 7. 【并行】触发请求完成事件 + await manager.executeParallel('onRequestEnd', context, transformedResult) + + return transformedResult + } catch (error) { + // 8. 【并行】触发错误事件 + await manager.executeParallel('onError', context, undefined, error) + throw error + } +} +``` + +## 🔧 自定义插件 + +### 模型别名插件 + +```typescript +const ModelAliasPlugin = definePlugin({ + name: 'model-alias', + enforce: 'pre', // 最先执行 + + async resolveModel(modelId) { + const aliases = { + gpt4: 'gpt-4-turbo-preview', + claude: 'claude-3-sonnet-20240229' + } + return aliases[modelId] || null + } +}) +``` + +### 参数验证插件 + +```typescript +const ValidationPlugin = definePlugin({ + name: 'validation', + + async transformParams(params) { + if (!params.messages) { + throw new Error('messages is required') + } + + return { + ...params, + temperature: params.temperature ?? 0.7, + max_tokens: params.max_tokens ?? 4096 + } + } +}) +``` + +### 监控插件 + +```typescript +const MonitoringPlugin = definePlugin({ + name: 'monitoring', + enforce: 'post', // 最后执行 + + async onRequestEnd(context, result) { + const duration = Date.now() - context.startTime + console.log(`请求耗时: ${duration}ms`) + } +}) +``` + +### 内容过滤插件 + +```typescript +const FilterPlugin = definePlugin({ + name: 'content-filter', + + transformStream() { + return () => + new TransformStream({ + transform(chunk, controller) { + if (chunk.type === 'text-delta') { + const filtered = chunk.textDelta.replace(/敏感词/g, '***') + controller.enqueue({ ...chunk, textDelta: filtered }) + } else { + controller.enqueue(chunk) + } + } + }) + } +}) +``` + +## 📊 执行顺序 + +### 插件排序 + +``` +enforce: 'pre' → normal → enforce: 'post' +``` + +### 钩子执行流程 + +```mermaid +graph TD + A[请求开始] --> B[onRequestStart 并行执行] + B --> C[resolveModel 首个有效] + C --> D[loadTemplate 首个有效] + D --> E[transformParams 串行执行] + E --> F[collectStreamTransforms] + F --> G[AI SDK 调用] + G --> H[transformResult 串行执行] + H --> I[onRequestEnd 并行执行] + + G --> J[异常处理] + J --> K[onError 并行执行] +``` + +## 💡 最佳实践 + +1. **功能单一**:每个插件专注一个功能 +2. **幂等性**:插件应该是幂等的,重复执行不会产生副作用 +3. **错误处理**:插件内部处理异常,不要让异常向上传播 +4. **性能优化**:使用合适的钩子类型(First vs Sequential vs Parallel) +5. **命名规范**:使用语义化的插件名称 + +## 🔍 调试工具 + +```typescript +// 查看插件统计信息 +const stats = manager.getStats() +console.log('插件统计:', stats) + +// 查看所有插件 +const plugins = manager.getPlugins() +console.log( + '已注册插件:', + plugins.map((p) => p.name) +) +``` + +## ⚡ 性能优势 + +- **First 钩子**:一旦找到结果立即停止,避免无效计算 +- **Parallel 钩子**:真正并发执行,不阻塞主流程 +- **Sequential 钩子**:保证数据转换的顺序性 +- **Stream 钩子**:直接集成 AI SDK,零开销 + +这个设计兼顾了简洁性和强大功能,为 AI Core 提供了灵活而高效的扩展机制。 diff --git a/packages/aiCore/src/core/plugins/built-in/index.ts b/packages/aiCore/src/core/plugins/built-in/index.ts new file mode 100644 index 0000000000..3c0dfa5a8f --- /dev/null +++ b/packages/aiCore/src/core/plugins/built-in/index.ts @@ -0,0 +1,10 @@ +/** + * 内置插件命名空间 + * 所有内置插件都以 'built-in:' 为前缀 + */ +export const BUILT_IN_PLUGIN_PREFIX = 'built-in:' + +export { createLoggingPlugin } from './logging' +export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin' +export type { PromptToolUseConfig, ToolUseRequestContext, ToolUseResult } from './toolUsePlugin/type' +export { webSearchPlugin } from './webSearchPlugin' diff --git a/packages/aiCore/src/core/plugins/built-in/logging.ts b/packages/aiCore/src/core/plugins/built-in/logging.ts new file mode 100644 index 0000000000..043765784c --- /dev/null +++ b/packages/aiCore/src/core/plugins/built-in/logging.ts @@ -0,0 +1,86 @@ +/** + * 内置插件:日志记录 + * 记录AI调用的关键信息,支持性能监控和调试 + */ +import { definePlugin } from '../index' +import type { AiRequestContext } from '../types' + +export interface LoggingConfig { + // 日志级别 + level?: 'debug' | 'info' | 'warn' | 'error' + // 是否记录参数 + logParams?: boolean + // 是否记录结果 + logResult?: boolean + // 是否记录性能数据 + logPerformance?: boolean + // 自定义日志函数 + logger?: (level: string, message: string, data?: any) => void +} + +/** + * 创建日志插件 + */ +export function createLoggingPlugin(config: LoggingConfig = {}) { + const { level = 'info', logParams = true, logResult = false, logPerformance = true, logger = console.log } = config + + const startTimes = new Map() + + return definePlugin({ + name: 'built-in:logging', + + onRequestStart: (context: AiRequestContext) => { + const requestId = context.requestId + startTimes.set(requestId, Date.now()) + + logger(level, `🚀 AI Request Started`, { + requestId, + providerId: context.providerId, + modelId: context.modelId, + originalParams: logParams ? context.originalParams : '[hidden]' + }) + }, + + onRequestEnd: (context: AiRequestContext, result: any) => { + const requestId = context.requestId + const startTime = startTimes.get(requestId) + const duration = startTime ? Date.now() - startTime : undefined + startTimes.delete(requestId) + + const logData: any = { + requestId, + providerId: context.providerId, + modelId: context.modelId + } + + if (logPerformance && duration) { + logData.duration = `${duration}ms` + } + + if (logResult) { + logData.result = result + } + + logger(level, `✅ AI Request Completed`, logData) + }, + + onError: (error: Error, context: AiRequestContext) => { + const requestId = context.requestId + const startTime = startTimes.get(requestId) + const duration = startTime ? Date.now() - startTime : undefined + startTimes.delete(requestId) + + logger('error', `❌ AI Request Failed`, { + requestId, + providerId: context.providerId, + modelId: context.modelId, + duration: duration ? `${duration}ms` : undefined, + error: { + name: error.name, + message: error.message, + stack: error.stack + } + }) + } + }) +} diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts new file mode 100644 index 0000000000..197b20e9b4 --- /dev/null +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts @@ -0,0 +1,139 @@ +/** + * 流事件管理器 + * + * 负责处理 AI SDK 流事件的发送和管理 + * 从 promptToolUsePlugin.ts 中提取出来以降低复杂度 + */ +import type { ModelMessage } from 'ai' + +import type { AiRequestContext } from '../../types' +import type { StreamController } from './ToolExecutor' + +/** + * 流事件管理器类 + */ +export class StreamEventManager { + /** + * 发送工具调用步骤开始事件 + */ + sendStepStartEvent(controller: StreamController): void { + controller.enqueue({ + type: 'start-step', + request: {}, + warnings: [] + }) + } + + /** + * 发送步骤完成事件 + */ + sendStepFinishEvent(controller: StreamController, chunk: any): void { + controller.enqueue({ + type: 'finish-step', + finishReason: 'stop', + response: chunk.response, + usage: chunk.usage, + providerMetadata: chunk.providerMetadata + }) + } + + /** + * 处理递归调用并将结果流接入当前流 + */ + async handleRecursiveCall( + controller: StreamController, + recursiveParams: any, + context: AiRequestContext, + stepId: string + ): Promise { + try { + console.log('[MCP Prompt] Starting recursive call after tool execution...') + + const recursiveResult = await context.recursiveCall(recursiveParams) + + if (recursiveResult && recursiveResult.fullStream) { + await this.pipeRecursiveStream(controller, recursiveResult.fullStream) + } else { + console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult) + } + } catch (error) { + this.handleRecursiveCallError(controller, error, stepId) + } + } + + /** + * 将递归流的数据传递到当前流 + */ + private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise { + const reader = recursiveStream.getReader() + try { + while (true) { + const { done, value } = await reader.read() + if (done) { + break + } + if (value.type === 'finish') { + // 迭代的流不发finish + break + } + // 将递归流的数据传递到当前流 + controller.enqueue(value) + } + } finally { + reader.releaseLock() + } + } + + /** + * 处理递归调用错误 + */ + private handleRecursiveCallError(controller: StreamController, error: unknown, stepId: string): void { + console.error('[MCP Prompt] Recursive call failed:', error) + + // 使用 AI SDK 标准错误格式,但不中断流 + controller.enqueue({ + type: 'error', + error: { + message: error instanceof Error ? error.message : String(error), + name: error instanceof Error ? error.name : 'RecursiveCallError' + } + }) + + // 继续发送文本增量,保持流的连续性 + controller.enqueue({ + type: 'text-delta', + id: stepId, + text: '\n\n[工具执行后递归调用失败,继续对话...]' + }) + } + + /** + * 构建递归调用的参数 + */ + buildRecursiveParams(context: AiRequestContext, textBuffer: string, toolResultsText: string, tools: any): any { + // 构建新的对话消息 + const newMessages: ModelMessage[] = [ + ...(context.originalParams.messages || []), + { + role: 'assistant', + content: textBuffer + }, + { + role: 'user', + content: toolResultsText + } + ] + + // 递归调用,继续对话,重新传递 tools + const recursiveParams = { + ...context.originalParams, + messages: newMessages, + tools: tools + } + + // 更新上下文中的消息 + context.originalParams.messages = newMessages + + return recursiveParams + } +} diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/ToolExecutor.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/ToolExecutor.ts new file mode 100644 index 0000000000..ec174fa2ea --- /dev/null +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/ToolExecutor.ts @@ -0,0 +1,156 @@ +/** + * 工具执行器 + * + * 负责工具的执行、结果格式化和相关事件发送 + * 从 promptToolUsePlugin.ts 中提取出来以降低复杂度 + */ +import type { ToolSet } from 'ai' + +import type { ToolUseResult } from './type' + +/** + * 工具执行结果 + */ +export interface ExecutedResult { + toolCallId: string + toolName: string + result: any + isError?: boolean +} + +/** + * 流控制器类型(从 AI SDK 提取) + */ +export interface StreamController { + enqueue(chunk: any): void +} + +/** + * 工具执行器类 + */ +export class ToolExecutor { + /** + * 执行多个工具调用 + */ + async executeTools( + toolUses: ToolUseResult[], + tools: ToolSet, + controller: StreamController + ): Promise { + const executedResults: ExecutedResult[] = [] + + for (const toolUse of toolUses) { + try { + const tool = tools[toolUse.toolName] + if (!tool || typeof tool.execute !== 'function') { + throw new Error(`Tool "${toolUse.toolName}" has no execute method`) + } + + // 发送工具调用开始事件 + this.sendToolStartEvents(controller, toolUse) + + console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments) + + // 发送 tool-call 事件 + controller.enqueue({ + type: 'tool-call', + toolCallId: toolUse.id, + toolName: toolUse.toolName, + input: tool.inputSchema + }) + + const result = await tool.execute(toolUse.arguments, { + toolCallId: toolUse.id, + messages: [], + abortSignal: new AbortController().signal + }) + + // 发送 tool-result 事件 + controller.enqueue({ + type: 'tool-result', + toolCallId: toolUse.id, + toolName: toolUse.toolName, + input: toolUse.arguments, + output: result + }) + + executedResults.push({ + toolCallId: toolUse.id, + toolName: toolUse.toolName, + result, + isError: false + }) + } catch (error) { + console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error) + + // 处理错误情况 + const errorResult = this.handleToolError(toolUse, error, controller) + executedResults.push(errorResult) + } + } + + return executedResults + } + + /** + * 格式化工具结果为 Cherry Studio 标准格式 + */ + formatToolResults(executedResults: ExecutedResult[]): string { + return executedResults + .map((tr) => { + if (!tr.isError) { + return `\n ${tr.toolName}\n ${JSON.stringify(tr.result)}\n` + } else { + const error = tr.result || 'Unknown error' + return `\n ${tr.toolName}\n ${error}\n` + } + }) + .join('\n\n') + } + + /** + * 发送工具调用开始相关事件 + */ + private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void { + // 发送 tool-input-start 事件 + controller.enqueue({ + type: 'tool-input-start', + id: toolUse.id, + toolName: toolUse.toolName + }) + } + + /** + * 处理工具执行错误 + */ + private handleToolError( + toolUse: ToolUseResult, + error: unknown, + controller: StreamController + // _tools: ToolSet + ): ExecutedResult { + // 使用 AI SDK 标准错误格式 + // const toolError: TypedToolError = { + // type: 'tool-error', + // toolCallId: toolUse.id, + // toolName: toolUse.toolName, + // input: toolUse.arguments, + // error: error instanceof Error ? error.message : String(error) + // } + + // controller.enqueue(toolError) + + // 发送标准错误事件 + controller.enqueue({ + type: 'error', + error: error instanceof Error ? error.message : String(error) + }) + + return { + toolCallId: toolUse.id, + toolName: toolUse.toolName, + result: error instanceof Error ? error.message : String(error), + isError: true + } + } +} diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts new file mode 100644 index 0000000000..1d795c94a3 --- /dev/null +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts @@ -0,0 +1,373 @@ +/** + * 内置插件:MCP Prompt 模式 + * 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用 + * 内置默认逻辑,支持自定义覆盖 + */ +import type { TextStreamPart, ToolSet } from 'ai' + +import { definePlugin } from '../../index' +import type { AiRequestContext } from '../../types' +import { StreamEventManager } from './StreamEventManager' +import { ToolExecutor } from './ToolExecutor' +import { PromptToolUseConfig, ToolUseResult } from './type' + +/** + * 默认系统提示符模板(提取自 Cherry Studio) + */ +const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \\ +You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use. + +## Tool Use Formatting + +Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure: + + + {tool_name} + {json_arguments} + + +The tool name should be the exact name of the tool you are using, and the arguments should be a JSON object containing the parameters required by that tool. For example: + + python_interpreter + {"code": "5 + 3 + 1294.678"} + + +The user will respond with the result of the tool use, which should be formatted as follows: + + + {tool_name} + {result} + + +The result should be a string, which can represent a file or any other output type. You can use this result as input for the next action. +For example, if the result of the tool use is an image file, you can use it in the next action like this: + + + image_transformer + {"image": "image_1.jpg"} + + +Always adhere to this format for the tool use to ensure proper parsing and execution. + +## Tool Use Examples +{{ TOOL_USE_EXAMPLES }} + +## Tool Use Available Tools +Above example were using notional tools that might not exist for you. You only have access to these tools: +{{ AVAILABLE_TOOLS }} + +## Tool Use Rules +Here are the rules you should always follow to solve your task: +1. Always use the right arguments for the tools. Never use variable names as the action arguments, use the value instead. +2. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself. +3. If no tool call is needed, just answer the question directly. +4. Never re-do a tool call that you previously did with the exact same parameters. +5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format. + +# User Instructions +{{ USER_SYSTEM_PROMPT }} + +Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.` + +/** + * 默认工具使用示例(提取自 Cherry Studio) + */ +const DEFAULT_TOOL_USE_EXAMPLES = ` +Here are a few examples using notional tools: +--- +User: Generate an image of the oldest person in this document. + +A: I can use the document_qa tool to find out who the oldest person is in the document. + + document_qa + {"document": "document.pdf", "question": "Who is the oldest person mentioned?"} + + +User: + document_qa + John Doe, a 55 year old lumberjack living in Newfoundland. + + +A: I can use the image_generator tool to create a portrait of John Doe. + + image_generator + {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} + + +User: + image_generator + image.png + + +A: the image is generated as image.png + +--- +User: "What is the result of the following operation: 5 + 3 + 1294.678?" + +A: I can use the python_interpreter tool to calculate the result of the operation. + + python_interpreter + {"code": "5 + 3 + 1294.678"} + + +User: + python_interpreter + 1302.678 + + +A: The result of the operation is 1302.678. + +--- +User: "Which city has the highest population , Guangzhou or Shanghai?" + +A: I can use the search tool to find the population of Guangzhou. + + search + {"query": "Population Guangzhou"} + + +User: + search + Guangzhou has a population of 15 million inhabitants as of 2021. + + +A: I can use the search tool to find the population of Shanghai. + + search + {"query": "Population Shanghai"} + + +User: + search + 26 million (2019) + +Assistant: The population of Shanghai is 26 million, while Guangzhou has a population of 15 million. Therefore, Shanghai has the highest population.` + +/** + * 构建可用工具部分(提取自 Cherry Studio) + */ +function buildAvailableTools(tools: ToolSet): string { + const availableTools = Object.keys(tools) + .map((toolName: string) => { + const tool = tools[toolName] + return ` + + ${toolName} + ${tool.description || ''} + + ${tool.inputSchema ? JSON.stringify(tool.inputSchema) : ''} + + +` + }) + .join('\n') + return ` +${availableTools} +` +} + +/** + * 默认的系统提示符构建函数(提取自 Cherry Studio) + */ +function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): string { + const availableTools = buildAvailableTools(tools) + + const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES) + .replace('{{ AVAILABLE_TOOLS }}', availableTools) + .replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt || '') + + return fullPrompt +} + +/** + * 默认工具解析函数(提取自 Cherry Studio) + * 解析 XML 格式的工具调用 + */ +function defaultParseToolUse(content: string, tools: ToolSet): { results: ToolUseResult[]; content: string } { + if (!content || !tools || Object.keys(tools).length === 0) { + return { results: [], content: content } + } + + // 支持两种格式: + // 1. 完整的 标签包围的内容 + // 2. 只有内部内容(从 TagExtractor 提取出来的) + + let contentToProcess = content + // 如果内容不包含 标签,说明是从 TagExtractor 提取的内部内容,需要包装 + if (!content.includes('')) { + contentToProcess = `\n${content}\n` + } + + const toolUsePattern = + /([\s\S]*?)([\s\S]*?)<\/name>([\s\S]*?)([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g + const results: ToolUseResult[] = [] + let match + let idx = 0 + + // Find all tool use blocks + while ((match = toolUsePattern.exec(contentToProcess)) !== null) { + const fullMatch = match[0] + const toolName = match[2].trim() + const toolArgs = match[4].trim() + + // Try to parse the arguments as JSON + let parsedArgs + try { + parsedArgs = JSON.parse(toolArgs) + } catch (error) { + // If parsing fails, use the string as is + parsedArgs = toolArgs + } + + // Find the corresponding tool + const tool = tools[toolName] + if (!tool) { + console.warn(`Tool "${toolName}" not found in available tools`) + continue + } + + // Add to results array + results.push({ + id: `${toolName}-${idx++}`, // Unique ID for each tool use + toolName: toolName, + arguments: parsedArgs, + status: 'pending' + }) + contentToProcess = contentToProcess.replace(fullMatch, '') + } + return { results, content: contentToProcess } +} + +export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => { + const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config + + return definePlugin({ + name: 'built-in:prompt-tool-use', + transformParams: (params: any, context: AiRequestContext) => { + if (!enabled || !params.tools || typeof params.tools !== 'object') { + return params + } + + context.mcpTools = params.tools + console.log('tools stored in context', params.tools) + + // 构建系统提示符 + const userSystemPrompt = typeof params.system === 'string' ? params.system : '' + const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools) + let systemMessage: string | null = systemPrompt + console.log('config.context', context) + if (config.createSystemMessage) { + // 🎯 如果用户提供了自定义处理函数,使用它 + systemMessage = config.createSystemMessage(systemPrompt, params, context) + } + + // 移除 tools,改为 prompt 模式 + const transformedParams = { + ...params, + ...(systemMessage ? { system: systemMessage } : {}), + tools: undefined + } + context.originalParams = transformedParams + console.log('transformedParams', transformedParams) + return transformedParams + }, + transformStream: (_: any, context: AiRequestContext) => () => { + let textBuffer = '' + let stepId = '' + + if (!context.mcpTools) { + throw new Error('No tools available') + } + + // 创建工具执行器和流事件管理器 + const toolExecutor = new ToolExecutor() + const streamEventManager = new StreamEventManager() + + type TOOLS = NonNullable + return new TransformStream, TextStreamPart>({ + async transform( + chunk: TextStreamPart, + controller: TransformStreamDefaultController> + ) { + // 收集文本内容 + if (chunk.type === 'text-delta') { + textBuffer += chunk.text || '' + stepId = chunk.id || '' + controller.enqueue(chunk) + return + } + + if (chunk.type === 'text-end' || chunk.type === 'finish-step') { + const tools = context.mcpTools + if (!tools || Object.keys(tools).length === 0) { + controller.enqueue(chunk) + return + } + + // 解析工具调用 + const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools) + const validToolUses = parsedTools.filter((t) => t.status === 'pending') + + // 如果没有有效的工具调用,直接传递原始事件 + if (validToolUses.length === 0) { + controller.enqueue(chunk) + return + } + + if (chunk.type === 'text-end') { + controller.enqueue({ + type: 'text-end', + id: stepId, + providerMetadata: { + text: { + value: parsedContent + } + } + }) + return + } + + controller.enqueue({ + ...chunk, + finishReason: 'tool-calls' + }) + + // 发送步骤开始事件 + streamEventManager.sendStepStartEvent(controller) + + // 执行工具调用 + const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller) + + // 发送步骤完成事件 + streamEventManager.sendStepFinishEvent(controller, chunk) + + // 处理递归调用 + if (validToolUses.length > 0) { + const toolResultsText = toolExecutor.formatToolResults(executedResults) + const recursiveParams = streamEventManager.buildRecursiveParams( + context, + textBuffer, + toolResultsText, + tools + ) + + await streamEventManager.handleRecursiveCall(controller, recursiveParams, context, stepId) + } + + // 清理状态 + textBuffer = '' + return + } + + // 对于其他类型的事件,直接传递 + controller.enqueue(chunk) + }, + + flush() { + // 流结束时的清理工作 + console.log('[MCP Prompt] Stream ended, cleaning up...') + } + }) + } + }) +} diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/tagExtraction.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/tagExtraction.ts new file mode 100644 index 0000000000..11d9b934a6 --- /dev/null +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/tagExtraction.ts @@ -0,0 +1,196 @@ +// Copied from https://github.com/vercel/ai/blob/main/packages/ai/core/util/get-potential-start-index.ts + +/** + * Returns the index of the start of the searchedText in the text, or null if it + * is not found. + */ +export function getPotentialStartIndex(text: string, searchedText: string): number | null { + // Return null immediately if searchedText is empty. + if (searchedText.length === 0) { + return null + } + + // Check if the searchedText exists as a direct substring of text. + const directIndex = text.indexOf(searchedText) + if (directIndex !== -1) { + return directIndex + } + + // Otherwise, look for the largest suffix of "text" that matches + // a prefix of "searchedText". We go from the end of text inward. + for (let i = text.length - 1; i >= 0; i--) { + const suffix = text.substring(i) + if (searchedText.startsWith(suffix)) { + return i + } + } + + return null +} + +export interface TagConfig { + openingTag: string + closingTag: string + separator?: string +} + +export interface TagExtractionState { + textBuffer: string + isInsideTag: boolean + isFirstTag: boolean + isFirstText: boolean + afterSwitch: boolean + accumulatedTagContent: string + hasTagContent: boolean +} + +export interface TagExtractionResult { + content: string + isTagContent: boolean + complete: boolean + tagContentExtracted?: string +} + +/** + * 通用标签提取处理器 + * 可以处理各种形式的标签对,如 ..., ... 等 + */ +export class TagExtractor { + private config: TagConfig + private state: TagExtractionState + + constructor(config: TagConfig) { + this.config = config + this.state = { + textBuffer: '', + isInsideTag: false, + isFirstTag: true, + isFirstText: true, + afterSwitch: false, + accumulatedTagContent: '', + hasTagContent: false + } + } + + /** + * 处理文本块,返回处理结果 + */ + processText(newText: string): TagExtractionResult[] { + this.state.textBuffer += newText + const results: TagExtractionResult[] = [] + + // 处理标签提取逻辑 + while (true) { + const nextTag = this.state.isInsideTag ? this.config.closingTag : this.config.openingTag + const startIndex = getPotentialStartIndex(this.state.textBuffer, nextTag) + + if (startIndex == null) { + const content = this.state.textBuffer + if (content.length > 0) { + results.push({ + content: this.addPrefix(content), + isTagContent: this.state.isInsideTag, + complete: false + }) + + if (this.state.isInsideTag) { + this.state.accumulatedTagContent += this.addPrefix(content) + this.state.hasTagContent = true + } + } + this.state.textBuffer = '' + break + } + + // 处理标签前的内容 + const contentBeforeTag = this.state.textBuffer.slice(0, startIndex) + if (contentBeforeTag.length > 0) { + results.push({ + content: this.addPrefix(contentBeforeTag), + isTagContent: this.state.isInsideTag, + complete: false + }) + + if (this.state.isInsideTag) { + this.state.accumulatedTagContent += this.addPrefix(contentBeforeTag) + this.state.hasTagContent = true + } + } + + const foundFullMatch = startIndex + nextTag.length <= this.state.textBuffer.length + + if (foundFullMatch) { + // 如果找到完整的标签 + this.state.textBuffer = this.state.textBuffer.slice(startIndex + nextTag.length) + + // 如果刚刚结束一个标签内容,生成完整的标签内容结果 + if (this.state.isInsideTag && this.state.hasTagContent) { + results.push({ + content: '', + isTagContent: false, + complete: true, + tagContentExtracted: this.state.accumulatedTagContent + }) + this.state.accumulatedTagContent = '' + this.state.hasTagContent = false + } + + this.state.isInsideTag = !this.state.isInsideTag + this.state.afterSwitch = true + + if (this.state.isInsideTag) { + this.state.isFirstTag = false + } else { + this.state.isFirstText = false + } + } else { + this.state.textBuffer = this.state.textBuffer.slice(startIndex) + break + } + } + + return results + } + + /** + * 完成处理,返回任何剩余的标签内容 + */ + finalize(): TagExtractionResult | null { + if (this.state.hasTagContent && this.state.accumulatedTagContent) { + const result = { + content: '', + isTagContent: false, + complete: true, + tagContentExtracted: this.state.accumulatedTagContent + } + this.state.accumulatedTagContent = '' + this.state.hasTagContent = false + return result + } + return null + } + + private addPrefix(text: string): string { + const needsPrefix = + this.state.afterSwitch && (this.state.isInsideTag ? !this.state.isFirstTag : !this.state.isFirstText) + + const prefix = needsPrefix && this.config.separator ? this.config.separator : '' + this.state.afterSwitch = false + return prefix + text + } + + /** + * 重置状态 + */ + reset(): void { + this.state = { + textBuffer: '', + isInsideTag: false, + isFirstTag: true, + isFirstText: true, + afterSwitch: false, + accumulatedTagContent: '', + hasTagContent: false + } + } +} diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/type.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/type.ts new file mode 100644 index 0000000000..33ed6189ed --- /dev/null +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/type.ts @@ -0,0 +1,33 @@ +import { ToolSet } from 'ai' + +import { AiRequestContext } from '../..' + +/** + * 解析结果类型 + * 表示从AI响应中解析出的工具使用意图 + */ +export interface ToolUseResult { + id: string + toolName: string + arguments: any + status: 'pending' | 'invoking' | 'done' | 'error' +} + +export interface BaseToolUsePluginConfig { + enabled?: boolean +} + +export interface PromptToolUseConfig extends BaseToolUsePluginConfig { + // 自定义系统提示符构建函数(可选,有默认实现) + buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string + // 自定义工具解析函数(可选,有默认实现) + parseToolUse?: (content: string, tools: ToolSet) => { results: ToolUseResult[]; content: string } + createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null +} + +/** + * 扩展的 AI 请求上下文,支持 MCP 工具存储 + */ +export interface ToolUseRequestContext extends AiRequestContext { + mcpTools: ToolSet +} diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts new file mode 100644 index 0000000000..a7c7187bca --- /dev/null +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts @@ -0,0 +1,67 @@ +import { anthropic } from '@ai-sdk/anthropic' +import { google } from '@ai-sdk/google' +import { openai } from '@ai-sdk/openai' + +import { ProviderOptionsMap } from '../../../options/types' + +/** + * 从 AI SDK 的工具函数中提取参数类型,以确保类型安全。 + */ +type OpenAISearchConfig = Parameters[0] +type AnthropicSearchConfig = Parameters[0] +type GoogleSearchConfig = Parameters[0] + +/** + * 插件初始化时接收的完整配置对象 + * + * 其结构与 ProviderOptions 保持一致,方便上游统一管理配置 + */ +export interface WebSearchPluginConfig { + openai?: OpenAISearchConfig + anthropic?: AnthropicSearchConfig + xai?: ProviderOptionsMap['xai']['searchParameters'] + google?: GoogleSearchConfig + 'google-vertex'?: GoogleSearchConfig +} + +/** + * 插件的默认配置 + */ +export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = { + google: {}, + 'google-vertex': {}, + openai: {}, + xai: { + mode: 'on', + returnCitations: true, + maxSearchResults: 5, + sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }] + }, + anthropic: { + maxUses: 5 + } +} + +export type WebSearchToolOutputSchema = { + // Anthropic 工具 - 手动定义 + anthropicWebSearch: Array<{ + url: string + title: string + pageAge: string | null + encryptedContent: string + type: string + }> + + // OpenAI 工具 - 基于实际输出 + openaiWebSearch: { + status: 'completed' | 'failed' + } + + // Google 工具 + googleSearch: { + webSearchQueries?: string[] + groundingChunks?: Array<{ + web?: { uri: string; title: string } + }> + } +} diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts new file mode 100644 index 0000000000..3d549eeac4 --- /dev/null +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts @@ -0,0 +1,69 @@ +/** + * Web Search Plugin + * 提供统一的网络搜索能力,支持多个 AI Provider + */ +import { anthropic } from '@ai-sdk/anthropic' +import { google } from '@ai-sdk/google' +import { openai } from '@ai-sdk/openai' + +import { createXaiOptions, mergeProviderOptions } from '../../../options' +import { definePlugin } from '../../' +import type { AiRequestContext } from '../../types' +import { DEFAULT_WEB_SEARCH_CONFIG, WebSearchPluginConfig } from './helper' + +/** + * 网络搜索插件 + * + * @param config - 在插件初始化时传入的静态配置 + */ +export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEARCH_CONFIG) => + definePlugin({ + name: 'webSearch', + enforce: 'pre', + + transformParams: async (params: any, context: AiRequestContext) => { + const { providerId } = context + switch (providerId) { + case 'openai': { + if (config.openai) { + if (!params.tools) params.tools = {} + params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai) + } + break + } + + case 'anthropic': { + if (config.anthropic) { + if (!params.tools) params.tools = {} + params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic) + } + break + } + + case 'google': { + // case 'google-vertex': + if (!params.tools) params.tools = {} + params.tools.web_search = google.tools.googleSearch(config.google || {}) + break + } + + case 'xai': { + if (config.xai) { + const searchOptions = createXaiOptions({ + searchParameters: { ...config.xai, mode: 'on' } + }) + params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) + } + break + } + } + + return params + } + }) + +// 导出类型定义供开发者使用 +export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper' + +// 默认导出 +export default webSearchPlugin diff --git a/packages/aiCore/src/core/plugins/index.ts b/packages/aiCore/src/core/plugins/index.ts new file mode 100644 index 0000000000..200188d59d --- /dev/null +++ b/packages/aiCore/src/core/plugins/index.ts @@ -0,0 +1,32 @@ +// 核心类型和接口 +export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types' +import type { ProviderId } from '../providers' +import type { AiPlugin, AiRequestContext } from './types' + +// 插件管理器 +export { PluginManager } from './manager' + +// 工具函数 +export function createContext( + providerId: T, + modelId: string, + originalParams: any +): AiRequestContext { + return { + providerId, + modelId, + originalParams, + metadata: {}, + startTime: Date.now(), + requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`, + // 占位 + recursiveCall: () => Promise.resolve(null) + } +} + +// 插件构建器 - 便于创建插件 +export function definePlugin(plugin: AiPlugin): AiPlugin +export function definePlugin AiPlugin>(pluginFactory: T): T +export function definePlugin(plugin: AiPlugin | ((...args: any[]) => AiPlugin)) { + return plugin +} diff --git a/packages/aiCore/src/core/plugins/manager.ts b/packages/aiCore/src/core/plugins/manager.ts new file mode 100644 index 0000000000..4c927ed1de --- /dev/null +++ b/packages/aiCore/src/core/plugins/manager.ts @@ -0,0 +1,184 @@ +import { AiPlugin, AiRequestContext } from './types' + +/** + * 插件管理器 + */ +export class PluginManager { + private plugins: AiPlugin[] = [] + + constructor(plugins: AiPlugin[] = []) { + this.plugins = this.sortPlugins(plugins) + } + + /** + * 添加插件 + */ + use(plugin: AiPlugin): this { + this.plugins = this.sortPlugins([...this.plugins, plugin]) + return this + } + + /** + * 移除插件 + */ + remove(pluginName: string): this { + this.plugins = this.plugins.filter((p) => p.name !== pluginName) + return this + } + + /** + * 插件排序:pre -> normal -> post + */ + private sortPlugins(plugins: AiPlugin[]): AiPlugin[] { + const pre: AiPlugin[] = [] + const normal: AiPlugin[] = [] + const post: AiPlugin[] = [] + + plugins.forEach((plugin) => { + if (plugin.enforce === 'pre') { + pre.push(plugin) + } else if (plugin.enforce === 'post') { + post.push(plugin) + } else { + normal.push(plugin) + } + }) + + return [...pre, ...normal, ...post] + } + + /** + * 执行 First 钩子 - 返回第一个有效结果 + */ + async executeFirst( + hookName: 'resolveModel' | 'loadTemplate', + arg: any, + context: AiRequestContext + ): Promise { + for (const plugin of this.plugins) { + const hook = plugin[hookName] + if (hook) { + const result = await hook(arg, context) + if (result !== null && result !== undefined) { + return result as T + } + } + } + return null + } + + /** + * 执行 Sequential 钩子 - 链式数据转换 + */ + async executeSequential( + hookName: 'transformParams' | 'transformResult', + initialValue: T, + context: AiRequestContext + ): Promise { + let result = initialValue + + for (const plugin of this.plugins) { + const hook = plugin[hookName] + if (hook) { + result = await hook(result, context) + } + } + + return result + } + + /** + * 执行 ConfigureContext 钩子 - 串行配置上下文 + */ + async executeConfigureContext(context: AiRequestContext): Promise { + for (const plugin of this.plugins) { + const hook = plugin.configureContext + if (hook) { + await hook(context) + } + } + } + + /** + * 执行 Parallel 钩子 - 并行副作用 + */ + async executeParallel( + hookName: 'onRequestStart' | 'onRequestEnd' | 'onError', + context: AiRequestContext, + result?: any, + error?: Error + ): Promise { + const promises = this.plugins + .map((plugin) => { + const hook = plugin[hookName] + if (!hook) return null + + if (hookName === 'onError' && error) { + return (hook as any)(error, context) + } else if (hookName === 'onRequestEnd' && result !== undefined) { + return (hook as any)(context, result) + } else if (hookName === 'onRequestStart') { + return (hook as any)(context) + } + return null + }) + .filter(Boolean) + + // 使用 Promise.all 而不是 allSettled,让插件错误能够抛出 + await Promise.all(promises) + } + + /** + * 收集所有流转换器(返回数组,AI SDK 原生支持) + */ + collectStreamTransforms(params: any, context: AiRequestContext) { + return this.plugins + .filter((plugin) => plugin.transformStream) + .map((plugin) => plugin.transformStream?.(params, context)) + } + + /** + * 获取所有插件信息 + */ + getPlugins(): AiPlugin[] { + return [...this.plugins] + } + + /** + * 获取插件统计信息 + */ + getStats() { + const stats = { + total: this.plugins.length, + pre: 0, + normal: 0, + post: 0, + hooks: { + resolveModel: 0, + loadTemplate: 0, + transformParams: 0, + transformResult: 0, + onRequestStart: 0, + onRequestEnd: 0, + onError: 0, + transformStream: 0 + } + } + + this.plugins.forEach((plugin) => { + // 统计 enforce 类型 + if (plugin.enforce === 'pre') stats.pre++ + else if (plugin.enforce === 'post') stats.post++ + else stats.normal++ + + // 统计钩子数量 + Object.keys(stats.hooks).forEach((hookName) => { + if (plugin[hookName as keyof AiPlugin]) { + stats.hooks[hookName as keyof typeof stats.hooks]++ + } + }) + }) + + return stats + } +} diff --git a/packages/aiCore/src/core/plugins/types.ts b/packages/aiCore/src/core/plugins/types.ts new file mode 100644 index 0000000000..378fa6c3d3 --- /dev/null +++ b/packages/aiCore/src/core/plugins/types.ts @@ -0,0 +1,79 @@ +import type { ImageModelV2 } from '@ai-sdk/provider' +import type { LanguageModel, TextStreamPart, ToolSet } from 'ai' + +import { type ProviderId } from '../providers/types' + +/** + * 递归调用函数类型 + * 使用 any 是因为递归调用时参数和返回类型可能完全不同 + */ +export type RecursiveCallFn = (newParams: any) => Promise + +/** + * AI 请求上下文 + */ +export interface AiRequestContext { + providerId: ProviderId + modelId: string + originalParams: any + metadata: Record + startTime: number + requestId: string + recursiveCall: RecursiveCallFn + isRecursiveCall?: boolean + mcpTools?: ToolSet + [key: string]: any +} + +/** + * 钩子分类 + */ +export interface AiPlugin { + name: string + enforce?: 'pre' | 'post' + + // 【First】首个钩子 - 只执行第一个返回值的插件 + resolveModel?: ( + modelId: string, + context: AiRequestContext + ) => Promise | LanguageModel | ImageModelV2 | null + loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise + + // 【Sequential】串行钩子 - 链式执行,支持数据转换 + configureContext?: (context: AiRequestContext) => void | Promise + transformParams?: (params: T, context: AiRequestContext) => T | Promise + transformResult?: (result: T, context: AiRequestContext) => T | Promise + + // 【Parallel】并行钩子 - 不依赖顺序,用于副作用 + onRequestStart?: (context: AiRequestContext) => void | Promise + onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise + onError?: (error: Error, context: AiRequestContext) => void | Promise + + // 【Stream】流处理 - 直接使用 AI SDK + transformStream?: ( + params: any, + context: AiRequestContext + ) => (options?: { + tools: TOOLS + stopStream: () => void + }) => TransformStream, TextStreamPart> + + // AI SDK 原生中间件 + // aiSdkMiddlewares?: LanguageModelV1Middleware[] +} + +/** + * 插件管理器配置 + */ +export interface PluginManagerConfig { + plugins: AiPlugin[] + context: Partial +} + +/** + * 钩子执行结果 + */ +export interface HookResult { + value: T + stop?: boolean +} diff --git a/packages/aiCore/src/core/providers/HubProvider.ts b/packages/aiCore/src/core/providers/HubProvider.ts new file mode 100644 index 0000000000..0283d634b0 --- /dev/null +++ b/packages/aiCore/src/core/providers/HubProvider.ts @@ -0,0 +1,101 @@ +/** + * Hub Provider - 支持路由到多个底层provider + * + * 支持格式: hubId:providerId:modelId + * 例如: aihubmix:anthropic:claude-3.5-sonnet + */ + +import { ProviderV2 } from '@ai-sdk/provider' +import { customProvider } from 'ai' + +import { globalRegistryManagement } from './RegistryManagement' +import type { AiSdkMethodName, AiSdkModelReturn, AiSdkModelType } from './types' + +export interface HubProviderConfig { + /** Hub的唯一标识符 */ + hubId: string + /** 是否启用调试日志 */ + debug?: boolean +} + +export class HubProviderError extends Error { + constructor( + message: string, + public readonly hubId: string, + public readonly providerId?: string, + public readonly originalError?: Error + ) { + super(message) + this.name = 'HubProviderError' + } +} + +/** + * 解析Hub模型ID + */ +function parseHubModelId(modelId: string): { provider: string; actualModelId: string } { + const parts = modelId.split(':') + if (parts.length !== 2) { + throw new HubProviderError(`Invalid hub model ID format. Expected "provider:modelId", got: ${modelId}`, 'unknown') + } + return { + provider: parts[0], + actualModelId: parts[1] + } +} + +/** + * 创建Hub Provider + */ +export function createHubProvider(config: HubProviderConfig): ProviderV2 { + const { hubId } = config + + function getTargetProvider(providerId: string): ProviderV2 { + // 从全局注册表获取provider实例 + try { + const provider = globalRegistryManagement.getProvider(providerId) + if (!provider) { + throw new HubProviderError( + `Provider "${providerId}" is not initialized. Please call initializeProvider("${providerId}", options) first.`, + hubId, + providerId + ) + } + return provider + } catch (error) { + throw new HubProviderError( + `Failed to get provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`, + hubId, + providerId, + error instanceof Error ? error : undefined + ) + } + } + + function resolveModel( + modelId: string, + modelType: T, + methodName: AiSdkMethodName + ): AiSdkModelReturn { + const { provider, actualModelId } = parseHubModelId(modelId) + const targetProvider = getTargetProvider(provider) + + const fn = targetProvider[methodName] as (id: string) => AiSdkModelReturn + + if (!fn) { + throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider) + } + + return fn(actualModelId) + } + + return customProvider({ + fallbackProvider: { + languageModel: (modelId: string) => resolveModel(modelId, 'text', 'languageModel'), + textEmbeddingModel: (modelId: string) => resolveModel(modelId, 'embedding', 'textEmbeddingModel'), + imageModel: (modelId: string) => resolveModel(modelId, 'image', 'imageModel'), + transcriptionModel: (modelId: string) => resolveModel(modelId, 'transcription', 'transcriptionModel'), + speechModel: (modelId: string) => resolveModel(modelId, 'speech', 'speechModel') + } + }) +} diff --git a/packages/aiCore/src/core/providers/RegistryManagement.ts b/packages/aiCore/src/core/providers/RegistryManagement.ts new file mode 100644 index 0000000000..a8aefd44b2 --- /dev/null +++ b/packages/aiCore/src/core/providers/RegistryManagement.ts @@ -0,0 +1,221 @@ +/** + * Provider 注册表管理器 + * 纯粹的管理功能:存储、检索已配置好的 provider 实例 + * 基于 AI SDK 原生的 createProviderRegistry + */ + +import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, ProviderV2 } from '@ai-sdk/provider' +import { createProviderRegistry, type ProviderRegistryProvider } from 'ai' + +type PROVIDERS = Record + +export const DEFAULT_SEPARATOR = '|' + +// export type MODEL_ID = `${string}${typeof DEFAULT_SEPARATOR}${string}` + +export class RegistryManagement { + private providers: PROVIDERS = {} + private aliases: Set = new Set() // 记录哪些key是别名 + private separator: SEPARATOR + private registry: ProviderRegistryProvider | null = null + + constructor(options: { separator: SEPARATOR } = { separator: DEFAULT_SEPARATOR as SEPARATOR }) { + this.separator = options.separator + } + + /** + * 注册已配置好的 provider 实例 + */ + registerProvider(id: string, provider: ProviderV2, aliases?: string[]): this { + // 注册主provider + this.providers[id] = provider + + // 注册别名(都指向同一个provider实例) + if (aliases) { + aliases.forEach((alias) => { + this.providers[alias] = provider // 直接存储引用 + this.aliases.add(alias) // 标记为别名 + }) + } + + this.rebuildRegistry() + return this + } + + /** + * 获取已注册的provider实例 + */ + getProvider(id: string): ProviderV2 | undefined { + return this.providers[id] + } + + /** + * 批量注册 providers + */ + registerProviders(providers: Record): this { + Object.assign(this.providers, providers) + this.rebuildRegistry() + return this + } + + /** + * 移除 provider(同时清理相关别名) + */ + unregisterProvider(id: string): this { + const provider = this.providers[id] + if (!provider) return this + + // 如果移除的是真实ID,需要清理所有指向它的别名 + if (!this.aliases.has(id)) { + // 找到所有指向此provider的别名并删除 + const aliasesToRemove: string[] = [] + this.aliases.forEach((alias) => { + if (this.providers[alias] === provider) { + aliasesToRemove.push(alias) + } + }) + + aliasesToRemove.forEach((alias) => { + delete this.providers[alias] + this.aliases.delete(alias) + }) + } else { + // 如果移除的是别名,只删除别名记录 + this.aliases.delete(id) + } + + delete this.providers[id] + this.rebuildRegistry() + return this + } + + /** + * 立即重建 registry - 每次变更都重建 + */ + private rebuildRegistry(): void { + if (Object.keys(this.providers).length === 0) { + this.registry = null + return + } + + this.registry = createProviderRegistry(this.providers, { + separator: this.separator + }) + } + + /** + * 获取语言模型 - AI SDK 原生方法 + */ + languageModel(id: `${string}${SEPARATOR}${string}`): LanguageModelV2 { + if (!this.registry) { + throw new Error('No providers registered') + } + return this.registry.languageModel(id) + } + + /** + * 获取文本嵌入模型 - AI SDK 原生方法 + */ + textEmbeddingModel(id: `${string}${SEPARATOR}${string}`): EmbeddingModelV2 { + if (!this.registry) { + throw new Error('No providers registered') + } + return this.registry.textEmbeddingModel(id) + } + + /** + * 获取图像模型 - AI SDK 原生方法 + */ + imageModel(id: `${string}${SEPARATOR}${string}`): ImageModelV2 { + if (!this.registry) { + throw new Error('No providers registered') + } + return this.registry.imageModel(id) + } + + /** + * 获取转录模型 - AI SDK 原生方法 + */ + transcriptionModel(id: `${string}${SEPARATOR}${string}`): any { + if (!this.registry) { + throw new Error('No providers registered') + } + return this.registry.transcriptionModel(id) + } + + /** + * 获取语音模型 - AI SDK 原生方法 + */ + speechModel(id: `${string}${SEPARATOR}${string}`): any { + if (!this.registry) { + throw new Error('No providers registered') + } + return this.registry.speechModel(id) + } + + /** + * 获取已注册的 provider 列表 + */ + getRegisteredProviders(): string[] { + return Object.keys(this.providers) + } + + /** + * 检查是否有已注册的 providers + */ + hasProviders(): boolean { + return Object.keys(this.providers).length > 0 + } + + /** + * 清除所有 providers + */ + clear(): this { + this.providers = {} + this.aliases.clear() + this.registry = null + return this + } + + /** + * 解析真实的Provider ID(供getAiSdkProviderId使用) + * 如果传入的是别名,返回真实的Provider ID + * 如果传入的是真实ID,直接返回 + */ + resolveProviderId(id: string): string { + if (!this.aliases.has(id)) return id // 不是别名,直接返回 + + // 是别名,找到真实ID + const targetProvider = this.providers[id] + for (const [realId, provider] of Object.entries(this.providers)) { + if (provider === targetProvider && !this.aliases.has(realId)) { + return realId + } + } + return id + } + + /** + * 检查是否为别名 + */ + isAlias(id: string): boolean { + return this.aliases.has(id) + } + + /** + * 获取所有别名映射关系 + */ + getAllAliases(): Record { + const result: Record = {} + this.aliases.forEach((alias) => { + result[alias] = this.resolveProviderId(alias) + }) + return result + } +} + +/** + * 全局注册表管理器实例 + * 使用 | 作为分隔符,因为 : 会和 :free 等suffix冲突 + */ +export const globalRegistryManagement = new RegistryManagement() diff --git a/packages/aiCore/src/core/providers/__tests__/registry-functionality.test.ts b/packages/aiCore/src/core/providers/__tests__/registry-functionality.test.ts new file mode 100644 index 0000000000..a4d7c43b0c --- /dev/null +++ b/packages/aiCore/src/core/providers/__tests__/registry-functionality.test.ts @@ -0,0 +1,632 @@ +/** + * 测试真正的 AiProviderRegistry 功能 + */ + +import { beforeEach, describe, expect, it, vi } from 'vitest' + +// 模拟 AI SDK +vi.mock('@ai-sdk/openai', () => ({ + createOpenAI: vi.fn(() => ({ name: 'openai-mock' })) +})) + +vi.mock('@ai-sdk/anthropic', () => ({ + createAnthropic: vi.fn(() => ({ name: 'anthropic-mock' })) +})) + +vi.mock('@ai-sdk/azure', () => ({ + createAzure: vi.fn(() => ({ name: 'azure-mock' })) +})) + +vi.mock('@ai-sdk/deepseek', () => ({ + createDeepSeek: vi.fn(() => ({ name: 'deepseek-mock' })) +})) + +vi.mock('@ai-sdk/google', () => ({ + createGoogleGenerativeAI: vi.fn(() => ({ name: 'google-mock' })) +})) + +vi.mock('@ai-sdk/openai-compatible', () => ({ + createOpenAICompatible: vi.fn(() => ({ name: 'openai-compatible-mock' })) +})) + +vi.mock('@ai-sdk/xai', () => ({ + createXai: vi.fn(() => ({ name: 'xai-mock' })) +})) + +import { + cleanup, + clearAllProviders, + createAndRegisterProvider, + createProvider, + getAllProviderConfigAliases, + getAllProviderConfigs, + getInitializedProviders, + getLanguageModel, + getProviderConfig, + getProviderConfigByAlias, + getSupportedProviders, + hasInitializedProviders, + hasProviderConfig, + hasProviderConfigByAlias, + isProviderConfigAlias, + ProviderInitializationError, + providerRegistry, + registerMultipleProviderConfigs, + registerProvider, + registerProviderConfig, + resolveProviderConfigId +} from '../registry' +import type { ProviderConfig } from '../schemas' + +describe('Provider Registry 功能测试', () => { + beforeEach(() => { + // 清理状态 + cleanup() + }) + + describe('基础功能', () => { + it('能够获取支持的 providers 列表', () => { + const providers = getSupportedProviders() + expect(Array.isArray(providers)).toBe(true) + expect(providers.length).toBeGreaterThan(0) + + // 检查返回的数据结构 + providers.forEach((provider) => { + expect(provider).toHaveProperty('id') + expect(provider).toHaveProperty('name') + expect(typeof provider.id).toBe('string') + expect(typeof provider.name).toBe('string') + }) + + // 包含基础 providers + const providerIds = providers.map((p) => p.id) + expect(providerIds).toContain('openai') + expect(providerIds).toContain('anthropic') + expect(providerIds).toContain('google') + }) + + it('能够获取已初始化的 providers', () => { + // 初始状态下没有已初始化的 providers + expect(getInitializedProviders()).toEqual([]) + expect(hasInitializedProviders()).toBe(false) + }) + + it('能够访问全局注册管理器', () => { + expect(providerRegistry).toBeDefined() + expect(typeof providerRegistry.clear).toBe('function') + expect(typeof providerRegistry.getRegisteredProviders).toBe('function') + expect(typeof providerRegistry.hasProviders).toBe('function') + }) + + it('能够获取语言模型', () => { + // 在没有注册 provider 的情况下,这个函数应该会抛出错误 + expect(() => getLanguageModel('non-existent')).toThrow('No providers registered') + }) + }) + + describe('Provider 配置注册', () => { + it('能够注册自定义 provider 配置', () => { + const config: ProviderConfig = { + id: 'custom-provider', + name: 'Custom Provider', + creator: vi.fn(() => ({ name: 'custom' })), + supportsImageGeneration: false + } + + const success = registerProviderConfig(config) + expect(success).toBe(true) + + expect(hasProviderConfig('custom-provider')).toBe(true) + expect(getProviderConfig('custom-provider')).toEqual(config) + }) + + it('能够注册带别名的 provider 配置', () => { + const config: ProviderConfig = { + id: 'custom-provider-with-aliases', + name: 'Custom Provider with Aliases', + creator: vi.fn(() => ({ name: 'custom-aliased' })), + supportsImageGeneration: false, + aliases: ['alias-1', 'alias-2'] + } + + const success = registerProviderConfig(config) + expect(success).toBe(true) + + expect(hasProviderConfigByAlias('alias-1')).toBe(true) + expect(hasProviderConfigByAlias('alias-2')).toBe(true) + expect(getProviderConfigByAlias('alias-1')).toEqual(config) + expect(resolveProviderConfigId('alias-1')).toBe('custom-provider-with-aliases') + }) + + it('拒绝无效的配置', () => { + // 缺少必要字段 + const invalidConfig = { + id: 'invalid-provider' + // 缺少 name, creator 等 + } + + const success = registerProviderConfig(invalidConfig as any) + expect(success).toBe(false) + }) + + it('能够批量注册 provider 配置', () => { + const configs: ProviderConfig[] = [ + { + id: 'provider-1', + name: 'Provider 1', + creator: vi.fn(() => ({ name: 'provider-1' })), + supportsImageGeneration: false + }, + { + id: 'provider-2', + name: 'Provider 2', + creator: vi.fn(() => ({ name: 'provider-2' })), + supportsImageGeneration: true + }, + { + id: '', // 无效配置 + name: 'Invalid Provider', + creator: vi.fn(() => ({ name: 'invalid' })), + supportsImageGeneration: false + } as any + ] + + const successCount = registerMultipleProviderConfigs(configs) + expect(successCount).toBe(2) // 只有前两个成功 + + expect(hasProviderConfig('provider-1')).toBe(true) + expect(hasProviderConfig('provider-2')).toBe(true) + expect(hasProviderConfig('')).toBe(false) + }) + + it('能够获取所有配置和别名信息', () => { + // 注册一些配置 + registerProviderConfig({ + id: 'test-provider', + name: 'Test Provider', + creator: vi.fn(), + supportsImageGeneration: false, + aliases: ['test-alias'] + }) + + const allConfigs = getAllProviderConfigs() + expect(Array.isArray(allConfigs)).toBe(true) + expect(allConfigs.some((config) => config.id === 'test-provider')).toBe(true) + + const aliases = getAllProviderConfigAliases() + expect(aliases['test-alias']).toBe('test-provider') + expect(isProviderConfigAlias('test-alias')).toBe(true) + }) + }) + + describe('Provider 创建和注册', () => { + it('能够创建 provider 实例', async () => { + const config: ProviderConfig = { + id: 'test-create-provider', + name: 'Test Create Provider', + creator: vi.fn(() => ({ name: 'test-created' })), + supportsImageGeneration: false + } + + // 先注册配置 + registerProviderConfig(config) + + // 创建 provider 实例 + const provider = await createProvider('test-create-provider', { apiKey: 'test' }) + expect(provider).toBeDefined() + expect(config.creator).toHaveBeenCalledWith({ apiKey: 'test' }) + }) + + it('能够注册 provider 到全局管理器', () => { + const mockProvider = { name: 'mock-provider' } + const config: ProviderConfig = { + id: 'test-register-provider', + name: 'Test Register Provider', + creator: vi.fn(() => mockProvider), + supportsImageGeneration: false + } + + // 先注册配置 + registerProviderConfig(config) + + // 注册 provider 到全局管理器 + const success = registerProvider('test-register-provider', mockProvider) + expect(success).toBe(true) + + // 验证注册成功 + const registeredProviders = getInitializedProviders() + expect(registeredProviders).toContain('test-register-provider') + expect(hasInitializedProviders()).toBe(true) + }) + + it('能够一步完成创建和注册', async () => { + const config: ProviderConfig = { + id: 'test-create-and-register', + name: 'Test Create and Register', + creator: vi.fn(() => ({ name: 'test-both' })), + supportsImageGeneration: false + } + + // 先注册配置 + registerProviderConfig(config) + + // 一步完成创建和注册 + const success = await createAndRegisterProvider('test-create-and-register', { apiKey: 'test' }) + expect(success).toBe(true) + + // 验证注册成功 + const registeredProviders = getInitializedProviders() + expect(registeredProviders).toContain('test-create-and-register') + }) + }) + + describe('Registry 管理', () => { + it('能够清理所有配置和注册的 providers', () => { + // 注册一些配置 + registerProviderConfig({ + id: 'temp-provider', + name: 'Temp Provider', + creator: vi.fn(() => ({ name: 'temp' })), + supportsImageGeneration: false + }) + + expect(hasProviderConfig('temp-provider')).toBe(true) + + // 清理 + cleanup() + + expect(hasProviderConfig('temp-provider')).toBe(false) + // 但基础配置应该重新加载 + expect(hasProviderConfig('openai')).toBe(true) // 基础 providers 会重新初始化 + }) + + it('能够单独清理已注册的 providers', () => { + // 清理所有 providers + clearAllProviders() + + expect(getInitializedProviders()).toEqual([]) + expect(hasInitializedProviders()).toBe(false) + }) + + it('ProviderInitializationError 错误类工作正常', () => { + const error = new ProviderInitializationError('Test error', 'test-provider') + expect(error.message).toBe('Test error') + expect(error.providerId).toBe('test-provider') + expect(error.name).toBe('ProviderInitializationError') + }) + }) + + describe('错误处理', () => { + it('优雅处理空配置', () => { + const success = registerProviderConfig(null as any) + expect(success).toBe(false) + }) + + it('优雅处理未定义配置', () => { + const success = registerProviderConfig(undefined as any) + expect(success).toBe(false) + }) + + it('处理空字符串 ID', () => { + const config = { + id: '', + name: 'Empty ID Provider', + creator: vi.fn(() => ({ name: 'empty' })), + supportsImageGeneration: false + } + + const success = registerProviderConfig(config) + expect(success).toBe(false) + }) + + it('处理创建不存在配置的 provider', async () => { + await expect(createProvider('non-existent-provider', {})).rejects.toThrow( + 'ProviderConfig not found for id: non-existent-provider' + ) + }) + + it('处理注册不存在配置的 provider', () => { + const mockProvider = { name: 'mock' } + const success = registerProvider('non-existent-provider', mockProvider) + expect(success).toBe(false) + }) + + it('处理获取不存在配置的情况', () => { + expect(getProviderConfig('non-existent')).toBeUndefined() + expect(getProviderConfigByAlias('non-existent-alias')).toBeUndefined() + expect(hasProviderConfig('non-existent')).toBe(false) + expect(hasProviderConfigByAlias('non-existent-alias')).toBe(false) + }) + + it('处理批量注册时的部分失败', () => { + const mixedConfigs: ProviderConfig[] = [ + { + id: 'valid-provider-1', + name: 'Valid Provider 1', + creator: vi.fn(() => ({ name: 'valid-1' })), + supportsImageGeneration: false + }, + { + id: '', // 无效配置 + name: 'Invalid Provider', + creator: vi.fn(() => ({ name: 'invalid' })), + supportsImageGeneration: false + } as any, + { + id: 'valid-provider-2', + name: 'Valid Provider 2', + creator: vi.fn(() => ({ name: 'valid-2' })), + supportsImageGeneration: true + } + ] + + const successCount = registerMultipleProviderConfigs(mixedConfigs) + expect(successCount).toBe(2) // 只有两个有效配置成功 + + expect(hasProviderConfig('valid-provider-1')).toBe(true) + expect(hasProviderConfig('valid-provider-2')).toBe(true) + expect(hasProviderConfig('')).toBe(false) + }) + + it('处理动态导入失败的情况', async () => { + const config: ProviderConfig = { + id: 'import-test-provider', + name: 'Import Test Provider', + import: vi.fn().mockRejectedValue(new Error('Import failed')), + creatorFunctionName: 'createTest', + supportsImageGeneration: false + } + + registerProviderConfig(config) + + await expect(createProvider('import-test-provider', {})).rejects.toThrow('Import failed') + }) + }) + + describe('集成测试', () => { + it('正确处理复杂的配置、创建、注册和清理场景', async () => { + // 初始状态验证 + const initialConfigs = getAllProviderConfigs() + expect(initialConfigs.length).toBeGreaterThan(0) // 有基础配置 + expect(getInitializedProviders()).toEqual([]) + + // 注册多个带别名的 provider 配置 + const configs: ProviderConfig[] = [ + { + id: 'integration-provider-1', + name: 'Integration Provider 1', + creator: vi.fn(() => ({ name: 'integration-1' })), + supportsImageGeneration: false, + aliases: ['alias-1', 'short-name-1'] + }, + { + id: 'integration-provider-2', + name: 'Integration Provider 2', + creator: vi.fn(() => ({ name: 'integration-2' })), + supportsImageGeneration: true, + aliases: ['alias-2', 'short-name-2'] + } + ] + + const successCount = registerMultipleProviderConfigs(configs) + expect(successCount).toBe(2) + + // 验证配置注册成功 + expect(hasProviderConfig('integration-provider-1')).toBe(true) + expect(hasProviderConfig('integration-provider-2')).toBe(true) + expect(hasProviderConfigByAlias('alias-1')).toBe(true) + expect(hasProviderConfigByAlias('alias-2')).toBe(true) + + // 验证别名映射 + const aliases = getAllProviderConfigAliases() + expect(aliases['alias-1']).toBe('integration-provider-1') + expect(aliases['alias-2']).toBe('integration-provider-2') + + // 创建和注册 providers + const success1 = await createAndRegisterProvider('integration-provider-1', { apiKey: 'test1' }) + const success2 = await createAndRegisterProvider('integration-provider-2', { apiKey: 'test2' }) + expect(success1).toBe(true) + expect(success2).toBe(true) + + // 验证注册成功 + const registeredProviders = getInitializedProviders() + expect(registeredProviders).toContain('integration-provider-1') + expect(registeredProviders).toContain('integration-provider-2') + expect(hasInitializedProviders()).toBe(true) + + // 清理 + cleanup() + + // 验证清理后的状态 + expect(getInitializedProviders()).toEqual([]) + expect(hasProviderConfig('integration-provider-1')).toBe(false) + expect(hasProviderConfig('integration-provider-2')).toBe(false) + expect(getAllProviderConfigAliases()).toEqual({}) + + // 基础配置应该重新加载 + expect(hasProviderConfig('openai')).toBe(true) + }) + + it('正确处理动态导入配置的 provider', async () => { + const mockModule = { createCustomProvider: vi.fn(() => ({ name: 'custom-dynamic' })) } + const dynamicImportConfig: ProviderConfig = { + id: 'dynamic-import-provider', + name: 'Dynamic Import Provider', + import: vi.fn().mockResolvedValue(mockModule), + creatorFunctionName: 'createCustomProvider', + supportsImageGeneration: false + } + + // 注册配置 + const configSuccess = registerProviderConfig(dynamicImportConfig) + expect(configSuccess).toBe(true) + + // 创建和注册 provider + const registerSuccess = await createAndRegisterProvider('dynamic-import-provider', { apiKey: 'test' }) + expect(registerSuccess).toBe(true) + + // 验证导入函数被调用 + expect(dynamicImportConfig.import).toHaveBeenCalled() + expect(mockModule.createCustomProvider).toHaveBeenCalledWith({ apiKey: 'test' }) + + // 验证注册成功 + expect(getInitializedProviders()).toContain('dynamic-import-provider') + }) + + it('正确处理大量配置的注册和管理', () => { + const largeConfigList: ProviderConfig[] = [] + + // 生成50个配置 + for (let i = 0; i < 50; i++) { + largeConfigList.push({ + id: `bulk-provider-${i}`, + name: `Bulk Provider ${i}`, + creator: vi.fn(() => ({ name: `bulk-${i}` })), + supportsImageGeneration: i % 2 === 0, // 偶数支持图像生成 + aliases: [`alias-${i}`, `short-${i}`] + }) + } + + const successCount = registerMultipleProviderConfigs(largeConfigList) + expect(successCount).toBe(50) + + // 验证所有配置都被正确注册 + const allConfigs = getAllProviderConfigs() + expect(allConfigs.filter((config) => config.id.startsWith('bulk-provider-')).length).toBe(50) + + // 验证别名数量 + const aliases = getAllProviderConfigAliases() + const bulkAliases = Object.keys(aliases).filter( + (alias) => alias.startsWith('alias-') || alias.startsWith('short-') + ) + expect(bulkAliases.length).toBe(100) // 每个 provider 有2个别名 + + // 随机验证几个配置 + expect(hasProviderConfig('bulk-provider-0')).toBe(true) + expect(hasProviderConfig('bulk-provider-25')).toBe(true) + expect(hasProviderConfig('bulk-provider-49')).toBe(true) + + // 验证别名工作正常 + expect(resolveProviderConfigId('alias-25')).toBe('bulk-provider-25') + expect(isProviderConfigAlias('short-30')).toBe(true) + + // 清理能正确处理大量数据 + cleanup() + const cleanupAliases = getAllProviderConfigAliases() + expect( + Object.keys(cleanupAliases).filter((alias) => alias.startsWith('alias-') || alias.startsWith('short-')) + ).toEqual([]) + }) + }) + + describe('边界测试', () => { + it('处理包含特殊字符的 provider IDs', () => { + const specialCharsConfigs: ProviderConfig[] = [ + { + id: 'provider-with-dashes', + name: 'Provider With Dashes', + creator: vi.fn(() => ({ name: 'dashes' })), + supportsImageGeneration: false + }, + { + id: 'provider_with_underscores', + name: 'Provider With Underscores', + creator: vi.fn(() => ({ name: 'underscores' })), + supportsImageGeneration: false + }, + { + id: 'provider.with.dots', + name: 'Provider With Dots', + creator: vi.fn(() => ({ name: 'dots' })), + supportsImageGeneration: false + } + ] + + const successCount = registerMultipleProviderConfigs(specialCharsConfigs) + expect(successCount).toBeGreaterThan(0) // 至少有一些成功 + + // 验证支持的特殊字符格式 + if (hasProviderConfig('provider-with-dashes')) { + expect(getProviderConfig('provider-with-dashes')).toBeDefined() + } + if (hasProviderConfig('provider_with_underscores')) { + expect(getProviderConfig('provider_with_underscores')).toBeDefined() + } + }) + + it('处理空的批量注册', () => { + const successCount = registerMultipleProviderConfigs([]) + expect(successCount).toBe(0) + + // 确保没有额外的配置被添加 + const configsBefore = getAllProviderConfigs().length + expect(configsBefore).toBeGreaterThan(0) // 应该有基础配置 + }) + + it('处理重复的配置注册', () => { + const config: ProviderConfig = { + id: 'duplicate-test-provider', + name: 'Duplicate Test Provider', + creator: vi.fn(() => ({ name: 'duplicate' })), + supportsImageGeneration: false + } + + // 第一次注册成功 + expect(registerProviderConfig(config)).toBe(true) + expect(hasProviderConfig('duplicate-test-provider')).toBe(true) + + // 重复注册相同的配置(允许覆盖) + const updatedConfig: ProviderConfig = { + ...config, + name: 'Updated Duplicate Test Provider' + } + expect(registerProviderConfig(updatedConfig)).toBe(true) + expect(hasProviderConfig('duplicate-test-provider')).toBe(true) + + // 验证配置被更新 + const retrievedConfig = getProviderConfig('duplicate-test-provider') + expect(retrievedConfig?.name).toBe('Updated Duplicate Test Provider') + }) + + it('处理极长的 ID 和名称', () => { + const longId = 'very-long-provider-id-' + 'x'.repeat(100) + const longName = 'Very Long Provider Name ' + 'Y'.repeat(100) + + const config: ProviderConfig = { + id: longId, + name: longName, + creator: vi.fn(() => ({ name: 'long-test' })), + supportsImageGeneration: false + } + + const success = registerProviderConfig(config) + expect(success).toBe(true) + expect(hasProviderConfig(longId)).toBe(true) + + const retrievedConfig = getProviderConfig(longId) + expect(retrievedConfig?.name).toBe(longName) + }) + + it('处理大量别名的配置', () => { + const manyAliases = Array.from({ length: 50 }, (_, i) => `alias-${i}`) + + const config: ProviderConfig = { + id: 'provider-with-many-aliases', + name: 'Provider With Many Aliases', + creator: vi.fn(() => ({ name: 'many-aliases' })), + supportsImageGeneration: false, + aliases: manyAliases + } + + const success = registerProviderConfig(config) + expect(success).toBe(true) + + // 验证所有别名都能正确解析 + manyAliases.forEach((alias) => { + expect(hasProviderConfigByAlias(alias)).toBe(true) + expect(resolveProviderConfigId(alias)).toBe('provider-with-many-aliases') + expect(isProviderConfigAlias(alias)).toBe(true) + }) + }) + }) +}) diff --git a/packages/aiCore/src/core/providers/__tests__/schemas.test.ts b/packages/aiCore/src/core/providers/__tests__/schemas.test.ts new file mode 100644 index 0000000000..82b390ba05 --- /dev/null +++ b/packages/aiCore/src/core/providers/__tests__/schemas.test.ts @@ -0,0 +1,264 @@ +import { describe, expect, it, vi } from 'vitest' + +import { + type BaseProviderId, + baseProviderIds, + baseProviderIdSchema, + baseProviders, + type CustomProviderId, + customProviderIdSchema, + providerConfigSchema, + type ProviderId, + providerIdSchema +} from '../schemas' + +describe('Provider Schemas', () => { + describe('baseProviders', () => { + it('包含所有预期的基础 providers', () => { + expect(baseProviders).toBeDefined() + expect(Array.isArray(baseProviders)).toBe(true) + expect(baseProviders.length).toBeGreaterThan(0) + + const expectedIds = [ + 'openai', + 'openai-responses', + 'openai-compatible', + 'anthropic', + 'google', + 'xai', + 'azure', + 'deepseek' + ] + const actualIds = baseProviders.map((p) => p.id) + expectedIds.forEach((id) => { + expect(actualIds).toContain(id) + }) + }) + + it('每个基础 provider 有必要的属性', () => { + baseProviders.forEach((provider) => { + expect(provider).toHaveProperty('id') + expect(provider).toHaveProperty('name') + expect(provider).toHaveProperty('creator') + expect(provider).toHaveProperty('supportsImageGeneration') + + expect(typeof provider.id).toBe('string') + expect(typeof provider.name).toBe('string') + expect(typeof provider.creator).toBe('function') + expect(typeof provider.supportsImageGeneration).toBe('boolean') + }) + }) + + it('provider ID 是唯一的', () => { + const ids = baseProviders.map((p) => p.id) + const uniqueIds = [...new Set(ids)] + expect(ids).toEqual(uniqueIds) + }) + }) + + describe('baseProviderIds', () => { + it('正确提取所有基础 provider IDs', () => { + expect(baseProviderIds).toBeDefined() + expect(Array.isArray(baseProviderIds)).toBe(true) + expect(baseProviderIds.length).toBe(baseProviders.length) + + baseProviders.forEach((provider) => { + expect(baseProviderIds).toContain(provider.id) + }) + }) + }) + + describe('baseProviderIdSchema', () => { + it('验证有效的基础 provider IDs', () => { + baseProviderIds.forEach((id) => { + expect(baseProviderIdSchema.safeParse(id).success).toBe(true) + }) + }) + + it('拒绝无效的基础 provider IDs', () => { + const invalidIds = ['invalid', 'not-exists', ''] + invalidIds.forEach((id) => { + expect(baseProviderIdSchema.safeParse(id).success).toBe(false) + }) + }) + }) + + describe('customProviderIdSchema', () => { + it('接受有效的自定义 provider IDs', () => { + const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2'] + validIds.forEach((id) => { + expect(customProviderIdSchema.safeParse(id).success).toBe(true) + }) + }) + + it('拒绝与基础 provider IDs 冲突的 IDs', () => { + baseProviderIds.forEach((id) => { + expect(customProviderIdSchema.safeParse(id).success).toBe(false) + }) + }) + + it('拒绝空字符串', () => { + expect(customProviderIdSchema.safeParse('').success).toBe(false) + }) + }) + + describe('providerIdSchema', () => { + it('接受基础 provider IDs', () => { + baseProviderIds.forEach((id) => { + expect(providerIdSchema.safeParse(id).success).toBe(true) + }) + }) + + it('接受有效的自定义 provider IDs', () => { + const validCustomIds = ['custom-provider', 'my-ai-service'] + validCustomIds.forEach((id) => { + expect(providerIdSchema.safeParse(id).success).toBe(true) + }) + }) + + it('拒绝无效的 IDs', () => { + const invalidIds = ['', undefined, null, 123] + invalidIds.forEach((id) => { + expect(providerIdSchema.safeParse(id).success).toBe(false) + }) + }) + }) + + describe('providerConfigSchema', () => { + it('验证带有 creator 的有效配置', () => { + const validConfig = { + id: 'custom-provider', + name: 'Custom Provider', + creator: vi.fn(), + supportsImageGeneration: true + } + expect(providerConfigSchema.safeParse(validConfig).success).toBe(true) + }) + + it('验证带有 import 配置的有效配置', () => { + const validConfig = { + id: 'custom-provider', + name: 'Custom Provider', + import: vi.fn(), + creatorFunctionName: 'createCustom', + supportsImageGeneration: false + } + expect(providerConfigSchema.safeParse(validConfig).success).toBe(true) + }) + + it('拒绝既没有 creator 也没有 import 配置的配置', () => { + const invalidConfig = { + id: 'invalid', + name: 'Invalid Provider', + supportsImageGeneration: false + } + expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false) + }) + + it('为 supportsImageGeneration 设置默认值', () => { + const config = { + id: 'test', + name: 'Test', + creator: vi.fn() + } + const result = providerConfigSchema.safeParse(config) + expect(result.success).toBe(true) + if (result.success) { + expect(result.data.supportsImageGeneration).toBe(false) + } + }) + + it('拒绝使用基础 provider ID 的配置', () => { + const invalidConfig = { + id: 'openai', // 基础 provider ID + name: 'Should Fail', + creator: vi.fn() + } + expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false) + }) + + it('拒绝缺少必需字段的配置', () => { + const invalidConfigs = [ + { name: 'Missing ID', creator: vi.fn() }, + { id: 'missing-name', creator: vi.fn() }, + { id: '', name: 'Empty ID', creator: vi.fn() }, + { id: 'valid-custom', name: '', creator: vi.fn() } + ] + + invalidConfigs.forEach((config) => { + expect(providerConfigSchema.safeParse(config).success).toBe(false) + }) + }) + }) + + describe('Schema 验证功能', () => { + it('baseProviderIdSchema 正确验证基础 provider IDs', () => { + baseProviderIds.forEach((id) => { + expect(baseProviderIdSchema.safeParse(id).success).toBe(true) + }) + + expect(baseProviderIdSchema.safeParse('invalid-id').success).toBe(false) + }) + + it('customProviderIdSchema 正确验证自定义 provider IDs', () => { + const customIds = ['custom-provider', 'my-service', 'company-llm'] + customIds.forEach((id) => { + expect(customProviderIdSchema.safeParse(id).success).toBe(true) + }) + + // 拒绝基础 provider IDs + baseProviderIds.forEach((id) => { + expect(customProviderIdSchema.safeParse(id).success).toBe(false) + }) + }) + + it('providerIdSchema 接受基础和自定义 provider IDs', () => { + // 基础 IDs + baseProviderIds.forEach((id) => { + expect(providerIdSchema.safeParse(id).success).toBe(true) + }) + + // 自定义 IDs + const customIds = ['custom-provider', 'my-service'] + customIds.forEach((id) => { + expect(providerIdSchema.safeParse(id).success).toBe(true) + }) + }) + + it('providerConfigSchema 验证完整的 provider 配置', () => { + const validConfig = { + id: 'custom-provider', + name: 'Custom Provider', + creator: vi.fn(), + supportsImageGeneration: true + } + expect(providerConfigSchema.safeParse(validConfig).success).toBe(true) + + const invalidConfig = { + id: 'openai', // 不允许基础 provider ID + name: 'OpenAI', + creator: vi.fn() + } + expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false) + }) + }) + + describe('类型推导', () => { + it('BaseProviderId 类型正确', () => { + const id: BaseProviderId = 'openai' + expect(baseProviderIds).toContain(id) + }) + + it('CustomProviderId 类型是字符串', () => { + const id: CustomProviderId = 'custom-provider' + expect(typeof id).toBe('string') + }) + + it('ProviderId 类型支持基础和自定义 IDs', () => { + const baseId: ProviderId = 'openai' + const customId: ProviderId = 'custom-provider' + expect(typeof baseId).toBe('string') + expect(typeof customId).toBe('string') + }) + }) +}) diff --git a/packages/aiCore/src/core/providers/factory.ts b/packages/aiCore/src/core/providers/factory.ts new file mode 100644 index 0000000000..831526c192 --- /dev/null +++ b/packages/aiCore/src/core/providers/factory.ts @@ -0,0 +1,291 @@ +/** + * AI Provider 配置工厂 + * 提供类型安全的 Provider 配置构建器 + */ + +import type { ProviderId, ProviderSettingsMap } from './types' + +/** + * 通用配置基础类型,包含所有 Provider 共有的属性 + */ +export interface BaseProviderConfig { + apiKey?: string + baseURL?: string + timeout?: number + headers?: Record + fetch?: typeof globalThis.fetch +} + +/** + * 完整的配置类型,结合基础配置、AI SDK 配置和特定 Provider 配置 + */ +type CompleteProviderConfig = BaseProviderConfig & Partial + +type ConfigHandler = ( + builder: ProviderConfigBuilder, + provider: CompleteProviderConfig +) => void + +const configHandlers: { + [K in ProviderId]?: ConfigHandler +} = { + azure: (builder, provider) => { + const azureBuilder = builder as ProviderConfigBuilder<'azure'> + const azureProvider = provider as CompleteProviderConfig<'azure'> + azureBuilder.withAzureConfig({ + apiVersion: azureProvider.apiVersion, + resourceName: azureProvider.resourceName + }) + } +} + +export class ProviderConfigBuilder { + private config: CompleteProviderConfig = {} as CompleteProviderConfig + + constructor(private providerId: T) {} + + /** + * 设置 API Key + */ + withApiKey(apiKey: string): this + withApiKey(apiKey: string, options: T extends 'openai' ? { organization?: string; project?: string } : never): this + withApiKey(apiKey: string, options?: any): this { + this.config.apiKey = apiKey + + // 类型安全的 OpenAI 特定配置 + if (this.providerId === 'openai' && options) { + const openaiConfig = this.config as CompleteProviderConfig<'openai'> + if (options.organization) { + openaiConfig.organization = options.organization + } + if (options.project) { + openaiConfig.project = options.project + } + } + + return this + } + + /** + * 设置基础 URL + */ + withBaseURL(baseURL: string) { + this.config.baseURL = baseURL + return this + } + + /** + * 设置请求配置 + */ + withRequestConfig(options: { headers?: Record; fetch?: typeof fetch }): this { + if (options.headers) { + this.config.headers = { ...this.config.headers, ...options.headers } + } + if (options.fetch) { + this.config.fetch = options.fetch + } + return this + } + + /** + * Azure OpenAI 特定配置 + */ + withAzureConfig(options: { apiVersion?: string; resourceName?: string }): T extends 'azure' ? this : never + withAzureConfig(options: any): any { + if (this.providerId === 'azure') { + const azureConfig = this.config as CompleteProviderConfig<'azure'> + if (options.apiVersion) { + azureConfig.apiVersion = options.apiVersion + } + if (options.resourceName) { + azureConfig.resourceName = options.resourceName + } + } + return this + } + + /** + * 设置自定义参数 + */ + withCustomParams(params: Record) { + Object.assign(this.config, params) + return this + } + + /** + * 构建最终配置 + */ + build(): ProviderSettingsMap[T] { + return this.config as ProviderSettingsMap[T] + } +} + +/** + * Provider 配置工厂 + * 提供便捷的配置创建方法 + */ +export class ProviderConfigFactory { + /** + * 创建配置构建器 + */ + static builder(providerId: T): ProviderConfigBuilder { + return new ProviderConfigBuilder(providerId) + } + + /** + * 从通用Provider对象创建配置 - 使用更优雅的处理器模式 + */ + static fromProvider( + providerId: T, + provider: CompleteProviderConfig, + options?: { + headers?: Record + [key: string]: any + } + ): ProviderSettingsMap[T] { + const builder = new ProviderConfigBuilder(providerId) + + // 设置基本配置 + if (provider.apiKey) { + builder.withApiKey(provider.apiKey) + } + + if (provider.baseURL) { + builder.withBaseURL(provider.baseURL) + } + + // 设置请求配置 + if (options?.headers) { + builder.withRequestConfig({ + headers: options.headers + }) + } + + // 使用配置处理器模式 - 更加优雅和可扩展 + const handler = configHandlers[providerId] + if (handler) { + handler(builder, provider) + } + + // 添加其他自定义参数 + if (options) { + const customOptions = { ...options } + delete customOptions.headers // 已经处理过了 + if (Object.keys(customOptions).length > 0) { + builder.withCustomParams(customOptions) + } + } + + return builder.build() + } + + /** + * 快速创建 OpenAI 配置 + */ + static createOpenAI( + apiKey: string, + options?: { + baseURL?: string + organization?: string + project?: string + } + ) { + const builder = this.builder('openai') + + // 使用类型安全的重载 + if (options?.organization || options?.project) { + builder.withApiKey(apiKey, { + organization: options.organization, + project: options.project + }) + } else { + builder.withApiKey(apiKey) + } + + return builder.withBaseURL(options?.baseURL || 'https://api.openai.com').build() + } + + /** + * 快速创建 Anthropic 配置 + */ + static createAnthropic( + apiKey: string, + options?: { + baseURL?: string + } + ) { + return this.builder('anthropic') + .withApiKey(apiKey) + .withBaseURL(options?.baseURL || 'https://api.anthropic.com') + .build() + } + + /** + * 快速创建 Azure OpenAI 配置 + */ + static createAzureOpenAI( + apiKey: string, + options: { + baseURL: string + apiVersion?: string + resourceName?: string + } + ) { + return this.builder('azure') + .withApiKey(apiKey) + .withBaseURL(options.baseURL) + .withAzureConfig({ + apiVersion: options.apiVersion, + resourceName: options.resourceName + }) + .build() + } + + /** + * 快速创建 Google 配置 + */ + static createGoogle( + apiKey: string, + options?: { + baseURL?: string + projectId?: string + location?: string + } + ) { + return this.builder('google') + .withApiKey(apiKey) + .withBaseURL(options?.baseURL || 'https://generativelanguage.googleapis.com') + .build() + } + + /** + * 快速创建 Vertex AI 配置 + */ + static createVertexAI() { + // credentials: { + // clientEmail: string + // privateKey: string + // }, + // options?: { + // project?: string + // location?: string + // } + // return this.builder('google-vertex') + // .withGoogleCredentials(credentials) + // .withGoogleVertexConfig({ + // project: options?.project, + // location: options?.location + // }) + // .build() + } + + static createOpenAICompatible(baseURL: string, apiKey: string) { + return this.builder('openai-compatible').withBaseURL(baseURL).withApiKey(apiKey).build() + } +} + +/** + * 便捷的配置创建函数 + */ +export const createProviderConfig = ProviderConfigFactory.fromProvider +export const providerConfigBuilder = ProviderConfigFactory.builder diff --git a/packages/aiCore/src/core/providers/index.ts b/packages/aiCore/src/core/providers/index.ts new file mode 100644 index 0000000000..3ac445cb22 --- /dev/null +++ b/packages/aiCore/src/core/providers/index.ts @@ -0,0 +1,83 @@ +/** + * Providers 模块统一导出 - 独立Provider包 + */ + +// ==================== 核心管理器 ==================== + +// Provider 注册表管理器 +export { globalRegistryManagement, RegistryManagement } from './RegistryManagement' + +// Provider 核心功能 +export { + // 状态管理 + cleanup, + clearAllProviders, + createAndRegisterProvider, + createProvider, + getAllProviderConfigAliases, + getAllProviderConfigs, + getImageModel, + // 工具函数 + getInitializedProviders, + getLanguageModel, + getProviderConfig, + getProviderConfigByAlias, + getSupportedProviders, + getTextEmbeddingModel, + hasInitializedProviders, + // 工具函数 + hasProviderConfig, + // 别名支持 + hasProviderConfigByAlias, + isProviderConfigAlias, + // 错误类型 + ProviderInitializationError, + // 全局访问 + providerRegistry, + registerMultipleProviderConfigs, + registerProvider, + // 统一Provider系统 + registerProviderConfig, + resolveProviderConfigId +} from './registry' + +// ==================== 基础数据和类型 ==================== + +// 基础Provider数据源 +export { baseProviderIds, baseProviders } from './schemas' + +// 类型定义和Schema +export type { + BaseProviderId, + CustomProviderId, + DynamicProviderRegistration, + ProviderConfig, + ProviderId +} from './schemas' // 从 schemas 导出的类型 +export { baseProviderIdSchema, customProviderIdSchema, providerConfigSchema, providerIdSchema } from './schemas' // Schema 导出 +export type { + DynamicProviderRegistry, + ExtensibleProviderSettingsMap, + ProviderError, + ProviderSettingsMap, + ProviderTypeRegistrar +} from './types' + +// ==================== 工具函数 ==================== + +// Provider配置工厂 +export { + type BaseProviderConfig, + createProviderConfig, + ProviderConfigBuilder, + providerConfigBuilder, + ProviderConfigFactory +} from './factory' + +// 工具函数 +export { formatPrivateKey } from './utils' + +// ==================== 扩展功能 ==================== + +// Hub Provider 功能 +export { createHubProvider, type HubProviderConfig, HubProviderError } from './HubProvider' diff --git a/packages/aiCore/src/core/providers/registry.ts b/packages/aiCore/src/core/providers/registry.ts new file mode 100644 index 0000000000..8cf33fdacd --- /dev/null +++ b/packages/aiCore/src/core/providers/registry.ts @@ -0,0 +1,320 @@ +/** + * Provider 初始化器 + * 负责根据配置创建 providers 并注册到全局管理器 + * 集成了来自 ModelCreator 的特殊处理逻辑 + */ + +import { customProvider } from 'ai' + +import { globalRegistryManagement } from './RegistryManagement' +import { baseProviders, type ProviderConfig } from './schemas' + +/** + * Provider 初始化错误类型 + */ +class ProviderInitializationError extends Error { + constructor( + message: string, + public providerId?: string, + public cause?: Error + ) { + super(message) + this.name = 'ProviderInitializationError' + } +} + +// ==================== 全局管理器导出 ==================== + +export { globalRegistryManagement as providerRegistry } + +// ==================== 便捷访问方法 ==================== + +export const getLanguageModel = (id: string) => globalRegistryManagement.languageModel(id as any) +export const getTextEmbeddingModel = (id: string) => globalRegistryManagement.textEmbeddingModel(id as any) +export const getImageModel = (id: string) => globalRegistryManagement.imageModel(id as any) + +// ==================== 工具函数 ==================== + +/** + * 获取支持的 Providers 列表 + */ +export function getSupportedProviders(): Array<{ + id: string + name: string +}> { + return baseProviders.map((provider) => ({ + id: provider.id, + name: provider.name + })) +} + +/** + * 获取所有已初始化的 providers + */ +export function getInitializedProviders(): string[] { + return globalRegistryManagement.getRegisteredProviders() +} + +/** + * 检查是否有任何已初始化的 providers + */ +export function hasInitializedProviders(): boolean { + return globalRegistryManagement.hasProviders() +} + +// ==================== 统一Provider配置系统 ==================== + +// 全局Provider配置存储 +const providerConfigs = new Map() +// 全局ProviderConfig别名映射 - 借鉴RegistryManagement模式 +const providerConfigAliases = new Map() // alias -> realId + +/** + * 初始化内置配置 - 将baseProviders转换为统一格式 + */ +function initializeBuiltInConfigs(): void { + baseProviders.forEach((provider) => { + const config: ProviderConfig = { + id: provider.id, + name: provider.name, + creator: provider.creator as any, // 类型转换以兼容多种creator签名 + supportsImageGeneration: provider.supportsImageGeneration || false + } + providerConfigs.set(provider.id, config) + }) +} + +// 启动时自动注册内置配置 +initializeBuiltInConfigs() + +/** + * 步骤1: 注册Provider配置 - 仅存储配置,不执行创建 + */ +export function registerProviderConfig(config: ProviderConfig): boolean { + try { + // 验证配置 + if (!config || !config.id || !config.name) { + return false + } + + // 检查是否与已有配置冲突(包括内置配置) + if (providerConfigs.has(config.id)) { + console.warn(`ProviderConfig "${config.id}" already exists, will override`) + } + + // 存储配置(内置和用户配置统一处理) + providerConfigs.set(config.id, config) + + // 处理别名 + if (config.aliases && config.aliases.length > 0) { + config.aliases.forEach((alias) => { + if (providerConfigAliases.has(alias)) { + console.warn(`ProviderConfig alias "${alias}" already exists, will override`) + } + providerConfigAliases.set(alias, config.id) + }) + } + + return true + } catch (error) { + console.error(`Failed to register ProviderConfig:`, error) + return false + } +} + +/** + * 步骤2: 创建Provider - 根据配置执行实际创建 + */ +export async function createProvider(providerId: string, options: any): Promise { + // 支持通过别名查找配置 + const config = getProviderConfigByAlias(providerId) + + if (!config) { + throw new Error(`ProviderConfig not found for id: ${providerId}`) + } + + try { + let creator: (options: any) => any + + if (config.creator) { + // 方式1: 直接执行 creator + creator = config.creator + } else if (config.import && config.creatorFunctionName) { + // 方式2: 动态导入并执行 + const module = await config.import() + creator = (module as any)[config.creatorFunctionName] + + if (!creator || typeof creator !== 'function') { + throw new Error(`Creator function "${config.creatorFunctionName}" not found in imported module`) + } + } else { + throw new Error('No valid creator method provided in ProviderConfig') + } + + // 使用真实配置创建provider实例 + return creator(options) + } catch (error) { + console.error(`Failed to create provider "${providerId}":`, error) + throw error + } +} + +/** + * 步骤3: 注册Provider到全局管理器 + */ +export function registerProvider(providerId: string, provider: any): boolean { + try { + const config = providerConfigs.get(providerId) + if (!config) { + console.error(`ProviderConfig not found for id: ${providerId}`) + return false + } + + // 获取aliases配置 + const aliases = config.aliases + + // 处理特殊provider逻辑 + if (providerId === 'openai') { + // 注册默认 openai + globalRegistryManagement.registerProvider(providerId, provider, aliases) + + // 创建并注册 openai-chat 变体 + const openaiChatProvider = customProvider({ + fallbackProvider: { + ...provider, + languageModel: (modelId: string) => provider.chat(modelId) + } + }) + globalRegistryManagement.registerProvider(`${providerId}-chat`, openaiChatProvider) + } else if (providerId === 'azure') { + globalRegistryManagement.registerProvider(`${providerId}-chat`, provider, aliases) + // 跟上面相反,creator产出的默认会调用chat + const azureResponsesProvider = customProvider({ + fallbackProvider: { + ...provider, + languageModel: (modelId: string) => provider.responses(modelId) + } + }) + globalRegistryManagement.registerProvider(providerId, azureResponsesProvider) + } else { + // 其他provider直接注册 + globalRegistryManagement.registerProvider(providerId, provider, aliases) + } + + return true + } catch (error) { + console.error(`Failed to register provider "${providerId}" to global registry:`, error) + return false + } +} + +/** + * 便捷函数: 一次性完成创建+注册 + */ +export async function createAndRegisterProvider(providerId: string, options: any): Promise { + try { + // 步骤2: 创建provider + const provider = await createProvider(providerId, options) + + // 步骤3: 注册到全局管理器 + return registerProvider(providerId, provider) + } catch (error) { + console.error(`Failed to create and register provider "${providerId}":`, error) + return false + } +} + +/** + * 批量注册Provider配置 + */ +export function registerMultipleProviderConfigs(configs: ProviderConfig[]): number { + let successCount = 0 + configs.forEach((config) => { + if (registerProviderConfig(config)) { + successCount++ + } + }) + return successCount +} + +/** + * 检查是否有对应的Provider配置 + */ +export function hasProviderConfig(providerId: string): boolean { + return providerConfigs.has(providerId) +} + +/** + * 通过别名或ID检查是否有对应的Provider配置 + */ +export function hasProviderConfigByAlias(aliasOrId: string): boolean { + const realId = resolveProviderConfigId(aliasOrId) + return providerConfigs.has(realId) +} + +/** + * 获取所有Provider配置 + */ +export function getAllProviderConfigs(): ProviderConfig[] { + return Array.from(providerConfigs.values()) +} + +/** + * 根据ID获取Provider配置 + */ +export function getProviderConfig(providerId: string): ProviderConfig | undefined { + return providerConfigs.get(providerId) +} + +/** + * 通过别名或ID获取Provider配置 + */ +export function getProviderConfigByAlias(aliasOrId: string): ProviderConfig | undefined { + // 先检查是否为别名,如果是则解析为真实ID + const realId = providerConfigAliases.get(aliasOrId) || aliasOrId + return providerConfigs.get(realId) +} + +/** + * 解析真实的ProviderConfig ID(去别名化) + */ +export function resolveProviderConfigId(aliasOrId: string): string { + return providerConfigAliases.get(aliasOrId) || aliasOrId +} + +/** + * 检查是否为ProviderConfig别名 + */ +export function isProviderConfigAlias(id: string): boolean { + return providerConfigAliases.has(id) +} + +/** + * 获取所有ProviderConfig别名映射关系 + */ +export function getAllProviderConfigAliases(): Record { + const result: Record = {} + providerConfigAliases.forEach((realId, alias) => { + result[alias] = realId + }) + return result +} + +/** + * 清理所有Provider配置和已注册的providers + */ +export function cleanup(): void { + providerConfigs.clear() + providerConfigAliases.clear() // 清理别名映射 + globalRegistryManagement.clear() + // 重新初始化内置配置 + initializeBuiltInConfigs() +} + +export function clearAllProviders(): void { + globalRegistryManagement.clear() +} + +// ==================== 导出错误类型 ==================== + +export { ProviderInitializationError } diff --git a/packages/aiCore/src/core/providers/schemas.ts b/packages/aiCore/src/core/providers/schemas.ts new file mode 100644 index 0000000000..0c1c847d98 --- /dev/null +++ b/packages/aiCore/src/core/providers/schemas.ts @@ -0,0 +1,178 @@ +/** + * Provider Config 定义 + */ + +import { createAnthropic } from '@ai-sdk/anthropic' +import { createAzure } from '@ai-sdk/azure' +import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure' +import { createDeepSeek } from '@ai-sdk/deepseek' +import { createGoogleGenerativeAI } from '@ai-sdk/google' +import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai' +import { createOpenAICompatible } from '@ai-sdk/openai-compatible' +import { createXai } from '@ai-sdk/xai' +import { customProvider, type Provider } from 'ai' +import * as z from 'zod' + +/** + * 基础 Provider IDs + */ +export const baseProviderIds = [ + 'openai', + 'openai-chat', + 'openai-compatible', + 'anthropic', + 'google', + 'xai', + 'azure', + 'azure-responses', + 'deepseek' +] as const + +/** + * 基础 Provider ID Schema + */ +export const baseProviderIdSchema = z.enum(baseProviderIds) + +/** + * 基础 Provider ID + */ +export type BaseProviderId = z.infer + +export const baseProviderSchema = z.object({ + id: baseProviderIdSchema, + name: z.string(), + creator: z.function().args(z.any()).returns(z.any()) as z.ZodType<(options: any) => Provider>, + supportsImageGeneration: z.boolean() +}) + +export type BaseProvider = z.infer + +/** + * 基础 Providers 定义 + * 作为唯一数据源,避免重复维护 + */ +export const baseProviders = [ + { + id: 'openai', + name: 'OpenAI', + creator: createOpenAI, + supportsImageGeneration: true + }, + { + id: 'openai-chat', + name: 'OpenAI Chat', + creator: (options: OpenAIProviderSettings) => { + const provider = createOpenAI(options) + return customProvider({ + fallbackProvider: { + ...provider, + languageModel: (modelId: string) => provider.chat(modelId) + } + }) + }, + supportsImageGeneration: true + }, + { + id: 'openai-compatible', + name: 'OpenAI Compatible', + creator: createOpenAICompatible, + supportsImageGeneration: true + }, + { + id: 'anthropic', + name: 'Anthropic', + creator: createAnthropic, + supportsImageGeneration: false + }, + { + id: 'google', + name: 'Google Generative AI', + creator: createGoogleGenerativeAI, + supportsImageGeneration: true + }, + { + id: 'xai', + name: 'xAI (Grok)', + creator: createXai, + supportsImageGeneration: true + }, + { + id: 'azure', + name: 'Azure OpenAI', + creator: createAzure, + supportsImageGeneration: true + }, + { + id: 'azure-responses', + name: 'Azure OpenAI Responses', + creator: (options: AzureOpenAIProviderSettings) => { + const provider = createAzure(options) + return customProvider({ + fallbackProvider: { + ...provider, + languageModel: (modelId: string) => provider.responses(modelId) + } + }) + }, + supportsImageGeneration: true + }, + { + id: 'deepseek', + name: 'DeepSeek', + creator: createDeepSeek, + supportsImageGeneration: false + } +] as const satisfies BaseProvider[] + +/** + * 用户自定义 Provider ID Schema + * 允许任意字符串,但排除基础 provider IDs 以避免冲突 + */ +export const customProviderIdSchema = z + .string() + .min(1) + .refine((id) => !baseProviderIds.includes(id as any), { + message: 'Custom provider ID cannot conflict with base provider IDs' + }) + +/** + * Provider ID Schema - 支持基础和自定义 + */ +export const providerIdSchema = z.union([baseProviderIdSchema, customProviderIdSchema]) + +/** + * Provider 配置 Schema + * 用于Provider的配置验证 + */ +export const providerConfigSchema = z + .object({ + id: customProviderIdSchema, // 只允许自定义ID + name: z.string().min(1), + creator: z.function().optional(), + import: z.function().optional(), + creatorFunctionName: z.string().optional(), + supportsImageGeneration: z.boolean().default(false), + imageCreator: z.function().optional(), + validateOptions: z.function().optional(), + aliases: z.array(z.string()).optional() + }) + .refine((data) => data.creator || (data.import && data.creatorFunctionName), { + message: 'Must provide either creator function or import configuration' + }) + +/** + * Provider ID 类型 - 基于 zod schema 推导 + */ +export type ProviderId = z.infer +export type CustomProviderId = z.infer + +/** + * Provider 配置类型 + */ +export type ProviderConfig = z.infer + +/** + * 兼容性类型别名 + * @deprecated 使用 ProviderConfig 替代 + */ +export type DynamicProviderRegistration = ProviderConfig diff --git a/packages/aiCore/src/core/providers/types.ts b/packages/aiCore/src/core/providers/types.ts new file mode 100644 index 0000000000..f862f43a75 --- /dev/null +++ b/packages/aiCore/src/core/providers/types.ts @@ -0,0 +1,96 @@ +import { type AnthropicProviderSettings } from '@ai-sdk/anthropic' +import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure' +import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek' +import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google' +import { type OpenAIProviderSettings } from '@ai-sdk/openai' +import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible' +import { + EmbeddingModelV2 as EmbeddingModel, + ImageModelV2 as ImageModel, + LanguageModelV2 as LanguageModel, + ProviderV2, + SpeechModelV2 as SpeechModel, + TranscriptionModelV2 as TranscriptionModel +} from '@ai-sdk/provider' +import { type XaiProviderSettings } from '@ai-sdk/xai' + +// 导入基于 Zod 的 ProviderId 类型 +import { type ProviderId as ZodProviderId } from './schemas' + +export interface ExtensibleProviderSettingsMap { + // 基础的静态providers + openai: OpenAIProviderSettings + 'openai-responses': OpenAIProviderSettings + 'openai-compatible': OpenAICompatibleProviderSettings + anthropic: AnthropicProviderSettings + google: GoogleGenerativeAIProviderSettings + xai: XaiProviderSettings + azure: AzureOpenAIProviderSettings + deepseek: DeepSeekProviderSettings +} + +// 动态扩展的provider类型注册表 +export interface DynamicProviderRegistry { + [key: string]: any +} + +// 合并基础和动态provider类型 +export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry + +// 错误类型 +export class ProviderError extends Error { + constructor( + message: string, + public providerId: string, + public code?: string, + public cause?: Error + ) { + super(message) + this.name = 'ProviderError' + } +} + +// 动态ProviderId类型 - 基于 Zod Schema,支持运行时扩展和验证 +export type ProviderId = ZodProviderId + +export interface ProviderTypeRegistrar { + registerProviderType(providerId: T, settingsType: S): void + getProviderSettings(providerId: T): any +} + +// 重新导出所有类型供外部使用 +export type { + AnthropicProviderSettings, + AzureOpenAIProviderSettings, + DeepSeekProviderSettings, + GoogleGenerativeAIProviderSettings, + OpenAICompatibleProviderSettings, + OpenAIProviderSettings, + XaiProviderSettings +} + +export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel | TranscriptionModel | SpeechModel + +export type AiSdkModelType = 'text' | 'image' | 'embedding' | 'transcription' | 'speech' + +export const METHOD_MAP = { + text: 'languageModel', + image: 'imageModel', + embedding: 'textEmbeddingModel', + transcription: 'transcriptionModel', + speech: 'speechModel' +} as const satisfies Record + +export type AiSdkModelMethodMap = Record + +export type AiSdkModelReturnMap = { + text: LanguageModel + image: ImageModel + embedding: EmbeddingModel + transcription: TranscriptionModel + speech: SpeechModel +} + +export type AiSdkMethodName = (typeof METHOD_MAP)[T] + +export type AiSdkModelReturn = AiSdkModelReturnMap[T] diff --git a/packages/aiCore/src/core/providers/utils.ts b/packages/aiCore/src/core/providers/utils.ts new file mode 100644 index 0000000000..08f08fa5d1 --- /dev/null +++ b/packages/aiCore/src/core/providers/utils.ts @@ -0,0 +1,86 @@ +/** + * 格式化私钥,确保它包含正确的PEM头部和尾部 + */ +export function formatPrivateKey(privateKey: string): string { + if (!privateKey || typeof privateKey !== 'string') { + throw new Error('Private key must be a non-empty string') + } + + // 先处理 JSON 字符串中的转义换行符 + const key = privateKey.replace(/\\n/g, '\n') + + // 检查是否已经是正确格式的 PEM 私钥 + const hasBeginMarker = key.includes('-----BEGIN PRIVATE KEY-----') + const hasEndMarker = key.includes('-----END PRIVATE KEY-----') + + if (hasBeginMarker && hasEndMarker) { + // 已经是 PEM 格式,但可能格式不规范,重新格式化 + return normalizePemFormat(key) + } + + // 如果没有完整的 PEM 头尾,尝试重新构建 + return reconstructPemKey(key) +} + +/** + * 标准化 PEM 格式 + */ +function normalizePemFormat(pemKey: string): string { + // 分离头部、内容和尾部 + const lines = pemKey + .split('\n') + .map((line) => line.trim()) + .filter((line) => line.length > 0) + + let keyContent = '' + let foundBegin = false + let foundEnd = false + + for (const line of lines) { + if (line === '-----BEGIN PRIVATE KEY-----') { + foundBegin = true + continue + } + if (line === '-----END PRIVATE KEY-----') { + foundEnd = true + break + } + if (foundBegin && !foundEnd) { + keyContent += line + } + } + + if (!foundBegin || !foundEnd || !keyContent) { + throw new Error('Invalid PEM format: missing BEGIN/END markers or key content') + } + + // 重新格式化为 64 字符一行 + const formattedContent = keyContent.match(/.{1,64}/g)?.join('\n') || keyContent + + return `-----BEGIN PRIVATE KEY-----\n${formattedContent}\n-----END PRIVATE KEY-----` +} + +/** + * 重新构建 PEM 私钥 + */ +function reconstructPemKey(key: string): string { + // 移除所有空白字符和可能存在的不完整头尾 + let cleanKey = key.replace(/\s+/g, '') + cleanKey = cleanKey.replace(/-----BEGIN[^-]*-----/g, '') + cleanKey = cleanKey.replace(/-----END[^-]*-----/g, '') + + // 确保私钥内容不为空 + if (!cleanKey) { + throw new Error('Private key content is empty after cleaning') + } + + // 验证是否是有效的 Base64 字符 + if (!/^[A-Za-z0-9+/=]+$/.test(cleanKey)) { + throw new Error('Private key contains invalid characters (not valid Base64)') + } + + // 格式化为 64 字符一行 + const formattedKey = cleanKey.match(/.{1,64}/g)?.join('\n') || cleanKey + + return `-----BEGIN PRIVATE KEY-----\n${formattedKey}\n-----END PRIVATE KEY-----` +} diff --git a/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts b/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts new file mode 100644 index 0000000000..bde5779fd9 --- /dev/null +++ b/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts @@ -0,0 +1,523 @@ +import { ImageModelV2 } from '@ai-sdk/provider' +import { experimental_generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { type AiPlugin } from '../../plugins' +import { globalRegistryManagement } from '../../providers/RegistryManagement' +import { ImageGenerationError, ImageModelResolutionError } from '../errors' +import { RuntimeExecutor } from '../executor' + +// Mock dependencies +vi.mock('ai', () => ({ + experimental_generateImage: vi.fn(), + NoImageGeneratedError: class NoImageGeneratedError extends Error { + static isInstance = vi.fn() + constructor() { + super('No image generated') + this.name = 'NoImageGeneratedError' + } + } +})) + +vi.mock('../../providers/RegistryManagement', () => ({ + globalRegistryManagement: { + imageModel: vi.fn() + }, + DEFAULT_SEPARATOR: '|' +})) + +describe('RuntimeExecutor.generateImage', () => { + let executor: RuntimeExecutor<'openai'> + let mockImageModel: ImageModelV2 + let mockGenerateImageResult: any + + beforeEach(() => { + // Reset all mocks + vi.clearAllMocks() + + // Create executor instance + executor = RuntimeExecutor.create('openai', { + apiKey: 'test-key' + }) + + // Mock image model + mockImageModel = { + modelId: 'dall-e-3', + provider: 'openai' + } as ImageModelV2 + + // Mock generateImage result + mockGenerateImageResult = { + image: { + base64: 'base64-encoded-image-data', + uint8Array: new Uint8Array([1, 2, 3]), + mediaType: 'image/png' + }, + images: [ + { + base64: 'base64-encoded-image-data', + uint8Array: new Uint8Array([1, 2, 3]), + mediaType: 'image/png' + } + ], + warnings: [], + providerMetadata: { + openai: { + images: [{ revisedPrompt: 'A detailed prompt' }] + } + }, + responses: [] + } + + // Setup mocks to avoid "No providers registered" error + vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel) + vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult) + }) + + describe('Basic functionality', () => { + it('should generate a single image with minimal parameters', async () => { + const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape at sunset' }) + + expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3') + + expect(aiGenerateImage).toHaveBeenCalledWith({ + model: mockImageModel, + prompt: 'A futuristic cityscape at sunset' + }) + + expect(result).toEqual(mockGenerateImageResult) + }) + + it('should generate image with pre-created model', async () => { + const result = await executor.generateImage({ + model: mockImageModel, + prompt: 'A beautiful landscape' + }) + + // Note: globalRegistryManagement.imageModel may still be called due to resolveImageModel logic + expect(aiGenerateImage).toHaveBeenCalledWith({ + model: mockImageModel, + prompt: 'A beautiful landscape' + }) + + expect(result).toEqual(mockGenerateImageResult) + }) + + it('should support multiple images generation', async () => { + await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape', n: 3 }) + + expect(aiGenerateImage).toHaveBeenCalledWith({ + model: mockImageModel, + prompt: 'A futuristic cityscape', + n: 3 + }) + }) + + it('should support size specification', async () => { + await executor.generateImage({ model: 'dall-e-3', prompt: 'A beautiful sunset', size: '1024x1024' }) + + expect(aiGenerateImage).toHaveBeenCalledWith({ + model: mockImageModel, + prompt: 'A beautiful sunset', + size: '1024x1024' + }) + }) + + it('should support aspect ratio specification', async () => { + await executor.generateImage({ model: 'dall-e-3', prompt: 'A mountain landscape', aspectRatio: '16:9' }) + + expect(aiGenerateImage).toHaveBeenCalledWith({ + model: mockImageModel, + prompt: 'A mountain landscape', + aspectRatio: '16:9' + }) + }) + + it('should support seed for consistent output', async () => { + await executor.generateImage({ model: 'dall-e-3', prompt: 'A cat in space', seed: 1234567890 }) + + expect(aiGenerateImage).toHaveBeenCalledWith({ + model: mockImageModel, + prompt: 'A cat in space', + seed: 1234567890 + }) + }) + + it('should support abort signal', async () => { + const abortController = new AbortController() + + await executor.generateImage({ model: 'dall-e-3', prompt: 'A cityscape', abortSignal: abortController.signal }) + + expect(aiGenerateImage).toHaveBeenCalledWith({ + model: mockImageModel, + prompt: 'A cityscape', + abortSignal: abortController.signal + }) + }) + + it('should support provider-specific options', async () => { + await executor.generateImage({ + model: 'dall-e-3', + prompt: 'A space station', + providerOptions: { + openai: { + quality: 'hd', + style: 'vivid' + } + } + }) + + expect(aiGenerateImage).toHaveBeenCalledWith({ + model: mockImageModel, + prompt: 'A space station', + providerOptions: { + openai: { + quality: 'hd', + style: 'vivid' + } + } + }) + }) + + it('should support custom headers', async () => { + await executor.generateImage({ + model: 'dall-e-3', + prompt: 'A robot', + headers: { + 'X-Custom-Header': 'test-value' + } + }) + + expect(aiGenerateImage).toHaveBeenCalledWith({ + model: mockImageModel, + prompt: 'A robot', + headers: { + 'X-Custom-Header': 'test-value' + } + }) + }) + }) + + describe('Plugin integration', () => { + it('should execute plugins in correct order', async () => { + const pluginCallOrder: string[] = [] + + const testPlugin: AiPlugin = { + name: 'test-plugin', + onRequestStart: vi.fn(async () => { + pluginCallOrder.push('onRequestStart') + }), + transformParams: vi.fn(async (params) => { + pluginCallOrder.push('transformParams') + return { ...params, size: '512x512' } + }), + transformResult: vi.fn(async (result) => { + pluginCallOrder.push('transformResult') + return { ...result, processed: true } + }), + onRequestEnd: vi.fn(async () => { + pluginCallOrder.push('onRequestEnd') + }) + } + + const executorWithPlugin = RuntimeExecutor.create( + 'openai', + { + apiKey: 'test-key' + }, + [testPlugin] + ) + + const result = await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' }) + + expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd']) + + expect(testPlugin.transformParams).toHaveBeenCalledWith( + { prompt: 'A test image' }, + expect.objectContaining({ + providerId: 'openai', + modelId: 'dall-e-3' + }) + ) + + expect(aiGenerateImage).toHaveBeenCalledWith({ + model: mockImageModel, + prompt: 'A test image', + size: '512x512' // Should be transformed by plugin + }) + + expect(result).toEqual({ + ...mockGenerateImageResult, + processed: true // Should be transformed by plugin + }) + }) + + it('should handle model resolution through plugins', async () => { + const customImageModel = { + modelId: 'custom-model', + provider: 'openai' + } as ImageModelV2 + + const modelResolutionPlugin: AiPlugin = { + name: 'model-resolver', + resolveModel: vi.fn(async () => customImageModel) + } + + const executorWithPlugin = RuntimeExecutor.create( + 'openai', + { + apiKey: 'test-key' + }, + [modelResolutionPlugin] + ) + + await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' }) + + expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith( + 'dall-e-3', + expect.objectContaining({ + providerId: 'openai', + modelId: 'dall-e-3' + }) + ) + + expect(aiGenerateImage).toHaveBeenCalledWith({ + model: customImageModel, + prompt: 'A test image' + }) + }) + + it('should support recursive calls from plugins', async () => { + const recursivePlugin: AiPlugin = { + name: 'recursive-plugin', + transformParams: vi.fn(async (params, context) => { + if (!context.isRecursiveCall && params.prompt === 'original') { + // Make a recursive call with modified prompt + await context.recursiveCall({ + model: 'dall-e-3', + prompt: 'modified' + }) + } + return params + }) + } + + const executorWithPlugin = RuntimeExecutor.create( + 'openai', + { + apiKey: 'test-key' + }, + [recursivePlugin] + ) + + await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'original' }) + + expect(recursivePlugin.transformParams).toHaveBeenCalledTimes(2) + expect(aiGenerateImage).toHaveBeenCalledTimes(2) + }) + }) + + describe('Error handling', () => { + it('should handle model creation errors', async () => { + const modelError = new Error('Failed to get image model') + vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => { + throw modelError + }) + + await expect(executor.generateImage({ model: 'invalid-model', prompt: 'A test image' })).rejects.toThrow( + ImageGenerationError + ) + }) + + it('should handle ImageModelResolutionError correctly', async () => { + const resolutionError = new ImageModelResolutionError('invalid-model', 'openai', new Error('Model not found')) + vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => { + throw resolutionError + }) + + const thrownError = await executor + .generateImage({ model: 'invalid-model', prompt: 'A test image' }) + .catch((error) => error) + + expect(thrownError).toBeInstanceOf(ImageGenerationError) + expect(thrownError.message).toContain('Failed to generate image:') + expect(thrownError.providerId).toBe('openai') + expect(thrownError.modelId).toBe('invalid-model') + expect(thrownError.cause).toBeInstanceOf(ImageModelResolutionError) + expect(thrownError.cause.message).toContain('Failed to resolve image model: invalid-model') + }) + + it('should handle ImageModelResolutionError without provider', async () => { + const resolutionError = new ImageModelResolutionError('unknown-model') + vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => { + throw resolutionError + }) + + await expect(executor.generateImage({ model: 'unknown-model', prompt: 'A test image' })).rejects.toThrow( + ImageGenerationError + ) + }) + + it('should handle image generation API errors', async () => { + const apiError = new Error('API request failed') + vi.mocked(aiGenerateImage).mockRejectedValue(apiError) + + await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow( + 'Failed to generate image:' + ) + }) + + it('should handle NoImageGeneratedError', async () => { + const noImageError = new NoImageGeneratedError({ + cause: new Error('No image generated'), + responses: [] + }) + + vi.mocked(aiGenerateImage).mockRejectedValue(noImageError) + vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true) + + await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow( + 'Failed to generate image:' + ) + }) + + it('should execute onError plugin hook on failure', async () => { + const error = new Error('Generation failed') + vi.mocked(aiGenerateImage).mockRejectedValue(error) + + const errorPlugin: AiPlugin = { + name: 'error-handler', + onError: vi.fn() + } + + const executorWithPlugin = RuntimeExecutor.create( + 'openai', + { + apiKey: 'test-key' + }, + [errorPlugin] + ) + + await expect(executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow( + 'Failed to generate image:' + ) + + expect(errorPlugin.onError).toHaveBeenCalledWith( + error, + expect.objectContaining({ + providerId: 'openai', + modelId: 'dall-e-3' + }) + ) + }) + + it('should handle abort signal timeout', async () => { + const abortError = new Error('Operation was aborted') + abortError.name = 'AbortError' + vi.mocked(aiGenerateImage).mockRejectedValue(abortError) + + const abortController = new AbortController() + setTimeout(() => abortController.abort(), 10) + + await expect( + executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', abortSignal: abortController.signal }) + ).rejects.toThrow('Failed to generate image:') + }) + }) + + describe('Multiple providers support', () => { + it('should work with different providers', async () => { + const googleExecutor = RuntimeExecutor.create('google', { + apiKey: 'google-key' + }) + + await googleExecutor.generateImage({ model: 'imagen-3.0-generate-002', prompt: 'A landscape' }) + + expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google|imagen-3.0-generate-002') + }) + + it('should support xAI Grok image models', async () => { + const xaiExecutor = RuntimeExecutor.create('xai', { + apiKey: 'xai-key' + }) + + await xaiExecutor.generateImage({ model: 'grok-2-image', prompt: 'A futuristic robot' }) + + expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai|grok-2-image') + }) + }) + + describe('Advanced features', () => { + it('should support batch image generation with maxImagesPerCall', async () => { + await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', n: 10, maxImagesPerCall: 5 }) + + expect(aiGenerateImage).toHaveBeenCalledWith({ + model: mockImageModel, + prompt: 'A test image', + n: 10, + maxImagesPerCall: 5 + }) + }) + + it('should support retries with maxRetries', async () => { + await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', maxRetries: 3 }) + + expect(aiGenerateImage).toHaveBeenCalledWith({ + model: mockImageModel, + prompt: 'A test image', + maxRetries: 3 + }) + }) + + it('should handle warnings from the model', async () => { + const resultWithWarnings = { + ...mockGenerateImageResult, + warnings: [ + { + type: 'unsupported-setting', + message: 'Size parameter not supported for this model' + } + ] + } + + vi.mocked(aiGenerateImage).mockResolvedValue(resultWithWarnings) + + const result = await executor.generateImage({ + model: 'dall-e-3', + prompt: 'A test image', + size: '2048x2048' // Unsupported size + }) + + expect(result.warnings).toHaveLength(1) + expect(result.warnings[0].type).toBe('unsupported-setting') + }) + + it('should provide access to provider metadata', async () => { + const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' }) + + expect(result.providerMetadata).toBeDefined() + expect(result.providerMetadata.openai).toBeDefined() + }) + + it('should provide response metadata', async () => { + const resultWithMetadata = { + ...mockGenerateImageResult, + responses: [ + { + timestamp: new Date(), + modelId: 'dall-e-3', + headers: { 'x-request-id': 'test-123' } + } + ] + } + + vi.mocked(aiGenerateImage).mockResolvedValue(resultWithMetadata) + + const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' }) + + expect(result.responses).toHaveLength(1) + expect(result.responses[0].modelId).toBe('dall-e-3') + expect(result.responses[0].headers).toEqual({ 'x-request-id': 'test-123' }) + }) + }) +}) diff --git a/packages/aiCore/src/core/runtime/errors.ts b/packages/aiCore/src/core/runtime/errors.ts new file mode 100644 index 0000000000..f3d7cbd1f5 --- /dev/null +++ b/packages/aiCore/src/core/runtime/errors.ts @@ -0,0 +1,38 @@ +/** + * Error classes for runtime operations + */ + +/** + * Error thrown when image generation fails + */ +export class ImageGenerationError extends Error { + constructor( + message: string, + public providerId?: string, + public modelId?: string, + public cause?: Error + ) { + super(message) + this.name = 'ImageGenerationError' + + // Maintain proper stack trace (for V8 engines) + if (Error.captureStackTrace) { + Error.captureStackTrace(this, ImageGenerationError) + } + } +} + +/** + * Error thrown when model resolution fails during image generation + */ +export class ImageModelResolutionError extends ImageGenerationError { + constructor(modelId: string, providerId?: string, cause?: Error) { + super( + `Failed to resolve image model: ${modelId}${providerId ? ` for provider: ${providerId}` : ''}`, + providerId, + modelId, + cause + ) + this.name = 'ImageModelResolutionError' + } +} diff --git a/packages/aiCore/src/core/runtime/executor.ts b/packages/aiCore/src/core/runtime/executor.ts new file mode 100644 index 0000000000..ab80f9cecc --- /dev/null +++ b/packages/aiCore/src/core/runtime/executor.ts @@ -0,0 +1,321 @@ +/** + * 运行时执行器 + * 专注于插件化的AI调用处理 + */ +import { ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider' +import { + experimental_generateImage as generateImage, + generateObject, + generateText, + LanguageModel, + streamObject, + streamText +} from 'ai' + +import { globalModelResolver } from '../models' +import { type ModelConfig } from '../models/types' +import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins' +import { type ProviderId } from '../providers' +import { ImageGenerationError, ImageModelResolutionError } from './errors' +import { PluginEngine } from './pluginEngine' +import { type RuntimeConfig } from './types' + +export class RuntimeExecutor { + public pluginEngine: PluginEngine + // private options: ProviderSettingsMap[T] + private config: RuntimeConfig + + constructor(config: RuntimeConfig) { + // if (!isProviderSupported(config.providerId)) { + // throw new Error(`Unsupported provider: ${config.providerId}`) + // } + + // 存储options供后续使用 + // this.options = config.options + this.config = config + // 创建插件客户端 + this.pluginEngine = new PluginEngine(config.providerId, config.plugins || []) + } + + private createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) { + return definePlugin({ + name: '_internal_resolveModel', + enforce: 'post', + + resolveModel: async (modelId: string) => { + // 注意:extraModelConfig 暂时不支持,已在新架构中移除 + return await this.resolveModel(modelId, middlewares) + } + }) + } + + private createResolveImageModelPlugin() { + return definePlugin({ + name: '_internal_resolveImageModel', + enforce: 'post', + + resolveModel: async (modelId: string) => { + return await this.resolveImageModel(modelId) + } + }) + } + + private createConfigureContextPlugin() { + return definePlugin({ + name: '_internal_configureContext', + configureContext: async (context: AiRequestContext) => { + context.executor = this + } + }) + } + + // === 高阶重载:直接使用模型 === + + /** + * 流式文本生成 + */ + async streamText( + params: Parameters[0], + options?: { + middlewares?: LanguageModelV2Middleware[] + } + ): Promise> { + const { model, ...restParams } = params + + // 根据 model 类型决定插件配置 + if (typeof model === 'string') { + this.pluginEngine.usePlugins([ + this.createResolveModelPlugin(options?.middlewares), + this.createConfigureContextPlugin() + ]) + } else { + this.pluginEngine.usePlugins([this.createConfigureContextPlugin()]) + } + + return this.pluginEngine.executeStreamWithPlugins( + 'streamText', + model, + restParams, + async (resolvedModel, transformedParams, streamTransforms) => { + const experimental_transform = + params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined) + + const finalParams = { + model: resolvedModel, + ...transformedParams, + experimental_transform + } as Parameters[0] + + return await streamText(finalParams) + } + ) + } + + // === 其他方法的重载 === + + /** + * 生成文本 + */ + async generateText( + params: Parameters[0], + options?: { + middlewares?: LanguageModelV2Middleware[] + } + ): Promise> { + const { model, ...restParams } = params + + // 根据 model 类型决定插件配置 + if (typeof model === 'string') { + this.pluginEngine.usePlugins([ + this.createResolveModelPlugin(options?.middlewares), + this.createConfigureContextPlugin() + ]) + } else { + this.pluginEngine.usePlugins([this.createConfigureContextPlugin()]) + } + + return this.pluginEngine.executeWithPlugins( + 'generateText', + model, + restParams, + async (resolvedModel, transformedParams) => + generateText({ model: resolvedModel, ...transformedParams } as Parameters[0]) + ) + } + + /** + * 生成结构化对象 + */ + async generateObject( + params: Parameters[0], + options?: { + middlewares?: LanguageModelV2Middleware[] + } + ): Promise> { + const { model, ...restParams } = params + + // 根据 model 类型决定插件配置 + if (typeof model === 'string') { + this.pluginEngine.usePlugins([ + this.createResolveModelPlugin(options?.middlewares), + this.createConfigureContextPlugin() + ]) + } else { + this.pluginEngine.usePlugins([this.createConfigureContextPlugin()]) + } + + return this.pluginEngine.executeWithPlugins( + 'generateObject', + model, + restParams, + async (resolvedModel, transformedParams) => + generateObject({ model: resolvedModel, ...transformedParams } as Parameters[0]) + ) + } + + /** + * 流式生成结构化对象 + */ + async streamObject( + params: Parameters[0], + options?: { + middlewares?: LanguageModelV2Middleware[] + } + ): Promise> { + const { model, ...restParams } = params + + // 根据 model 类型决定插件配置 + if (typeof model === 'string') { + this.pluginEngine.usePlugins([ + this.createResolveModelPlugin(options?.middlewares), + this.createConfigureContextPlugin() + ]) + } else { + this.pluginEngine.usePlugins([this.createConfigureContextPlugin()]) + } + + return this.pluginEngine.executeWithPlugins( + 'streamObject', + model, + restParams, + async (resolvedModel, transformedParams) => + streamObject({ model: resolvedModel, ...transformedParams } as Parameters[0]) + ) + } + + /** + * 生成图像 + */ + async generateImage( + params: Omit[0], 'model'> & { model: string | ImageModelV2 } + ): Promise> { + try { + const { model, ...restParams } = params + + // 根据 model 类型决定插件配置 + if (typeof model === 'string') { + this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()]) + } else { + this.pluginEngine.usePlugins([this.createConfigureContextPlugin()]) + } + + return await this.pluginEngine.executeImageWithPlugins( + 'generateImage', + model, + restParams, + async (resolvedModel, transformedParams) => { + return await generateImage({ model: resolvedModel, ...transformedParams }) + } + ) + } catch (error) { + if (error instanceof Error) { + const modelId = typeof params.model === 'string' ? params.model : params.model.modelId + throw new ImageGenerationError( + `Failed to generate image: ${error.message}`, + this.config.providerId, + modelId, + error + ) + } + throw error + } + } + + // === 辅助方法 === + + /** + * 解析模型:如果是字符串则创建模型,如果是模型则直接返回 + */ + private async resolveModel( + modelOrId: LanguageModel, + middlewares?: LanguageModelV2Middleware[] + ): Promise { + if (typeof modelOrId === 'string') { + // 🎯 字符串modelId,使用新的ModelResolver解析,传递完整参数 + return await globalModelResolver.resolveLanguageModel( + modelOrId, // 支持 'gpt-4' 和 'aihubmix:anthropic:claude-3.5-sonnet' + this.config.providerId, // fallback provider + this.config.providerSettings, // provider options + middlewares // 中间件数组 + ) + } else { + // 已经是模型,直接返回 + return modelOrId + } + } + + /** + * 解析图像模型:如果是字符串则创建图像模型,如果是模型则直接返回 + */ + private async resolveImageModel(modelOrId: ImageModelV2 | string): Promise { + try { + if (typeof modelOrId === 'string') { + // 字符串modelId,使用新的ModelResolver解析 + return await globalModelResolver.resolveImageModel( + modelOrId, // 支持 'dall-e-3' 和 'aihubmix:openai:dall-e-3' + this.config.providerId // fallback provider + ) + } else { + // 已经是模型,直接返回 + return modelOrId + } + } catch (error) { + throw new ImageModelResolutionError( + typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId, + this.config.providerId, + error instanceof Error ? error : undefined + ) + } + } + + // === 静态工厂方法 === + + /** + * 创建执行器 - 支持已知provider的类型安全 + */ + static create( + providerId: T, + options: ModelConfig['providerSettings'], + plugins?: AiPlugin[] + ): RuntimeExecutor { + return new RuntimeExecutor({ + providerId, + providerSettings: options, + plugins + }) + } + + /** + * 创建OpenAI Compatible执行器 + */ + static createOpenAICompatible( + options: ModelConfig<'openai-compatible'>['providerSettings'], + plugins: AiPlugin[] = [] + ): RuntimeExecutor<'openai-compatible'> { + return new RuntimeExecutor({ + providerId: 'openai-compatible', + providerSettings: options, + plugins + }) + } +} diff --git a/packages/aiCore/src/core/runtime/index.ts b/packages/aiCore/src/core/runtime/index.ts new file mode 100644 index 0000000000..37aa4fec34 --- /dev/null +++ b/packages/aiCore/src/core/runtime/index.ts @@ -0,0 +1,117 @@ +/** + * Runtime 模块导出 + * 专注于运行时插件化AI调用处理 + */ + +// 主要的运行时执行器 +export { RuntimeExecutor } from './executor' + +// 导出类型 +export type { RuntimeConfig } from './types' + +// === 便捷工厂函数 === + +import { LanguageModelV2Middleware } from '@ai-sdk/provider' + +import { type AiPlugin } from '../plugins' +import { type ProviderId, type ProviderSettingsMap } from '../providers/types' +import { RuntimeExecutor } from './executor' + +/** + * 创建运行时执行器 - 支持类型安全的已知provider + */ +export function createExecutor( + providerId: T, + options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }, + plugins?: AiPlugin[] +): RuntimeExecutor { + return RuntimeExecutor.create(providerId, options, plugins) +} + +/** + * 创建OpenAI Compatible执行器 + */ +export function createOpenAICompatibleExecutor( + options: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' }, + plugins: AiPlugin[] = [] +): RuntimeExecutor<'openai-compatible'> { + return RuntimeExecutor.createOpenAICompatible(options, plugins) +} + +// === 直接调用API(无需创建executor实例)=== + +/** + * 直接流式文本生成 - 支持middlewares + */ +export async function streamText( + providerId: T, + options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }, + params: Parameters['streamText']>[0], + plugins?: AiPlugin[], + middlewares?: LanguageModelV2Middleware[] +): Promise['streamText']>> { + const executor = createExecutor(providerId, options, plugins) + return executor.streamText(params, { middlewares }) +} + +/** + * 直接生成文本 - 支持middlewares + */ +export async function generateText( + providerId: T, + options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }, + params: Parameters['generateText']>[0], + plugins?: AiPlugin[], + middlewares?: LanguageModelV2Middleware[] +): Promise['generateText']>> { + const executor = createExecutor(providerId, options, plugins) + return executor.generateText(params, { middlewares }) +} + +/** + * 直接生成结构化对象 - 支持middlewares + */ +export async function generateObject( + providerId: T, + options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }, + params: Parameters['generateObject']>[0], + plugins?: AiPlugin[], + middlewares?: LanguageModelV2Middleware[] +): Promise['generateObject']>> { + const executor = createExecutor(providerId, options, plugins) + return executor.generateObject(params, { middlewares }) +} + +/** + * 直接流式生成结构化对象 - 支持middlewares + */ +export async function streamObject( + providerId: T, + options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }, + params: Parameters['streamObject']>[0], + plugins?: AiPlugin[], + middlewares?: LanguageModelV2Middleware[] +): Promise['streamObject']>> { + const executor = createExecutor(providerId, options, plugins) + return executor.streamObject(params, { middlewares }) +} + +/** + * 直接生成图像 - 支持middlewares + */ +export async function generateImage( + providerId: T, + options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }, + params: Parameters['generateImage']>[0], + plugins?: AiPlugin[] +): Promise['generateImage']>> { + const executor = createExecutor(providerId, options, plugins) + return executor.generateImage(params) +} + +// === Agent 功能预留 === +// 未来将在 ../agents/ 文件夹中添加: +// - AgentExecutor.ts +// - WorkflowManager.ts +// - ConversationManager.ts +// 并在此处导出相关API diff --git a/packages/aiCore/src/core/runtime/pluginEngine.ts b/packages/aiCore/src/core/runtime/pluginEngine.ts new file mode 100644 index 0000000000..7a4bb440f7 --- /dev/null +++ b/packages/aiCore/src/core/runtime/pluginEngine.ts @@ -0,0 +1,290 @@ +/* eslint-disable @eslint-react/naming-convention/context-name */ +import { ImageModelV2 } from '@ai-sdk/provider' +import { LanguageModel } from 'ai' + +import { type AiPlugin, createContext, PluginManager } from '../plugins' +import { type ProviderId } from '../providers/types' + +/** + * 插件增强的 AI 客户端 + * 专注于插件处理,不暴露用户API + */ +export class PluginEngine { + private pluginManager: PluginManager + + constructor( + private readonly providerId: T, + // private readonly options: ProviderSettingsMap[T], + plugins: AiPlugin[] = [] + ) { + this.pluginManager = new PluginManager(plugins) + } + + /** + * 添加插件 + */ + use(plugin: AiPlugin): this { + this.pluginManager.use(plugin) + return this + } + + /** + * 批量添加插件 + */ + usePlugins(plugins: AiPlugin[]): this { + plugins.forEach((plugin) => this.use(plugin)) + return this + } + + /** + * 移除插件 + */ + removePlugin(pluginName: string): this { + this.pluginManager.remove(pluginName) + return this + } + + /** + * 获取插件统计 + */ + getPluginStats() { + return this.pluginManager.getStats() + } + + /** + * 获取所有插件 + */ + getPlugins() { + return this.pluginManager.getPlugins() + } + + /** + * 执行带插件的操作(非流式) + * 提供给AiExecutor使用 + */ + async executeWithPlugins( + methodName: string, + model: LanguageModel, + params: TParams, + executor: (model: LanguageModel, transformedParams: TParams) => Promise, + _context?: ReturnType + ): Promise { + // 统一处理模型解析 + let resolvedModel: LanguageModel | undefined + let modelId: string + + if (typeof model === 'string') { + // 字符串:需要通过插件解析 + modelId = model + } else { + // 模型对象:直接使用 + resolvedModel = model + modelId = model.modelId + } + + // 使用正确的createContext创建请求上下文 + const context = _context ? _context : createContext(this.providerId, modelId, params) + + // 🔥 为上下文添加递归调用能力 + context.recursiveCall = async (newParams: any): Promise => { + // 递归调用自身,重新走完整的插件流程 + context.isRecursiveCall = true + const result = await this.executeWithPlugins(methodName, model, newParams, executor, context) + context.isRecursiveCall = false + return result + } + + try { + // 0. 配置上下文 + await this.pluginManager.executeConfigureContext(context) + + // 1. 触发请求开始事件 + await this.pluginManager.executeParallel('onRequestStart', context) + + // 2. 解析模型(如果是字符串) + if (typeof model === 'string') { + const resolved = await this.pluginManager.executeFirst('resolveModel', modelId, context) + if (!resolved) { + throw new Error(`Failed to resolve model: ${modelId}`) + } + resolvedModel = resolved + } + + if (!resolvedModel) { + throw new Error(`Model resolution failed: no model available`) + } + + // 3. 转换请求参数 + const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context) + + // 4. 执行具体的 API 调用 + const result = await executor(resolvedModel, transformedParams) + + // 5. 转换结果(对于非流式调用) + const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context) + + // 6. 触发完成事件 + await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult) + + return transformedResult + } catch (error) { + // 7. 触发错误事件 + await this.pluginManager.executeParallel('onError', context, undefined, error as Error) + throw error + } + } + + /** + * 执行带插件的图像生成操作 + * 提供给AiExecutor使用 + */ + async executeImageWithPlugins( + methodName: string, + model: ImageModelV2 | string, + params: TParams, + executor: (model: ImageModelV2, transformedParams: TParams) => Promise, + _context?: ReturnType + ): Promise { + // 统一处理模型解析 + let resolvedModel: ImageModelV2 | undefined + let modelId: string + + if (typeof model === 'string') { + // 字符串:需要通过插件解析 + modelId = model + } else { + // 模型对象:直接使用 + resolvedModel = model + modelId = model.modelId + } + + // 使用正确的createContext创建请求上下文 + const context = _context ? _context : createContext(this.providerId, modelId, params) + + // 🔥 为上下文添加递归调用能力 + context.recursiveCall = async (newParams: any): Promise => { + // 递归调用自身,重新走完整的插件流程 + context.isRecursiveCall = true + const result = await this.executeImageWithPlugins(methodName, model, newParams, executor, context) + context.isRecursiveCall = false + return result + } + + try { + // 0. 配置上下文 + await this.pluginManager.executeConfigureContext(context) + + // 1. 触发请求开始事件 + await this.pluginManager.executeParallel('onRequestStart', context) + + // 2. 解析模型(如果是字符串) + if (typeof model === 'string') { + const resolved = await this.pluginManager.executeFirst('resolveModel', modelId, context) + if (!resolved) { + throw new Error(`Failed to resolve image model: ${modelId}`) + } + resolvedModel = resolved + } + + if (!resolvedModel) { + throw new Error(`Image model resolution failed: no model available`) + } + + // 3. 转换请求参数 + const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context) + + // 4. 执行具体的 API 调用 + const result = await executor(resolvedModel, transformedParams) + + // 5. 转换结果 + const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context) + + // 6. 触发完成事件 + await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult) + + return transformedResult + } catch (error) { + // 7. 触发错误事件 + await this.pluginManager.executeParallel('onError', context, undefined, error as Error) + throw error + } + } + + /** + * 执行流式调用的通用逻辑(支持流转换器) + * 提供给AiExecutor使用 + */ + async executeStreamWithPlugins( + methodName: string, + model: LanguageModel, + params: TParams, + executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise, + _context?: ReturnType + ): Promise { + // 统一处理模型解析 + let resolvedModel: LanguageModel | undefined + let modelId: string + + if (typeof model === 'string') { + // 字符串:需要通过插件解析 + modelId = model + } else { + // 模型对象:直接使用 + resolvedModel = model + modelId = model.modelId + } + + // 创建请求上下文 + const context = _context ? _context : createContext(this.providerId, modelId, params) + + // 🔥 为上下文添加递归调用能力 + context.recursiveCall = async (newParams: any): Promise => { + // 递归调用自身,重新走完整的插件流程 + context.isRecursiveCall = true + const result = await this.executeStreamWithPlugins(methodName, model, newParams, executor, context) + context.isRecursiveCall = false + return result + } + + try { + // 0. 配置上下文 + await this.pluginManager.executeConfigureContext(context) + + // 1. 触发请求开始事件 + await this.pluginManager.executeParallel('onRequestStart', context) + + // 2. 解析模型(如果是字符串) + if (typeof model === 'string') { + const resolved = await this.pluginManager.executeFirst('resolveModel', modelId, context) + if (!resolved) { + throw new Error(`Failed to resolve model: ${modelId}`) + } + resolvedModel = resolved + } + + if (!resolvedModel) { + throw new Error(`Model resolution failed: no model available`) + } + + // 3. 转换请求参数 + const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context) + + // 4. 收集流转换器 + const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context) + + // 5. 执行流式 API 调用 + const result = await executor(resolvedModel, transformedParams, streamTransforms) + + const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context) + + // 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件) + await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult) + + return transformedResult + } catch (error) { + // 7. 触发错误事件 + await this.pluginManager.executeParallel('onError', context, undefined, error as Error) + throw error + } + } +} diff --git a/packages/aiCore/src/core/runtime/types.ts b/packages/aiCore/src/core/runtime/types.ts new file mode 100644 index 0000000000..f98e9034c6 --- /dev/null +++ b/packages/aiCore/src/core/runtime/types.ts @@ -0,0 +1,15 @@ +/** + * Runtime 层类型定义 + */ +import { type ModelConfig } from '../models/types' +import { type AiPlugin } from '../plugins' +import { type ProviderId } from '../providers/types' + +/** + * 运行时执行器配置 + */ +export interface RuntimeConfig { + providerId: T + providerSettings: ModelConfig['providerSettings'] & { mode?: 'chat' | 'responses' } + plugins?: AiPlugin[] +} diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts new file mode 100644 index 0000000000..9db95e512c --- /dev/null +++ b/packages/aiCore/src/index.ts @@ -0,0 +1,46 @@ +/** + * Cherry Studio AI Core Package + * 基于 Vercel AI SDK 的统一 AI Provider 接口 + */ + +// 导入内部使用的类和函数 + +// ==================== 主要用户接口 ==================== +export { + createExecutor, + createOpenAICompatibleExecutor, + generateImage, + generateObject, + generateText, + streamText +} from './core/runtime' + +// ==================== 高级API ==================== +export { globalModelResolver as modelResolver } from './core/models' + +// ==================== 插件系统 ==================== +export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins' +export { createContext, definePlugin, PluginManager } from './core/plugins' +// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in' +export { PluginEngine } from './core/runtime/pluginEngine' + +// ==================== AI SDK 常用类型导出 ==================== +// 直接导出 AI SDK 的常用类型,方便使用 +export type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider' +export type { ToolCall } from '@ai-sdk/provider-utils' +export type { ReasoningPart } from '@ai-sdk/provider-utils' + +// ==================== 选项 ==================== +export { + createAnthropicOptions, + createGoogleOptions, + createOpenAIOptions, + type ExtractProviderOptions, + mergeProviderOptions, + type ProviderOptionsMap, + type TypedProviderOptions +} from './core/options' + +// ==================== 包信息 ==================== +export const AI_CORE_VERSION = '1.0.0' +export const AI_CORE_NAME = '@cherrystudio/ai-core' diff --git a/packages/aiCore/src/types.ts b/packages/aiCore/src/types.ts new file mode 100644 index 0000000000..d7796a943d --- /dev/null +++ b/packages/aiCore/src/types.ts @@ -0,0 +1,2 @@ +// 重新导出插件类型 +export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types' diff --git a/packages/aiCore/tsconfig.json b/packages/aiCore/tsconfig.json new file mode 100644 index 0000000000..9ee30166c1 --- /dev/null +++ b/packages/aiCore/tsconfig.json @@ -0,0 +1,26 @@ +{ + "compilerOptions": { + "target": "ES2020", + "module": "ESNext", + "moduleResolution": "bundler", + "declaration": true, + "outDir": "./dist", + "rootDir": "./src", + "strict": true, + "esModuleInterop": true, + "skipLibCheck": true, + "forceConsistentCasingInFileNames": true, + "resolveJsonModule": true, + "allowSyntheticDefaultImports": true, + "noEmitOnError": false, + "experimentalDecorators": true, + "emitDecoratorMetadata": true + }, + "include": [ + "src/**/*" + ], + "exclude": [ + "node_modules", + "dist" + ] +} \ No newline at end of file diff --git a/packages/aiCore/tsdown.config.ts b/packages/aiCore/tsdown.config.ts new file mode 100644 index 0000000000..6f1e6978f3 --- /dev/null +++ b/packages/aiCore/tsdown.config.ts @@ -0,0 +1,14 @@ +import { defineConfig } from 'tsdown' + +export default defineConfig({ + entry: { + index: 'src/index.ts', + 'built-in/plugins/index': 'src/core/plugins/built-in/index.ts', + 'provider/index': 'src/core/providers/index.ts' + }, + outDir: 'dist', + format: ['esm', 'cjs'], + clean: true, + dts: true, + tsconfig: 'tsconfig.json' +}) diff --git a/packages/aiCore/vitest.config.ts b/packages/aiCore/vitest.config.ts new file mode 100644 index 0000000000..0cc6b51df4 --- /dev/null +++ b/packages/aiCore/vitest.config.ts @@ -0,0 +1,15 @@ +import { defineConfig } from 'vitest/config' + +export default defineConfig({ + test: { + globals: true + }, + resolve: { + alias: { + '@': './src' + } + }, + esbuild: { + target: 'node18' + } +}) diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index 9e2c3f88d7..538a0ede79 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -570,7 +570,8 @@ class McpService { ...tool, id: buildFunctionCallToolName(server.name, tool.name), serverId: server.id, - serverName: server.name + serverName: server.name, + type: 'mcp' } serverTools.push(serverTool) }) diff --git a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts new file mode 100644 index 0000000000..183e469cdd --- /dev/null +++ b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts @@ -0,0 +1,314 @@ +/** + * AI SDK 到 Cherry Studio Chunk 适配器 + * 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式 + */ + +import { loggerService } from '@logger' +import { MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types' +import { Chunk, ChunkType } from '@renderer/types/chunk' +import type { TextStreamPart, ToolSet } from 'ai' + +import { ToolCallChunkHandler } from './handleToolCallChunk' + +const logger = loggerService.withContext('AiSdkToChunkAdapter') + +export interface CherryStudioChunk { + type: 'text-delta' | 'text-complete' | 'tool-call' | 'tool-result' | 'finish' | 'error' + text?: string + toolCall?: any + toolResult?: any + finishReason?: string + usage?: any + error?: any +} + +/** + * AI SDK 到 Cherry Studio Chunk 适配器类 + * 处理 fullStream 到 Cherry Studio chunk 的转换 + */ +export class AiSdkToChunkAdapter { + toolCallHandler: ToolCallChunkHandler + private accumulate: boolean | undefined + constructor( + private onChunk: (chunk: Chunk) => void, + mcpTools: MCPTool[] = [], + accumulate?: boolean + ) { + this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools) + this.accumulate = accumulate + } + + /** + * 处理 AI SDK 流结果 + * @param aiSdkResult AI SDK 的流结果对象 + * @returns 最终的文本内容 + */ + async processStream(aiSdkResult: any): Promise { + // 如果是流式且有 fullStream + if (aiSdkResult.fullStream) { + await this.readFullStream(aiSdkResult.fullStream) + } + + // 使用 streamResult.text 获取最终结果 + return await aiSdkResult.text + } + + /** + * 读取 fullStream 并转换为 Cherry Studio chunks + * @param fullStream AI SDK 的 fullStream (ReadableStream) + */ + private async readFullStream(fullStream: ReadableStream>) { + const reader = fullStream.getReader() + const final = { + text: '', + reasoningContent: '', + webSearchResults: [], + reasoningId: '' + } + try { + while (true) { + const { done, value } = await reader.read() + + if (done) { + break + } + + // 转换并发送 chunk + this.convertAndEmitChunk(value, final) + } + } finally { + reader.releaseLock() + } + } + + /** + * 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调 + * @param chunk AI SDK 的 chunk 数据 + */ + private convertAndEmitChunk( + chunk: TextStreamPart, + final: { text: string; reasoningContent: string; webSearchResults: any[]; reasoningId: string } + ) { + logger.info(`AI SDK chunk type: ${chunk.type}`, chunk) + switch (chunk.type) { + // === 文本相关事件 === + case 'text-start': + this.onChunk({ + type: ChunkType.TEXT_START + }) + break + case 'text-delta': + if (this.accumulate) { + final.text += chunk.text || '' + } else { + final.text = chunk.text || '' + } + this.onChunk({ + type: ChunkType.TEXT_DELTA, + text: final.text || '' + }) + break + case 'text-end': + this.onChunk({ + type: ChunkType.TEXT_COMPLETE, + text: (chunk.providerMetadata?.text?.value as string) ?? final.text ?? '' + }) + final.text = '' + break + case 'reasoning-start': + // if (final.reasoningId !== chunk.id) { + final.reasoningId = chunk.id + this.onChunk({ + type: ChunkType.THINKING_START + }) + // } + break + case 'reasoning-delta': + final.reasoningContent += chunk.text || '' + this.onChunk({ + type: ChunkType.THINKING_DELTA, + text: final.reasoningContent || '', + thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0 + }) + break + case 'reasoning-end': + this.onChunk({ + type: ChunkType.THINKING_COMPLETE, + text: (chunk.providerMetadata?.metadata?.thinking_content as string) || '', + thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0 + }) + final.reasoningContent = '' + break + + // === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) === + + // case 'tool-input-start': + // case 'tool-input-delta': + // case 'tool-input-end': + // this.toolCallHandler.handleToolCallCreated(chunk) + // break + + // case 'tool-input-delta': + // this.toolCallHandler.handleToolCallCreated(chunk) + // break + case 'tool-call': + // 原始的工具调用(未被中间件处理) + this.toolCallHandler.handleToolCall(chunk) + break + + case 'tool-result': + // 原始的工具调用结果(未被中间件处理) + this.toolCallHandler.handleToolResult(chunk) + break + + // === 步骤相关事件 === + // case 'start': + // this.onChunk({ + // type: ChunkType.LLM_RESPONSE_CREATED + // }) + // break + // TODO: 需要区分接口开始和步骤开始 + // case 'start-step': + // this.onChunk({ + // type: ChunkType.BLOCK_CREATED + // }) + // break + // case 'step-finish': + // this.onChunk({ + // type: ChunkType.TEXT_COMPLETE, + // text: final.text || '' // TEXT_COMPLETE 需要 text 字段 + // }) + // final.text = '' + // break + + case 'finish-step': { + const { providerMetadata, finishReason } = chunk + // googel web search + if (providerMetadata?.google?.groundingMetadata) { + this.onChunk({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + results: providerMetadata.google?.groundingMetadata as WebSearchResults, + source: WebSearchSource.GEMINI + } + }) + } else if (final.webSearchResults.length) { + const providerName = Object.keys(providerMetadata || {})[0] + const sourceMap: Record = { + [WebSearchSource.OPENAI]: WebSearchSource.OPENAI_RESPONSE, + [WebSearchSource.ANTHROPIC]: WebSearchSource.ANTHROPIC, + [WebSearchSource.OPENROUTER]: WebSearchSource.OPENROUTER, + [WebSearchSource.GEMINI]: WebSearchSource.GEMINI, + [WebSearchSource.PERPLEXITY]: WebSearchSource.PERPLEXITY, + [WebSearchSource.QWEN]: WebSearchSource.QWEN, + [WebSearchSource.HUNYUAN]: WebSearchSource.HUNYUAN, + [WebSearchSource.ZHIPU]: WebSearchSource.ZHIPU, + [WebSearchSource.GROK]: WebSearchSource.GROK, + [WebSearchSource.WEBSEARCH]: WebSearchSource.WEBSEARCH + } + const source = sourceMap[providerName] || WebSearchSource.AISDK + + this.onChunk({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + results: final.webSearchResults, + source + } + }) + } + if (finishReason === 'tool-calls') { + this.onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) + } + + final.webSearchResults = [] + // final.reasoningId = '' + break + } + + case 'finish': + this.onChunk({ + type: ChunkType.BLOCK_COMPLETE, + response: { + text: final.text || '', + reasoning_content: final.reasoningContent || '', + usage: { + completion_tokens: chunk.totalUsage.outputTokens || 0, + prompt_tokens: chunk.totalUsage.inputTokens || 0, + total_tokens: chunk.totalUsage.totalTokens || 0 + }, + metrics: chunk.totalUsage + ? { + completion_tokens: chunk.totalUsage.outputTokens || 0, + time_completion_millsec: 0 + } + : undefined + } + }) + this.onChunk({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + text: final.text || '', + reasoning_content: final.reasoningContent || '', + usage: { + completion_tokens: chunk.totalUsage.outputTokens || 0, + prompt_tokens: chunk.totalUsage.inputTokens || 0, + total_tokens: chunk.totalUsage.totalTokens || 0 + }, + metrics: chunk.totalUsage + ? { + completion_tokens: chunk.totalUsage.outputTokens || 0, + time_completion_millsec: 0 + } + : undefined + } + }) + break + + // === 源和文件相关事件 === + case 'source': + if (chunk.sourceType === 'url') { + // if (final.webSearchResults.length === 0) { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { sourceType: _, ...rest } = chunk + final.webSearchResults.push(rest) + // } + // this.onChunk({ + // type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + // llm_web_search: { + // source: WebSearchSource.AISDK, + // results: final.webSearchResults + // } + // }) + } + break + case 'file': + // 文件相关事件,可能是图片生成 + this.onChunk({ + type: ChunkType.IMAGE_COMPLETE, + image: { + type: 'base64', + images: [`data:${chunk.file.mediaType};base64,${chunk.file.base64}`] + } + }) + break + case 'abort': + this.onChunk({ + type: ChunkType.ERROR, + error: new DOMException('Request was aborted', 'AbortError') + }) + break + case 'error': + this.onChunk({ + type: ChunkType.ERROR, + error: chunk.error as Record + }) + break + + default: + // 其他类型的 chunk 可以忽略或记录日志 + // console.log('Unhandled AI SDK chunk type:', chunk.type, chunk) + } + } +} + +export default AiSdkToChunkAdapter diff --git a/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts b/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts new file mode 100644 index 0000000000..ebe3f4e5c0 --- /dev/null +++ b/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts @@ -0,0 +1,266 @@ +/** + * 工具调用 Chunk 处理模块 + * TODO: Tool包含了providerTool和普通的Tool还有MCPTool,后面需要重构 + * 提供工具调用相关的处理API,每个交互使用一个新的实例 + */ + +import { loggerService } from '@logger' +import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types' +import { Chunk, ChunkType } from '@renderer/types/chunk' +import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from 'ai' +// import type { +// AnthropicSearchOutput, +// WebSearchPluginConfig +// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin' + +const logger = loggerService.withContext('ToolCallChunkHandler') + +/** + * 工具调用处理器类 + */ +export class ToolCallChunkHandler { + // private onChunk: (chunk: Chunk) => void + private activeToolCalls = new Map< + string, + { + toolCallId: string + toolName: string + args: any + // mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型 + tool: BaseTool + } + >() + constructor( + private onChunk: (chunk: Chunk) => void, + private mcpTools: MCPTool[] + ) {} + + // /** + // * 设置 onChunk 回调 + // */ + // public setOnChunk(callback: (chunk: Chunk) => void): void { + // this.onChunk = callback + // } + + handleToolCallCreated( + chunk: + | { + type: 'tool-input-start' + id: string + toolName: string + providerMetadata?: ProviderMetadata + providerExecuted?: boolean + } + | { + type: 'tool-input-end' + id: string + providerMetadata?: ProviderMetadata + } + | { + type: 'tool-input-delta' + id: string + delta: string + providerMetadata?: ProviderMetadata + } + ): void { + switch (chunk.type) { + case 'tool-input-start': { + // 能拿到说明是mcpTool + // if (this.activeToolCalls.get(chunk.id)) return + + const tool: BaseTool | MCPTool = { + id: chunk.id, + name: chunk.toolName, + description: chunk.toolName, + type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider' + } + this.activeToolCalls.set(chunk.id, { + toolCallId: chunk.id, + toolName: chunk.toolName, + args: '', + tool + }) + const toolResponse: MCPToolResponse | NormalToolResponse = { + id: chunk.id, + tool: tool, + arguments: {}, + status: 'pending', + toolCallId: chunk.id + } + this.onChunk({ + type: ChunkType.MCP_TOOL_PENDING, + responses: [toolResponse] + }) + break + } + case 'tool-input-delta': { + const toolCall = this.activeToolCalls.get(chunk.id) + if (!toolCall) { + logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`) + return + } + toolCall.args += chunk.delta + break + } + case 'tool-input-end': { + const toolCall = this.activeToolCalls.get(chunk.id) + this.activeToolCalls.delete(chunk.id) + if (!toolCall) { + logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`) + return + } + // const toolResponse: ToolCallResponse = { + // id: toolCall.toolCallId, + // tool: toolCall.tool, + // arguments: toolCall.args, + // status: 'pending', + // toolCallId: toolCall.toolCallId + // } + // logger.debug('toolResponse', toolResponse) + // this.onChunk({ + // type: ChunkType.MCP_TOOL_PENDING, + // responses: [toolResponse] + // }) + break + } + } + // if (!toolCall) { + // Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`) + // return + // } + // this.onChunk({ + // type: ChunkType.MCP_TOOL_CREATED, + // tool_calls: [ + // { + // id: chunk.id, + // name: chunk.toolName, + // status: 'pending' + // } + // ] + // }) + } + + /** + * 处理工具调用事件 + */ + public handleToolCall( + chunk: { + type: 'tool-call' + } & TypedToolCall + ): void { + const { toolCallId, toolName, input: args, providerExecuted } = chunk + + if (!toolCallId || !toolName) { + logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`) + return + } + + let tool: BaseTool + let mcpTool: MCPTool | undefined + + // 根据 providerExecuted 标志区分处理逻辑 + if (providerExecuted) { + // 如果是 Provider 执行的工具(如 web_search) + logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`) + tool = { + id: toolCallId, + name: toolName, + description: toolName, + type: 'provider' + } as BaseTool + } else if (toolName.startsWith('builtin_')) { + // 如果是内置工具,沿用现有逻辑 + logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`) + tool = { + id: toolCallId, + name: toolName, + description: toolName, + type: 'builtin' + } as BaseTool + } else if ((mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool)) { + // 如果是客户端执行的 MCP 工具,沿用现有逻辑 + logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`) + // mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool + // if (!mcpTool) { + // logger.warn(`[ToolCallChunkHandler] MCP tool not found: ${toolName}`) + // return + // } + tool = mcpTool + } else { + tool = { + id: toolCallId, + name: toolName, + description: toolName, + type: 'provider' + } + } + + // 记录活跃的工具调用 + this.activeToolCalls.set(toolCallId, { + toolCallId, + toolName, + args, + tool + }) + + // 创建 MCPToolResponse 格式 + const toolResponse: MCPToolResponse | NormalToolResponse = { + id: toolCallId, + tool: tool, + arguments: args, + status: 'pending', + toolCallId: toolCallId + } + + // 调用 onChunk + if (this.onChunk) { + this.onChunk({ + type: ChunkType.MCP_TOOL_PENDING, + responses: [toolResponse] + }) + } + } + + /** + * 处理工具调用结果事件 + */ + public handleToolResult( + chunk: { + type: 'tool-result' + } & TypedToolResult + ): void { + const { toolCallId, output, input } = chunk + + if (!toolCallId) { + logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`) + return + } + + // 查找对应的工具调用信息 + const toolCallInfo = this.activeToolCalls.get(toolCallId) + if (!toolCallInfo) { + logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`) + return + } + + // 创建工具调用结果的 MCPToolResponse 格式 + const toolResponse: MCPToolResponse | NormalToolResponse = { + id: toolCallInfo.toolCallId, + tool: toolCallInfo.tool, + arguments: input, + status: 'done', + response: output, + toolCallId: toolCallId + } + // 从活跃调用中移除(交互结束后整个实例会被丢弃) + this.activeToolCalls.delete(toolCallId) + + // 调用 onChunk + if (this.onChunk) { + this.onChunk({ + type: ChunkType.MCP_TOOL_COMPLETE, + responses: [toolResponse] + }) + } + } +} diff --git a/src/renderer/src/aiCore/index.ts b/src/renderer/src/aiCore/index.ts index 2b48137b24..eb68da74ea 100644 --- a/src/renderer/src/aiCore/index.ts +++ b/src/renderer/src/aiCore/index.ts @@ -1,189 +1,16 @@ -import { loggerService } from '@logger' -import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory' -import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient' -import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models' -import { getProviderByModel } from '@renderer/services/AssistantService' -import { withSpanResult } from '@renderer/services/SpanManagerService' -import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity' -import type { GenerateImageParams, Model, Provider } from '@renderer/types' -import type { RequestOptions, SdkModel } from '@renderer/types/sdk' -import { isEnabledToolUse } from '@renderer/utils/mcp-tools' +/** + * Cherry Studio AI Core - 统一入口点 + * + * 这是新的统一入口,保持向后兼容性 + * 默认导出legacy AiProvider以保持现有代码的兼容性 + */ -import { AihubmixAPIClient } from './clients/aihubmix/AihubmixAPIClient' -import { VertexAPIClient } from './clients/gemini/VertexAPIClient' -import { NewAPIClient } from './clients/newapi/NewAPIClient' -import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient' -import { CompletionsMiddlewareBuilder } from './middleware/builder' -import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware' -import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware' -import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware' -import { applyCompletionsMiddlewares } from './middleware/composer' -import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware' -import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware' -import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware' -import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware' -import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware' -import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware' -import { MiddlewareRegistry } from './middleware/register' -import type { CompletionsParams, CompletionsResult } from './middleware/schemas' +// 导出Legacy AiProvider作为默认导出(保持向后兼容) +export { default } from './legacy/index' -const logger = loggerService.withContext('AiProvider') +// 同时导出Modern AiProvider供新代码使用 +export { default as ModernAiProvider } from './index_new' -export default class AiProvider { - private apiClient: BaseApiClient - - constructor(provider: Provider) { - // Use the new ApiClientFactory to get a BaseApiClient instance - this.apiClient = ApiClientFactory.create(provider) - } - - public async completions(params: CompletionsParams, options?: RequestOptions): Promise { - // 1. 根据模型识别正确的客户端 - const model = params.assistant.model - if (!model) { - return Promise.reject(new Error('Model is required')) - } - - // 根据client类型选择合适的处理方式 - let client: BaseApiClient - - if (this.apiClient instanceof AihubmixAPIClient) { - // AihubmixAPIClient: 根据模型选择合适的子client - client = this.apiClient.getClientForModel(model) - if (client instanceof OpenAIResponseAPIClient) { - client = client.getClient(model) as BaseApiClient - } - } else if (this.apiClient instanceof NewAPIClient) { - client = this.apiClient.getClientForModel(model) - if (client instanceof OpenAIResponseAPIClient) { - client = client.getClient(model) as BaseApiClient - } - } else if (this.apiClient instanceof OpenAIResponseAPIClient) { - // OpenAIResponseAPIClient: 根据模型特征选择API类型 - client = this.apiClient.getClient(model) as BaseApiClient - } else if (this.apiClient instanceof VertexAPIClient) { - client = this.apiClient.getClient(model) as BaseApiClient - } else { - // 其他client直接使用 - client = this.apiClient - } - - // 2. 构建中间件链 - const builder = CompletionsMiddlewareBuilder.withDefaults() - // images api - if (isDedicatedImageGenerationModel(model)) { - builder.clear() - builder - .add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName]) - .add(MiddlewareRegistry[ErrorHandlerMiddlewareName]) - .add(MiddlewareRegistry[AbortHandlerMiddlewareName]) - .add(MiddlewareRegistry[ImageGenerationMiddlewareName]) - } else { - // Existing logic for other models - logger.silly('Builder Params', params) - // 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题 - const clientTypes = client.getClientCompatibilityType(model) - const isOpenAICompatible = - clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient') - if (!isOpenAICompatible) { - logger.silly('ThinkingTagExtractionMiddleware is removed') - builder.remove(ThinkingTagExtractionMiddlewareName) - } - - const isAnthropicOrOpenAIResponseCompatible = - clientTypes.includes('AnthropicAPIClient') || - clientTypes.includes('OpenAIResponseAPIClient') || - clientTypes.includes('AnthropicVertexAPIClient') - if (!isAnthropicOrOpenAIResponseCompatible) { - logger.silly('RawStreamListenerMiddleware is removed') - builder.remove(RawStreamListenerMiddlewareName) - } - if (!params.enableWebSearch) { - logger.silly('WebSearchMiddleware is removed') - builder.remove(WebSearchMiddlewareName) - } - if (!params.mcpTools?.length) { - builder.remove(ToolUseExtractionMiddlewareName) - logger.silly('ToolUseExtractionMiddleware is removed') - builder.remove(McpToolChunkMiddlewareName) - logger.silly('McpToolChunkMiddleware is removed') - } - if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) { - builder.remove(ToolUseExtractionMiddlewareName) - logger.silly('ToolUseExtractionMiddleware is removed') - } - if (params.callType !== 'chat' && params.callType !== 'check' && params.callType !== 'translate') { - logger.silly('AbortHandlerMiddleware is removed') - builder.remove(AbortHandlerMiddlewareName) - } - if (params.callType === 'test') { - builder.remove(ErrorHandlerMiddlewareName) - logger.silly('ErrorHandlerMiddleware is removed') - builder.remove(FinalChunkConsumerMiddlewareName) - logger.silly('FinalChunkConsumerMiddleware is removed') - } - } - - const middlewares = builder.build() - logger.silly( - 'middlewares', - middlewares.map((m) => m.name) - ) - - // 3. Create the wrapped SDK method with middlewares - const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares) - - // 4. Execute the wrapped method with the original params - const result = wrappedCompletionMethod(params, options) - return result - } - - public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise { - const traceName = params.assistant.model?.name - ? `${params.assistant.model?.name}.${params.callType}` - : `LLM.${params.callType}` - - const traceParams: StartSpanParams = { - name: traceName, - tag: 'LLM', - topicId: params.topicId || '', - modelName: params.assistant.model?.name - } - - return await withSpanResult(this.completions.bind(this), traceParams, params, options) - } - - public async models(): Promise { - return this.apiClient.listModels() - } - - public async getEmbeddingDimensions(model: Model): Promise { - try { - // Use the SDK instance to test embedding capabilities - if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') { - this.apiClient = this.apiClient.getClient(model) as BaseApiClient - } - const dimensions = await this.apiClient.getEmbeddingDimensions(model) - return dimensions - } catch (error) { - logger.error('Error getting embedding dimensions:', error as Error) - throw error - } - } - - public async generateImage(params: GenerateImageParams): Promise { - if (this.apiClient instanceof AihubmixAPIClient) { - const client = this.apiClient.getClientForModel({ id: params.model } as Model) - return client.generateImage(params) - } - return this.apiClient.generateImage(params) - } - - public getBaseURL(): string { - return this.apiClient.getBaseURL() - } - - public getApiKey(): string { - return this.apiClient.getApiKey() - } -} +// 导出一些常用的类型和工具 +export * from './legacy/clients/types' +export * from './legacy/middleware/schemas' diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts new file mode 100644 index 0000000000..9991dffd1f --- /dev/null +++ b/src/renderer/src/aiCore/index_new.ts @@ -0,0 +1,506 @@ +/** + * Cherry Studio AI Core - 新版本入口 + * 集成 @cherrystudio/ai-core 库的渐进式重构方案 + * + * 融合方案:简化实现,专注于核心功能 + * 1. 优先使用新AI SDK + * 2. 暂时保持接口兼容性 + */ + +import { createExecutor } from '@cherrystudio/ai-core' +import { loggerService } from '@logger' +import { getEnableDeveloperMode } from '@renderer/hooks/useSettings' +import { addSpan, endSpan } from '@renderer/services/SpanManagerService' +import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity' +import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types' +import type { AiSdkModel, StreamTextParams } from '@renderer/types/aiCoreTypes' +import { type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai' + +import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter' +import LegacyAiProvider from './legacy/index' +import { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas' +import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder' +import { buildPlugins } from './plugins/PluginBuilder' +import { createAiSdkProvider } from './provider/factory' +import { + getActualProvider, + isModernSdkSupported, + prepareSpecialProviderConfig, + providerToAiSdkConfig +} from './provider/providerConfig' + +const logger = loggerService.withContext('ModernAiProvider') + +export type ModernAiProviderConfig = AiSdkMiddlewareConfig & { + assistant: Assistant + // topicId for tracing + topicId?: string + callType: string +} + +export default class ModernAiProvider { + private legacyProvider: LegacyAiProvider + private config?: ReturnType + private actualProvider: Provider + private model?: Model + private localProvider: Awaited | null = null + + // 构造函数重载签名 + constructor(model: Model, provider?: Provider) + constructor(provider: Provider) + constructor(modelOrProvider: Model | Provider, provider?: Provider) + constructor(modelOrProvider: Model | Provider, provider?: Provider) { + if (this.isModel(modelOrProvider)) { + // 传入的是 Model + this.model = modelOrProvider + this.actualProvider = provider || getActualProvider(modelOrProvider) + // 只保存配置,不预先创建executor + this.config = providerToAiSdkConfig(this.actualProvider, modelOrProvider) + } else { + // 传入的是 Provider + this.actualProvider = modelOrProvider + // model为可选,某些操作(如fetchModels)不需要model + } + + this.legacyProvider = new LegacyAiProvider(this.actualProvider) + } + + /** + * 类型守卫函数:通过 provider 属性区分 Model 和 Provider + */ + private isModel(obj: Model | Provider): obj is Model { + return 'provider' in obj && typeof obj.provider === 'string' + } + + public getActualProvider() { + return this.actualProvider + } + + public async completions(modelId: string, params: StreamTextParams, config: ModernAiProviderConfig) { + // 检查model是否存在 + if (!this.model) { + throw new Error('Model is required for completions. Please use constructor with model parameter.') + } + + // 确保配置存在 + if (!this.config) { + this.config = providerToAiSdkConfig(this.actualProvider, this.model) + } + + // 准备特殊配置 + await prepareSpecialProviderConfig(this.actualProvider, this.config) + + // 提前创建本地 provider 实例 + if (!this.localProvider) { + this.localProvider = await createAiSdkProvider(this.config) + } + + // 提前构建中间件 + const middlewares = buildAiSdkMiddlewares({ + ...config, + provider: this.actualProvider + }) + logger.debug('Built middlewares in completions', { + middlewareCount: middlewares.length, + isImageGeneration: config.isImageGenerationEndpoint + }) + if (!this.localProvider) { + throw new Error('Local provider not created') + } + + // 根据endpoint类型创建对应的模型 + let model: AiSdkModel | undefined + if (config.isImageGenerationEndpoint) { + model = this.localProvider.imageModel(modelId) + } else { + model = this.localProvider.languageModel(modelId) + // 如果有中间件,应用到语言模型上 + if (middlewares.length > 0 && typeof model === 'object') { + model = wrapLanguageModel({ model, middleware: middlewares }) + } + } + + if (config.topicId && getEnableDeveloperMode()) { + // TypeScript类型窄化:确保topicId是string类型 + const traceConfig = { + ...config, + topicId: config.topicId + } + return await this._completionsForTrace(model, params, traceConfig) + } else { + return await this._completionsOrImageGeneration(model, params, config) + } + } + + private async _completionsOrImageGeneration( + model: AiSdkModel, + params: StreamTextParams, + config: ModernAiProviderConfig + ): Promise { + if (config.isImageGenerationEndpoint) { + // 使用 legacy 实现处理图像生成(支持图片编辑等高级功能) + if (!config.uiMessages) { + throw new Error('uiMessages is required for image generation endpoint') + } + + const legacyParams: CompletionsParams = { + callType: 'chat', + messages: config.uiMessages, // 使用原始的 UI 消息格式 + assistant: config.assistant, + streamOutput: config.streamOutput ?? true, + onChunk: config.onChunk, + topicId: config.topicId, + mcpTools: config.mcpTools, + enableWebSearch: config.enableWebSearch + } + + // 调用 legacy 的 completions,会自动使用 ImageGenerationMiddleware + return await this.legacyProvider.completions(legacyParams) + } + + return await this.modernCompletions(model as LanguageModel, params, config) + } + + /** + * 带trace支持的completions方法 + * 类似于legacy的completionsForTrace,确保AI SDK spans在正确的trace上下文中 + */ + private async _completionsForTrace( + model: AiSdkModel, + params: StreamTextParams, + config: ModernAiProviderConfig & { topicId: string } + ): Promise { + const modelId = this.model!.id + const traceName = `${this.actualProvider.name}.${modelId}.${config.callType}` + const traceParams: StartSpanParams = { + name: traceName, + tag: 'LLM', + topicId: config.topicId, + modelName: config.assistant.model?.name, // 使用modelId而不是provider名称 + inputs: params + } + + logger.info('Starting AI SDK trace span', { + traceName, + topicId: config.topicId, + modelId, + hasTools: !!params.tools && Object.keys(params.tools).length > 0, + toolNames: params.tools ? Object.keys(params.tools) : [], + isImageGeneration: config.isImageGenerationEndpoint + }) + + const span = addSpan(traceParams) + if (!span) { + logger.warn('Failed to create span, falling back to regular completions', { + topicId: config.topicId, + modelId, + traceName + }) + return await this._completionsOrImageGeneration(model, params, config) + } + + try { + logger.info('Created parent span, now calling completions', { + spanId: span.spanContext().spanId, + traceId: span.spanContext().traceId, + topicId: config.topicId, + modelId, + parentSpanCreated: true + }) + + const result = await this._completionsOrImageGeneration(model, params, config) + + logger.info('Completions finished, ending parent span', { + spanId: span.spanContext().spanId, + traceId: span.spanContext().traceId, + topicId: config.topicId, + modelId, + resultLength: result.getText().length + }) + + // 标记span完成 + endSpan({ + topicId: config.topicId, + outputs: result, + span, + modelName: modelId // 使用modelId保持一致性 + }) + + return result + } catch (error) { + logger.error('Error in completionsForTrace, ending parent span with error', error as Error, { + spanId: span.spanContext().spanId, + traceId: span.spanContext().traceId, + topicId: config.topicId, + modelId + }) + + // 标记span出错 + endSpan({ + topicId: config.topicId, + error: error as Error, + span, + modelName: modelId // 使用modelId保持一致性 + }) + throw error + } + } + + /** + * 使用现代化AI SDK的completions实现 + */ + private async modernCompletions( + model: LanguageModel, + params: StreamTextParams, + config: ModernAiProviderConfig + ): Promise { + const modelId = this.model!.id + logger.info('Starting modernCompletions', { + modelId, + providerId: this.config!.providerId, + topicId: config.topicId, + hasOnChunk: !!config.onChunk, + hasTools: !!params.tools && Object.keys(params.tools).length > 0, + toolCount: params.tools ? Object.keys(params.tools).length : 0 + }) + + // 根据条件构建插件数组 + const plugins = buildPlugins(config) + + // 用构建好的插件数组创建executor + const executor = createExecutor(this.config!.providerId, this.config!.options, plugins) + + // 创建带有中间件的执行器 + if (config.onChunk) { + const accumulate = this.model!.supported_text_delta !== false // true and undefined + const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools, accumulate) + + const streamResult = await executor.streamText({ + ...params, + model, + experimental_context: { onChunk: config.onChunk } + }) + + const finalText = await adapter.processStream(streamResult) + + return { + getText: () => finalText + } + } else { + const streamResult = await executor.streamText({ + ...params, + model + }) + + // 强制消费流,不然await streamResult.text会阻塞 + await streamResult?.consumeStream() + + const finalText = await streamResult.text + + return { + getText: () => finalText + } + } + } + + /** + * 使用现代化 AI SDK 的图像生成实现,支持流式输出 + * @deprecated 已改为使用 legacy 实现以支持图片编辑等高级功能 + */ + /* + private async modernImageGeneration( + model: ImageModel, + params: StreamTextParams, + config: ModernAiProviderConfig + ): Promise { + const { onChunk } = config + + try { + // 检查 messages 是否存在 + if (!params.messages || params.messages.length === 0) { + throw new Error('No messages provided for image generation.') + } + + // 从最后一条用户消息中提取 prompt + const lastUserMessage = params.messages.findLast((m) => m.role === 'user') + if (!lastUserMessage) { + throw new Error('No user message found for image generation.') + } + + // 直接使用消息内容,避免类型转换问题 + const prompt = + typeof lastUserMessage.content === 'string' + ? lastUserMessage.content + : lastUserMessage.content?.map((part) => ('text' in part ? part.text : '')).join('') || '' + + if (!prompt) { + throw new Error('No prompt found in user message.') + } + + const startTime = Date.now() + + // 发送图像生成开始事件 + if (onChunk) { + onChunk({ type: ChunkType.IMAGE_CREATED }) + } + + // 构建图像生成参数 + const imageParams = { + prompt, + size: isNotSupportedImageSizeModel(config.model) ? undefined : ('1024x1024' as `${number}x${number}`), // 默认尺寸,使用正确的类型 + n: 1, + ...(params.abortSignal && { abortSignal: params.abortSignal }) + } + + // 调用新 AI SDK 的图像生成功能 + const executor = createExecutor(this.config!.providerId, this.config!.options, []) + const result = await executor.generateImage({ + model, + ...imageParams + }) + + // 转换结果格式 + const images: string[] = [] + const imageType: 'url' | 'base64' = 'base64' + + if (result.images) { + for (const image of result.images) { + if ('base64' in image && image.base64) { + images.push(`data:${image.mediaType};base64,${image.base64}`) + } + } + } + + // 发送图像生成完成事件 + if (onChunk && images.length > 0) { + onChunk({ + type: ChunkType.IMAGE_COMPLETE, + image: { type: imageType, images } + }) + } + + // 发送块完成事件(类似于 modernCompletions 的处理) + if (onChunk) { + const usage = { + prompt_tokens: prompt.length, // 估算的 token 数量 + completion_tokens: 0, // 图像生成没有 completion tokens + total_tokens: prompt.length + } + + onChunk({ + type: ChunkType.BLOCK_COMPLETE, + response: { + usage, + metrics: { + completion_tokens: usage.completion_tokens, + time_first_token_millsec: 0, + time_completion_millsec: Date.now() - startTime + } + } + }) + + // 发送 LLM 响应完成事件 + onChunk({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage, + metrics: { + completion_tokens: usage.completion_tokens, + time_first_token_millsec: 0, + time_completion_millsec: Date.now() - startTime + } + } + }) + } + + return { + getText: () => '' // 图像生成不返回文本 + } + } catch (error) { + // 发送错误事件 + if (onChunk) { + onChunk({ type: ChunkType.ERROR, error: error as any }) + } + throw error + } + } + */ + + // 代理其他方法到原有实现 + public async models() { + return this.legacyProvider.models() + } + + public async getEmbeddingDimensions(model: Model): Promise { + return this.legacyProvider.getEmbeddingDimensions(model) + } + + public async generateImage(params: GenerateImageParams): Promise { + // 如果支持新的 AI SDK,使用现代化实现 + if (isModernSdkSupported(this.actualProvider)) { + try { + // 确保本地provider已创建 + if (!this.localProvider) { + this.localProvider = await createAiSdkProvider(this.config) + if (!this.localProvider) { + throw new Error('Local provider not created') + } + } + + const result = await this.modernGenerateImage(params) + return result + } catch (error) { + logger.warn('Modern AI SDK generateImage failed, falling back to legacy:', error as Error) + // fallback 到传统实现 + return this.legacyProvider.generateImage(params) + } + } + + // 直接使用传统实现 + return this.legacyProvider.generateImage(params) + } + + /** + * 使用现代化 AI SDK 的图像生成实现 + */ + private async modernGenerateImage(params: GenerateImageParams): Promise { + const { model, prompt, imageSize, batchSize, signal } = params + + // 转换参数格式 + const aiSdkParams = { + prompt, + size: (imageSize || '1024x1024') as `${number}x${number}`, + n: batchSize || 1, + ...(signal && { abortSignal: signal }) + } + + const executor = createExecutor(this.config!.providerId, this.config!.options, []) + const result = await executor.generateImage({ + model: this.localProvider?.imageModel(model) as ImageModel, + ...aiSdkParams + }) + + // 转换结果格式 + const images: string[] = [] + if (result.images) { + for (const image of result.images) { + if ('base64' in image && image.base64) { + images.push(`data:image/png;base64,${image.base64}`) + } + } + } + + return images + } + + public getBaseURL(): string { + return this.legacyProvider.getBaseURL() + } + + public getApiKey(): string { + return this.legacyProvider.getApiKey() + } +} + +// 为了方便调试,导出一些工具函数 +export { isModernSdkSupported, providerToAiSdkConfig } diff --git a/src/renderer/src/aiCore/clients/ApiClientFactory.ts b/src/renderer/src/aiCore/legacy/clients/ApiClientFactory.ts similarity index 97% rename from src/renderer/src/aiCore/clients/ApiClientFactory.ts rename to src/renderer/src/aiCore/legacy/clients/ApiClientFactory.ts index 7c5575aa08..31a911533e 100644 --- a/src/renderer/src/aiCore/clients/ApiClientFactory.ts +++ b/src/renderer/src/aiCore/legacy/clients/ApiClientFactory.ts @@ -75,6 +75,7 @@ export class ApiClientFactory { instance = new GeminiAPIClient(provider) as BaseApiClient break case 'vertexai': + logger.debug(`Creating VertexAPIClient for provider: ${provider.id}`) instance = new VertexAPIClient(provider) as BaseApiClient break case 'anthropic': diff --git a/src/renderer/src/aiCore/clients/BaseApiClient.ts b/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/BaseApiClient.ts rename to src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts diff --git a/src/renderer/src/aiCore/clients/MixedBaseApiClient.ts b/src/renderer/src/aiCore/legacy/clients/MixedBaseApiClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/MixedBaseApiClient.ts rename to src/renderer/src/aiCore/legacy/clients/MixedBaseApiClient.ts diff --git a/src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts b/src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts similarity index 99% rename from src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts rename to src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts index 4d58c78772..081469516b 100644 --- a/src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts +++ b/src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts @@ -66,7 +66,8 @@ vi.mock('@renderer/config/models', () => ({ SYSTEM_MODELS: { silicon: [], defaultModel: [] - } + }, + isOpenAIModel: vi.fn(() => false) })) describe('ApiClientFactory', () => { diff --git a/src/renderer/src/aiCore/__tests__/index.clientCompatibilityTypes.test.ts b/src/renderer/src/aiCore/legacy/clients/__tests__/index.clientCompatibilityTypes.test.ts similarity index 92% rename from src/renderer/src/aiCore/__tests__/index.clientCompatibilityTypes.test.ts rename to src/renderer/src/aiCore/legacy/clients/__tests__/index.clientCompatibilityTypes.test.ts index 12571875db..343bc4d544 100644 --- a/src/renderer/src/aiCore/__tests__/index.clientCompatibilityTypes.test.ts +++ b/src/renderer/src/aiCore/legacy/clients/__tests__/index.clientCompatibilityTypes.test.ts @@ -1,11 +1,11 @@ -import { AihubmixAPIClient } from '@renderer/aiCore/clients/aihubmix/AihubmixAPIClient' -import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient' -import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory' -import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient' -import { VertexAPIClient } from '@renderer/aiCore/clients/gemini/VertexAPIClient' -import { NewAPIClient } from '@renderer/aiCore/clients/newapi/NewAPIClient' -import { OpenAIAPIClient } from '@renderer/aiCore/clients/openai/OpenAIApiClient' -import { OpenAIResponseAPIClient } from '@renderer/aiCore/clients/openai/OpenAIResponseAPIClient' +import { AihubmixAPIClient } from '@renderer/aiCore/legacy/clients/aihubmix/AihubmixAPIClient' +import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient' +import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory' +import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient' +import { VertexAPIClient } from '@renderer/aiCore/legacy/clients/gemini/VertexAPIClient' +import { NewAPIClient } from '@renderer/aiCore/legacy/clients/newapi/NewAPIClient' +import { OpenAIAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIApiClient' +import { OpenAIResponseAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIResponseAPIClient' import { EndpointType, Model, Provider } from '@renderer/types' import { beforeEach, describe, expect, it, vi } from 'vitest' @@ -22,6 +22,7 @@ vi.mock('@renderer/config/models', () => ({ anthropic: [], gemini: [] }, + isOpenAIModel: vi.fn().mockReturnValue(true), isOpenAILLMModel: vi.fn().mockReturnValue(true), isOpenAIChatCompletionOnlyModel: vi.fn().mockReturnValue(false), isAnthropicLLMModel: vi.fn().mockReturnValue(false), @@ -80,6 +81,7 @@ vi.mock('@logger', () => ({ } })) +// 到底是谁想出来的在服务层调用 React Hook ????????? // Mock additional services and hooks that might be imported vi.mock('@renderer/hooks/useVertexAI', () => ({ getVertexAILocation: vi.fn().mockReturnValue('us-central1'), @@ -87,7 +89,9 @@ vi.mock('@renderer/hooks/useVertexAI', () => ({ getVertexAIServiceAccount: vi.fn().mockReturnValue({ privateKey: 'test-key', clientEmail: 'test@example.com' - }) + }), + isVertexAIConfigured: vi.fn().mockReturnValue(true), + isVertexProvider: vi.fn().mockReturnValue(true) })) vi.mock('@renderer/hooks/useSettings', () => ({ @@ -131,7 +135,7 @@ vi.mock('@google-cloud/vertexai', () => ({ })) // Mock the circular dependency between VertexAPIClient and AnthropicVertexClient -vi.mock('@renderer/aiCore/clients/anthropic/AnthropicVertexClient', () => { +vi.mock('@renderer/aiCore/legacy/clients/anthropic/AnthropicVertexClient', () => { const MockAnthropicVertexClient = vi.fn() MockAnthropicVertexClient.prototype.getClientCompatibilityType = vi.fn().mockReturnValue(['AnthropicVertexAPIClient']) return { diff --git a/src/renderer/src/aiCore/clients/aihubmix/AihubmixAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/aihubmix/AihubmixAPIClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/aihubmix/AihubmixAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/aihubmix/AihubmixAPIClient.ts diff --git a/src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts similarity index 99% rename from src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts index f286b40d59..8b42643472 100644 --- a/src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts @@ -25,7 +25,6 @@ import { import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages' import AnthropicVertex from '@anthropic-ai/vertex-sdk' import { loggerService } from '@logger' -import { GenericChunk } from '@renderer/aiCore/middleware/schemas' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models' import { getAssistantSettings } from '@renderer/services/AssistantService' @@ -64,13 +63,14 @@ import { import { addImageFileToContents } from '@renderer/utils/formats' import { anthropicToolUseToMcpTool, - isEnabledToolUse, + isSupportedToolUse, mcpToolCallResponseToAnthropicMessage, mcpToolsToAnthropicTools } from '@renderer/utils/mcp-tools' import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' import { t } from 'i18next' +import { GenericChunk } from '../../middleware/schemas' import { BaseApiClient } from '../BaseApiClient' import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types' @@ -457,7 +457,7 @@ export class AnthropicAPIClient extends BaseApiClient< const { tools } = this.setupToolsConfig({ mcpTools: mcpTools, model, - enableToolUse: isEnabledToolUse(assistant) + enableToolUse: isSupportedToolUse(assistant) }) const systemMessage: TextBlockParam | undefined = systemPrompt diff --git a/src/renderer/src/aiCore/clients/anthropic/AnthropicVertexClient.ts b/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicVertexClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/anthropic/AnthropicVertexClient.ts rename to src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicVertexClient.ts diff --git a/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/aws/AwsBedrockAPIClient.ts similarity index 99% rename from src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/aws/AwsBedrockAPIClient.ts index 1d990dfbda..42c5118580 100644 --- a/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/aws/AwsBedrockAPIClient.ts @@ -6,7 +6,7 @@ import { InvokeModelWithResponseStreamCommand } from '@aws-sdk/client-bedrock-runtime' import { loggerService } from '@logger' -import { GenericChunk } from '@renderer/aiCore/middleware/schemas' +import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { findTokenLimit, isReasoningModel } from '@renderer/config/models' import { @@ -50,7 +50,7 @@ import { import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils' import { awsBedrockToolUseToMcpTool, - isEnabledToolUse, + isSupportedToolUse, mcpToolCallResponseToAwsBedrockMessage, mcpToolsToAwsBedrockTools } from '@renderer/utils/mcp-tools' @@ -739,7 +739,7 @@ export class AwsBedrockAPIClient extends BaseApiClient< const { tools } = this.setupToolsConfig({ mcpTools: mcpTools, model, - enableToolUse: isEnabledToolUse(assistant) + enableToolUse: isSupportedToolUse(assistant) }) // 3. 处理消息 diff --git a/src/renderer/src/aiCore/clients/cherryin/CherryinAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/cherryin/CherryinAPIClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/cherryin/CherryinAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/cherryin/CherryinAPIClient.ts diff --git a/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts similarity index 99% rename from src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts index 70b2997fd2..153b53a362 100644 --- a/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts @@ -18,7 +18,6 @@ import { } from '@google/genai' import { loggerService } from '@logger' import { nanoid } from '@reduxjs/toolkit' -import { GenericChunk } from '@renderer/aiCore/middleware/schemas' import { findTokenLimit, GEMINI_FLASH_MODEL_REGEX, @@ -55,7 +54,7 @@ import { import { isToolUseModeFunction } from '@renderer/utils/assistant' import { geminiFunctionCallToMcpTool, - isEnabledToolUse, + isSupportedToolUse, mcpToolCallResponseToGeminiMessage, mcpToolsToGeminiTools } from '@renderer/utils/mcp-tools' @@ -63,6 +62,7 @@ import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/u import { defaultTimeout, MB } from '@shared/config/constant' import { t } from 'i18next' +import { GenericChunk } from '../../middleware/schemas' import { BaseApiClient } from '../BaseApiClient' import { RequestTransformer, ResponseChunkTransformer } from '../types' @@ -454,7 +454,7 @@ export class GeminiAPIClient extends BaseApiClient< const { tools } = this.setupToolsConfig({ mcpTools, model, - enableToolUse: isEnabledToolUse(assistant) + enableToolUse: isSupportedToolUse(assistant) }) let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents diff --git a/src/renderer/src/aiCore/clients/gemini/VertexAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts similarity index 70% rename from src/renderer/src/aiCore/clients/gemini/VertexAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts index a5328e9e61..37e6677367 100644 --- a/src/renderer/src/aiCore/clients/gemini/VertexAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts @@ -1,7 +1,7 @@ import { GoogleGenAI } from '@google/genai' import { loggerService } from '@logger' -import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI' -import { Model, Provider } from '@renderer/types' +import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' +import { Model, Provider, VertexProvider } from '@renderer/types' import { isEmpty } from 'lodash' import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient' @@ -12,10 +12,21 @@ export class VertexAPIClient extends GeminiAPIClient { private authHeaders?: Record private authHeadersExpiry?: number private anthropicVertexClient: AnthropicVertexClient + private vertexProvider: VertexProvider constructor(provider: Provider) { super(provider) + // 检查 VertexAI 配置 + if (!isVertexAIConfigured()) { + throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') + } this.anthropicVertexClient = new AnthropicVertexClient(provider) + // 如果传入的是普通 Provider,转换为 VertexProvider + if (isVertexProvider(provider)) { + this.vertexProvider = provider + } else { + this.vertexProvider = createVertexProvider(provider) + } } override getClientCompatibilityType(model?: Model): string[] { @@ -56,11 +67,9 @@ export class VertexAPIClient extends GeminiAPIClient { return this.sdkInstance } - const serviceAccount = getVertexAIServiceAccount() - const projectId = getVertexAIProjectId() - const location = getVertexAILocation() + const { googleCredentials, project, location } = this.vertexProvider - if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) { + if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project || !location) { throw new Error('Vertex AI settings are not configured') } @@ -68,7 +77,7 @@ export class VertexAPIClient extends GeminiAPIClient { this.sdkInstance = new GoogleGenAI({ vertexai: true, - project: projectId, + project: project, location: location, httpOptions: { apiVersion: this.getApiVersion(), @@ -84,11 +93,10 @@ export class VertexAPIClient extends GeminiAPIClient { * 获取认证头,如果配置了 service account 则从主进程获取 */ private async getServiceAccountAuthHeaders(): Promise | undefined> { - const serviceAccount = getVertexAIServiceAccount() - const projectId = getVertexAIProjectId() + const { googleCredentials, project } = this.vertexProvider // 检查是否配置了 service account - if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) { + if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project) { return undefined } @@ -101,10 +109,10 @@ export class VertexAPIClient extends GeminiAPIClient { try { // 从主进程获取认证头 this.authHeaders = await window.api.vertexAI.getAuthHeaders({ - projectId, + projectId: project, serviceAccount: { - privateKey: serviceAccount.privateKey, - clientEmail: serviceAccount.clientEmail + privateKey: googleCredentials.privateKey, + clientEmail: googleCredentials.clientEmail } }) @@ -125,11 +133,10 @@ export class VertexAPIClient extends GeminiAPIClient { this.authHeaders = undefined this.authHeadersExpiry = undefined - const serviceAccount = getVertexAIServiceAccount() - const projectId = getVertexAIProjectId() + const { googleCredentials, project } = this.vertexProvider - if (projectId && serviceAccount.clientEmail) { - window.api.vertexAI.clearAuthCache(projectId, serviceAccount.clientEmail) + if (project && googleCredentials.clientEmail) { + window.api.vertexAI.clearAuthCache(project, googleCredentials.clientEmail) } } } diff --git a/src/renderer/src/aiCore/clients/index.ts b/src/renderer/src/aiCore/legacy/clients/index.ts similarity index 100% rename from src/renderer/src/aiCore/clients/index.ts rename to src/renderer/src/aiCore/legacy/clients/index.ts diff --git a/src/renderer/src/aiCore/clients/newapi/NewAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/newapi/NewAPIClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/newapi/NewAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/newapi/NewAPIClient.ts diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts similarity index 99% rename from src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts rename to src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts index 7d527c0577..92de478313 100644 --- a/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts @@ -71,7 +71,7 @@ import { } from '@renderer/types/sdk' import { addImageFileToContents } from '@renderer/utils/formats' import { - isEnabledToolUse, + isSupportedToolUse, mcpToolCallResponseToOpenAICompatibleMessage, mcpToolsToOpenAIChatTools, openAIToolsToMcpTool @@ -611,7 +611,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient< const { tools } = this.setupToolsConfig({ mcpTools: mcpTools, model, - enableToolUse: isEnabledToolUse(assistant) + enableToolUse: isSupportedToolUse(assistant) }) // 3. 处理用户消息 diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts rename to src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts similarity index 99% rename from src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts index 36666fcaf2..2d5bc6010a 100644 --- a/src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts @@ -1,6 +1,6 @@ import { loggerService } from '@logger' -import { GenericChunk } from '@renderer/aiCore/middleware/schemas' -import { CompletionsContext } from '@renderer/aiCore/middleware/types' +import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas' +import { CompletionsContext } from '@renderer/aiCore/legacy/middleware/types' import { isGPT5SeriesModel, isOpenAIChatCompletionOnlyModel, @@ -36,7 +36,7 @@ import { } from '@renderer/types/sdk' import { addImageFileToContents } from '@renderer/utils/formats' import { - isEnabledToolUse, + isSupportedToolUse, mcpToolCallResponseToOpenAIMessage, mcpToolsToOpenAIResponseTools, openAIToolsToMcpTool @@ -388,7 +388,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient< const { tools: extraTools } = this.setupToolsConfig({ mcpTools: mcpTools, model, - enableToolUse: isEnabledToolUse(assistant) + enableToolUse: isSupportedToolUse(assistant) }) systemMessageContent.push(systemMessageInput) diff --git a/src/renderer/src/aiCore/clients/ppio/PPIOAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/ppio/PPIOAPIClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/ppio/PPIOAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/ppio/PPIOAPIClient.ts diff --git a/src/renderer/src/aiCore/clients/types.ts b/src/renderer/src/aiCore/legacy/clients/types.ts similarity index 100% rename from src/renderer/src/aiCore/clients/types.ts rename to src/renderer/src/aiCore/legacy/clients/types.ts diff --git a/src/renderer/src/aiCore/clients/zhipu/ZhipuAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/zhipu/ZhipuAPIClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/zhipu/ZhipuAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/zhipu/ZhipuAPIClient.ts diff --git a/src/renderer/src/aiCore/legacy/index.ts b/src/renderer/src/aiCore/legacy/index.ts new file mode 100644 index 0000000000..adc81f03ad --- /dev/null +++ b/src/renderer/src/aiCore/legacy/index.ts @@ -0,0 +1,189 @@ +import { loggerService } from '@logger' +import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory' +import { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient' +import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models' +import { getProviderByModel } from '@renderer/services/AssistantService' +import { withSpanResult } from '@renderer/services/SpanManagerService' +import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity' +import type { GenerateImageParams, Model, Provider } from '@renderer/types' +import type { RequestOptions, SdkModel } from '@renderer/types/sdk' +import { isSupportedToolUse } from '@renderer/utils/mcp-tools' + +import { AihubmixAPIClient } from './clients/aihubmix/AihubmixAPIClient' +import { VertexAPIClient } from './clients/gemini/VertexAPIClient' +import { NewAPIClient } from './clients/newapi/NewAPIClient' +import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient' +import { CompletionsMiddlewareBuilder } from './middleware/builder' +import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware' +import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware' +import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware' +import { applyCompletionsMiddlewares } from './middleware/composer' +import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware' +import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware' +import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware' +import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware' +import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware' +import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware' +import { MiddlewareRegistry } from './middleware/register' +import type { CompletionsParams, CompletionsResult } from './middleware/schemas' + +const logger = loggerService.withContext('AiProvider') + +export default class AiProvider { + private apiClient: BaseApiClient + + constructor(provider: Provider) { + // Use the new ApiClientFactory to get a BaseApiClient instance + this.apiClient = ApiClientFactory.create(provider) + } + + public async completions(params: CompletionsParams, options?: RequestOptions): Promise { + // 1. 根据模型识别正确的客户端 + const model = params.assistant.model + if (!model) { + return Promise.reject(new Error('Model is required')) + } + + // 根据client类型选择合适的处理方式 + let client: BaseApiClient + + if (this.apiClient instanceof AihubmixAPIClient) { + // AihubmixAPIClient: 根据模型选择合适的子client + client = this.apiClient.getClientForModel(model) + if (client instanceof OpenAIResponseAPIClient) { + client = client.getClient(model) as BaseApiClient + } + } else if (this.apiClient instanceof NewAPIClient) { + client = this.apiClient.getClientForModel(model) + if (client instanceof OpenAIResponseAPIClient) { + client = client.getClient(model) as BaseApiClient + } + } else if (this.apiClient instanceof OpenAIResponseAPIClient) { + // OpenAIResponseAPIClient: 根据模型特征选择API类型 + client = this.apiClient.getClient(model) as BaseApiClient + } else if (this.apiClient instanceof VertexAPIClient) { + client = this.apiClient.getClient(model) as BaseApiClient + } else { + // 其他client直接使用 + client = this.apiClient + } + + // 2. 构建中间件链 + const builder = CompletionsMiddlewareBuilder.withDefaults() + // images api + if (isDedicatedImageGenerationModel(model)) { + builder.clear() + builder + .add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName]) + .add(MiddlewareRegistry[ErrorHandlerMiddlewareName]) + .add(MiddlewareRegistry[AbortHandlerMiddlewareName]) + .add(MiddlewareRegistry[ImageGenerationMiddlewareName]) + } else { + // Existing logic for other models + logger.silly('Builder Params', params) + // 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题 + const clientTypes = client.getClientCompatibilityType(model) + const isOpenAICompatible = + clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient') + if (!isOpenAICompatible) { + logger.silly('ThinkingTagExtractionMiddleware is removed') + builder.remove(ThinkingTagExtractionMiddlewareName) + } + + const isAnthropicOrOpenAIResponseCompatible = + clientTypes.includes('AnthropicAPIClient') || + clientTypes.includes('OpenAIResponseAPIClient') || + clientTypes.includes('AnthropicVertexAPIClient') + if (!isAnthropicOrOpenAIResponseCompatible) { + logger.silly('RawStreamListenerMiddleware is removed') + builder.remove(RawStreamListenerMiddlewareName) + } + if (!params.enableWebSearch) { + logger.silly('WebSearchMiddleware is removed') + builder.remove(WebSearchMiddlewareName) + } + if (!params.mcpTools?.length) { + builder.remove(ToolUseExtractionMiddlewareName) + logger.silly('ToolUseExtractionMiddleware is removed') + builder.remove(McpToolChunkMiddlewareName) + logger.silly('McpToolChunkMiddleware is removed') + } + if (isSupportedToolUse(params.assistant) && isFunctionCallingModel(model)) { + builder.remove(ToolUseExtractionMiddlewareName) + logger.silly('ToolUseExtractionMiddleware is removed') + } + if (params.callType !== 'chat' && params.callType !== 'check' && params.callType !== 'translate') { + logger.silly('AbortHandlerMiddleware is removed') + builder.remove(AbortHandlerMiddlewareName) + } + if (params.callType === 'test') { + builder.remove(ErrorHandlerMiddlewareName) + logger.silly('ErrorHandlerMiddleware is removed') + builder.remove(FinalChunkConsumerMiddlewareName) + logger.silly('FinalChunkConsumerMiddleware is removed') + } + } + + const middlewares = builder.build() + logger.silly( + 'middlewares', + middlewares.map((m) => m.name) + ) + + // 3. Create the wrapped SDK method with middlewares + const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares) + + // 4. Execute the wrapped method with the original params + const result = wrappedCompletionMethod(params, options) + return result + } + + public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise { + const traceName = params.assistant.model?.name + ? `${params.assistant.model?.name}.${params.callType}` + : `LLM.${params.callType}` + + const traceParams: StartSpanParams = { + name: traceName, + tag: 'LLM', + topicId: params.topicId || '', + modelName: params.assistant.model?.name + } + + return await withSpanResult(this.completions.bind(this), traceParams, params, options) + } + + public async models(): Promise { + return this.apiClient.listModels() + } + + public async getEmbeddingDimensions(model: Model): Promise { + try { + // Use the SDK instance to test embedding capabilities + if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') { + this.apiClient = this.apiClient.getClient(model) as BaseApiClient + } + const dimensions = await this.apiClient.getEmbeddingDimensions(model) + return dimensions + } catch (error) { + logger.error('Error getting embedding dimensions:', error as Error) + throw error + } + } + + public async generateImage(params: GenerateImageParams): Promise { + if (this.apiClient instanceof AihubmixAPIClient) { + const client = this.apiClient.getClientForModel({ id: params.model } as Model) + return client.generateImage(params) + } + return this.apiClient.generateImage(params) + } + + public getBaseURL(): string { + return this.apiClient.getBaseURL() + } + + public getApiKey(): string { + return this.apiClient.getApiKey() + } +} diff --git a/src/renderer/src/aiCore/middleware/BUILDER_USAGE.md b/src/renderer/src/aiCore/legacy/middleware/BUILDER_USAGE.md similarity index 100% rename from src/renderer/src/aiCore/middleware/BUILDER_USAGE.md rename to src/renderer/src/aiCore/legacy/middleware/BUILDER_USAGE.md diff --git a/src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md b/src/renderer/src/aiCore/legacy/middleware/MIDDLEWARE_SPECIFICATION.md similarity index 100% rename from src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md rename to src/renderer/src/aiCore/legacy/middleware/MIDDLEWARE_SPECIFICATION.md diff --git a/src/renderer/src/aiCore/middleware/__tests__/utils.test.ts b/src/renderer/src/aiCore/legacy/middleware/__tests__/utils.test.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/__tests__/utils.test.ts rename to src/renderer/src/aiCore/legacy/middleware/__tests__/utils.test.ts diff --git a/src/renderer/src/aiCore/middleware/builder.ts b/src/renderer/src/aiCore/legacy/middleware/builder.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/builder.ts rename to src/renderer/src/aiCore/legacy/middleware/builder.ts diff --git a/src/renderer/src/aiCore/middleware/common/AbortHandlerMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/common/AbortHandlerMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/common/AbortHandlerMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/common/AbortHandlerMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/common/ErrorHandlerMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/common/ErrorHandlerMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/common/ErrorHandlerMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/common/ErrorHandlerMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/common/FinalChunkConsumerMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/common/FinalChunkConsumerMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/common/FinalChunkConsumerMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/common/FinalChunkConsumerMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/common/LoggingMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/common/LoggingMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/common/LoggingMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/common/LoggingMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/composer.ts b/src/renderer/src/aiCore/legacy/middleware/composer.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/composer.ts rename to src/renderer/src/aiCore/legacy/middleware/composer.ts diff --git a/src/renderer/src/aiCore/middleware/core/McpToolChunkMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/McpToolChunkMiddleware.ts similarity index 98% rename from src/renderer/src/aiCore/middleware/core/McpToolChunkMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/McpToolChunkMiddleware.ts index c0d85b0fde..fc0327925e 100644 --- a/src/renderer/src/aiCore/middleware/core/McpToolChunkMiddleware.ts +++ b/src/renderer/src/aiCore/legacy/middleware/core/McpToolChunkMiddleware.ts @@ -1,5 +1,5 @@ import { loggerService } from '@logger' -import { MCPCallToolResponse, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types' +import { MCPCallToolResponse, MCPTool, MCPToolResponse, Model } from '@renderer/types' import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk' import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk' import { @@ -230,7 +230,7 @@ async function executeToolCalls( model: Model, topicId?: string ): Promise<{ toolResults: SdkMessageParam[]; confirmedToolCalls: SdkToolCall[] }> { - const mcpToolResponses: ToolCallResponse[] = toolCalls + const mcpToolResponses: MCPToolResponse[] = toolCalls .map((toolCall) => { const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools) if (!mcpTool) { @@ -238,7 +238,7 @@ async function executeToolCalls( } return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool) }) - .filter((t): t is ToolCallResponse => typeof t !== 'undefined') + .filter((t): t is MCPToolResponse => typeof t !== 'undefined') if (mcpToolResponses.length === 0) { logger.warn(`No valid MCP tool responses to execute`) diff --git a/src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/RawStreamListenerMiddleware.ts similarity index 94% rename from src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/RawStreamListenerMiddleware.ts index fa936af479..0d59ad9de6 100644 --- a/src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts +++ b/src/renderer/src/aiCore/legacy/middleware/core/RawStreamListenerMiddleware.ts @@ -1,4 +1,4 @@ -import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient' +import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient' import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk' import { AnthropicStreamListener } from '../../clients/types' diff --git a/src/renderer/src/aiCore/middleware/core/ResponseTransformMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/ResponseTransformMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/core/ResponseTransformMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/ResponseTransformMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/core/StreamAdapterMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/StreamAdapterMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/core/StreamAdapterMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/StreamAdapterMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/TextChunkMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/TextChunkMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/ThinkChunkMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/ThinkChunkMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/core/TransformCoreToSdkParamsMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/TransformCoreToSdkParamsMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/core/TransformCoreToSdkParamsMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/TransformCoreToSdkParamsMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/core/WebSearchMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/WebSearchMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/core/WebSearchMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/WebSearchMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/feat/ImageGenerationMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/feat/ImageGenerationMiddleware.ts similarity index 98% rename from src/renderer/src/aiCore/middleware/feat/ImageGenerationMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/feat/ImageGenerationMiddleware.ts index ceb8d791d7..0f89e8aca8 100644 --- a/src/renderer/src/aiCore/middleware/feat/ImageGenerationMiddleware.ts +++ b/src/renderer/src/aiCore/legacy/middleware/feat/ImageGenerationMiddleware.ts @@ -1,4 +1,3 @@ -import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient' import { isDedicatedImageGenerationModel } from '@renderer/config/models' import FileManager from '@renderer/services/FileManager' import { ChunkType } from '@renderer/types/chunk' @@ -7,6 +6,7 @@ import { defaultTimeout } from '@shared/config/constant' import OpenAI from 'openai' import { toFile } from 'openai/uploads' +import { BaseApiClient } from '../../clients/BaseApiClient' import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' import { CompletionsContext, CompletionsMiddleware } from '../types' diff --git a/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/feat/ThinkingTagExtractionMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/feat/ThinkingTagExtractionMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/feat/ToolUseExtractionMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/feat/ToolUseExtractionMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/feat/ToolUseExtractionMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/feat/ToolUseExtractionMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/index.ts b/src/renderer/src/aiCore/legacy/middleware/index.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/index.ts rename to src/renderer/src/aiCore/legacy/middleware/index.ts diff --git a/src/renderer/src/aiCore/middleware/register.ts b/src/renderer/src/aiCore/legacy/middleware/register.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/register.ts rename to src/renderer/src/aiCore/legacy/middleware/register.ts diff --git a/src/renderer/src/aiCore/middleware/schemas.ts b/src/renderer/src/aiCore/legacy/middleware/schemas.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/schemas.ts rename to src/renderer/src/aiCore/legacy/middleware/schemas.ts diff --git a/src/renderer/src/aiCore/middleware/types.ts b/src/renderer/src/aiCore/legacy/middleware/types.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/types.ts rename to src/renderer/src/aiCore/legacy/middleware/types.ts diff --git a/src/renderer/src/aiCore/middleware/utils.ts b/src/renderer/src/aiCore/legacy/middleware/utils.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/utils.ts rename to src/renderer/src/aiCore/legacy/middleware/utils.ts diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts new file mode 100644 index 0000000000..f331d36a7e --- /dev/null +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -0,0 +1,210 @@ +import { loggerService } from '@logger' +import type { MCPTool, Message, Model, Provider } from '@renderer/types' +import type { Chunk } from '@renderer/types/chunk' +import { extractReasoningMiddleware, LanguageModelMiddleware, simulateStreamingMiddleware } from 'ai' + +const logger = loggerService.withContext('AiSdkMiddlewareBuilder') + +/** + * AI SDK 中间件配置项 + */ +export interface AiSdkMiddlewareConfig { + streamOutput: boolean + onChunk?: (chunk: Chunk) => void + model?: Model + provider?: Provider + enableReasoning: boolean + // 是否开启提示词工具调用 + isPromptToolUse: boolean + // 是否支持工具调用 + isSupportedToolUse: boolean + // image generation endpoint + isImageGenerationEndpoint: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + mcpTools?: MCPTool[] + uiMessages?: Message[] +} + +/** + * 具名的 AI SDK 中间件 + */ +export interface NamedAiSdkMiddleware { + name: string + middleware: LanguageModelMiddleware +} + +/** + * AI SDK 中间件建造者 + * 用于根据不同条件动态构建中间件数组 + */ +export class AiSdkMiddlewareBuilder { + private middlewares: NamedAiSdkMiddleware[] = [] + + /** + * 添加具名中间件 + */ + public add(namedMiddleware: NamedAiSdkMiddleware): this { + this.middlewares.push(namedMiddleware) + return this + } + + /** + * 在指定位置插入中间件 + */ + public insertAfter(targetName: string, middleware: NamedAiSdkMiddleware): this { + const index = this.middlewares.findIndex((m) => m.name === targetName) + if (index !== -1) { + this.middlewares.splice(index + 1, 0, middleware) + } else { + logger.warn(`AiSdkMiddlewareBuilder: Middleware named '${targetName}' not found, cannot insert`) + } + return this + } + + /** + * 检查是否包含指定名称的中间件 + */ + public has(name: string): boolean { + return this.middlewares.some((m) => m.name === name) + } + + /** + * 移除指定名称的中间件 + */ + public remove(name: string): this { + this.middlewares = this.middlewares.filter((m) => m.name !== name) + return this + } + + /** + * 构建最终的中间件数组 + */ + public build(): LanguageModelMiddleware[] { + return this.middlewares.map((m) => m.middleware) + } + + /** + * 获取具名中间件数组(用于调试) + */ + public buildNamed(): NamedAiSdkMiddleware[] { + return [...this.middlewares] + } + + /** + * 清空所有中间件 + */ + public clear(): this { + this.middlewares = [] + return this + } + + /** + * 获取中间件总数 + */ + public get length(): number { + return this.middlewares.length + } +} + +/** + * 根据配置构建AI SDK中间件的工厂函数 + * 这里要注意构建顺序,因为有些中间件需要依赖其他中间件的结果 + */ +export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] { + const builder = new AiSdkMiddlewareBuilder() + + // 1. 根据provider添加特定中间件 + if (config.provider) { + addProviderSpecificMiddlewares(builder, config) + } + + // 2. 根据模型类型添加特定中间件 + if (config.model) { + addModelSpecificMiddlewares(builder, config) + } + + // 3. 非流式输出时添加模拟流中间件 + if (config.streamOutput === false) { + builder.add({ + name: 'simulate-streaming', + middleware: simulateStreamingMiddleware() + }) + } + + logger.info('builder.build()', builder.buildNamed()) + return builder.build() +} + +const tagNameArray = ['think', 'thought'] + +/** + * 添加provider特定的中间件 + */ +function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void { + if (!config.provider) return + + // 根据不同provider添加特定中间件 + switch (config.provider.type) { + case 'anthropic': + // Anthropic特定中间件 + break + case 'openai': + case 'azure-openai': { + if (config.enableReasoning) { + const tagName = config.model?.id.includes('gemini') ? tagNameArray[1] : tagNameArray[0] + builder.add({ + name: 'thinking-tag-extraction', + middleware: extractReasoningMiddleware({ tagName }) + }) + } + break + } + case 'gemini': + // Gemini特定中间件 + break + default: + // 其他provider的通用处理 + break + } +} + +/** + * 添加模型特定的中间件 + */ +function addModelSpecificMiddlewares(_: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void { + if (!config.model) return + + // 可以根据模型ID或特性添加特定中间件 + // 例如:图像生成模型、多模态模型等 + + // 示例:某些模型需要特殊处理 + if (config.model.id.includes('dalle') || config.model.id.includes('midjourney')) { + // 图像生成相关中间件 + } +} + +/** + * 创建一个预配置的中间件建造者 + */ +export function createAiSdkMiddlewareBuilder(): AiSdkMiddlewareBuilder { + return new AiSdkMiddlewareBuilder() +} + +/** + * 创建一个带有默认中间件的建造者 + */ +export function createDefaultAiSdkMiddlewareBuilder(config: AiSdkMiddlewareConfig): AiSdkMiddlewareBuilder { + const builder = new AiSdkMiddlewareBuilder() + const defaultMiddlewares = buildAiSdkMiddlewares(config) + + // 将普通中间件数组转换为具名中间件并添加 + defaultMiddlewares.forEach((middleware, index) => { + builder.add({ + name: `default-middleware-${index}`, + middleware + }) + }) + + return builder +} diff --git a/src/renderer/src/aiCore/middleware/README.md b/src/renderer/src/aiCore/middleware/README.md new file mode 100644 index 0000000000..7731d263c3 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/README.md @@ -0,0 +1,140 @@ +# AI SDK 中间件建造者 + +## 概述 + +`AiSdkMiddlewareBuilder` 是一个用于动态构建 AI SDK 中间件数组的建造者模式实现。它可以根据不同的条件(如流式输出、思考模型、provider类型等)自动构建合适的中间件组合。 + +## 使用方式 + +### 基本用法 + +```typescript +import { buildAiSdkMiddlewares, type AiSdkMiddlewareConfig } from './AiSdkMiddlewareBuilder' + +// 配置中间件参数 +const config: AiSdkMiddlewareConfig = { + streamOutput: false, // 非流式输出 + onChunk: chunkHandler, // chunk回调函数 + model: currentModel, // 当前模型 + provider: currentProvider, // 当前provider + enableReasoning: true, // 启用推理 + enableTool: false, // 禁用工具 + enableWebSearch: false // 禁用网页搜索 +} + +// 构建中间件数组 +const middlewares = buildAiSdkMiddlewares(config) + +// 创建带有中间件的客户端 +const client = createClient(providerId, options, middlewares) +``` + +### 手动构建 + +```typescript +import { AiSdkMiddlewareBuilder, createAiSdkMiddlewareBuilder } from './AiSdkMiddlewareBuilder' + +const builder = createAiSdkMiddlewareBuilder() + +// 添加特定中间件 +builder.add({ + name: 'custom-middleware', + aiSdkMiddlewares: [customMiddleware()] +}) + +// 检查是否包含某个中间件 +if (builder.has('thinking-time')) { + console.log('已包含思考时间中间件') +} + +// 移除不需要的中间件 +builder.remove('simulate-streaming') + +// 构建最终数组 +const middlewares = builder.build() +``` + +## 支持的条件 + +### 1. 流式输出控制 + +- **streamOutput = false**: 自动添加 `simulateStreamingMiddleware` +- **streamOutput = true**: 使用原生流式处理 + +### 2. 思考模型处理 + +- **条件**: `onChunk` 存在 && `isReasoningModel(model)` 为 true +- **效果**: 自动添加 `thinkingTimeMiddleware` + +### 3. Provider 特定中间件 + +根据不同的 provider 类型添加特定中间件: + +- **anthropic**: Anthropic 特定处理 +- **openai**: OpenAI 特定处理 +- **gemini**: Gemini 特定处理 + +### 4. 模型特定中间件 + +根据模型特性添加中间件: + +- **图像生成模型**: 添加图像处理相关中间件 +- **多模态模型**: 添加多模态处理中间件 + +## 扩展指南 + +### 添加新的条件判断 + +在 `buildAiSdkMiddlewares` 函数中添加新的条件: + +```typescript +// 例如:添加缓存中间件 +if (config.enableCache) { + builder.add({ + name: 'cache', + aiSdkMiddlewares: [cacheMiddleware(config.cacheOptions)] + }) +} +``` + +### 添加 Provider 特定处理 + +在 `addProviderSpecificMiddlewares` 函数中添加: + +```typescript +case 'custom-provider': + builder.add({ + name: 'custom-provider-middleware', + aiSdkMiddlewares: [customProviderMiddleware()] + }) + break +``` + +### 添加模型特定处理 + +在 `addModelSpecificMiddlewares` 函数中添加: + +```typescript +if (config.model.id.includes('custom-model')) { + builder.add({ + name: 'custom-model-middleware', + aiSdkMiddlewares: [customModelMiddleware()] + }) +} +``` + +## 中间件执行顺序 + +中间件按照添加顺序执行: + +1. **simulate-streaming** (如果 streamOutput = false) +2. **thinking-time** (如果是思考模型且有 onChunk) +3. **provider-specific** (根据 provider 类型) +4. **model-specific** (根据模型类型) + +## 注意事项 + +1. 中间件的执行顺序很重要,确保按正确顺序添加 +2. 避免添加冲突的中间件 +3. 某些中间件可能有依赖关系,需要确保依赖的中间件先添加 +4. 建议在开发环境下启用日志,以便调试中间件构建过程 diff --git a/src/renderer/src/aiCore/plugins/PluginBuilder.ts b/src/renderer/src/aiCore/plugins/PluginBuilder.ts new file mode 100644 index 0000000000..792e05b240 --- /dev/null +++ b/src/renderer/src/aiCore/plugins/PluginBuilder.ts @@ -0,0 +1,80 @@ +import { AiPlugin } from '@cherrystudio/ai-core' +import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins' +import { loggerService } from '@logger' +import { getEnableDeveloperMode } from '@renderer/hooks/useSettings' +import { Assistant } from '@renderer/types' + +import { AiSdkMiddlewareConfig } from '../middleware/AiSdkMiddlewareBuilder' +import reasoningTimePlugin from './reasoningTimePlugin' +import { searchOrchestrationPlugin } from './searchOrchestrationPlugin' +import { createTelemetryPlugin } from './telemetryPlugin' + +const logger = loggerService.withContext('PluginBuilder') +/** + * 根据条件构建插件数组 + */ +export function buildPlugins( + middlewareConfig: AiSdkMiddlewareConfig & { assistant: Assistant; topicId?: string } +): AiPlugin[] { + const plugins: AiPlugin[] = [] + + if (middlewareConfig.topicId && getEnableDeveloperMode()) { + // 0. 添加 telemetry 插件 + plugins.push( + createTelemetryPlugin({ + enabled: true, + topicId: middlewareConfig.topicId, + assistant: middlewareConfig.assistant + }) + ) + } + + // 1. 模型内置搜索 + if (middlewareConfig.enableWebSearch) { + // 内置了默认搜索参数,如果改的话可以传config进去 + plugins.push(webSearchPlugin()) + } + // 2. 支持工具调用时添加搜索插件 + if (middlewareConfig.isSupportedToolUse) { + plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant, middlewareConfig.topicId || '')) + } + + // 3. 推理模型时添加推理插件 + if (middlewareConfig.enableReasoning) { + plugins.push(reasoningTimePlugin) + } + + // 4. 启用Prompt工具调用时添加工具插件 + if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) { + plugins.push( + createPromptToolUsePlugin({ + enabled: true, + createSystemMessage: (systemPrompt, params, context) => { + if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) { + if (context.isRecursiveCall) { + return null + } + params.messages = [ + { + role: 'assistant', + content: systemPrompt + }, + ...params.messages + ] + return null + } + return systemPrompt + } + }) + ) + } + + // if (!middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) { + // plugins.push(createNativeToolUsePlugin()) + // } + logger.info( + 'Final plugin list:', + plugins.map((p) => p.name) + ) + return plugins +} diff --git a/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts b/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts new file mode 100644 index 0000000000..1fe0a177c3 --- /dev/null +++ b/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts @@ -0,0 +1,56 @@ +import { definePlugin } from '@cherrystudio/ai-core' +import type { TextStreamPart, ToolSet } from 'ai' + +export default definePlugin({ + name: 'reasoningTimePlugin', + + transformStream: () => () => { + // === 时间跟踪状态 === + let thinkingStartTime = 0 + let hasStartedThinking = false + let accumulatedThinkingContent = '' + let reasoningBlockId = '' + + return new TransformStream, TextStreamPart>({ + transform(chunk: TextStreamPart, controller: TransformStreamDefaultController>) { + // === 处理 reasoning 类型 === + if (chunk.type === 'reasoning-start') { + controller.enqueue(chunk) + hasStartedThinking = true + thinkingStartTime = performance.now() + reasoningBlockId = chunk.id + } else if (chunk.type === 'reasoning-delta') { + accumulatedThinkingContent += chunk.text + controller.enqueue({ + ...chunk, + providerMetadata: { + ...chunk.providerMetadata, + metadata: { + ...chunk.providerMetadata?.metadata, + thinking_millsec: performance.now() - thinkingStartTime, + thinking_content: accumulatedThinkingContent + } + } + }) + } else if (chunk.type === 'reasoning-end' && hasStartedThinking) { + controller.enqueue({ + type: 'reasoning-end', + id: reasoningBlockId, + providerMetadata: { + metadata: { + thinking_millsec: performance.now() - thinkingStartTime, + thinking_content: accumulatedThinkingContent + } + } + }) + accumulatedThinkingContent = '' + hasStartedThinking = false + thinkingStartTime = 0 + reasoningBlockId = '' + } else { + controller.enqueue(chunk) + } + } + }) + } +}) diff --git a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts new file mode 100644 index 0000000000..3e6ca8cc3d --- /dev/null +++ b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts @@ -0,0 +1,431 @@ +/** + * 搜索编排插件 + * + * 功能: + * 1. onRequestStart: 智能意图识别 - 分析是否需要网络搜索、知识库搜索、记忆搜索 + * 2. transformParams: 根据意图分析结果动态添加对应的工具 + * 3. onRequestEnd: 自动记忆存储 + */ +import { type AiRequestContext, definePlugin } from '@cherrystudio/ai-core' +import { loggerService } from '@logger' +// import { generateObject } from '@cherrystudio/ai-core' +import { + SEARCH_SUMMARY_PROMPT, + SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY, + SEARCH_SUMMARY_PROMPT_WEB_ONLY +} from '@renderer/config/prompts' +import { getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService' +import store from '@renderer/store' +import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory' +import type { Assistant } from '@renderer/types' +import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract' +import type { ModelMessage } from 'ai' +import { isEmpty } from 'lodash' + +import { MemoryProcessor } from '../../services/MemoryProcessor' +import { knowledgeSearchTool } from '../tools/KnowledgeSearchTool' +import { memorySearchTool } from '../tools/MemorySearchTool' +import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool' + +const logger = loggerService.withContext('SearchOrchestrationPlugin') + +const getMessageContent = (message: ModelMessage) => { + if (typeof message.content === 'string') return message.content + return message.content.reduce((acc, part) => { + if (part.type === 'text') { + return acc + part.text + '\n' + } + return acc + }, '') +} + +// === Schema Definitions === + +// const WebSearchSchema = z.object({ +// question: z +// .array(z.string()) +// .describe('Search queries for web search. Use "not_needed" if no web search is required.'), +// links: z.array(z.string()).optional().describe('Specific URLs to search or summarize if mentioned in the query.') +// }) + +// const KnowledgeSearchSchema = z.object({ +// question: z +// .array(z.string()) +// .describe('Search queries for knowledge base. Use "not_needed" if no knowledge search is required.'), +// rewrite: z +// .string() +// .describe('Rewritten query with alternative phrasing while preserving original intent and meaning.') +// }) + +// const SearchIntentAnalysisSchema = z.object({ +// websearch: WebSearchSchema.optional().describe('Web search intent analysis results.'), +// knowledge: KnowledgeSearchSchema.optional().describe('Knowledge base search intent analysis results.') +// }) + +// type SearchIntentResult = z.infer + +// let isAnalyzing = false +/** + * 🧠 意图分析函数 - 使用 XML 解析 + */ +async function analyzeSearchIntent( + lastUserMessage: ModelMessage, + assistant: Assistant, + options: { + shouldWebSearch?: boolean + shouldKnowledgeSearch?: boolean + shouldMemorySearch?: boolean + lastAnswer?: ModelMessage + context: AiRequestContext & { + isAnalyzing?: boolean + } + topicId: string + } +): Promise { + const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer, context } = options + + if (!lastUserMessage) return undefined + + // 根据配置决定是否需要提取 + const needWebExtract = shouldWebSearch + const needKnowledgeExtract = shouldKnowledgeSearch + + if (!needWebExtract && !needKnowledgeExtract) return undefined + + // 选择合适的提示词 + let prompt: string + // let schema: z.Schema + + if (needWebExtract && !needKnowledgeExtract) { + prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY + // schema = z.object({ websearch: WebSearchSchema }) + } else if (!needWebExtract && needKnowledgeExtract) { + prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY + // schema = z.object({ knowledge: KnowledgeSearchSchema }) + } else { + prompt = SEARCH_SUMMARY_PROMPT + // schema = SearchIntentAnalysisSchema + } + + // 构建消息上下文 - 简化逻辑 + const chatHistory = lastAnswer ? `assistant: ${getMessageContent(lastAnswer)}` : '' + const question = getMessageContent(lastUserMessage) || '' + + // 使用模板替换变量 + const formattedPrompt = prompt.replace('{chat_history}', chatHistory).replace('{question}', question) + + // 获取模型和provider信息 + const model = assistant.model || getDefaultModel() + const provider = getProviderByModel(model) + + if (!provider || isEmpty(provider.apiKey)) { + logger.error('Provider not found or missing API key') + return getFallbackResult() + } + // console.log('formattedPrompt', schema) + try { + context.isAnalyzing = true + logger.info('Starting intent analysis generateText call', { + modelId: model.id, + topicId: options.topicId, + requestId: context.requestId, + hasWebSearch: needWebExtract, + hasKnowledgeSearch: needKnowledgeExtract + }) + + const { text: result } = await context.executor + .generateText(model.id, { + prompt: formattedPrompt + }) + .finally(() => { + context.isAnalyzing = false + logger.info('Intent analysis generateText call completed', { + modelId: model.id, + topicId: options.topicId, + requestId: context.requestId + }) + }) + const parsedResult = extractInfoFromXML(result) + logger.debug('Intent analysis result', { parsedResult }) + + // 根据需求过滤结果 + return { + websearch: needWebExtract ? parsedResult?.websearch : undefined, + knowledge: needKnowledgeExtract ? parsedResult?.knowledge : undefined + } + } catch (e: any) { + logger.error('Intent analysis failed', e as Error) + return getFallbackResult() + } + + function getFallbackResult(): ExtractResults { + const fallbackContent = getMessageContent(lastUserMessage) + return { + websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined, + knowledge: shouldKnowledgeSearch + ? { + question: [fallbackContent || 'search'], + rewrite: fallbackContent || 'search' + } + : undefined + } + } +} + +/** + * 🧠 记忆存储函数 - 基于注释代码中的 processConversationMemory + */ +async function storeConversationMemory( + messages: ModelMessage[], + assistant: Assistant, + context: AiRequestContext +): Promise { + const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState()) + + if (!globalMemoryEnabled || !assistant.enableMemory) { + // console.log('Memory storage is disabled') + return + } + + try { + const memoryConfig = selectMemoryConfig(store.getState()) + + // 转换消息为记忆处理器期望的格式 + const conversationMessages = messages + .filter((msg) => msg.role === 'user' || msg.role === 'assistant') + .map((msg) => ({ + role: msg.role, + content: getMessageContent(msg) || '' + })) + .filter((msg) => msg.content.trim().length > 0) + logger.debug('conversationMessages', conversationMessages) + if (conversationMessages.length < 2) { + logger.info('Need at least a user message and assistant response for memory processing') + return + } + + const currentUserId = selectCurrentUserId(store.getState()) + // const lastUserMessage = messages.findLast((m) => m.role === 'user') + + const processorConfig = MemoryProcessor.getProcessorConfig( + memoryConfig, + assistant.id, + currentUserId, + context.requestId + ) + + logger.info('Processing conversation memory...', { messageCount: conversationMessages.length }) + + // 后台处理对话记忆(不阻塞 UI) + const memoryProcessor = new MemoryProcessor() + memoryProcessor + .processConversation(conversationMessages, processorConfig) + .then((result) => { + logger.info('Memory processing completed:', result) + if (result.facts?.length > 0) { + logger.info('Extracted facts from conversation:', result.facts) + logger.info('Memory operations performed:', result.operations) + } else { + logger.info('No facts extracted from conversation') + } + }) + .catch((error) => { + logger.error('Background memory processing failed:', error as Error) + }) + } catch (error) { + logger.error('Error in conversation memory processing:', error as Error) + // 不抛出错误,避免影响主流程 + } +} + +/** + * 🎯 搜索编排插件 + */ +export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) => { + // 存储意图分析结果 + const intentAnalysisResults: { [requestId: string]: ExtractResults } = {} + const userMessages: { [requestId: string]: ModelMessage } = {} + let currentContext: AiRequestContext | null = null + + return definePlugin({ + name: 'search-orchestration', + enforce: 'pre', // 确保在其他插件之前执行 + + configureContext: (context: AiRequestContext) => { + if (currentContext) { + context.isAnalyzing = currentContext.isAnalyzing + } + currentContext = context + }, + + /** + * 🔍 Step 1: 意图识别阶段 + */ + onRequestStart: async (context: AiRequestContext) => { + if (context.isAnalyzing) return + + // 没开启任何搜索则不进行意图分析 + if (!(assistant.webSearchProviderId || assistant.knowledge_bases?.length || assistant.enableMemory)) return + + try { + const messages = context.originalParams.messages + if (!messages || messages.length === 0) { + return + } + + const lastUserMessage = messages[messages.length - 1] + const lastAssistantMessage = messages.length >= 2 ? messages[messages.length - 2] : undefined + + // 存储用户消息用于后续记忆存储 + userMessages[context.requestId] = lastUserMessage + + // 判断是否需要各种搜索 + const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) + const hasKnowledgeBase = !isEmpty(knowledgeBaseIds) + const knowledgeRecognition = assistant.knowledgeRecognition || 'on' + const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState()) + + const shouldWebSearch = !!assistant.webSearchProviderId + const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on' + const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory + + // 执行意图分析 + if (shouldWebSearch || hasKnowledgeBase) { + const analysisResult = await analyzeSearchIntent(lastUserMessage, assistant, { + shouldWebSearch, + shouldKnowledgeSearch, + shouldMemorySearch, + lastAnswer: lastAssistantMessage, + context, + topicId + }) + + if (analysisResult) { + intentAnalysisResults[context.requestId] = analysisResult + // logger.info('🧠 Intent analysis completed:', analysisResult) + } + } + } catch (error) { + logger.error('🧠 Intent analysis failed:', error as Error) + // 不抛出错误,让流程继续 + } + }, + + /** + * 🔧 Step 2: 工具配置阶段 + */ + transformParams: async (params: any, context: AiRequestContext) => { + if (context.isAnalyzing) return params + // logger.info('🔧 Configuring tools based on intent...', context.requestId) + + try { + const analysisResult = intentAnalysisResults[context.requestId] + // if (!analysisResult || !assistant) { + // logger.info('🔧 No analysis result or assistant, skipping tool configuration') + // return params + // } + + // 确保 tools 对象存在 + if (!params.tools) { + params.tools = {} + } + + // 🌐 网络搜索工具配置 + if (analysisResult?.websearch && assistant.webSearchProviderId) { + const needsSearch = analysisResult.websearch.question && analysisResult.websearch.question[0] !== 'not_needed' + + if (needsSearch) { + // onChunk({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS }) + // logger.info('🌐 Adding web search tool with pre-extracted keywords') + params.tools['builtin_web_search'] = webSearchToolWithPreExtractedKeywords( + assistant.webSearchProviderId, + analysisResult.websearch, + context.requestId + ) + } + } + + // 📚 知识库搜索工具配置 + const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) + const hasKnowledgeBase = !isEmpty(knowledgeBaseIds) + const knowledgeRecognition = assistant.knowledgeRecognition || 'on' + + if (hasKnowledgeBase) { + if (knowledgeRecognition === 'off') { + // off 模式:直接添加知识库搜索工具,使用用户消息作为搜索关键词 + const userMessage = userMessages[context.requestId] + const fallbackKeywords = { + question: [getMessageContent(userMessage) || 'search'], + rewrite: getMessageContent(userMessage) || 'search' + } + // logger.info('📚 Adding knowledge search tool (force mode)') + params.tools['builtin_knowledge_search'] = knowledgeSearchTool( + assistant, + fallbackKeywords, + getMessageContent(userMessage), + topicId + ) + // params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' } + } else { + // on 模式:根据意图识别结果决定是否添加工具 + const needsKnowledgeSearch = + analysisResult?.knowledge && + analysisResult.knowledge.question && + analysisResult.knowledge.question[0] !== 'not_needed' + + if (needsKnowledgeSearch && analysisResult.knowledge) { + // logger.info('📚 Adding knowledge search tool (intent-based)') + const userMessage = userMessages[context.requestId] + params.tools['builtin_knowledge_search'] = knowledgeSearchTool( + assistant, + analysisResult.knowledge, + getMessageContent(userMessage), + topicId + ) + } + } + } + + // 🧠 记忆搜索工具配置 + const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState()) + if (globalMemoryEnabled && assistant.enableMemory) { + // logger.info('🧠 Adding memory search tool') + params.tools['builtin_memory_search'] = memorySearchTool() + } + + // logger.info('🔧 Tools configured:', Object.keys(params.tools)) + return params + } catch (error) { + logger.error('🔧 Tool configuration failed:', error as Error) + return params + } + }, + + /** + * 💾 Step 3: 记忆存储阶段 + */ + + onRequestEnd: async (context: AiRequestContext) => { + // context.isAnalyzing = false + // logger.info('context.isAnalyzing', context, result) + // logger.info('💾 Starting memory storage...', context.requestId) + if (context.isAnalyzing) return + try { + const messages = context.originalParams.messages + + if (messages && assistant) { + await storeConversationMemory(messages, assistant, context) + } + + // 清理缓存 + delete intentAnalysisResults[context.requestId] + delete userMessages[context.requestId] + } catch (error) { + logger.error('💾 Memory storage failed:', error as Error) + // 不抛出错误,避免影响主流程 + } + } + }) +} + +export default searchOrchestrationPlugin diff --git a/src/renderer/src/aiCore/plugins/telemetryPlugin.ts b/src/renderer/src/aiCore/plugins/telemetryPlugin.ts new file mode 100644 index 0000000000..2083f8a098 --- /dev/null +++ b/src/renderer/src/aiCore/plugins/telemetryPlugin.ts @@ -0,0 +1,422 @@ +/** + * Telemetry Plugin for AI SDK Integration + * + * 在 transformParams 钩子中注入 experimental_telemetry 参数, + * 实现 AI SDK trace 与现有手动 trace 系统的统一 + * 集成 AiSdkSpanAdapter 将 AI SDK trace 数据转换为现有格式 + */ + +import { definePlugin } from '@cherrystudio/ai-core' +import { loggerService } from '@logger' +import { Context, context as otelContext, Span, SpanContext, trace, Tracer } from '@opentelemetry/api' +import { currentSpan } from '@renderer/services/SpanManagerService' +import { webTraceService } from '@renderer/services/WebTraceService' +import { Assistant } from '@renderer/types' + +import { AiSdkSpanAdapter } from '../trace/AiSdkSpanAdapter' + +const logger = loggerService.withContext('TelemetryPlugin') + +export interface TelemetryPluginConfig { + enabled?: boolean + recordInputs?: boolean + recordOutputs?: boolean + topicId: string + assistant: Assistant +} + +/** + * 自定义 Tracer,集成适配器转换逻辑 + */ +class AdapterTracer { + private originalTracer: Tracer + private topicId?: string + private modelName?: string + private parentSpanContext?: SpanContext + private cachedParentContext?: Context + + constructor(originalTracer: Tracer, topicId?: string, modelName?: string, parentSpanContext?: SpanContext) { + this.originalTracer = originalTracer + this.topicId = topicId + this.modelName = modelName + this.parentSpanContext = parentSpanContext + // 预构建一个包含父 SpanContext 的 Context,便于复用 + try { + this.cachedParentContext = this.parentSpanContext + ? trace.setSpanContext(otelContext.active(), this.parentSpanContext) + : undefined + } catch { + this.cachedParentContext = undefined + } + + logger.info('AdapterTracer created with parent context info', { + topicId, + modelName, + parentTraceId: this.parentSpanContext?.traceId, + parentSpanId: this.parentSpanContext?.spanId, + hasOriginalTracer: !!originalTracer + }) + } + + // startSpan(name: string, options?: any, context?: any): Span { + // // 如果提供了父 SpanContext 且未显式传入 context,则使用父上下文 + // const contextToUse = context ?? this.cachedParentContext ?? otelContext.active() + + // const span = this.originalTracer.startSpan(name, options, contextToUse) + + // // 标记父子关系,便于在转换阶段兜底重建层级 + // try { + // if (this.parentSpanContext) { + // span.setAttribute('trace.parentSpanId', this.parentSpanContext.spanId) + // span.setAttribute('trace.parentTraceId', this.parentSpanContext.traceId) + // } + // if (this.topicId) { + // span.setAttribute('trace.topicId', this.topicId) + // } + // } catch (e) { + // logger.debug('Failed to set trace parent attributes', e as Error) + // } + + // logger.info('AI SDK span created via AdapterTracer', { + // spanName: name, + // spanId: span.spanContext().spanId, + // traceId: span.spanContext().traceId, + // parentTraceId: this.parentSpanContext?.traceId, + // topicId: this.topicId, + // modelName: this.modelName, + // traceIdMatches: this.parentSpanContext ? span.spanContext().traceId === this.parentSpanContext.traceId : undefined + // }) + + // // 包装 span 的 end 方法,在结束时进行数据转换 + // const originalEnd = span.end.bind(span) + // span.end = (endTime?: any) => { + // logger.info('AI SDK span.end() called - about to convert span', { + // spanName: name, + // spanId: span.spanContext().spanId, + // traceId: span.spanContext().traceId, + // topicId: this.topicId, + // modelName: this.modelName + // }) + + // // 调用原始 end 方法 + // originalEnd(endTime) + + // // 转换并保存 span 数据 + // try { + // logger.info('Converting AI SDK span to SpanEntity', { + // spanName: name, + // spanId: span.spanContext().spanId, + // traceId: span.spanContext().traceId, + // topicId: this.topicId, + // modelName: this.modelName + // }) + // logger.info('spanspanspanspanspanspan', span) + // const spanEntity = AiSdkSpanAdapter.convertToSpanEntity({ + // span, + // topicId: this.topicId, + // modelName: this.modelName + // }) + + // // 保存转换后的数据 + // window.api.trace.saveEntity(spanEntity) + + // logger.info('AI SDK span converted and saved successfully', { + // spanName: name, + // spanId: span.spanContext().spanId, + // traceId: span.spanContext().traceId, + // topicId: this.topicId, + // modelName: this.modelName, + // hasUsage: !!spanEntity.usage, + // usage: spanEntity.usage + // }) + // } catch (error) { + // logger.error('Failed to convert AI SDK span', error as Error, { + // spanName: name, + // spanId: span.spanContext().spanId, + // traceId: span.spanContext().traceId, + // topicId: this.topicId, + // modelName: this.modelName + // }) + // } + // } + + // return span + // } + + startActiveSpan any>(name: string, fn: F): ReturnType + startActiveSpan any>(name: string, options: any, fn: F): ReturnType + startActiveSpan any>(name: string, options: any, context: any, fn: F): ReturnType + startActiveSpan any>(name: string, arg2?: any, arg3?: any, arg4?: any): ReturnType { + logger.info('AdapterTracer.startActiveSpan called', { + spanName: name, + topicId: this.topicId, + modelName: this.modelName, + argCount: arguments.length + }) + + // 包装函数来添加span转换逻辑 + const wrapFunction = (originalFn: F, span: Span): F => { + const wrappedFn = ((passedSpan: Span) => { + // 注入父子关系属性(兜底重建层级用) + try { + if (this.parentSpanContext) { + passedSpan.setAttribute('trace.parentSpanId', this.parentSpanContext.spanId) + passedSpan.setAttribute('trace.parentTraceId', this.parentSpanContext.traceId) + } + if (this.topicId) { + passedSpan.setAttribute('trace.topicId', this.topicId) + } + } catch (e) { + logger.debug('Failed to set trace parent attributes in startActiveSpan', e as Error) + } + // 包装span的end方法 + const originalEnd = span.end.bind(span) + span.end = (endTime?: any) => { + logger.info('AI SDK span.end() called in startActiveSpan - about to convert span', { + spanName: name, + spanId: span.spanContext().spanId, + traceId: span.spanContext().traceId, + topicId: this.topicId, + modelName: this.modelName + }) + + // 调用原始 end 方法 + originalEnd(endTime) + + // 转换并保存 span 数据 + try { + logger.info('Converting AI SDK span to SpanEntity (from startActiveSpan)', { + spanName: name, + spanId: span.spanContext().spanId, + traceId: span.spanContext().traceId, + topicId: this.topicId, + modelName: this.modelName + }) + logger.info('span', span) + const spanEntity = AiSdkSpanAdapter.convertToSpanEntity({ + span, + topicId: this.topicId, + modelName: this.modelName + }) + + // 保存转换后的数据 + window.api.trace.saveEntity(spanEntity) + + logger.info('AI SDK span converted and saved successfully (from startActiveSpan)', { + spanName: name, + spanId: span.spanContext().spanId, + traceId: span.spanContext().traceId, + topicId: this.topicId, + modelName: this.modelName, + hasUsage: !!spanEntity.usage, + usage: spanEntity.usage + }) + } catch (error) { + logger.error('Failed to convert AI SDK span (from startActiveSpan)', error as Error, { + spanName: name, + spanId: span.spanContext().spanId, + traceId: span.spanContext().traceId, + topicId: this.topicId, + modelName: this.modelName + }) + } + } + + return originalFn(passedSpan) + }) as F + return wrappedFn + } + + // 创建包含父 SpanContext 的上下文(如果有的话) + const createContextWithParent = () => { + if (this.cachedParentContext) { + return this.cachedParentContext + } + if (this.parentSpanContext) { + try { + const ctx = trace.setSpanContext(otelContext.active(), this.parentSpanContext) + logger.info('Created active context with parent SpanContext for startActiveSpan', { + spanName: name, + parentTraceId: this.parentSpanContext.traceId, + parentSpanId: this.parentSpanContext.spanId, + topicId: this.topicId + }) + return ctx + } catch (error) { + logger.warn('Failed to create context with parent SpanContext in startActiveSpan', error as Error) + } + } + return otelContext.active() + } + + // 根据参数数量确定调用方式,注入包含mainTraceId的上下文 + if (typeof arg2 === 'function') { + return this.originalTracer.startActiveSpan(name, {}, createContextWithParent(), (span: Span) => { + return wrapFunction(arg2, span)(span) + }) + } else if (typeof arg3 === 'function') { + return this.originalTracer.startActiveSpan(name, arg2, createContextWithParent(), (span: Span) => { + return wrapFunction(arg3, span)(span) + }) + } else if (typeof arg4 === 'function') { + // 如果调用方提供了 context,则保留以维护嵌套关系;否则回退到父上下文 + const ctx = arg3 ?? createContextWithParent() + return this.originalTracer.startActiveSpan(name, arg2, ctx, (span: Span) => { + return wrapFunction(arg4, span)(span) + }) + } else { + throw new Error('Invalid arguments for startActiveSpan') + } + } +} + +export function createTelemetryPlugin(config: TelemetryPluginConfig) { + const { enabled = true, recordInputs = true, recordOutputs = true, topicId } = config + + return definePlugin({ + name: 'telemetryPlugin', + enforce: 'pre', // 在其他插件之前执行,确保 telemetry 配置被正确注入 + + transformParams: (params, context) => { + if (!enabled) { + return params + } + + // 获取共享的 tracer + const originalTracer = webTraceService.getTracer() + if (!originalTracer) { + logger.warn('No tracer available from WebTraceService') + return params + } + + // 获取topicId和modelName + const effectiveTopicId = context.topicId || topicId + // 使用与父span创建时一致的modelName - 应该是完整的modelId + const modelName = config.assistant.model?.name || context.modelId + + // 获取当前活跃的 span,确保 AI SDK spans 与手动 spans 在同一个 trace 中 + let parentSpan: Span | undefined = undefined + let parentSpanContext: SpanContext | undefined = undefined + + // 只有在有topicId时才尝试查找父span + if (effectiveTopicId) { + try { + // 从 SpanManagerService 获取当前的 span + logger.info('Attempting to find parent span', { + topicId: effectiveTopicId, + requestId: context.requestId, + modelName: modelName, + contextModelId: context.modelId, + providerId: context.providerId + }) + + parentSpan = currentSpan(effectiveTopicId, modelName) + if (parentSpan) { + // 直接使用父 span 的 SpanContext,避免手动拼装字段遗漏 + parentSpanContext = parentSpan.spanContext() + logger.info('Found active parent span for AI SDK', { + parentSpanId: parentSpanContext.spanId, + parentTraceId: parentSpanContext.traceId, + topicId: effectiveTopicId, + requestId: context.requestId, + modelName: modelName + }) + } else { + logger.warn('No active parent span found in SpanManagerService', { + topicId: effectiveTopicId, + requestId: context.requestId, + modelId: context.modelId, + modelName: modelName, + providerId: context.providerId, + // 更详细的调试信息 + searchedModelName: modelName, + contextModelId: context.modelId, + isAnalyzing: context.isAnalyzing + }) + } + } catch (error) { + logger.error('Error getting current span from SpanManagerService', error as Error, { + topicId: effectiveTopicId, + requestId: context.requestId, + modelName: modelName + }) + } + } else { + logger.debug('No topicId provided, skipping parent span lookup', { + requestId: context.requestId, + contextTopicId: context.topicId, + configTopicId: topicId, + modelName: modelName + }) + } + + // 创建适配器包装的 tracer,传入获取到的父 SpanContext + const adapterTracer = new AdapterTracer(originalTracer, effectiveTopicId, modelName, parentSpanContext) + + // 注入 AI SDK telemetry 配置 + const telemetryConfig = { + isEnabled: true, + recordInputs, + recordOutputs, + tracer: adapterTracer, // 使用包装后的 tracer + functionId: `ai-request-${context.requestId}`, + metadata: { + providerId: context.providerId, + modelId: context.modelId, + topicId: effectiveTopicId, + requestId: context.requestId, + modelName: modelName, + // 确保topicId也作为标准属性传递 + 'trace.topicId': effectiveTopicId, + 'trace.modelName': modelName, + // 添加父span信息用于调试 + parentSpanId: parentSpanContext?.spanId, + parentTraceId: parentSpanContext?.traceId + } + } + + // 如果有父span,尝试在telemetry配置中设置父上下文 + if (parentSpan) { + try { + // 设置活跃上下文,确保 AI SDK spans 在正确的 trace 上下文中创建 + const activeContext = trace.setSpan(otelContext.active(), parentSpan) + + // 更新全局上下文 + otelContext.with(activeContext, () => { + logger.debug('Updated active context with parent span') + }) + + logger.info('Set parent context for AI SDK spans', { + parentSpanId: parentSpanContext?.spanId, + parentTraceId: parentSpanContext?.traceId, + hasActiveContext: !!activeContext, + hasParentSpan: !!parentSpan + }) + } catch (error) { + logger.warn('Failed to set parent context in telemetry config', error as Error) + } + } + + logger.info('Injecting AI SDK telemetry config with adapter', { + requestId: context.requestId, + topicId: effectiveTopicId, + modelId: context.modelId, + modelName: modelName, + hasParentSpan: !!parentSpan, + parentSpanId: parentSpanContext?.spanId, + parentTraceId: parentSpanContext?.traceId, + functionId: telemetryConfig.functionId, + hasTracer: !!telemetryConfig.tracer, + tracerType: telemetryConfig.tracer?.constructor?.name || 'unknown' + }) + + return { + ...params, + experimental_telemetry: telemetryConfig + } + } + }) +} + +// 默认导出便于使用 +export default createTelemetryPlugin diff --git a/src/renderer/src/aiCore/prepareParams/fileProcessor.ts b/src/renderer/src/aiCore/prepareParams/fileProcessor.ts new file mode 100644 index 0000000000..2defe2c711 --- /dev/null +++ b/src/renderer/src/aiCore/prepareParams/fileProcessor.ts @@ -0,0 +1,192 @@ +/** + * 文件处理模块 + * 处理文件内容提取、文件格式转换、文件上传等逻辑 + */ + +import { loggerService } from '@logger' +import { getProviderByModel } from '@renderer/services/AssistantService' +import type { FileMetadata, Message, Model } from '@renderer/types' +import { FileTypes } from '@renderer/types' +import { FileMessageBlock } from '@renderer/types/newMessage' +import { findFileBlocks } from '@renderer/utils/messageUtils/find' +import type { FilePart, TextPart } from 'ai' + +import { getAiSdkProviderId } from '../provider/factory' +import { getFileSizeLimit, supportsImageInput, supportsLargeFileUpload, supportsPdfInput } from './modelCapabilities' + +const logger = loggerService.withContext('fileProcessor') + +/** + * 提取文件内容 + */ +export async function extractFileContent(message: Message): Promise { + const fileBlocks = findFileBlocks(message) + if (fileBlocks.length > 0) { + const textFileBlocks = fileBlocks.filter( + (fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type) + ) + + if (textFileBlocks.length > 0) { + let text = '' + const divider = '\n\n---\n\n' + + for (const fileBlock of textFileBlocks) { + const file = fileBlock.file + const fileContent = (await window.api.file.read(file.id + file.ext)).trim() + const fileNameRow = 'file: ' + file.origin_name + '\n\n' + text = text + fileNameRow + fileContent + divider + } + + return text + } + } + + return '' +} + +/** + * 将文件块转换为文本部分 + */ +export async function convertFileBlockToTextPart(fileBlock: FileMessageBlock): Promise { + const file = fileBlock.file + + // 处理文本文件 + if (file.type === FileTypes.TEXT) { + try { + const fileContent = await window.api.file.read(file.id + file.ext) + return { + type: 'text', + text: `${file.origin_name}\n${fileContent.trim()}` + } + } catch (error) { + logger.warn('Failed to read text file:', error as Error) + } + } + + // 处理文档文件(PDF、Word、Excel等)- 提取为文本内容 + if (file.type === FileTypes.DOCUMENT) { + try { + const fileContent = await window.api.file.read(file.id + file.ext, true) // true表示强制文本提取 + return { + type: 'text', + text: `${file.origin_name}\n${fileContent.trim()}` + } + } catch (error) { + logger.warn(`Failed to extract text from document ${file.origin_name}:`, error as Error) + } + } + + return null +} + +/** + * 处理Gemini大文件上传 + */ +export async function handleGeminiFileUpload(file: FileMetadata, model: Model): Promise { + try { + const provider = getProviderByModel(model) + + // 检查文件是否已经上传过 + const fileMetadata = await window.api.fileService.retrieve(provider, file.id) + + if (fileMetadata.status === 'success' && fileMetadata.originalFile?.file) { + const remoteFile = fileMetadata.originalFile.file as any // 临时类型断言,因为File类型定义可能不完整 + // 注意:AI SDK的FilePart格式和Gemini原生格式不同,这里需要适配 + // 暂时返回null让它回退到文本处理,或者需要扩展FilePart支持uri + logger.info(`File ${file.origin_name} already uploaded to Gemini with URI: ${remoteFile.uri || 'unknown'}`) + return null + } + + // 如果文件未上传,执行上传 + const uploadResult = await window.api.fileService.upload(provider, file) + if (uploadResult.originalFile?.file) { + const remoteFile = uploadResult.originalFile.file as any // 临时类型断言 + logger.info(`File ${file.origin_name} uploaded to Gemini with URI: ${remoteFile.uri || 'unknown'}`) + // 同样,这里需要处理URI格式的文件引用 + return null + } + } catch (error) { + logger.error(`Failed to upload file ${file.origin_name} to Gemini:`, error as Error) + } + + return null +} + +/** + * 将文件块转换为FilePart(用于原生文件支持) + */ +export async function convertFileBlockToFilePart(fileBlock: FileMessageBlock, model: Model): Promise { + const file = fileBlock.file + const fileSizeLimit = getFileSizeLimit(model, file.type) + + try { + // 处理PDF文档 + if (file.type === FileTypes.DOCUMENT && file.ext === '.pdf' && supportsPdfInput(model)) { + // 检查文件大小限制 + if (file.size > fileSizeLimit) { + // 如果支持大文件上传(如Gemini File API),尝试上传 + if (supportsLargeFileUpload(model)) { + logger.info(`Large PDF file ${file.origin_name} (${file.size} bytes) attempting File API upload`) + const uploadResult = await handleGeminiFileUpload(file, model) + if (uploadResult) { + return uploadResult + } + // 如果上传失败,回退到文本处理 + logger.warn(`Failed to upload large PDF ${file.origin_name}, falling back to text extraction`) + return null + } else { + logger.warn(`PDF file ${file.origin_name} exceeds size limit (${file.size} > ${fileSizeLimit})`) + return null // 文件过大,回退到文本处理 + } + } + + const base64Data = await window.api.file.base64File(file.id + file.ext) + return { + type: 'file', + data: base64Data.data, + mediaType: base64Data.mime, + filename: file.origin_name + } + } + + // 处理图片文件 + if (file.type === FileTypes.IMAGE && supportsImageInput(model)) { + // 检查文件大小 + if (file.size > fileSizeLimit) { + logger.warn(`Image file ${file.origin_name} exceeds size limit (${file.size} > ${fileSizeLimit})`) + return null + } + + const base64Data = await window.api.file.base64Image(file.id + file.ext) + + // 处理MIME类型,特别是jpg->jpeg的转换(Anthropic要求) + let mediaType = base64Data.mime + const provider = getProviderByModel(model) + const aiSdkId = getAiSdkProviderId(provider) + + if (aiSdkId === 'anthropic' && mediaType === 'image/jpg') { + mediaType = 'image/jpeg' + } + + return { + type: 'file', + data: base64Data.base64, + mediaType: mediaType, + filename: file.origin_name + } + } + + // 处理其他文档类型(Word、Excel等) + if (file.type === FileTypes.DOCUMENT && file.ext !== '.pdf') { + // 目前大多数提供商不支持Word等格式的原生处理 + // 返回null会触发上层调用convertFileBlockToTextPart进行文本提取 + // 这与Legacy架构中的处理方式一致 + logger.debug(`Document file ${file.origin_name} with extension ${file.ext} will use text extraction fallback`) + return null + } + } catch (error) { + logger.warn(`Failed to process file ${file.origin_name}:`, error as Error) + } + + return null +} diff --git a/src/renderer/src/aiCore/prepareParams/index.ts b/src/renderer/src/aiCore/prepareParams/index.ts new file mode 100644 index 0000000000..ac0822e76d --- /dev/null +++ b/src/renderer/src/aiCore/prepareParams/index.ts @@ -0,0 +1,22 @@ +/** + * AI SDK 参数转换模块 - 统一入口 + * + * 此模块已重构,功能分拆到以下子模块: + * - modelParameters.ts: 基础参数处理 (温度、TopP、超时) + * - modelCapabilities.ts: 模型能力检查 (PDF、图片、文件支持) + * - fileProcessor.ts: 文件处理逻辑 (转换、上传) + * - messageConverter.ts: 消息转换核心 (单个消息转换) + * - parameterBuilder.ts: 参数构建器 (最终参数组装) + */ + +// 基础参数处理 +export { getTimeout } from './modelParameters' + +// 文件处理 +export { extractFileContent } from './fileProcessor' + +// 消息转换 +export { convertMessagesToSdkMessages, convertMessageToSdkParam } from './messageConverter' + +// 参数构建 (主要API) +export { buildGenerateTextParams, buildStreamTextParams } from './parameterBuilder' diff --git a/src/renderer/src/aiCore/prepareParams/messageConverter.ts b/src/renderer/src/aiCore/prepareParams/messageConverter.ts new file mode 100644 index 0000000000..d11f25fc2c --- /dev/null +++ b/src/renderer/src/aiCore/prepareParams/messageConverter.ts @@ -0,0 +1,166 @@ +/** + * 消息转换模块 + * 将 Cherry Studio 消息格式转换为 AI SDK 消息格式 + */ + +import { loggerService } from '@logger' +import { isVisionModel } from '@renderer/config/models' +import type { Message, Model } from '@renderer/types' +import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage' +import { + findFileBlocks, + findImageBlocks, + findThinkingBlocks, + getMainTextContent +} from '@renderer/utils/messageUtils/find' +import type { AssistantModelMessage, FilePart, ImagePart, ModelMessage, TextPart, UserModelMessage } from 'ai' + +import { convertFileBlockToFilePart, convertFileBlockToTextPart } from './fileProcessor' + +const logger = loggerService.withContext('messageConverter') + +/** + * 转换消息为 AI SDK 参数格式 + * 基于 OpenAI 格式的通用转换,支持文本、图片和文件 + */ +export async function convertMessageToSdkParam( + message: Message, + isVisionModel = false, + model?: Model +): Promise { + const content = getMainTextContent(message) + const fileBlocks = findFileBlocks(message) + const imageBlocks = findImageBlocks(message) + const reasoningBlocks = findThinkingBlocks(message) + if (message.role === 'user' || message.role === 'system') { + return convertMessageToUserModelMessage(content, fileBlocks, imageBlocks, isVisionModel, model) + } else { + return convertMessageToAssistantModelMessage(content, fileBlocks, reasoningBlocks, model) + } +} + +/** + * 转换为用户模型消息 + */ +async function convertMessageToUserModelMessage( + content: string, + fileBlocks: FileMessageBlock[], + imageBlocks: ImageMessageBlock[], + isVisionModel = false, + model?: Model +): Promise { + const parts: Array = [] + if (content) { + parts.push({ type: 'text', text: content }) + } + + // 处理图片(仅在支持视觉的模型中) + if (isVisionModel) { + for (const imageBlock of imageBlocks) { + if (imageBlock.file) { + try { + const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) + parts.push({ + type: 'image', + image: image.base64, + mediaType: image.mime + }) + } catch (error) { + logger.warn('Failed to load image:', error as Error) + } + } else if (imageBlock.url) { + parts.push({ + type: 'image', + image: imageBlock.url + }) + } + } + } + // 处理文件 + for (const fileBlock of fileBlocks) { + const file = fileBlock.file + let processed = false + + // 优先尝试原生文件支持(PDF、图片等) + if (model) { + const filePart = await convertFileBlockToFilePart(fileBlock, model) + if (filePart) { + parts.push(filePart) + logger.debug(`File ${file.origin_name} processed as native file format`) + processed = true + } + } + + // 如果原生处理失败,回退到文本提取 + if (!processed) { + const textPart = await convertFileBlockToTextPart(fileBlock) + if (textPart) { + parts.push(textPart) + logger.debug(`File ${file.origin_name} processed as text content`) + } else { + logger.warn(`File ${file.origin_name} could not be processed in any format`) + } + } + } + + return { + role: 'user', + content: parts + } +} + +/** + * 转换为助手模型消息 + */ +async function convertMessageToAssistantModelMessage( + content: string, + fileBlocks: FileMessageBlock[], + thinkingBlocks: ThinkingMessageBlock[], + model?: Model +): Promise { + const parts: Array = [] + if (content) { + parts.push({ type: 'text', text: content }) + } + + for (const thinkingBlock of thinkingBlocks) { + parts.push({ type: 'text', text: thinkingBlock.content }) + } + + for (const fileBlock of fileBlocks) { + // 优先尝试原生文件支持(PDF等) + if (model) { + const filePart = await convertFileBlockToFilePart(fileBlock, model) + if (filePart) { + parts.push(filePart) + continue + } + } + + // 回退到文本处理 + const textPart = await convertFileBlockToTextPart(fileBlock) + if (textPart) { + parts.push(textPart) + } + } + + return { + role: 'assistant', + content: parts + } +} + +/** + * 转换 Cherry Studio 消息数组为 AI SDK 消息数组 + */ +export async function convertMessagesToSdkMessages(messages: Message[], model: Model): Promise { + const sdkMessages: ModelMessage[] = [] + const isVision = isVisionModel(model) + + for (const message of messages) { + const sdkMessage = await convertMessageToSdkParam(message, isVision, model) + sdkMessages.push(sdkMessage) + } + + return sdkMessages +} diff --git a/src/renderer/src/aiCore/prepareParams/modelCapabilities.ts b/src/renderer/src/aiCore/prepareParams/modelCapabilities.ts new file mode 100644 index 0000000000..a70576ff11 --- /dev/null +++ b/src/renderer/src/aiCore/prepareParams/modelCapabilities.ts @@ -0,0 +1,73 @@ +/** + * 模型能力检查模块 + * 检查不同模型支持的功能(PDF输入、图片输入、大文件上传等) + */ + +import { isVisionModel } from '@renderer/config/models' +import { getProviderByModel } from '@renderer/services/AssistantService' +import type { Model } from '@renderer/types' +import { FileTypes } from '@renderer/types' + +import { getAiSdkProviderId } from '../provider/factory' + +/** + * 检查模型是否支持原生PDF输入 + */ +export function supportsPdfInput(model: Model): boolean { + // 基于AI SDK文档,这些提供商支持PDF输入 + const supportedProviders = [ + 'openai', + 'azure-openai', + 'anthropic', + 'google', + 'google-generative-ai', + 'google-vertex', + 'bedrock', + 'amazon-bedrock' + ] + + const provider = getProviderByModel(model) + const aiSdkId = getAiSdkProviderId(provider) + + return supportedProviders.some((provider) => aiSdkId === provider) +} + +/** + * 检查模型是否支持原生图片输入 + */ +export function supportsImageInput(model: Model): boolean { + return isVisionModel(model) +} + +/** + * 检查提供商是否支持大文件上传(如Gemini File API) + */ +export function supportsLargeFileUpload(model: Model): boolean { + const provider = getProviderByModel(model) + const aiSdkId = getAiSdkProviderId(provider) + + // 目前主要是Gemini系列支持大文件上传 + return ['google', 'google-generative-ai', 'google-vertex'].includes(aiSdkId) +} + +/** + * 获取提供商特定的文件大小限制 + */ +export function getFileSizeLimit(model: Model, fileType: FileTypes): number { + const provider = getProviderByModel(model) + const aiSdkId = getAiSdkProviderId(provider) + + // Anthropic PDF限制32MB + if (aiSdkId === 'anthropic' && fileType === FileTypes.DOCUMENT) { + return 32 * 1024 * 1024 // 32MB + } + + // Gemini小文件限制20MB(超过此限制会使用File API上传) + if (['google', 'google-generative-ai', 'google-vertex'].includes(aiSdkId)) { + return 20 * 1024 * 1024 // 20MB + } + + // 其他提供商没有明确限制,使用较大的默认值 + // 这与Legacy架构中的实现一致,让提供商自行处理文件大小 + return Infinity +} diff --git a/src/renderer/src/aiCore/prepareParams/modelParameters.ts b/src/renderer/src/aiCore/prepareParams/modelParameters.ts new file mode 100644 index 0000000000..6f78ac2cc4 --- /dev/null +++ b/src/renderer/src/aiCore/prepareParams/modelParameters.ts @@ -0,0 +1,51 @@ +/** + * 模型基础参数处理模块 + * 处理温度、TopP、超时等基础参数的获取逻辑 + */ + +import { + isClaudeReasoningModel, + isNotSupportTemperatureAndTopP, + isSupportedFlexServiceTier +} from '@renderer/config/models' +import { getAssistantSettings } from '@renderer/services/AssistantService' +import type { Assistant, Model } from '@renderer/types' +import { defaultTimeout } from '@shared/config/constant' + +/** + * 获取温度参数 + */ +export function getTemperature(assistant: Assistant, model: Model): number | undefined { + if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { + return undefined + } + if (isNotSupportTemperatureAndTopP(model)) { + return undefined + } + const assistantSettings = getAssistantSettings(assistant) + return assistantSettings?.enableTemperature ? assistantSettings?.temperature : undefined +} + +/** + * 获取 TopP 参数 + */ +export function getTopP(assistant: Assistant, model: Model): number | undefined { + if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { + return undefined + } + if (isNotSupportTemperatureAndTopP(model)) { + return undefined + } + const assistantSettings = getAssistantSettings(assistant) + return assistantSettings?.enableTopP ? assistantSettings?.topP : undefined +} + +/** + * 获取超时设置 + */ +export function getTimeout(model: Model): number { + if (isSupportedFlexServiceTier(model)) { + return 15 * 1000 * 60 + } + return defaultTimeout +} diff --git a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts new file mode 100644 index 0000000000..d72010d03b --- /dev/null +++ b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts @@ -0,0 +1,129 @@ +/** + * 参数构建模块 + * 构建AI SDK的流式和非流式参数 + */ + +import { loggerService } from '@logger' +import { + isGenerateImageModel, + isOpenRouterBuiltInWebSearchModel, + isReasoningModel, + isSupportedReasoningEffortModel, + isSupportedThinkingTokenModel, + isWebSearchModel +} from '@renderer/config/models' +import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService' +import type { Assistant, MCPTool, Provider } from '@renderer/types' +import type { StreamTextParams } from '@renderer/types/aiCoreTypes' +import type { ModelMessage } from 'ai' +import { stepCountIs } from 'ai' + +import { setupToolsConfig } from '../utils/mcp' +import { buildProviderOptions } from '../utils/options' +import { getTemperature, getTopP } from './modelParameters' + +const logger = loggerService.withContext('parameterBuilder') + +/** + * 构建 AI SDK 流式参数 + * 这是主要的参数构建函数,整合所有转换逻辑 + */ +export async function buildStreamTextParams( + sdkMessages: StreamTextParams['messages'] = [], + assistant: Assistant, + provider: Provider, + options: { + mcpTools?: MCPTool[] + webSearchProviderId?: string + requestOptions?: { + signal?: AbortSignal + timeout?: number + headers?: Record + } + } = {} +): Promise<{ + params: StreamTextParams + modelId: string + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + enableUrlContext: boolean + } +}> { + const { mcpTools } = options + + const model = assistant.model || getDefaultModel() + + const { maxTokens } = getAssistantSettings(assistant) + + // 这三个变量透传出来,交给下面启用插件/中间件 + // 也可以在外部构建好再传入buildStreamTextParams + // FIXME: qwen3即使关闭思考仍然会导致enableReasoning的结果为true + const enableReasoning = + ((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) && + assistant.settings?.reasoning_effort !== undefined) || + (isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model))) + + const enableWebSearch = + (assistant.enableWebSearch && isWebSearchModel(model)) || + isOpenRouterBuiltInWebSearchModel(model) || + model.id.includes('sonar') || + false + + const enableUrlContext = assistant.enableUrlContext || false + + const enableGenerateImage = !!(isGenerateImageModel(model) && assistant.enableGenerateImage) + + const tools = setupToolsConfig(mcpTools) + + // if (webSearchProviderId) { + // tools['builtin_web_search'] = webSearchTool(webSearchProviderId) + // } + + // 构建真正的 providerOptions + const providerOptions = buildProviderOptions(assistant, model, provider, { + enableReasoning, + enableWebSearch, + enableGenerateImage + }) + + // 构建基础参数 + const params: StreamTextParams = { + messages: sdkMessages, + maxOutputTokens: maxTokens, + temperature: getTemperature(assistant, model), + topP: getTopP(assistant, model), + abortSignal: options.requestOptions?.signal, + headers: options.requestOptions?.headers, + providerOptions, + tools, + stopWhen: stepCountIs(10), + maxRetries: 0 + } + if (assistant.prompt) { + params.system = assistant.prompt + } + logger.debug('params', params) + return { + params, + modelId: model.id, + capabilities: { enableReasoning, enableWebSearch, enableGenerateImage, enableUrlContext } + } +} + +/** + * 构建非流式的 generateText 参数 + */ +export async function buildGenerateTextParams( + messages: ModelMessage[], + assistant: Assistant, + provider: Provider, + options: { + mcpTools?: MCPTool[] + enableTools?: boolean + } = {} +): Promise { + // 复用流式参数的构建逻辑 + return await buildStreamTextParams(messages, assistant, provider, options) +} diff --git a/src/renderer/src/aiCore/provider/__tests__/integratedRegistry.test.ts b/src/renderer/src/aiCore/provider/__tests__/integratedRegistry.test.ts new file mode 100644 index 0000000000..e26597e2d1 --- /dev/null +++ b/src/renderer/src/aiCore/provider/__tests__/integratedRegistry.test.ts @@ -0,0 +1,103 @@ +import type { Provider } from '@renderer/types' +import { describe, expect, it, vi } from 'vitest' + +import { getAiSdkProviderId } from '../factory' + +// Mock the external dependencies +vi.mock('@cherrystudio/ai-core', () => ({ + registerMultipleProviders: vi.fn(() => 4), // Mock successful registration of 4 providers + getProviderMapping: vi.fn((id: string) => { + // Mock dynamic mappings + const mappings: Record = { + openrouter: 'openrouter', + 'google-vertex': 'google-vertex', + vertexai: 'google-vertex', + bedrock: 'bedrock', + 'aws-bedrock': 'bedrock', + zhipu: 'zhipu' + } + return mappings[id] + }), + AiCore: { + isSupported: vi.fn(() => true) + } +})) + +// Mock the provider configs +vi.mock('../providerConfigs', () => ({ + initializeNewProviders: vi.fn() +})) + +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn() + }) + } +})) + +function createTestProvider(id: string, type: string): Provider { + return { + id, + type, + name: `Test ${id}`, + apiKey: 'test-key', + apiHost: 'test-host' + } as Provider +} + +describe('Integrated Provider Registry', () => { + describe('Provider ID Resolution', () => { + it('should resolve openrouter provider correctly', () => { + const provider = createTestProvider('openrouter', 'openrouter') + const result = getAiSdkProviderId(provider) + expect(result).toBe('openrouter') + }) + + it('should resolve google-vertex provider correctly', () => { + const provider = createTestProvider('google-vertex', 'vertexai') + const result = getAiSdkProviderId(provider) + expect(result).toBe('google-vertex') + }) + + it('should resolve bedrock provider correctly', () => { + const provider = createTestProvider('bedrock', 'aws-bedrock') + const result = getAiSdkProviderId(provider) + expect(result).toBe('bedrock') + }) + + it('should resolve zhipu provider correctly', () => { + const provider = createTestProvider('zhipu', 'zhipu') + const result = getAiSdkProviderId(provider) + expect(result).toBe('zhipu') + }) + + it('should resolve provider type mapping correctly', () => { + const provider = createTestProvider('vertex-test', 'vertexai') + const result = getAiSdkProviderId(provider) + expect(result).toBe('google-vertex') + }) + + it('should handle static provider mappings', () => { + const geminiProvider = createTestProvider('gemini', 'gemini') + const result = getAiSdkProviderId(geminiProvider) + expect(result).toBe('google') + }) + + it('should fallback to provider.id for unknown providers', () => { + const unknownProvider = createTestProvider('unknown-provider', 'unknown-type') + const result = getAiSdkProviderId(unknownProvider) + expect(result).toBe('unknown-provider') + }) + }) + + describe('Backward Compatibility', () => { + it('should maintain compatibility with existing providers', () => { + const grokProvider = createTestProvider('grok', 'grok') + const result = getAiSdkProviderId(grokProvider) + expect(result).toBe('xai') + }) + }) +}) diff --git a/src/renderer/src/aiCore/provider/config/aihubmix.ts b/src/renderer/src/aiCore/provider/config/aihubmix.ts new file mode 100644 index 0000000000..88453ca38e --- /dev/null +++ b/src/renderer/src/aiCore/provider/config/aihubmix.ts @@ -0,0 +1,57 @@ +/** + * AiHubMix规则集 + */ +import { isOpenAIModel } from '@renderer/config/models' +import { Provider } from '@renderer/types' + +import { provider2Provider, startsWith } from './helper' +import type { RuleSet } from './types' + +const extraProviderConfig = (provider: Provider) => { + return { + ...provider, + extra_headers: { + ...provider.extra_headers, + 'APP-Code': 'MLTG2087' + } + } +} + +const AIHUBMIX_RULES: RuleSet = { + rules: [ + { + match: startsWith('claude'), + provider: (provider: Provider) => { + return extraProviderConfig({ + ...provider, + type: 'anthropic' + }) + } + }, + { + match: (model) => + (startsWith('gemini')(model) || startsWith('imagen')(model)) && + !model.id.endsWith('-nothink') && + !model.id.endsWith('-search'), + provider: (provider: Provider) => { + return extraProviderConfig({ + ...provider, + type: 'gemini', + apiHost: 'https://aihubmix.com/gemini' + }) + } + }, + { + match: isOpenAIModel, + provider: (provider: Provider) => { + return extraProviderConfig({ + ...provider, + type: 'openai-response' + }) + } + } + ], + fallbackRule: (provider: Provider) => provider +} + +export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES) diff --git a/src/renderer/src/aiCore/provider/config/helper.ts b/src/renderer/src/aiCore/provider/config/helper.ts new file mode 100644 index 0000000000..656911fc76 --- /dev/null +++ b/src/renderer/src/aiCore/provider/config/helper.ts @@ -0,0 +1,22 @@ +import type { Model, Provider } from '@renderer/types' + +import type { RuleSet } from './types' + +export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase()) +export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type + +/** + * 解析模型对应的Provider + * @param ruleSet 规则集对象 + * @param model 模型对象 + * @param provider 原始provider对象 + * @returns 解析出的provider对象 + */ +export function provider2Provider(ruleSet: RuleSet, model: Model, provider: Provider): Provider { + for (const rule of ruleSet.rules) { + if (rule.match(model)) { + return rule.provider(provider) + } + } + return ruleSet.fallbackRule(provider) +} diff --git a/src/renderer/src/aiCore/provider/config/index.ts b/src/renderer/src/aiCore/provider/config/index.ts new file mode 100644 index 0000000000..2f51234cec --- /dev/null +++ b/src/renderer/src/aiCore/provider/config/index.ts @@ -0,0 +1,3 @@ +export { aihubmixProviderCreator } from './aihubmix' +export { newApiResolverCreator } from './newApi' +export { vertexAnthropicProviderCreator } from './vertext-anthropic' diff --git a/src/renderer/src/aiCore/provider/config/newApi.ts b/src/renderer/src/aiCore/provider/config/newApi.ts new file mode 100644 index 0000000000..5277495cdb --- /dev/null +++ b/src/renderer/src/aiCore/provider/config/newApi.ts @@ -0,0 +1,51 @@ +/** + * NewAPI规则集 + */ +import { Provider } from '@renderer/types' + +import { endpointIs, provider2Provider } from './helper' +import type { RuleSet } from './types' + +const NEWAPI_RULES: RuleSet = { + rules: [ + { + match: endpointIs('anthropic'), + provider: (provider: Provider) => { + return { + ...provider, + type: 'anthropic' + } + } + }, + { + match: endpointIs('gemini'), + provider: (provider: Provider) => { + return { + ...provider, + type: 'gemini' + } + } + }, + { + match: endpointIs('openai-response'), + provider: (provider: Provider) => { + return { + ...provider, + type: 'openai-response' + } + } + }, + { + match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model), + provider: (provider: Provider) => { + return { + ...provider, + type: 'openai' + } + } + } + ], + fallbackRule: (provider: Provider) => provider +} + +export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES) diff --git a/src/renderer/src/aiCore/provider/config/types.ts b/src/renderer/src/aiCore/provider/config/types.ts new file mode 100644 index 0000000000..f3938b84d1 --- /dev/null +++ b/src/renderer/src/aiCore/provider/config/types.ts @@ -0,0 +1,9 @@ +import type { Model, Provider } from '@renderer/types' + +export interface RuleSet { + rules: Array<{ + match: (model: Model) => boolean + provider: (provider: Provider) => Provider + }> + fallbackRule: (provider: Provider) => Provider +} diff --git a/src/renderer/src/aiCore/provider/config/vertext-anthropic.ts b/src/renderer/src/aiCore/provider/config/vertext-anthropic.ts new file mode 100644 index 0000000000..23c8b5185c --- /dev/null +++ b/src/renderer/src/aiCore/provider/config/vertext-anthropic.ts @@ -0,0 +1,19 @@ +import type { Provider } from '@renderer/types' + +import { provider2Provider, startsWith } from './helper' +import type { RuleSet } from './types' + +const VERTEX_ANTHROPIC_RULES: RuleSet = { + rules: [ + { + match: startsWith('claude'), + provider: (provider: Provider) => ({ + ...provider, + id: 'google-vertex-anthropic' + }) + } + ], + fallbackRule: (provider: Provider) => provider +} + +export const vertexAnthropicProviderCreator = provider2Provider.bind(null, VERTEX_ANTHROPIC_RULES) diff --git a/src/renderer/src/aiCore/provider/factory.ts b/src/renderer/src/aiCore/provider/factory.ts new file mode 100644 index 0000000000..617758753e --- /dev/null +++ b/src/renderer/src/aiCore/provider/factory.ts @@ -0,0 +1,99 @@ +import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider' +import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' +import { loggerService } from '@logger' +import { Provider } from '@renderer/types' +import type { Provider as AiSdkProvider } from 'ai' + +import { initializeNewProviders } from './providerInitialization' + +const logger = loggerService.withContext('ProviderFactory') + +/** + * 初始化动态Provider系统 + * 在模块加载时自动注册新的providers + */ +;(async () => { + try { + await initializeNewProviders() + } catch (error) { + logger.warn('Failed to initialize new providers:', error as Error) + } +})() + +/** + * 静态Provider映射表 + * 处理Cherry Studio特有的provider ID到AI SDK标准ID的映射 + */ +const STATIC_PROVIDER_MAPPING: Record = { + gemini: 'google', // Google Gemini -> google + 'azure-openai': 'azure', // Azure OpenAI -> azure + 'openai-response': 'openai', // OpenAI Responses -> openai + grok: 'xai' // Grok -> xai +} + +/** + * 尝试解析provider标识符(支持静态映射和别名) + */ +function tryResolveProviderId(identifier: string): ProviderId | null { + // 1. 检查静态映射 + const staticMapping = STATIC_PROVIDER_MAPPING[identifier] + if (staticMapping) { + return staticMapping + } + + // 2. 检查AiCore是否支持(包括别名支持) + if (hasProviderConfigByAlias(identifier)) { + // 解析为真实的Provider ID + return resolveProviderConfigId(identifier) as ProviderId + } + + return null +} + +/** + * 获取AI SDK Provider ID + * 简化版:减少重复逻辑,利用通用解析函数 + */ +export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' { + // 1. 尝试解析provider.id + const resolvedFromId = tryResolveProviderId(provider.id) + if (resolvedFromId) { + return resolvedFromId + } + + // 2. 尝试解析provider.type + // 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上 + if (provider.type !== 'openai') { + const resolvedFromType = tryResolveProviderId(provider.type) + if (resolvedFromType) { + return resolvedFromType + } + } + // 3. 最后的fallback(通常会成为openai-compatible) + return provider.id as ProviderId +} + +export async function createAiSdkProvider(config) { + let localProvider: Awaited | null = null + try { + if (config.providerId === 'openai' && config.options?.mode === 'chat') { + config.providerId = `${config.providerId}-chat` + } else if (config.providerId === 'azure' && config.options?.mode === 'responses') { + config.providerId = `${config.providerId}-responses` + } + localProvider = await createProviderCore(config.providerId, config.options) + + logger.debug('Local provider created successfully', { + providerId: config.providerId, + hasOptions: !!config.options, + localProvider: localProvider, + options: config.options + }) + } catch (error) { + logger.error('Failed to create local provider', error as Error, { + providerId: config.providerId + }) + throw error + } + return localProvider +} diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts new file mode 100644 index 0000000000..63d1b6ed54 --- /dev/null +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -0,0 +1,267 @@ +import { + formatPrivateKey, + hasProviderConfig, + ProviderConfigFactory, + type ProviderId, + type ProviderSettingsMap +} from '@cherrystudio/ai-core/provider' +import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models' +import { + getAwsBedrockAccessKeyId, + getAwsBedrockRegion, + getAwsBedrockSecretAccessKey +} from '@renderer/hooks/useAwsBedrock' +import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI' +import { getProviderByModel } from '@renderer/services/AssistantService' +import { loggerService } from '@renderer/services/LoggerService' +import store from '@renderer/store' +import type { Model, Provider } from '@renderer/types' +import { formatApiHost } from '@renderer/utils/api' +import { cloneDeep, isEmpty } from 'lodash' + +import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config' +import { getAiSdkProviderId } from './factory' + +const logger = loggerService.withContext('ProviderConfigProcessor') + +/** + * 获取轮询的API key + * 复用legacy架构的多key轮询逻辑 + */ +function getRotatedApiKey(provider: Provider): string { + const keys = provider.apiKey.split(',').map((key) => key.trim()) + const keyName = `provider:${provider.id}:last_used_key` + + if (keys.length === 1) { + return keys[0] + } + + const lastUsedKey = window.keyv.get(keyName) + if (!lastUsedKey) { + window.keyv.set(keyName, keys[0]) + return keys[0] + } + + const currentIndex = keys.indexOf(lastUsedKey) + const nextIndex = (currentIndex + 1) % keys.length + const nextKey = keys[nextIndex] + window.keyv.set(keyName, nextKey) + + return nextKey +} + +/** + * 处理特殊provider的转换逻辑 + */ +function handleSpecialProviders(model: Model, provider: Provider): Provider { + // if (provider.type === 'vertexai' && !isVertexProvider(provider)) { + // if (!isVertexAIConfigured()) { + // throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') + // } + // return createVertexProvider(provider) + // } + + if (provider.id === 'aihubmix') { + return aihubmixProviderCreator(model, provider) + } + if (provider.id === 'newapi') { + return newApiResolverCreator(model, provider) + } + if (provider.id === 'vertexai') { + return vertexAnthropicProviderCreator(model, provider) + } + return provider +} + +/** + * 格式化provider的API Host + */ +function formatProviderApiHost(provider: Provider): Provider { + const formatted = { ...provider } + if (formatted.type === 'gemini') { + formatted.apiHost = formatApiHost(formatted.apiHost, 'v1beta') + } else { + formatted.apiHost = formatApiHost(formatted.apiHost) + } + return formatted +} + +/** + * 获取实际的Provider配置 + * 简化版:将逻辑分解为小函数 + */ +export function getActualProvider(model: Model): Provider { + const baseProvider = getProviderByModel(model) + + // 按顺序处理各种转换 + let actualProvider = cloneDeep(baseProvider) + actualProvider = handleSpecialProviders(model, actualProvider) + actualProvider = formatProviderApiHost(actualProvider) + + return actualProvider +} + +/** + * 将 Provider 配置转换为新 AI SDK 格式 + * 简化版:利用新的别名映射系统 + */ +export function providerToAiSdkConfig( + actualProvider: Provider, + model: Model +): { + providerId: ProviderId | 'openai-compatible' + options: ProviderSettingsMap[keyof ProviderSettingsMap] +} { + const aiSdkProviderId = getAiSdkProviderId(actualProvider) + logger.debug('providerToAiSdkConfig', { aiSdkProviderId }) + + // 构建基础配置 + const baseConfig = { + baseURL: actualProvider.apiHost, + apiKey: getRotatedApiKey(actualProvider) + } + // 处理OpenAI模式 + const extraOptions: any = {} + if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) { + extraOptions.mode = 'responses' + } else if (aiSdkProviderId === 'openai') { + extraOptions.mode = 'chat' + } + + // 添加额外headers + if (actualProvider.extra_headers) { + extraOptions.headers = actualProvider.extra_headers + // copy from openaiBaseClient/openaiResponseApiClient + if (aiSdkProviderId === 'openai') { + extraOptions.headers = { + ...extraOptions.headers, + 'HTTP-Referer': 'https://cherry-ai.com', + 'X-Title': 'Cherry Studio', + 'X-Api-Key': baseConfig.apiKey + } + } + } + + // copilot + if (actualProvider.id === 'copilot') { + extraOptions.headers = { + ...extraOptions.headers, + 'editor-version': 'vscode/1.97.2', + 'copilot-vision-request': 'true' + } + } + // azure + if (aiSdkProviderId === 'azure' || actualProvider.type === 'azure-openai') { + extraOptions.apiVersion = actualProvider.apiVersion + baseConfig.baseURL += '/openai' + if (actualProvider.apiVersion === 'preview') { + extraOptions.mode = 'responses' + } else { + extraOptions.mode = 'chat' + extraOptions.useDeploymentBasedUrls = true + } + } + + // bedrock + if (aiSdkProviderId === 'bedrock') { + extraOptions.region = getAwsBedrockRegion() + extraOptions.accessKeyId = getAwsBedrockAccessKeyId() + extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey() + } + // google-vertex + if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') { + if (!isVertexAIConfigured()) { + throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') + } + const { project, location, googleCredentials } = createVertexProvider(actualProvider) + extraOptions.project = project + extraOptions.location = location + extraOptions.googleCredentials = { + ...googleCredentials, + privateKey: formatPrivateKey(googleCredentials.privateKey) + } + // extraOptions.headers = window.api.vertexAI.getAuthHeaders({ + // projectId: project, + // serviceAccount: { + // privateKey: googleCredentials.privateKey, + // clientEmail: googleCredentials.clientEmail + // } + // }) + if (baseConfig.baseURL.endsWith('/v1/')) { + baseConfig.baseURL = baseConfig.baseURL.slice(0, -4) + } else if (baseConfig.baseURL.endsWith('/v1')) { + baseConfig.baseURL = baseConfig.baseURL.slice(0, -3) + } + baseConfig.baseURL = isEmpty(baseConfig.baseURL) ? '' : baseConfig.baseURL + } + + // 如果AI SDK支持该provider,使用原生配置 + if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { + const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions) + return { + providerId: aiSdkProviderId as ProviderId, + options + } + } + + // 否则fallback到openai-compatible + const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey) + return { + providerId: 'openai-compatible', + options: { + ...options, + name: actualProvider.id, + ...extraOptions + } + } +} + +/** + * 检查是否支持使用新的AI SDK + * 简化版:利用新的别名映射和动态provider系统 + */ +export function isModernSdkSupported(provider: Provider): boolean { + // 特殊检查:vertexai需要配置完整 + if (provider.type === 'vertexai' && !isVertexAIConfigured()) { + return false + } + + // 使用getAiSdkProviderId获取映射后的providerId,然后检查AI SDK是否支持 + const aiSdkProviderId = getAiSdkProviderId(provider) + + // 如果映射到了支持的provider,则支持现代SDK + return hasProviderConfig(aiSdkProviderId) +} + +/** + * 准备特殊provider的配置,主要用于异步处理的配置 + */ +export async function prepareSpecialProviderConfig( + provider: Provider, + config: ReturnType +) { + if (provider.id === 'copilot') { + const defaultHeaders = store.getState().copilot.defaultHeaders + const { token } = await window.api.copilot.getToken(defaultHeaders) + config.options.apiKey = token + } + if (provider.id === 'cherryin') { + config.options.fetch = async (url, options) => { + // 在这里对最终参数进行签名 + const signature = await window.api.cherryin.generateSignature({ + method: 'POST', + path: '/chat/completions', + query: '', + body: JSON.parse(options.body) + }) + return fetch(url, { + ...options, + headers: { + ...options.headers, + ...signature + } + }) + } + } + return config +} diff --git a/src/renderer/src/aiCore/provider/providerInitialization.ts b/src/renderer/src/aiCore/provider/providerInitialization.ts new file mode 100644 index 0000000000..cf3366d70a --- /dev/null +++ b/src/renderer/src/aiCore/provider/providerInitialization.ts @@ -0,0 +1,58 @@ +import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider' +import { loggerService } from '@logger' + +const logger = loggerService.withContext('ProviderConfigs') + +/** + * 新Provider配置定义 + * 定义了需要动态注册的AI Providers + */ +export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [ + { + id: 'openrouter', + name: 'OpenRouter', + import: () => import('@openrouter/ai-sdk-provider'), + creatorFunctionName: 'createOpenRouter', + supportsImageGeneration: true, + aliases: ['openrouter'] + }, + { + id: 'google-vertex', + name: 'Google Vertex AI', + import: () => import('@ai-sdk/google-vertex/edge'), + creatorFunctionName: 'createVertex', + supportsImageGeneration: true, + aliases: ['vertexai'] + }, + { + id: 'google-vertex-anthropic', + name: 'Google Vertex AI Anthropic', + import: () => import('@ai-sdk/google-vertex/anthropic/edge'), + creatorFunctionName: 'createVertexAnthropic', + supportsImageGeneration: true, + aliases: ['vertexai-anthropic'] + }, + { + id: 'bedrock', + name: 'Amazon Bedrock', + import: () => import('@ai-sdk/amazon-bedrock'), + creatorFunctionName: 'createAmazonBedrock', + supportsImageGeneration: true, + aliases: ['aws-bedrock'] + } +] as const + +/** + * 初始化新的Providers + * 使用aiCore的动态注册功能 + */ +export async function initializeNewProviders(): Promise { + try { + const successCount = registerMultipleProviderConfigs(NEW_PROVIDER_CONFIGS) + if (successCount < NEW_PROVIDER_CONFIGS.length) { + logger.warn('Some providers failed to register. Check previous error logs.') + } + } catch (error) { + logger.error('Failed to initialize new providers:', error as Error) + } +} diff --git a/src/renderer/src/aiCore/tools/KnowledgeSearchTool.ts b/src/renderer/src/aiCore/tools/KnowledgeSearchTool.ts new file mode 100644 index 0000000000..f156061a8c --- /dev/null +++ b/src/renderer/src/aiCore/tools/KnowledgeSearchTool.ts @@ -0,0 +1,138 @@ +import { REFERENCE_PROMPT } from '@renderer/config/prompts' +import { processKnowledgeSearch } from '@renderer/services/KnowledgeService' +import type { Assistant, KnowledgeReference } from '@renderer/types' +import { ExtractResults, KnowledgeExtractResults } from '@renderer/utils/extract' +import { type InferToolInput, type InferToolOutput, tool } from 'ai' +import { isEmpty } from 'lodash' +import { z } from 'zod' + +/** + * 知识库搜索工具 + * 使用预提取关键词,直接使用插件阶段分析的搜索意图,避免重复分析 + */ +export const knowledgeSearchTool = ( + assistant: Assistant, + extractedKeywords: KnowledgeExtractResults, + topicId: string, + userMessage?: string +) => { + return tool({ + name: 'builtin_knowledge_search', + description: `Search the knowledge base for relevant information using pre-analyzed search intent. + +Pre-extracted search queries: "${extractedKeywords.question.join(', ')}" +Rewritten query: "${extractedKeywords.rewrite}" + +This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response. + +Call this tool to execute the search. You can optionally provide additional context to refine the search.`, + + inputSchema: z.object({ + additionalContext: z + .string() + .optional() + .describe('Optional additional context or specific focus to enhance the knowledge search') + }), + + execute: async ({ additionalContext }) => { + try { + // 获取助手的知识库配置 + const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) + const hasKnowledgeBase = !isEmpty(knowledgeBaseIds) + const knowledgeRecognition = assistant.knowledgeRecognition || 'on' + + // 检查是否有知识库 + if (!hasKnowledgeBase) { + return { + summary: 'No knowledge base configured for this assistant.', + knowledgeReferences: [], + instructions: '' + } + } + + let finalQueries = [...extractedKeywords.question] + let finalRewrite = extractedKeywords.rewrite + + if (additionalContext?.trim()) { + // 如果大模型提供了额外上下文,使用更具体的描述 + const cleanContext = additionalContext.trim() + if (cleanContext) { + finalQueries = [cleanContext] + finalRewrite = cleanContext + } + } + + // 检查是否需要搜索 + if (finalQueries[0] === 'not_needed') { + return { + summary: 'No search needed based on the query analysis.', + knowledgeReferences: [], + instructions: '' + } + } + + // 构建搜索条件 + let searchCriteria: { question: string[]; rewrite: string } + + if (knowledgeRecognition === 'off') { + // 直接模式:使用用户消息内容 + const directContent = userMessage || finalQueries[0] || 'search' + searchCriteria = { + question: [directContent], + rewrite: directContent + } + } else { + // 自动模式:使用意图识别的结果 + searchCriteria = { + question: finalQueries, + rewrite: finalRewrite + } + } + + // 构建 ExtractResults 对象 + const extractResults: ExtractResults = { + websearch: undefined, + knowledge: searchCriteria + } + + // 执行知识库搜索 + const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds, topicId) + const knowledgeReferencesData = knowledgeReferences.map((ref: KnowledgeReference) => ({ + id: ref.id, + content: ref.content, + sourceUrl: ref.sourceUrl, + type: ref.type, + file: ref.file + })) + + // const referenceContent = `\`\`\`json\n${JSON.stringify(knowledgeReferencesData, null, 2)}\n\`\`\`` + // TODO 在工具函数中添加搜索缓存机制 + // const searchCacheKey = `${topicId}-${JSON.stringify(finalQueries)}` + // 可以在插件层面管理已搜索的查询,避免重复搜索 + const fullInstructions = REFERENCE_PROMPT.replace( + '{question}', + "Based on the knowledge references, please answer the user's question with proper citations." + ).replace('{references}', 'knowledgeReferences:') + + // 返回结果 + return { + summary: `Found ${knowledgeReferencesData.length} relevant sources. Use [number] format to cite specific information.`, + knowledgeReferences: knowledgeReferencesData, + instructions: fullInstructions + } + } catch (error) { + // 返回空对象而不是抛出错误,避免中断对话流程 + return { + summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`, + knowledgeReferences: [], + instructions: '' + } + } + } + }) +} + +export type KnowledgeSearchToolInput = InferToolInput> +export type KnowledgeSearchToolOutput = InferToolOutput> + +export default knowledgeSearchTool diff --git a/src/renderer/src/aiCore/tools/MemorySearchTool.ts b/src/renderer/src/aiCore/tools/MemorySearchTool.ts new file mode 100644 index 0000000000..6430692d4c --- /dev/null +++ b/src/renderer/src/aiCore/tools/MemorySearchTool.ts @@ -0,0 +1,151 @@ +import store from '@renderer/store' +import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory' +import type { Assistant } from '@renderer/types' +import { type InferToolInput, type InferToolOutput, tool } from 'ai' +import { z } from 'zod' + +import { MemoryProcessor } from '../../services/MemoryProcessor' + +/** + * 🧠 基础记忆搜索工具 + * AI 可以主动调用的简单记忆搜索 + */ +export const memorySearchTool = () => { + return tool({ + name: 'builtin_memory_search', + description: 'Search through conversation memories and stored facts for relevant context', + inputSchema: z.object({ + query: z.string().describe('Search query to find relevant memories'), + limit: z.number().min(1).max(20).default(5).describe('Maximum number of memories to return') + }), + execute: async ({ query, limit = 5 }) => { + // console.log('🧠 [memorySearchTool] Searching memories:', { query, limit }) + + try { + const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState()) + if (!globalMemoryEnabled) { + return [] + } + + const memoryConfig = selectMemoryConfig(store.getState()) + if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) { + // console.warn('Memory search skipped: embedding or LLM model not configured') + return [] + } + + const currentUserId = selectCurrentUserId(store.getState()) + const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, 'default', currentUserId) + + const memoryProcessor = new MemoryProcessor() + const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit) + + if (relevantMemories?.length > 0) { + // console.log('🧠 [memorySearchTool] Found memories:', relevantMemories.length) + return relevantMemories + } + return [] + } catch (error) { + // console.error('🧠 [memorySearchTool] Error:', error) + return [] + } + } + }) +} + +// 方案4: 为第二个工具也使用类型断言 +type MessageRole = 'user' | 'assistant' | 'system' +type MessageType = { + content: string + role: MessageRole +} +type MemorySearchWithExtractionInput = { + userMessage: MessageType + lastAnswer?: MessageType +} + +/** + * 🧠 智能记忆搜索工具(带上下文提取) + * 从用户消息和对话历史中自动提取关键词进行记忆搜索 + */ +export const memorySearchToolWithExtraction = (assistant: Assistant) => { + return tool({ + name: 'memory_search_with_extraction', + description: 'Search memories with automatic keyword extraction from conversation context', + inputSchema: z.object({ + userMessage: z.object({ + content: z.string().describe('The main content of the user message'), + role: z.enum(['user', 'assistant', 'system']).describe('Message role') + }), + lastAnswer: z + .object({ + content: z.string().describe('The main content of the last assistant response'), + role: z.enum(['user', 'assistant', 'system']).describe('Message role') + }) + .optional() + }) as z.ZodSchema, + execute: async ({ userMessage }) => { + // console.log('🧠 [memorySearchToolWithExtraction] Processing:', { userMessage, lastAnswer }) + + try { + const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState()) + if (!globalMemoryEnabled || !assistant.enableMemory) { + return { + extractedKeywords: 'Memory search disabled', + searchResults: [] + } + } + + const memoryConfig = selectMemoryConfig(store.getState()) + if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) { + // console.warn('Memory search skipped: embedding or LLM model not configured') + return { + extractedKeywords: 'Memory models not configured', + searchResults: [] + } + } + + // 🔍 使用用户消息内容作为搜索关键词 + const content = userMessage.content + + if (!content) { + return { + extractedKeywords: 'No content to search', + searchResults: [] + } + } + + const currentUserId = selectCurrentUserId(store.getState()) + const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, assistant.id, currentUserId) + + const memoryProcessor = new MemoryProcessor() + const relevantMemories = await memoryProcessor.searchRelevantMemories( + content, + processorConfig, + 5 // Limit to top 5 most relevant memories + ) + + if (relevantMemories?.length > 0) { + // console.log('🧠 [memorySearchToolWithExtraction] Found memories:', relevantMemories.length) + return { + extractedKeywords: content, + searchResults: relevantMemories + } + } + + return { + extractedKeywords: content, + searchResults: [] + } + } catch (error) { + // console.error('🧠 [memorySearchToolWithExtraction] Error:', error) + return { + extractedKeywords: 'Search failed', + searchResults: [] + } + } + } + }) +} +export type MemorySearchToolInput = InferToolInput> +export type MemorySearchToolOutput = InferToolOutput> +export type MemorySearchToolWithExtractionOutput = InferToolOutput> diff --git a/src/renderer/src/aiCore/tools/WebSearchTool.ts b/src/renderer/src/aiCore/tools/WebSearchTool.ts new file mode 100644 index 0000000000..81f64dbde5 --- /dev/null +++ b/src/renderer/src/aiCore/tools/WebSearchTool.ts @@ -0,0 +1,210 @@ +import { REFERENCE_PROMPT } from '@renderer/config/prompts' +import WebSearchService from '@renderer/services/WebSearchService' +import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' +import { ExtractResults } from '@renderer/utils/extract' +import { type InferToolInput, type InferToolOutput, tool } from 'ai' +import { z } from 'zod' + +/** + * 使用预提取关键词的网络搜索工具 + * 这个工具直接使用插件阶段分析的搜索意图,避免重复分析 + */ +export const webSearchToolWithPreExtractedKeywords = ( + webSearchProviderId: WebSearchProvider['id'], + extractedKeywords: { + question: string[] + links?: string[] + }, + requestId: string +) => { + const webSearchProvider = WebSearchService.getWebSearchProvider(webSearchProviderId) + + return tool({ + name: 'builtin_web_search', + description: `Search the web and return citable sources using pre-analyzed search intent. + +Pre-extracted search keywords: "${extractedKeywords.question.join(', ')}"${ + extractedKeywords.links + ? ` +Relevant links: ${extractedKeywords.links.join(', ')}` + : '' + } + +This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response. + +Call this tool to execute the search. You can optionally provide additional context to refine the search.`, + + inputSchema: z.object({ + additionalContext: z + .string() + .optional() + .describe('Optional additional context, keywords, or specific focus to enhance the search') + }), + + execute: async ({ additionalContext }) => { + let finalQueries = [...extractedKeywords.question] + + if (additionalContext?.trim()) { + // 如果大模型提供了额外上下文,使用更具体的描述 + const cleanContext = additionalContext.trim() + if (cleanContext) { + finalQueries = [cleanContext] + } + } + + let searchResults: WebSearchProviderResponse = { + query: '', + results: [] + } + // 检查是否需要搜索 + if (finalQueries[0] === 'not_needed') { + return { + summary: 'No search needed based on the query analysis.', + searchResults, + sources: '', + instructions: '' + } + } + + try { + // 构建 ExtractResults 结构用于 processWebsearch + const extractResults: ExtractResults = { + websearch: { + question: finalQueries, + links: extractedKeywords.links + } + } + searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId) + } catch (error) { + return { + summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`, + sources: [], + instructions: '' + } + } + if (searchResults.results.length === 0) { + return { + summary: 'No search results found for the given query.', + sources: [], + instructions: '' + } + } + + const results = searchResults.results + const citationData = results.map((result, index) => ({ + number: index + 1, + title: result.title, + content: result.content, + url: result.url + })) + + // 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑 + // const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\`` + + // 构建完整的引用指导文本 + const fullInstructions = REFERENCE_PROMPT.replace( + '{question}', + "Based on the search results, please answer the user's question with proper citations." + ).replace('{references}', 'searchResults:') + + return { + summary: `Found ${citationData.length} relevant sources. Use [number] format to cite specific information.`, + searchResults, + instructions: fullInstructions + } + } + }) +} + +// export const webSearchToolWithExtraction = ( +// webSearchProviderId: WebSearchProvider['id'], +// requestId: string, +// assistant: Assistant +// ) => { +// const webSearchService = WebSearchService.getInstance(webSearchProviderId) + +// return tool({ +// name: 'web_search_with_extraction', +// description: 'Search the web for information with automatic keyword extraction from user messages', +// inputSchema: z.object({ +// userMessage: z.object({ +// content: z.string().describe('The main content of the message'), +// role: z.enum(['user', 'assistant', 'system']).describe('Message role') +// }), +// lastAnswer: z.object({ +// content: z.string().describe('The main content of the message'), +// role: z.enum(['user', 'assistant', 'system']).describe('Message role') +// }) +// }), +// outputSchema: z.object({ +// extractedKeywords: z.object({ +// question: z.array(z.string()), +// links: z.array(z.string()).optional() +// }), +// searchResults: z.array( +// z.object({ +// query: z.string(), +// results: WebSearchProviderResult +// }) +// ) +// }), +// execute: async ({ userMessage, lastAnswer }) => { +// const lastUserMessage: Message = { +// id: requestId, +// role: userMessage.role, +// assistantId: assistant.id, +// topicId: 'temp', +// createdAt: new Date().toISOString(), +// status: UserMessageStatus.SUCCESS, +// blocks: [] +// } + +// const lastAnswerMessage: Message | undefined = lastAnswer +// ? { +// id: requestId + '_answer', +// role: lastAnswer.role, +// assistantId: assistant.id, +// topicId: 'temp', +// createdAt: new Date().toISOString(), +// status: UserMessageStatus.SUCCESS, +// blocks: [] +// } +// : undefined + +// const extractResults = await extractSearchKeywords(lastUserMessage, assistant, { +// shouldWebSearch: true, +// shouldKnowledgeSearch: false, +// lastAnswer: lastAnswerMessage +// }) + +// if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') { +// return 'No search needed or extraction failed' +// } + +// const searchQueries = extractResults.websearch.question +// const searchResults: Array<{ query: string; results: any }> = [] + +// for (const query of searchQueries) { +// // 构建单个查询的ExtractResults结构 +// const queryExtractResults: ExtractResults = { +// websearch: { +// question: [query], +// links: extractResults.websearch.links +// } +// } +// const response = await webSearchService.processWebsearch(queryExtractResults, requestId) +// searchResults.push({ +// query, +// results: response +// }) +// } + +// return { extractedKeywords: extractResults.websearch, searchResults } +// } +// }) +// } + +// export type WebSearchToolWithExtractionOutput = InferToolOutput> + +export type WebSearchToolOutput = InferToolOutput> +export type WebSearchToolInput = InferToolInput> diff --git a/src/renderer/src/aiCore/trace/AiSdkSpanAdapter.ts b/src/renderer/src/aiCore/trace/AiSdkSpanAdapter.ts new file mode 100644 index 0000000000..fc844b5fe5 --- /dev/null +++ b/src/renderer/src/aiCore/trace/AiSdkSpanAdapter.ts @@ -0,0 +1,655 @@ +/** + * AI SDK Span Adapter + * + * 将 AI SDK 的 telemetry 数据转换为现有的 SpanEntity 格式 + * 注意 AI SDK 的层级结构:ai.xxx 是一个层级,ai.xxx.xxx 是对应层级下的子集 + */ + +import { loggerService } from '@logger' +import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core' +import { Span, SpanKind, SpanStatusCode } from '@opentelemetry/api' + +const logger = loggerService.withContext('AiSdkSpanAdapter') + +export interface AiSdkSpanData { + span: Span + topicId?: string + modelName?: string +} + +// 扩展接口用于访问span的内部数据 +interface SpanWithInternals extends Span { + _spanProcessor?: any + _attributes?: Record + _events?: any[] + name?: string + startTime?: [number, number] + endTime?: [number, number] | null + status?: { code: SpanStatusCode; message?: string } + kind?: SpanKind + ended?: boolean + parentSpanId?: string + links?: any[] +} + +export class AiSdkSpanAdapter { + /** + * 将 AI SDK span 转换为 SpanEntity 格式 + */ + static convertToSpanEntity(spanData: AiSdkSpanData): SpanEntity { + const { span, topicId, modelName } = spanData + const spanContext = span.spanContext() + + // 尝试从不同方式获取span数据 + const spanWithInternals = span as SpanWithInternals + let attributes: Record = {} + let events: any[] = [] + let spanName = 'unknown' + let spanStatus = { code: SpanStatusCode.UNSET } + let spanKind = SpanKind.INTERNAL + let startTime: [number, number] = [0, 0] + let endTime: [number, number] | null = null + let ended = false + let parentSpanId = '' + let links: any[] = [] + + // 详细记录span的结构信息用于调试 + logger.debug('Debugging span structure', { + hasInternalAttributes: !!spanWithInternals._attributes, + hasGetAttributes: typeof (span as any).getAttributes === 'function', + spanKeys: Object.keys(span), + spanInternalKeys: Object.keys(spanWithInternals), + spanContext: span.spanContext(), + // 尝试获取所有可能的属性路径 + attributesPath1: spanWithInternals._attributes, + attributesPath2: (span as any).attributes, + attributesPath3: (span as any)._spanData?.attributes, + attributesPath4: (span as any).resource?.attributes + }) + + // 尝试多种方式获取attributes + if (spanWithInternals._attributes) { + attributes = spanWithInternals._attributes + logger.debug('Found attributes via _attributes', { attributeCount: Object.keys(attributes).length }) + } else if (typeof (span as any).getAttributes === 'function') { + attributes = (span as any).getAttributes() + logger.debug('Found attributes via getAttributes()', { attributeCount: Object.keys(attributes).length }) + } else if ((span as any).attributes) { + attributes = (span as any).attributes + logger.debug('Found attributes via direct attributes property', { + attributeCount: Object.keys(attributes).length + }) + } else if ((span as any)._spanData?.attributes) { + attributes = (span as any)._spanData.attributes + logger.debug('Found attributes via _spanData.attributes', { attributeCount: Object.keys(attributes).length }) + } else { + // 尝试从span的其他属性获取 + logger.warn('无法获取span attributes,尝试备用方法', { + availableKeys: Object.keys(span), + spanType: span.constructor.name + }) + } + + // 获取其他属性 + if (spanWithInternals._events) { + events = spanWithInternals._events + } + if (spanWithInternals.name) { + spanName = spanWithInternals.name + } + if (spanWithInternals.status) { + spanStatus = spanWithInternals.status + } + if (spanWithInternals.kind !== undefined) { + spanKind = spanWithInternals.kind + } + if (spanWithInternals.startTime) { + startTime = spanWithInternals.startTime + } + if (spanWithInternals.endTime) { + endTime = spanWithInternals.endTime + } + if (spanWithInternals.ended !== undefined) { + ended = spanWithInternals.ended + } + if (spanWithInternals.parentSpanId) { + parentSpanId = spanWithInternals.parentSpanId + } + // 兜底:尝试从 attributes 中读取我们注入的父信息 + if (!parentSpanId && attributes['trace.parentSpanId']) { + parentSpanId = attributes['trace.parentSpanId'] + } + if (spanWithInternals.links) { + links = spanWithInternals.links + } + + // 提取 AI SDK 特有的数据 + const tokenUsage = this.extractTokenUsage(attributes) + const { inputs, outputs } = this.extractInputsOutputs(attributes) + const formattedSpanName = this.formatSpanName(spanName) + const spanTag = this.extractSpanTag(spanName, attributes) + const typeSpecificData = this.extractSpanTypeSpecificData(attributes) + + // 详细记录转换过程 + const operationId = attributes['ai.operationId'] + logger.info('Converting AI SDK span to SpanEntity', { + spanName: spanName, + operationId, + spanTag, + hasTokenUsage: !!tokenUsage, + hasInputs: !!inputs, + hasOutputs: !!outputs, + hasTypeSpecificData: Object.keys(typeSpecificData).length > 0, + attributesCount: Object.keys(attributes).length, + topicId, + modelName, + spanId: spanContext.spanId, + traceId: spanContext.traceId + }) + + if (tokenUsage) { + logger.info('Token usage data found', { + spanName: spanName, + operationId, + usage: tokenUsage, + spanId: spanContext.spanId + }) + } + + if (inputs || outputs) { + logger.info('Input/Output data extracted', { + spanName: spanName, + operationId, + hasInputs: !!inputs, + hasOutputs: !!outputs, + inputKeys: inputs ? Object.keys(inputs) : [], + outputKeys: outputs ? Object.keys(outputs) : [], + spanId: spanContext.spanId + }) + } + + if (Object.keys(typeSpecificData).length > 0) { + logger.info('Type-specific data extracted', { + spanName: spanName, + operationId, + typeSpecificKeys: Object.keys(typeSpecificData), + spanId: spanContext.spanId + }) + } + + // 转换为 SpanEntity 格式 + const spanEntity: SpanEntity = { + id: spanContext.spanId, + name: formattedSpanName, + parentId: parentSpanId, + traceId: spanContext.traceId, + status: this.convertSpanStatus(spanStatus.code), + kind: this.convertSpanKind(spanKind), + attributes: { + ...this.filterRelevantAttributes(attributes), + ...typeSpecificData, + inputs: inputs, + outputs: outputs, + tags: spanTag, + modelName: modelName || this.extractModelFromAttributes(attributes) || '' + }, + isEnd: ended, + events: events, + startTime: this.convertTimestamp(startTime), + endTime: endTime ? this.convertTimestamp(endTime) : null, + links: links, + topicId: topicId, + usage: tokenUsage, + modelName: modelName || this.extractModelFromAttributes(attributes) + } + + logger.info('AI SDK span successfully converted to SpanEntity', { + spanName: spanName, + operationId, + spanId: spanContext.spanId, + traceId: spanContext.traceId, + topicId, + entityId: spanEntity.id, + hasUsage: !!spanEntity.usage, + status: spanEntity.status, + tags: spanEntity.attributes?.tags + }) + + return spanEntity + } + + /** + * 从 AI SDK attributes 中提取 token usage + * 支持多种格式: + * - AI SDK 标准格式: ai.usage.completionTokens, ai.usage.promptTokens + * - 完整usage对象格式: ai.usage (JSON字符串或对象) + */ + private static extractTokenUsage(attributes: Record): TokenUsage | undefined { + logger.debug('Extracting token usage from attributes', { + attributeKeys: Object.keys(attributes), + usageRelatedKeys: Object.keys(attributes).filter((key) => key.includes('usage') || key.includes('token')), + fullAttributes: attributes + }) + + const inputsTokenKeys = [ + // base span + 'ai.usage.promptTokens', + // LLM span + 'gen_ai.usage.input_tokens' + ] + const outputTokenKeys = [ + // base span + 'ai.usage.completionTokens', + // LLM span + 'gen_ai.usage.output_tokens' + ] + + const completionTokens = attributes[inputsTokenKeys.find((key) => attributes[key]) || ''] + const promptTokens = attributes[outputTokenKeys.find((key) => attributes[key]) || ''] + + if (completionTokens !== undefined || promptTokens !== undefined) { + const usage: TokenUsage = { + prompt_tokens: Number(promptTokens) || 0, + completion_tokens: Number(completionTokens) || 0, + total_tokens: (Number(promptTokens) || 0) + (Number(completionTokens) || 0) + } + + logger.debug('Extracted token usage from AI SDK standard attributes', { + usage, + foundStandardAttributes: { + 'ai.usage.completionTokens': completionTokens, + 'ai.usage.promptTokens': promptTokens + } + }) + + return usage + } + + // 对于不包含token usage的spans(如tool calls),这是正常的 + logger.debug('No token usage found in span attributes (normal for tool calls)', { + availableKeys: Object.keys(attributes), + usageKeys: Object.keys(attributes).filter((key) => key.includes('usage') || key.includes('token')) + }) + + return undefined + } + + /** + * 从 AI SDK attributes 中提取 inputs 和 outputs + * 根据AI SDK文档按不同span类型精确映射 + */ + private static extractInputsOutputs(attributes: Record): { inputs?: any; outputs?: any } { + const operationId = attributes['ai.operationId'] + let inputs: any = undefined + let outputs: any = undefined + + logger.debug('Extracting inputs/outputs by operation type', { + operationId, + availableKeys: Object.keys(attributes).filter( + (key) => key.includes('prompt') || key.includes('response') || key.includes('toolCall') + ) + }) + + // 根据AI SDK文档按操作类型提取数据 + switch (operationId) { + case 'ai.generateText': + case 'ai.streamText': + // 顶层LLM spans: ai.prompt 包含输入 + inputs = { + prompt: this.parseAttributeValue(attributes['ai.prompt']) + } + outputs = this.extractLLMOutputs(attributes) + break + + case 'ai.generateText.doGenerate': + case 'ai.streamText.doStream': + // Provider spans: ai.prompt.messages 包含详细输入 + inputs = { + messages: this.parseAttributeValue(attributes['ai.prompt.messages']), + tools: this.parseAttributeValue(attributes['ai.prompt.tools']), + toolChoice: this.parseAttributeValue(attributes['ai.prompt.toolChoice']) + } + outputs = this.extractProviderOutputs(attributes) + break + + case 'ai.toolCall': + // Tool call spans: 工具参数和结果 + inputs = { + toolName: attributes['ai.toolCall.name'], + toolId: attributes['ai.toolCall.id'], + args: this.parseAttributeValue(attributes['ai.toolCall.args']) + } + outputs = { + result: this.parseAttributeValue(attributes['ai.toolCall.result']) + } + break + + default: + // 回退到通用逻辑 + inputs = this.extractGenericInputs(attributes) + outputs = this.extractGenericOutputs(attributes) + break + } + + logger.debug('Extracted inputs/outputs', { + operationId, + hasInputs: !!inputs, + hasOutputs: !!outputs, + inputKeys: inputs ? Object.keys(inputs) : [], + outputKeys: outputs ? Object.keys(outputs) : [] + }) + + return { inputs, outputs } + } + + /** + * 提取LLM顶层spans的输出 + */ + private static extractLLMOutputs(attributes: Record): any { + const outputs: any = {} + + if (attributes['ai.response.text']) { + outputs.text = attributes['ai.response.text'] + } + if (attributes['ai.response.toolCalls']) { + outputs.toolCalls = this.parseAttributeValue(attributes['ai.response.toolCalls']) + } + if (attributes['ai.response.finishReason']) { + outputs.finishReason = attributes['ai.response.finishReason'] + } + if (attributes['ai.settings.maxOutputTokens']) { + outputs.maxOutputTokens = attributes['ai.settings.maxOutputTokens'] + } + + return Object.keys(outputs).length > 0 ? outputs : undefined + } + + /** + * 提取Provider spans的输出 + */ + private static extractProviderOutputs(attributes: Record): any { + const outputs: any = {} + + if (attributes['ai.response.text']) { + outputs.text = attributes['ai.response.text'] + } + if (attributes['ai.response.toolCalls']) { + outputs.toolCalls = this.parseAttributeValue(attributes['ai.response.toolCalls']) + } + if (attributes['ai.response.finishReason']) { + outputs.finishReason = attributes['ai.response.finishReason'] + } + + // doStream特有的性能指标 + if (attributes['ai.response.msToFirstChunk']) { + outputs.msToFirstChunk = attributes['ai.response.msToFirstChunk'] + } + if (attributes['ai.response.msToFinish']) { + outputs.msToFinish = attributes['ai.response.msToFinish'] + } + if (attributes['ai.response.avgCompletionTokensPerSecond']) { + outputs.avgCompletionTokensPerSecond = attributes['ai.response.avgCompletionTokensPerSecond'] + } + + return Object.keys(outputs).length > 0 ? outputs : undefined + } + + /** + * 通用输入提取(回退逻辑) + */ + private static extractGenericInputs(attributes: Record): any { + const inputKeys = ['ai.prompt', 'ai.prompt.messages', 'ai.request', 'inputs'] + + for (const key of inputKeys) { + if (attributes[key]) { + return this.parseAttributeValue(attributes[key]) + } + } + return undefined + } + + /** + * 通用输出提取(回退逻辑) + */ + private static extractGenericOutputs(attributes: Record): any { + const outputKeys = ['ai.response.text', 'ai.response', 'ai.output', 'outputs'] + + for (const key of outputKeys) { + if (attributes[key]) { + return this.parseAttributeValue(attributes[key]) + } + } + return undefined + } + + /** + * 解析属性值,处理字符串化的 JSON + */ + private static parseAttributeValue(value: any): any { + if (typeof value === 'string') { + try { + return JSON.parse(value) + } catch (e) { + return value + } + } + return value + } + + /** + * 格式化 span 名称,处理 AI SDK 的层级结构 + */ + private static formatSpanName(name: string): string { + // AI SDK 的 span 名称可能是 ai.generateText, ai.streamText.doStream 等 + // 保持原始名称,但可以添加一些格式化逻辑 + if (name.startsWith('ai.')) { + return name + } + return name + } + + /** + * 从AI SDK operationId中提取精确的span标签 + */ + private static extractSpanTag(name: string, attributes: Record): string { + const operationId = attributes['ai.operationId'] + + logger.debug('Extracting span tag', { + spanName: name, + operationId, + operationName: attributes['operation.name'] + }) + + // 根据AI SDK文档的operationId精确分类 + switch (operationId) { + case 'ai.generateText': + return 'LLM-GENERATE' + case 'ai.streamText': + return 'LLM-STREAM' + case 'ai.generateText.doGenerate': + return 'PROVIDER-GENERATE' + case 'ai.streamText.doStream': + return 'PROVIDER-STREAM' + case 'ai.toolCall': + return 'TOOL-CALL' + case 'ai.generateImage': + return 'IMAGE' + case 'ai.embed': + return 'EMBEDDING' + default: + // 回退逻辑:基于span名称 + if (name.includes('generateText') || name.includes('streamText')) { + return 'LLM' + } + if (name.includes('generateImage')) { + return 'IMAGE' + } + if (name.includes('embed')) { + return 'EMBEDDING' + } + if (name.includes('toolCall')) { + return 'TOOL' + } + + // 最终回退 + return attributes['ai.operationType'] || attributes['operation.type'] || 'AI_SDK' + } + } + + /** + * 根据span类型提取特定的额外数据 + */ + private static extractSpanTypeSpecificData(attributes: Record): Record { + const operationId = attributes['ai.operationId'] + const specificData: Record = {} + + switch (operationId) { + case 'ai.generateText': + case 'ai.streamText': + // LLM顶层spans的特定数据 + if (attributes['ai.settings.maxOutputTokens']) { + specificData.maxOutputTokens = attributes['ai.settings.maxOutputTokens'] + } + if (attributes['resource.name']) { + specificData.functionId = attributes['resource.name'] + } + break + + case 'ai.generateText.doGenerate': + case 'ai.streamText.doStream': + // Provider spans的特定数据 + if (attributes['ai.model.id']) { + specificData.providerId = attributes['ai.model.provider'] || 'unknown' + specificData.modelId = attributes['ai.model.id'] + } + + // doStream特有的性能数据 + if (operationId === 'ai.streamText.doStream') { + if (attributes['ai.response.msToFirstChunk']) { + specificData.msToFirstChunk = attributes['ai.response.msToFirstChunk'] + } + if (attributes['ai.response.msToFinish']) { + specificData.msToFinish = attributes['ai.response.msToFinish'] + } + if (attributes['ai.response.avgCompletionTokensPerSecond']) { + specificData.tokensPerSecond = attributes['ai.response.avgCompletionTokensPerSecond'] + } + } + break + + case 'ai.toolCall': + // Tool call spans的特定数据 + specificData.toolName = attributes['ai.toolCall.name'] + specificData.toolId = attributes['ai.toolCall.id'] + + // 根据文档,tool call可能有不同的操作类型 + if (attributes['operation.name']) { + specificData.operationName = attributes['operation.name'] + } + break + + default: + // 通用的AI SDK属性 + if (attributes['ai.telemetry.functionId']) { + specificData.telemetryFunctionId = attributes['ai.telemetry.functionId'] + } + if (attributes['ai.telemetry.metadata']) { + specificData.telemetryMetadata = this.parseAttributeValue(attributes['ai.telemetry.metadata']) + } + break + } + + // 添加通用的操作标识 + if (operationId) { + specificData.operationType = operationId + } + if (attributes['operation.name']) { + specificData.operationName = attributes['operation.name'] + } + + logger.debug('Extracted type-specific data', { + operationId, + specificDataKeys: Object.keys(specificData), + specificData + }) + + return specificData + } + + /** + * 从属性中提取模型名称 + */ + private static extractModelFromAttributes(attributes: Record): string | undefined { + return ( + attributes['ai.model.id'] || + attributes['ai.model'] || + attributes['model.id'] || + attributes['model'] || + attributes['modelName'] + ) + } + + /** + * 过滤相关的属性,移除不需要的系统属性 + */ + private static filterRelevantAttributes(attributes: Record): Record { + const filtered: Record = {} + + // 保留有用的属性,过滤掉已经单独处理的属性 + const excludeKeys = ['ai.usage', 'ai.prompt', 'ai.response', 'ai.input', 'ai.output', 'inputs', 'outputs'] + + Object.entries(attributes).forEach(([key, value]) => { + if (!excludeKeys.includes(key)) { + filtered[key] = value + } + }) + + return filtered + } + + /** + * 转换 span 状态 + */ + private static convertSpanStatus(statusCode?: SpanStatusCode): string { + switch (statusCode) { + case SpanStatusCode.OK: + return 'OK' + case SpanStatusCode.ERROR: + return 'ERROR' + case SpanStatusCode.UNSET: + default: + return 'UNSET' + } + } + + /** + * 转换 span 类型 + */ + private static convertSpanKind(kind?: SpanKind): string { + switch (kind) { + case SpanKind.INTERNAL: + return 'INTERNAL' + case SpanKind.CLIENT: + return 'CLIENT' + case SpanKind.SERVER: + return 'SERVER' + case SpanKind.PRODUCER: + return 'PRODUCER' + case SpanKind.CONSUMER: + return 'CONSUMER' + default: + return 'INTERNAL' + } + } + + /** + * 转换时间戳格式 + */ + private static convertTimestamp(timestamp: [number, number] | number): number { + if (Array.isArray(timestamp)) { + // OpenTelemetry 高精度时间戳 [seconds, nanoseconds] + return timestamp[0] * 1000 + timestamp[1] / 1000000 + } + return timestamp + } +} diff --git a/src/renderer/src/aiCore/utils/image.ts b/src/renderer/src/aiCore/utils/image.ts new file mode 100644 index 0000000000..7691f9d4b1 --- /dev/null +++ b/src/renderer/src/aiCore/utils/image.ts @@ -0,0 +1,5 @@ +export function buildGeminiGenerateImageParams(): Record { + return { + responseModalities: ['TEXT', 'IMAGE'] + } +} diff --git a/src/renderer/src/aiCore/utils/mcp.ts b/src/renderer/src/aiCore/utils/mcp.ts new file mode 100644 index 0000000000..01f7adbf50 --- /dev/null +++ b/src/renderer/src/aiCore/utils/mcp.ts @@ -0,0 +1,99 @@ +import { loggerService } from '@logger' +// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types' +import { MCPTool, MCPToolResponse } from '@renderer/types' +import { Chunk, ChunkType } from '@renderer/types/chunk' +import { callMCPTool, getMcpServerByTool, isToolAutoApproved } from '@renderer/utils/mcp-tools' +import { requestToolConfirmation } from '@renderer/utils/userConfirmation' +import { type Tool, type ToolSet } from 'ai' +import { jsonSchema, tool } from 'ai' +import { JSONSchema7 } from 'json-schema' + +const logger = loggerService.withContext('MCP-utils') + +// Setup tools configuration based on provided parameters +export function setupToolsConfig(mcpTools?: MCPTool[]): Record | undefined { + let tools: ToolSet = {} + + if (!mcpTools?.length) { + return undefined + } + + tools = convertMcpToolsToAiSdkTools(mcpTools) + + return tools +} + +/** + * 将 MCPTool 转换为 AI SDK 工具格式 + */ +export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet { + const tools: ToolSet = {} + + for (const mcpTool of mcpTools) { + tools[mcpTool.name] = tool({ + description: mcpTool.description || `Tool from ${mcpTool.serverName}`, + inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7), + execute: async (params, { toolCallId, experimental_context }) => { + const { onChunk } = experimental_context as { onChunk: (chunk: Chunk) => void } + // 创建适配的 MCPToolResponse 对象 + const toolResponse: MCPToolResponse = { + id: toolCallId, + tool: mcpTool, + arguments: params, + status: 'pending', + toolCallId + } + + try { + // 检查是否启用自动批准 + const server = getMcpServerByTool(mcpTool) + const isAutoApproveEnabled = isToolAutoApproved(mcpTool, server) + + let confirmed = true + if (!isAutoApproveEnabled) { + // 请求用户确认 + logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`) + confirmed = await requestToolConfirmation(toolResponse.id) + } + + if (!confirmed) { + // 用户拒绝执行工具 + logger.debug(`User cancelled tool execution: ${mcpTool.name}`) + return { + content: [ + { + type: 'text', + text: `User declined to execute tool "${mcpTool.name}".` + } + ], + isError: false + } + } + + // 用户确认或自动批准,执行工具 + toolResponse.status = 'invoking' + logger.debug(`Executing tool: ${mcpTool.name}`) + + onChunk({ + type: ChunkType.MCP_TOOL_IN_PROGRESS, + responses: [toolResponse] + }) + + const result = await callMCPTool(toolResponse) + + // 返回结果,AI SDK 会处理序列化 + if (result.isError) { + throw new Error(result.content?.[0]?.text || 'Tool execution failed') + } + // 返回工具执行结果 + return result + } catch (error) { + logger.error(`MCP Tool execution failed: ${mcpTool.name}`, { error }) + throw error + } + } + }) + } + + return tools +} diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts new file mode 100644 index 0000000000..a244619264 --- /dev/null +++ b/src/renderer/src/aiCore/utils/options.ts @@ -0,0 +1,303 @@ +import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-core/provider' +import { isOpenAIModel, isQwenMTModel, isSupportFlexServiceTierModel } from '@renderer/config/models' +import { isSupportServiceTierProvider } from '@renderer/config/providers' +import { mapLanguageToQwenMTModel } from '@renderer/config/translate' +import { + Assistant, + GroqServiceTiers, + isGroqServiceTier, + isOpenAIServiceTier, + isTranslateAssistant, + Model, + OpenAIServiceTiers, + Provider, + SystemProviderIds +} from '@renderer/types' +import { t } from 'i18next' + +import { getAiSdkProviderId } from '../provider/factory' +import { buildGeminiGenerateImageParams } from './image' +import { + getAnthropicReasoningParams, + getCustomParameters, + getGeminiReasoningParams, + getOpenAIReasoningParams, + getReasoningEffort, + getXAIReasoningParams +} from './reasoning' +import { getWebSearchParams } from './websearch' + +// copy from BaseApiClient.ts +const getServiceTier = (model: Model, provider: Provider) => { + const serviceTierSetting = provider.serviceTier + + if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) { + return undefined + } + + // 处理不同供应商需要 fallback 到默认值的情况 + if (provider.id === SystemProviderIds.groq) { + if ( + !isGroqServiceTier(serviceTierSetting) || + (serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model)) + ) { + return undefined + } + } else { + // 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同 + if ( + !isOpenAIServiceTier(serviceTierSetting) || + (serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model)) + ) { + return undefined + } + } + + return serviceTierSetting +} + +/** + * 构建 AI SDK 的 providerOptions + * 按 provider 类型分离,保持类型安全 + * 返回格式:{ 'providerId': providerOptions } + */ +export function buildProviderOptions( + assistant: Assistant, + model: Model, + actualProvider: Provider, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + } +): Record { + const rawProviderId = getAiSdkProviderId(actualProvider) + // 构建 provider 特定的选项 + let providerSpecificOptions: Record = {} + const serviceTierSetting = getServiceTier(model, actualProvider) + providerSpecificOptions.serviceTier = serviceTierSetting + // 根据 provider 类型分离构建逻辑 + const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId) + if (success) { + // 应该覆盖所有类型 + switch (baseProviderId) { + case 'openai': + case 'azure': + providerSpecificOptions = { + ...buildOpenAIProviderOptions(assistant, model, capabilities), + serviceTier: serviceTierSetting + } + break + + case 'anthropic': + providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities) + break + + case 'google': + providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities) + break + + case 'xai': + providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities) + break + case 'deepseek': + case 'openai-compatible': + // 对于其他 provider,使用通用的构建逻辑 + providerSpecificOptions = { + ...buildGenericProviderOptions(assistant, model, capabilities), + serviceTier: serviceTierSetting + } + break + default: + throw new Error(`Unsupported base provider ${baseProviderId}`) + } + } else { + // 处理自定义 provider + const { data: providerId, success, error } = customProviderIdSchema.safeParse(rawProviderId) + if (success) { + switch (providerId) { + // 非 base provider 的单独处理逻辑 + case 'google-vertex': + providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities) + break + default: + // 对于其他 provider,使用通用的构建逻辑 + providerSpecificOptions = { + ...buildGenericProviderOptions(assistant, model, capabilities), + serviceTier: serviceTierSetting + } + } + } else { + throw error + } + } + + // 合并自定义参数到 provider 特定的选项中 + providerSpecificOptions = { + ...providerSpecificOptions, + ...getCustomParameters(assistant) + } + + // 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } + return { + [rawProviderId]: providerSpecificOptions + } +} + +/** + * 构建 OpenAI 特定的 providerOptions + */ +function buildOpenAIProviderOptions( + assistant: Assistant, + model: Model, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + } +): Record { + const { enableReasoning } = capabilities + let providerOptions: Record = {} + // OpenAI 推理参数 + if (enableReasoning) { + const reasoningParams = getOpenAIReasoningParams(assistant, model) + providerOptions = { + ...providerOptions, + ...reasoningParams + } + } + return providerOptions +} + +/** + * 构建 Anthropic 特定的 providerOptions + */ +function buildAnthropicProviderOptions( + assistant: Assistant, + model: Model, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + } +): Record { + const { enableReasoning } = capabilities + let providerOptions: Record = {} + + // Anthropic 推理参数 + if (enableReasoning) { + const reasoningParams = getAnthropicReasoningParams(assistant, model) + providerOptions = { + ...providerOptions, + ...reasoningParams + } + } + + return providerOptions +} + +/** + * 构建 Gemini 特定的 providerOptions + */ +function buildGeminiProviderOptions( + assistant: Assistant, + model: Model, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + } +): Record { + const { enableReasoning, enableGenerateImage } = capabilities + let providerOptions: Record = {} + + // Gemini 推理参数 + if (enableReasoning) { + const reasoningParams = getGeminiReasoningParams(assistant, model) + providerOptions = { + ...providerOptions, + ...reasoningParams + } + } + + if (enableGenerateImage) { + providerOptions = { + ...providerOptions, + ...buildGeminiGenerateImageParams() + } + } + + return providerOptions +} + +function buildXAIProviderOptions( + assistant: Assistant, + model: Model, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + } +): Record { + const { enableReasoning } = capabilities + let providerOptions: Record = {} + + if (enableReasoning) { + const reasoningParams = getXAIReasoningParams(assistant, model) + providerOptions = { + ...providerOptions, + ...reasoningParams + } + } + + return providerOptions +} + +/** + * 构建通用的 providerOptions(用于其他 provider) + */ +function buildGenericProviderOptions( + assistant: Assistant, + model: Model, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + } +): Record { + const { enableWebSearch } = capabilities + let providerOptions: Record = {} + + const reasoningParams = getReasoningEffort(assistant, model) + providerOptions = { + ...providerOptions, + ...reasoningParams + } + + if (enableWebSearch) { + const webSearchParams = getWebSearchParams(model) + providerOptions = { + ...providerOptions, + ...webSearchParams + } + } + + // 特殊处理 Qwen MT + if (isQwenMTModel(model)) { + if (isTranslateAssistant(assistant)) { + const targetLanguage = assistant.targetLanguage + const translationOptions = { + source_lang: 'auto', + target_lang: mapLanguageToQwenMTModel(targetLanguage) + } as const + if (!translationOptions.target_lang) { + throw new Error(t('translate.error.not_supported', { language: targetLanguage.value })) + } + providerOptions.translation_options = translationOptions + } else { + throw new Error(t('translate.error.chat_qwen_mt')) + } + } + + return providerOptions +} diff --git a/src/renderer/src/aiCore/utils/reasoning.ts b/src/renderer/src/aiCore/utils/reasoning.ts new file mode 100644 index 0000000000..507b2cd9ce --- /dev/null +++ b/src/renderer/src/aiCore/utils/reasoning.ts @@ -0,0 +1,444 @@ +import { loggerService } from '@logger' +import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' +import { + findTokenLimit, + GEMINI_FLASH_MODEL_REGEX, + getThinkModelType, + isDeepSeekHybridInferenceModel, + isDoubaoThinkingAutoModel, + isGrokReasoningModel, + isOpenAIReasoningModel, + isQwenAlwaysThinkModel, + isQwenReasoningModel, + isReasoningModel, + isSupportedReasoningEffortGrokModel, + isSupportedReasoningEffortModel, + isSupportedReasoningEffortOpenAIModel, + isSupportedThinkingTokenClaudeModel, + isSupportedThinkingTokenDoubaoModel, + isSupportedThinkingTokenGeminiModel, + isSupportedThinkingTokenHunyuanModel, + isSupportedThinkingTokenModel, + isSupportedThinkingTokenQwenModel, + isSupportedThinkingTokenZhipuModel, + MODEL_SUPPORTED_REASONING_EFFORT +} from '@renderer/config/models' +import { isSupportEnableThinkingProvider } from '@renderer/config/providers' +import { getStoreSetting } from '@renderer/hooks/useSettings' +import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService' +import { SettingsState } from '@renderer/store/settings' +import { Assistant, EFFORT_RATIO, isSystemProvider, Model, SystemProviderIds } from '@renderer/types' +import { ReasoningEffortOptionalParams } from '@renderer/types/sdk' + +const logger = loggerService.withContext('reasoning') + +// The function is only for generic provider. May extract some logics to independent provider +export function getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams { + const provider = getProviderByModel(model) + if (provider.id === 'groq') { + return {} + } + + if (!isReasoningModel(model)) { + return {} + } + const reasoningEffort = assistant?.settings?.reasoning_effort + + if (!reasoningEffort) { + // openrouter: use reasoning + if (model.provider === SystemProviderIds.openrouter) { + // Don't disable reasoning for Gemini models that support thinking tokens + if (isSupportedThinkingTokenGeminiModel(model) && !GEMINI_FLASH_MODEL_REGEX.test(model.id)) { + return {} + } + // Don't disable reasoning for models that require it + if (isGrokReasoningModel(model) || isOpenAIReasoningModel(model)) { + return {} + } + return { reasoning: { enabled: false, exclude: true } } + } + + // providers that use enable_thinking + if ( + isSupportEnableThinkingProvider(provider) && + (isSupportedThinkingTokenQwenModel(model) || + isSupportedThinkingTokenHunyuanModel(model) || + (provider.id === SystemProviderIds.dashscope && isDeepSeekHybridInferenceModel(model))) + ) { + return { enable_thinking: false } + } + + // claude + if (isSupportedThinkingTokenClaudeModel(model)) { + return {} + } + + // gemini + if (isSupportedThinkingTokenGeminiModel(model)) { + if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) { + return { + extra_body: { + google: { + thinking_config: { + thinking_budget: 0 + } + } + } + } + } + return {} + } + + // use thinking, doubao, zhipu, etc. + if (isSupportedThinkingTokenDoubaoModel(model) || isSupportedThinkingTokenZhipuModel(model)) { + return { thinking: { type: 'disabled' } } + } + + return {} + } + + // reasoningEffort有效的情况 + // DeepSeek hybrid inference models, v3.1 and maybe more in the future + // 不同的 provider 有不同的思考控制方式,在这里统一解决 + if (isDeepSeekHybridInferenceModel(model)) { + if (isSystemProvider(provider)) { + switch (provider.id) { + case SystemProviderIds.dashscope: + return { + enable_thinking: true, + incremental_output: true + } + case SystemProviderIds.silicon: + return { + enable_thinking: true + } + case SystemProviderIds.doubao: + return { + thinking: { + type: 'enabled' // auto is invalid + } + } + case SystemProviderIds.openrouter: + return { + reasoning: { + enabled: true + } + } + case 'nvidia': + return { + chat_template_kwargs: { + thinking: true + } + } + default: + logger.warn( + `Skipping thinking options for provider ${provider.name} as DeepSeek v3.1 thinking control method is unknown` + ) + } + } + } + + // OpenRouter models + if (model.provider === SystemProviderIds.openrouter) { + if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) { + return { + reasoning: { + effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort + } + } + } + } + + // Doubao 思考模式支持 + if (isSupportedThinkingTokenDoubaoModel(model)) { + // reasoningEffort 为空,默认开启 enabled + if (reasoningEffort === 'high') { + return { thinking: { type: 'enabled' } } + } + if (reasoningEffort === 'auto' && isDoubaoThinkingAutoModel(model)) { + return { thinking: { type: 'auto' } } + } + // 其他情况不带 thinking 字段 + return {} + } + + const effortRatio = EFFORT_RATIO[reasoningEffort] + const budgetTokens = Math.floor( + (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min! + ) + + // OpenRouter models, use thinking + if (model.provider === SystemProviderIds.openrouter) { + if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) { + return { + reasoning: { + effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort + } + } + } + } + + // Qwen models, use enable_thinking + if (isQwenReasoningModel(model)) { + const thinkConfig = { + enable_thinking: isQwenAlwaysThinkModel(model) || !isSupportEnableThinkingProvider(provider) ? undefined : true, + thinking_budget: budgetTokens + } + if (provider.id === SystemProviderIds.dashscope) { + return { + ...thinkConfig, + incremental_output: true + } + } + return thinkConfig + } + + // Hunyuan models, use enable_thinking + if (isSupportedThinkingTokenHunyuanModel(model) && isSupportEnableThinkingProvider(provider)) { + return { + enable_thinking: true + } + } + + // Grok models/Perplexity models/OpenAI models, use reasoning_effort + if (isSupportedReasoningEffortModel(model)) { + // 检查模型是否支持所选选项 + const modelType = getThinkModelType(model) + const supportedOptions = MODEL_SUPPORTED_REASONING_EFFORT[modelType] + if (supportedOptions.includes(reasoningEffort)) { + return { + reasoning_effort: reasoningEffort + } + } else { + // 如果不支持,fallback到第一个支持的值 + return { + reasoning_effort: supportedOptions[0] + } + } + } + + // gemini series, openai compatible api + if (isSupportedThinkingTokenGeminiModel(model)) { + if (reasoningEffort === 'auto') { + return { + extra_body: { + google: { + thinking_config: { + thinking_budget: -1, + include_thoughts: true + } + } + } + } + } + return { + extra_body: { + google: { + thinking_config: { + thinking_budget: budgetTokens, + include_thoughts: true + } + } + } + } + } + + // Claude models, openai compatible api + if (isSupportedThinkingTokenClaudeModel(model)) { + const maxTokens = assistant.settings?.maxTokens + return { + thinking: { + type: 'enabled', + budget_tokens: Math.floor( + Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio)) + ) + } + } + } + + // Use thinking, doubao, zhipu, etc. + if (isSupportedThinkingTokenDoubaoModel(model)) { + if (assistant.settings?.reasoning_effort === 'high') { + return { + thinking: { + type: 'enabled' + } + } + } + } + if (isSupportedThinkingTokenZhipuModel(model)) { + return { thinking: { type: 'enabled' } } + } + + // Default case: no special thinking settings + return {} +} + +/** + * 获取 OpenAI 推理参数 + * 从 OpenAIResponseAPIClient 和 OpenAIAPIClient 中提取的逻辑 + */ +export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Record { + if (!isReasoningModel(model)) { + return {} + } + const openAI = getStoreSetting('openAI') as SettingsState['openAI'] + const summaryText = openAI?.summaryText || 'off' + + let reasoningSummary: string | undefined = undefined + + if (summaryText === 'off' || model.id.includes('o1-pro')) { + reasoningSummary = undefined + } else { + reasoningSummary = summaryText + } + + const reasoningEffort = assistant?.settings?.reasoning_effort + + if (!reasoningEffort) { + return {} + } + + // OpenAI 推理参数 + if (isSupportedReasoningEffortOpenAIModel(model)) { + return { + reasoningEffort, + reasoningSummary + } + } + + return {} +} + +/** + * 获取 Anthropic 推理参数 + * 从 AnthropicAPIClient 中提取的逻辑 + */ +export function getAnthropicReasoningParams(assistant: Assistant, model: Model): Record { + if (!isReasoningModel(model)) { + return {} + } + + const reasoningEffort = assistant?.settings?.reasoning_effort + + if (reasoningEffort === undefined) { + return { + thinking: { + type: 'disabled' + } + } + } + + // Claude 推理参数 + if (isSupportedThinkingTokenClaudeModel(model)) { + const { maxTokens } = getAssistantSettings(assistant) + const effortRatio = EFFORT_RATIO[reasoningEffort] + + const budgetTokens = Math.max( + 1024, + Math.floor( + Math.min( + (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + + findTokenLimit(model.id)?.min!, + (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio + ) + ) + ) + + return { + thinking: { + type: 'enabled', + budgetTokens: budgetTokens + } + } + } + + return {} +} + +/** + * 获取 Gemini 推理参数 + * 从 GeminiAPIClient 中提取的逻辑 + */ +export function getGeminiReasoningParams(assistant: Assistant, model: Model): Record { + if (!isReasoningModel(model)) { + return {} + } + + const reasoningEffort = assistant?.settings?.reasoning_effort + + // Gemini 推理参数 + if (isSupportedThinkingTokenGeminiModel(model)) { + if (reasoningEffort === undefined) { + return { + thinkingConfig: { + includeThoughts: false, + ...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {}) + } + } + } + + const effortRatio = EFFORT_RATIO[reasoningEffort] + + if (effortRatio > 1) { + return { + thinkingConfig: { + includeThoughts: true + } + } + } + + const { min, max } = findTokenLimit(model.id) || { min: 0, max: 0 } + const budget = Math.floor((max - min) * effortRatio + min) + + return { + thinkingConfig: { + ...(budget > 0 ? { thinkingBudget: budget } : {}), + includeThoughts: true + } + } + } + + return {} +} + +export function getXAIReasoningParams(assistant: Assistant, model: Model): Record { + if (!isSupportedReasoningEffortGrokModel(model)) { + return {} + } + + const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant) + + return { + reasoningEffort + } +} + +/** + * 获取自定义参数 + * 从 assistant 设置中提取自定义参数 + */ +export function getCustomParameters(assistant: Assistant): Record { + return ( + assistant?.settings?.customParameters?.reduce((acc, param) => { + if (!param.name?.trim()) { + return acc + } + if (param.type === 'json') { + const value = param.value as string + if (value === 'undefined') { + return { ...acc, [param.name]: undefined } + } + try { + return { ...acc, [param.name]: JSON.parse(value) } + } catch { + return { ...acc, [param.name]: value } + } + } + return { + ...acc, + [param.name]: param.value + } + }, {}) || {} + ) +} diff --git a/src/renderer/src/aiCore/utils/websearch.ts b/src/renderer/src/aiCore/utils/websearch.ts new file mode 100644 index 0000000000..d2d0345826 --- /dev/null +++ b/src/renderer/src/aiCore/utils/websearch.ts @@ -0,0 +1,31 @@ +import { isOpenAIWebSearchChatCompletionOnlyModel } from '@renderer/config/models' +import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from '@renderer/config/prompts' +import { Model } from '@renderer/types' + +export function getWebSearchParams(model: Model): Record { + if (model.provider === 'hunyuan') { + return { enable_enhancement: true, citation: true, search_info: true } + } + + if (model.provider === 'dashscope') { + return { + enable_search: true, + search_options: { + forced_search: true + } + } + } + + if (isOpenAIWebSearchChatCompletionOnlyModel(model)) { + return { + web_search_options: {} + } + } + + if (model.provider === 'openrouter') { + return { + plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }] + } + } + return {} +} diff --git a/src/renderer/src/components/CodeViewer.tsx b/src/renderer/src/components/CodeViewer.tsx index e8080d9518..ac7a14e0ac 100644 --- a/src/renderer/src/components/CodeViewer.tsx +++ b/src/renderer/src/components/CodeViewer.tsx @@ -19,7 +19,6 @@ interface CodeViewerProps { * - Supports shiki aliases: c#/csharp, objective-c++/obj-c++/objc++, etc. */ language: string - /** Fired when the editor height changes. */ onHeightChange?: (scrollHeight: number) => void /** * Height of the scroll container. diff --git a/src/renderer/src/components/Spinner.tsx b/src/renderer/src/components/Spinner.tsx index 5495115056..77b6007438 100644 --- a/src/renderer/src/components/Spinner.tsx +++ b/src/renderer/src/components/Spinner.tsx @@ -3,7 +3,7 @@ import { motion } from 'motion/react' import styled from 'styled-components' interface Props { - text: string + text: React.ReactNode } // Define variants for the spinner animation @@ -33,34 +33,12 @@ export default function Spinner({ text }: Props) { ) } - -// const baseContainer = css` -// display: flex; -// flex-direction: row; -// align-items: center; -// ` - -// const Container = styled.div` -// ${baseContainer} -// background-color: var(--color-background-mute); -// padding: 10px; -// border-radius: 10px; -// margin-bottom: 10px; -// gap: 10px; -// ` - -// const StatusText = styled.div` -// font-size: 14px; -// line-height: 1.6; -// text-decoration: none; -// color: var(--color-text-1); -// ` const SearchWrapper = styled.div` display: flex; align-items: center; gap: 4px; - font-size: 14px; - padding: 10px; - padding-left: 0; + /* font-size: 14px; */ + padding: 0px; + /* padding-left: 0; */ ` const Searching = motion.create(SearchWrapper) diff --git a/src/renderer/src/components/__tests__/__snapshots__/Spinner.test.tsx.snap b/src/renderer/src/components/__tests__/__snapshots__/Spinner.test.tsx.snap index aa374d932f..c05a945d24 100644 --- a/src/renderer/src/components/__tests__/__snapshots__/Spinner.test.tsx.snap +++ b/src/renderer/src/components/__tests__/__snapshots__/Spinner.test.tsx.snap @@ -5,9 +5,7 @@ exports[`Spinner > should match snapshot 1`] = ` display: flex; align-items: center; gap: 4px; - font-size: 14px; - padding: 10px; - padding-left: 0; + padding: 0px; }
{ if (!isEnableWebSearch) { return {} @@ -3046,8 +3059,6 @@ export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boole return { tools: webSearchTools } - - return {} } export function isGemmaModel(model?: Model): boolean { diff --git a/src/renderer/src/config/providers.ts b/src/renderer/src/config/providers.ts index 194d4c20b7..e1037c52de 100644 --- a/src/renderer/src/config/providers.ts +++ b/src/renderer/src/config/providers.ts @@ -271,7 +271,7 @@ export const SYSTEM_PROVIDERS_CONFIG: Record = name: 'Anthropic', type: 'anthropic', apiKey: '', - apiHost: 'https://api.anthropic.com/', + apiHost: 'https://api.anthropic.com', models: SYSTEM_MODELS.anthropic, isSystem: true, enabled: false diff --git a/src/renderer/src/databases/upgrades.ts b/src/renderer/src/databases/upgrades.ts index b1d722d3b4..79a36419de 100644 --- a/src/renderer/src/databases/upgrades.ts +++ b/src/renderer/src/databases/upgrades.ts @@ -136,7 +136,7 @@ export async function upgradeToV7(tx: Transaction): Promise { content: mcpTool.response, error: mcpTool.status !== 'done' - ? { message: 'MCP Tool did not complete', originalStatus: mcpTool.status } + ? { message: 'MCP Tool did not complete', originalStatus: mcpTool.status, name: null, stack: null } : undefined, createdAt: oldMessage.createdAt, metadata: { rawMcpToolResponse: mcpTool } @@ -263,10 +263,18 @@ export async function upgradeToV7(tx: Transaction): Promise { // 10. Error Block (Status is ERROR) if (oldMessage.error && typeof oldMessage.error === 'object' && Object.keys(oldMessage.error).length > 0) { if (isEmpty(oldMessage.content)) { - const block = createErrorBlock(oldMessage.id, oldMessage.error, { - createdAt: oldMessage.createdAt, - status: MessageBlockStatus.ERROR // Error block status is ERROR - }) + const block = createErrorBlock( + oldMessage.id, + { + message: oldMessage.error?.message ?? null, + name: oldMessage.error?.name ?? null, + stack: oldMessage.error?.stack ?? null + }, + { + createdAt: oldMessage.createdAt, + status: MessageBlockStatus.ERROR // Error block status is ERROR + } + ) blocksToCreate.push(block) messageBlockIds.push(block.id) } diff --git a/src/renderer/src/hooks/useMessageOperations.ts b/src/renderer/src/hooks/useMessageOperations.ts index e1b7ea30ff..b1836f3fa7 100644 --- a/src/renderer/src/hooks/useMessageOperations.ts +++ b/src/renderer/src/hooks/useMessageOperations.ts @@ -1,7 +1,7 @@ import { loggerService } from '@logger' import { createSelector } from '@reduxjs/toolkit' import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService' -import { appendTrace, pauseTrace, restartTrace } from '@renderer/services/SpanManagerService' +import { appendMessageTrace, pauseTrace, restartTrace } from '@renderer/services/SpanManagerService' import { estimateUserPromptUsage } from '@renderer/services/TokenService' import store, { type RootState, useAppDispatch, useAppSelector } from '@renderer/store' import { updateOneBlock } from '@renderer/store/messageBlock' @@ -178,7 +178,7 @@ export function useMessageOperations(topic: Topic) { */ const appendAssistantResponse = useCallback( async (existingAssistantMessage: Message, newModel: Model, assistant: Assistant) => { - await appendTrace(existingAssistantMessage, newModel) + await appendMessageTrace(existingAssistantMessage, newModel) if (existingAssistantMessage.role !== 'assistant') { logger.error('appendAssistantResponse should only be called for an existing assistant message.') return diff --git a/src/renderer/src/hooks/useVertexAI.ts b/src/renderer/src/hooks/useVertexAI.ts index 89769c0d54..a1cfd9b130 100644 --- a/src/renderer/src/hooks/useVertexAI.ts +++ b/src/renderer/src/hooks/useVertexAI.ts @@ -5,6 +5,7 @@ import { setVertexAIServiceAccountClientEmail, setVertexAIServiceAccountPrivateKey } from '@renderer/store/llm' +import { Provider, VertexProvider } from '@renderer/types' import { useDispatch } from 'react-redux' export function useVertexAISettings() { @@ -20,6 +21,7 @@ export function useVertexAISettings() { } } +// FIXME: 这些redux设置状态被服务层使用,这是不应该的。 export function getVertexAISettings() { return store.getState().llm.settings.vertexai } @@ -35,3 +37,43 @@ export function getVertexAIProjectId() { export function getVertexAIServiceAccount() { return store.getState().llm.settings.vertexai.serviceAccount } + +/** + * 类型守卫:检查 Provider 是否为 VertexProvider + */ +export function isVertexProvider(provider: Provider): provider is VertexProvider { + return provider.type === 'vertexai' && 'googleCredentials' in provider +} + +/** + * 创建 VertexProvider 对象,整合单独的配置 + * @param baseProvider 基础的 provider 配置 + * @returns VertexProvider 对象 + */ +export function createVertexProvider(baseProvider: Provider): VertexProvider { + const settings = getVertexAISettings() + + return { + ...baseProvider, + type: 'vertexai' as const, + googleCredentials: { + clientEmail: settings.serviceAccount.clientEmail, + privateKey: settings.serviceAccount.privateKey + }, + project: settings.projectId, + location: settings.location + } +} + +/** + * 检查 VertexAI 配置是否完整 + */ +export function isVertexAIConfigured(): boolean { + const settings = getVertexAISettings() + return !!( + settings.serviceAccount.clientEmail && + settings.serviceAccount.privateKey && + settings.projectId && + settings.location + ) +} diff --git a/src/renderer/src/i18n/locales/en-us.json b/src/renderer/src/i18n/locales/en-us.json index 057abc94ec..41882c82ea 100644 --- a/src/renderer/src/i18n/locales/en-us.json +++ b/src/renderer/src/i18n/locales/en-us.json @@ -747,6 +747,7 @@ "delete": "Delete", "delete_confirm": "Are you sure you want to delete?", "description": "Description", + "detail": "Detail", "disabled": "Disabled", "docs": "Docs", "download": "Download", @@ -830,6 +831,7 @@ "invalid": "Invalid MCP server" } }, + "cause": "Error cause", "chat": { "chunk": { "non_json": "Returned an invalid data format" @@ -839,6 +841,9 @@ "quota_exceeded": "Your daily {{quota}} free quota has been exhausted. Please go to the {{provider}} to obtain an API key and configure the API key to continue using.", "response": "Something went wrong. Please check if you have set your API key in the Settings > Providers" }, + "data": "data", + "detail": "Error Details", + "details": "Details", "http": { "400": "Request failed. Please check if the request parameters are correct. If you have changed the model settings, please reset them to the default settings", "401": "Authentication failed. Please check if your API key is correct", @@ -850,11 +855,13 @@ "503": "Service unavailable. Please try again later", "504": "Gateway timeout. Please try again later" }, + "message": "Error Message", "missing_user_message": "Cannot switch model response: The original user message has been deleted. Please send a new message to get a response with this model.", "model": { "exists": "Model already exists", "not_exists": "Model does not exist" }, + "name": "Error name", "no_api_key": "API key is not configured", "pause_placeholder": "Paused", "provider_disabled": "Model provider is not enabled", @@ -862,6 +869,14 @@ "description": "Failed to render message content. Please check if the message content format is correct", "title": "Render Error" }, + "requestBody": "Request Body", + "requestBodyValues": "Request Body Values", + "requestUrl": "Request URL", + "responseBody": "Response Body", + "responseHeaders": "Response Header", + "stack": "Stack Trace", + "status": "Status Code", + "statusCode": "Status code", "unknown": "Unknown error", "user_message_not_found": "Cannot find original user message to resend" }, @@ -1448,7 +1463,7 @@ }, "websearch": { "cutoff": "Truncating search content...", - "fetch_complete": "Completed {{count}} searches...", + "fetch_complete": "{{count}} search result(s)", "rag": "Executing RAG...", "rag_complete": "Keeping {{countAfter}} out of {{countBefore}} results...", "rag_failed": "RAG failed, returning empty results..." diff --git a/src/renderer/src/i18n/locales/ja-jp.json b/src/renderer/src/i18n/locales/ja-jp.json index ebf88ef989..90b2eb32b9 100644 --- a/src/renderer/src/i18n/locales/ja-jp.json +++ b/src/renderer/src/i18n/locales/ja-jp.json @@ -747,6 +747,7 @@ "delete": "削除", "delete_confirm": "削除してもよろしいですか?", "description": "説明", + "detail": "詳細", "disabled": "無効", "docs": "ドキュメント", "download": "ダウンロード", @@ -830,6 +831,7 @@ "invalid": "無効なMCPサーバー" } }, + "cause": "エラーの原因", "chat": { "chunk": { "non_json": "無効なデータ形式が返されました" @@ -839,6 +841,9 @@ "quota_exceeded": "本日の{{quota}}無料クォータが使い果たされました。{{provider}}でAPIキーを取得し、APIキーを設定して使用を続けてください。", "response": "エラーが発生しました。APIキーが設定されていない場合は、設定 > プロバイダーでキーを設定してください" }, + "data": "データ", + "detail": "エラーの詳細", + "details": "詳細", "http": { "400": "リクエストに失敗しました。リクエストパラメータが正しいか確認してください。モデルの設定を変更した場合は、デフォルトの設定にリセットしてください", "401": "認証に失敗しました。APIキーが正しいか確認してください", @@ -850,11 +855,13 @@ "503": "サービスが利用できません。後でもう一度試してください", "504": "ゲートウェイタイムアウトが発生しました。後でもう一度試してください" }, + "message": "エラーメッセージ", "missing_user_message": "モデル応答を切り替えられません:元のユーザーメッセージが削除されました。このモデルで応答を得るには、新しいメッセージを送信してください", "model": { "exists": "モデルが既に存在します", "not_exists": "モデルが存在しません" }, + "name": "エラー名", "no_api_key": "APIキーが設定されていません", "pause_placeholder": "応答を一時停止しました", "provider_disabled": "モデルプロバイダーが有効になっていません", @@ -862,6 +869,14 @@ "description": "メッセージの内容のレンダリングに失敗しました。メッセージの内容の形式が正しいか確認してください", "title": "レンダリングエラー" }, + "requestBody": "要求されたコンテンツ", + "requestBodyValues": "リクエストボディ", + "requestUrl": "リクエストパス", + "responseBody": "レスポンス内容", + "responseHeaders": "レスポンスヘッダー", + "stack": "スタック情報", + "status": "ステータスコード", + "statusCode": "ステータスコード", "unknown": "不明なエラー", "user_message_not_found": "元のユーザーメッセージを見つけることができませんでした" }, @@ -1448,7 +1463,7 @@ }, "websearch": { "cutoff": "検索内容を切り詰めています...", - "fetch_complete": "{{count}}回の検索を完了しました...", + "fetch_complete": "{{count}}件の検索結果", "rag": "RAGを実行中...", "rag_complete": "{{countBefore}}個の結果から{{countAfter}}個を保持...", "rag_failed": "RAGが失敗しました。空の結果を返します..." diff --git a/src/renderer/src/i18n/locales/ru-ru.json b/src/renderer/src/i18n/locales/ru-ru.json index 1b773c956d..6371fd9efe 100644 --- a/src/renderer/src/i18n/locales/ru-ru.json +++ b/src/renderer/src/i18n/locales/ru-ru.json @@ -747,6 +747,7 @@ "delete": "Удалить", "delete_confirm": "Вы уверены, что хотите удалить?", "description": "Описание", + "detail": "Подробности", "disabled": "Отключено", "docs": "Документы", "download": "Скачать", @@ -830,6 +831,7 @@ "invalid": "Недействительный сервер MCP" } }, + "cause": "Ошибка произошла по следующей причине", "chat": { "chunk": { "non_json": "Вернулся недопустимый формат данных" @@ -839,6 +841,9 @@ "quota_exceeded": "Ваша ежедневная {{quota}} бесплатная квота исчерпана. Пожалуйста, перейдите в {{provider}} для получения ключа API и настройте ключ API для продолжения использования.", "response": "Что-то пошло не так. Пожалуйста, проверьте, установлен ли ваш ключ API в Настройки > Провайдеры" }, + "data": "данные", + "detail": "Детали ошибки", + "details": "Подробности", "http": { "400": "Не удалось выполнить запрос. Пожалуйста, проверьте, правильно ли настроены параметры запроса. Если вы изменили настройки модели, пожалуйста, сбросьте их до значений по умолчанию", "401": "Не удалось пройти аутентификацию. Пожалуйста, проверьте, правильно ли настроен ваш ключ API", @@ -850,11 +855,13 @@ "503": "Серверная ошибка. Пожалуйста, попробуйте позже", "504": "Серверная ошибка. Пожалуйста, попробуйте позже" }, + "message": "Сообщение об ошибке", "missing_user_message": "Невозможно изменить модель ответа: исходное сообщение пользователя было удалено. Пожалуйста, отправьте новое сообщение, чтобы получить ответ от этой модели", "model": { "exists": "Модель уже существует", "not_exists": "Модель не существует" }, + "name": "错误名称", "no_api_key": "Ключ API не настроен", "pause_placeholder": "Получение ответа приостановлено", "provider_disabled": "Провайдер моделей не включен", @@ -862,6 +869,14 @@ "description": "Не удалось рендерить содержимое сообщения. Пожалуйста, проверьте, правильно ли формат содержимого сообщения", "title": "Ошибка рендеринга" }, + "requestBody": "Запрашиваемый контент", + "requestBodyValues": "Тело запроса", + "requestUrl": "Путь запроса", + "responseBody": "Содержание ответа", + "responseHeaders": "Заголовки ответа", + "stack": "Информация стека", + "status": "Код статуса", + "statusCode": "Код состояния", "unknown": "Неизвестная ошибка", "user_message_not_found": "Не удалось найти исходное сообщение пользователя" }, @@ -1448,7 +1463,7 @@ }, "websearch": { "cutoff": "Обрезка содержимого поиска...", - "fetch_complete": "Завершено {{count}} поисков...", + "fetch_complete": "{{count}} результатов поиска", "rag": "Выполнение RAG...", "rag_complete": "Сохранено {{countAfter}} из {{countBefore}} результатов...", "rag_failed": "RAG не удалось, возвращается пустой результат..." diff --git a/src/renderer/src/i18n/locales/zh-cn.json b/src/renderer/src/i18n/locales/zh-cn.json index 16d63aab77..71c734ce75 100644 --- a/src/renderer/src/i18n/locales/zh-cn.json +++ b/src/renderer/src/i18n/locales/zh-cn.json @@ -747,6 +747,7 @@ "delete": "删除", "delete_confirm": "确定要删除吗?", "description": "描述", + "detail": "详情", "disabled": "已禁用", "docs": "文档", "download": "下载", @@ -830,6 +831,7 @@ "invalid": "无效的MCP服务器" } }, + "cause": "错误原因", "chat": { "chunk": { "non_json": "返回了无效的数据格式" @@ -839,6 +841,9 @@ "quota_exceeded": "您今日免费配额已用尽,请前往 {{provider}} 获取API密钥,配置API密钥后继续使用", "response": "出错了,如果没有配置 API 密钥,请前往设置 > 模型提供商中配置密钥" }, + "data": "数据", + "detail": "错误详情", + "details": "详细信息", "http": { "400": "请求错误,请检查请求参数是否正确。如果修改了模型设置,请重置到默认设置", "401": "身份验证失败,请检查 API 密钥是否正确", @@ -850,11 +855,13 @@ "503": "服务不可用,请稍后再试", "504": "网关超时,请稍后再试" }, + "message": "错误信息", "missing_user_message": "无法切换模型响应:原始用户消息已被删除。请发送新消息以获取此模型的响应", "model": { "exists": "模型已存在", "not_exists": "模型不存在" }, + "name": "错误名称", "no_api_key": "API 密钥未配置", "pause_placeholder": "已中断", "provider_disabled": "模型提供商未启用", @@ -862,6 +869,14 @@ "description": "消息内容渲染失败,请检查消息内容格式是否正确", "title": "渲染错误" }, + "requestBody": "请求内容", + "requestBodyValues": "请求体", + "requestUrl": "请求路径", + "responseBody": "响应内容", + "responseHeaders": "响应首部", + "stack": "堆栈信息", + "status": "状态码", + "statusCode": "状态码", "unknown": "未知错误", "user_message_not_found": "无法找到原始用户消息" }, @@ -1448,7 +1463,7 @@ }, "websearch": { "cutoff": "正在截断搜索内容...", - "fetch_complete": "已完成 {{count}} 次搜索...", + "fetch_complete": "{{count}} 个搜索结果", "rag": "正在执行 RAG...", "rag_complete": "保留 {{countBefore}} 个结果中的 {{countAfter}} 个...", "rag_failed": "RAG 失败,返回空结果..." diff --git a/src/renderer/src/i18n/locales/zh-tw.json b/src/renderer/src/i18n/locales/zh-tw.json index af61bc2936..4ab0fa3dae 100644 --- a/src/renderer/src/i18n/locales/zh-tw.json +++ b/src/renderer/src/i18n/locales/zh-tw.json @@ -747,6 +747,7 @@ "delete": "刪除", "delete_confirm": "確定要刪除嗎?", "description": "描述", + "detail": "詳情", "disabled": "已停用", "docs": "文件", "download": "下載", @@ -830,6 +831,7 @@ "invalid": "無效的MCP伺服器" } }, + "cause": "錯誤原因", "chat": { "chunk": { "non_json": "返回了無效的資料格式" @@ -839,6 +841,9 @@ "quota_exceeded": "您今日{{quota}}免费配额已用尽,请前往 {{provider}} 获取API密钥,配置API密钥后继续使用", "response": "出現錯誤。如果尚未設定 API 金鑰,請前往設定 > 模型提供者中設定金鑰" }, + "data": "数据", + "detail": "錯誤詳情", + "details": "詳細信息", "http": { "400": "請求錯誤,請檢查請求參數是否正確。如果修改了模型設定,請重設到預設設定", "401": "身份驗證失敗,請檢查 API 金鑰是否正確", @@ -850,11 +855,13 @@ "503": "服務無法使用,請稍後再試", "504": "閘道器超時,請稍後再試" }, + "message": "錯誤訊息", "missing_user_message": "無法切換模型回應:原始用戶訊息已被刪除。請發送新訊息以獲得此模型回應。", "model": { "exists": "模型已存在", "not_exists": "模型不存在" }, + "name": "錯誤名稱", "no_api_key": "API 金鑰未設定", "pause_placeholder": "回應已暫停", "provider_disabled": "模型供應商未啟用", @@ -862,6 +869,14 @@ "description": "消息內容渲染失敗,請檢查消息內容格式是否正確", "title": "渲染錯誤" }, + "requestBody": "請求內容", + "requestBodyValues": "请求体", + "requestUrl": "請求路徑", + "responseBody": "响应内容", + "responseHeaders": "响应首部", + "stack": "堆棧信息", + "status": "狀態碼", + "statusCode": "狀態碼", "unknown": "未知錯誤", "user_message_not_found": "無法找到原始用戶訊息" }, @@ -1448,7 +1463,7 @@ }, "websearch": { "cutoff": "正在截斷搜尋內容...", - "fetch_complete": "已完成 {{count}} 次搜尋...", + "fetch_complete": "{{count}} 個搜尋結果", "rag": "正在執行 RAG...", "rag_complete": "保留 {{countBefore}} 個結果中的 {{countAfter}} 個...", "rag_failed": "RAG 失敗,返回空結果..." diff --git a/src/renderer/src/i18n/translate/el-gr.json b/src/renderer/src/i18n/translate/el-gr.json index 5960f6ac47..187110e47d 100644 --- a/src/renderer/src/i18n/translate/el-gr.json +++ b/src/renderer/src/i18n/translate/el-gr.json @@ -747,6 +747,7 @@ "delete": "Διαγραφή", "delete_confirm": "Είστε βέβαιοι ότι θέλετε να διαγράψετε;", "description": "Περιγραφή", + "detail": "Λεπτομέρειες", "disabled": "Απενεργοποιημένο", "docs": "Έγγραφα", "download": "Λήψη", @@ -830,6 +831,7 @@ "invalid": "Μη έγκυρος διακομιστής MCP" } }, + "cause": "Αιτία σφάλματος", "chat": { "chunk": { "non_json": "Επέστρεψε μη έγκυρη μορφή δεδομένων" @@ -839,6 +841,9 @@ "quota_exceeded": "Η ημερήσια δωρεάν ποσόστωση {{quota}} tokens σας έχει εξαντληθεί. Παρακαλώ μεταβείτε στο {{provider}} για να λάβετε ένα κλειδί API και να ρυθμίσετε το κλειδί API για να συνεχίσετε τη χρήση.", "response": "Σφάλμα. Εάν δεν έχετε ρυθμίσει το κλειδί API, πηγαίνετε στο ρυθμισμένα > παρέχοντας το πρόσωπο του μοντέλου" }, + "data": "δεδομένα", + "detail": "Λεπτομέρειες σφάλματος", + "details": "Λεπτομέρειες", "http": { "400": "Σφάλμα ζητήματος, παρακαλώ ελέγξτε αν τα παράμετρα του ζητήματος είναι σωστά. Εάν έχετε αλλάξει τις ρυθμίσεις του μοντέλου, επαναφέρετε τις προεπιλεγμένες ρυθμίσεις.", "401": "Αποτυχία επιβεβαίωσης ταυτότητας, παρακαλώ ελέγξτε αν η κλειδί API είναι σωστή", @@ -850,11 +855,13 @@ "503": "Η υπηρεσία δεν είναι διαθέσιμη, παρακαλώ δοκιμάστε ξανά", "504": "Υπερχρονισμός φάρων, παρακαλώ δοκιμάστε ξανά" }, + "message": "Μήνυμα σφάλματος", "missing_user_message": "Αδυναμία εναλλαγής απάντησης μοντέλου: το αρχικό μήνυμα χρήστη έχει διαγραφεί. Παρακαλούμε στείλτε ένα νέο μήνυμα για να λάβετε απάντηση από αυτό το μοντέλο", "model": { "exists": "Το μοντέλο υπάρχει ήδη", "not_exists": "Το μοντέλο δεν υπάρχει" }, + "name": "Λάθος όνομα", "no_api_key": "Δεν έχετε ρυθμίσει το κλειδί API", "pause_placeholder": "Διακόπηκε", "provider_disabled": "Ο παρεχόμενος παροχός του μοντέλου δεν είναι ενεργοποιημένος", @@ -862,6 +869,14 @@ "description": "Απέτυχε η ώθηση της εξίσωσης, παρακαλώ ελέγξτε το σωστό μορφάτι της", "title": "Σφάλμα Παρασκήνιου" }, + "requestBody": "Περιεχόμενο αιτήματος", + "requestBodyValues": "Σώμα αιτήματος", + "requestUrl": "Μονοπάτι αιτήματος", + "responseBody": "απάντηση περιεχομένου", + "responseHeaders": "Επικεφαλίδες απόκρισης", + "stack": "Πληροφορίες στοίβας", + "status": "Κωδικός κατάστασης", + "statusCode": "Κωδικός κατάστασης", "unknown": "Άγνωστο σφάλμα", "user_message_not_found": "Αδυναμία εύρεσης της αρχικής μηνύματος χρήστη" }, @@ -1448,7 +1463,7 @@ }, "websearch": { "cutoff": "Περικόπτεται η αναζήτηση...", - "fetch_complete": "Ολοκληρώθηκαν {{count}} αναζητήσεις...", + "fetch_complete": "{{count}} αποτελέσματα αναζήτησης", "rag": "Εκτελείται RAG...", "rag_complete": "Διατηρούνται {{countAfter}} από τα {{countBefore}} αποτελέσματα...", "rag_failed": "Το RAG απέτυχε, επιστρέφεται κενό αποτέλεσμα..." diff --git a/src/renderer/src/i18n/translate/es-es.json b/src/renderer/src/i18n/translate/es-es.json index 3b8e033172..029b5c5813 100644 --- a/src/renderer/src/i18n/translate/es-es.json +++ b/src/renderer/src/i18n/translate/es-es.json @@ -747,6 +747,7 @@ "delete": "Eliminar", "delete_confirm": "¿Está seguro de que desea eliminarlo?", "description": "Descripción", + "detail": "Detalles", "disabled": "Desactivado", "docs": "Documentos", "download": "Descargar", @@ -830,6 +831,7 @@ "invalid": "Servidor MCP no válido" } }, + "cause": "Error原因", "chat": { "chunk": { "non_json": "Devuelve un formato de datos no válido" @@ -839,6 +841,9 @@ "quota_exceeded": "Su cuota gratuita diaria de {{quota}} tokens se ha agotado. Por favor, vaya a {{provider}} para obtener una clave API y configurar la clave API para continuar usando.", "response": "Ha ocurrido un error, si no ha configurado la clave API, vaya a Configuración > Proveedor de modelos para configurar la clave" }, + "data": "datos", + "detail": "Detalles del error", + "details": "Detalles", "http": { "400": "Error en la solicitud, revise si los parámetros de la solicitud son correctos. Si modificó la configuración del modelo, restablezca a la configuración predeterminada", "401": "Fallo en la autenticación, revise si la clave API es correcta", @@ -850,11 +855,13 @@ "503": "Servicio no disponible, inténtelo de nuevo más tarde", "504": "Tiempo de espera de la puerta de enlace, inténtelo de nuevo más tarde" }, + "message": "错误信息", "missing_user_message": "No se puede cambiar la respuesta del modelo: el mensaje original del usuario ha sido eliminado. Envíe un nuevo mensaje para obtener la respuesta de este modelo", "model": { "exists": "El modelo ya existe", "not_exists": "El modelo no existe" }, + "name": "Nombre de error", "no_api_key": "La clave API no está configurada", "pause_placeholder": "Interrumpido", "provider_disabled": "El proveedor de modelos no está habilitado", @@ -862,6 +869,14 @@ "description": "Error al renderizar la fórmula, por favor, compruebe si el formato de la fórmula es correcto", "title": "Error de renderizado" }, + "requestBody": "Contenido de la solicitud", + "requestBodyValues": "Cuerpo de la solicitud", + "requestUrl": "Ruta de solicitud", + "responseBody": "Contenido de la respuesta", + "responseHeaders": "Encabezados de respuesta", + "stack": "Información de la pila", + "status": "código de estado", + "statusCode": "código de estado", "unknown": "Error desconocido", "user_message_not_found": "No se pudo encontrar el mensaje original del usuario" }, @@ -1448,7 +1463,7 @@ }, "websearch": { "cutoff": "Truncando el contenido de búsqueda...", - "fetch_complete": "Búsqueda completada {{count}} veces...", + "fetch_complete": "{{count}} resultados de búsqueda", "rag": "Ejecutando RAG...", "rag_complete": "Conservando {{countAfter}} de los {{countBefore}} resultados...", "rag_failed": "RAG fallido, devolviendo resultados vacíos..." diff --git a/src/renderer/src/i18n/translate/fr-fr.json b/src/renderer/src/i18n/translate/fr-fr.json index 181f2b5f75..08af4d6e30 100644 --- a/src/renderer/src/i18n/translate/fr-fr.json +++ b/src/renderer/src/i18n/translate/fr-fr.json @@ -747,6 +747,7 @@ "delete": "Supprimer", "delete_confirm": "Êtes-vous sûr de vouloir supprimer ?", "description": "Description", + "detail": "détails", "disabled": "Désactivé", "docs": "Documents", "download": "Télécharger", @@ -830,6 +831,7 @@ "invalid": "Serveur MCP invalide" } }, + "cause": "Erreur causée par", "chat": { "chunk": { "non_json": "a renvoyé un format de données invalide" @@ -839,6 +841,9 @@ "quota_exceeded": "Votre quota gratuit quotidien de {{quota}} tokens a été épuisé. Veuillez vous rendre sur {{provider}} pour obtenir une clé API et configurer la clé API pour continuer à utiliser.", "response": "Une erreur s'est produite, si l'API n'est pas configurée, veuillez aller dans Paramètres > Fournisseurs de modèles pour configurer la clé" }, + "data": "données", + "detail": "Détails de l'erreur", + "details": "Informations détaillées", "http": { "400": "Erreur de requête, veuillez vérifier si les paramètres de la requête sont corrects. Si vous avez modifié les paramètres du modèle, réinitialisez-les aux paramètres par défaut.", "401": "Échec de l'authentification, veuillez vérifier que votre clé API est correcte.", @@ -850,11 +855,13 @@ "503": "Service indisponible, veuillez réessayer plus tard.", "504": "Délai d'expiration de la passerelle, veuillez réessayer plus tard." }, + "message": "Erreur message", "missing_user_message": "Impossible de changer de modèle de réponse : le message utilisateur d'origine a été supprimé. Veuillez envoyer un nouveau message pour obtenir une réponse de ce modèle.", "model": { "exists": "Le modèle existe déjà", "not_exists": "Le modèle n'existe pas" }, + "name": "Nom d'erreur", "no_api_key": "La clé API n'est pas configurée", "pause_placeholder": "Прервано", "provider_disabled": "Le fournisseur de modèles n'est pas activé", @@ -862,6 +869,14 @@ "description": "La formule n'a pas été rendue avec succès, veuillez vérifier si le format de la formule est correct", "title": "Erreur de rendu" }, + "requestBody": "Contenu de la demande", + "requestBodyValues": "Corps de la requête", + "requestUrl": "Chemin de la requête", + "responseBody": "Contenu de la réponse", + "responseHeaders": "En-têtes de réponse", + "stack": "Informations de la pile", + "status": "Code d'état", + "statusCode": "Code d'état", "unknown": "Неизвестная ошибка", "user_message_not_found": "Impossible de trouver le message d'utilisateur original" }, @@ -1448,7 +1463,7 @@ }, "websearch": { "cutoff": "Troncature du contenu de recherche en cours...", - "fetch_complete": "{{count}} recherches terminées...", + "fetch_complete": "{{count}} résultats de recherche", "rag": "Exécution de la RAG en cours...", "rag_complete": "Conserver {{countAfter}} résultats sur {{countBefore}}...", "rag_failed": "Échec de la RAG, retour d'un résultat vide..." diff --git a/src/renderer/src/i18n/translate/pt-pt.json b/src/renderer/src/i18n/translate/pt-pt.json index fa2faa1a24..eff87d6902 100644 --- a/src/renderer/src/i18n/translate/pt-pt.json +++ b/src/renderer/src/i18n/translate/pt-pt.json @@ -747,6 +747,7 @@ "delete": "Excluir", "delete_confirm": "Tem certeza de que deseja excluir?", "description": "Descrição", + "detail": "detalhes", "disabled": "Desativado", "docs": "Documentos", "download": "Baixar", @@ -830,6 +831,7 @@ "invalid": "Servidor MCP inválido" } }, + "cause": "Causa do erro", "chat": { "chunk": { "non_json": "Devolveu um formato de dados inválido" @@ -839,6 +841,9 @@ "quota_exceeded": "Sua cota gratuita diária de {{quota}} tokens foi esgotada. Por favor, vá para {{provider}} para obter uma chave API e configurar a chave API para continuar usando.", "response": "Ocorreu um erro, se a chave da API não foi configurada, por favor vá para Configurações > Provedores de Modelo para configurar a chave" }, + "data": "dados", + "detail": "Detalhes do erro", + "details": "Detalhes", "http": { "400": "Erro na solicitação, por favor verifique se os parâmetros da solicitação estão corretos. Se você alterou as configurações do modelo, redefina para as configurações padrão", "401": "Falha na autenticação, por favor verifique se a chave da API está correta", @@ -850,11 +855,13 @@ "503": "Serviço indisponível, por favor tente novamente mais tarde", "504": "Tempo de espera do gateway excedido, por favor tente novamente mais tarde" }, + "message": "Mensagem de erro", "missing_user_message": "Não é possível alternar a resposta do modelo: a mensagem original do usuário foi excluída. Envie uma nova mensagem para obter a resposta deste modelo", "model": { "exists": "O modelo já existe", "not_exists": "O modelo não existe" }, + "name": "Nome do erro", "no_api_key": "A chave da API não foi configurada", "pause_placeholder": "Interrompido", "provider_disabled": "O provedor de modelos está desativado", @@ -862,6 +869,14 @@ "description": "Falha ao renderizar a fórmula, por favor verifique se o formato da fórmula está correto", "title": "Erro de Renderização" }, + "requestBody": "Conteúdo da solicitação", + "requestBodyValues": "Corpo da solicitação", + "requestUrl": "Caminho da solicitação", + "responseBody": "Conteúdo da resposta", + "responseHeaders": "Cabeçalho de resposta", + "stack": "Informações da pilha", + "status": "Código de status", + "statusCode": "Código de status", "unknown": "Erro desconhecido", "user_message_not_found": "Não foi possível encontrar a mensagem original do usuário" }, @@ -1448,7 +1463,7 @@ }, "websearch": { "cutoff": "Truncando o conteúdo da pesquisa...", - "fetch_complete": "Concluída {{count}} busca...", + "fetch_complete": "{{count}} resultados da pesquisa", "rag": "Executando RAG...", "rag_complete": "Mantendo {{countAfter}} dos {{countBefore}} resultados...", "rag_failed": "RAG falhou, retornando resultado vazio..." diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index af4428d158..db787c40fd 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -46,6 +46,7 @@ import { getTextFromDropEvent, isSendMessageKeyPressed } from '@renderer/utils/input' +import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools' import { documentExts, imageExts, textExts } from '@shared/config/constant' import { IpcChannel } from '@shared/IpcChannel' import { Button, Tooltip } from 'antd' @@ -230,7 +231,6 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = const uploadedFiles = await FileManager.uploadFiles(files) const baseUserMessage: MessageInputBaseParams = { assistant, topic, content: text } - logger.info('baseUserMessage', baseUserMessage) // getUserMessage() if (uploadedFiles) { @@ -831,6 +831,7 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = const isExpanded = expanded || !!textareaHeight const showThinkingButton = isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model) + const showMcpTools = isSupportedToolUse(assistant) || isPromptToolUse(assistant) if (isMultiSelectMode) { return null @@ -899,7 +900,8 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = extensions={supportedExts} setFiles={setFiles} showThinkingButton={showThinkingButton} - showKnowledgeIcon={showKnowledgeIcon} + showKnowledgeIcon={showKnowledgeIcon && showMcpTools} + showMcpTools={showMcpTools} selectedKnowledgeBases={selectedKnowledgeBases} handleKnowledgeBaseSelect={handleKnowledgeBaseSelect} setText={setText} diff --git a/src/renderer/src/pages/home/Inputbar/InputbarTools.tsx b/src/renderer/src/pages/home/Inputbar/InputbarTools.tsx index 079b7801b1..82a0071ee9 100644 --- a/src/renderer/src/pages/home/Inputbar/InputbarTools.tsx +++ b/src/renderer/src/pages/home/Inputbar/InputbarTools.tsx @@ -63,6 +63,7 @@ export interface InputbarToolsProps { extensions: string[] showThinkingButton: boolean showKnowledgeIcon: boolean + showMcpTools: boolean selectedKnowledgeBases: KnowledgeBase[] handleKnowledgeBaseSelect: (bases?: KnowledgeBase[]) => void setText: Dispatch> @@ -105,6 +106,7 @@ const InputbarTools = ({ setFiles, showThinkingButton, showKnowledgeIcon, + showMcpTools, selectedKnowledgeBases, handleKnowledgeBaseSelect, setText, @@ -376,7 +378,8 @@ const InputbarTools = ({ setInputValue={setText} resizeTextArea={resizeTextArea} /> - ) + ), + condition: showMcpTools }, { key: 'generate_image', @@ -480,6 +483,7 @@ const InputbarTools = ({ setFiles, setText, showKnowledgeIcon, + showMcpTools, showThinkingButton, t ]) diff --git a/src/renderer/src/pages/home/Markdown/Markdown.tsx b/src/renderer/src/pages/home/Markdown/Markdown.tsx index af00712ca2..f2d65e9a7c 100644 --- a/src/renderer/src/pages/home/Markdown/Markdown.tsx +++ b/src/renderer/src/pages/home/Markdown/Markdown.tsx @@ -11,8 +11,7 @@ import type { MainTextMessageBlock, ThinkingMessageBlock, TranslationMessageBloc import { removeSvgEmptyLines } from '@renderer/utils/formats' import { processLatexBrackets } from '@renderer/utils/markdown' import { isEmpty } from 'lodash' -import { type FC, memo, useCallback, useEffect, useMemo, useState } from 'react' -import { useRef } from 'react' +import { type FC, memo, useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import ReactMarkdown, { type Components, defaultUrlTransform } from 'react-markdown' import rehypeKatex from 'rehype-katex' diff --git a/src/renderer/src/pages/home/Messages/Blocks/ErrorBlock.tsx b/src/renderer/src/pages/home/Messages/Blocks/ErrorBlock.tsx index 514eb96634..f4eb8d1dd6 100644 --- a/src/renderer/src/pages/home/Messages/Blocks/ErrorBlock.tsx +++ b/src/renderer/src/pages/home/Messages/Blocks/ErrorBlock.tsx @@ -1,11 +1,21 @@ +import CodeViewer from '@renderer/components/CodeViewer' import { useTimer } from '@renderer/hooks/useTimer' import { getHttpMessageLabel, getProviderLabel } from '@renderer/i18n/label' import { getProviderById } from '@renderer/services/ProviderService' import { useAppDispatch } from '@renderer/store' import { removeBlocksThunk } from '@renderer/store/thunk/messageThunk' +import { + isSerializedAiSdkAPICallError, + isSerializedAiSdkError, + isSerializedError, + SerializedAiSdkAPICallError, + SerializedAiSdkError, + SerializedError +} from '@renderer/types/error' import type { ErrorMessageBlock, Message } from '@renderer/types/newMessage' -import { Alert as AntdAlert } from 'antd' -import React from 'react' +import { formatAiSdkError, formatError, safeToString } from '@renderer/utils/error' +import { Alert as AntdAlert, Button, Modal } from 'antd' +import React, { useState } from 'react' import { Trans, useTranslation } from 'react-i18next' import { Link } from 'react-router-dom' import styled from 'styled-components' @@ -24,13 +34,16 @@ const ErrorBlock: React.FC = ({ block, message }) => { const ErrorMessage: React.FC<{ block: ErrorMessageBlock }> = ({ block }) => { const { t, i18n } = useTranslation() - const i18nKey = `error.${block.error?.i18nKey}` + const i18nKey = block.error && 'i18nKey' in block.error ? `error.${block.error?.i18nKey}` : '' const errorKey = `error.${block.error?.message}` - const errorStatus = block.error?.status + const errorStatus = + block.error && ('status' in block.error || 'statusCode' in block.error) + ? block.error?.status || block.error?.statusCode + : undefined if (i18n.exists(i18nKey)) { - const providerId = block.error?.providerId - if (providerId) { + const providerId = block.error && 'providerId' in block.error ? block.error?.providerId : undefined + if (providerId && typeof providerId === 'string') { return ( = ({ block }) => { return t(errorKey) } - if (HTTP_ERROR_CODES.includes(errorStatus)) { + if (typeof errorStatus === 'number' && HTTP_ERROR_CODES.includes(errorStatus)) { return (
- {getHttpMessageLabel(errorStatus)} {block.error?.message} + {getHttpMessageLabel(errorStatus.toString())} {block.error?.message}
) } @@ -64,27 +77,180 @@ const ErrorMessage: React.FC<{ block: ErrorMessageBlock }> = ({ block }) => { return block.error?.message || '' } -const ErrorDescription: React.FC<{ block: ErrorMessageBlock }> = ({ block }) => { - const { t } = useTranslation() - - if (block.error) { - return - } - - return <>{t('error.chat.response')} -} - const MessageErrorInfo: React.FC<{ block: ErrorMessageBlock; message: Message }> = ({ block, message }) => { const dispatch = useAppDispatch() const { setTimeoutTimer } = useTimer() + const [showDetailModal, setShowDetailModal] = useState(false) + const { t } = useTranslation() const onRemoveBlock = () => { setTimeoutTimer('onRemoveBlock', () => dispatch(removeBlocksThunk(message.topicId, message.id, [block.id])), 350) } - return } type="error" closable onClose={onRemoveBlock} /> + const showErrorDetail = () => { + setShowDetailModal(true) + } + + const getAlertMessage = () => { + const status = + block.error && ('status' in block.error || 'statusCode' in block.error) + ? block.error?.status || block.error?.statusCode + : undefined + if (block.error && typeof status === 'number' && HTTP_ERROR_CODES.includes(status)) { + return block.error.message + } + return null + } + + const getAlertDescription = () => { + const status = + block.error && ('status' in block.error || 'statusCode' in block.error) + ? block.error?.status || block.error?.statusCode + : undefined + if (block.error && typeof status === 'number' && HTTP_ERROR_CODES.includes(status)) { + return getHttpMessageLabel(status.toString()) + } + return + } + + return ( + <> + { + e.stopPropagation() + showErrorDetail() + }}> + {t('common.detail')} + + } + /> + setShowDetailModal(false)} error={block.error} /> + + ) } +interface ErrorDetailModalProps { + open: boolean + onClose: () => void + error?: SerializedError +} + +const ErrorDetailModal: React.FC = ({ open, onClose, error }) => { + const { t } = useTranslation() + + const copyErrorDetails = () => { + if (!error) return + let errorText: string + if (isSerializedAiSdkError(error)) { + errorText = formatAiSdkError(error) + } else if (isSerializedError(error)) { + errorText = formatError(error) + } else { + // fallback + errorText = safeToString(error) + } + + navigator.clipboard.writeText(errorText) + window.message.success(t('message.copied')) + } + + const renderErrorDetails = (error?: SerializedError) => { + if (!error) return
{t('error.unknown')}
+ if (isSerializedAiSdkAPICallError(error)) { + return + } + if (isSerializedAiSdkError(error)) { + return + } + return ( + + + + ) + } + + return ( + + {t('common.copy')} + , + + ]} + width={600}> + {renderErrorDetails(error)} + + ) +} + +const ErrorDetailContainer = styled.div` + max-height: 400px; + overflow-y: auto; +` + +const ErrorDetailList = styled.div` + display: flex; + flex-direction: column; + gap: 16px; +` + +const ErrorDetailItem = styled.div` + display: flex; + flex-direction: column; + gap: 8px; +` + +const ErrorDetailLabel = styled.div` + font-weight: 600; + color: var(--color-text); + font-size: 14px; +` + +const ErrorDetailValue = styled.div` + font-family: var(--code-font-family); + font-size: 12px; + padding: 8px; + background: var(--color-code-background); + border-radius: 4px; + border: 1px solid var(--color-border); + word-break: break-word; + color: var(--color-text); +` + +const StackTrace = styled.div` + background: var(--color-background-soft); + border: 1px solid var(--color-error); + border-radius: 6px; + padding: 12px; + + pre { + margin: 0; + white-space: pre-wrap; + word-break: break-word; + font-family: var(--code-font-family); + font-size: 12px; + line-height: 1.4; + color: var(--color-error); + } +` + const Alert = styled(AntdAlert)` margin: 0.5rem 0 !important; padding: 10px; @@ -94,4 +260,110 @@ const Alert = styled(AntdAlert)` } ` +// 作为 base,渲染公共字段,应当在 ErrorDetailList 中渲染 +const BuiltinError = ({ error }: { error: SerializedError }) => { + const { t } = useTranslation() + return ( + <> + {error.name && ( + + {t('error.name')}: + {error.name} + + )} + {error.message && ( + + {t('error.message')}: + {error.message} + + )} + {error.stack && ( + + {t('error.stack')}: + +
{error.stack}
+
+
+ )} + + ) +} + +// 作为 base,渲染公共字段,应当在 ErrorDetailList 中渲染 +const AiSdkError = ({ error }: { error: SerializedAiSdkError }) => { + const { t } = useTranslation() + const cause = error.cause + return ( + <> + + {cause && ( + + {t('error.cause')}: + {error.cause} + + )} + + ) +} + +const AiApiCallError = ({ error }: { error: SerializedAiSdkAPICallError }) => { + const { t } = useTranslation() + + // 这些字段是 unknown 类型,暂且不清楚都可能是什么类型,总之先覆盖下大部分场景 + const requestBodyValues = safeToString(error.requestBodyValues) + const data = safeToString(error.data) + + return ( + + + + {error.url && ( + + {t('error.requestUrl')}: + {error.url} + + )} + + {requestBodyValues && ( + + {t('error.requestBodyValues')}: + + + )} + + {error.statusCode && ( + + {t('error.statusCode')}: + {error.statusCode} + + )} + {error.responseHeaders && ( + + {t('error.responseHeaders')}: + + + )} + + {error.responseBody && ( + + {t('error.responseBody')}: + + + )} + + {data && ( + + {t('error.data')}: + + + )} + + ) +} + export default React.memo(ErrorBlock) diff --git a/src/renderer/src/pages/home/Messages/Blocks/ThinkingBlock.tsx b/src/renderer/src/pages/home/Messages/Blocks/ThinkingBlock.tsx index 96f439fa2e..0aee1eccac 100644 --- a/src/renderer/src/pages/home/Messages/Blocks/ThinkingBlock.tsx +++ b/src/renderer/src/pages/home/Messages/Blocks/ThinkingBlock.tsx @@ -103,37 +103,41 @@ const ThinkingBlock: React.FC = ({ block }) => { } const ThinkingTimeSeconds = memo( - ({ blockThinkingTime, isThinking }: { blockThinkingTime?: number; isThinking: boolean }) => { + ({ blockThinkingTime, isThinking }: { blockThinkingTime: number; isThinking: boolean }) => { const { t } = useTranslation() - - const [thinkingTime, setThinkingTime] = useState(blockThinkingTime || 0) + // console.log('blockThinkingTime', blockThinkingTime) + // const [thinkingTime, setThinkingTime] = useState(blockThinkingTime || 0) // FIXME: 这里统计的和请求处统计的有一定误差 - useEffect(() => { - let timer: NodeJS.Timeout | null = null - if (isThinking) { - timer = setInterval(() => { - setThinkingTime((prev) => prev + 100) - }, 100) - } else if (timer) { - // 立即清除计时器 - clearInterval(timer) - timer = null - } + // useEffect(() => { + // let timer: NodeJS.Timeout | null = null + // if (isThinking) { + // timer = setInterval(() => { + // setThinkingTime((prev) => prev + 100) + // }, 100) + // } else if (timer) { + // // 立即清除计时器 + // clearInterval(timer) + // timer = null + // } - return () => { - if (timer) { - clearInterval(timer) - timer = null - } - } - }, [isThinking]) + // return () => { + // if (timer) { + // clearInterval(timer) + // timer = null + // } + // } + // }, [isThinking]) - const thinkingTimeSeconds = useMemo(() => (thinkingTime / 1000).toFixed(1), [thinkingTime]) + const thinkingTimeSeconds = useMemo(() => (blockThinkingTime / 1000).toFixed(1), [blockThinkingTime]) - return t(isThinking ? 'chat.thinking' : 'chat.deeply_thought', { - seconds: thinkingTimeSeconds - }) + return isThinking + ? t('chat.thinking', { + seconds: thinkingTimeSeconds + }) + : t('chat.deeply_thought', { + seconds: thinkingTimeSeconds + }) } ) diff --git a/src/renderer/src/pages/home/Messages/Blocks/ToolBlock.tsx b/src/renderer/src/pages/home/Messages/Blocks/ToolBlock.tsx index 865f9c2948..67ef3ed607 100644 --- a/src/renderer/src/pages/home/Messages/Blocks/ToolBlock.tsx +++ b/src/renderer/src/pages/home/Messages/Blocks/ToolBlock.tsx @@ -1,7 +1,7 @@ import type { ToolMessageBlock } from '@renderer/types/newMessage' import React from 'react' -import MessageTools from '../MessageTools' +import MessageTools from '../Tools/MessageTools' interface Props { block: ToolMessageBlock diff --git a/src/renderer/src/pages/home/Messages/Blocks/__tests__/ThinkingBlock.test.tsx b/src/renderer/src/pages/home/Messages/Blocks/__tests__/ThinkingBlock.test.tsx index 15088396df..8db122d948 100644 --- a/src/renderer/src/pages/home/Messages/Blocks/__tests__/ThinkingBlock.test.tsx +++ b/src/renderer/src/pages/home/Messages/Blocks/__tests__/ThinkingBlock.test.tsx @@ -1,7 +1,6 @@ import type { ThinkingMessageBlock } from '@renderer/types/newMessage' import { MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage' import { render, screen } from '@testing-library/react' -import { act } from 'react' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import ThinkingBlock from '../ThinkingBlock' @@ -240,28 +239,9 @@ describe('ThinkingBlock', () => { expect(activeTimeText).toHaveTextContent('Thinking...') }) - it('should update thinking time in real-time when active', () => { - const block = createThinkingBlock({ - thinking_millsec: 1000, - status: MessageBlockStatus.STREAMING - }) - renderThinkingBlock(block) - - // Initial state - expect(getThinkingTimeText()).toHaveTextContent('1.0s') - - // After time passes - act(() => { - vi.advanceTimersByTime(500) - }) - - expect(getThinkingTimeText()).toHaveTextContent('1.5s') - }) - it('should handle extreme thinking times correctly', () => { const testCases = [ { thinking_millsec: 0, expectedTime: '0.0s' }, - { thinking_millsec: undefined, expectedTime: '0.0s' }, { thinking_millsec: 86400000, expectedTime: '86400.0s' }, // 1 day { thinking_millsec: 259200000, expectedTime: '259200.0s' } // 3 days ] @@ -276,38 +256,6 @@ describe('ThinkingBlock', () => { unmount() }) }) - - it('should stop timer when thinking status changes to completed', () => { - const block = createThinkingBlock({ - thinking_millsec: 1000, - status: MessageBlockStatus.STREAMING - }) - const { rerender } = renderThinkingBlock(block) - - // Advance timer while thinking - act(() => { - vi.advanceTimersByTime(1000) - }) - expect(getThinkingTimeText()).toHaveTextContent('2.0s') - - // Complete thinking - const completedBlock = createThinkingBlock({ - thinking_millsec: 1000, // Original time doesn't matter - status: MessageBlockStatus.SUCCESS - }) - rerender() - - // Timer should stop - text should change from "Thinking..." to "Thought for" - const timeText = getThinkingTimeText() - expect(timeText).toHaveTextContent('Thought for') - expect(timeText).toHaveTextContent('2.0s') - - // Further time advancement shouldn't change the display - act(() => { - vi.advanceTimersByTime(1000) - }) - expect(timeText).toHaveTextContent('2.0s') - }) }) describe('collapse behavior', () => { @@ -413,16 +361,6 @@ describe('ThinkingBlock', () => { expect(screen.queryByText('Markdown: Original thought')).not.toBeInTheDocument() }) - it('should clean up timer on unmount', () => { - const block = createThinkingBlock({ status: MessageBlockStatus.STREAMING }) - const { unmount } = renderThinkingBlock(block) - - const clearIntervalSpy = vi.spyOn(global, 'clearInterval') - unmount() - - expect(clearIntervalSpy).toHaveBeenCalled() - }) - it('should handle rapid status changes gracefully', () => { const block = createThinkingBlock({ status: MessageBlockStatus.STREAMING }) const { rerender } = renderThinkingBlock(block) diff --git a/src/renderer/src/pages/home/Messages/Blocks/index.tsx b/src/renderer/src/pages/home/Messages/Blocks/index.tsx index a7e86164c1..7c44d6c662 100644 --- a/src/renderer/src/pages/home/Messages/Blocks/index.tsx +++ b/src/renderer/src/pages/home/Messages/Blocks/index.tsx @@ -3,7 +3,7 @@ import type { RootState } from '@renderer/store' import { messageBlocksSelectors } from '@renderer/store/messageBlock' import type { ImageMessageBlock, MainTextMessageBlock, Message, MessageBlock } from '@renderer/types/newMessage' import { MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage' -import { AnimatePresence, motion } from 'motion/react' +import { AnimatePresence, motion, type Variants } from 'motion/react' import React, { useMemo } from 'react' import { useSelector } from 'react-redux' import styled from 'styled-components' @@ -25,7 +25,7 @@ interface AnimatedBlockWrapperProps { enableAnimation: boolean } -const blockWrapperVariants = { +const blockWrapperVariants: Variants = { visible: { opacity: 1, x: 0, diff --git a/src/renderer/src/pages/home/Messages/Tools/MessageKnowledgeSearch.tsx b/src/renderer/src/pages/home/Messages/Tools/MessageKnowledgeSearch.tsx new file mode 100644 index 0000000000..a7ebeb648c --- /dev/null +++ b/src/renderer/src/pages/home/Messages/Tools/MessageKnowledgeSearch.tsx @@ -0,0 +1,72 @@ +import { KnowledgeSearchToolInput, KnowledgeSearchToolOutput } from '@renderer/aiCore/tools/KnowledgeSearchTool' +import Spinner from '@renderer/components/Spinner' +import i18n from '@renderer/i18n' +import { MCPToolResponse } from '@renderer/types' +import { Typography } from 'antd' +import { FileSearch } from 'lucide-react' +import styled from 'styled-components' + +const { Text } = Typography +export function MessageKnowledgeSearchToolTitle({ toolResponse }: { toolResponse: MCPToolResponse }) { + const toolInput = toolResponse.arguments as KnowledgeSearchToolInput + const toolOutput = toolResponse.response as KnowledgeSearchToolOutput + + return toolResponse.status !== 'done' ? ( + + {i18n.t('message.searching')} + {toolInput?.additionalContext ?? ''} + + } + /> + ) : ( + + + {i18n.t('message.websearch.fetch_complete', { count: toolOutput.knowledgeReferences.length ?? 0 })} + + ) +} + +export function MessageKnowledgeSearchToolBody({ toolResponse }: { toolResponse: MCPToolResponse }) { + const toolOutput = toolResponse.response as KnowledgeSearchToolOutput + + return toolResponse.status === 'done' ? ( + + {toolOutput.knowledgeReferences.map((result) => ( +
  • + {result.id} + {result.content} +
  • + ))} +
    + ) : null +} + +const PrepareToolWrapper = styled.span` + display: flex; + align-items: center; + gap: 4px; + font-size: 14px; + padding-left: 0; +` +const MessageWebSearchToolTitleTextWrapper = styled(Text)` + display: flex; + align-items: center; + gap: 4px; +` + +const MessageWebSearchToolBodyUlWrapper = styled.ul` + display: flex; + flex-direction: column; + gap: 4px; + padding: 0; + > li { + padding: 0; + margin: 0; + max-width: 70%; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + } +` diff --git a/src/renderer/src/pages/home/Messages/MessageTools.tsx b/src/renderer/src/pages/home/Messages/Tools/MessageMcpTool.tsx similarity index 99% rename from src/renderer/src/pages/home/Messages/MessageTools.tsx rename to src/renderer/src/pages/home/Messages/Tools/MessageMcpTool.tsx index 2a5077f48d..312f5bb448 100644 --- a/src/renderer/src/pages/home/Messages/MessageTools.tsx +++ b/src/renderer/src/pages/home/Messages/Tools/MessageMcpTool.tsx @@ -44,7 +44,7 @@ const logger = loggerService.withContext('MessageTools') const COUNTDOWN_TIME = 30 -const MessageTools: FC = ({ block }) => { +const MessageMcpTool: FC = ({ block }) => { const [activeKeys, setActiveKeys] = useState([]) const [copiedMap, setCopiedMap] = useState>({}) const [countdown, setCountdown] = useState(COUNTDOWN_TIME) @@ -750,4 +750,4 @@ const ExpandedResponseContainer = styled.div` } ` -export default memo(MessageTools) +export default memo(MessageMcpTool) diff --git a/src/renderer/src/pages/home/Messages/Tools/MessageMemorySearch.tsx b/src/renderer/src/pages/home/Messages/Tools/MessageMemorySearch.tsx new file mode 100644 index 0000000000..cb86d8a259 --- /dev/null +++ b/src/renderer/src/pages/home/Messages/Tools/MessageMemorySearch.tsx @@ -0,0 +1,41 @@ +import { MemorySearchToolInput, MemorySearchToolOutput } from '@renderer/aiCore/tools/MemorySearchTool' +import Spinner from '@renderer/components/Spinner' +import { MCPToolResponse } from '@renderer/types' +import { Typography } from 'antd' +import { ChevronRight } from 'lucide-react' +import { useTranslation } from 'react-i18next' +import styled from 'styled-components' + +const { Text } = Typography + +export const MessageMemorySearchToolTitle = ({ toolResponse }: { toolResponse: MCPToolResponse }) => { + const { t } = useTranslation() + const toolInput = toolResponse.arguments as MemorySearchToolInput + const toolOutput = toolResponse.response as MemorySearchToolOutput + + return toolResponse.status !== 'done' ? ( + + {t('memory.search_placeholder')} + {toolInput?.query ?? ''} + + } + /> + ) : toolOutput?.length ? ( + + + {/* */} + {toolOutput?.length ?? 0} + {t('memory.memory')} + + ) : null +} + +const MessageWebSearchToolTitleTextWrapper = styled(Text)` + display: flex; + align-items: center; + gap: 4px; + padding: 5px; + padding-left: 0; +` diff --git a/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx b/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx new file mode 100644 index 0000000000..38ae73e95e --- /dev/null +++ b/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx @@ -0,0 +1,80 @@ +import { MCPToolResponse } from '@renderer/types' +import type { ToolMessageBlock } from '@renderer/types/newMessage' +import { Collapse } from 'antd' + +import { MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch' +import { MessageMemorySearchToolTitle } from './MessageMemorySearch' +import { MessageWebSearchToolTitle } from './MessageWebSearch' + +interface Props { + block: ToolMessageBlock +} +const prefix = 'builtin_' + +const ChooseTool = (toolResponse: MCPToolResponse): { label: React.ReactNode; body: React.ReactNode } | null => { + let toolName = toolResponse.tool.name + if (toolName.startsWith(prefix)) { + toolName = toolName.slice(prefix.length) + } + + switch (toolName) { + case 'web_search': + case 'web_search_preview': + return { + label: , + body: null + } + case 'knowledge_search': + return { + label: , + body: null + } + case 'memory_search': + return { + label: , + body: null + } + default: + return null + } +} + +export default function MessageTool({ block }: Props) { + // FIXME: 语义错误,这里已经不是 MCP tool 了,更改rawMcpToolResponse需要改用户数据, 所以暂时保留 + const toolResponse = block.metadata?.rawMcpToolResponse + + if (!toolResponse) return null + + const toolRenderer = ChooseTool(toolResponse) + + if (!toolRenderer) return null + + return toolRenderer.body ? ( + + ) : ( + toolRenderer.label + ) +} +// const PrepareToolWrapper = styled.span` +// display: flex; +// align-items: center; +// gap: 4px; +// font-size: 14px; +// padding-left: 0; +// ` diff --git a/src/renderer/src/pages/home/Messages/Tools/MessageTools.tsx b/src/renderer/src/pages/home/Messages/Tools/MessageTools.tsx new file mode 100644 index 0000000000..15db1525ee --- /dev/null +++ b/src/renderer/src/pages/home/Messages/Tools/MessageTools.tsx @@ -0,0 +1,20 @@ +import type { ToolMessageBlock } from '@renderer/types/newMessage' + +import MessageMcpTool from './MessageMcpTool' +import MessageTool from './MessageTool' + +interface Props { + block: ToolMessageBlock +} + +export default function MessageTools({ block }: Props) { + const toolResponse = block.metadata?.rawMcpToolResponse + if (!toolResponse) return null + + const tool = toolResponse.tool + if (tool.type === 'mcp') { + return + } + + return +} diff --git a/src/renderer/src/pages/home/Messages/Tools/MessageWebSearch.tsx b/src/renderer/src/pages/home/Messages/Tools/MessageWebSearch.tsx new file mode 100644 index 0000000000..a399c790e8 --- /dev/null +++ b/src/renderer/src/pages/home/Messages/Tools/MessageWebSearch.tsx @@ -0,0 +1,72 @@ +import { WebSearchToolInput, WebSearchToolOutput } from '@renderer/aiCore/tools/WebSearchTool' +import Spinner from '@renderer/components/Spinner' +import { MCPToolResponse } from '@renderer/types' +import { Typography } from 'antd' +import { Search } from 'lucide-react' +import { useTranslation } from 'react-i18next' +import styled from 'styled-components' + +const { Text } = Typography + +export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPToolResponse }) => { + const { t } = useTranslation() + const toolInput = toolResponse.arguments as WebSearchToolInput + const toolOutput = toolResponse.response as WebSearchToolOutput + + return toolResponse.status !== 'done' ? ( + + {t('message.searching')} + {toolInput?.additionalContext ?? ''} + + } + /> + ) : ( + + + {t('message.websearch.fetch_complete', { + count: toolOutput?.searchResults?.results?.length ?? 0 + })} + + ) +} + +// export const MessageWebSearchToolBody = ({ toolResponse }: { toolResponse: MCPToolResponse }) => { +// const toolOutput = toolResponse.response as WebSearchToolOutput + +// return toolResponse.status === 'done' +// ? toolOutput?.searchResults?.map((result, index) => ( +// +// {result.results.map((item, index) => ( +//
  • +// {item.title} +//
  • +// ))} +//
    +// )) +// : null +// } + +const PrepareToolWrapper = styled.span` + display: flex; + align-items: center; + gap: 4px; + font-size: 14px; + padding: 5px; + padding-left: 0; +` + +const MessageWebSearchToolTitleTextWrapper = styled(Text)` + display: flex; + align-items: center; + gap: 4px; + padding: 5px; +` + +// const MessageWebSearchToolBodyUlWrapper = styled.ul` +// display: flex; +// flex-direction: column; +// gap: 4px; +// padding: 0; +// ` diff --git a/src/renderer/src/pages/translate/TranslatePage.tsx b/src/renderer/src/pages/translate/TranslatePage.tsx index 6b31bae488..1c9bd94839 100644 --- a/src/renderer/src/pages/translate/TranslatePage.tsx +++ b/src/renderer/src/pages/translate/TranslatePage.tsx @@ -155,7 +155,7 @@ const TranslatePage: FC = () => { } catch (e) { if (!isAbortError(e)) { logger.error('Failed to translate text', e as Error) - window.message.error(t('translate.error.failed' + ': ' + (e as Error).message)) + window.message.error(t('translate.error.failed') + ': ' + formatErrorMessage(e)) } setTranslating(false) return @@ -167,11 +167,11 @@ const TranslatePage: FC = () => { await saveTranslateHistory(text, translated, actualSourceLanguage.langCode, actualTargetLanguage.langCode) } catch (e) { logger.error('Failed to save translate history', e as Error) - window.message.error(t('translate.history.error.save') + ': ' + (e as Error).message) + window.message.error(t('translate.history.error.save') + ': ' + formatErrorMessage(e)) } } catch (e) { logger.error('Failed to translate', e as Error) - window.message.error(t('translate.error.unknown') + ': ' + (e as Error).message) + window.message.error(t('translate.error.unknown') + ': ' + formatErrorMessage(e)) } }, [dispatch, setTranslatedContent, setTranslating, t, translating] diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index fb57a86719..2f69eea8e0 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -1,682 +1,151 @@ +/** + * 职责:提供原子化的、无状态的API调用函数 + */ import { loggerService } from '@logger' -import { CompletionsParams } from '@renderer/aiCore/middleware/schemas' -import { SYSTEM_PROMPT_THRESHOLD } from '@renderer/config/constant' -import { - isEmbeddingModel, - isGenerateImageModel, - isOpenRouterBuiltInWebSearchModel, - isQwenMTModel, - isReasoningModel, - isSupportedReasoningEffortModel, - isSupportedThinkingTokenModel, - isWebSearchModel -} from '@renderer/config/models' -import { - LANG_DETECT_PROMPT, - SEARCH_SUMMARY_PROMPT, - SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY, - SEARCH_SUMMARY_PROMPT_WEB_ONLY -} from '@renderer/config/prompts' -import { getModel } from '@renderer/hooks/useModel' +import AiProvider from '@renderer/aiCore' +import { CompletionsParams } from '@renderer/aiCore/legacy/middleware/schemas' +import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/AiSdkMiddlewareBuilder' +import { buildStreamTextParams } from '@renderer/aiCore/prepareParams' +import { isDedicatedImageGenerationModel, isEmbeddingModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' -import { currentSpan, withSpanResult } from '@renderer/services/SpanManagerService' import store from '@renderer/store' -import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory' -import { - Assistant, - ExternalToolResult, - KnowledgeReference, - MCPTool, - MemoryItem, - Model, - Provider, - WebSearchResponse, - WebSearchSource -} from '@renderer/types' +import type { FetchChatCompletionParams } from '@renderer/types' +import { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types' +import type { StreamTextParams } from '@renderer/types/aiCoreTypes' import { type Chunk, ChunkType } from '@renderer/types/chunk' import { Message } from '@renderer/types/newMessage' import { SdkModel } from '@renderer/types/sdk' import { removeSpecialCharactersForTopicName, uuid } from '@renderer/utils' -import { abortCompletion } from '@renderer/utils/abortController' +import { abortCompletion, readyToAbort } from '@renderer/utils/abortController' import { isAbortError } from '@renderer/utils/error' -import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract' import { purifyMarkdownImages } from '@renderer/utils/markdown' -import { filterAdjacentUserMessaegs, filterLastAssistantMessage } from '@renderer/utils/messageUtils/filters' +import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools' import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' -import { - buildSystemPromptWithThinkTool, - buildSystemPromptWithTools, - containsSupportedVariables, - replacePromptVariables -} from '@renderer/utils/prompt' -import { getTranslateOptions } from '@renderer/utils/translate' -import { findLast, isEmpty, takeRight } from 'lodash' +import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt' +import { isEmpty, takeRight } from 'lodash' -import AiProvider from '../aiCore' +import AiProviderNew, { ModernAiProviderConfig } from '../aiCore/index_new' import { - getAssistantProvider, - getAssistantSettings, + // getAssistantProvider, + // getAssistantSettings, getDefaultAssistant, getDefaultModel, getProviderByModel, getQuickModel } from './AssistantService' -import { processKnowledgeSearch } from './KnowledgeService' -import { MemoryProcessor } from './MemoryProcessor' -import { - filterAfterContextClearMessages, - filterEmptyMessages, - filterUsefulMessages, - filterUserRoleStartMessages -} from './MessagesService' -import WebSearchService from './WebSearchService' +// import { processKnowledgeSearch } from './KnowledgeService' +// import { +// filterContextMessages, +// filterEmptyMessages, +// filterUsefulMessages, +// filterUserRoleStartMessages +// } from './MessagesService' +// import WebSearchService from './WebSearchService' const logger = loggerService.withContext('ApiService') -// TODO:考虑拆开 -async function fetchExternalTool( - lastUserMessage: Message, - assistant: Assistant, - onChunkReceived: (chunk: Chunk) => void, - lastAnswer?: Message -): Promise { - // 可能会有重复? - const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) - const hasKnowledgeBase = !isEmpty(knowledgeBaseIds) - const knowledgeRecognition = assistant.knowledgeRecognition || 'off' - const webSearchProvider = WebSearchService.getWebSearchProvider(assistant.webSearchProviderId) - - // 使用外部搜索工具 - const shouldWebSearch = !!assistant.webSearchProviderId && webSearchProvider !== null - const shouldKnowledgeSearch = hasKnowledgeBase - const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState()) - const shouldSearchMemory = globalMemoryEnabled && assistant.enableMemory - - // 获取 MCP 工具 - let mcpTools: MCPTool[] = [] +export async function fetchMcpTools(assistant: Assistant) { + // Get MCP tools (Fix duplicate declaration) + let mcpTools: MCPTool[] = [] // Initialize as empty array const allMcpServers = store.getState().mcp.servers || [] const activedMcpServers = allMcpServers.filter((s) => s.isActive) const assistantMcpServers = assistant.mcpServers || [] + const enabledMCPs = activedMcpServers.filter((server) => assistantMcpServers.some((s) => s.id === server.id)) - const showListTools = enabledMCPs && enabledMCPs.length > 0 - - // 是否使用工具 - const hasAnyTool = shouldWebSearch || shouldKnowledgeSearch || showListTools - - // 在工具链开始时发送进度通知(不包括记忆搜索) - if (hasAnyTool) { - onChunkReceived({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS }) - } - - // --- Keyword/Question Extraction Function --- - const extract = async (): Promise => { - if (!lastUserMessage) return undefined - - // 根据配置决定是否需要提取 - const needWebExtract = shouldWebSearch - const needKnowledgeExtract = hasKnowledgeBase && knowledgeRecognition === 'on' - - if (!needWebExtract && !needKnowledgeExtract) return undefined - - let prompt: string - if (needWebExtract && !needKnowledgeExtract) { - prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY - } else if (!needWebExtract && needKnowledgeExtract) { - prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY - } else { - prompt = SEARCH_SUMMARY_PROMPT - } - - const summaryAssistant = getDefaultAssistant() - summaryAssistant.model = getQuickModel() || assistant.model || getDefaultModel() - summaryAssistant.prompt = prompt - - const callSearchSummary = async (params: { messages: Message[]; assistant: Assistant }) => { - return await fetchSearchSummary(params) - } - - const traceParams = { - name: `${summaryAssistant.model?.name}.Summary`, - tag: 'LLM', - topicId: lastUserMessage.topicId, - modelName: summaryAssistant.model.name - } - - const searchSummaryParams = { - messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage], - assistant: summaryAssistant - } + if (enabledMCPs && enabledMCPs.length > 0) { try { - const result = await withSpanResult(callSearchSummary, traceParams, searchSummaryParams) - - if (!result) return getFallbackResult() - - const extracted = extractInfoFromXML(result.getText()) - // 根据需求过滤结果 - return { - websearch: needWebExtract ? extracted?.websearch : undefined, - knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined - } - } catch (e: any) { - logger.error('extract error', e) - if (isAbortError(e)) throw e - return getFallbackResult() - } - } - - const getFallbackResult = (): ExtractResults => { - const fallbackContent = getMainTextContent(lastUserMessage) - return { - websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined, - knowledge: shouldKnowledgeSearch - ? { - question: [fallbackContent || 'search'], - rewrite: fallbackContent - } - : undefined - } - } - - // --- Web Search Function --- - const searchTheWeb = async ( - extractResults: ExtractResults | undefined, - parentSpanId?: string - ): Promise => { - if (!shouldWebSearch) return - - // Add check for extractResults existence early - if (!extractResults?.websearch) { - logger.warn('searchTheWeb called without valid extractResults.websearch') - return - } - - if (extractResults.websearch.question[0] === 'not_needed') return - - // Add check for assistant.model before using it - if (!assistant.model) { - logger.warn('searchTheWeb called without assistant.model') - return undefined - } - - try { - // Use the consolidated processWebsearch function - WebSearchService.createAbortSignal(lastUserMessage.id) - let safeWebSearchProvider = webSearchProvider - if (webSearchProvider) { - safeWebSearchProvider = { - ...webSearchProvider, - topicId: lastUserMessage.topicId, - parentSpanId, - modelName: assistant.model.name - } - } - const webSearchResponse = await WebSearchService.processWebsearch( - safeWebSearchProvider!, - extractResults, - lastUserMessage.id - ) - return { - results: webSearchResponse, - source: WebSearchSource.WEBSEARCH - } - } catch (error) { - if (isAbortError(error)) throw error - logger.error('Web search failed:', error as Error) - return - } - } - - const searchMemory = async (): Promise => { - if (!shouldSearchMemory) return [] - try { - const memoryConfig = selectMemoryConfig(store.getState()) - const content = getMainTextContent(lastUserMessage) - if (!content) { - logger.warn('searchMemory called without valid content in lastUserMessage') - return [] - } - - if (memoryConfig.llmApiClient && memoryConfig.embedderApiClient) { - const currentUserId = selectCurrentUserId(store.getState()) - // Search for relevant memories - const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, assistant.id, currentUserId) - logger.info(`Searching for relevant memories with content: ${content}`) - const memoryProcessor = new MemoryProcessor() - const relevantMemories = await memoryProcessor.searchRelevantMemories( - content, - processorConfig, - 5 // Limit to top 5 most relevant memories - ) - - if (relevantMemories?.length > 0) { - logger.info('Found relevant memories:', relevantMemories) - - return relevantMemories - } - return [] - } else { - logger.warn('Memory is enabled but embedding or LLM model is not configured') - return [] - } - } catch (error) { - logger.error('Error processing memory search:', error as Error) - // Continue with conversation even if memory processing fails - return [] - } - } - - // --- Knowledge Base Search Function --- - const searchKnowledgeBase = async ( - extractResults: ExtractResults | undefined, - parentSpanId?: string, - modelName?: string - ): Promise => { - if (!hasKnowledgeBase) return - - // 知识库搜索条件 - let searchCriteria: { question: string[]; rewrite: string } - if (knowledgeRecognition === 'off') { - const directContent = getMainTextContent(lastUserMessage) - searchCriteria = { question: [directContent || 'search'], rewrite: directContent } - } else { - // auto mode - if (!extractResults?.knowledge) { - logger.warn('searchKnowledgeBase: No valid search criteria in auto mode') - return - } - searchCriteria = extractResults.knowledge - } - - if (searchCriteria.question[0] === 'not_needed') return - - try { - const tempExtractResults: ExtractResults = { - websearch: undefined, - knowledge: searchCriteria - } - // Attempt to get knowledgeBaseIds from the main text block - // NOTE: This assumes knowledgeBaseIds are ONLY on the main text block - // NOTE: processKnowledgeSearch needs to handle undefined ids gracefully - // const mainTextBlock = mainTextBlocks - // ?.map((blockId) => store.getState().messageBlocks.entities[blockId]) - // .find((block) => block?.type === MessageBlockType.MAIN_TEXT) as MainTextMessageBlock | undefined - return await processKnowledgeSearch( - tempExtractResults, - knowledgeBaseIds, - lastUserMessage.topicId, - parentSpanId, - modelName - ) - } catch (error) { - logger.error('Knowledge base search failed:', error as Error) - return - } - } - - // --- Execute Extraction and Searches --- - let extractResults: ExtractResults | undefined - - try { - // 根据配置决定是否需要提取 - if (shouldWebSearch || hasKnowledgeBase) { - extractResults = await extract() - logger.info('[fetchExternalTool] Extraction results:', extractResults) - } - - let webSearchResponseFromSearch: WebSearchResponse | undefined - let knowledgeReferencesFromSearch: KnowledgeReference[] | undefined - let memorySearchReferences: MemoryItem[] | undefined - - const parentSpanId = currentSpan(lastUserMessage.topicId, assistant.model?.name)?.spanContext().spanId - if (shouldWebSearch) { - webSearchResponseFromSearch = await searchTheWeb(extractResults, parentSpanId) - } - - if (shouldKnowledgeSearch) { - knowledgeReferencesFromSearch = await searchKnowledgeBase(extractResults, parentSpanId, assistant.model?.name) - } - - if (shouldSearchMemory) { - memorySearchReferences = await searchMemory() - } - - if (lastUserMessage) { - if (webSearchResponseFromSearch) { - window.keyv.set(`web-search-${lastUserMessage.id}`, webSearchResponseFromSearch) - } - if (knowledgeReferencesFromSearch) { - window.keyv.set(`knowledge-search-${lastUserMessage.id}`, knowledgeReferencesFromSearch) - } - if (memorySearchReferences) { - window.keyv.set(`memory-search-${lastUserMessage.id}`, memorySearchReferences) - } - } - - if (showListTools) { - try { - const spanContext = currentSpan(lastUserMessage.topicId, assistant.model?.name)?.spanContext() - const toolPromises = enabledMCPs.map>(async (mcpServer) => { - try { - const tools = await window.api.mcp.listTools(mcpServer, spanContext) - return tools.filter((tool: any) => !mcpServer.disabledTools?.includes(tool.name)) - } catch (error) { - logger.error(`Error fetching tools from MCP server ${mcpServer.name}:`, error as Error) - return [] - } - }) - const results = await Promise.allSettled(toolPromises) - mcpTools = results - .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') - .map((result) => result.value) - .flat() - - // 根据toolUseMode决定如何构建系统提示词 - const basePrompt = assistant.prompt - if (assistant.settings?.toolUseMode === 'prompt' || mcpTools.length > SYSTEM_PROMPT_THRESHOLD) { - // 提示词模式:需要完整的工具定义,思考工具返回会打乱提示词的返回(先去掉) - assistant.prompt = buildSystemPromptWithTools(basePrompt, mcpTools) - } else { - // 原生函数调用模式:仅需要注入思考指令 - assistant.prompt = buildSystemPromptWithThinkTool(basePrompt) - } - } catch (toolError) { - logger.error('Error fetching MCP tools:', toolError as Error) - } - } - - // 发送工具执行完成通知 - if (hasAnyTool) { - onChunkReceived({ - type: ChunkType.EXTERNEL_TOOL_COMPLETE, - external_tool: { - webSearch: webSearchResponseFromSearch, - knowledge: knowledgeReferencesFromSearch, - memories: memorySearchReferences + const toolPromises = enabledMCPs.map(async (mcpServer: MCPServer) => { + try { + const tools = await window.api.mcp.listTools(mcpServer) + return tools.filter((tool: any) => !mcpServer.disabledTools?.includes(tool.name)) + } catch (error) { + logger.error(`Error fetching tools from MCP server ${mcpServer.name}:`, error as Error) + return [] } }) + const results = await Promise.allSettled(toolPromises) + mcpTools = results + .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') + .map((result) => result.value) + .flat() + } catch (toolError) { + logger.error('Error fetching MCP tools:', toolError as Error) } - - return { mcpTools } - } catch (error) { - if (isAbortError(error)) throw error - logger.error('Tool execution failed:', error as Error) - - // 发送错误状态 - const wasAnyToolEnabled = shouldWebSearch || shouldKnowledgeSearch || shouldSearchMemory - if (wasAnyToolEnabled) { - onChunkReceived({ - type: ChunkType.EXTERNEL_TOOL_COMPLETE, - external_tool: { - webSearch: undefined, - knowledge: undefined - } - }) - } - - return { mcpTools: [] } } + return mcpTools } export async function fetchChatCompletion({ messages, + prompt, assistant, - onChunkReceived -}: { - messages: Message[] - assistant: Assistant - onChunkReceived: (chunk: Chunk) => void - // TODO - // onChunkStatus: (status: 'searching' | 'processing' | 'success' | 'error') => void -}) { - logger.debug('fetchChatCompletion', messages, assistant) + options, + onChunkReceived, + topicId, + uiMessages +}: FetchChatCompletionParams) { + logger.info('fetchChatCompletion called with detailed context', { + messageCount: messages?.length || 0, + prompt: prompt, + assistantId: assistant.id, + topicId, + hasTopicId: !!topicId, + modelId: assistant.model?.id, + modelName: assistant.model?.name + }) + const AI = new AiProviderNew(assistant.model || getDefaultModel()) + const provider = AI.getActualProvider() - if (assistant.prompt && containsSupportedVariables(assistant.prompt)) { - assistant.prompt = await replacePromptVariables(assistant.prompt, assistant.model?.name) + const mcpTools: MCPTool[] = [] + + if (isSupportedToolUse(assistant)) { + mcpTools.push(...(await fetchMcpTools(assistant))) + } + if (prompt) { + messages = [ + { + role: 'user', + content: prompt + } + ] } - const provider = getAssistantProvider(assistant) - const AI = new AiProvider(provider) + // 使用 transformParameters 模块构建参数 + const { + params: aiSdkParams, + modelId, + capabilities + } = await buildStreamTextParams(messages, assistant, provider, { + mcpTools: mcpTools, + webSearchProviderId: assistant.webSearchProviderId, + requestOptions: options + }) - // Make sure that 'Clear Context' works for all scenarios including external tool and normal chat. - const filteredMessages1 = filterAfterContextClearMessages(messages) - - const lastUserMessage = findLast(messages, (m) => m.role === 'user') - const lastAnswer = findLast(messages, (m) => m.role === 'assistant') - if (!lastUserMessage) { - logger.error('fetchChatCompletion returning early: Missing lastUserMessage or lastAnswer') - return + const middlewareConfig: AiSdkMiddlewareConfig = { + streamOutput: assistant.settings?.streamOutput ?? true, + onChunk: onChunkReceived, + model: assistant.model, + enableReasoning: capabilities.enableReasoning, + isPromptToolUse: isPromptToolUse(assistant), + isSupportedToolUse: isSupportedToolUse(assistant), + isImageGenerationEndpoint: isDedicatedImageGenerationModel(assistant.model || getDefaultModel()), + enableWebSearch: capabilities.enableWebSearch, + enableGenerateImage: capabilities.enableGenerateImage, + mcpTools, + uiMessages } - // try { - // NOTE: The search results are NOT added to the messages sent to the AI here. - // They will be retrieved and used by the messageThunk later to create CitationBlocks. - const { mcpTools } = await fetchExternalTool(lastUserMessage, assistant, onChunkReceived, lastAnswer) - const model = assistant.model || getDefaultModel() - - const { maxTokens, contextCount } = getAssistantSettings(assistant) - - const filteredMessages2 = filterUsefulMessages(filteredMessages1) - - const filteredMessages3 = filterLastAssistantMessage(filteredMessages2) - - const filteredMessages4 = filterAdjacentUserMessaegs(filteredMessages3) - - let _messages = filterUserRoleStartMessages( - filterEmptyMessages(filterAfterContextClearMessages(takeRight(filteredMessages4, contextCount + 2))) // 取原来几个provider的最大值 - ) - - // Fallback: ensure at least the last user message is present to avoid empty payloads - if ((!_messages || _messages.length === 0) && lastUserMessage) { - _messages = [lastUserMessage] - } - - // FIXME: qwen3即使关闭思考仍然会导致enableReasoning的结果为true - const enableReasoning = - ((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) && - assistant.settings?.reasoning_effort !== undefined) || - (isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model))) - - // NOTE:assistant.enableWebSearch 的语义是是否启用模型内置搜索功能 - const enableWebSearch = - assistant.enableWebSearch || - (assistant.webSearchProviderId && isWebSearchModel(model)) || - isOpenRouterBuiltInWebSearchModel(model) || - model.id.includes('sonar') || - false - - const enableUrlContext = assistant.enableUrlContext || false - - const enableGenerateImage = isGenerateImageModel(model) && assistant.enableGenerateImage // --- Call AI Completions --- onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED }) - const completionsParams: CompletionsParams = { + await AI.completions(modelId, aiSdkParams, { + ...middlewareConfig, + assistant, + topicId, callType: 'chat', - messages: _messages, - assistant, - onChunk: onChunkReceived, - mcpTools: mcpTools, - maxTokens, - streamOutput: assistant.settings?.streamOutput || false, - enableReasoning, - enableWebSearch, - enableUrlContext, - enableGenerateImage, - topicId: lastUserMessage.topicId - } - - const requestOptions = { - streamOutput: assistant.settings?.streamOutput || false - } - - // Post-conversation memory processing - const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState()) - if (globalMemoryEnabled && assistant.enableMemory) { - processConversationMemory(messages, assistant) - } - - return await AI.completionsForTrace(completionsParams, requestOptions) -} - -/** - * Process conversation for memory extraction and storage - */ -async function processConversationMemory(messages: Message[], assistant: Assistant) { - try { - const memoryConfig = selectMemoryConfig(store.getState()) - - // Use assistant's model as fallback for memory processing if not configured - const llmModel = - getModel(memoryConfig.llmApiClient?.model, memoryConfig.llmApiClient?.provider) || - assistant.model || - getDefaultModel() - const embedderModel = - getModel(memoryConfig.embedderApiClient?.model, memoryConfig.embedderApiClient?.provider) || - getFirstEmbeddingModel() - - if (!embedderModel) { - logger.warn( - 'Memory processing skipped: no embedding model available. Please configure an embedding model in memory settings.' - ) - return - } - - if (!llmModel) { - logger.warn('Memory processing skipped: LLM model not available') - return - } - - // Convert messages to the format expected by memory processor - const conversationMessages = messages - .filter((msg) => msg.role === 'user' || msg.role === 'assistant') - .map((msg) => ({ - role: msg.role as 'user' | 'assistant', - content: getMainTextContent(msg) || '' - })) - .filter((msg) => msg.content.trim().length > 0) - - // if (conversationMessages.length < 2) { - // Need at least a user message and assistant response - // return - // } - - const currentUserId = selectCurrentUserId(store.getState()) - - // Create updated memory config with resolved models - const updatedMemoryConfig = { - ...memoryConfig, - llmApiClient: { - model: llmModel.id, - provider: llmModel.provider, - apiKey: getProviderByModel(llmModel).apiKey, - baseURL: new AiProvider(getProviderByModel(llmModel)).getBaseURL(), - apiVersion: getProviderByModel(llmModel).apiVersion - }, - embedderApiClient: { - model: embedderModel.id, - provider: embedderModel.provider, - apiKey: getProviderByModel(embedderModel).apiKey, - baseURL: new AiProvider(getProviderByModel(embedderModel)).getBaseURL(), - apiVersion: getProviderByModel(embedderModel).apiVersion - } - } - - const lastUserMessage = findLast(messages, (m) => m.role === 'user') - const processorConfig = MemoryProcessor.getProcessorConfig( - updatedMemoryConfig, - assistant.id, - currentUserId, - lastUserMessage?.id - ) - - // Process the conversation in the background (don't await to avoid blocking UI) - const memoryProcessor = new MemoryProcessor() - memoryProcessor - .processConversation(conversationMessages, processorConfig) - .then((result) => { - logger.debug('Memory processing completed:', result) - if (result.facts.length > 0) { - logger.debug('Extracted facts from conversation:', result.facts) - logger.debug('Memory operations performed:', result.operations) - } else { - logger.debug('No facts extracted from conversation') - } - }) - .catch((error) => { - logger.error('Background memory processing failed:', error as Error) - }) - } catch (error) { - logger.error('Error in post-conversation memory processing:', error as Error) - } -} - -interface FetchLanguageDetectionProps { - text: string - onResponse?: (text: string, isComplete: boolean) => void -} - -/** - * 检测文本语言 - * @param params - 参数对象 - * @param {string} params.text - 需要检测语言的文本内容 - * @param {function} [params.onResponse] - 流式响应回调函数,用于实时获取检测结果 - * @returns {Promise} 返回检测到的语言代码,如果检测失败会抛出错误 - * @throws {Error} - */ -export async function fetchLanguageDetection({ text, onResponse }: FetchLanguageDetectionProps) { - const translateLanguageOptions = await getTranslateOptions() - const listLang = translateLanguageOptions.map((item) => item.langCode) - const listLangText = JSON.stringify(listLang) - - const model = getQuickModel() || getDefaultModel() - if (!model) { - throw new Error(i18n.t('error.model.not_exists')) - } - - if (isQwenMTModel(model)) { - logger.info('QwenMT cannot be used for language detection.') - if (isQwenMTModel(model)) { - throw new Error(i18n.t('translate.error.detect.qwen_mt')) - } - } - - const provider = getProviderByModel(model) - - if (!hasApiKey(provider)) { - throw new Error(i18n.t('error.no_api_key')) - } - - const assistant: Assistant = getDefaultAssistant() - - assistant.model = model - assistant.settings = { - temperature: 0.7 - } - assistant.prompt = LANG_DETECT_PROMPT.replace('{{list_lang}}', listLangText).replace('{{input}}', text) - - const isSupportedStreamOutput = () => { - if (!onResponse) { - return false - } - return true - } - - const stream = isSupportedStreamOutput() - - const params: CompletionsParams = { - callType: 'translate-lang-detect', - messages: 'follow system prompt', - assistant, - streamOutput: stream, - enableReasoning: false, - shouldThrow: true, - onResponse - } - - const AI = new AiProvider(provider) - - return (await AI.completions(params)).getText() + uiMessages + }) } export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) { @@ -689,16 +158,15 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages: // 总结上下文总是取最后5条消息 const contextMessages = takeRight(messages, 5) - const provider = getProviderByModel(model) if (!hasApiKey(provider)) { return null } - const AI = new AiProvider(provider) + const AI = new AiProviderNew(model) - const topicId = messages?.find((message) => message.topicId)?.topicId || undefined + const topicId = messages?.find((message) => message.topicId)?.topicId || '' // LLM对多条消息的总结有问题,用单条结构化的消息表示会话内容会更好 const structredMessages = contextMessages.map((message) => { @@ -721,28 +189,57 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages: }) const conversation = JSON.stringify(structredMessages) - // 复制 assistant 对象,并强制关闭思考预算 + // // 复制 assistant 对象,并强制关闭思考预算 + // const summaryAssistant = { + // ...assistant, + // settings: { + // ...assistant.settings, + // reasoning_effort: undefined, + // qwenThinkMode: false + // } + // } const summaryAssistant = { ...assistant, settings: { ...assistant.settings, reasoning_effort: undefined, qwenThinkMode: false - } + }, + prompt, + model } - const params: CompletionsParams = { - callType: 'summary', - messages: conversation, - assistant: { ...summaryAssistant, prompt, model }, - maxTokens: 1000, + const llmMessages = { + system: prompt, + prompt: conversation + } + + const middlewareConfig: AiSdkMiddlewareConfig = { streamOutput: false, - topicId, - enableReasoning: false + enableReasoning: false, + isPromptToolUse: false, + isSupportedToolUse: false, + isImageGenerationEndpoint: false, + enableWebSearch: false, + enableGenerateImage: false, + mcpTools: [] } - try { - const { getText } = await AI.completionsForTrace(params) + // 从 messages 中找到有 traceId 的助手消息,用于绑定现有 trace + const messageWithTrace = messages.find((m) => m.role === 'assistant' && m.traceId) + + if (messageWithTrace && messageWithTrace.traceId) { + // 导入并调用 appendTrace 来绑定现有 trace,传入summary使用的模型名 + const { appendTrace } = await import('@renderer/services/SpanManagerService') + await appendTrace({ topicId, traceId: messageWithTrace.traceId, model }) + } + + const { getText } = await AI.completions(model.id, llmMessages, { + ...middlewareConfig, + assistant: summaryAssistant, + topicId, + callType: 'summary' + }) const text = getText() return removeSpecialCharactersForTopicName(text) || null } catch (error: any) { @@ -750,28 +247,28 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages: } } -export async function fetchSearchSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) { - const model = getQuickModel() || assistant.model || getDefaultModel() - const provider = getProviderByModel(model) +// export async function fetchSearchSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) { +// const model = getQuickModel() || assistant.model || getDefaultModel() +// const provider = getProviderByModel(model) - if (!hasApiKey(provider)) { - return null - } +// if (!hasApiKey(provider)) { +// return null +// } - const topicId = messages?.find((message) => message.topicId)?.topicId || undefined +// const topicId = messages?.find((message) => message.topicId)?.topicId || undefined - const AI = new AiProvider(provider) +// const AI = new AiProvider(provider) - const params: CompletionsParams = { - callType: 'search', - messages: messages, - assistant, - streamOutput: false, - topicId - } +// const params: CompletionsParams = { +// callType: 'search', +// messages: messages, +// assistant, +// streamOutput: false, +// topicId +// } - return await AI.completionsForTrace(params) -} +// return await AI.completionsForTrace(params) +// } export async function fetchGenerate({ prompt, @@ -791,21 +288,42 @@ export async function fetchGenerate({ return '' } - const AI = new AiProvider(provider) + const AI = new AiProviderNew(model) const assistant = getDefaultAssistant() assistant.model = model assistant.prompt = prompt - const params: CompletionsParams = { - callType: 'generate', - messages: content, - assistant, - streamOutput: false + // const params: CompletionsParams = { + // callType: 'generate', + // messages: content, + // assistant, + // streamOutput: false + // } + + const middlewareConfig: AiSdkMiddlewareConfig = { + streamOutput: assistant.settings?.streamOutput ?? false, + enableReasoning: false, + isPromptToolUse: false, + isSupportedToolUse: false, + isImageGenerationEndpoint: false, + enableWebSearch: false, + enableGenerateImage: false } try { - const result = await AI.completions(params) + const result = await AI.completions( + model.id, + { + system: prompt, + prompt: content + }, + { + ...middlewareConfig, + assistant, + callType: 'generate' + } + ) return result.getText() || '' } catch (error: any) { return '' @@ -821,21 +339,21 @@ export function hasApiKey(provider: Provider) { /** * Get the first available embedding model from enabled providers */ -function getFirstEmbeddingModel() { - const providers = store.getState().llm.providers.filter((p) => p.enabled) +// function getFirstEmbeddingModel() { +// const providers = store.getState().llm.providers.filter((p) => p.enabled) - for (const provider of providers) { - const embeddingModel = provider.models.find((model) => isEmbeddingModel(model)) - if (embeddingModel) { - return embeddingModel - } - } +// for (const provider of providers) { +// const embeddingModel = provider.models.find((model) => isEmbeddingModel(model)) +// if (embeddingModel) { +// return embeddingModel +// } +// } - return undefined -} +// return undefined +// } export async function fetchModels(provider: Provider): Promise { - const AI = new AiProvider(provider) + const AI = new AiProviderNew(provider) try { return await AI.models() @@ -874,12 +392,11 @@ export function checkApiProvider(provider: Provider): void { export async function checkApi(provider: Provider, model: Model, timeout = 15000): Promise { checkApiProvider(provider) - const taskId = uuid() - - const ai = new AiProvider(provider) + const ai = new AiProviderNew(model) const assistant = getDefaultAssistant() assistant.model = model + assistant.prompt = 'test' // 避免部分 provider 空系统提示词会报错 try { if (isEmbeddingModel(model)) { // race 超时 15s @@ -887,42 +404,45 @@ export async function checkApi(provider: Provider, model: Model, timeout = 15000 const timerPromise = new Promise((_, reject) => setTimeout(() => reject('Timeout'), timeout)) await Promise.race([ai.getEmbeddingDimensions(model), timerPromise]) } else { - let streamError: Error | undefined = undefined - - // 15s超时 - const timer = setTimeout(() => { - abortCompletion(taskId) - streamError = new Error('Timeout') - }, timeout) - - const params: CompletionsParams = { - callType: 'check', - messages: 'hi', - assistant, + const abortId = uuid() + const signal = readyToAbort(abortId) + let chunkError + const params: StreamTextParams = { + system: assistant.prompt, + prompt: 'hi', + abortSignal: signal + } + const config: ModernAiProviderConfig = { streamOutput: true, enableReasoning: false, + isSupportedToolUse: false, + isImageGenerationEndpoint: false, + enableWebSearch: false, + enableGenerateImage: false, + isPromptToolUse: false, + assistant, + callType: 'check', onChunk: (chunk: Chunk) => { - if (chunk.type === ChunkType.ERROR && !isAbortError(chunk.error)) { - streamError = new Error(JSON.stringify(chunk.error)) + if (chunk.type === ChunkType.ERROR) { + chunkError = chunk.error + } else { + abortCompletion(abortId) } - abortCompletion(taskId) - }, - shouldThrow: true, - abortKey: taskId + } } // Try streaming check first try { - await ai.completions(params) - } finally { - clearTimeout(timer) - } - if (streamError) { - throw streamError + await ai.completions(model.id, params, config) + } catch (e) { + if (!isAbortError(e) && !isAbortError(chunkError)) { + throw e + } } } } catch (error: any) { - // FIXME: 这种判断方法无法严格保证错误是流式引起的 + // 失败回退legacy + const legacyAi = new AiProvider(provider) if (error.message.includes('stream')) { const params: CompletionsParams = { callType: 'check', @@ -931,12 +451,16 @@ export async function checkApi(provider: Provider, model: Model, timeout = 15000 streamOutput: false, shouldThrow: true } - // 超时判断 - const timeoutPromise = new Promise((_, reject) => setTimeout(() => reject('Timeout'), timeout)) - await Promise.race([ai.completions(params), timeoutPromise]) + const result = await legacyAi.completions(params) + if (!result.getText()) { + throw new Error('No response received') + } } else { throw error } + // } finally { + // removeAbortController(taskId, abortFn) + // } } } diff --git a/src/renderer/src/services/AssistantService.ts b/src/renderer/src/services/AssistantService.ts index 567c59991c..a8c95380a0 100644 --- a/src/renderer/src/services/AssistantService.ts +++ b/src/renderer/src/services/AssistantService.ts @@ -175,6 +175,7 @@ export const getAssistantSettings = (assistant: Assistant): AssistantSettings => streamOutput: assistant?.settings?.streamOutput ?? true, toolUseMode: assistant?.settings?.toolUseMode ?? 'prompt', defaultModel: assistant?.defaultModel ?? undefined, + reasoning_effort: assistant?.settings?.reasoning_effort ?? undefined, customParameters: assistant?.settings?.customParameters ?? [] } } diff --git a/src/renderer/src/services/ConversationService.ts b/src/renderer/src/services/ConversationService.ts new file mode 100644 index 0000000000..a7f3fab13c --- /dev/null +++ b/src/renderer/src/services/ConversationService.ts @@ -0,0 +1,61 @@ +import { convertMessagesToSdkMessages } from '@renderer/aiCore/prepareParams' +import { Assistant, Message } from '@renderer/types' +import { filterAdjacentUserMessaegs, filterLastAssistantMessage } from '@renderer/utils/messageUtils/filters' +import { ModelMessage } from 'ai' +import { findLast, isEmpty, takeRight } from 'lodash' + +import { getAssistantSettings, getDefaultModel } from './AssistantService' +import { + filterAfterContextClearMessages, + filterEmptyMessages, + filterUsefulMessages, + filterUserRoleStartMessages +} from './MessagesService' + +export class ConversationService { + static async prepareMessagesForModel( + messages: Message[], + assistant: Assistant + ): Promise<{ modelMessages: ModelMessage[]; uiMessages: Message[] }> { + const { contextCount } = getAssistantSettings(assistant) + // This logic is extracted from the original ApiService.fetchChatCompletion + // const contextMessages = filterContextMessages(messages) + const lastUserMessage = findLast(messages, (m) => m.role === 'user') + if (!lastUserMessage) { + return { + modelMessages: [], + uiMessages: [] + } + } + + const filteredMessages1 = filterAfterContextClearMessages(messages) + + const filteredMessages2 = filterUsefulMessages(filteredMessages1) + + const filteredMessages3 = filterLastAssistantMessage(filteredMessages2) + + const filteredMessages4 = filterAdjacentUserMessaegs(filteredMessages3) + + let uiMessages = filterUserRoleStartMessages( + filterEmptyMessages(filterAfterContextClearMessages(takeRight(filteredMessages4, contextCount + 2))) // 取原来几个provider的最大值 + ) + + // Fallback: ensure at least the last user message is present to avoid empty payloads + if ((!uiMessages || uiMessages.length === 0) && lastUserMessage) { + uiMessages = [lastUserMessage] + } + + return { + modelMessages: await convertMessagesToSdkMessages(uiMessages, assistant.model || getDefaultModel()), + uiMessages + } + } + + static needsWebSearch(assistant: Assistant): boolean { + return !!assistant.webSearchProviderId + } + + static needsKnowledgeSearch(assistant: Assistant): boolean { + return !isEmpty(assistant.knowledge_bases) + } +} diff --git a/src/renderer/src/services/MemoryProcessor.ts b/src/renderer/src/services/MemoryProcessor.ts index ef7833494b..28656e54ed 100644 --- a/src/renderer/src/services/MemoryProcessor.ts +++ b/src/renderer/src/services/MemoryProcessor.ts @@ -103,8 +103,7 @@ export class MemoryProcessor { if (!memoryConfig.llmApiClient) { throw new Error('No LLM model configured for memory processing') } - - const existingMemoriesResult = window.keyv.get(`memory-search-${lastMessageId}`) as MemoryItem[] | [] + const existingMemoriesResult = (window.keyv.get(`memory-search-${lastMessageId}`) as MemoryItem[]) || [] const existingMemories = existingMemoriesResult.map((memory) => ({ id: memory.id, diff --git a/src/renderer/src/services/OrchestrateService.ts b/src/renderer/src/services/OrchestrateService.ts new file mode 100644 index 0000000000..eef206d14e --- /dev/null +++ b/src/renderer/src/services/OrchestrateService.ts @@ -0,0 +1,90 @@ +import { Assistant, Message } from '@renderer/types' +import { Chunk, ChunkType } from '@renderer/types/chunk' +import { replacePromptVariables } from '@renderer/utils/prompt' + +import { fetchChatCompletion } from './ApiService' +import { ConversationService } from './ConversationService' + +/** + * The request object for handling a user message. + */ +export interface OrchestrationRequest { + messages: Message[] + assistant: Assistant + options: { + signal?: AbortSignal + timeout?: number + headers?: Record + } + topicId?: string // 添加 topicId 用于 trace +} + +/** + * The OrchestrationService is responsible for orchestrating the different services + * to handle a user's message. It contains the core logic of the application. + */ +// NOTE:暂时没有用到这个类 +export class OrchestrationService { + constructor() { + // In the future, this could be a singleton, but for now, a new instance is fine. + // this.conversationService = new ConversationService() + } + + /** + * This is the core method to handle user messages. + * It takes the message context and an events object for callbacks, + * and orchestrates the call to the LLM. + * The logic is moved from `messageThunk.ts`. + * @param request The orchestration request containing messages and assistant info. + * @param events A set of callbacks to report progress and results to the UI layer. + */ + async transformMessagesAndFetch(request: OrchestrationRequest, onChunkReceived: (chunk: Chunk) => void) { + const { messages, assistant } = request + + try { + const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant) + + await fetchChatCompletion({ + messages: modelMessages, + assistant: assistant, + options: request.options, + onChunkReceived, + topicId: request.topicId, + uiMessages: uiMessages + }) + } catch (error: any) { + onChunkReceived({ type: ChunkType.ERROR, error }) + } + } +} + +/** + * 将用户消息转换为LLM可以理解的格式并发送请求 + * @param request - 包含消息内容和助手信息的请求对象 + * @param onChunkReceived - 接收流式响应数据的回调函数 + */ +// 目前先按照函数来写,后续如果有需要到class的地方就改回来 +export async function transformMessagesAndFetch( + request: OrchestrationRequest, + onChunkReceived: (chunk: Chunk) => void +) { + const { messages, assistant } = request + + try { + const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant) + + // replace prompt variables + assistant.prompt = await replacePromptVariables(assistant.prompt, assistant.model?.name) + + await fetchChatCompletion({ + messages: modelMessages, + assistant: assistant, + options: request.options, + onChunkReceived, + topicId: request.topicId, + uiMessages + }) + } catch (error: any) { + onChunkReceived({ type: ChunkType.ERROR, error }) + } +} diff --git a/src/renderer/src/services/SpanManagerService.ts b/src/renderer/src/services/SpanManagerService.ts index 14a80e6386..bfb59cb25d 100644 --- a/src/renderer/src/services/SpanManagerService.ts +++ b/src/renderer/src/services/SpanManagerService.ts @@ -3,7 +3,7 @@ import { loggerService } from '@logger' import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core' import { cleanContext, endContext, getContext, startContext } from '@mcp-trace/trace-web' import { Context, context, Span, SpanStatusCode, trace } from '@opentelemetry/api' -import { isAsyncIterable } from '@renderer/aiCore/middleware/utils' +import { isAsyncIterable } from '@renderer/aiCore/legacy/middleware/utils' import { db } from '@renderer/databases' import { getEnableDeveloperMode } from '@renderer/hooks/useSettings' import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService' @@ -97,7 +97,7 @@ class SpanManagerService { window.api.trace.openWindow(message.topicId, message.traceId, false, modelName) } - async appendTrace(message: Message, model: Model) { + async appendMessageTrace(message: Message, model: Model) { if (!getEnableDeveloperMode()) { return } @@ -113,6 +113,29 @@ class SpanManagerService { window.api.trace.openWindow(message.topicId, message.traceId, false, model.name) } + async appendTrace({ topicId, traceId, model }: { topicId: string; traceId: string; model: Model }) { + if (!getEnableDeveloperMode()) { + return + } + if (!traceId) { + return + } + + await window.api.trace.cleanHistory(topicId, traceId, model.name) + + // const input = await this._getContentFromMessage(message) + await window.api.trace.bindTopic(topicId, traceId) + + // 不使用 _addModelRootSpan,直接创建简单的 span 来避免额外的模型层级 + const entity = this.getModelSpanEntity(topicId, model.name) + const span = webTracer.startSpan('') + span['_spanContext'].traceId = traceId + entity.addSpan(span, true) + this._updateContext(span, topicId, traceId) + + window.api.trace.openWindow(topicId, traceId, false, model.name) + } + private async _getContentFromMessage(message: Message, content?: string): Promise { let _content = content if (!_content) { @@ -349,6 +372,7 @@ export const currentSpan = spanManagerService.getCurrentSpan.bind(spanManagerSer export const addTokenUsage = spanManagerService.addTokenUsage.bind(spanManagerService) export const pauseTrace = spanManagerService.finishModelTrace.bind(spanManagerService) export const appendTrace = spanManagerService.appendTrace.bind(spanManagerService) +export const appendMessageTrace = spanManagerService.appendMessageTrace.bind(spanManagerService) export const restartTrace = spanManagerService.restartTrace.bind(spanManagerService) EventEmitter.on(EVENT_NAMES.SEND_MESSAGE, ({ topicId, traceId }) => { diff --git a/src/renderer/src/services/StreamProcessingService.ts b/src/renderer/src/services/StreamProcessingService.ts index 4861e5b11a..fb92b867fe 100644 --- a/src/renderer/src/services/StreamProcessingService.ts +++ b/src/renderer/src/services/StreamProcessingService.ts @@ -43,6 +43,8 @@ export interface StreamProcessorCallbacks { onError?: (error: any) => void // Called when the entire stream processing is signaled as complete (success or failure) onComplete?: (status: AssistantMessageStatus, response?: Response) => void + // Called when a block is created + onBlockCreated?: () => void } // Function to create a stream processor instance @@ -51,7 +53,7 @@ export function createStreamProcessor(callbacks: StreamProcessorCallbacks = {}) return (chunk: Chunk) => { try { const data = chunk - logger.debug('data: ', data) + // logger.debug('data: ', data) switch (data.type) { case ChunkType.BLOCK_COMPLETE: { if (callbacks.onComplete) callbacks.onComplete(AssistantMessageStatus.SUCCESS, data?.response) @@ -136,6 +138,10 @@ export function createStreamProcessor(callbacks: StreamProcessorCallbacks = {}) if (callbacks.onError) callbacks.onError(data.error) break } + case ChunkType.BLOCK_CREATED: { + if (callbacks.onBlockCreated) callbacks.onBlockCreated() + break + } default: { // Handle unknown chunk types or log an error logger.warn(`Unknown chunk type: ${data.type}`) diff --git a/src/renderer/src/services/TranslateService.ts b/src/renderer/src/services/TranslateService.ts index 5d09933d46..c8bdad9406 100644 --- a/src/renderer/src/services/TranslateService.ts +++ b/src/renderer/src/services/TranslateService.ts @@ -1,75 +1,69 @@ import { loggerService } from '@logger' -import AiProvider from '@renderer/aiCore' -import { CompletionsParams } from '@renderer/aiCore/middleware/schemas' -import { - isReasoningModel, - isSupportedReasoningEffortModel, - isSupportedThinkingTokenModel -} from '@renderer/config/models' import { db } from '@renderer/databases' -import { CustomTranslateLanguage, TranslateHistory, TranslateLanguage, TranslateLanguageCode } from '@renderer/types' -import { TranslateAssistant } from '@renderer/types' -import { ChunkType } from '@renderer/types/chunk' +import { + CustomTranslateLanguage, + FetchChatCompletionOptions, + TranslateHistory, + TranslateLanguage, + TranslateLanguageCode +} from '@renderer/types' +import { Chunk, ChunkType } from '@renderer/types/chunk' import { uuid } from '@renderer/utils' +import { readyToAbort } from '@renderer/utils/abortController' import { formatErrorMessage, isAbortError } from '@renderer/utils/error' import { t } from 'i18next' -import { hasApiKey } from './ApiService' -import { getDefaultTranslateAssistant, getProviderByModel } from './AssistantService' +import { fetchChatCompletion } from './ApiService' +import { getDefaultTranslateAssistant } from './AssistantService' const logger = loggerService.withContext('TranslateService') -interface FetchTranslateProps { - assistant: TranslateAssistant - onResponse?: (text: string, isComplete: boolean) => void - abortKey?: string -} -async function fetchTranslate({ assistant, onResponse, abortKey }: FetchTranslateProps) { - const model = assistant.model +// async function fetchTranslate({ assistant, onResponse, abortKey }: FetchTranslateProps) { +// const model = assistant.model - const provider = getProviderByModel(model) +// const provider = getProviderByModel(model) - if (!hasApiKey(provider)) { - throw new Error(t('error.no_api_key')) - } +// if (!hasApiKey(provider)) { +// throw new Error(t('error.no_api_key')) +// } - const isSupportedStreamOutput = () => { - if (!onResponse) { - return false - } - return true - } +// const isSupportedStreamOutput = () => { +// if (!onResponse) { +// return false +// } +// return true +// } - const stream = isSupportedStreamOutput() - const enableReasoning = - ((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) && - assistant.settings?.reasoning_effort !== undefined) || - (isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model))) - let abortError +// const stream = isSupportedStreamOutput() +// const enableReasoning = +// ((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) && +// assistant.settings?.reasoning_effort !== undefined) || +// (isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model))) - const params: CompletionsParams = { - callType: 'translate', - messages: assistant.content, - assistant, - streamOutput: stream, - enableReasoning, - onResponse, - onChunk: (chunk) => { - if (chunk.type === ChunkType.ERROR && isAbortError(chunk.error)) { - abortError = chunk.error - } - }, - abortKey - } +// // abort control +// const controller = new AbortController() +// const signal = controller.signal - const AI = new AiProvider(provider) +// // 使用 transformParameters 模块构建参数 +// const { params, modelId, capabilities } = await buildStreamTextParams(undefined, assistant, provider, { +// requestOptions: { +// signal +// } +// }) - const result = (await AI.completionsForTrace(params)).getText().trim() - if (abortError) { - throw abortError - } - return result -} +// const options: ModernAiProviderConfig = { +// assistant, +// streamOutput: stream, +// enableReasoning, +// model: assistant.model, +// provider: provider +// } + +// const AI = new ModernAiProvider(model, provider) + +// const result = (await AI.completions(modelId, params, options)).getText().trim() +// return result +// } /** * 翻译文本到目标语言 @@ -86,10 +80,38 @@ export const translateText = async ( onResponse?: (text: string, isComplete: boolean) => void, abortKey?: string ) => { + let abortError try { const assistant = getDefaultTranslateAssistant(targetLanguage, text) - const translatedText = await fetchTranslate({ assistant, onResponse, abortKey }) + const signal = abortKey ? readyToAbort(abortKey) : undefined + + let translatedText = '' + let completed = false + const onChunk = (chunk: Chunk) => { + if (chunk.type === ChunkType.TEXT_DELTA) { + translatedText = chunk.text + } else if (chunk.type === ChunkType.TEXT_COMPLETE) { + completed = true + } else if (chunk.type === ChunkType.ERROR) { + if (isAbortError(chunk.error)) { + abortError = chunk.error + completed = true + } + } + onResponse?.(translatedText, completed) + } + + const options = { + signal + } satisfies FetchChatCompletionOptions + + await fetchChatCompletion({ + prompt: assistant.content, + assistant, + options, + onChunkReceived: onChunk + }) const trimmedText = translatedText.trim() @@ -101,11 +123,15 @@ export const translateText = async ( } catch (e) { if (isAbortError(e)) { window.message.info(t('translate.info.aborted')) + throw e + } else if (isAbortError(abortError)) { + window.message.info(t('translate.info.aborted')) + throw abortError } else { logger.error('Failed to translate', e as Error) window.message.error(t('translate.error.failed' + ': ' + formatErrorMessage(e))) + throw e } - throw e } } diff --git a/src/renderer/src/services/WebSearchService.ts b/src/renderer/src/services/WebSearchService.ts index 062d4e4ee4..57189ec0da 100644 --- a/src/renderer/src/services/WebSearchService.ts +++ b/src/renderer/src/services/WebSearchService.ts @@ -150,6 +150,7 @@ class WebSearchService { */ public getWebSearchProvider(providerId?: string): WebSearchProvider | undefined { const { providers } = this.getWebSearchState() + logger.debug('providers', providers) const provider = providers.find((provider) => provider.id === providerId) return provider diff --git a/src/renderer/src/services/WebTraceService.ts b/src/renderer/src/services/WebTraceService.ts index 135e75b0a8..119da16d05 100644 --- a/src/renderer/src/services/WebTraceService.ts +++ b/src/renderer/src/services/WebTraceService.ts @@ -1,6 +1,7 @@ import { loggerService } from '@logger' import { convertSpanToSpanEntity, FunctionSpanExporter, FunctionSpanProcessor } from '@mcp-trace/trace-core' import { WebTracer } from '@mcp-trace/trace-web' +import { trace } from '@opentelemetry/api' import { ReadableSpan } from '@opentelemetry/sdk-trace-base' const logger = loggerService.withContext('WebTraceService') @@ -33,6 +34,10 @@ class WebTraceService { processor ) } + + getTracer() { + return trace.getTracer(TRACER_NAME, '1.0.0') + } } export const webTraceService = new WebTraceService() diff --git a/src/renderer/src/services/__tests__/ApiService.test.ts b/src/renderer/src/services/__tests__/ApiService.test.ts index 12ac0f95e5..5702561aa2 100644 --- a/src/renderer/src/services/__tests__/ApiService.test.ts +++ b/src/renderer/src/services/__tests__/ApiService.test.ts @@ -9,12 +9,12 @@ import { import { FinishReason, MediaModality } from '@google/genai' import { FunctionCall } from '@google/genai' import AiProvider from '@renderer/aiCore' -import { BaseApiClient, OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/clients' -import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient' -import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory' -import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient' -import { OpenAIResponseAPIClient } from '@renderer/aiCore/clients/openai/OpenAIResponseAPIClient' -import { GenericChunk } from '@renderer/aiCore/middleware/schemas' +import { BaseApiClient, OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/legacy/clients' +import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient' +import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory' +import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient' +import { OpenAIResponseAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIResponseAPIClient' +import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas' import { isVisionModel } from '@renderer/config/models' import { LlmState } from '@renderer/store/llm' import { Assistant, MCPCallToolResponse, MCPToolResponse, Model, Provider, WebSearchSource } from '@renderer/types' @@ -42,7 +42,7 @@ import OpenAI from 'openai' import { ChatCompletionChunk } from 'openai/resources' import { beforeEach, describe, expect, it, vi } from 'vitest' // Mock the ApiClientFactory -vi.mock('@renderer/aiCore/clients/ApiClientFactory', () => ({ +vi.mock('@renderer/aiCore/legacy/clients/ApiClientFactory', () => ({ ApiClientFactory: { create: vi.fn() } @@ -2423,7 +2423,8 @@ describe('ApiService', () => { }, description: 'print the name and age', required: ['name', 'age'] - } + }, + type: 'mcp' } ], onChunk, @@ -2514,7 +2515,8 @@ describe('ApiService', () => { }, description: 'print the name and age', required: ['name', 'age'] - } + }, + type: 'mcp' }, toolUseId: 'mcp-tool-1', arguments: { @@ -2546,7 +2548,8 @@ describe('ApiService', () => { }, description: 'print the name and age', required: ['name', 'age'] - } + }, + type: 'mcp' }, toolUseId: 'mcp-tool-1', arguments: { @@ -2577,7 +2580,8 @@ describe('ApiService', () => { }, description: 'print the name and age', required: ['name', 'age'] - } + }, + type: 'mcp' }, response: { content: [ diff --git a/src/renderer/src/services/messageStreaming/BlockManager.ts b/src/renderer/src/services/messageStreaming/BlockManager.ts index bb6ba50dac..fb0d65913a 100644 --- a/src/renderer/src/services/messageStreaming/BlockManager.ts +++ b/src/renderer/src/services/messageStreaming/BlockManager.ts @@ -105,6 +105,7 @@ export class BlockManager { * 处理块转换 */ async handleBlockTransition(newBlock: MessageBlock, newBlockType: MessageBlockType) { + logger.debug('handleBlockTransition', { newBlock, newBlockType }) this._lastBlockType = newBlockType this._activeBlockInfo = { id: newBlock.id, type: newBlockType } // 设置新的活跃块信息 diff --git a/src/renderer/src/services/messageStreaming/callbacks/baseCallbacks.ts b/src/renderer/src/services/messageStreaming/callbacks/baseCallbacks.ts index 6a83ec918c..a561ba937e 100644 --- a/src/renderer/src/services/messageStreaming/callbacks/baseCallbacks.ts +++ b/src/renderer/src/services/messageStreaming/callbacks/baseCallbacks.ts @@ -15,10 +15,11 @@ import { PlaceholderMessageBlock } from '@renderer/types/newMessage' import { uuid } from '@renderer/utils' -import { formatErrorMessage, isAbortError } from '@renderer/utils/error' +import { isAbortError, serializeError } from '@renderer/utils/error' import { createBaseMessageBlock, createErrorBlock } from '@renderer/utils/messageUtils/create' import { findAllBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { isFocused, isOnHomePage } from '@renderer/utils/window' +import { AISDKError, NoOutputGeneratedError } from 'ai' import { BlockManager } from '../BlockManager' @@ -68,24 +69,26 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => { }) await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN) }, + // onBlockCreated: async () => { + // if (blockManager.hasInitialPlaceholder) { + // return + // } + // console.log('onBlockCreated') + // const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, { + // status: MessageBlockStatus.PROCESSING + // }) + // await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN) + // }, - onError: async (error: any) => { + onError: async (error: AISDKError) => { logger.debug('onError', error) - const isErrorTypeAbort = isAbortError(error) - let pauseErrorLanguagePlaceholder = '' - if (isErrorTypeAbort) { - pauseErrorLanguagePlaceholder = 'pause_placeholder' + if (NoOutputGeneratedError.isInstance(error)) { + return } - - const serializableError = { - name: error.name, - message: pauseErrorLanguagePlaceholder || error.message || formatErrorMessage(error), - originalMessage: error.message, - stack: error.stack, - status: error.status || error.code, - requestId: error.request_id, - providerId: error.providerId, - i18nKey: error.i18nKey + const isErrorTypeAbort = isAbortError(error) + const serializableError = serializeError(error) + if (isErrorTypeAbort) { + serializableError.message = 'pause_placeholder' } const duration = Date.now() - startTime @@ -97,7 +100,7 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => { id: uuid(), type: 'error', title: i18n.t('notification.assistant'), - message: serializableError.message, + message: serializableError.message ?? '', silent: false, timestamp: Date.now(), source: 'assistant' diff --git a/src/renderer/src/services/messageStreaming/callbacks/imageCallbacks.ts b/src/renderer/src/services/messageStreaming/callbacks/imageCallbacks.ts index 9728f564a5..9a61eed552 100644 --- a/src/renderer/src/services/messageStreaming/callbacks/imageCallbacks.ts +++ b/src/renderer/src/services/messageStreaming/callbacks/imageCallbacks.ts @@ -47,7 +47,7 @@ export const createImageCallbacks = (deps: ImageCallbacksDependencies) => { } }, - onImageGenerated: (imageData: any) => { + onImageGenerated: async (imageData: any) => { if (imageBlockId) { if (!imageData) { const changes: Partial = { @@ -65,7 +65,16 @@ export const createImageCallbacks = (deps: ImageCallbacksDependencies) => { } imageBlockId = null } else { - logger.error('[onImageGenerated] Last block was not an Image block or ID is missing.') + if (imageData) { + const imageBlock = createImageBlock(assistantMsgId, { + status: MessageBlockStatus.SUCCESS, + url: imageData.images?.[0] || 'placeholder_image_url', + metadata: { generateImageResponse: imageData } + }) + await blockManager.handleBlockTransition(imageBlock, MessageBlockType.IMAGE) + } else { + logger.error('[onImageGenerated] Last block was not an Image block or ID is missing.') + } } } } diff --git a/src/renderer/src/services/messageStreaming/callbacks/index.ts b/src/renderer/src/services/messageStreaming/callbacks/index.ts index 2b6fc5968a..866147be41 100644 --- a/src/renderer/src/services/messageStreaming/callbacks/index.ts +++ b/src/renderer/src/services/messageStreaming/callbacks/index.ts @@ -59,7 +59,8 @@ export const createCallbacks = (deps: CallbacksDependencies) => { blockManager, getState, assistantMsgId, - getCitationBlockId: citationCallbacks.getCitationBlockId + getCitationBlockId: citationCallbacks.getCitationBlockId, + getCitationBlockIdFromTool: toolCallbacks.getCitationBlockId }) // 组合所有回调 diff --git a/src/renderer/src/services/messageStreaming/callbacks/textCallbacks.ts b/src/renderer/src/services/messageStreaming/callbacks/textCallbacks.ts index 657cf6f0f3..3756c88c82 100644 --- a/src/renderer/src/services/messageStreaming/callbacks/textCallbacks.ts +++ b/src/renderer/src/services/messageStreaming/callbacks/textCallbacks.ts @@ -12,10 +12,11 @@ interface TextCallbacksDependencies { getState: any assistantMsgId: string getCitationBlockId: () => string | null + getCitationBlockIdFromTool: () => string | null } export const createTextCallbacks = (deps: TextCallbacksDependencies) => { - const { blockManager, getState, assistantMsgId, getCitationBlockId } = deps + const { blockManager, getState, assistantMsgId, getCitationBlockId, getCitationBlockIdFromTool } = deps // 内部维护的状态 let mainTextBlockId: string | null = null @@ -40,7 +41,7 @@ export const createTextCallbacks = (deps: TextCallbacksDependencies) => { }, onTextChunk: async (text: string) => { - const citationBlockId = getCitationBlockId() + const citationBlockId = getCitationBlockId() || getCitationBlockIdFromTool() const citationBlockSource = citationBlockId ? (getState().messageBlocks.entities[citationBlockId] as CitationMessageBlock).response?.source : WebSearchSource.WEBSEARCH diff --git a/src/renderer/src/services/messageStreaming/callbacks/toolCallbacks.ts b/src/renderer/src/services/messageStreaming/callbacks/toolCallbacks.ts index 3782b4f35f..1a5f3df28e 100644 --- a/src/renderer/src/services/messageStreaming/callbacks/toolCallbacks.ts +++ b/src/renderer/src/services/messageStreaming/callbacks/toolCallbacks.ts @@ -1,7 +1,8 @@ import { loggerService } from '@logger' import type { MCPToolResponse } from '@renderer/types' +import { WebSearchSource } from '@renderer/types' import { MessageBlockStatus, MessageBlockType, ToolMessageBlock } from '@renderer/types/newMessage' -import { createToolBlock } from '@renderer/utils/messageUtils/create' +import { createCitationBlock, createToolBlock } from '@renderer/utils/messageUtils/create' import { BlockManager } from '../BlockManager' @@ -18,9 +19,12 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => { // 内部维护的状态 const toolCallIdToBlockIdMap = new Map() let toolBlockId: string | null = null + let citationBlockId: string | null = null return { onToolCallPending: (toolResponse: MCPToolResponse) => { + logger.debug('onToolCallPending', toolResponse) + if (blockManager.hasInitialPlaceholder) { const changes = { type: MessageBlockType.TOOL, @@ -93,10 +97,42 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => { } if (finalStatus === MessageBlockStatus.ERROR) { - changes.error = { message: `Tool execution failed/error`, details: toolResponse.response } + changes.error = { + message: `Tool execution failed/error`, + details: toolResponse.response, + name: null, + stack: null + } } blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true) + + // Handle citation block creation for web search results + if (toolResponse.tool.name === 'builtin_web_search' && toolResponse.response?.searchResults) { + const citationBlock = createCitationBlock( + assistantMsgId, + { + response: { results: toolResponse.response.searchResults, source: WebSearchSource.WEBSEARCH } + }, + { + status: MessageBlockStatus.SUCCESS + } + ) + citationBlockId = citationBlock.id + blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION) + } + if (toolResponse.tool.name === 'builtin_knowledge_search' && toolResponse.response?.knowledgeReferences) { + const citationBlock = createCitationBlock( + assistantMsgId, + { knowledge: toolResponse.response.knowledgeReferences }, + { + status: MessageBlockStatus.SUCCESS + } + ) + citationBlockId = citationBlock.id + blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION) + } + // TODO: 处理 memory 引用 } else { logger.warn( `[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}` @@ -104,6 +140,9 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => { } toolBlockId = null - } + }, + + // 暴露给 textCallbacks 使用的方法 + getCitationBlockId: () => citationBlockId } } diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts index fc5bd93575..5a70c202f8 100644 --- a/src/renderer/src/store/index.ts +++ b/src/renderer/src/store/index.ts @@ -67,7 +67,7 @@ const persistedReducer = persistReducer( { key: 'cherry-studio', storage, - version: 146, + version: 147, blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs'], migrate }, diff --git a/src/renderer/src/store/messageBlock.ts b/src/renderer/src/store/messageBlock.ts index f06470f7cd..cd77c05430 100644 --- a/src/renderer/src/store/messageBlock.ts +++ b/src/renderer/src/store/messageBlock.ts @@ -216,6 +216,15 @@ export const formatCitationsFromBlock = (block: CitationMessageBlock | undefined type: 'websearch' })) || [] break + case WebSearchSource.AISDK: + formattedCitations = + (block.response.results as any[])?.map((result, index) => ({ + number: index + 1, + url: result.url, + title: result.title, + providerMetadata: result?.providerMetadata + })) || [] + break } } // 3. Handle Knowledge Base References diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index 85ae95752e..0ba27d931e 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -838,6 +838,7 @@ const migrateConfig = { state.llm.providers.forEach((provider) => { if (provider.id === 'qwenlm') { + // @ts-ignore eslint-disable-next-line provider.type = 'qwenlm' } }) @@ -895,6 +896,7 @@ const migrateConfig = { try { state.llm.providers.forEach((provider) => { if (provider.id === 'qwenlm') { + // @ts-ignore eslint-disable-next-line provider.type = 'qwenlm' } }) @@ -2366,6 +2368,21 @@ const migrateConfig = { logger.error('migrate 146 error', error as Error) return state } + }, + '147': (state: RootState) => { + try { + state.llm.providers.forEach((provider) => { + if (provider.id === SystemProviderIds.anthropic) { + if (provider.apiHost.endsWith('/')) { + provider.apiHost = provider.apiHost.slice(0, -1) + } + } + }) + return state + } catch (error) { + logger.error('migrate 147 error', error as Error) + return state + } } } diff --git a/src/renderer/src/store/thunk/__tests__/streamCallback.integration.test.ts b/src/renderer/src/store/thunk/__tests__/streamCallback.integration.test.ts index b7bab08ff8..6c98e3cee4 100644 --- a/src/renderer/src/store/thunk/__tests__/streamCallback.integration.test.ts +++ b/src/renderer/src/store/thunk/__tests__/streamCallback.integration.test.ts @@ -241,7 +241,13 @@ vi.mock('i18next', () => { vi.mock('@renderer/utils/error', () => ({ formatErrorMessage: vi.fn((error) => error.message || 'Unknown error'), - isAbortError: vi.fn((error) => error.name === 'AbortError') + isAbortError: vi.fn((error) => error.name === 'AbortError'), + serializeError: vi.fn((error) => ({ + name: error.name, + message: error.message, + stack: error.stack, + cause: error.cause ? String(error.cause) : undefined + })) })) vi.mock('@renderer/utils', () => ({ @@ -434,7 +440,8 @@ describe('streamCallback Integration Tests', () => { type: 'object', title: 'Test Tool Input', properties: {} - } + }, + type: 'mcp' } const chunks: Chunk[] = [ @@ -570,7 +577,8 @@ describe('streamCallback Integration Tests', () => { type: 'object', title: 'Calculator Input', properties: {} - } + }, + type: 'mcp' } const chunks: Chunk[] = [ diff --git a/src/renderer/src/store/thunk/messageThunk.ts b/src/renderer/src/store/thunk/messageThunk.ts index f4d009e0e5..6bd2ecd826 100644 --- a/src/renderer/src/store/thunk/messageThunk.ts +++ b/src/renderer/src/store/thunk/messageThunk.ts @@ -1,9 +1,9 @@ import { loggerService } from '@logger' import db from '@renderer/databases' -import { fetchChatCompletion } from '@renderer/services/ApiService' import FileManager from '@renderer/services/FileManager' import { BlockManager } from '@renderer/services/messageStreaming/BlockManager' import { createCallbacks } from '@renderer/services/messageStreaming/callbacks' +import { transformMessagesAndFetch } from '@renderer/services/OrchestrateService' import { endSpan } from '@renderer/services/SpanManagerService' import { createStreamProcessor, type StreamProcessorCallbacks } from '@renderer/services/StreamProcessingService' import store from '@renderer/store' @@ -12,13 +12,13 @@ import { type Assistant, type FileMetadata, type Model, type Topic } from '@rend import type { FileMessageBlock, ImageMessageBlock, Message, MessageBlock } from '@renderer/types/newMessage' import { AssistantMessageStatus, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage' import { uuid } from '@renderer/utils' +import { addAbortController } from '@renderer/utils/abortController' import { createAssistantMessage, createTranslationBlock, resetAssistantMessage } from '@renderer/utils/messageUtils/create' -import { getTopicQueue } from '@renderer/utils/queue' -import { waitForTopicQueue } from '@renderer/utils/queue' +import { getTopicQueue, waitForTopicQueue } from '@renderer/utils/queue' import { t } from 'i18next' import { isEmpty, throttle } from 'lodash' import { LRUCache } from 'lru-cache' @@ -155,6 +155,7 @@ const getBlockThrottler = (id: string) => { */ export const throttledBlockUpdate = (id: string, blockUpdate: any) => { const throttler = getBlockThrottler(id) + // store.dispatch(updateOneBlock({ id, changes: blockUpdate })) throttler(blockUpdate) } @@ -358,28 +359,36 @@ const fetchAndProcessAssistantResponseImpl = async ( }) const streamProcessorCallbacks = createStreamProcessor(callbacks) - // const startTime = Date.now() - const result = await fetchChatCompletion({ - messages: messagesForContext, - assistant: assistant, - onChunkReceived: streamProcessorCallbacks - }) - endSpan({ - topicId, - outputs: result ? result.getText() : '', - modelName: assistant.model?.name, - modelEnded: true - }) + const abortController = new AbortController() + addAbortController(userMessageId!, () => abortController.abort()) + + await transformMessagesAndFetch( + { + messages: messagesForContext, + assistant, + topicId, + options: { + signal: abortController.signal, + timeout: 30000 + } + }, + streamProcessorCallbacks + ) } catch (error: any) { - logger.error('Error fetching chat completion:', error) + logger.error('Error in fetchAndProcessAssistantResponseImpl:', error) endSpan({ topicId, error: error, modelName: assistant.model?.name }) - if (assistantMessage) { - callbacks.onError?.(error) - throw error + // 统一错误处理:确保 loading 状态被正确设置,避免队列任务卡住 + try { + await callbacks.onError?.(error) + } catch (callbackError) { + logger.error('Error in onError callback:', callbackError as Error) + } finally { + // 确保无论如何都设置 loading 为 false(onError 回调中已设置,这里是保险) + dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false })) } } } @@ -970,11 +979,11 @@ export const appendAssistantResponseThunk = const existingMessageIndex = currentTopicMessageIds.findIndex((id) => id === existingAssistantMessageId) const insertAtIndex = existingMessageIndex !== -1 ? existingMessageIndex + 1 : currentTopicMessageIds.length - dispatch(newMessagesActions.insertMessageAtIndex({ topicId, message: newAssistantStub, index: insertAtIndex })) - // 4. Update Database (Save the stub to the topic's message list) await saveMessageAndBlocksToDB(newAssistantStub, [], insertAtIndex) + dispatch(newMessagesActions.insertMessageAtIndex({ topicId, message: newAssistantStub, index: insertAtIndex })) + // 5. Prepare and queue the processing task const assistantConfigForThisCall = { ...assistant, diff --git a/src/renderer/src/tools/think.ts b/src/renderer/src/tools/think.ts index 3da4c7f972..1e17d3e865 100644 --- a/src/renderer/src/tools/think.ts +++ b/src/renderer/src/tools/think.ts @@ -8,6 +8,7 @@ export const thinkTool: MCPTool = { description: 'Use the tool to think about something. It will not obtain new information or make any changes to the repository, but just log the thought. Use it when complex reasoning or brainstorming is needed. For example, if you explore the repo and discover the source of a bug, call this tool to brainstorm several unique ways of fixing the bug, and assess which change(s) are likely to be simplest and most effective. Alternatively, if you receive some test results, call this tool to brainstorm ways to fix the failing tests.', isBuiltIn: true, + type: 'mcp', inputSchema: { type: 'object', title: 'Think Tool Input', diff --git a/src/renderer/src/trace/dataHandler/CommonResultHandler.ts b/src/renderer/src/trace/dataHandler/CommonResultHandler.ts index 4c5ea553c5..c66528abb0 100644 --- a/src/renderer/src/trace/dataHandler/CommonResultHandler.ts +++ b/src/renderer/src/trace/dataHandler/CommonResultHandler.ts @@ -1,6 +1,6 @@ import { TokenUsage } from '@mcp-trace/trace-core' import { Span } from '@opentelemetry/api' -import { CompletionsResult } from '@renderer/aiCore/middleware/schemas' +import { CompletionsResult } from '@renderer/aiCore/legacy/middleware/schemas' import { endSpan } from '@renderer/services/SpanManagerService' export class CompletionsResultHandler { diff --git a/src/renderer/src/types/aiCoreTypes.ts b/src/renderer/src/types/aiCoreTypes.ts new file mode 100644 index 0000000000..614211a5c7 --- /dev/null +++ b/src/renderer/src/types/aiCoreTypes.ts @@ -0,0 +1,32 @@ +import type { AISDKError, APICallError, ImageModel, LanguageModel } from 'ai' +import { generateObject, generateText, ModelMessage, streamObject, streamText } from 'ai' + +export type StreamTextParams = Omit[0], 'model' | 'messages'> & + ( + | { + prompt: string | Array + messages?: never + } + | { + messages: Array + prompt?: never + } + ) +export type GenerateTextParams = Omit[0], 'model' | 'messages'> & + ( + | { + prompt: string | Array + messages?: never + } + | { + messages: Array + prompt?: never + } + ) +export type StreamObjectParams = Omit[0], 'model'> +export type GenerateObjectParams = Omit[0], 'model'> + +export type AiSdkModel = LanguageModel | ImageModel + +// 该类型用于格式化错误信息,目前只处理 APICallError,待扩展 +export type AiSdkErrorUnion = AISDKError | APICallError diff --git a/src/renderer/src/types/chunk.ts b/src/renderer/src/types/chunk.ts index 1fdbbdae6f..6111060bad 100644 --- a/src/renderer/src/types/chunk.ts +++ b/src/renderer/src/types/chunk.ts @@ -1,4 +1,12 @@ -import { ExternalToolResult, KnowledgeReference, MCPToolResponse, ToolUseResponse, WebSearchResponse } from '.' +import { + ExternalToolResult, + KnowledgeReference, + MCPTool, + MCPToolResponse, + NormalToolResponse, + ToolUseResponse, + WebSearchResponse +} from '.' import { Response, ResponseError } from './newMessage' import { SdkToolCall } from './sdk' @@ -287,12 +295,12 @@ export interface ExternalToolCompleteChunk { export interface MCPToolCreatedChunk { type: ChunkType.MCP_TOOL_CREATED tool_calls?: SdkToolCall[] // 工具调用 - tool_use_responses?: ToolUseResponse[] // 工具使用响应 + tool_use_responses?: (Omit & { tool: MCPTool })[] // 工具使用响应 } export interface MCPToolPendingChunk { type: ChunkType.MCP_TOOL_PENDING - responses: MCPToolResponse[] + responses: MCPToolResponse[] | NormalToolResponse[] } export interface MCPToolInProgressChunk { @@ -303,14 +311,14 @@ export interface MCPToolInProgressChunk { /** * The tool responses of the chunk */ - responses: MCPToolResponse[] + responses: MCPToolResponse[] | NormalToolResponse[] } export interface MCPToolCompleteChunk { /** * The tool response of the chunk */ - responses: MCPToolResponse[] + responses: MCPToolResponse[] | NormalToolResponse[] /** * The type of the chunk diff --git a/src/renderer/src/types/error.ts b/src/renderer/src/types/error.ts new file mode 100644 index 0000000000..2439eaa123 --- /dev/null +++ b/src/renderer/src/types/error.ts @@ -0,0 +1,32 @@ +import { Serializable } from './serialize' + +export interface SerializedError { + name: string | null + message: string | null + stack: string | null + [key: string]: Serializable +} +export const isSerializedError = (error: Record): error is SerializedAiSdkError => { + return 'name' in error && 'message' in error && 'stack' in error +} +export interface SerializedAiSdkError extends SerializedError { + readonly cause: string | null +} + +export const isSerializedAiSdkError = (error: SerializedError): error is SerializedAiSdkError => { + return 'cause' in error +} + +export interface SerializedAiSdkAPICallError extends SerializedAiSdkError { + readonly url: string + readonly requestBodyValues: Serializable + readonly statusCode: number | null + readonly responseHeaders: Record | null + readonly responseBody: string | null + readonly isRetryable: boolean + readonly data: Serializable | null +} + +export const isSerializedAiSdkAPICallError = (error: SerializedError): error is SerializedAiSdkAPICallError => { + return isSerializedAiSdkError(error) && 'url' in error && 'requestBodyValues' in error && 'isRetryable' in error +} diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 95d09120b3..4d3046775b 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -7,9 +7,12 @@ import * as z from 'zod/v4' export * from './file' export * from './note' +import type { StreamTextParams } from './aiCoreTypes' +import type { Chunk } from './chunk' import type { FileMetadata } from './file' import { MCPConfigSample, McpServerType } from './mcp' import type { Message } from './newMessage' +import type { BaseTool, MCPTool } from './tool' export * from './mcp' export * from './ocr' @@ -329,6 +332,15 @@ export type SystemProvider = Provider & { apiOptions?: never } +export type VertexProvider = Provider & { + googleCredentials: { + privateKey: string + clientEmail: string + } + project: string + location: string +} + /** * 判断是否为系统内置的提供商。比直接使用`provider.isSystem`更好,因为该数据字段不会随着版本更新而变化。 * @param provider - Provider对象,包含提供商的信息 @@ -348,6 +360,7 @@ export type ProviderType = | 'vertexai' | 'mistral' | 'aws-bedrock' + | 'vertex-anthropic' export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search' | 'rerank' @@ -809,7 +822,8 @@ export enum WebSearchSource { QWEN = 'qwen', HUNYUAN = 'hunyuan', ZHIPU = 'zhipu', - GROK = 'grok' + GROK = 'grok', + AISDK = 'ai-sdk' } export type WebSearchResponse = { @@ -920,17 +934,6 @@ export const MCPToolOutputSchema = z.object({ required: z.array(z.string()) }) -export interface MCPTool { - id: string - serverId: string - serverName: string - name: string - description?: string - inputSchema: MCPToolInputSchema - outputSchema?: z.infer - isBuiltIn?: boolean // 标识是否为内置工具,内置工具不需要通过MCP协议调用 -} - export interface MCPPromptArguments { name: string description?: string @@ -969,7 +972,7 @@ export type MCPToolResponseStatus = 'pending' | 'cancelled' | 'invoking' | 'done interface BaseToolResponse { id: string // unique id - tool: MCPTool + tool: BaseTool | MCPTool arguments: Record | undefined status: MCPToolResponseStatus response?: any @@ -984,7 +987,17 @@ export interface ToolCallResponse extends BaseToolResponse { toolCallId?: string } -export type MCPToolResponse = ToolUseResponse | ToolCallResponse +// export type MCPToolResponse = ToolUseResponse | ToolCallResponse +export interface MCPToolResponse extends Omit { + tool: MCPTool + toolCallId?: string + toolUseId?: string +} + +export interface NormalToolResponse extends Omit { + tool: BaseTool + toolCallId: string +} export interface MCPToolResultContent { type: 'text' | 'image' | 'audio' | 'resource' @@ -1110,6 +1123,7 @@ export interface ApiServerConfig { port: number apiKey: string } +export * from './tool' // Memory Service Types // ======================================================================== @@ -1281,3 +1295,32 @@ export type HexColor = string export const isHexColor = (value: string): value is HexColor => { return /^#([0-9A-F]{3}){1,2}$/i.test(value) } + +export type FetchChatCompletionOptions = { + signal?: AbortSignal + timeout?: number + headers?: Record +} + +type BaseParams = { + assistant: Assistant + options?: FetchChatCompletionOptions + onChunkReceived: (chunk: Chunk) => void + topicId?: string // 添加 topicId 参数 + uiMessages?: Message[] +} + +type MessagesParams = BaseParams & { + messages: StreamTextParams['messages'] + prompt?: never +} + +type PromptParams = BaseParams & { + messages?: never + // prompt: Just use string for convinience. Native prompt type unite more types, including messages type. + // we craete a non-intersecting prompt type to discriminate them. + // see https://github.com/vercel/ai/issues/8363 + prompt: string +} + +export type FetchChatCompletionParams = MessagesParams | PromptParams diff --git a/src/renderer/src/types/newMessage.ts b/src/renderer/src/types/newMessage.ts index e7f4831ec9..a8a037a851 100644 --- a/src/renderer/src/types/newMessage.ts +++ b/src/renderer/src/types/newMessage.ts @@ -15,6 +15,7 @@ import type { WebSearchResponse, WebSearchSource } from '.' +import { SerializedError } from './error' // MessageBlock 类型枚举 - 根据实际API返回特性优化 export enum MessageBlockType { @@ -50,7 +51,7 @@ export interface BaseMessageBlock { status: MessageBlockStatus // 块状态 model?: Model // 使用的模型 metadata?: Record // 通用元数据 - error?: Record // Added optional error field to base + error?: SerializedError // Serializable error object instead of AISDKError } export interface PlaceholderMessageBlock extends BaseMessageBlock { @@ -73,7 +74,7 @@ export interface MainTextMessageBlock extends BaseMessageBlock { export interface ThinkingMessageBlock extends BaseMessageBlock { type: MessageBlockType.THINKING content: string - thinking_millsec?: number + thinking_millsec: number } // 翻译块 diff --git a/src/renderer/src/types/sdk.ts b/src/renderer/src/types/sdk.ts index 8ce94e6f4c..54a7896be4 100644 --- a/src/renderer/src/types/sdk.ts +++ b/src/renderer/src/types/sdk.ts @@ -77,6 +77,7 @@ type OpenAIParamsPurified = Omit _isSerializable(item)) + } + + // 检查是否为纯对象(plain object) + const proto = Object.getPrototypeOf(val) + if (proto !== null && proto !== Object.prototype && proto !== Array.prototype) { + return false // 不是 plain object,比如 class 实例 + } + + // 检查内置对象(如 Date、RegExp、Map、Set 等) + if ( + val instanceof Date || + val instanceof RegExp || + val instanceof Map || + val instanceof Set || + val instanceof Error || + val instanceof File || + val instanceof Blob + ) { + return false + } + + // 递归检查所有属性值 + return Object.values(val).every((v) => _isSerializable(v)) + } + + // function、symbol 不可序列化 + return false + } + + try { + return _isSerializable(value) + } catch { + return false // 如出现循环引用错误等 + } +} diff --git a/src/renderer/src/types/tool.ts b/src/renderer/src/types/tool.ts new file mode 100644 index 0000000000..b552c67685 --- /dev/null +++ b/src/renderer/src/types/tool.ts @@ -0,0 +1,51 @@ +import * as z from 'zod/v4' + +export type ToolType = 'builtin' | 'provider' | 'mcp' + +export interface BaseTool { + id: string + name: string + description?: string + type: ToolType +} + +// export interface ToolCallResponse { +// id: string +// toolName: string +// arguments: Record | undefined +// status: 'invoking' | 'completed' | 'error' +// result?: any // AI SDK的工具执行结果 +// error?: string +// providerExecuted?: boolean // 标识是Provider端执行还是客户端执行 +// } + +export const MCPToolOutputSchema = z.object({ + type: z.literal('object'), + properties: z.record(z.string(), z.unknown()), + required: z.array(z.string()) +}) + +export interface MCPToolInputSchema { + type: string + title: string + description?: string + required?: string[] + properties: Record +} + +export interface BuiltinTool extends BaseTool { + inputSchema: MCPToolInputSchema + type: 'builtin' +} + +export interface MCPTool extends BaseTool { + id: string + serverId: string + serverName: string + name: string + description?: string + inputSchema: MCPToolInputSchema + outputSchema?: z.infer + isBuiltIn?: boolean // 标识是否为内置工具,内置工具不需要通过MCP协议调用 + type: 'mcp' +} diff --git a/src/renderer/src/utils/__tests__/error.test.ts b/src/renderer/src/utils/__tests__/error.test.ts index db09f99952..1f2afab9f7 100644 --- a/src/renderer/src/utils/__tests__/error.test.ts +++ b/src/renderer/src/utils/__tests__/error.test.ts @@ -1,6 +1,6 @@ import { describe, expect, it, vi } from 'vitest' -import { formatErrorMessage, formatMessageError, getErrorDetails, getErrorMessage, isAbortError } from '../error' +import { formatErrorMessage, getErrorDetails, isAbortError } from '../error' describe('error', () => { describe('getErrorDetails', () => { @@ -123,79 +123,6 @@ describe('error', () => { }) }) - describe('formatMessageError', () => { - it('should return error details as an object', () => { - const error = new Error('Test error') - const result = formatMessageError(error) - - expect(result.message).toBe('Test error') - expect(result.stack).toBeUndefined() - expect(result.headers).toBeUndefined() - expect(result.request_id).toBeUndefined() - }) - - it('should handle string errors', () => { - const result = formatMessageError('String error') - expect(typeof result).toBe('string') - expect(result).toBe('String error') - }) - - it('should handle formatting errors', () => { - const problematicError = { - get message() { - throw new Error('Cannot access') - }, - toString: () => 'Error object' - } - - const result = formatMessageError(problematicError) - expect(result).toBeTruthy() - }) - - it('should handle completely invalid errors', () => { - let invalidError: any - try { - invalidError = Object.create(null) - Object.defineProperty(invalidError, 'toString', { - get: () => { - throw new Error() - } - }) - } catch (e) { - invalidError = { - toString() { - throw new Error() - } - } - } - - const result = formatMessageError(invalidError) - expect(result).toBeTruthy() - }) - }) - - describe('getErrorMessage', () => { - it('should extract message from Error objects', () => { - const error = new Error('Test message') - expect(getErrorMessage(error)).toBe('Test message') - }) - - it('should handle objects with message property', () => { - const errorObj = { message: 'Object message' } - expect(getErrorMessage(errorObj)).toBe('Object message') - }) - - it('should convert non-Error objects to string', () => { - const obj = { toString: () => 'Custom toString' } - expect(getErrorMessage(obj)).toBe('Custom toString') - }) - - it('should return empty string for undefined or null', () => { - expect(getErrorMessage(undefined)).toBe('') - expect(getErrorMessage(null)).toBe('') - }) - }) - describe('isAbortError', () => { it('should identify OpenAI abort errors by message', () => { const openaiError = { message: 'Request was aborted.' } diff --git a/src/renderer/src/utils/__tests__/prompt.test.ts b/src/renderer/src/utils/__tests__/prompt.test.ts index 6d990c16eb..e0a3965e90 100644 --- a/src/renderer/src/utils/__tests__/prompt.test.ts +++ b/src/renderer/src/utils/__tests__/prompt.test.ts @@ -49,7 +49,8 @@ const createMockTool = (id: string, description: string, inputSchema: any = {}): title: `${id}-schema`, properties: {}, ...inputSchema - } + }, + type: 'mcp' }) // Helper to create a mock Assistant diff --git a/src/renderer/src/utils/abortController.ts b/src/renderer/src/utils/abortController.ts index 6593b1c162..11bd7a791d 100644 --- a/src/renderer/src/utils/abortController.ts +++ b/src/renderer/src/utils/abortController.ts @@ -49,3 +49,25 @@ export function createAbortPromise(signal: AbortSignal, finallyPromise: Promi }) }) } + +/** + * 创建一个新的 AbortController 并将其注册到全局的 abort 映射中 + * @param key - 用于标识此 AbortController 的唯一键值 + * @returns AbortSignal - 返回 AbortController 的信号 + * @example + * ```typescript + * const signal = readyToAbort('uniqueKey'); + * fetch('https://api.example.com/data', { signal }) + * .then(response => response.json()) + * .catch(error => { + * if (error.name === 'AbortError') { + * console.log('Fetch aborted'); + * } + * }); + * ``` + */ +export function readyToAbort(key: string) { + const controller = new AbortController() + addAbortController(key, () => controller.abort()) + return controller.signal +} diff --git a/src/renderer/src/utils/api.ts b/src/renderer/src/utils/api.ts index b8a7cc1319..5e9b8f91a6 100644 --- a/src/renderer/src/utils/api.ts +++ b/src/renderer/src/utils/api.ts @@ -16,9 +16,10 @@ export function formatApiKeys(value: string): string { * - 要加:其余情况。 * * @param {string} host - 需要格式化的 API 主机地址。 + * @param {string} apiVersion - 需要添加的 API 版本。 * @returns {string} 格式化后的 API 主机地址。 */ -export function formatApiHost(host: string): string { +export function formatApiHost(host: string, apiVersion: string = 'v1'): string { const forceUseOriginalHost = () => { if (host.endsWith('/')) { return true @@ -27,7 +28,7 @@ export function formatApiHost(host: string): string { return host.endsWith('volces.com/api/v3') } - return forceUseOriginalHost() ? host : `${host}/v1/` + return forceUseOriginalHost() ? host : `${host}/${apiVersion}/` } /** diff --git a/src/renderer/src/utils/error.ts b/src/renderer/src/utils/error.ts index 296db0da5d..759c1cf512 100644 --- a/src/renderer/src/utils/error.ts +++ b/src/renderer/src/utils/error.ts @@ -1,8 +1,17 @@ -// import { loggerService } from '@logger' +import { loggerService } from '@logger' +import { + isSerializedAiSdkAPICallError, + SerializedAiSdkAPICallError, + SerializedAiSdkError, + SerializedError +} from '@renderer/types/error' +import { AISDKError, APICallError } from 'ai' import { t } from 'i18next' import z from 'zod' -// const logger = loggerService.withContext('Utils:error') +import { safeSerialize } from './serialize' + +const logger = loggerService.withContext('Utils:error') export function getErrorDetails(err: any, seen = new WeakSet()): any { // Handle circular references @@ -52,29 +61,12 @@ export function formatErrorMessage(error: any): string { } } -export function formatMessageError(error: any): Record { - try { - const detailedError = getErrorDetails(error) - delete detailedError?.headers - delete detailedError?.stack - delete detailedError?.request_id - return detailedError - } catch (e) { - try { - return { message: String(error) } - } catch { - return { message: 'Error: Unable to format error message' } - } - } -} - -export function getErrorMessage(error: any): string { - return error?.message || error?.toString() || '' -} - export const isAbortError = (error: any): boolean => { + // Convert message to string for consistent checking + const errorMessage = String(error?.message || '') + // 检查错误消息 - if (error?.message === 'Request was aborted.') { + if (errorMessage === 'Request was aborted.') { return true } @@ -87,7 +79,8 @@ export const isAbortError = (error: any): boolean => { if ( error && typeof error === 'object' && - (error.message === 'Request was aborted.' || error?.message?.includes('signal is aborted without reason')) + errorMessage && + (errorMessage === 'Request was aborted.' || errorMessage.includes('signal is aborted without reason')) ) { return true } @@ -102,6 +95,34 @@ export const formatMcpError = (error: any) => { return error.message } +export const serializeError = (error: AISDKError): SerializedError => { + const baseError = { + name: error.name, + message: error.message, + stack: error.stack ?? null, + cause: error.cause ? String(error.cause) : null + } + if (APICallError.isInstance(error)) { + let content = error.message === '' ? error.responseBody || 'Unknown error' : error.message + try { + const obj = JSON.parse(content) + content = obj.error.message + } catch (e: any) { + logger.warn('Error parsing error response body:', e) + } + return { + ...baseError, + url: error.url, + requestBodyValues: safeSerialize(error.requestBodyValues), + statusCode: error.statusCode ?? null, + responseBody: content, + isRetryable: error.isRetryable, + data: safeSerialize(error.data), + responseHeaders: error.responseHeaders ?? null + } satisfies SerializedAiSdkAPICallError + } + return baseError +} /** * 格式化 Zod 验证错误信息为可读的字符串 * @param error - Zod 验证错误对象 @@ -113,3 +134,102 @@ export const formatZodError = (error: z.ZodError, title?: string) => { const errorMessage = readableErrors.join('\n') return title ? `${title}: \n${errorMessage}` : errorMessage } + +/** + * 将任意值安全地转换为字符串 + * @param value - 需要转换的值,unknown 类型 + * @returns 转换后的字符串 + * + * @description + * 该函数可以安全地处理以下情况: + * - null 和 undefined 会被转换为 'null' + * - 字符串直接返回 + * - 原始类型(数字、布尔值、bigint等)使用 String() 转换 + * - 对象和数组会尝试使用 JSON.stringify 序列化,并处理循环引用 + * - 如果序列化失败,返回错误信息 + * + * @example + * ```ts + * safeToString(null) // 'null' + * safeToString('test') // 'test' + * safeToString(123) // '123' + * safeToString({a: 1}) // '{"a":1}' + * ``` + */ +export function safeToString(value: unknown): string { + // 处理 null 和 undefined + if (value == null) { + return 'null' + } + + // 字符串直接返回 + if (typeof value === 'string') { + return value + } + + // 数字、布尔值、bigint 等原始类型,安全用 String() + if (typeof value !== 'object' && typeof value !== 'function') { + return String(value) + } + + // 处理对象(包括数组) + if (typeof value === 'object') { + // 处理函数 + if (typeof value === 'function') { + return value.toString() + } + // 其他对象 + try { + return JSON.stringify(value, getCircularReplacer()) + } catch (err) { + return '[Unserializable: ' + err + ']' + } + } + + return String(value) +} + +// 防止循环引用导致的 JSON.stringify 崩溃 +function getCircularReplacer() { + const seen = new WeakSet() + return (_key: string, value: unknown) => { + if (typeof value === 'object' && value !== null) { + if (seen.has(value)) { + return '[Circular]' + } + seen.add(value) + } + return value + } +} + +export function formatError(error: SerializedError): string { + return `${t('error.name')}: ${error.name}\n${t('error.message')}: ${error.message}\n${t('error.stack')}: ${error.stack}` +} + +export function formatAiSdkError(error: SerializedAiSdkError): string { + let text = formatError(error) + '\n' + if (error.cause) { + text += `${t('error.cause')}: ${error.cause}\n` + } + if (isSerializedAiSdkAPICallError(error)) { + if (error.statusCode) { + text += `${t('error.statusCode')}: ${error.statusCode}\n` + } + text += `${t('error.requestUrl')}: ${error.url}\n` + const requestBodyValues = safeToString(error.requestBodyValues) + text += `${t('error.requestBodyValues')}: ${requestBodyValues}\n` + if (error.responseHeaders) { + text += `${t('error.responseHeaders')}: ${JSON.stringify(error.responseHeaders, null, 2)}\n` + } + if (error.responseBody) { + text += `${t('error.responseBody')}: ${error.responseBody}\n` + } + if (error.data) { + const data = safeToString(error.data) + text += `${t('error.data')}: ${data}\n` + } + } + + return text.trim() +} diff --git a/src/renderer/src/utils/mcp-tools.ts b/src/renderer/src/utils/mcp-tools.ts index a399e6dc65..c9ee86a76b 100644 --- a/src/renderer/src/utils/mcp-tools.ts +++ b/src/renderer/src/utils/mcp-tools.ts @@ -326,7 +326,11 @@ export function isToolAutoApproved(tool: MCPTool, server?: MCPServer): boolean { return effectiveServer ? !effectiveServer.disabledAutoApproveTools?.includes(tool.name) : false } -export function parseToolUse(content: string, mcpTools: MCPTool[], startIdx: number = 0): ToolUseResponse[] { +export function parseToolUse( + content: string, + mcpTools: MCPTool[], + startIdx: number = 0 +): (Omit & { tool: MCPTool })[] { if (!content || !mcpTools || mcpTools.length === 0) { return [] } @@ -344,7 +348,7 @@ export function parseToolUse(content: string, mcpTools: MCPTool[], startIdx: num const toolUsePattern = /([\s\S]*?)([\s\S]*?)<\/name>([\s\S]*?)([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g - const tools: ToolUseResponse[] = [] + const tools: (Omit & { tool: MCPTool })[] = [] let match let idx = startIdx // Find all tool use blocks @@ -821,10 +825,26 @@ export function mcpToolCallResponseToAwsBedrockMessage( return message } -export function isEnabledToolUse(assistant: Assistant) { +/** + * 是否启用工具使用 + * 1. 如果模型支持函数调用,则启用工具使用 + * 2. 如果工具使用模式为 prompt,则启用工具使用 + * @param assistant + * @returns 是否启用工具使用 + */ +export function isSupportedToolUse(assistant: Assistant) { if (assistant.model) { return isFunctionCallingModel(assistant.model) && isToolUseModeFunction(assistant) } return false } + +/** + * 是否使用提示词工具使用 + * @param assistant + * @returns 是否使用提示词工具使用 + */ +export function isPromptToolUse(assistant: Assistant) { + return assistant.settings?.toolUseMode === 'prompt' +} diff --git a/src/renderer/src/utils/messageUtils/create.ts b/src/renderer/src/utils/messageUtils/create.ts index e0081c8c6b..99d9c0f095 100644 --- a/src/renderer/src/utils/messageUtils/create.ts +++ b/src/renderer/src/utils/messageUtils/create.ts @@ -1,6 +1,7 @@ import { loggerService } from '@logger' import type { Assistant, FileMetadata, Topic } from '@renderer/types' import { FileTypes } from '@renderer/types' +import { SerializedError } from '@renderer/types/error' import type { BaseMessageBlock, CitationMessageBlock, @@ -135,7 +136,7 @@ export function createThinkingBlock( return { ...baseBlock, content, - thinking_millsec: overrides.thinking_millsec + thinking_millsec: overrides.thinking_millsec || 0 } } @@ -197,7 +198,7 @@ export function createFileBlock( */ export function createErrorBlock( messageId: string, - errorData: Record, + errorData: SerializedError, overrides: Partial> = {} ): ErrorMessageBlock { const baseBlock = createBaseMessageBlock(messageId, MessageBlockType.ERROR, { diff --git a/src/renderer/src/utils/serialize.ts b/src/renderer/src/utils/serialize.ts new file mode 100644 index 0000000000..bc29b994ea --- /dev/null +++ b/src/renderer/src/utils/serialize.ts @@ -0,0 +1,94 @@ +import { isSerializable } from '@renderer/types/serialize' + +/** + * 安全地序列化一个值为 JSON 字符串。 + * 基于 `Serializable` 类型和 `isSerializable` 运行时检查。 + * + * @param value 要序列化的值 + * @param options 配置选项 + * @returns 序列化后的字符串,或 null(如果失败且未抛错) + */ +export function safeSerialize( + value: unknown, + options: { + /** + * 处理不可序列化值的方式: + * - 'error': 抛出错误 + * - 'omit': 尝试过滤掉非法字段(⚠️ 不支持深度修复,仅顶层判断) + * - 'serialize': 尝试安全转换(如 Date → ISO 字符串) + */ + onError?: 'error' | 'omit' | 'serialize' + + /** + * 是否美化输出 + * @default true + */ + pretty?: boolean + } = {} +): string | null { + const { onError = 'serialize', pretty = true } = options + const space = pretty ? 2 : undefined + + // 1. 如果本身就是合法的 Serializable 值,直接序列化 + if (isSerializable(value)) { + try { + return JSON.stringify(value, null, space) + } catch (err) { + // 理论上不会发生,但以防万一(比如极深嵌套栈溢出) + if (onError === 'error') { + throw new Error(`Failed to stringify serializable value: ${err instanceof Error ? err.message : err}`) + } + return null + } + } + + // 2. 不是可序列化的,根据策略处理 + switch (onError) { + case 'error': + throw new TypeError('Value is not serializable and cannot be safely serialized.') + + case 'omit': + // 注意:这里不能“修复”对象,只能返回 null 表示跳过 + return null + + case 'serialize': { + // 宽容模式:尝试做一些安全转换 + return tryLenientSerialize(value, space) + } + } +} + +/** + * 尽力而为地序列化一个值,即使它不符合 Serializable。 + * 适用于调试、日志等非关键场景。 + */ +function tryLenientSerialize(value: unknown, space?: string | number): string { + const seen = new WeakSet() + + const serialized = JSON.stringify( + value, + (_, val: any) => { + // 处理循环引用 + if (typeof val === 'object' && val !== null) { + if (seen.has(val)) { + return '[Circular]' + } + seen.add(val) + } + + // 处理特殊类型 + if (val instanceof Date) return val.toISOString() + if (val instanceof RegExp) return `{RegExp: "${val.toString()}"}` + if (typeof val === 'function') return `[Function: ${val.name || 'anonymous'}]` + if (typeof val === 'symbol') return `Symbol(${String(val.description)})` + if (val instanceof Map) return Object.fromEntries(val.entries()) + if (val instanceof Set) return Array.from(val) + if (val === undefined) return '[undefined]' + + return val + }, + space + ) + + return serialized +} diff --git a/src/renderer/src/utils/translate.ts b/src/renderer/src/utils/translate.ts index 979f38dc83..89d6726c18 100644 --- a/src/renderer/src/utils/translate.ts +++ b/src/renderer/src/utils/translate.ts @@ -1,10 +1,15 @@ import { loggerService } from '@logger' +import { isQwenMTModel } from '@renderer/config/models' +import { LANG_DETECT_PROMPT } from '@renderer/config/prompts' import { builtinLanguages as builtinLanguages, LanguagesEnum, UNKNOWN } from '@renderer/config/translate' import db from '@renderer/databases' -import { fetchLanguageDetection } from '@renderer/services/ApiService' +import i18n from '@renderer/i18n' +import { fetchChatCompletion } from '@renderer/services/ApiService' +import { getDefaultAssistant, getDefaultModel, getQuickModel } from '@renderer/services/AssistantService' import { estimateTextTokens } from '@renderer/services/TokenService' import { getAllCustomLanguages } from '@renderer/services/TranslateService' -import { TranslateLanguage, TranslateLanguageCode } from '@renderer/types' +import { Assistant, TranslateLanguage, TranslateLanguageCode } from '@renderer/types' +import { Chunk, ChunkType } from '@renderer/types/chunk' import { franc } from 'franc-min' import React, { RefObject } from 'react' import { sliceByTokens } from 'tokenx' @@ -55,13 +60,41 @@ export const detectLanguage = async (inputText: string): Promise => { logger.info('Detect language by llm') let detectedLang = '' - await fetchLanguageDetection({ - text: sliceByTokens(inputText, 0, 100), - onResponse: (text) => { - detectedLang = text.replace(/^\s*\n+/g, '') + const text = sliceByTokens(inputText, 0, 100) + + const translateLanguageOptions = await getTranslateOptions() + const listLang = translateLanguageOptions.map((item) => item.langCode) + const listLangText = JSON.stringify(listLang) + + const model = getQuickModel() || getDefaultModel() + if (!model) { + throw new Error(i18n.t('error.model.not_exists')) + } + + if (isQwenMTModel(model)) { + logger.info('QwenMT cannot be used for language detection.') + if (isQwenMTModel(model)) { + throw new Error(i18n.t('translate.error.detect.qwen_mt')) } - }) - return detectedLang + } + + const assistant: Assistant = getDefaultAssistant() + + assistant.model = model + assistant.settings = { + temperature: 0.7 + } + assistant.prompt = LANG_DETECT_PROMPT.replace('{{list_lang}}', listLangText).replace('{{input}}', text) + + const onChunk: (chunk: Chunk) => void = (chunk: Chunk) => { + // 你的意思是,虽然写的是delta类型,但其实是完整拼接后的结果? + if (chunk.type === ChunkType.TEXT_DELTA) { + detectedLang = chunk.text + } + } + + await fetchChatCompletion({ prompt: 'follow system prompt', assistant, onChunkReceived: onChunk }) + return detectedLang.trim() } const detectLanguageByFranc = (inputText: string): TranslateLanguageCode => { diff --git a/src/renderer/src/windows/mini/MiniWindowApp.tsx b/src/renderer/src/windows/mini/MiniWindowApp.tsx index 5d19bb16cc..7b7753dba2 100644 --- a/src/renderer/src/windows/mini/MiniWindowApp.tsx +++ b/src/renderer/src/windows/mini/MiniWindowApp.tsx @@ -1,5 +1,6 @@ import '@renderer/databases' +import { ErrorBoundary } from '@renderer/components/ErrorBoundary' import { useSettings } from '@renderer/hooks/useSettings' import store, { persistor } from '@renderer/store' import { message } from 'antd' @@ -44,8 +45,10 @@ function MiniWindow(): React.ReactElement { - {messageContextHolder} - + + {messageContextHolder} + + diff --git a/src/renderer/src/windows/mini/home/HomeWindow.tsx b/src/renderer/src/windows/mini/home/HomeWindow.tsx index a8df8a6387..ca654d9781 100644 --- a/src/renderer/src/windows/mini/home/HomeWindow.tsx +++ b/src/renderer/src/windows/mini/home/HomeWindow.tsx @@ -6,6 +6,7 @@ import { useSettings } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import { fetchChatCompletion } from '@renderer/services/ApiService' import { getDefaultTopic } from '@renderer/services/AssistantService' +import { ConversationService } from '@renderer/services/ConversationService' import { getAssistantMessage, getUserMessage } from '@renderer/services/MessagesService' import store, { useAppSelector } from '@renderer/store' import { updateOneBlock, upsertManyBlocks, upsertOneBlock } from '@renderer/store/messageBlock' @@ -262,13 +263,20 @@ const HomeWindow: FC<{ draggable?: boolean }> = ({ draggable = true }) => { } newAssistant.settings.streamOutput = true // 显式关闭这些功能 - // newAssistant.webSearchProviderId = undefined + newAssistant.webSearchProviderId = undefined newAssistant.mcpServers = undefined - // newAssistant.knowledge_bases = undefined + newAssistant.knowledge_bases = undefined + const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel( + messagesForContext, + newAssistant + ) await fetchChatCompletion({ - messages: messagesForContext, + messages: modelMessages, assistant: newAssistant, + options: {}, + topicId, + uiMessages: uiMessages, onChunkReceived: (chunk: Chunk) => { switch (chunk.type) { case ChunkType.THINKING_START: diff --git a/src/renderer/src/windows/selection/action/components/ActionUtils.ts b/src/renderer/src/windows/selection/action/components/ActionUtils.ts index 5f41b9b107..7ef1166258 100644 --- a/src/renderer/src/windows/selection/action/components/ActionUtils.ts +++ b/src/renderer/src/windows/selection/action/components/ActionUtils.ts @@ -1,5 +1,6 @@ import { loggerService } from '@logger' import { fetchChatCompletion } from '@renderer/services/ApiService' +import { ConversationService } from '@renderer/services/ConversationService' import { getAssistantMessage, getUserMessage } from '@renderer/services/MessagesService' import store from '@renderer/store' import { updateOneBlock, upsertManyBlocks, upsertOneBlock } from '@renderer/store/messageBlock' @@ -62,11 +63,14 @@ export const processMessages = async ( // 显式关闭这些功能 newAssistant.webSearchProviderId = undefined newAssistant.mcpServers = undefined - // newAssistant.knowledge_bases = undefined + newAssistant.knowledge_bases = undefined + const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel([userMessage], newAssistant) await fetchChatCompletion({ - messages: [userMessage], + messages: modelMessages, assistant: newAssistant, + options: {}, + uiMessages: uiMessages, onChunkReceived: (chunk: Chunk) => { if (finished) { return diff --git a/src/renderer/src/windows/selection/action/components/__tests__/ActionUtils.test.ts b/src/renderer/src/windows/selection/action/components/__tests__/ActionUtils.test.ts index 3431dd9440..5e02b813ca 100644 --- a/src/renderer/src/windows/selection/action/components/__tests__/ActionUtils.test.ts +++ b/src/renderer/src/windows/selection/action/components/__tests__/ActionUtils.test.ts @@ -1,7 +1,6 @@ import type { Assistant, Topic } from '@renderer/types' import { ChunkType } from '@renderer/types/chunk' import { AssistantMessageStatus, MessageBlockStatus } from '@renderer/types/newMessage' -import OpenAI from 'openai' import { afterEach, beforeEach, describe, expect, it, type Mock, vi } from 'vitest' import { processMessages } from '../ActionUtils' @@ -11,6 +10,17 @@ vi.mock('@renderer/services/ApiService', () => ({ fetchChatCompletion: vi.fn() })) +vi.mock('@renderer/services/ConversationService', () => ({ + ConversationService: class { + static async prepareMessagesForModel() { + return { + modelMessages: [{ role: 'user', content: 'test prompt' }], + uiMessages: [{ id: 'user-message-1', role: 'user', content: 'test prompt' }] + } + } + } +})) + vi.mock('@renderer/services/MessagesService', () => ({ getUserMessage: vi.fn(), getAssistantMessage: vi.fn() @@ -51,6 +61,20 @@ vi.mock('@renderer/utils/messageUtils/create', () => ({ createErrorBlock: vi.fn() })) +vi.mock('@renderer/config/models', () => ({ + SYSTEM_MODELS: { + defaultModel: [ + { id: 'gpt-4', name: 'GPT-4' }, + { id: 'gpt-4', name: 'GPT-4' }, + { id: 'gpt-4', name: 'GPT-4' } + ], + silicon: [], + openai: [], + anthropic: [], + gemini: [] + } +})) + // Import mocked modules import { fetchChatCompletion } from '@renderer/services/ApiService' import { getAssistantMessage, getUserMessage } from '@renderer/services/MessagesService' @@ -165,28 +189,6 @@ describe('processMessages', () => { for (const chunk of mockChunks) { await onChunkReceived(chunk) } - const rawOutput: OpenAI.ChatCompletion = { - id: 'test-id', - model: 'test-model', - object: 'chat.completion', - created: Date.now(), - choices: [ - { - index: 0, - message: { - role: 'assistant', - content: 'Here is my answer to your question.', - refusal: '' - }, - finish_reason: 'stop', - logprobs: null - } - ] - } - return { - rawOutput, - getText: () => 'Here is my answer to your question.' - } }) await processMessages( @@ -295,28 +297,6 @@ describe('processMessages', () => { for (const chunk of mockChunks) { await onChunkReceived(chunk) } - const rawOutput: OpenAI.ChatCompletion = { - id: 'test-id', - model: 'test-model', - object: 'chat.completion', - created: Date.now(), - choices: [ - { - index: 0, - message: { - role: 'assistant', - content: 'Partial response', - refusal: '' - }, - finish_reason: 'stop', - logprobs: null - } - ] - } - return { - rawOutput, - getText: () => 'Partial response' - } }) await processMessages( @@ -405,28 +385,6 @@ describe('processMessages', () => { for (const chunk of mockChunks) { await onChunkReceived(chunk) } - const rawOutput: OpenAI.ChatCompletion = { - id: 'test-id', - model: 'test-model', - object: 'chat.completion', - created: Date.now(), - choices: [ - { - index: 0, - message: { - role: 'assistant', - content: 'Partial', - refusal: '' - }, - finish_reason: 'stop', - logprobs: null - } - ] - } - return { - rawOutput, - getText: () => 'Partial' - } }) await processMessages( @@ -567,28 +525,6 @@ describe('processMessages', () => { for (const chunk of mockChunks) { await onChunkReceived(chunk) } - const rawOutput: OpenAI.ChatCompletion = { - id: 'test-id', - model: 'test-model', - object: 'chat.completion', - created: Date.now(), - choices: [ - { - index: 0, - message: { - role: 'assistant', - content: 'Second text', - refusal: '' - }, - finish_reason: 'stop', - logprobs: null - } - ] - } - return { - rawOutput, - getText: () => 'Second text' - } }) await processMessages( diff --git a/tsconfig.web.json b/tsconfig.web.json index cb2ac3c289..5936bfaa03 100644 --- a/tsconfig.web.json +++ b/tsconfig.web.json @@ -7,6 +7,7 @@ "packages/shared/**/*", "tests/__mocks__/**/*", "packages/mcp-trace/**/*", + "packages/aiCore/src/**/*", "src/main/integration/cherryin/index.js", "packages/extension-table-plus/**/*" ], @@ -23,6 +24,10 @@ "@shared/*": ["packages/shared/*"], "@types": ["src/renderer/src/types/index.ts"], "@mcp-trace/*": ["packages/mcp-trace/*"], + "@cherrystudio/ai-core/provider": ["packages/aiCore/src/core/providers/index.ts"], + "@cherrystudio/ai-core/built-in/plugins": ["packages/aiCore/src/core/plugins/built-in/index.ts"], + "@cherrystudio/ai-core/*": ["packages/aiCore/src/*"], + "@cherrystudio/ai-core": ["packages/aiCore/src/index.ts"], "@cherrystudio/extension-table-plus": ["packages/extension-table-plus/src/index.ts"] }, "experimentalDecorators": true, diff --git a/yarn.lock b/yarn.lock index a8fdaab0af..04ea80ce6c 100644 --- a/yarn.lock +++ b/yarn.lock @@ -74,6 +74,248 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/amazon-bedrock@npm:^3.0.0": + version: 3.0.8 + resolution: "@ai-sdk/amazon-bedrock@npm:3.0.8" + dependencies: + "@ai-sdk/anthropic": "npm:2.0.4" + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.3" + "@smithy/eventstream-codec": "npm:^4.0.1" + "@smithy/util-utf8": "npm:^4.0.0" + aws4fetch: "npm:^1.0.20" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/d7b303b8581e9d28e9ac375b3718ef3f7fff3353d18185870f0b90fd542eb9398d029768502981e9e45a6b64137a7029f591993afd0b18e9ef74525f625524f7 + languageName: node + linkType: hard + +"@ai-sdk/anthropic@npm:2.0.4": + version: 2.0.4 + resolution: "@ai-sdk/anthropic@npm:2.0.4" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.3" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/2e5a997b6e2d9a2964c4681418643fd2f347df78ac1f9677a0cc6a3a3454920d05c663e35521d8922f0a382ec77a25e4b92204b3760a1da05876bf00d41adc39 + languageName: node + linkType: hard + +"@ai-sdk/anthropic@npm:^2.0.5": + version: 2.0.5 + resolution: "@ai-sdk/anthropic@npm:2.0.5" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.4" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/aaca0d4b2e00715c513a7c688d6b6116eaf29d1d37f005c150f1229200713fb1c393c81a8b01ac29af954fb1ee213f3a537861227051865abe51aa547dca364e + languageName: node + linkType: hard + +"@ai-sdk/azure@npm:^2.0.16": + version: 2.0.16 + resolution: "@ai-sdk/azure@npm:2.0.16" + dependencies: + "@ai-sdk/openai": "npm:2.0.16" + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.4" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/49bd9d27cba3104ba5d8a82c70a16dd475572585c5187e5bc29c9d46a30a373338181b29f37dfe9f61f50b5b82e86808139c93da225eb1721cb15e1a8b97cceb + languageName: node + linkType: hard + +"@ai-sdk/deepseek@npm:^1.0.9": + version: 1.0.9 + resolution: "@ai-sdk/deepseek@npm:1.0.9" + dependencies: + "@ai-sdk/openai-compatible": "npm:1.0.9" + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.4" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/b02a000a98a6df9808d472bf63640ee96297f9acce7422de0d198ffda40edcbcadc0946ae383464b80a92ac033a3a61cf71fa1bc640c08cac589bebc8d5623b9 + languageName: node + linkType: hard + +"@ai-sdk/gateway@npm:1.0.15": + version: 1.0.15 + resolution: "@ai-sdk/gateway@npm:1.0.15" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.7" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/cdd09f119d6618f00c363a27f51dc466a8a64f57f01bcdd127030a804825bd143b0fef2dbdb7802530865d474f4b9d55855670fecd7f2e6c615a5d9ac9fd6e3b + languageName: node + linkType: hard + +"@ai-sdk/google-vertex@npm:^3.0.0": + version: 3.0.9 + resolution: "@ai-sdk/google-vertex@npm:3.0.9" + dependencies: + "@ai-sdk/anthropic": "npm:2.0.4" + "@ai-sdk/google": "npm:2.0.6" + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.3" + google-auth-library: "npm:^9.15.0" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/c6584b877f9e20a10dd7d92fc4cb1b4a9838510aa89734cf1ff2faa74ba820b976d3359d4eadcb6035c8911973300efb157931fa0d1105abc8db36f94544cc88 + languageName: node + linkType: hard + +"@ai-sdk/google@npm:2.0.6": + version: 2.0.6 + resolution: "@ai-sdk/google@npm:2.0.6" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.3" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/ad54dd4168df62851646bec3ac2e5cf9e39f3def3e9017579aef5c8e8ecdf57c150c67a80cad4d092c3df69cd8539bc1792adb6c311ed095f8261673b7812e98 + languageName: node + linkType: hard + +"@ai-sdk/google@npm:^2.0.7": + version: 2.0.7 + resolution: "@ai-sdk/google@npm:2.0.7" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.4" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/bde4c95a2a167355cda18de9d5b273d562d2a724f650ca69016daa8df2766280487e143cf0cdd96f6654c255d587a680c6a937b280eb734ca2c35d6f9b9e943c + languageName: node + linkType: hard + +"@ai-sdk/mistral@npm:^2.0.0": + version: 2.0.4 + resolution: "@ai-sdk/mistral@npm:2.0.4" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.3" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/cca88cba855d4952551ca0be748e21f0d1b54537d0c7e08f30facdfbdbac7e6894ff4a1ceb53657aaf6e4380bbaa39d3cc37d1f734d777cdc1caba004c87221f + languageName: node + linkType: hard + +"@ai-sdk/openai-compatible@npm:1.0.9, @ai-sdk/openai-compatible@npm:^1.0.9": + version: 1.0.9 + resolution: "@ai-sdk/openai-compatible@npm:1.0.9" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.4" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/a98505438f7a4c0d5c1aee9fb03aae00ff726c1c5ba0eff45d00ddc30ab9f25de634fcfd111a634bd654042150b9f16a131ce3f45887f9661c0241e3807d6ad4 + languageName: node + linkType: hard + +"@ai-sdk/openai@npm:2.0.16": + version: 2.0.16 + resolution: "@ai-sdk/openai@npm:2.0.16" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.4" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/1ea694bd096175a67a383e73fd1f4434eeaa7ddc6c378e44f295333d9a7b4153251d405dac2d8da330f95e4d5ef58641cc8533a3e63ff4d250b3cbc66f9abfea + languageName: node + linkType: hard + +"@ai-sdk/openai@npm:^2.0.19": + version: 2.0.19 + resolution: "@ai-sdk/openai@npm:2.0.19" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.5" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/04db695669d783a810b80283e0cd48f6e7654667fd76ca2d35c7cffae6fdd68fb0473118e4e097ef1352f4432dd7c15c07f873d712b940c72495e5839b0ede98 + languageName: node + linkType: hard + +"@ai-sdk/provider-utils@npm:3.0.3": + version: 3.0.3 + resolution: "@ai-sdk/provider-utils@npm:3.0.3" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@standard-schema/spec": "npm:^1.0.0" + eventsource-parser: "npm:^3.0.3" + zod-to-json-schema: "npm:^3.24.1" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/f02e26a6b85ef728862505b150475ef2e52d60130ca64b23316ff7b952f1817b01f959b9e48819dad64d82a96ba4ad538610d69dbbfe5be4b4b38469c16a6ccf + languageName: node + linkType: hard + +"@ai-sdk/provider-utils@npm:3.0.4, @ai-sdk/provider-utils@npm:^3.0.4": + version: 3.0.4 + resolution: "@ai-sdk/provider-utils@npm:3.0.4" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@standard-schema/spec": "npm:^1.0.0" + eventsource-parser: "npm:^3.0.3" + zod-to-json-schema: "npm:^3.24.1" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/6732b99310561d72262cdeef40cc58190afa55248dca0eb3a378ef87fede12086e534c68687e0fe5ef5b092da41f3e745857ce3f9b248a272a78c0dc268dffd4 + languageName: node + linkType: hard + +"@ai-sdk/provider-utils@npm:3.0.5": + version: 3.0.5 + resolution: "@ai-sdk/provider-utils@npm:3.0.5" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@standard-schema/spec": "npm:^1.0.0" + eventsource-parser: "npm:^3.0.3" + zod-to-json-schema: "npm:^3.24.1" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/4057810b320bda149a178dc1bfc9cdd592ca88b736c3c22bd0c1f8111c75ef69beec4a523f363e5d0d120348b876942fd66c0bb4965864da4c12c5cfddee15a3 + languageName: node + linkType: hard + +"@ai-sdk/provider-utils@npm:3.0.7": + version: 3.0.7 + resolution: "@ai-sdk/provider-utils@npm:3.0.7" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@standard-schema/spec": "npm:^1.0.0" + eventsource-parser: "npm:^3.0.5" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/7e709289f9e514a6ba56a9b19764eb124ea1bd36d4b3b3e455a1c05353674c152839a4d3cd061af7a4cc36106bd15859a2346e54d4ed0a861feec3b2c4c21513 + languageName: node + linkType: hard + +"@ai-sdk/provider@npm:2.0.0, @ai-sdk/provider@npm:^2.0.0": + version: 2.0.0 + resolution: "@ai-sdk/provider@npm:2.0.0" + dependencies: + json-schema: "npm:^0.4.0" + checksum: 10c0/e50e520016c9fc0a8b5009cadd47dae2f1c81ec05c1792b9e312d7d15479f024ca8039525813a33425c884e3449019fed21043b1bfabd6a2626152ca9a388199 + languageName: node + linkType: hard + +"@ai-sdk/xai@npm:^2.0.9": + version: 2.0.9 + resolution: "@ai-sdk/xai@npm:2.0.9" + dependencies: + "@ai-sdk/openai-compatible": "npm:1.0.9" + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.4" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/15a3ace8e06b42ee148d8d100cdf946919e0763c45fb1b85454e313d4de43426c6d162c333d07ad338a9de415dc9e68c50411a6ec0305dbc5edb7d623c2023da + languageName: node + linkType: hard + "@ampproject/remapping@npm:^2.2.0, @ampproject/remapping@npm:^2.3.0": version: 2.3.0 resolution: "@ampproject/remapping@npm:2.3.0" @@ -1950,7 +2192,7 @@ __metadata: languageName: node linkType: hard -"@babel/types@npm:^7.28.2": +"@babel/types@npm:^7.28.1, @babel/types@npm:^7.28.2": version: 7.28.2 resolution: "@babel/types@npm:7.28.2" dependencies: @@ -1993,6 +2235,28 @@ __metadata: languageName: node linkType: hard +"@cherrystudio/ai-core@workspace:*, @cherrystudio/ai-core@workspace:packages/aiCore": + version: 0.0.0-use.local + resolution: "@cherrystudio/ai-core@workspace:packages/aiCore" + dependencies: + "@ai-sdk/anthropic": "npm:^2.0.5" + "@ai-sdk/azure": "npm:^2.0.16" + "@ai-sdk/deepseek": "npm:^1.0.9" + "@ai-sdk/google": "npm:^2.0.7" + "@ai-sdk/openai": "npm:^2.0.19" + "@ai-sdk/openai-compatible": "npm:^1.0.9" + "@ai-sdk/provider": "npm:^2.0.0" + "@ai-sdk/provider-utils": "npm:^3.0.4" + "@ai-sdk/xai": "npm:^2.0.9" + tsdown: "npm:^0.12.9" + typescript: "npm:^5.0.0" + vitest: "npm:^3.2.4" + zod: "npm:^3.25.0" + peerDependencies: + ai: ^5.0.26 + languageName: unknown + linkType: soft + "@cherrystudio/embedjs-interfaces@npm:0.1.30": version: 0.1.30 resolution: "@cherrystudio/embedjs-interfaces@npm:0.1.30" @@ -4921,6 +5185,16 @@ __metadata: languageName: node linkType: hard +"@openrouter/ai-sdk-provider@npm:^1.1.2": + version: 1.1.2 + resolution: "@openrouter/ai-sdk-provider@npm:1.1.2" + peerDependencies: + ai: ^5.0.0 + zod: ^3.24.1 || ^v4 + checksum: 10c0/1ad50804189910d52c2c10e479bec40dfbd2109820e43135d001f4f8706be6ace532d4769a8c30111f5870afdfa97b815c7334b2e4d8d36ca68b1578ce5d9a41 + languageName: node + linkType: hard + "@opentelemetry/api-logs@npm:0.200.0": version: 0.200.0 resolution: "@opentelemetry/api-logs@npm:0.200.0" @@ -4930,7 +5204,7 @@ __metadata: languageName: node linkType: hard -"@opentelemetry/api@npm:^1.3.0, @opentelemetry/api@npm:^1.9.0": +"@opentelemetry/api@npm:1.9.0, @opentelemetry/api@npm:^1.3.0, @opentelemetry/api@npm:^1.9.0": version: 1.9.0 resolution: "@opentelemetry/api@npm:1.9.0" checksum: 10c0/9aae2fe6e8a3a3eeb6c1fdef78e1939cf05a0f37f8a4fae4d6bf2e09eb1e06f966ece85805626e01ba5fab48072b94f19b835449e58b6d26720ee19a58298add @@ -5113,9 +5387,9 @@ __metadata: linkType: hard "@opentelemetry/semantic-conventions@npm:^1.29.0": - version: 1.34.0 - resolution: "@opentelemetry/semantic-conventions@npm:1.34.0" - checksum: 10c0/a51a32a5cf5c803bd2125a680d0abacbff632f3b255d0fe52379dac191114a0e8d72a34f9c46c5483ccfe91c4061c309f3cf61a19d11347e2a69779e82cfefd0 + version: 1.36.0 + resolution: "@opentelemetry/semantic-conventions@npm:1.36.0" + checksum: 10c0/edc8a6fe3ec4fc0c67ba3a92b86fb3dcc78fe1eb4f19838d8013c3232b9868540a034dd25cfe0afdd5eae752c5f0e9f42272ff46da144a2d5b35c644478e1c62 languageName: node linkType: hard @@ -5826,6 +6100,13 @@ __metadata: languageName: node linkType: hard +"@rolldown/pluginutils@npm:1.0.0-beta.27": + version: 1.0.0-beta.27 + resolution: "@rolldown/pluginutils@npm:1.0.0-beta.27" + checksum: 10c0/9658f235b345201d4f6bfb1f32da9754ca164f892d1cb68154fe5f53c1df42bd675ecd409836dff46884a7847d6c00bdc38af870f7c81e05bba5c2645eb4ab9c + languageName: node + linkType: hard + "@rolldown/pluginutils@npm:1.0.0-beta.34": version: 1.0.0-beta.34 resolution: "@rolldown/pluginutils@npm:1.0.0-beta.34" @@ -6098,6 +6379,18 @@ __metadata: languageName: node linkType: hard +"@smithy/eventstream-codec@npm:^4.0.1": + version: 4.0.5 + resolution: "@smithy/eventstream-codec@npm:4.0.5" + dependencies: + "@aws-crypto/crc32": "npm:5.2.0" + "@smithy/types": "npm:^4.3.2" + "@smithy/util-hex-encoding": "npm:^4.0.0" + tslib: "npm:^2.6.2" + checksum: 10c0/d94928e22468cb6e6d09bdc8a6ee04f05947c141c0b040aa90e95b6edc123ba03a562ff3994b5827c57295981183325ed8e8f6c60448a4eec392227735e86d62 + languageName: node + linkType: hard + "@smithy/eventstream-codec@npm:^4.0.4": version: 4.0.4 resolution: "@smithy/eventstream-codec@npm:4.0.4" @@ -7055,6 +7348,13 @@ __metadata: languageName: node linkType: hard +"@standard-schema/spec@npm:^1.0.0": + version: 1.0.0 + resolution: "@standard-schema/spec@npm:1.0.0" + checksum: 10c0/a1ab9a8bdc09b5b47aa8365d0e0ec40cc2df6437be02853696a0e377321653b0d3ac6f079a8c67d5ddbe9821025584b1fb71d9cc041a6666a96f1fadf2ece15f + languageName: node + linkType: hard + "@strongtz/win32-arm64-msvc@npm:^0.4.7": version: 0.4.7 resolution: "@strongtz/win32-arm64-msvc@npm:0.4.7" @@ -7063,92 +7363,92 @@ __metadata: languageName: node linkType: hard -"@swc/core-darwin-arm64@npm:1.11.21": - version: 1.11.21 - resolution: "@swc/core-darwin-arm64@npm:1.11.21" +"@swc/core-darwin-arm64@npm:1.13.3": + version: 1.13.3 + resolution: "@swc/core-darwin-arm64@npm:1.13.3" conditions: os=darwin & cpu=arm64 languageName: node linkType: hard -"@swc/core-darwin-x64@npm:1.11.21": - version: 1.11.21 - resolution: "@swc/core-darwin-x64@npm:1.11.21" +"@swc/core-darwin-x64@npm:1.13.3": + version: 1.13.3 + resolution: "@swc/core-darwin-x64@npm:1.13.3" conditions: os=darwin & cpu=x64 languageName: node linkType: hard -"@swc/core-linux-arm-gnueabihf@npm:1.11.21": - version: 1.11.21 - resolution: "@swc/core-linux-arm-gnueabihf@npm:1.11.21" +"@swc/core-linux-arm-gnueabihf@npm:1.13.3": + version: 1.13.3 + resolution: "@swc/core-linux-arm-gnueabihf@npm:1.13.3" conditions: os=linux & cpu=arm languageName: node linkType: hard -"@swc/core-linux-arm64-gnu@npm:1.11.21": - version: 1.11.21 - resolution: "@swc/core-linux-arm64-gnu@npm:1.11.21" +"@swc/core-linux-arm64-gnu@npm:1.13.3": + version: 1.13.3 + resolution: "@swc/core-linux-arm64-gnu@npm:1.13.3" conditions: os=linux & cpu=arm64 & libc=glibc languageName: node linkType: hard -"@swc/core-linux-arm64-musl@npm:1.11.21": - version: 1.11.21 - resolution: "@swc/core-linux-arm64-musl@npm:1.11.21" +"@swc/core-linux-arm64-musl@npm:1.13.3": + version: 1.13.3 + resolution: "@swc/core-linux-arm64-musl@npm:1.13.3" conditions: os=linux & cpu=arm64 & libc=musl languageName: node linkType: hard -"@swc/core-linux-x64-gnu@npm:1.11.21": - version: 1.11.21 - resolution: "@swc/core-linux-x64-gnu@npm:1.11.21" +"@swc/core-linux-x64-gnu@npm:1.13.3": + version: 1.13.3 + resolution: "@swc/core-linux-x64-gnu@npm:1.13.3" conditions: os=linux & cpu=x64 & libc=glibc languageName: node linkType: hard -"@swc/core-linux-x64-musl@npm:1.11.21": - version: 1.11.21 - resolution: "@swc/core-linux-x64-musl@npm:1.11.21" +"@swc/core-linux-x64-musl@npm:1.13.3": + version: 1.13.3 + resolution: "@swc/core-linux-x64-musl@npm:1.13.3" conditions: os=linux & cpu=x64 & libc=musl languageName: node linkType: hard -"@swc/core-win32-arm64-msvc@npm:1.11.21": - version: 1.11.21 - resolution: "@swc/core-win32-arm64-msvc@npm:1.11.21" +"@swc/core-win32-arm64-msvc@npm:1.13.3": + version: 1.13.3 + resolution: "@swc/core-win32-arm64-msvc@npm:1.13.3" conditions: os=win32 & cpu=arm64 languageName: node linkType: hard -"@swc/core-win32-ia32-msvc@npm:1.11.21": - version: 1.11.21 - resolution: "@swc/core-win32-ia32-msvc@npm:1.11.21" +"@swc/core-win32-ia32-msvc@npm:1.13.3": + version: 1.13.3 + resolution: "@swc/core-win32-ia32-msvc@npm:1.13.3" conditions: os=win32 & cpu=ia32 languageName: node linkType: hard -"@swc/core-win32-x64-msvc@npm:1.11.21": - version: 1.11.21 - resolution: "@swc/core-win32-x64-msvc@npm:1.11.21" +"@swc/core-win32-x64-msvc@npm:1.13.3": + version: 1.13.3 + resolution: "@swc/core-win32-x64-msvc@npm:1.13.3" conditions: os=win32 & cpu=x64 languageName: node linkType: hard -"@swc/core@npm:^1.11.21": - version: 1.11.21 - resolution: "@swc/core@npm:1.11.21" +"@swc/core@npm:^1.12.11": + version: 1.13.3 + resolution: "@swc/core@npm:1.13.3" dependencies: - "@swc/core-darwin-arm64": "npm:1.11.21" - "@swc/core-darwin-x64": "npm:1.11.21" - "@swc/core-linux-arm-gnueabihf": "npm:1.11.21" - "@swc/core-linux-arm64-gnu": "npm:1.11.21" - "@swc/core-linux-arm64-musl": "npm:1.11.21" - "@swc/core-linux-x64-gnu": "npm:1.11.21" - "@swc/core-linux-x64-musl": "npm:1.11.21" - "@swc/core-win32-arm64-msvc": "npm:1.11.21" - "@swc/core-win32-ia32-msvc": "npm:1.11.21" - "@swc/core-win32-x64-msvc": "npm:1.11.21" + "@swc/core-darwin-arm64": "npm:1.13.3" + "@swc/core-darwin-x64": "npm:1.13.3" + "@swc/core-linux-arm-gnueabihf": "npm:1.13.3" + "@swc/core-linux-arm64-gnu": "npm:1.13.3" + "@swc/core-linux-arm64-musl": "npm:1.13.3" + "@swc/core-linux-x64-gnu": "npm:1.13.3" + "@swc/core-linux-x64-musl": "npm:1.13.3" + "@swc/core-win32-arm64-msvc": "npm:1.13.3" + "@swc/core-win32-ia32-msvc": "npm:1.13.3" + "@swc/core-win32-x64-msvc": "npm:1.13.3" "@swc/counter": "npm:^0.1.3" - "@swc/types": "npm:^0.1.21" + "@swc/types": "npm:^0.1.23" peerDependencies: "@swc/helpers": ">=0.5.17" dependenciesMeta: @@ -7175,7 +7475,7 @@ __metadata: peerDependenciesMeta: "@swc/helpers": optional: true - checksum: 10c0/d37d21bcc8656e1719c262403eb54f3ec7925493642ca17bf4061ddf67cb327ea2718ad1da749b9db0c6e6e3aeb2d9f0e544939688408c4f89d38982c24612d4 + checksum: 10c0/88a04c319082f8ae5e53b7d7a874014600296087cad3e07d0e927088a19ba2e8355cbced7da02476b5f89cc653e26d1e1c44d9f43ef07fb7b74ec4b5f9e95ef6 languageName: node linkType: hard @@ -7186,21 +7486,21 @@ __metadata: languageName: node linkType: hard -"@swc/plugin-styled-components@npm:^7.1.5": - version: 7.1.5 - resolution: "@swc/plugin-styled-components@npm:7.1.5" +"@swc/plugin-styled-components@npm:^8.0.4": + version: 8.0.4 + resolution: "@swc/plugin-styled-components@npm:8.0.4" dependencies: "@swc/counter": "npm:^0.1.3" - checksum: 10c0/abffb0030aeb65bd0ba5be62debd35588e621d50414bd882773d32d8b63b839083cc0c089b1311c84e2068df82d697206d254a1e72ebd1be7a86523dabab98a9 + checksum: 10c0/8c9c133c517133eb1d241cffb44a76bfb526a17c6de6d5d0ddaf9eb5181eebb4798d79e2852947f827c14768327b31d76a41e1f9ff1b4366b3cc7f8435d84b4b languageName: node linkType: hard -"@swc/types@npm:^0.1.21": - version: 0.1.21 - resolution: "@swc/types@npm:0.1.21" +"@swc/types@npm:^0.1.23": + version: 0.1.24 + resolution: "@swc/types@npm:0.1.24" dependencies: "@swc/counter": "npm:^0.1.3" - checksum: 10c0/2baa89c824426e0de0c84e212278010e2df8dc2d6ffaa6f1e306e1b2930c6404b3d3f8989307e8c42ceb95ac143ab7a80be138af6a014d5c782dce5be94dcd5e + checksum: 10c0/4ca95a338f070f48303e705996bacfc1219f606c45274bed4f6e3488b86b7b20397bd52792e58fdea0fa924fc939695b5eb5ff7f3ff4737382148fe6097e235a languageName: node linkType: hard @@ -9094,13 +9394,14 @@ __metadata: linkType: hard "@vitejs/plugin-react-swc@npm:^3.9.0": - version: 3.9.0 - resolution: "@vitejs/plugin-react-swc@npm:3.9.0" + version: 3.11.0 + resolution: "@vitejs/plugin-react-swc@npm:3.11.0" dependencies: - "@swc/core": "npm:^1.11.21" + "@rolldown/pluginutils": "npm:1.0.0-beta.27" + "@swc/core": "npm:^1.12.11" peerDependencies: - vite: ^4 || ^5 || ^6 - checksum: 10c0/28e99f2833d390982b9ab17ccabe7fa7562c30ea56cf57b34fe97dc401958ce838bfbb8f562f73c2764e041f8672f10ca3872220fd78660fc9cc43b539bb7962 + vite: ^4 || ^5 || ^6 || ^7 + checksum: 10c0/0d12ee81f8c8acb74b35e7acfc45d23ecc2714ab3a0f6060e4bd900a6a739dd5a9be9c6bc842388f3c406f475f2a83e7ff3ade04ec6df9172faa1242e4faa424 languageName: node linkType: hard @@ -9385,12 +9686,16 @@ __metadata: "@agentic/exa": "npm:^7.3.3" "@agentic/searxng": "npm:^7.3.3" "@agentic/tavily": "npm:^7.3.3" + "@ai-sdk/amazon-bedrock": "npm:^3.0.0" + "@ai-sdk/google-vertex": "npm:^3.0.0" + "@ai-sdk/mistral": "npm:^2.0.0" "@ant-design/v5-patch-for-react-19": "npm:^1.0.3" "@anthropic-ai/sdk": "npm:^0.41.0" "@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch" "@aws-sdk/client-bedrock": "npm:^3.840.0" "@aws-sdk/client-bedrock-runtime": "npm:^3.840.0" "@aws-sdk/client-s3": "npm:^3.840.0" + "@cherrystudio/ai-core": "workspace:*" "@cherrystudio/embedjs": "npm:^0.1.31" "@cherrystudio/embedjs-libsql": "npm:^0.1.31" "@cherrystudio/embedjs-loader-csv": "npm:^0.1.31" @@ -9429,6 +9734,7 @@ __metadata: "@mozilla/readability": "npm:^0.6.0" "@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch" "@notionhq/client": "npm:^2.2.15" + "@openrouter/ai-sdk-provider": "npm:^1.1.2" "@opentelemetry/api": "npm:^1.9.0" "@opentelemetry/core": "npm:2.0.0" "@opentelemetry/exporter-trace-otlp-http": "npm:^0.200.0" @@ -9439,7 +9745,7 @@ __metadata: "@reduxjs/toolkit": "npm:^2.2.5" "@shikijs/markdown-it": "npm:^3.12.0" "@strongtz/win32-arm64-msvc": "npm:^0.4.7" - "@swc/plugin-styled-components": "npm:^7.1.5" + "@swc/plugin-styled-components": "npm:^8.0.4" "@tanstack/react-query": "npm:^5.85.5" "@tanstack/react-virtual": "npm:^3.13.12" "@testing-library/dom": "npm:^10.4.0" @@ -9490,6 +9796,7 @@ __metadata: "@viz-js/lang-dot": "npm:^1.0.5" "@viz-js/viz": "npm:^3.14.0" "@xyflow/react": "npm:^12.4.4" + ai: "npm:^5.0.29" antd: "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch" archiver: "npm:^7.0.1" async-mutex: "npm:^0.5.0" @@ -9715,6 +10022,20 @@ __metadata: languageName: node linkType: hard +"ai@npm:^5.0.29": + version: 5.0.29 + resolution: "ai@npm:5.0.29" + dependencies: + "@ai-sdk/gateway": "npm:1.0.15" + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.7" + "@opentelemetry/api": "npm:1.9.0" + peerDependencies: + zod: ^3.25.76 || ^4 + checksum: 10c0/526cd2fd59b35b19d902665e3dc1ba5a09f2bb1377295d642fb8a33e13a890874e4dd4b49a787de7f31f4ec6b07257be8514efac08f993081daeb430cf2f60ba + languageName: node + linkType: hard + "ajv-formats@npm:^2.1.1": version: 2.1.1 resolution: "ajv-formats@npm:2.1.1" @@ -10225,6 +10546,13 @@ __metadata: languageName: node linkType: hard +"aws4fetch@npm:^1.0.20": + version: 1.0.20 + resolution: "aws4fetch@npm:1.0.20" + checksum: 10c0/a4eac7bd0d1c3e611c17ed1ef41ac0b48c0a8e74a985ad968c071e74d94586d3572edc943b43fa5ca756c686ea73baa2f48e264d657bb8c2e95c8e0037d48a87 + languageName: node + linkType: hard + "axios@npm:^1.7.3": version: 1.8.4 resolution: "axios@npm:1.8.4" @@ -13491,7 +13819,7 @@ __metadata: languageName: node linkType: hard -"eventsource-parser@npm:^3.0.0": +"eventsource-parser@npm:^3.0.0, eventsource-parser@npm:^3.0.3": version: 3.0.3 resolution: "eventsource-parser@npm:3.0.3" checksum: 10c0/2594011630efba56cafafc8ed6bd9a50db8f6d5dd62089b0950346e7961828c16efe07a588bdea3ba79e568fd9246c8163824a2ffaade767e1fdb2270c1fae0b @@ -13505,6 +13833,13 @@ __metadata: languageName: node linkType: hard +"eventsource-parser@npm:^3.0.5": + version: 3.0.5 + resolution: "eventsource-parser@npm:3.0.5" + checksum: 10c0/5cb75e3f84ff1cfa1cee6199d4fd430c4544855ab03e953ddbe5927e7b31bc2af3933ab8aba6440ba160ed2c48972b6c317f27b8a1d0764c7b12e34e249de631 + languageName: node + linkType: hard + "eventsource@npm:^3.0.2": version: 3.0.6 resolution: "eventsource@npm:3.0.6" @@ -14526,7 +14861,7 @@ __metadata: languageName: node linkType: hard -"google-auth-library@npm:^9.14.2, google-auth-library@npm:^9.15.1, google-auth-library@npm:^9.4.2": +"google-auth-library@npm:^9.14.2, google-auth-library@npm:^9.15.0, google-auth-library@npm:^9.15.1, google-auth-library@npm:^9.4.2": version: 9.15.1 resolution: "google-auth-library@npm:9.15.1" dependencies: @@ -15796,6 +16131,13 @@ __metadata: languageName: node linkType: hard +"json-schema@npm:^0.4.0": + version: 0.4.0 + resolution: "json-schema@npm:0.4.0" + checksum: 10c0/d4a637ec1d83544857c1c163232f3da46912e971d5bf054ba44fdb88f07d8d359a462b4aec46f2745efbc57053365608d88bc1d7b1729f7b4fc3369765639ed3 + languageName: node + linkType: hard + "json-stable-stringify-without-jsonify@npm:^1.0.1": version: 1.0.1 resolution: "json-stable-stringify-without-jsonify@npm:1.0.1" @@ -21153,6 +21495,34 @@ __metadata: languageName: node linkType: hard +"rolldown-plugin-dts@npm:^0.13.12": + version: 0.13.14 + resolution: "rolldown-plugin-dts@npm:0.13.14" + dependencies: + "@babel/generator": "npm:^7.28.0" + "@babel/parser": "npm:^7.28.0" + "@babel/types": "npm:^7.28.1" + ast-kit: "npm:^2.1.1" + birpc: "npm:^2.5.0" + debug: "npm:^4.4.1" + dts-resolver: "npm:^2.1.1" + get-tsconfig: "npm:^4.10.1" + peerDependencies: + "@typescript/native-preview": ">=7.0.0-dev.20250601.1" + rolldown: ^1.0.0-beta.9 + typescript: ^5.0.0 + vue-tsc: ^2.2.0 || ^3.0.0 + peerDependenciesMeta: + "@typescript/native-preview": + optional: true + typescript: + optional: true + vue-tsc: + optional: true + checksum: 10c0/f09da3990a6be11aed07db121439db907251578cd51bff69479186a056524f07a7a4d5b4056ff6bf884c1fdbacc655fcb0c956753f9bf8a73f2c885731bc2e77 + languageName: node + linkType: hard + "rolldown-plugin-dts@npm:^0.15.3": version: 0.15.6 resolution: "rolldown-plugin-dts@npm:0.15.6" @@ -21238,7 +21608,7 @@ __metadata: languageName: node linkType: hard -"rolldown@npm:^1.0.0-beta.31": +"rolldown@npm:^1.0.0-beta.19, rolldown@npm:^1.0.0-beta.31": version: 1.0.0-beta.9-commit.d91dfb5 resolution: "rolldown@npm:1.0.0-beta.9-commit.d91dfb5" dependencies: @@ -22912,6 +23282,46 @@ __metadata: languageName: node linkType: hard +"tsdown@npm:^0.12.9": + version: 0.12.9 + resolution: "tsdown@npm:0.12.9" + dependencies: + ansis: "npm:^4.1.0" + cac: "npm:^6.7.14" + chokidar: "npm:^4.0.3" + debug: "npm:^4.4.1" + diff: "npm:^8.0.2" + empathic: "npm:^2.0.0" + hookable: "npm:^5.5.3" + rolldown: "npm:^1.0.0-beta.19" + rolldown-plugin-dts: "npm:^0.13.12" + semver: "npm:^7.7.2" + tinyexec: "npm:^1.0.1" + tinyglobby: "npm:^0.2.14" + unconfig: "npm:^7.3.2" + peerDependencies: + "@arethetypeswrong/core": ^0.18.1 + publint: ^0.3.0 + typescript: ^5.0.0 + unplugin-lightningcss: ^0.4.0 + unplugin-unused: ^0.5.0 + peerDependenciesMeta: + "@arethetypeswrong/core": + optional: true + publint: + optional: true + typescript: + optional: true + unplugin-lightningcss: + optional: true + unplugin-unused: + optional: true + bin: + tsdown: dist/run.mjs + checksum: 10c0/5dd4842982815181f5a79bc87fff1dd9afc6952aaec065dcececde5ba76887163a01de313272964003ea90df8ac23efdfc8aabb290c5b8f8dae5332e9905c05b + languageName: node + linkType: hard + "tsdown@npm:^0.13.3": version: 0.13.3 resolution: "tsdown@npm:0.13.3" @@ -23070,6 +23480,16 @@ __metadata: languageName: node linkType: hard +"typescript@npm:^5.0.0": + version: 5.9.2 + resolution: "typescript@npm:5.9.2" + bin: + tsc: bin/tsc + tsserver: bin/tsserver + checksum: 10c0/cd635d50f02d6cf98ed42de2f76289701c1ec587a363369255f01ed15aaf22be0813226bff3c53e99d971f9b540e0b3cc7583dbe05faded49b1b0bed2f638a18 + languageName: node + linkType: hard + "typescript@npm:^5.4.3, typescript@npm:^5.6.2": version: 5.8.3 resolution: "typescript@npm:5.8.3" @@ -23080,6 +23500,16 @@ __metadata: languageName: node linkType: hard +"typescript@patch:typescript@npm%3A^5.0.0#optional!builtin": + version: 5.9.2 + resolution: "typescript@patch:typescript@npm%3A5.9.2#optional!builtin::version=5.9.2&hash=5786d5" + bin: + tsc: bin/tsc + tsserver: bin/tsserver + checksum: 10c0/34d2a8e23eb8e0d1875072064d5e1d9c102e0bdce56a10a25c0b917b8aa9001a9cf5c225df12497e99da107dc379360bc138163c66b55b95f5b105b50578067e + languageName: node + linkType: hard + "typescript@patch:typescript@npm%3A^5.4.3#optional!builtin, typescript@patch:typescript@npm%3A^5.6.2#optional!builtin": version: 5.8.3 resolution: "typescript@patch:typescript@npm%3A5.8.3#optional!builtin::version=5.8.3&hash=5786d5" @@ -24339,6 +24769,13 @@ __metadata: languageName: node linkType: hard +"zod@npm:^3.25.0": + version: 3.25.76 + resolution: "zod@npm:3.25.76" + checksum: 10c0/5718ec35e3c40b600316c5b4c5e4976f7fee68151bc8f8d90ec18a469be9571f072e1bbaace10f1e85cf8892ea12d90821b200e980ab46916a6166a4260a983c + languageName: node + linkType: hard + "zod@npm:^3.25.74": version: 3.25.74 resolution: "zod@npm:3.25.74"