diff --git a/electron-builder.yml b/electron-builder.yml index e7457ade48..909181956c 100644 --- a/electron-builder.yml +++ b/electron-builder.yml @@ -55,6 +55,9 @@ files: - '!node_modules/selection-hook/prebuilds/**/*' # we rebuild .node, don't use prebuilds - '!node_modules/selection-hook/node_modules' # we don't need what in the node_modules dir - '!node_modules/selection-hook/src' # we don't need source files + - '!node_modules/tesseract.js-core/{tesseract-core.js,tesseract-core.wasm,tesseract-core.wasm.js}' # we don't need source files + - '!node_modules/tesseract.js-core/{tesseract-core-lstm.js,tesseract-core-lstm.wasm,tesseract-core-lstm.wasm.js}' # we don't need source files + - '!node_modules/tesseract.js-core/{tesseract-core-simd-lstm.js,tesseract-core-simd-lstm.wasm,tesseract-core-simd-lstm.wasm.js}' # we don't need source files - '!**/*.{h,iobj,ipdb,tlog,recipe,vcxproj,vcxproj.filters,Makefile,*.Makefile}' # filter .node build files asarUnpack: - resources/** @@ -117,36 +120,25 @@ artifactBuildCompleted: scripts/artifact-build-completed.js releaseInfo: releaseNotes: | 🎉 新增功能: - - 集成全新 AI SDK 架构,提升 AI 提供商管理和工具调用性能 - - 新增 OCR 图像文字识别功能,支持图片转文字和翻译 - - 新增代码工具页面,支持环境变量设置和命令行工具管理 - - 新增内置 Web 搜索工具和记忆搜索工具 - - 新增翻译历史搜索和收藏功能 - - 新增模型筛选和推理缓存功能 - - 新增快速模型检测和语言检测功能 - - 支持多种新模型:Qwen Flash、DeepSeek v3.1、Vertex AI 等 + - 新增错误详情模态框,提供完整的错误信息展示和复制功能 + - 新增错误详情的多语言支持(英语、日语、俄语、中文简繁体) 🔧 优化改进: - - 优化 MCP 服务器列表,新增搜索功能 - - 优化 SVG 预览和 HTML 内容样式渲染 - - 优化代码块识别和语言别名支持 - - 优化拖拽列表组件,移除 antd 依赖 - - 优化选择工具栏样式和验证逻辑 - - 优化历史页面搜索性能和加载状态 - - 优化代理管理器的规则匹配和日志记录 - - 优化翻译服务,减少 token 消耗 + - 升级 AI Core 到 v1.0.0-alpha.11,重构模型解析逻辑 + - 增强温度和 TopP 参数处理,特别针对 Claude 推理努力模型优化 + - 改进提供商配置管理,简化 OpenAI 模式处理和服务层级设置 + - 优化 MCP 工具可见性,增强提示工具支持 + - 重构错误序列化机制,提升类型安全性 + - 优化补全方法,支持开发者模式下的追踪功能 + - 改进提供商初始化逻辑,支持动态注册新的 AI 提供商 🐛 问题修复: - - 修复知识库文档预处理失败问题 - - 修复 Windows 平台代码工具路径空格处理 - - 修复选择文本范围验证和空字符串处理 - - 修复 Web 搜索引用丢失问题 - - 修复全屏模式意外退出问题 - - 修复各种模型 API 兼容性问题 - - 修复文件上传和附件预览相关问题 - - 修复多语言翻译格式和显示问题 + - 修复错误处理回调中的类型安全问题,使用 AISDKError 类型 + - 修复提供商初始化和配置相关问题 + - 移除过时的模型解析函数,清理废弃代码 + - 修复 Gemini 集成中的提供商配置缺失问题 ⚡ 性能提升: - - 优化消息处理和工具调用性能 - - 改进错误边界和异常处理 - - 优化内存使用和避免内存泄漏 + - 提升模型参数处理效率,优化温度和 TopP 计算逻辑 + - 优化提供商配置加载和初始化性能 + - 改进错误处理性能,减少不必要的错误格式化开销 diff --git a/package.json b/package.json index 4379c5e5af..49ed7eab47 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "CherryStudio", - "version": "1.6.0-beta.1", + "version": "1.6.0-beta.2", "private": true, "description": "A powerful AI assistant for producer.", "main": "./out/main/index.js", @@ -72,6 +72,7 @@ "dependencies": { "@libsql/client": "0.14.0", "@libsql/win32-x64-msvc": "^0.4.7", + "@napi-rs/system-ocr": "^1.0.2", "@strongtz/win32-arm64-msvc": "^0.4.7", "graceful-fs": "^4.2.11", "jsdom": "26.1.0", @@ -123,7 +124,7 @@ "@eslint-react/eslint-plugin": "^1.36.1", "@eslint/js": "^9.22.0", "@google/genai": "patch:@google/genai@npm%3A1.0.1#~/.yarn/patches/@google-genai-npm-1.0.1-e26f0f9af7.patch", - "@hello-pangea/dnd": "^16.6.0", + "@hello-pangea/dnd": "^18.0.1", "@kangfenmao/keyv-storage": "^0.1.0", "@langchain/community": "^0.3.36", "@langchain/ollama": "^0.2.1", @@ -140,9 +141,9 @@ "@opentelemetry/sdk-trace-web": "^2.0.0", "@playwright/test": "^1.52.0", "@reduxjs/toolkit": "^2.2.5", - "@shikijs/markdown-it": "^3.9.1", + "@shikijs/markdown-it": "^3.12.0", "@swc/plugin-styled-components": "^8.0.4", - "@tanstack/react-query": "^5.27.0", + "@tanstack/react-query": "^5.85.5", "@tanstack/react-virtual": "^3.13.12", "@testing-library/dom": "^10.4.0", "@testing-library/jest-dom": "^6.6.3", @@ -150,7 +151,6 @@ "@testing-library/user-event": "^14.6.1", "@tryfabric/martian": "^1.2.4", "@types/cli-progress": "^3", - "@types/diff": "^7", "@types/fs-extra": "^11", "@types/lodash": "^4.17.5", "@types/markdown-it": "^14", @@ -188,7 +188,7 @@ "dayjs": "^1.11.11", "dexie": "^4.0.8", "dexie-react-hooks": "^1.1.7", - "diff": "^7.0.0", + "diff": "^8.0.2", "docx": "^9.0.2", "dotenv-cli": "^7.4.2", "electron": "37.3.1", @@ -218,14 +218,14 @@ "isbinaryfile": "5.0.4", "jaison": "^2.0.2", "jest-styled-components": "^7.2.0", - "linguist-languages": "^8.0.0", + "linguist-languages": "^8.1.0", "lint-staged": "^15.5.0", "lodash": "^4.17.21", "lru-cache": "^11.1.0", "lucide-react": "^0.525.0", "macos-release": "^3.4.0", "markdown-it": "^14.1.0", - "mermaid": "^11.9.0", + "mermaid": "^11.10.1", "mime": "^4.0.4", "motion": "^12.10.5", "notion-helper": "^1.3.22", @@ -265,7 +265,7 @@ "remove-markdown": "^0.6.2", "rollup-plugin-visualizer": "^5.12.0", "sass": "^1.88.0", - "shiki": "^3.9.1", + "shiki": "^3.12.0", "strict-url-sanitise": "^0.0.1", "string-width": "^7.2.0", "styled-components": "^6.1.11", diff --git a/packages/aiCore/src/core/models/ModelResolver.ts b/packages/aiCore/src/core/models/ModelResolver.ts index 9d336e819e..0f1bde95c6 100644 --- a/packages/aiCore/src/core/models/ModelResolver.ts +++ b/packages/aiCore/src/core/models/ModelResolver.ts @@ -28,8 +28,8 @@ export class ModelResolver { let finalProviderId = fallbackProviderId let model: LanguageModelV2 // 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移) - if (fallbackProviderId === 'openai' && providerOptions?.mode === 'chat') { - finalProviderId = 'openai-chat' + if ((fallbackProviderId === 'openai' || fallbackProviderId === 'azure') && providerOptions?.mode === 'chat') { + finalProviderId = `${fallbackProviderId}-chat` } // 检查是否是命名空间格式 @@ -84,6 +84,7 @@ export class ModelResolver { */ private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 { const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}` + console.log('fullModelId', fullModelId) return globalRegistryManagement.languageModel(fullModelId as any) } diff --git a/packages/aiCore/src/core/providers/registry.ts b/packages/aiCore/src/core/providers/registry.ts index 050333d650..e8cd46770a 100644 --- a/packages/aiCore/src/core/providers/registry.ts +++ b/packages/aiCore/src/core/providers/registry.ts @@ -176,7 +176,7 @@ export function registerProvider(providerId: string, provider: any): boolean { // 处理特殊provider逻辑 if (providerId === 'openai') { // 注册默认 openai - globalRegistryManagement.registerProvider('openai', provider, aliases) + globalRegistryManagement.registerProvider(providerId, provider, aliases) // 创建并注册 openai-chat 变体 const openaiChatProvider = customProvider({ @@ -185,7 +185,17 @@ export function registerProvider(providerId: string, provider: any): boolean { languageModel: (modelId: string) => provider.chat(modelId) } }) - globalRegistryManagement.registerProvider('openai-chat', openaiChatProvider) + globalRegistryManagement.registerProvider(`${providerId}-chat`, openaiChatProvider) + } else if (providerId === 'azure') { + globalRegistryManagement.registerProvider(`${providerId}-chat`, provider, aliases) + // 跟上面相反,creator产出的默认会调用chat + const azureResponsesProvider = customProvider({ + fallbackProvider: { + ...provider, + languageModel: (modelId: string) => provider.responses(modelId) + } + }) + globalRegistryManagement.registerProvider(providerId, azureResponsesProvider) } else { // 其他provider直接注册 globalRegistryManagement.registerProvider(providerId, provider, aliases) diff --git a/packages/aiCore/src/core/providers/schemas.ts b/packages/aiCore/src/core/providers/schemas.ts index f80fdcb3b5..9b5441c8a0 100644 --- a/packages/aiCore/src/core/providers/schemas.ts +++ b/packages/aiCore/src/core/providers/schemas.ts @@ -11,6 +11,39 @@ import { createOpenAICompatible } from '@ai-sdk/openai-compatible' import { createXai } from '@ai-sdk/xai' import * as z from 'zod' +/** + * 基础 Provider IDs + */ +export const baseProviderIds = [ + 'openai', + 'openai-responses', + 'openai-compatible', + 'anthropic', + 'google', + 'xai', + 'azure', + 'deepseek' +] as const + +/** + * 基础 Provider ID Schema + */ +export const baseProviderIdSchema = z.enum(baseProviderIds) + +/** + * 基础 Provider ID + */ +export type BaseProviderId = z.infer + +export const baseProviderSchema = z.object({ + id: baseProviderIdSchema, + name: z.string(), + creator: z.function().args(z.any()).returns(z.any()), + supportsImageGeneration: z.boolean() +}) + +export type BaseProvider = z.infer + /** * 基础 Providers 定义 * 作为唯一数据源,避免重复维护 @@ -64,18 +97,7 @@ export const baseProviders = [ creator: createDeepSeek, supportsImageGeneration: false } -] as const - -/** - * 基础 Provider IDs - * 从 baseProviders 动态生成 - */ -export const baseProviderIds = baseProviders.map((p) => p.id) as unknown as readonly [string, ...string[]] - -/** - * 基础 Provider ID Schema - */ -export const baseProviderIdSchema = z.enum(baseProviderIds) +] as const satisfies BaseProvider[] /** * 用户自定义 Provider ID Schema @@ -117,7 +139,6 @@ export const providerConfigSchema = z * Provider ID 类型 - 基于 zod schema 推导 */ export type ProviderId = z.infer -export type BaseProviderId = z.infer export type CustomProviderId = z.infer /** diff --git a/packages/shared/config/languages.ts b/packages/shared/config/languages.ts index 95b8cab587..42a733bc4a 100644 --- a/packages/shared/config/languages.ts +++ b/packages/shared/config/languages.ts @@ -2020,6 +2020,10 @@ export const languages: Record = { extensions: ['.nginx', '.nginxconf', '.vhost'], aliases: ['nginx configuration file'] }, + Nickel: { + type: 'programming', + extensions: ['.ncl'] + }, Nim: { type: 'programming', extensions: ['.nim', '.nim.cfg', '.nimble', '.nimrod', '.nims'] @@ -3061,7 +3065,7 @@ export const languages: Record = { }, SWIG: { type: 'programming', - extensions: ['.i'] + extensions: ['.i', '.swg', '.swig'] }, SystemVerilog: { type: 'programming', diff --git a/src/main/services/CodeToolsService.ts b/src/main/services/CodeToolsService.ts index 6cc8a41b05..256b4dcbd6 100644 --- a/src/main/services/CodeToolsService.ts +++ b/src/main/services/CodeToolsService.ts @@ -421,7 +421,7 @@ end tell` const envPrefix = buildEnvPrefix(false) const command = envPrefix ? `${envPrefix} && ${baseCommand}` : baseCommand - const linuxTerminals = ['gnome-terminal', 'konsole', 'xterm', 'x-terminal-emulator'] + const linuxTerminals = ['gnome-terminal', 'konsole', 'deepin-terminal', 'xterm', 'x-terminal-emulator'] let foundTerminal = 'xterm' // Default to xterm for (const terminal of linuxTerminals) { @@ -448,6 +448,9 @@ end tell` } else if (foundTerminal === 'konsole') { terminalCommand = 'konsole' terminalArgs = ['--workdir', directory, '-e', 'bash', '-c', `clear && ${command}; exec bash`] + } else if (foundTerminal === 'deepin-terminal') { + terminalCommand = 'deepin-terminal' + terminalArgs = ['-w', directory, '-e', 'bash', '-c', `clear && ${command}; exec bash`] } else { // Default to xterm terminalCommand = 'xterm' diff --git a/src/main/services/ocr/OcrService.ts b/src/main/services/ocr/OcrService.ts index 6ac8c311e3..0d7383a24a 100644 --- a/src/main/services/ocr/OcrService.ts +++ b/src/main/services/ocr/OcrService.ts @@ -1,7 +1,8 @@ import { loggerService } from '@logger' import { BuiltinOcrProviderIds, OcrHandler, OcrProvider, OcrResult, SupportedOcrFile } from '@types' -import { tesseractService } from './tesseract/TesseractService' +import { systemOcrService } from './builtin/SystemOcrService' +import { tesseractService } from './builtin/TesseractService' const logger = loggerService.withContext('OcrService') @@ -24,7 +25,7 @@ export class OcrService { if (!handler) { throw new Error(`Provider ${provider.id} is not registered`) } - return handler(file) + return handler(file, provider.config) } } @@ -32,3 +33,4 @@ export const ocrService = new OcrService() // Register built-in providers ocrService.register(BuiltinOcrProviderIds.tesseract, tesseractService.ocr.bind(tesseractService)) +ocrService.register(BuiltinOcrProviderIds.system, systemOcrService.ocr.bind(systemOcrService)) diff --git a/src/main/services/ocr/builtin/OcrBaseService.ts b/src/main/services/ocr/builtin/OcrBaseService.ts new file mode 100644 index 0000000000..9c36e79c3a --- /dev/null +++ b/src/main/services/ocr/builtin/OcrBaseService.ts @@ -0,0 +1,5 @@ +import { OcrHandler } from '@types' + +export abstract class OcrBaseService { + abstract ocr: OcrHandler +} diff --git a/src/main/services/ocr/builtin/SystemOcrService.ts b/src/main/services/ocr/builtin/SystemOcrService.ts new file mode 100644 index 0000000000..cda52bfec6 --- /dev/null +++ b/src/main/services/ocr/builtin/SystemOcrService.ts @@ -0,0 +1,39 @@ +import { isMac, isWin } from '@main/constant' +import { loadOcrImage } from '@main/utils/ocr' +import { OcrAccuracy, recognize } from '@napi-rs/system-ocr' +import { + ImageFileMetadata, + isImageFileMetadata as isImageFileMetadata, + OcrResult, + OcrSystemConfig, + SupportedOcrFile +} from '@types' + +import { OcrBaseService } from './OcrBaseService' + +// const logger = loggerService.withContext('SystemOcrService') +export class SystemOcrService extends OcrBaseService { + constructor() { + super() + if (!isWin && !isMac) { + throw new Error('System OCR is only supported on Windows and macOS') + } + } + + private async ocrImage(file: ImageFileMetadata, options?: OcrSystemConfig): Promise { + const buffer = await loadOcrImage(file) + const langs = isWin ? options?.langs : undefined + const result = await recognize(buffer, OcrAccuracy.Accurate, langs) + return { text: result.text } + } + + public ocr = async (file: SupportedOcrFile, options?: OcrSystemConfig): Promise => { + if (isImageFileMetadata(file)) { + return this.ocrImage(file, options) + } else { + throw new Error('Unsupported file type, currently only image files are supported') + } + } +} + +export const systemOcrService = new SystemOcrService() diff --git a/src/main/services/ocr/builtin/TesseractService.ts b/src/main/services/ocr/builtin/TesseractService.ts new file mode 100644 index 0000000000..9fd7bbcf01 --- /dev/null +++ b/src/main/services/ocr/builtin/TesseractService.ts @@ -0,0 +1,115 @@ +import { loggerService } from '@logger' +import { getIpCountry } from '@main/utils/ipService' +import { loadOcrImage } from '@main/utils/ocr' +import { MB } from '@shared/config/constant' +import { ImageFileMetadata, isImageFileMetadata, OcrResult, OcrTesseractConfig, SupportedOcrFile } from '@types' +import { app } from 'electron' +import fs from 'fs' +import { isEqual } from 'lodash' +import path from 'path' +import Tesseract, { createWorker, LanguageCode } from 'tesseract.js' + +import { OcrBaseService } from './OcrBaseService' + +const logger = loggerService.withContext('TesseractService') + +// config +const MB_SIZE_THRESHOLD = 50 +const defaultLangs = ['chi_sim', 'chi_tra', 'eng'] satisfies LanguageCode[] +enum TesseractLangsDownloadUrl { + CN = 'https://gitcode.com/beyondkmp/tessdata-best/releases/download/1.0.0/' +} + +export class TesseractService extends OcrBaseService { + private worker: Tesseract.Worker | null = null + private previousLangs: OcrTesseractConfig['langs'] + + constructor() { + super() + this.previousLangs = {} + } + + async getWorker(options?: OcrTesseractConfig): Promise { + let langsArray: LanguageCode[] + if (options?.langs) { + // TODO: use type safe objectKeys + langsArray = Object.keys(options.langs) as LanguageCode[] + if (langsArray.length === 0) { + logger.warn('Empty langs option. Fallback to defaultLangs.') + langsArray = defaultLangs + } + } else { + langsArray = defaultLangs + } + logger.debug('langsArray', langsArray) + if (!this.worker || !isEqual(this.previousLangs, langsArray)) { + if (this.worker) { + await this.dispose() + } + logger.debug('use langsArray to create worker', langsArray) + const langPath = await this._getLangPath() + const cachePath = await this._getCacheDir() + const promise = new Promise((resolve, reject) => { + createWorker(langsArray, undefined, { + langPath, + cachePath, + logger: (m) => logger.debug('From worker', m), + errorHandler: (e) => { + logger.error('Worker Error', e) + reject(e) + } + }) + .then(resolve) + .catch(reject) + }) + this.worker = await promise + } + return this.worker + } + + private async imageOcr(file: ImageFileMetadata, options?: OcrTesseractConfig): Promise { + const worker = await this.getWorker(options) + const stat = await fs.promises.stat(file.path) + if (stat.size > MB_SIZE_THRESHOLD * MB) { + throw new Error(`This image is too large (max ${MB_SIZE_THRESHOLD}MB)`) + } + const buffer = await loadOcrImage(file) + const result = await worker.recognize(buffer) + return { text: result.data.text } + } + + public ocr = async (file: SupportedOcrFile, options?: OcrTesseractConfig): Promise => { + if (!isImageFileMetadata(file)) { + throw new Error('Only image files are supported currently') + } + return this.imageOcr(file, options) + } + + private async _getLangPath(): Promise { + const country = await getIpCountry() + return country.toLowerCase() === 'cn' ? TesseractLangsDownloadUrl.CN : '' + } + + private async _getCacheDir(): Promise { + const cacheDir = path.join(app.getPath('userData'), 'tesseract') + // use access to check if the directory exists + if ( + !(await fs.promises + .access(cacheDir, fs.constants.F_OK) + .then(() => true) + .catch(() => false)) + ) { + await fs.promises.mkdir(cacheDir, { recursive: true }) + } + return cacheDir + } + + async dispose(): Promise { + if (this.worker) { + await this.worker.terminate() + this.worker = null + } + } +} + +export const tesseractService = new TesseractService() diff --git a/src/main/services/ocr/tesseract/TesseractService.ts b/src/main/services/ocr/tesseract/TesseractService.ts deleted file mode 100644 index d2ba6d2ed8..0000000000 --- a/src/main/services/ocr/tesseract/TesseractService.ts +++ /dev/null @@ -1,82 +0,0 @@ -import { loggerService } from '@logger' -import { getIpCountry } from '@main/utils/ipService' -import { loadOcrImage } from '@main/utils/ocr' -import { MB } from '@shared/config/constant' -import { ImageFileMetadata, isImageFile, OcrResult, SupportedOcrFile } from '@types' -import { app } from 'electron' -import fs from 'fs' -import path from 'path' -import Tesseract, { createWorker, LanguageCode } from 'tesseract.js' - -const logger = loggerService.withContext('TesseractService') - -// config -const MB_SIZE_THRESHOLD = 50 -const tesseractLangs = ['chi_sim', 'chi_tra', 'eng'] satisfies LanguageCode[] -enum TesseractLangsDownloadUrl { - CN = 'https://gitcode.com/beyondkmp/tessdata/releases/download/4.1.0/', - GLOBAL = 'https://github.com/tesseract-ocr/tessdata/raw/main/' -} - -export class TesseractService { - private worker: Tesseract.Worker | null = null - - async getWorker(): Promise { - if (!this.worker) { - // for now, only support limited languages - this.worker = await createWorker(tesseractLangs, undefined, { - langPath: await this._getLangPath(), - cachePath: await this._getCacheDir(), - gzip: false, - logger: (m) => logger.debug('From worker', m) - }) - } - return this.worker - } - - async imageOcr(file: ImageFileMetadata): Promise { - const worker = await this.getWorker() - const stat = await fs.promises.stat(file.path) - if (stat.size > MB_SIZE_THRESHOLD * MB) { - throw new Error(`This image is too large (max ${MB_SIZE_THRESHOLD}MB)`) - } - const buffer = await loadOcrImage(file) - const result = await worker.recognize(buffer) - return { text: result.data.text } - } - - async ocr(file: SupportedOcrFile): Promise { - if (!isImageFile(file)) { - throw new Error('Only image files are supported currently') - } - return this.imageOcr(file) - } - - private async _getLangPath(): Promise { - const country = await getIpCountry() - return country.toLowerCase() === 'cn' ? TesseractLangsDownloadUrl.CN : TesseractLangsDownloadUrl.GLOBAL - } - - private async _getCacheDir(): Promise { - const cacheDir = path.join(app.getPath('userData'), 'tesseract') - // use access to check if the directory exists - if ( - !(await fs.promises - .access(cacheDir, fs.constants.F_OK) - .then(() => true) - .catch(() => false)) - ) { - await fs.promises.mkdir(cacheDir, { recursive: true }) - } - return cacheDir - } - - async dispose(): Promise { - if (this.worker) { - await this.worker.terminate() - this.worker = null - } - } -} - -export const tesseractService = new TesseractService() diff --git a/src/main/utils/ocr.ts b/src/main/utils/ocr.ts index ca63e82f07..446fbe63d6 100644 --- a/src/main/utils/ocr.ts +++ b/src/main/utils/ocr.ts @@ -2,11 +2,12 @@ import { ImageFileMetadata } from '@types' import { readFile } from 'fs/promises' import sharp from 'sharp' -const preprocessImage = async (buffer: Buffer) => { - return await sharp(buffer) +const preprocessImage = async (buffer: Buffer): Promise => { + return sharp(buffer) .grayscale() // 转为灰度 .normalize() .sharpen() + .png({ quality: 100 }) .toBuffer() } @@ -23,5 +24,5 @@ const preprocessImage = async (buffer: Buffer) => { */ export const loadOcrImage = async (file: ImageFileMetadata): Promise => { const buffer = await readFile(file.path) - return await preprocessImage(buffer) + return preprocessImage(buffer) } diff --git a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts index 0430c750d2..d441f228c1 100644 --- a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts +++ b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts @@ -28,11 +28,14 @@ export interface CherryStudioChunk { */ export class AiSdkToChunkAdapter { toolCallHandler: ToolCallChunkHandler + private accumulate: boolean | undefined constructor( private onChunk: (chunk: Chunk) => void, - mcpTools: MCPTool[] = [] + mcpTools: MCPTool[] = [], + accumulate?: boolean ) { this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools) + this.accumulate = accumulate } /** @@ -95,7 +98,11 @@ export class AiSdkToChunkAdapter { }) break case 'text-delta': - final.text += chunk.text || '' + if (this.accumulate) { + final.text += chunk.text || '' + } else { + final.text = chunk.text || '' + } this.onChunk({ type: ChunkType.TEXT_DELTA, text: final.text || '' diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 849ec54c63..9d94fdb544 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -22,19 +22,33 @@ import LegacyAiProvider from './legacy/index' import { CompletionsResult } from './legacy/middleware/schemas' import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder' import { buildPlugins } from './plugins/PluginBuilder' -import { getActualProvider, isModernSdkSupported, providerToAiSdkConfig } from './provider/ProviderConfigProcessor' +import { + getActualProvider, + isModernSdkSupported, + prepareSpecialProviderConfig, + providerToAiSdkConfig +} from './provider/providerConfig' import type { StreamTextParams } from './types' const logger = loggerService.withContext('ModernAiProvider') +export type ModernAiProviderConfig = AiSdkMiddlewareConfig & { + assistant: Assistant + // topicId for tracing + topicId?: string + callType: string +} + export default class ModernAiProvider { private legacyProvider: LegacyAiProvider private config: ReturnType private actualProvider: Provider + private model: Model constructor(model: Model, provider?: Provider) { this.actualProvider = provider || getActualProvider(model) this.legacyProvider = new LegacyAiProvider(this.actualProvider) + this.model = model // 只保存配置,不预先创建executor this.config = providerToAiSdkConfig(this.actualProvider, model) @@ -44,16 +58,11 @@ export default class ModernAiProvider { return this.actualProvider } - public async completions( - modelId: string, - params: StreamTextParams, - config: AiSdkMiddlewareConfig & { - assistant: Assistant - // topicId for tracing - topicId?: string - callType: string - } - ) { + public async completions(modelId: string, params: StreamTextParams, config: ModernAiProviderConfig) { + // 准备特殊配置 + await prepareSpecialProviderConfig(this.actualProvider, this.config) + + logger.debug('this.config', this.config) if (config.topicId && getEnableDeveloperMode()) { // TypeScript类型窄化:确保topicId是string类型 const traceConfig = { @@ -69,12 +78,7 @@ export default class ModernAiProvider { private async _completions( modelId: string, params: StreamTextParams, - config: AiSdkMiddlewareConfig & { - assistant: Assistant - // topicId for tracing - topicId?: string - callType: string - } + config: ModernAiProviderConfig ): Promise { // 初始化 provider 到全局管理器 try { @@ -105,12 +109,7 @@ export default class ModernAiProvider { private async _completionsForTrace( modelId: string, params: StreamTextParams, - config: AiSdkMiddlewareConfig & { - assistant: Assistant - // topicId for tracing - topicId: string - callType: string - } + config: ModernAiProviderConfig & { topicId: string } ): Promise { const traceName = `${this.actualProvider.name}.${modelId}.${config.callType}` const traceParams: StartSpanParams = { @@ -193,11 +192,7 @@ export default class ModernAiProvider { private async modernCompletions( modelId: string, params: StreamTextParams, - config: AiSdkMiddlewareConfig & { - assistant: Assistant - topicId?: string - callType: string - } + config: ModernAiProviderConfig ): Promise { logger.info('Starting modernCompletions', { modelId, @@ -244,7 +239,8 @@ export default class ModernAiProvider { topicId: config.topicId }) - const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools) + const accumulate = this.model.supported_text_delta !== false // true and undefined + const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools, accumulate) logger.debug('Final params before streamText', { modelId, @@ -326,12 +322,7 @@ export default class ModernAiProvider { private async modernImageGeneration( modelId: string, params: StreamTextParams, - config: AiSdkMiddlewareConfig & { - assistant: Assistant - // topicId for tracing - topicId?: string - callType: string - } + config: ModernAiProviderConfig ): Promise { const { onChunk } = config diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts index feebf487e0..5649b86984 100644 --- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts @@ -46,6 +46,7 @@ import { EFFORT_RATIO, FileTypes, isSystemProvider, + isTranslateAssistant, MCPCallToolResponse, MCPTool, MCPToolResponse, @@ -54,7 +55,6 @@ import { Provider, SystemProviderIds, ToolCallResponse, - TranslateAssistant, WebSearchSource } from '@renderer/types' import { ChunkType, TextStartChunk, ThinkingStartChunk } from '@renderer/types/chunk' @@ -569,13 +569,18 @@ export class OpenAIAPIClient extends OpenAIBaseClient< const extra_body: Record = {} if (isQwenMTModel(model)) { - const targetLanguage = (assistant as TranslateAssistant).targetLanguage - extra_body.translation_options = { - source_lang: 'auto', - target_lang: mapLanguageToQwenMTModel(targetLanguage!) - } - if (!extra_body.translation_options.target_lang) { - throw new Error(t('translate.error.not_supported', { language: targetLanguage?.value })) + if (isTranslateAssistant(assistant)) { + const targetLanguage = assistant.targetLanguage + const translationOptions = { + source_lang: 'auto', + target_lang: mapLanguageToQwenMTModel(targetLanguage) + } as const + if (!translationOptions.target_lang) { + throw new Error(t('translate.error.not_supported', { language: targetLanguage.value })) + } + extra_body.translation_options = translationOptions + } else { + throw new Error(t('translate.error.chat_qwen_mt')) } } diff --git a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts index e57481eda5..3e6ca8cc3d 100644 --- a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts +++ b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts @@ -365,7 +365,7 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) getMessageContent(userMessage), topicId ) - params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' } + // params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' } } else { // on 模式:根据意图识别结果决定是否添加工具 const needsKnowledgeSearch = diff --git a/src/renderer/src/aiCore/provider/factory.ts b/src/renderer/src/aiCore/provider/factory.ts index 0aa7583aaf..752450ea0c 100644 --- a/src/renderer/src/aiCore/provider/factory.ts +++ b/src/renderer/src/aiCore/provider/factory.ts @@ -2,7 +2,7 @@ import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } fr import { loggerService } from '@logger' import { Provider } from '@renderer/types' -import { initializeNewProviders } from './providerConfigs' +import { initializeNewProviders } from './providerInitialization' const logger = loggerService.withContext('ProviderFactory') diff --git a/src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts b/src/renderer/src/aiCore/provider/providerConfig.ts similarity index 79% rename from src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts rename to src/renderer/src/aiCore/provider/providerConfig.ts index a72f7506e5..51f9c52db7 100644 --- a/src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -8,6 +8,7 @@ import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models' import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import { getProviderByModel } from '@renderer/services/AssistantService' import { loggerService } from '@renderer/services/LoggerService' +import store from '@renderer/store' import type { Model, Provider } from '@renderer/types' import { formatApiHost } from '@renderer/utils/api' import { cloneDeep } from 'lodash' @@ -84,7 +85,7 @@ export function providerToAiSdkConfig( baseURL: actualProvider.apiHost, apiKey: actualProvider.apiKey } - // 处理OpenAI模式(简化逻辑) + // 处理OpenAI模式 const extraOptions: any = {} if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) { extraOptions.mode = 'responses' @@ -97,6 +98,26 @@ export function providerToAiSdkConfig( extraOptions.headers = actualProvider.extra_headers } + // copilot + if (actualProvider.id === 'copilot') { + extraOptions.headers = { + ...extraOptions.extra_headers, + 'editor-version': 'vscode/1.97.2', + 'copilot-vision-request': 'true' + } + } + // azure + if (aiSdkProviderId === 'azure' || actualProvider.type === 'azure-openai') { + extraOptions.apiVersion = actualProvider.apiVersion + baseConfig.baseURL += '/openai' + if (actualProvider.apiVersion === 'preview') { + extraOptions.mode = 'responses' + } else { + extraOptions.mode = 'chat' + extraOptions.useDeploymentBasedUrls = true + } + } + // 如果AI SDK支持该provider,使用原生配置 if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions) @@ -134,3 +155,18 @@ export function isModernSdkSupported(provider: Provider): boolean { // 如果映射到了支持的provider,则支持现代SDK return hasProviderConfig(aiSdkProviderId) } + +/** + * 准备特殊provider的配置,主要用于异步处理的配置 + */ +export async function prepareSpecialProviderConfig( + provider: Provider, + config: ReturnType +) { + if (provider.id === 'copilot') { + const defaultHeaders = store.getState().copilot.defaultHeaders + const { token } = await window.api.copilot.getToken(defaultHeaders) + config.options.apiKey = token + } + return config +} diff --git a/src/renderer/src/aiCore/provider/providerConfigs.ts b/src/renderer/src/aiCore/provider/providerInitialization.ts similarity index 100% rename from src/renderer/src/aiCore/provider/providerConfigs.ts rename to src/renderer/src/aiCore/provider/providerInitialization.ts diff --git a/src/renderer/src/aiCore/tools/KnowledgeSearchTool.ts b/src/renderer/src/aiCore/tools/KnowledgeSearchTool.ts index f0824b1d8d..f156061a8c 100644 --- a/src/renderer/src/aiCore/tools/KnowledgeSearchTool.ts +++ b/src/renderer/src/aiCore/tools/KnowledgeSearchTool.ts @@ -46,9 +46,7 @@ Call this tool to execute the search. You can optionally provide additional cont return { summary: 'No knowledge base configured for this assistant.', knowledgeReferences: [], - sources: '', - instructions: '', - rawResults: [] + instructions: '' } } @@ -69,9 +67,7 @@ Call this tool to execute the search. You can optionally provide additional cont return { summary: 'No search needed based on the query analysis.', knowledgeReferences: [], - sources: '', - instructions: '', - rawResults: [] + instructions: '' } } @@ -109,22 +105,20 @@ Call this tool to execute the search. You can optionally provide additional cont file: ref.file })) - const referenceContent = `\`\`\`json\n${JSON.stringify(knowledgeReferencesData, null, 2)}\n\`\`\`` - + // const referenceContent = `\`\`\`json\n${JSON.stringify(knowledgeReferencesData, null, 2)}\n\`\`\`` + // TODO 在工具函数中添加搜索缓存机制 + // const searchCacheKey = `${topicId}-${JSON.stringify(finalQueries)}` + // 可以在插件层面管理已搜索的查询,避免重复搜索 const fullInstructions = REFERENCE_PROMPT.replace( '{question}', "Based on the knowledge references, please answer the user's question with proper citations." - ).replace('{references}', referenceContent) + ).replace('{references}', 'knowledgeReferences:') // 返回结果 return { summary: `Found ${knowledgeReferencesData.length} relevant sources. Use [number] format to cite specific information.`, knowledgeReferences: knowledgeReferencesData, - // sources: citationData - // .map((source) => `[${source.number}] ${source.title}\n${source.content}\nURL: ${source.url}`) - // .join('\n\n'), instructions: fullInstructions - // rawResults: citationData } } catch (error) { // 返回空对象而不是抛出错误,避免中断对话流程 @@ -132,7 +126,6 @@ Call this tool to execute the search. You can optionally provide additional cont summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`, knowledgeReferences: [], instructions: '' - // rawResults: [] } } } diff --git a/src/renderer/src/aiCore/transformParameters.ts b/src/renderer/src/aiCore/transformParameters.ts index d5fdace254..9e6e1bc8e3 100644 --- a/src/renderer/src/aiCore/transformParameters.ts +++ b/src/renderer/src/aiCore/transformParameters.ts @@ -4,8 +4,8 @@ */ import { loggerService } from '@logger' -import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { + isClaudeReasoningModel, isGenerateImageModel, isNotSupportTemperatureAndTopP, isOpenRouterBuiltInWebSearchModel, @@ -44,15 +44,29 @@ const logger = loggerService.withContext('transformParameters') /** * 获取温度参数 */ -export function getTemperature(assistant: Assistant, model: Model): number | undefined { - return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.temperature +function getTemperature(assistant: Assistant, model: Model): number | undefined { + if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { + return undefined + } + if (isNotSupportTemperatureAndTopP(model)) { + return undefined + } + const assistantSettings = getAssistantSettings(assistant) + return assistantSettings?.enableTemperature ? assistantSettings?.temperature : undefined } /** * 获取 TopP 参数 */ -export function getTopP(assistant: Assistant, model: Model): number | undefined { - return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.topP +function getTopP(assistant: Assistant, model: Model): number | undefined { + if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { + return undefined + } + if (isNotSupportTemperatureAndTopP(model)) { + return undefined + } + const assistantSettings = getAssistantSettings(assistant) + return assistantSettings?.enableTopP ? assistantSettings?.topP : undefined } /** @@ -360,7 +374,7 @@ export async function buildStreamTextParams( // 构建基础参数 const params: StreamTextParams = { messages: sdkMessages, - maxOutputTokens: maxTokens || DEFAULT_MAX_TOKENS, + maxOutputTokens: maxTokens, temperature: getTemperature(assistant, model), topP: getTopP(assistant, model), abortSignal: options.requestOptions?.signal, @@ -372,7 +386,7 @@ export async function buildStreamTextParams( if (assistant.prompt) { params.system = assistant.prompt } - + logger.debug('params', params) return { params, modelId: model.id, diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index 65890d5ba7..687c9c7a0c 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -1,15 +1,19 @@ -import { isOpenAIModel, isSupportFlexServiceTierModel } from '@renderer/config/models' +import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-core/provider' +import { isOpenAIModel, isQwenMTModel, isSupportFlexServiceTierModel } from '@renderer/config/models' import { isSupportServiceTierProvider } from '@renderer/config/providers' +import { mapLanguageToQwenMTModel } from '@renderer/config/translate' import { Assistant, GroqServiceTiers, isGroqServiceTier, isOpenAIServiceTier, + isTranslateAssistant, Model, OpenAIServiceTiers, Provider, SystemProviderIds } from '@renderer/types' +import { t } from 'i18next' import { getAiSdkProviderId } from '../provider/factory' import { buildGeminiGenerateImageParams } from './image' @@ -67,49 +71,77 @@ export function buildProviderOptions( enableGenerateImage: boolean } ): Record { - const providerId = getAiSdkProviderId(actualProvider) - const serviceTierSetting = getServiceTier(model, actualProvider) + const rawProviderId = getAiSdkProviderId(actualProvider) // 构建 provider 特定的选项 let providerSpecificOptions: Record = {} - + const serviceTierSetting = getServiceTier(model, actualProvider) + providerSpecificOptions.serviceTier = serviceTierSetting // 根据 provider 类型分离构建逻辑 - switch (providerId) { - case 'openai': - case 'azure': - providerSpecificOptions = { - ...buildOpenAIProviderOptions(assistant, model, capabilities) + const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId) + if (success) { + // 应该覆盖所有类型 + switch (baseProviderId) { + case 'openai': + case 'azure': + providerSpecificOptions = { + ...buildOpenAIProviderOptions(assistant, model, capabilities), + serviceTier: serviceTierSetting + } + break + + case 'anthropic': + providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities) + break + + case 'google': + providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities) + break + + case 'xai': + providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities) + break + case 'deepseek': + case 'openai-compatible': + case 'openai-responses': + // 对于其他 provider,使用通用的构建逻辑 + providerSpecificOptions = { + ...buildGenericProviderOptions(assistant, model, capabilities), + serviceTier: serviceTierSetting + } + break + default: + throw new Error(`Unsupported base provider ${baseProviderId}`) + } + } else { + // 处理自定义 provider + const { data: providerId, success, error } = customProviderIdSchema.safeParse(rawProviderId) + if (success) { + switch (providerId) { + // 非 base provider 的单独处理逻辑 + case 'google-vertex': + providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities) + break + default: + // 对于其他 provider,使用通用的构建逻辑 + providerSpecificOptions = { + ...buildGenericProviderOptions(assistant, model, capabilities), + serviceTier: serviceTierSetting + } } - break - - case 'anthropic': - providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities) - break - - case 'google': - case 'google-vertex': - providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities) - break - - case 'xai': - providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities) - break - - default: - // 对于其他 provider,使用通用的构建逻辑 - providerSpecificOptions = buildGenericProviderOptions(assistant, model, capabilities) - break + } else { + throw error + } } // 合并自定义参数到 provider 特定的选项中 providerSpecificOptions = { ...providerSpecificOptions, - serviceTier: serviceTierSetting, ...getCustomParameters(assistant) } // 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } return { - [providerId]: providerSpecificOptions + [rawProviderId]: providerSpecificOptions } } @@ -127,7 +159,6 @@ function buildOpenAIProviderOptions( ): Record { const { enableReasoning } = capabilities let providerOptions: Record = {} - // OpenAI 推理参数 if (enableReasoning) { const reasoningParams = getOpenAIReasoningParams(assistant, model) @@ -136,7 +167,6 @@ function buildOpenAIProviderOptions( ...reasoningParams } } - return providerOptions } @@ -253,5 +283,22 @@ function buildGenericProviderOptions( } } + // 特殊处理 Qwen MT + if (isQwenMTModel(model)) { + if (isTranslateAssistant(assistant)) { + const targetLanguage = assistant.targetLanguage + const translationOptions = { + source_lang: 'auto', + target_lang: mapLanguageToQwenMTModel(targetLanguage) + } as const + if (!translationOptions.target_lang) { + throw new Error(t('translate.error.not_supported', { language: targetLanguage.value })) + } + providerOptions.translation_options = translationOptions + } else { + throw new Error(t('translate.error.chat_qwen_mt')) + } + } + return providerOptions } diff --git a/src/renderer/src/aiCore/utils/reasoning.ts b/src/renderer/src/aiCore/utils/reasoning.ts index 07f1958fe9..507b2cd9ce 100644 --- a/src/renderer/src/aiCore/utils/reasoning.ts +++ b/src/renderer/src/aiCore/utils/reasoning.ts @@ -1,8 +1,15 @@ +import { loggerService } from '@logger' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { findTokenLimit, GEMINI_FLASH_MODEL_REGEX, + getThinkModelType, + isDeepSeekHybridInferenceModel, isDoubaoThinkingAutoModel, + isGrokReasoningModel, + isOpenAIReasoningModel, + isQwenAlwaysThinkModel, + isQwenReasoningModel, isReasoningModel, isSupportedReasoningEffortGrokModel, isSupportedReasoningEffortModel, @@ -10,15 +17,22 @@ import { isSupportedThinkingTokenClaudeModel, isSupportedThinkingTokenDoubaoModel, isSupportedThinkingTokenGeminiModel, + isSupportedThinkingTokenHunyuanModel, isSupportedThinkingTokenModel, - isSupportedThinkingTokenQwenModel + isSupportedThinkingTokenQwenModel, + isSupportedThinkingTokenZhipuModel, + MODEL_SUPPORTED_REASONING_EFFORT } from '@renderer/config/models' +import { isSupportEnableThinkingProvider } from '@renderer/config/providers' import { getStoreSetting } from '@renderer/hooks/useSettings' import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService' import { SettingsState } from '@renderer/store/settings' -import { Assistant, EFFORT_RATIO, Model } from '@renderer/types' +import { Assistant, EFFORT_RATIO, isSystemProvider, Model, SystemProviderIds } from '@renderer/types' import { ReasoningEffortOptionalParams } from '@renderer/types/sdk' +const logger = loggerService.withContext('reasoning') + +// The function is only for generic provider. May extract some logics to independent provider export function getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams { const provider = getProviderByModel(model) if (provider.id === 'groq') { @@ -31,62 +45,35 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin const reasoningEffort = assistant?.settings?.reasoning_effort if (!reasoningEffort) { - if (model.provider === 'openrouter') { - return { reasoning: { enabled: false } } - } - if (isSupportedThinkingTokenQwenModel(model)) { - return { enable_thinking: false } - } - - if (isSupportedThinkingTokenClaudeModel(model)) { - return {} - } - - if (isSupportedThinkingTokenGeminiModel(model)) { - if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) { - return { reasoning_effort: 'none' } - } - return {} - } - - if (isSupportedThinkingTokenDoubaoModel(model)) { - return { thinking: { type: 'disabled' } } - } - - return {} - } - - // Doubao 思考模式支持 - if (isSupportedThinkingTokenDoubaoModel(model)) { - // reasoningEffort 为空,默认开启 enabled - if (!reasoningEffort) { - return { thinking: { type: 'disabled' } } - } - if (reasoningEffort === 'high') { - return { thinking: { type: 'enabled' } } - } - if (reasoningEffort === 'auto' && isDoubaoThinkingAutoModel(model)) { - return { thinking: { type: 'auto' } } - } - // 其他情况不带 thinking 字段 - return {} - } - - if (!reasoningEffort) { - if (model.provider === 'openrouter') { + // openrouter: use reasoning + if (model.provider === SystemProviderIds.openrouter) { + // Don't disable reasoning for Gemini models that support thinking tokens if (isSupportedThinkingTokenGeminiModel(model) && !GEMINI_FLASH_MODEL_REGEX.test(model.id)) { return {} } + // Don't disable reasoning for models that require it + if (isGrokReasoningModel(model) || isOpenAIReasoningModel(model)) { + return {} + } return { reasoning: { enabled: false, exclude: true } } } - if (isSupportedThinkingTokenQwenModel(model)) { + + // providers that use enable_thinking + if ( + isSupportEnableThinkingProvider(provider) && + (isSupportedThinkingTokenQwenModel(model) || + isSupportedThinkingTokenHunyuanModel(model) || + (provider.id === SystemProviderIds.dashscope && isDeepSeekHybridInferenceModel(model))) + ) { return { enable_thinking: false } } + // claude if (isSupportedThinkingTokenClaudeModel(model)) { return {} } + // gemini if (isSupportedThinkingTokenGeminiModel(model)) { if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) { return { @@ -102,19 +89,57 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin return {} } - if (isSupportedThinkingTokenDoubaoModel(model)) { + // use thinking, doubao, zhipu, etc. + if (isSupportedThinkingTokenDoubaoModel(model) || isSupportedThinkingTokenZhipuModel(model)) { return { thinking: { type: 'disabled' } } } return {} } - const effortRatio = EFFORT_RATIO[reasoningEffort] - const budgetTokens = Math.floor( - (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min! - ) + + // reasoningEffort有效的情况 + // DeepSeek hybrid inference models, v3.1 and maybe more in the future + // 不同的 provider 有不同的思考控制方式,在这里统一解决 + if (isDeepSeekHybridInferenceModel(model)) { + if (isSystemProvider(provider)) { + switch (provider.id) { + case SystemProviderIds.dashscope: + return { + enable_thinking: true, + incremental_output: true + } + case SystemProviderIds.silicon: + return { + enable_thinking: true + } + case SystemProviderIds.doubao: + return { + thinking: { + type: 'enabled' // auto is invalid + } + } + case SystemProviderIds.openrouter: + return { + reasoning: { + enabled: true + } + } + case 'nvidia': + return { + chat_template_kwargs: { + thinking: true + } + } + default: + logger.warn( + `Skipping thinking options for provider ${provider.name} as DeepSeek v3.1 thinking control method is unknown` + ) + } + } + } // OpenRouter models - if (model.provider === 'openrouter') { + if (model.provider === SystemProviderIds.openrouter) { if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) { return { reasoning: { @@ -124,28 +149,75 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin } } - // Qwen models - if (isSupportedThinkingTokenQwenModel(model)) { - return { - enable_thinking: true, + // Doubao 思考模式支持 + if (isSupportedThinkingTokenDoubaoModel(model)) { + // reasoningEffort 为空,默认开启 enabled + if (reasoningEffort === 'high') { + return { thinking: { type: 'enabled' } } + } + if (reasoningEffort === 'auto' && isDoubaoThinkingAutoModel(model)) { + return { thinking: { type: 'auto' } } + } + // 其他情况不带 thinking 字段 + return {} + } + + const effortRatio = EFFORT_RATIO[reasoningEffort] + const budgetTokens = Math.floor( + (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min! + ) + + // OpenRouter models, use thinking + if (model.provider === SystemProviderIds.openrouter) { + if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) { + return { + reasoning: { + effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort + } + } + } + } + + // Qwen models, use enable_thinking + if (isQwenReasoningModel(model)) { + const thinkConfig = { + enable_thinking: isQwenAlwaysThinkModel(model) || !isSupportEnableThinkingProvider(provider) ? undefined : true, thinking_budget: budgetTokens } + if (provider.id === SystemProviderIds.dashscope) { + return { + ...thinkConfig, + incremental_output: true + } + } + return thinkConfig } - // Grok models - if (isSupportedReasoningEffortGrokModel(model)) { + // Hunyuan models, use enable_thinking + if (isSupportedThinkingTokenHunyuanModel(model) && isSupportEnableThinkingProvider(provider)) { return { - reasoningEffort: reasoningEffort + enable_thinking: true } } - // OpenAI models - if (isSupportedReasoningEffortOpenAIModel(model)) { - return { - reasoningEffort: reasoningEffort + // Grok models/Perplexity models/OpenAI models, use reasoning_effort + if (isSupportedReasoningEffortModel(model)) { + // 检查模型是否支持所选选项 + const modelType = getThinkModelType(model) + const supportedOptions = MODEL_SUPPORTED_REASONING_EFFORT[modelType] + if (supportedOptions.includes(reasoningEffort)) { + return { + reasoning_effort: reasoningEffort + } + } else { + // 如果不支持,fallback到第一个支持的值 + return { + reasoning_effort: supportedOptions[0] + } } } + // gemini series, openai compatible api if (isSupportedThinkingTokenGeminiModel(model)) { if (reasoningEffort === 'auto') { return { @@ -171,7 +243,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin } } - // Claude models + // Claude models, openai compatible api if (isSupportedThinkingTokenClaudeModel(model)) { const maxTokens = assistant.settings?.maxTokens return { @@ -184,7 +256,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin } } - // Doubao models + // Use thinking, doubao, zhipu, etc. if (isSupportedThinkingTokenDoubaoModel(model)) { if (assistant.settings?.reasoning_effort === 'high') { return { @@ -194,6 +266,9 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin } } } + if (isSupportedThinkingTokenZhipuModel(model)) { + return { thinking: { type: 'enabled' } } + } // Default case: no special thinking settings return {} diff --git a/src/renderer/src/components/CodeBlockView/HtmlArtifactsCard.tsx b/src/renderer/src/components/CodeBlockView/HtmlArtifactsCard.tsx index acb4a9c4f1..13d13c55a9 100644 --- a/src/renderer/src/components/CodeBlockView/HtmlArtifactsCard.tsx +++ b/src/renderer/src/components/CodeBlockView/HtmlArtifactsCard.tsx @@ -2,7 +2,7 @@ import { CodeOutlined } from '@ant-design/icons' import { loggerService } from '@logger' import { useTheme } from '@renderer/context/ThemeProvider' import { ThemeMode } from '@renderer/types' -import { extractTitle } from '@renderer/utils/formats' +import { extractHtmlTitle, getFileNameFromHtmlTitle } from '@renderer/utils/formats' import { Button } from 'antd' import { Code, DownloadIcon, Globe, LinkIcon, Sparkles } from 'lucide-react' import { FC, useState } from 'react' @@ -28,7 +28,7 @@ const getTerminalStyles = (theme: ThemeMode) => ({ const HtmlArtifactsCard: FC = ({ html, onSave, isStreaming = false }) => { const { t } = useTranslation() - const title = extractTitle(html) || 'HTML Artifacts' + const title = extractHtmlTitle(html) || 'HTML Artifacts' const [isPopupOpen, setIsPopupOpen] = useState(false) const { theme } = useTheme() @@ -48,7 +48,7 @@ const HtmlArtifactsCard: FC = ({ html, onSave, isStreaming = false }) => } const handleDownload = async () => { - const fileName = `${title.replace(/[^a-zA-Z0-9\s]/g, '').replace(/\s+/g, '-') || 'html-artifact'}.html` + const fileName = `${getFileNameFromHtmlTitle(title) || 'html-artifact'}.html` await window.api.file.save(fileName, htmlContent) window.message.success({ content: t('message.download.success'), key: 'download' }) } diff --git a/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx b/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx index 216e247701..8cdf4e4d45 100644 --- a/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx +++ b/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx @@ -1,9 +1,13 @@ import CodeEditor, { CodeEditorHandles } from '@renderer/components/CodeEditor' +import { CopyIcon, FilePngIcon } from '@renderer/components/Icons' import { isLinux, isMac, isWin } from '@renderer/config/constant' +import { useTemporaryValue } from '@renderer/hooks/useTemporaryValue' import { classNames } from '@renderer/utils' -import { Button, Modal, Splitter, Tooltip, Typography } from 'antd' -import { Code, Eye, Maximize2, Minimize2, SaveIcon, SquareSplitHorizontal, X } from 'lucide-react' -import { useEffect, useRef, useState } from 'react' +import { extractHtmlTitle, getFileNameFromHtmlTitle } from '@renderer/utils/formats' +import { captureScrollableIframeAsBlob, captureScrollableIframeAsDataURL } from '@renderer/utils/image' +import { Button, Dropdown, Modal, Splitter, Tooltip, Typography } from 'antd' +import { Camera, Check, Code, Eye, Maximize2, Minimize2, SaveIcon, SquareSplitHorizontal, X } from 'lucide-react' +import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import styled from 'styled-components' @@ -21,7 +25,9 @@ const HtmlArtifactsPopup: React.FC = ({ open, title, ht const { t } = useTranslation() const [viewMode, setViewMode] = useState('split') const [isFullscreen, setIsFullscreen] = useState(false) + const [saved, setSaved] = useTemporaryValue(false, 2000) const codeEditorRef = useRef(null) + const previewFrameRef = useRef(null) // Prevent body scroll when fullscreen useEffect(() => { @@ -38,8 +44,32 @@ const HtmlArtifactsPopup: React.FC = ({ open, title, ht const handleSave = () => { codeEditorRef.current?.save?.() + setSaved(true) } + const handleCapture = useCallback( + async (to: 'file' | 'clipboard') => { + const title = extractHtmlTitle(html) + const fileName = getFileNameFromHtmlTitle(title) || 'html-artifact' + + if (to === 'file') { + const dataUrl = await captureScrollableIframeAsDataURL(previewFrameRef) + if (dataUrl) { + window.api.file.saveImage(fileName, dataUrl) + } + } + if (to === 'clipboard') { + await captureScrollableIframeAsBlob(previewFrameRef, async (blob) => { + if (blob) { + await navigator.clipboard.write([new ClipboardItem({ 'image/png': blob })]) + window.message.success(t('message.copy.success')) + } + }) + } + }, + [html, t] + ) + const renderHeader = () => ( setIsFullscreen(!isFullscreen)} className={classNames({ drag: isFullscreen })}> @@ -47,7 +77,7 @@ const HtmlArtifactsPopup: React.FC = ({ open, title, ht - + e.stopPropagation()}> = ({ open, title, ht - + e.stopPropagation()}> + , + onClick: () => handleCapture('file') + }, + { + label: t('html_artifacts.capture.to_clipboard'), + key: 'capture_to_clipboard', + icon: , + onClick: () => handleCapture('clipboard') + } + ] + }}> + +