Merge branch 'main' of github.com:CherryHQ/cherry-studio into wip/data-refactor

This commit is contained in:
fullex 2025-09-14 17:05:14 +08:00
commit 57fd73e51a
356 changed files with 13168 additions and 5263 deletions

View File

@ -1,7 +1,7 @@
name: Auto I18N
env:
API_KEY: ${{ secrets.TRANSLATE_API_KEY}}
API_KEY: ${{ secrets.TRANSLATE_API_KEY }}
MODEL: ${{ vars.MODEL || 'deepseek/deepseek-v3.1'}}
BASE_URL: ${{ vars.BASE_URL || 'https://api.ppinfra.com/openai'}}
@ -35,7 +35,7 @@ jobs:
# 在临时目录安装依赖
mkdir -p /tmp/translation-deps
cd /tmp/translation-deps
echo '{"dependencies": {"openai": "^5.12.2", "cli-progress": "^3.12.0", "tsx": "^4.20.3", "prettier": "^3.5.3", "prettier-plugin-sort-json": "^4.1.1"}}' > package.json
echo '{"dependencies": {"openai": "^5.12.2", "cli-progress": "^3.12.0", "tsx": "^4.20.3", "prettier": "^3.5.3", "prettier-plugin-sort-json": "^4.1.1", "prettier-plugin-tailwindcss": "^0.6.14"}}' > package.json
npm install --no-package-lock
# 设置 NODE_PATH 让项目能找到这些依赖

View File

@ -2,7 +2,7 @@ name: Claude Code Review
on:
pull_request:
types: [opened, synchronize]
types: [opened]
# Optional: Only run on specific file changes
# paths:
# - "src/**/*.ts"
@ -12,12 +12,11 @@ on:
jobs:
claude-review:
# Optional: Filter by PR author
# if: |
# github.event.pull_request.user.login == 'external-contributor' ||
# github.event.pull_request.user.login == 'new-developer' ||
# github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR'
# Only trigger code review for PRs from the main repository due to upstream OIDC issues
# https://github.com/anthropics/claude-code-action/issues/542
if: |
(github.event.pull_request.head.repo.full_name == github.repository) &&
(github.event.pull_request.draft == false)
runs-on: ubuntu-latest
permissions:
contents: read
@ -45,6 +44,9 @@ jobs:
- Security concerns
- Test coverage
PR number: ${{ github.event.number }}
Repo: ${{ github.repository }}
Use the repository's CLAUDE.md for guidance on style and conventions. Be constructive and helpful in your feedback.
Use `gh pr comment` with your Bash tool to leave your review as a comment on the PR.

View File

@ -1,6 +1,6 @@
name: English Translator
name: Claude Translator
concurrency:
group: translator-${{ github.event.issue.number }}
group: translator-${{ github.event.comment.id || github.event.issue.number }}
cancel-in-progress: false
on:
@ -12,13 +12,15 @@ on:
jobs:
translate:
if: |
(github.event_name == 'issues' && github.event.issue.author_association == 'COLLABORATOR' && !contains(github.event.issue.body, 'This issue/comment was translated by Claude.')) ||
(github.event_name == 'issue_comment' && github.event.comment.author_association == 'COLLABORATOR' && !contains(github.event.issue.body, 'This issue/comment was translated by Claude.'))
(github.event_name == 'issues') ||
(github.event_name == 'issue_comment' && github.event.sender.type != 'Bot') &&
((github.event_name == 'issue_comment' && github.event.action == 'created' && !contains(github.event.comment.body, 'This issue was translated by Claude')) ||
(github.event_name == 'issue_comment' && github.event.action == 'edited'))
runs-on: ubuntu-latest
permissions:
contents: read
issues: write # 编辑issues/comments
pull-requests: read
pull-requests: write
id-token: write
steps:
@ -28,11 +30,16 @@ jobs:
fetch-depth: 1
- name: Run Claude for translation
uses: anthropics/claude-code-action@v1
uses: anthropics/claude-code-action@main
id: claude
with:
# Warning: Permissions should have been controlled by workflow permission.
# Now `contents: read` is safe for files, but we could make a fine-grained token to control it.
# See: https://github.com/anthropics/claude-code-action/blob/main/docs/security.md
github_token: ${{ secrets.TOKEN_GITHUB_WRITE }}
allowed_non_write_users: '*'
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
claude_args: '--allowed-tools mcp__github_comment__update_claude_comment,Bash(gh issue:*),Bash(gh api:repos/*/issues:*)'
claude_args: '--allowed-tools Bash(gh issue:*),Bash(gh api:repos/*/issues:*)'
prompt: |
你是一个多语言翻译助手。请完成以下任务:
@ -50,7 +57,7 @@ jobs:
---
<details>
<summary>**Original Content:**</summary>
<summary>Original Content</summary>
[原始内容]
</details>

View File

@ -3,9 +3,11 @@
"endOfLine": "lf",
"jsonRecursiveSort": true,
"jsonSortOrder": "{\"*\": \"lexical\"}",
"plugins": ["prettier-plugin-sort-json"],
"plugins": ["prettier-plugin-sort-json", "prettier-plugin-tailwindcss"],
"printWidth": 120,
"semi": false,
"singleQuote": true,
"tailwindFunctions": ["clsx"],
"tailwindStylesheet": "./src/renderer/src/assets/styles/tailwind.css",
"trailingComma": "none"
}

View File

@ -28,6 +28,9 @@
"source.organizeImports": "never"
},
"editor.formatOnSave": true,
"files.associations": {
"*.css": "tailwindcss"
},
"files.eol": "\n",
"i18n-ally.displayLanguage": "zh-cn",
"i18n-ally.enabledFrameworks": ["react-i18next", "i18next"],

View File

@ -92,6 +92,12 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
- **Multi-language Support**: i18n with dynamic loading
- **Theme System**: Light/dark themes with custom CSS variables
### UI Design
The project is in the process of migrating from antd & styled-components to HeroUI. Please use HeroUI to build UI components. The use of antd and styled-components is prohibited.
HeroUI Docs: https://www.heroui.com/docs/guide/introduction
### Database Architecture
- **Database**: SQLite (`cherrystudio.sqlite`) + libsql driver

View File

@ -82,7 +82,7 @@ Cherry Studio is a desktop client that supports multiple LLM providers, availabl
1. **Diverse LLM Provider Support**:
- ☁️ Major LLM Cloud Services: OpenAI, Gemini, Anthropic, and more
- 🔗 AI Web Service Integration: Claude, Peplexity, Poe, and others
- 🔗 AI Web Service Integration: Claude, Perplexity, Poe, and others
- 💻 Local Model Support with Ollama, LM Studio
2. **AI Assistants & Conversations**:

21
components.json Normal file
View File

@ -0,0 +1,21 @@
{
"$schema": "https://ui.shadcn.com/schema.json",
"aliases": {
"components": "@renderer/ui/third-party",
"hooks": "@renderer/hooks",
"lib": "@renderer/lib",
"ui": "@renderer/ui",
"utils": "@renderer/utils"
},
"iconLibrary": "lucide",
"rsc": false,
"style": "new-york",
"tailwind": {
"baseColor": "zinc",
"config": "",
"css": "src/renderer/src/assets/styles/tailwind.css",
"cssVariables": true,
"prefix": ""
},
"tsx": true
}

View File

@ -13,7 +13,7 @@
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=fr">Français</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=de">Deutsch</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=es">Español</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=it">Itapano</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=it">Italiano</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=ru">Русский</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=pt">Português</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=nl">Nederlands</a></p>
@ -89,7 +89,7 @@ https://docs.cherry-ai.com
1. **多样化 LLM 服务支持**
- ☁️ 支持主流 LLM 云服务OpenAI、Gemini、Anthropic、硅基流动等
- 🔗 集成流行 AI Web 服务Claude、Peplexity、Poe、腾讯元宝、知乎直答等
- 🔗 集成流行 AI Web 服务Claude、Perplexity、Poe、腾讯元宝、知乎直答等
- 💻 支持 Ollama、LM Studio 本地模型部署
2. **智能助手与对话**

View File

@ -113,6 +113,10 @@ linux:
StartupWMClass: CherryStudio
mimeTypes:
- x-scheme-handler/cherrystudio
rpm:
# Workaround for electron build issue on rpm package:
# https://github.com/electron/forge/issues/3594
fpm: ['--rpm-rpmbuild-define=_build_id_links none']
publish:
provider: generic
url: https://releases.cherry-ai.com

View File

@ -74,6 +74,7 @@ export default defineConfig({
},
renderer: {
plugins: [
(async () => (await import('@tailwindcss/vite')).default())(),
react({
tsDecorators: true,
plugins: [

View File

@ -123,7 +123,10 @@ export default defineConfig([
'.gitignore',
'scripts/cloudflare-worker.js',
'src/main/integration/nutstore/sso/lib/**',
'src/main/integration/cherryin/index.js'
'src/main/integration/cherryin/index.js',
'src/main/integration/nutstore/sso/lib/**',
'src/renderer/src/ui/**',
'packages/**/dist'
]
}
])

View File

@ -69,6 +69,7 @@
"format:check": "prettier --check .",
"lint": "eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix && yarn typecheck && yarn check:i18n",
"prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky",
"claude": "dotenv -e .env -- claude",
"migrations:generate": "drizzle-kit generate --config ./migrations/sqlite-drizzle.config.ts"
},
"dependencies": {
@ -76,14 +77,18 @@
"@libsql/win32-x64-msvc": "^0.4.7",
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
"@strongtz/win32-arm64-msvc": "^0.4.7",
"express": "^5.1.0",
"faiss-node": "^0.5.1",
"font-list": "^2.0.0",
"graceful-fs": "^4.2.11",
"jsdom": "26.1.0",
"node-stream-zip": "^1.15.0",
"officeparser": "^4.2.0",
"os-proxy-config": "^1.1.2",
"selection-hook": "^1.0.11",
"selection-hook": "^1.0.12",
"sharp": "^0.34.3",
"swagger-jsdoc": "^6.2.8",
"swagger-ui-express": "^5.0.1",
"tesseract.js": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch",
"turndown": "7.2.0"
},
@ -92,8 +97,9 @@
"@agentic/searxng": "^7.3.3",
"@agentic/tavily": "^7.3.3",
"@ai-sdk/amazon-bedrock": "^3.0.0",
"@ai-sdk/google-vertex": "^3.0.0",
"@ai-sdk/google-vertex": "^3.0.25",
"@ai-sdk/mistral": "^2.0.0",
"@ai-sdk/perplexity": "^2.0.8",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
@ -129,13 +135,14 @@
"@eslint/js": "^9.22.0",
"@google/genai": "patch:@google/genai@npm%3A1.0.1#~/.yarn/patches/@google-genai-npm-1.0.1-e26f0f9af7.patch",
"@hello-pangea/dnd": "^18.0.1",
"@heroui/react": "^2.8.3",
"@kangfenmao/keyv-storage": "^0.1.0",
"@langchain/community": "^0.3.50",
"@langchain/core": "^0.3.68",
"@langchain/ollama": "^0.2.1",
"@langchain/openai": "^0.6.7",
"@mistralai/mistralai": "^1.7.5",
"@modelcontextprotocol/sdk": "^1.17.0",
"@modelcontextprotocol/sdk": "^1.17.5",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
"@openrouter/ai-sdk-provider": "^1.1.2",
@ -149,6 +156,7 @@
"@reduxjs/toolkit": "^2.2.5",
"@shikijs/markdown-it": "^3.12.0",
"@swc/plugin-styled-components": "^8.0.4",
"@tailwindcss/vite": "^4.1.13",
"@tanstack/react-query": "^5.85.5",
"@tanstack/react-virtual": "^3.13.12",
"@testing-library/dom": "^10.4.0",
@ -174,6 +182,10 @@
"@truto/turndown-plugin-gfm": "^1.0.2",
"@tryfabric/martian": "^1.2.4",
"@types/cli-progress": "^3",
"@types/content-type": "^1.1.9",
"@types/cors": "^2.8.19",
"@types/diff": "^7",
"@types/express": "^5",
"@types/fs-extra": "^11",
"@types/he": "^1",
"@types/html-to-text": "^9",
@ -187,6 +199,9 @@
"@types/react-dom": "^19.0.4",
"@types/react-infinite-scroll-component": "^5.0.0",
"@types/react-transition-group": "^4.4.12",
"@types/react-window": "^1",
"@types/swagger-jsdoc": "^6",
"@types/swagger-ui-express": "^4.1.8",
"@types/tinycolor2": "^1",
"@types/turndown": "^5.0.5",
"@types/word-extractor": "^1",
@ -201,16 +216,18 @@
"@viz-js/lang-dot": "^1.0.5",
"@viz-js/viz": "^3.14.0",
"@xyflow/react": "^12.4.4",
"ai": "^5.0.29",
"ai": "^5.0.38",
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
"archiver": "^7.0.1",
"async-mutex": "^0.5.0",
"axios": "^1.7.3",
"browser-image-compression": "^2.0.2",
"chardet": "^2.1.0",
"check-disk-space": "3.4.0",
"cheerio": "^1.1.2",
"chokidar": "^4.0.3",
"cli-progress": "^3.12.0",
"clsx": "^2.1.1",
"code-inspector-plugin": "^0.20.14",
"color": "^5.0.0",
"concurrently": "^9.2.1",
@ -241,6 +258,7 @@
"fast-diff": "^1.3.0",
"fast-xml-parser": "^5.2.0",
"fetch-socks": "1.3.2",
"framer-motion": "^12.23.12",
"franc-min": "^6.2.0",
"fs-extra": "^11.2.0",
"google-auth-library": "^9.15.1",
@ -275,6 +293,7 @@
"playwright": "^1.52.0",
"prettier": "^3.5.3",
"prettier-plugin-sort-json": "^4.1.1",
"prettier-plugin-tailwindcss": "^0.6.14",
"proxy-agent": "^6.5.0",
"react": "^19.0.0",
"react-dom": "^19.0.0",
@ -304,18 +323,19 @@
"remark-math": "^6.0.0",
"remove-markdown": "^0.6.2",
"rollup-plugin-visualizer": "^5.12.0",
"sass": "^1.88.0",
"shiki": "^3.12.0",
"strict-url-sanitise": "^0.0.1",
"string-width": "^7.2.0",
"striptags": "^3.2.0",
"styled-components": "^6.1.11",
"swr": "^2.3.6",
"tailwindcss": "^4.1.13",
"tar": "^7.4.3",
"tiny-pinyin": "^1.3.2",
"tokenx": "^1.1.0",
"tsx": "^4.20.3",
"turndown-plugin-gfm": "^1.0.2",
"tw-animate-css": "^1.3.8",
"typescript": "^5.6.2",
"undici": "6.21.2",
"unified": "^11.0.5",
@ -331,7 +351,7 @@
"yjs": "^13.6.27",
"youtubei.js": "^15.0.1",
"zipread": "^1.3.3",
"zod": "^3.25.74"
"zod": "^4.1.5"
},
"resolutions": {
"@codemirror/language": "6.11.3",
@ -360,7 +380,7 @@
"prettier --write",
"eslint --fix"
],
"*.{json,yml,yaml,css,scss,html}": [
"*.{json,yml,yaml,css,html}": [
"prettier --write"
]
}

View File

@ -1,103 +0,0 @@
/**
* Hub Provider 使
*
* 使Hub Provider功能来路由到多个底层provider
*/
import { createHubProvider, initializeProvider, providerRegistry } from '../src/index'
async function demonstrateHubProvider() {
try {
// 1. 初始化底层providers
console.log('📦 初始化底层providers...')
initializeProvider('openai', {
apiKey: process.env.OPENAI_API_KEY || 'sk-test-key'
})
initializeProvider('anthropic', {
apiKey: process.env.ANTHROPIC_API_KEY || 'sk-ant-test-key'
})
// 2. 创建Hub Provider自动包含所有已初始化的providers
console.log('🌐 创建Hub Provider...')
const aihubmixProvider = createHubProvider({
hubId: 'aihubmix',
debug: true
})
// 3. 注册Hub Provider
providerRegistry.registerProvider('aihubmix', aihubmixProvider)
console.log('✅ Hub Provider "aihubmix" 注册成功')
// 4. 使用Hub Provider访问不同的模型
console.log('\n🚀 使用Hub模型...')
// 通过Hub路由到OpenAI
const openaiModel = providerRegistry.languageModel('aihubmix:openai:gpt-4')
console.log('✓ OpenAI模型已获取:', openaiModel.modelId)
// 通过Hub路由到Anthropic
const anthropicModel = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
console.log('✓ Anthropic模型已获取:', anthropicModel.modelId)
// 5. 演示错误处理
console.log('\n❌ 演示错误处理...')
try {
// 尝试访问未初始化的provider
providerRegistry.languageModel('aihubmix:google:gemini-pro')
} catch (error) {
console.log('预期错误:', error.message)
}
try {
// 尝试使用错误的模型ID格式
providerRegistry.languageModel('aihubmix:invalid-format')
} catch (error) {
console.log('预期错误:', error.message)
}
// 6. 多个Hub Provider示例
console.log('\n🔄 创建多个Hub Provider...')
const localHubProvider = createHubProvider({
hubId: 'local-ai'
})
providerRegistry.registerProvider('local-ai', localHubProvider)
console.log('✅ Hub Provider "local-ai" 注册成功')
console.log('\n🎉 Hub Provider演示完成')
} catch (error) {
console.error('💥 演示过程中发生错误:', error)
}
}
// 演示简化的使用方式
function simplifiedUsageExample() {
console.log('\n📝 简化使用示例:')
console.log(`
// 1. 初始化providers
initializeProvider('openai', { apiKey: 'sk-xxx' })
initializeProvider('anthropic', { apiKey: 'sk-ant-xxx' })
// 2. 创建并注册Hub Provider
const hubProvider = createHubProvider({ hubId: 'aihubmix' })
providerRegistry.registerProvider('aihubmix', hubProvider)
// 3. 直接使用
const model1 = providerRegistry.languageModel('aihubmix:openai:gpt-4')
const model2 = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
`)
}
// 运行演示
if (require.main === module) {
demonstrateHubProvider()
simplifiedUsageExample()
}
export { demonstrateHubProvider, simplifiedUsageExample }

View File

@ -1,167 +0,0 @@
/**
* Image Generation Example
* 使 aiCore
*/
import { createExecutor, generateImage } from '../src/index'
async function main() {
// 方式1: 使用执行器实例
console.log('📸 创建 OpenAI 图像生成执行器...')
const executor = createExecutor('openai', {
apiKey: process.env.OPENAI_API_KEY!
})
try {
console.log('🎨 使用执行器生成图像...')
const result1 = await executor.generateImage('dall-e-3', {
prompt: 'A futuristic cityscape at sunset with flying cars',
size: '1024x1024',
n: 1
})
console.log('✅ 图像生成成功!')
console.log('📊 结果:', {
imagesCount: result1.images.length,
mediaType: result1.image.mediaType,
hasBase64: !!result1.image.base64,
providerMetadata: result1.providerMetadata
})
} catch (error) {
console.error('❌ 执行器生成失败:', error)
}
// 方式2: 使用直接调用 API
try {
console.log('🎨 使用直接 API 生成图像...')
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
prompt: 'A magical forest with glowing mushrooms and fairy lights',
aspectRatio: '16:9',
providerOptions: {
openai: {
quality: 'hd',
style: 'vivid'
}
}
})
console.log('✅ 直接 API 生成成功!')
console.log('📊 结果:', {
imagesCount: result2.images.length,
mediaType: result2.image.mediaType,
hasBase64: !!result2.image.base64
})
} catch (error) {
console.error('❌ 直接 API 生成失败:', error)
}
// 方式3: 支持其他提供商 (Google Imagen)
if (process.env.GOOGLE_API_KEY) {
try {
console.log('🎨 使用 Google Imagen 生成图像...')
const googleExecutor = createExecutor('google', {
apiKey: process.env.GOOGLE_API_KEY!
})
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
prompt: 'A serene mountain lake at dawn with mist rising from the water',
aspectRatio: '1:1'
})
console.log('✅ Google Imagen 生成成功!')
console.log('📊 结果:', {
imagesCount: result3.images.length,
mediaType: result3.image.mediaType,
hasBase64: !!result3.image.base64
})
} catch (error) {
console.error('❌ Google Imagen 生成失败:', error)
}
}
// 方式4: 支持插件系统
const pluginExample = async () => {
console.log('🔌 演示插件系统...')
// 创建一个示例插件,用于修改提示词
const promptEnhancerPlugin = {
name: 'prompt-enhancer',
transformParams: async (params: any) => {
console.log('🔧 插件: 增强提示词...')
return {
...params,
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
}
},
transformResult: async (result: any) => {
console.log('🔧 插件: 处理结果...')
return {
...result,
enhanced: true
}
}
}
const executorWithPlugin = createExecutor(
'openai',
{
apiKey: process.env.OPENAI_API_KEY!
},
[promptEnhancerPlugin]
)
try {
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A cute robot playing in a garden'
})
console.log('✅ 插件系统生成成功!')
console.log('📊 结果:', {
imagesCount: result4.images.length,
enhanced: (result4 as any).enhanced,
mediaType: result4.image.mediaType
})
} catch (error) {
console.error('❌ 插件系统生成失败:', error)
}
}
await pluginExample()
}
// 错误处理演示
async function errorHandlingExample() {
console.log('⚠️ 演示错误处理...')
try {
const executor = createExecutor('openai', {
apiKey: 'invalid-key'
})
await executor.generateImage('dall-e-3', {
prompt: 'Test image'
})
} catch (error: any) {
console.log('✅ 成功捕获错误:', error.constructor.name)
console.log('📋 错误信息:', error.message)
console.log('🏷️ 提供商ID:', error.providerId)
console.log('🏷️ 模型ID:', error.modelId)
}
}
// 运行示例
if (require.main === module) {
main()
.then(() => {
console.log('🎉 所有示例完成!')
return errorHandlingExample()
})
.then(() => {
console.log('🎯 示例程序结束')
process.exit(0)
})
.catch((error) => {
console.error('💥 程序执行出错:', error)
process.exit(1)
})
}

View File

@ -1,6 +1,6 @@
{
"name": "@cherrystudio/ai-core",
"version": "1.0.0-alpha.11",
"version": "1.0.0-alpha.14",
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
"main": "dist/index.js",
"module": "dist/index.mjs",
@ -39,13 +39,13 @@
"@ai-sdk/anthropic": "^2.0.5",
"@ai-sdk/azure": "^2.0.16",
"@ai-sdk/deepseek": "^1.0.9",
"@ai-sdk/google": "^2.0.7",
"@ai-sdk/openai": "^2.0.19",
"@ai-sdk/google": "^2.0.13",
"@ai-sdk/openai": "^2.0.26",
"@ai-sdk/openai-compatible": "^1.0.9",
"@ai-sdk/provider": "^2.0.0",
"@ai-sdk/provider-utils": "^3.0.4",
"@ai-sdk/xai": "^2.0.9",
"zod": "^3.25.0"
"zod": "^4.1.5"
},
"devDependencies": {
"tsdown": "^0.12.9",

View File

@ -84,7 +84,6 @@ export class ModelResolver {
*/
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
console.log('fullModelId', fullModelId)
return globalRegistryManagement.languageModel(fullModelId as any)
}

View File

@ -1,7 +1,7 @@
// copy from @ai-sdk/xai/xai-chat-options.ts
// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件
import * as z from 'zod/v4'
import { z } from 'zod'
const webSourceSchema = z.object({
type: z.literal('web'),

View File

@ -0,0 +1,39 @@
import { google } from '@ai-sdk/google'
import { definePlugin } from '../../'
import type { AiRequestContext } from '../../types'
const toolNameMap = {
googleSearch: 'google_search',
urlContext: 'url_context',
codeExecution: 'code_execution'
} as const
type ToolConfigKey = keyof typeof toolNameMap
type ToolConfig = { googleSearch?: boolean; urlContext?: boolean; codeExecution?: boolean }
export const googleToolsPlugin = (config?: ToolConfig) =>
definePlugin({
name: 'googleToolsPlugin',
transformParams: <T>(params: T, context: AiRequestContext): T => {
const { providerId } = context
if (providerId === 'google' && config) {
if (typeof params === 'object' && params !== null) {
const typedParams = params as T & { tools?: Record<string, unknown> }
if (!typedParams.tools) {
typedParams.tools = {}
}
// 使用类型安全的方式遍历配置
;(Object.keys(config) as ToolConfigKey[]).forEach((key) => {
if (config[key] && key in toolNameMap && key in google.tools) {
const toolName = toolNameMap[key]
typedParams.tools![toolName] = google.tools[key]({})
}
})
}
}
return params
}
})

View File

@ -4,6 +4,7 @@
*/
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
export { googleToolsPlugin } from './googleToolsPlugin'
export { createLoggingPlugin } from './logging'
export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin'
export type { PromptToolUseConfig, ToolUseRequestContext, ToolUseResult } from './toolUsePlugin/type'

View File

@ -27,10 +27,20 @@ export class StreamEventManager {
/**
*
*/
sendStepFinishEvent(controller: StreamController, chunk: any): void {
sendStepFinishEvent(
controller: StreamController,
chunk: any,
context: AiRequestContext,
finishReason: string = 'stop'
): void {
// 累加当前步骤的 usage
if (chunk.usage && context.accumulatedUsage) {
this.accumulateUsage(context.accumulatedUsage, chunk.usage)
}
controller.enqueue({
type: 'finish-step',
finishReason: 'stop',
finishReason,
response: chunk.response,
usage: chunk.usage,
providerMetadata: chunk.providerMetadata
@ -43,28 +53,32 @@ export class StreamEventManager {
async handleRecursiveCall(
controller: StreamController,
recursiveParams: any,
context: AiRequestContext,
stepId: string
context: AiRequestContext
): Promise<void> {
try {
console.log('[MCP Prompt] Starting recursive call after tool execution...')
// try {
// 重置工具执行状态,准备处理新的步骤
context.hasExecutedToolsInCurrentStep = false
const recursiveResult = await context.recursiveCall(recursiveParams)
const recursiveResult = await context.recursiveCall(recursiveParams)
if (recursiveResult && recursiveResult.fullStream) {
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
} else {
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
}
} catch (error) {
this.handleRecursiveCallError(controller, error, stepId)
if (recursiveResult && recursiveResult.fullStream) {
await this.pipeRecursiveStream(controller, recursiveResult.fullStream, context)
} else {
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
}
// } catch (error) {
// this.handleRecursiveCallError(controller, error, stepId)
// }
}
/**
*
*/
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
private async pipeRecursiveStream(
controller: StreamController,
recursiveStream: ReadableStream,
context?: AiRequestContext
): Promise<void> {
const reader = recursiveStream.getReader()
try {
while (true) {
@ -73,9 +87,16 @@ export class StreamEventManager {
break
}
if (value.type === 'finish') {
// 迭代的流不发finish
// 迭代的流不发finish但需要累加其 usage
if (value.usage && context?.accumulatedUsage) {
this.accumulateUsage(context.accumulatedUsage, value.usage)
}
break
}
// 对于 finish-step 类型,累加其 usage
if (value.type === 'finish-step' && value.usage && context?.accumulatedUsage) {
this.accumulateUsage(context.accumulatedUsage, value.usage)
}
// 将递归流的数据传递到当前流
controller.enqueue(value)
}
@ -87,25 +108,25 @@ export class StreamEventManager {
/**
*
*/
private handleRecursiveCallError(controller: StreamController, error: unknown, stepId: string): void {
console.error('[MCP Prompt] Recursive call failed:', error)
// private handleRecursiveCallError(controller: StreamController, error: unknown): void {
// console.error('[MCP Prompt] Recursive call failed:', error)
// 使用 AI SDK 标准错误格式,但不中断流
controller.enqueue({
type: 'error',
error: {
message: error instanceof Error ? error.message : String(error),
name: error instanceof Error ? error.name : 'RecursiveCallError'
}
})
// // 使用 AI SDK 标准错误格式,但不中断流
// controller.enqueue({
// type: 'error',
// error: {
// message: error instanceof Error ? error.message : String(error),
// name: error instanceof Error ? error.name : 'RecursiveCallError'
// }
// })
// 继续发送文本增量,保持流的连续性
controller.enqueue({
type: 'text-delta',
id: stepId,
text: '\n\n[工具执行后递归调用失败,继续对话...]'
})
}
// // // 继续发送文本增量,保持流的连续性
// // controller.enqueue({
// // type: 'text-delta',
// // id: stepId,
// // text: '\n\n[工具执行后递归调用失败,继续对话...]'
// // })
// }
/**
*
@ -136,4 +157,18 @@ export class StreamEventManager {
return recursiveParams
}
/**
* usage
*/
private accumulateUsage(target: any, source: any): void {
if (!target || !source) return
// 累加各种 token 类型
target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0)
target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0)
target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0)
target.reasoningTokens = (target.reasoningTokens || 0) + (source.reasoningTokens || 0)
target.cachedInputTokens = (target.cachedInputTokens || 0) + (source.cachedInputTokens || 0)
}
}

View File

@ -4,7 +4,7 @@
*
* promptToolUsePlugin.ts
*/
import type { ToolSet } from 'ai'
import type { ToolSet, TypedToolError } from 'ai'
import type { ToolUseResult } from './type'
@ -38,7 +38,6 @@ export class ToolExecutor {
controller: StreamController
): Promise<ExecutedResult[]> {
const executedResults: ExecutedResult[] = []
for (const toolUse of toolUses) {
try {
const tool = tools[toolUse.toolName]
@ -46,17 +45,12 @@ export class ToolExecutor {
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
}
// 发送工具调用开始事件
this.sendToolStartEvents(controller, toolUse)
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
// 发送 tool-call 事件
controller.enqueue({
type: 'tool-call',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: tool.inputSchema
input: toolUse.arguments
})
const result = await tool.execute(toolUse.arguments, {
@ -111,45 +105,46 @@ export class ToolExecutor {
/**
*
*/
private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
// 发送 tool-input-start 事件
controller.enqueue({
type: 'tool-input-start',
id: toolUse.id,
toolName: toolUse.toolName
})
}
// private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
// // 发送 tool-input-start 事件
// controller.enqueue({
// type: 'tool-input-start',
// id: toolUse.id,
// toolName: toolUse.toolName
// })
// }
/**
*
*/
private handleToolError(
private handleToolError<T extends ToolSet>(
toolUse: ToolUseResult,
error: unknown,
controller: StreamController
// _tools: ToolSet
): ExecutedResult {
// 使用 AI SDK 标准错误格式
// const toolError: TypedToolError<typeof _tools> = {
// type: 'tool-error',
// toolCallId: toolUse.id,
// toolName: toolUse.toolName,
// input: toolUse.arguments,
// error: error instanceof Error ? error.message : String(error)
// }
const toolError: TypedToolError<T> = {
type: 'tool-error',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: toolUse.arguments,
error
}
// controller.enqueue(toolError)
controller.enqueue(toolError)
// 发送标准错误事件
controller.enqueue({
type: 'error',
error: error instanceof Error ? error.message : String(error)
})
// controller.enqueue({
// type: 'tool-error',
// toolCallId: toolUse.id,
// error: error instanceof Error ? error.message : String(error),
// input: toolUse.arguments
// })
return {
toolCallId: toolUse.id,
toolName: toolUse.toolName,
result: error instanceof Error ? error.message : String(error),
result: error,
isError: true
}
}

View File

@ -8,9 +8,19 @@ import type { TextStreamPart, ToolSet } from 'ai'
import { definePlugin } from '../../index'
import type { AiRequestContext } from '../../types'
import { StreamEventManager } from './StreamEventManager'
import { type TagConfig, TagExtractor } from './tagExtraction'
import { ToolExecutor } from './ToolExecutor'
import { PromptToolUseConfig, ToolUseResult } from './type'
/**
* 使
*/
const TOOL_USE_TAG_CONFIG: TagConfig = {
openingTag: '<tool_use>',
closingTag: '</tool_use>',
separator: '\n'
}
/**
* Cherry Studio
*/
@ -146,8 +156,10 @@ Assistant: The population of Shanghai is 26 million, while Guangzhou has a popul
/**
* Cherry Studio
*/
function buildAvailableTools(tools: ToolSet): string {
function buildAvailableTools(tools: ToolSet): string | null {
const availableTools = Object.keys(tools)
if (availableTools.length === 0) return null
const result = availableTools
.map((toolName: string) => {
const tool = tools[toolName]
return `
@ -162,7 +174,7 @@ function buildAvailableTools(tools: ToolSet): string {
})
.join('\n')
return `<tools>
${availableTools}
${result}
</tools>`
}
@ -171,6 +183,7 @@ ${availableTools}
*/
function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): string {
const availableTools = buildAvailableTools(tools)
if (availableTools === null) return userSystemPrompt
const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES)
.replace('{{ AVAILABLE_TOOLS }}', availableTools)
@ -249,13 +262,11 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
}
context.mcpTools = params.tools
console.log('tools stored in context', params.tools)
// 构建系统提示符
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
let systemMessage: string | null = systemPrompt
console.log('config.context', context)
if (config.createSystemMessage) {
// 🎯 如果用户提供了自定义处理函数,使用它
systemMessage = config.createSystemMessage(systemPrompt, params, context)
@ -268,20 +279,40 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
tools: undefined
}
context.originalParams = transformedParams
console.log('transformedParams', transformedParams)
return transformedParams
},
transformStream: (_: any, context: AiRequestContext) => () => {
let textBuffer = ''
let stepId = ''
// let stepId = ''
if (!context.mcpTools) {
throw new Error('No tools available')
}
// 创建工具执行器和流事件管理器
// 从 context 中获取或初始化 usage 累加器
if (!context.accumulatedUsage) {
context.accumulatedUsage = {
inputTokens: 0,
outputTokens: 0,
totalTokens: 0,
reasoningTokens: 0,
cachedInputTokens: 0
}
}
// 创建工具执行器、流事件管理器和标签提取器
const toolExecutor = new ToolExecutor()
const streamEventManager = new StreamEventManager()
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
// 在context中初始化工具执行状态避免递归调用时状态丢失
if (!context.hasExecutedToolsInCurrentStep) {
context.hasExecutedToolsInCurrentStep = false
}
// 用于hold text-start事件直到确认有非工具标签内容
let pendingTextStart: TextStreamPart<TOOLS> | null = null
let hasStartedText = false
type TOOLS = NonNullable<typeof context.mcpTools>
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
@ -289,83 +320,106 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
chunk: TextStreamPart<TOOLS>,
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
) {
// 收集文本内容
if (chunk.type === 'text-delta') {
textBuffer += chunk.text || ''
stepId = chunk.id || ''
controller.enqueue(chunk)
// Hold住text-start事件直到确认有非工具标签内容
if ((chunk as any).type === 'text-start') {
pendingTextStart = chunk
return
}
if (chunk.type === 'text-end' || chunk.type === 'finish-step') {
const tools = context.mcpTools
if (!tools || Object.keys(tools).length === 0) {
controller.enqueue(chunk)
return
}
// text-delta阶段收集文本内容并过滤工具标签
if (chunk.type === 'text-delta') {
textBuffer += chunk.text || ''
// stepId = chunk.id || ''
// 解析工具调用
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools)
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
// 使用TagExtractor过滤工具标签只传递非标签内容到UI层
const extractionResults = tagExtractor.processText(chunk.text || '')
// 如果没有有效的工具调用,直接传递原始事件
if (validToolUses.length === 0) {
controller.enqueue(chunk)
return
}
if (chunk.type === 'text-end') {
controller.enqueue({
type: 'text-end',
id: stepId,
providerMetadata: {
text: {
value: parsedContent
}
for (const result of extractionResults) {
// 只传递非标签内容到UI层
if (!result.isTagContent && result.content) {
// 如果还没有发送text-start且有pending的text-start先发送它
if (!hasStartedText && pendingTextStart) {
controller.enqueue(pendingTextStart)
hasStartedText = true
pendingTextStart = null
}
})
return
const filteredChunk = {
...chunk,
text: result.content
}
controller.enqueue(filteredChunk)
}
}
return
}
if (chunk.type === 'text-end') {
// 只有当已经发送了text-start时才发送text-end
if (hasStartedText) {
controller.enqueue(chunk)
}
return
}
if (chunk.type === 'finish-step') {
// 统一在finish-step阶段检查并执行工具调用
const tools = context.mcpTools
if (tools && Object.keys(tools).length > 0 && !context.hasExecutedToolsInCurrentStep) {
// 解析完整的textBuffer来检测工具调用
const { results: parsedTools } = parseToolUse(textBuffer, tools)
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
if (validToolUses.length > 0) {
context.hasExecutedToolsInCurrentStep = true
// 执行工具调用(不需要手动发送 start-step外部流已经处理
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
// 发送步骤完成事件,使用 tool-calls 作为 finishReason
streamEventManager.sendStepFinishEvent(controller, chunk, context, 'tool-calls')
// 处理递归调用
const toolResultsText = toolExecutor.formatToolResults(executedResults)
const recursiveParams = streamEventManager.buildRecursiveParams(
context,
textBuffer,
toolResultsText,
tools
)
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context)
return
}
}
controller.enqueue({
...chunk,
finishReason: 'tool-calls'
})
// 发送步骤开始事件
streamEventManager.sendStepStartEvent(controller)
// 执行工具调用
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
// 发送步骤完成事件
streamEventManager.sendStepFinishEvent(controller, chunk)
// 处理递归调用
if (validToolUses.length > 0) {
const toolResultsText = toolExecutor.formatToolResults(executedResults)
const recursiveParams = streamEventManager.buildRecursiveParams(
context,
textBuffer,
toolResultsText,
tools
)
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context, stepId)
}
// 如果没有执行工具调用直接传递原始finish-step事件
controller.enqueue(chunk)
// 清理状态
textBuffer = ''
return
}
// 对于其他类型的事件,直接传递
controller.enqueue(chunk)
// 处理 finish 类型,使用累加后的 totalUsage
if (chunk.type === 'finish') {
controller.enqueue({
...chunk,
totalUsage: context.accumulatedUsage
})
return
}
// 对于其他类型的事件直接传递不包括text-start已在上面处理
if ((chunk as any).type !== 'text-start') {
controller.enqueue(chunk)
}
},
flush() {
// 流结束时的清理工作
console.log('[MCP Prompt] Stream ended, cleaning up...')
// 清理pending状态
pendingTextStart = null
hasStartedText = false
}
})
}

View File

@ -27,7 +27,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
case 'openai': {
if (config.openai) {
if (!params.tools) params.tools = {}
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
params.tools.web_search = openai.tools.webSearch(config.openai)
}
break
}

View File

@ -1,5 +1,8 @@
// 核心类型和接口
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
import type { ImageModelV2 } from '@ai-sdk/provider'
import type { LanguageModel } from 'ai'
import type { ProviderId } from '../providers'
import type { AiPlugin, AiRequestContext } from './types'
@ -9,16 +12,16 @@ export { PluginManager } from './manager'
// 工具函数
export function createContext<T extends ProviderId>(
providerId: T,
modelId: string,
model: LanguageModel | ImageModelV2,
originalParams: any
): AiRequestContext {
return {
providerId,
modelId,
model,
originalParams,
metadata: {},
startTime: Date.now(),
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
requestId: `${providerId}-${typeof model === 'string' ? model : model?.modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
// 占位
recursiveCall: () => Promise.resolve(null)
}

View File

@ -14,7 +14,7 @@ export type RecursiveCallFn = (newParams: any) => Promise<any>
*/
export interface AiRequestContext {
providerId: ProviderId
modelId: string
model: LanguageModel | ImageModelV2
originalParams: any
metadata: Record<string, any>
startTime: number

View File

@ -10,8 +10,8 @@ import { createGoogleGenerativeAI } from '@ai-sdk/google'
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
import { createXai } from '@ai-sdk/xai'
import { customProvider, type Provider } from 'ai'
import * as z from 'zod'
import { customProvider, Provider } from 'ai'
import { z } from 'zod'
/**
* Provider IDs
@ -38,14 +38,12 @@ export const baseProviderIdSchema = z.enum(baseProviderIds)
*/
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
export const baseProviderSchema = z.object({
id: baseProviderIdSchema,
name: z.string(),
creator: z.function().args(z.any()).returns(z.any()) as z.ZodType<(options: any) => Provider>,
supportsImageGeneration: z.boolean()
})
export type BaseProvider = z.infer<typeof baseProviderSchema>
type BaseProvider = {
id: BaseProviderId
name: string
creator: (options: any) => Provider
supportsImageGeneration: boolean
}
/**
* Providers
@ -148,7 +146,12 @@ export const providerConfigSchema = z
.object({
id: customProviderIdSchema, // 只允许自定义ID
name: z.string().min(1),
creator: z.function().optional(),
creator: z
.function({
input: z.any(),
output: z.any()
})
.optional(),
import: z.function().optional(),
creatorFunctionName: z.string().optional(),
supportsImageGeneration: z.boolean().default(false),

View File

@ -4,12 +4,12 @@
*/
import { ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
import {
experimental_generateImage as generateImage,
generateObject,
generateText,
experimental_generateImage as _generateImage,
generateObject as _generateObject,
generateText as _generateText,
LanguageModel,
streamObject,
streamText
streamObject as _streamObject,
streamText as _streamText
} from 'ai'
import { globalModelResolver } from '../models'
@ -18,7 +18,14 @@ import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
import { type ProviderId } from '../providers'
import { ImageGenerationError, ImageModelResolutionError } from './errors'
import { PluginEngine } from './pluginEngine'
import { type RuntimeConfig } from './types'
import type {
generateImageParams,
generateObjectParams,
generateTextParams,
RuntimeConfig,
streamObjectParams,
streamTextParams
} from './types'
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
public pluginEngine: PluginEngine<T>
@ -75,12 +82,12 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
*
*/
async streamText(
params: Parameters<typeof streamText>[0],
params: streamTextParams,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof streamText>> {
const { model, ...restParams } = params
): Promise<ReturnType<typeof _streamText>> {
const { model } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
@ -94,19 +101,16 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
return this.pluginEngine.executeStreamWithPlugins(
'streamText',
model,
restParams,
async (resolvedModel, transformedParams, streamTransforms) => {
params,
(resolvedModel, transformedParams, streamTransforms) => {
const experimental_transform =
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
const finalParams = {
model: resolvedModel,
return _streamText({
...transformedParams,
model: resolvedModel,
experimental_transform
} as Parameters<typeof streamText>[0]
return await streamText(finalParams)
})
}
)
}
@ -117,12 +121,12 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
*
*/
async generateText(
params: Parameters<typeof generateText>[0],
params: generateTextParams,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateText>> {
const { model, ...restParams } = params
): Promise<ReturnType<typeof _generateText>> {
const { model } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
@ -134,12 +138,10 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
}
return this.pluginEngine.executeWithPlugins(
return this.pluginEngine.executeWithPlugins<Parameters<typeof _generateText>[0], ReturnType<typeof _generateText>>(
'generateText',
model,
restParams,
async (resolvedModel, transformedParams) =>
generateText({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateText>[0])
params,
(resolvedModel, transformedParams) => _generateText({ ...transformedParams, model: resolvedModel })
)
}
@ -147,12 +149,12 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
*
*/
async generateObject(
params: Parameters<typeof generateObject>[0],
params: generateObjectParams,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateObject>> {
const { model, ...restParams } = params
): Promise<ReturnType<typeof _generateObject>> {
const { model } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
@ -164,25 +166,23 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
}
return this.pluginEngine.executeWithPlugins(
return this.pluginEngine.executeWithPlugins<generateObjectParams, ReturnType<typeof _generateObject>>(
'generateObject',
model,
restParams,
async (resolvedModel, transformedParams) =>
generateObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateObject>[0])
params,
async (resolvedModel, transformedParams) => _generateObject({ ...transformedParams, model: resolvedModel })
)
}
/**
*
*/
async streamObject(
params: Parameters<typeof streamObject>[0],
streamObject(
params: streamObjectParams,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof streamObject>> {
const { model, ...restParams } = params
): Promise<ReturnType<typeof _streamObject>> {
const { model } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
@ -194,23 +194,17 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
}
return this.pluginEngine.executeWithPlugins(
'streamObject',
model,
restParams,
async (resolvedModel, transformedParams) =>
streamObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof streamObject>[0])
return this.pluginEngine.executeStreamWithPlugins('streamObject', params, (resolvedModel, transformedParams) =>
_streamObject({ ...transformedParams, model: resolvedModel })
)
}
/**
*
*/
async generateImage(
params: Omit<Parameters<typeof generateImage>[0], 'model'> & { model: string | ImageModelV2 }
): Promise<ReturnType<typeof generateImage>> {
generateImage(params: generateImageParams): Promise<ReturnType<typeof _generateImage>> {
try {
const { model, ...restParams } = params
const { model } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
@ -219,13 +213,8 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
}
return await this.pluginEngine.executeImageWithPlugins(
'generateImage',
model,
restParams,
async (resolvedModel, transformedParams) => {
return await generateImage({ model: resolvedModel, ...transformedParams })
}
return this.pluginEngine.executeImageWithPlugins('generateImage', params, (resolvedModel, transformedParams) =>
_generateImage({ ...transformedParams, model: resolvedModel })
)
} catch (error) {
if (error instanceof Error) {

View File

@ -1,6 +1,6 @@
/* eslint-disable @eslint-react/naming-convention/context-name */
import { ImageModelV2 } from '@ai-sdk/provider'
import { LanguageModel } from 'ai'
import { experimental_generateImage, generateObject, generateText, LanguageModel, streamObject, streamText } from 'ai'
import { type AiPlugin, createContext, PluginManager } from '../plugins'
import { type ProviderId } from '../providers/types'
@ -62,17 +62,19 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
*
* AiExecutor使用
*/
async executeWithPlugins<TParams, TResult>(
async executeWithPlugins<
TParams extends Parameters<typeof generateText | typeof generateObject>[0],
TResult extends ReturnType<typeof generateText | typeof generateObject>
>(
methodName: string,
model: LanguageModel,
params: TParams,
executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
executor: (model: LanguageModel, transformedParams: TParams) => TResult,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 统一处理模型解析
let resolvedModel: LanguageModel | undefined
let modelId: string
const { model } = params
if (typeof model === 'string') {
// 字符串:需要通过插件解析
modelId = model
@ -83,13 +85,13 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
}
// 使用正确的createContext创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
const context = _context ? _context : createContext(this.providerId, model, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
context.isRecursiveCall = true
const result = await this.executeWithPlugins(methodName, model, newParams, executor, context)
const result = await this.executeWithPlugins(methodName, newParams, executor, context)
context.isRecursiveCall = false
return result
}
@ -138,17 +140,19 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
*
* AiExecutor使用
*/
async executeImageWithPlugins<TParams, TResult>(
async executeImageWithPlugins<
TParams extends Omit<Parameters<typeof experimental_generateImage>[0], 'model'> & { model: string | ImageModelV2 },
TResult extends ReturnType<typeof experimental_generateImage>
>(
methodName: string,
model: ImageModelV2 | string,
params: TParams,
executor: (model: ImageModelV2, transformedParams: TParams) => Promise<TResult>,
executor: (model: ImageModelV2, transformedParams: TParams) => TResult,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 统一处理模型解析
let resolvedModel: ImageModelV2 | undefined
let modelId: string
const { model } = params
if (typeof model === 'string') {
// 字符串:需要通过插件解析
modelId = model
@ -159,13 +163,13 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
}
// 使用正确的createContext创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
const context = _context ? _context : createContext(this.providerId, model, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
context.isRecursiveCall = true
const result = await this.executeImageWithPlugins(methodName, model, newParams, executor, context)
const result = await this.executeImageWithPlugins(methodName, newParams, executor, context)
context.isRecursiveCall = false
return result
}
@ -214,17 +218,19 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
*
* AiExecutor使用
*/
async executeStreamWithPlugins<TParams, TResult>(
async executeStreamWithPlugins<
TParams extends Parameters<typeof streamText | typeof streamObject>[0],
TResult extends ReturnType<typeof streamText | typeof streamObject>
>(
methodName: string,
model: LanguageModel,
params: TParams,
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => TResult,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 统一处理模型解析
let resolvedModel: LanguageModel | undefined
let modelId: string
const { model } = params
if (typeof model === 'string') {
// 字符串:需要通过插件解析
modelId = model
@ -235,13 +241,13 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
}
// 创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
const context = _context ? _context : createContext(this.providerId, model, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
context.isRecursiveCall = true
const result = await this.executeStreamWithPlugins(methodName, model, newParams, executor, context)
const result = await this.executeStreamWithPlugins(methodName, newParams, executor, context)
context.isRecursiveCall = false
return result
}

View File

@ -1,6 +1,9 @@
/**
* Runtime
*/
import { ImageModelV2 } from '@ai-sdk/provider'
import { experimental_generateImage, generateObject, generateText, streamObject, streamText } from 'ai'
import { type ModelConfig } from '../models/types'
import { type AiPlugin } from '../plugins'
import { type ProviderId } from '../providers/types'
@ -13,3 +16,11 @@ export interface RuntimeConfig<T extends ProviderId = ProviderId> {
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
plugins?: AiPlugin[]
}
export type generateImageParams = Omit<Parameters<typeof experimental_generateImage>[0], 'model'> & {
model: string | ImageModelV2
}
export type generateObjectParams = Parameters<typeof generateObject>[0]
export type generateTextParams = Parameters<typeof generateText>[0]
export type streamObjectParams = Parameters<typeof streamObject>[0]
export type streamTextParams = Parameters<typeof streamText>[0]

View File

@ -35,8 +35,10 @@ export enum IpcChannel {
App_InstallBunBinary = 'app:install-bun-binary',
App_LogToMain = 'app:log-to-main',
App_SaveData = 'app:save-data',
App_GetDiskInfo = 'app:get-disk-info',
App_SetFullScreen = 'app:set-full-screen',
App_IsFullScreen = 'app:is-full-screen',
App_GetSystemFonts = 'app:get-system-fonts',
App_MacIsProcessTrusted = 'app:mac-is-process-trusted',
App_MacRequestProcessTrust = 'app:mac-request-process-trust',
@ -331,6 +333,13 @@ export enum IpcChannel {
TRACE_CLEAN_LOCAL_DATA = 'trace:cleanLocalData',
TRACE_ADD_STREAM_MESSAGE = 'trace:addStreamMessage',
// API Server
ApiServer_Start = 'api-server:start',
ApiServer_Stop = 'api-server:stop',
ApiServer_Restart = 'api-server:restart',
ApiServer_GetStatus = 'api-server:get-status',
ApiServer_GetConfig = 'api-server:get-config',
// Anthropic OAuth
Anthropic_StartOAuthFlow = 'anthropic:start-oauth-flow',
Anthropic_CompleteOAuthWithCode = 'anthropic:complete-oauth-with-code',

6
packages/shared/utils.ts Normal file
View File

@ -0,0 +1,6 @@
export const defaultAppHeaders = () => {
return {
'HTTP-Referer': 'https://cherry-ai.com',
'X-Title': 'Cherry Studio'
}
}

View File

@ -8,18 +8,18 @@
</head>
<body class="bg-gray-50">
<div class="max-w-4xl mx-auto px-4 py-8">
<div class="mx-auto max-w-4xl px-4 py-8">
<!-- 中文版本 -->
<div class="mb-12">
<h1 class="text-3xl font-bold mb-8 text-gray-900">许可协议</h1>
<h1 class="mb-8 text-3xl font-bold text-gray-900">许可协议</h1>
<p class="mb-6 text-gray-700">
本项目采用<strong>区分用户的双重许可 (User-Segmented Dual Licensing)</strong> 模式。
</p>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">核心原则</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<h2 class="mb-4 text-xl font-semibold text-gray-900">核心原则</h2>
<ul class="list-disc space-y-2 pl-6 text-gray-700">
<li>
<strong>个人用户 和 10人及以下企业/组织:</strong> 默认适用
<strong>GNU Affero 通用公共许可证 v3.0 (AGPLv3)</strong>
@ -32,7 +32,7 @@
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">定义:"10人及以下"</h2>
<h2 class="mb-4 text-xl font-semibold text-gray-900">定义:"10人及以下"</h2>
<p class="text-gray-700">
指在您的组织包括公司、非营利组织、政府机构、教育机构等任何实体能够访问、使用或以任何方式直接或间接受益于本软件Cherry
Studio功能的个人总数不超过10人。这包括但不限于开发者、测试人员、运营人员、最终用户、通过集成系统间接使用者等。
@ -40,10 +40,10 @@
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">
<h2 class="mb-4 text-xl font-semibold text-gray-900">
1. 开源许可证 (Open Source License): AGPLv3 - 适用于个人及10人及以下组织
</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<ul class="list-disc space-y-2 pl-6 text-gray-700">
<li>
如果您是个人用户,或者您的组织满足上述"10人及以下"的定义,您可以在
<strong>AGPLv3</strong> 的条款下自由使用、修改和分发 Cherry Studio。AGPLv3 的完整文本可以访问
@ -62,10 +62,10 @@
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">
<h2 class="mb-4 text-xl font-semibold text-gray-900">
2. 商业许可证 (Commercial License) - 适用于超过10人的组织或希望规避 AGPLv3 义务的用户
</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<ul class="list-disc space-y-2 pl-6 text-gray-700">
<li>
<strong>强制要求:</strong>
如果您的组织<strong></strong>满足上述"10人及以下"的定义即有11人或更多人可以访问、使用或受益于本软件<strong>必须</strong>联系我们获取并签署一份商业许可证才能使用
@ -80,7 +80,7 @@
</li>
<li>
<strong>需要商业许可证的常见情况包括(但不限于):</strong>
<ul class="list-disc pl-6 mt-2 space-y-1">
<ul class="mt-2 list-disc space-y-1 pl-6">
<li>您的组织规模超过10人。</li>
<li>
(无论组织规模)您希望分发修改过的 Cherry Studio 版本,但<strong>不希望</strong>根据 AGPLv3
@ -104,8 +104,8 @@
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">3. 贡献 (Contributions)</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<h2 class="mb-4 text-xl font-semibold text-gray-900">3. 贡献 (Contributions)</h2>
<ul class="list-disc space-y-2 pl-6 text-gray-700">
<li>
我们欢迎社区对 Cherry Studio 的贡献。所有向本项目提交的贡献都将被视为在
<strong>AGPLv3</strong> 许可证下提供。
@ -119,8 +119,8 @@
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">4. 其他条款 (Other Terms)</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<h2 class="mb-4 text-xl font-semibold text-gray-900">4. 其他条款 (Other Terms)</h2>
<ul class="list-disc space-y-2 pl-6 text-gray-700">
<li>关于商业许可证的具体条款和条件,以双方签署的正式商业许可协议为准。</li>
<li>
项目维护者保留根据需要更新本许可政策(包括用户规模定义和阈值)的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。
@ -133,13 +133,13 @@
<!-- English Version -->
<div>
<h1 class="text-3xl font-bold mb-8 text-gray-900">Licensing</h1>
<h1 class="mb-8 text-3xl font-bold text-gray-900">Licensing</h1>
<p class="mb-6 text-gray-700">This project employs a <strong>User-Segmented Dual Licensing</strong> model.</p>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">Core Principle</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<h2 class="mb-4 text-xl font-semibold text-gray-900">Core Principle</h2>
<ul class="list-disc space-y-2 pl-6 text-gray-700">
<li>
<strong>Individual Users and Organizations with 10 or Fewer Individuals:</strong> Governed by default
under the <strong>GNU Affero General Public License v3.0 (AGPLv3)</strong>.
@ -152,7 +152,7 @@
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">Definition: "10 or Fewer Individuals"</h2>
<h2 class="mb-4 text-xl font-semibold text-gray-900">Definition: "10 or Fewer Individuals"</h2>
<p class="text-gray-700">
Refers to any organization (including companies, non-profits, government agencies, educational institutions,
etc.) where the total number of individuals who can access, use, or in any way directly or indirectly
@ -162,10 +162,10 @@
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">
<h2 class="mb-4 text-xl font-semibold text-gray-900">
1. Open Source License: AGPLv3 - For Individuals and Organizations of 10 or Fewer
</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<ul class="list-disc space-y-2 pl-6 text-gray-700">
<li>
If you are an individual user, or if your organization meets the "10 or Fewer Individuals" definition
above, you are free to use, modify, and distribute Cherry Studio under the terms of the
@ -186,11 +186,11 @@
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">
<h2 class="mb-4 text-xl font-semibold text-gray-900">
2. Commercial License - For Organizations with More Than 10 Individuals, or Users Needing to Avoid AGPLv3
Obligations
</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<ul class="list-disc space-y-2 pl-6 text-gray-700">
<li>
<strong>Mandatory Requirement:</strong> If your organization does <strong>not</strong> meet the "10 or
Fewer Individuals" definition above (i.e., 11 or more individuals can access, use, or benefit from the
@ -207,7 +207,7 @@
</li>
<li>
<strong>Common scenarios requiring a Commercial License include (but are not limited to):</strong>
<ul class="list-disc pl-6 mt-2 space-y-1">
<ul class="mt-2 list-disc space-y-1 pl-6">
<li>
Your organization has more than 10 individuals who can access, use, or benefit from the software.
</li>
@ -236,8 +236,8 @@
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">3. Contributions</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<h2 class="mb-4 text-xl font-semibold text-gray-900">3. Contributions</h2>
<ul class="list-disc space-y-2 pl-6 text-gray-700">
<li>
We welcome community contributions to Cherry Studio. All contributions submitted to this project are
considered to be offered under the <strong>AGPLv3</strong> license.
@ -255,8 +255,8 @@
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">4. Other Terms</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<h2 class="mb-4 text-xl font-semibold text-gray-900">4. Other Terms</h2>
<ul class="list-disc space-y-2 pl-6 text-gray-700">
<li>
The specific terms and conditions of the Commercial License are governed by the formal commercial license
agreement signed by both parties.

View File

@ -12,18 +12,18 @@
<body id="app">
<div :class="isDark ? 'dark-bg' : 'bg'" class="min-h-screen">
<div class="max-w-3xl mx-auto py-12 px-4">
<h1 class="text-3xl font-bold mb-8" :class="isDark ? 'text-white' : 'text-gray-900'">Release Timeline</h1>
<div class="mx-auto max-w-3xl px-4 py-12">
<h1 class="mb-8 text-3xl font-bold" :class="isDark ? 'text-white' : 'text-gray-900'">Release Timeline</h1>
<!-- Loading状态 -->
<div v-if="loading" class="text-center py-8">
<div v-if="loading" class="py-8 text-center">
<div
class="inline-block animate-spin rounded-full h-8 w-8 border-4"
class="inline-block h-8 w-8 animate-spin rounded-full border-4"
:class="isDark ? 'border-gray-700 border-t-blue-500' : 'border-gray-300 border-t-blue-500'"></div>
</div>
<!-- Error 状态 -->
<div v-else-if="error" class="text-red-500 text-center py-8">{{ error }}</div>
<div v-else-if="error" class="py-8 text-center text-red-500">{{ error }}</div>
<!-- Release 列表 -->
<div v-else class="space-y-8">
@ -32,21 +32,21 @@
:key="release.id"
class="relative pl-8"
:class="isDark ? 'border-l-2 border-gray-700' : 'border-l-2 border-gray-200'">
<div class="absolute -left-2 top-0 w-4 h-4 rounded-full bg-green-500"></div>
<div class="absolute top-0 -left-2 h-4 w-4 rounded-full bg-green-500"></div>
<div
class="rounded-lg shadow-sm p-6 transition-shadow"
class="rounded-lg p-6 shadow-sm transition-shadow"
:class="isDark ? 'bg-black hover:shadow-md hover:shadow-black' : 'bg-white hover:shadow-md'">
<div class="flex items-start justify-between mb-4">
<div class="mb-4 flex items-start justify-between">
<div>
<h2 class="text-xl font-semibold" :class="isDark ? 'text-white' : 'text-gray-900'">
{{ release.name || release.tag_name }}
</h2>
<p class="text-sm mt-1" :class="isDark ? 'text-gray-400' : 'text-gray-500'">
<p class="mt-1 text-sm" :class="isDark ? 'text-gray-400' : 'text-gray-500'">
{{ formatDate(release.published_at) }}
</p>
</div>
<span
class="inline-flex items-center px-3 py-1 rounded-full text-sm font-medium"
class="inline-flex items-center rounded-full px-3 py-1 text-sm font-medium"
:class="isDark ? 'bg-green-900 text-green-200' : 'bg-green-100 text-green-800'">
{{ release.tag_name }}
</span>

View File

@ -1,6 +1,6 @@
import { exec } from 'child_process'
import * as fs from 'fs/promises'
import linguistLanguages from 'linguist-languages'
import * as linguistLanguages from 'linguist-languages'
import * as path from 'path'
import { promisify } from 'util'

128
src/main/apiServer/app.ts Normal file
View File

@ -0,0 +1,128 @@
import { loggerService } from '@main/services/LoggerService'
import cors from 'cors'
import express from 'express'
import { v4 as uuidv4 } from 'uuid'
import { authMiddleware } from './middleware/auth'
import { errorHandler } from './middleware/error'
import { setupOpenAPIDocumentation } from './middleware/openapi'
import { chatRoutes } from './routes/chat'
import { mcpRoutes } from './routes/mcp'
import { modelsRoutes } from './routes/models'
const logger = loggerService.withContext('ApiServer')
const app = express()
// Global middleware
app.use((req, res, next) => {
const start = Date.now()
res.on('finish', () => {
const duration = Date.now() - start
logger.info(`${req.method} ${req.path} - ${res.statusCode} - ${duration}ms`)
})
next()
})
app.use((_req, res, next) => {
res.setHeader('X-Request-ID', uuidv4())
next()
})
app.use(
cors({
origin: '*',
allowedHeaders: ['Content-Type', 'Authorization'],
methods: ['GET', 'POST', 'PUT', 'DELETE', 'OPTIONS']
})
)
/**
* @swagger
* /health:
* get:
* summary: Health check endpoint
* description: Check server status (no authentication required)
* tags: [Health]
* security: []
* responses:
* 200:
* description: Server is healthy
* content:
* application/json:
* schema:
* type: object
* properties:
* status:
* type: string
* example: ok
* timestamp:
* type: string
* format: date-time
* version:
* type: string
* example: 1.0.0
*/
app.get('/health', (_req, res) => {
res.json({
status: 'ok',
timestamp: new Date().toISOString(),
version: process.env.npm_package_version || '1.0.0'
})
})
/**
* @swagger
* /:
* get:
* summary: API information
* description: Get basic API information and available endpoints
* tags: [General]
* security: []
* responses:
* 200:
* description: API information
* content:
* application/json:
* schema:
* type: object
* properties:
* name:
* type: string
* example: Cherry Studio API
* version:
* type: string
* example: 1.0.0
* endpoints:
* type: object
*/
app.get('/', (_req, res) => {
res.json({
name: 'Cherry Studio API',
version: '1.0.0',
endpoints: {
health: 'GET /health',
models: 'GET /v1/models',
chat: 'POST /v1/chat/completions',
mcp: 'GET /v1/mcps'
}
})
})
// API v1 routes with auth
const apiRouter = express.Router()
apiRouter.use(authMiddleware)
apiRouter.use(express.json())
// Mount routes
apiRouter.use('/chat', chatRoutes)
apiRouter.use('/mcps', mcpRoutes)
apiRouter.use('/models', modelsRoutes)
app.use('/v1', apiRouter)
// Setup OpenAPI documentation
setupOpenAPIDocumentation(app)
// Error handling (must be last)
app.use(errorHandler)
export { app }

View File

@ -0,0 +1,65 @@
import { ApiServerConfig } from '@types'
import { v4 as uuidv4 } from 'uuid'
import { loggerService } from '../services/LoggerService'
import { reduxService } from '../services/ReduxService'
const logger = loggerService.withContext('ApiServerConfig')
const defaultHost = 'localhost'
const defaultPort = 23333
class ConfigManager {
private _config: ApiServerConfig | null = null
private generateApiKey(): string {
return `cs-sk-${uuidv4()}`
}
async load(): Promise<ApiServerConfig> {
try {
const settings = await reduxService.select('state.settings')
const serverSettings = settings?.apiServer
let apiKey = serverSettings?.apiKey
if (!apiKey || apiKey.trim() === '') {
apiKey = this.generateApiKey()
await reduxService.dispatch({
type: 'settings/setApiServerApiKey',
payload: apiKey
})
}
this._config = {
enabled: serverSettings?.enabled ?? false,
port: serverSettings?.port ?? defaultPort,
host: defaultHost,
apiKey: apiKey
}
return this._config
} catch (error: any) {
logger.warn('Failed to load config from Redux, using defaults:', error)
this._config = {
enabled: false,
port: defaultPort,
host: defaultHost,
apiKey: this.generateApiKey()
}
return this._config
}
}
async get(): Promise<ApiServerConfig> {
if (!this._config) {
await this.load()
}
if (!this._config) {
throw new Error('Failed to load API server configuration')
}
return this._config
}
async reload(): Promise<ApiServerConfig> {
return await this.load()
}
}
export const config = new ConfigManager()

View File

@ -0,0 +1,2 @@
export { config } from './config'
export { apiServer } from './server'

View File

@ -0,0 +1,62 @@
import crypto from 'crypto'
import { NextFunction, Request, Response } from 'express'
import { config } from '../config'
export const authMiddleware = async (req: Request, res: Response, next: NextFunction) => {
const auth = req.header('Authorization') || ''
const xApiKey = req.header('x-api-key') || ''
// Fast rejection if neither credential header provided
if (!auth && !xApiKey) {
return res.status(401).json({ error: 'Unauthorized: missing credentials' })
}
let token: string | undefined
// Prefer Bearer if wellformed
if (auth) {
const trimmed = auth.trim()
const bearerPrefix = /^Bearer\s+/i
if (bearerPrefix.test(trimmed)) {
const candidate = trimmed.replace(bearerPrefix, '').trim()
if (!candidate) {
return res.status(401).json({ error: 'Unauthorized: empty bearer token' })
}
token = candidate
}
}
// Fallback to x-api-key if token still not resolved
if (!token && xApiKey) {
if (!xApiKey.trim()) {
return res.status(401).json({ error: 'Unauthorized: empty x-api-key' })
}
token = xApiKey.trim()
}
if (!token) {
// At this point we had at least one header, but none yielded a usable token
return res.status(401).json({ error: 'Unauthorized: invalid credentials format' })
}
const { apiKey } = await config.get()
if (!apiKey) {
// If server not configured, treat as forbidden (or could be 500). Choose 403 to avoid leaking config state.
return res.status(403).json({ error: 'Forbidden' })
}
// Timing-safe compare when lengths match, else immediate forbidden
if (token.length !== apiKey.length) {
return res.status(403).json({ error: 'Forbidden' })
}
const tokenBuf = Buffer.from(token)
const keyBuf = Buffer.from(apiKey)
if (!crypto.timingSafeEqual(tokenBuf, keyBuf)) {
return res.status(403).json({ error: 'Forbidden' })
}
return next()
}

View File

@ -0,0 +1,21 @@
import { NextFunction, Request, Response } from 'express'
import { loggerService } from '../../services/LoggerService'
const logger = loggerService.withContext('ApiServerErrorHandler')
// eslint-disable-next-line @typescript-eslint/no-unused-vars
export const errorHandler = (err: Error, _req: Request, res: Response, _next: NextFunction) => {
logger.error('API Server Error:', err)
// Don't expose internal errors in production
const isDev = process.env.NODE_ENV === 'development'
res.status(500).json({
error: {
message: isDev ? err.message : 'Internal server error',
type: 'server_error',
...(isDev && { stack: err.stack })
}
})
}

View File

@ -0,0 +1,206 @@
import { Express } from 'express'
import swaggerJSDoc from 'swagger-jsdoc'
import swaggerUi from 'swagger-ui-express'
import { loggerService } from '../../services/LoggerService'
const logger = loggerService.withContext('OpenAPIMiddleware')
const swaggerOptions: swaggerJSDoc.Options = {
definition: {
openapi: '3.0.0',
info: {
title: 'Cherry Studio API',
version: '1.0.0',
description: 'OpenAI-compatible API for Cherry Studio with additional Cherry-specific endpoints',
contact: {
name: 'Cherry Studio',
url: 'https://github.com/CherryHQ/cherry-studio'
}
},
servers: [
{
url: 'http://localhost:23333',
description: 'Local development server'
}
],
components: {
securitySchemes: {
BearerAuth: {
type: 'http',
scheme: 'bearer',
bearerFormat: 'JWT',
description: 'Use the API key from Cherry Studio settings'
}
},
schemas: {
Error: {
type: 'object',
properties: {
error: {
type: 'object',
properties: {
message: { type: 'string' },
type: { type: 'string' },
code: { type: 'string' }
}
}
}
},
ChatMessage: {
type: 'object',
properties: {
role: {
type: 'string',
enum: ['system', 'user', 'assistant', 'tool']
},
content: {
oneOf: [
{ type: 'string' },
{
type: 'array',
items: {
type: 'object',
properties: {
type: { type: 'string' },
text: { type: 'string' },
image_url: {
type: 'object',
properties: {
url: { type: 'string' }
}
}
}
}
}
]
},
name: { type: 'string' },
tool_calls: {
type: 'array',
items: {
type: 'object',
properties: {
id: { type: 'string' },
type: { type: 'string' },
function: {
type: 'object',
properties: {
name: { type: 'string' },
arguments: { type: 'string' }
}
}
}
}
}
}
},
ChatCompletionRequest: {
type: 'object',
required: ['model', 'messages'],
properties: {
model: {
type: 'string',
description: 'The model to use for completion, in format provider:model-id'
},
messages: {
type: 'array',
items: { $ref: '#/components/schemas/ChatMessage' }
},
temperature: {
type: 'number',
minimum: 0,
maximum: 2,
default: 1
},
max_tokens: {
type: 'integer',
minimum: 1
},
stream: {
type: 'boolean',
default: false
},
tools: {
type: 'array',
items: {
type: 'object',
properties: {
type: { type: 'string' },
function: {
type: 'object',
properties: {
name: { type: 'string' },
description: { type: 'string' },
parameters: { type: 'object' }
}
}
}
}
}
}
},
Model: {
type: 'object',
properties: {
id: { type: 'string' },
object: { type: 'string', enum: ['model'] },
created: { type: 'integer' },
owned_by: { type: 'string' }
}
},
MCPServer: {
type: 'object',
properties: {
id: { type: 'string' },
name: { type: 'string' },
command: { type: 'string' },
args: {
type: 'array',
items: { type: 'string' }
},
env: { type: 'object' },
disabled: { type: 'boolean' }
}
}
}
},
security: [
{
BearerAuth: []
}
]
},
apis: ['./src/main/apiServer/routes/*.ts', './src/main/apiServer/app.ts']
}
export function setupOpenAPIDocumentation(app: Express) {
try {
const specs = swaggerJSDoc(swaggerOptions)
// Serve OpenAPI JSON
app.get('/api-docs.json', (_req, res) => {
res.setHeader('Content-Type', 'application/json')
res.send(specs)
})
// Serve Swagger UI
app.use(
'/api-docs',
swaggerUi.serve,
swaggerUi.setup(specs, {
customCss: `
.swagger-ui .topbar { display: none; }
.swagger-ui .info .title { color: #1890ff; }
`,
customSiteTitle: 'Cherry Studio API Documentation'
})
)
logger.info('OpenAPI documentation setup complete')
logger.info('Documentation available at /api-docs')
logger.info('OpenAPI spec available at /api-docs.json')
} catch (error) {
logger.error('Failed to setup OpenAPI documentation:', error as Error)
}
}

View File

@ -0,0 +1,225 @@
import express, { Request, Response } from 'express'
import OpenAI from 'openai'
import { ChatCompletionCreateParams } from 'openai/resources'
import { loggerService } from '../../services/LoggerService'
import { chatCompletionService } from '../services/chat-completion'
import { validateModelId } from '../utils'
const logger = loggerService.withContext('ApiServerChatRoutes')
const router = express.Router()
/**
* @swagger
* /v1/chat/completions:
* post:
* summary: Create chat completion
* description: Create a chat completion response, compatible with OpenAI API
* tags: [Chat]
* requestBody:
* required: true
* content:
* application/json:
* schema:
* $ref: '#/components/schemas/ChatCompletionRequest'
* responses:
* 200:
* description: Chat completion response
* content:
* application/json:
* schema:
* type: object
* properties:
* id:
* type: string
* object:
* type: string
* example: chat.completion
* created:
* type: integer
* model:
* type: string
* choices:
* type: array
* items:
* type: object
* properties:
* index:
* type: integer
* message:
* $ref: '#/components/schemas/ChatMessage'
* finish_reason:
* type: string
* usage:
* type: object
* properties:
* prompt_tokens:
* type: integer
* completion_tokens:
* type: integer
* total_tokens:
* type: integer
* text/plain:
* schema:
* type: string
* description: Server-sent events stream (when stream=true)
* 400:
* description: Bad request
* content:
* application/json:
* schema:
* $ref: '#/components/schemas/Error'
* 401:
* description: Unauthorized
* content:
* application/json:
* schema:
* $ref: '#/components/schemas/Error'
* 429:
* description: Rate limit exceeded
* content:
* application/json:
* schema:
* $ref: '#/components/schemas/Error'
* 500:
* description: Internal server error
* content:
* application/json:
* schema:
* $ref: '#/components/schemas/Error'
*/
router.post('/completions', async (req: Request, res: Response) => {
try {
const request: ChatCompletionCreateParams = req.body
if (!request) {
return res.status(400).json({
error: {
message: 'Request body is required',
type: 'invalid_request_error',
code: 'missing_body'
}
})
}
logger.info('Chat completion request:', {
model: request.model,
messageCount: request.messages?.length || 0,
stream: request.stream,
temperature: request.temperature
})
// Validate request
const validation = chatCompletionService.validateRequest(request)
if (!validation.isValid) {
return res.status(400).json({
error: {
message: validation.errors.join('; '),
type: 'invalid_request_error',
code: 'validation_failed'
}
})
}
// Validate model ID and get provider
const modelValidation = await validateModelId(request.model)
if (!modelValidation.valid) {
const error = modelValidation.error!
logger.warn(`Model validation failed for '${request.model}':`, error)
return res.status(400).json({
error: {
message: error.message,
type: 'invalid_request_error',
code: error.code
}
})
}
const provider = modelValidation.provider!
const modelId = modelValidation.modelId!
logger.info('Model validation successful:', {
provider: provider.id,
providerType: provider.type,
modelId: modelId,
fullModelId: request.model
})
// Create OpenAI client
const client = new OpenAI({
baseURL: provider.apiHost,
apiKey: provider.apiKey
})
request.model = modelId
// Handle streaming
if (request.stream) {
const streamResponse = await client.chat.completions.create(request)
res.setHeader('Content-Type', 'text/plain; charset=utf-8')
res.setHeader('Cache-Control', 'no-cache')
res.setHeader('Connection', 'keep-alive')
try {
for await (const chunk of streamResponse as any) {
res.write(`data: ${JSON.stringify(chunk)}\n\n`)
}
res.write('data: [DONE]\n\n')
res.end()
} catch (streamError: any) {
logger.error('Stream error:', streamError)
res.write(
`data: ${JSON.stringify({
error: {
message: 'Stream processing error',
type: 'server_error',
code: 'stream_error'
}
})}\n\n`
)
res.end()
}
return
}
// Handle non-streaming
const response = await client.chat.completions.create(request)
return res.json(response)
} catch (error: any) {
logger.error('Chat completion error:', error)
let statusCode = 500
let errorType = 'server_error'
let errorCode = 'internal_error'
let errorMessage = 'Internal server error'
if (error instanceof Error) {
errorMessage = error.message
if (error.message.includes('API key') || error.message.includes('authentication')) {
statusCode = 401
errorType = 'authentication_error'
errorCode = 'invalid_api_key'
} else if (error.message.includes('rate limit') || error.message.includes('quota')) {
statusCode = 429
errorType = 'rate_limit_error'
errorCode = 'rate_limit_exceeded'
} else if (error.message.includes('timeout') || error.message.includes('connection')) {
statusCode = 502
errorType = 'server_error'
errorCode = 'upstream_error'
}
}
return res.status(statusCode).json({
error: {
message: errorMessage,
type: errorType,
code: errorCode
}
})
}
})
export { router as chatRoutes }

View File

@ -0,0 +1,153 @@
import express, { Request, Response } from 'express'
import { loggerService } from '../../services/LoggerService'
import { mcpApiService } from '../services/mcp'
const logger = loggerService.withContext('ApiServerMCPRoutes')
const router = express.Router()
/**
* @swagger
* /v1/mcps:
* get:
* summary: List MCP servers
* description: Get a list of all configured Model Context Protocol servers
* tags: [MCP]
* responses:
* 200:
* description: List of MCP servers
* content:
* application/json:
* schema:
* type: object
* properties:
* success:
* type: boolean
* data:
* type: array
* items:
* $ref: '#/components/schemas/MCPServer'
* 503:
* description: Service unavailable
* content:
* application/json:
* schema:
* type: object
* properties:
* success:
* type: boolean
* example: false
* error:
* $ref: '#/components/schemas/Error'
*/
router.get('/', async (req: Request, res: Response) => {
try {
logger.info('Get all MCP servers request received')
const servers = await mcpApiService.getAllServers(req)
return res.json({
success: true,
data: servers
})
} catch (error: any) {
logger.error('Error fetching MCP servers:', error)
return res.status(503).json({
success: false,
error: {
message: `Failed to retrieve MCP servers: ${error.message}`,
type: 'service_unavailable',
code: 'servers_unavailable'
}
})
}
})
/**
* @swagger
* /v1/mcps/{server_id}:
* get:
* summary: Get MCP server info
* description: Get detailed information about a specific MCP server
* tags: [MCP]
* parameters:
* - in: path
* name: server_id
* required: true
* schema:
* type: string
* description: MCP server ID
* responses:
* 200:
* description: MCP server information
* content:
* application/json:
* schema:
* type: object
* properties:
* success:
* type: boolean
* data:
* $ref: '#/components/schemas/MCPServer'
* 404:
* description: MCP server not found
* content:
* application/json:
* schema:
* type: object
* properties:
* success:
* type: boolean
* example: false
* error:
* $ref: '#/components/schemas/Error'
*/
router.get('/:server_id', async (req: Request, res: Response) => {
try {
logger.info('Get MCP server info request received')
const server = await mcpApiService.getServerInfo(req.params.server_id)
if (!server) {
logger.warn('MCP server not found')
return res.status(404).json({
success: false,
error: {
message: 'MCP server not found',
type: 'not_found',
code: 'server_not_found'
}
})
}
return res.json({
success: true,
data: server
})
} catch (error: any) {
logger.error('Error fetching MCP server info:', error)
return res.status(503).json({
success: false,
error: {
message: `Failed to retrieve MCP server info: ${error.message}`,
type: 'service_unavailable',
code: 'server_info_unavailable'
}
})
}
})
// Connect to MCP server
router.all('/:server_id/mcp', async (req: Request, res: Response) => {
const server = await mcpApiService.getServerById(req.params.server_id)
if (!server) {
logger.warn('MCP server not found')
return res.status(404).json({
success: false,
error: {
message: 'MCP server not found',
type: 'not_found',
code: 'server_not_found'
}
})
}
return await mcpApiService.handleRequest(req, res, server)
})
export { router as mcpRoutes }

View File

@ -0,0 +1,73 @@
import express, { Request, Response } from 'express'
import { loggerService } from '../../services/LoggerService'
import { chatCompletionService } from '../services/chat-completion'
const logger = loggerService.withContext('ApiServerModelsRoutes')
const router = express.Router()
/**
* @swagger
* /v1/models:
* get:
* summary: List available models
* description: Returns a list of available AI models from all configured providers
* tags: [Models]
* responses:
* 200:
* description: List of available models
* content:
* application/json:
* schema:
* type: object
* properties:
* object:
* type: string
* example: list
* data:
* type: array
* items:
* $ref: '#/components/schemas/Model'
* 503:
* description: Service unavailable
* content:
* application/json:
* schema:
* $ref: '#/components/schemas/Error'
*/
router.get('/', async (_req: Request, res: Response) => {
try {
logger.info('Models list request received')
const models = await chatCompletionService.getModels()
if (models.length === 0) {
logger.warn(
'No models available from providers. This may be because no OpenAI providers are configured or enabled.'
)
}
logger.info(`Returning ${models.length} models (OpenAI providers only)`)
logger.debug(
'Model IDs:',
models.map((m) => m.id)
)
return res.json({
object: 'list',
data: models
})
} catch (error: any) {
logger.error('Error fetching models:', error)
return res.status(503).json({
error: {
message: 'Failed to retrieve models from available providers',
type: 'service_unavailable',
code: 'models_unavailable'
}
})
}
})
export { router as modelsRoutes }

View File

@ -0,0 +1,65 @@
import { createServer } from 'node:http'
import { loggerService } from '../services/LoggerService'
import { app } from './app'
import { config } from './config'
const logger = loggerService.withContext('ApiServer')
export class ApiServer {
private server: ReturnType<typeof createServer> | null = null
async start(): Promise<void> {
if (this.server) {
logger.warn('Server already running')
return
}
// Load config
const { port, host, apiKey } = await config.load()
// Create server with Express app
this.server = createServer(app)
// Start server
return new Promise((resolve, reject) => {
this.server!.listen(port, host, () => {
logger.info(`API Server started at http://${host}:${port}`)
logger.info(`API Key: ${apiKey}`)
resolve()
})
this.server!.on('error', reject)
})
}
async stop(): Promise<void> {
if (!this.server) return
return new Promise((resolve) => {
this.server!.close(() => {
logger.info('API Server stopped')
this.server = null
resolve()
})
})
}
async restart(): Promise<void> {
await this.stop()
await config.reload()
await this.start()
}
isRunning(): boolean {
const hasServer = this.server !== null
const isListening = this.server?.listening || false
const result = hasServer && isListening
logger.debug('isRunning check:', { hasServer, isListening, result })
return result
}
}
export const apiServer = new ApiServer()

View File

@ -0,0 +1,239 @@
import OpenAI from 'openai'
import { ChatCompletionCreateParams } from 'openai/resources'
import { loggerService } from '../../services/LoggerService'
import {
getProviderByModel,
getRealProviderModel,
listAllAvailableModels,
OpenAICompatibleModel,
transformModelToOpenAI,
validateProvider
} from '../utils'
const logger = loggerService.withContext('ChatCompletionService')
export interface ModelData extends OpenAICompatibleModel {
provider_id: string
model_id: string
name: string
}
export interface ValidationResult {
isValid: boolean
errors: string[]
}
export class ChatCompletionService {
async getModels(): Promise<ModelData[]> {
try {
logger.info('Getting available models from providers')
const models = await listAllAvailableModels()
// Use Map to deduplicate models by their full ID (provider:model_id)
const uniqueModels = new Map<string, ModelData>()
for (const model of models) {
const openAIModel = transformModelToOpenAI(model)
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
// Only add if not already present (first occurrence wins)
if (!uniqueModels.has(fullModelId)) {
uniqueModels.set(fullModelId, {
...openAIModel,
provider_id: model.provider,
model_id: model.id,
name: model.name
})
} else {
logger.debug(`Skipping duplicate model: ${fullModelId}`)
}
}
const modelData = Array.from(uniqueModels.values())
logger.info(`Successfully retrieved ${modelData.length} unique models from ${models.length} total models`)
if (models.length > modelData.length) {
logger.debug(`Filtered out ${models.length - modelData.length} duplicate models`)
}
return modelData
} catch (error: any) {
logger.error('Error getting models:', error)
return []
}
}
validateRequest(request: ChatCompletionCreateParams): ValidationResult {
const errors: string[] = []
// Validate model
if (!request.model) {
errors.push('Model is required')
} else if (typeof request.model !== 'string') {
errors.push('Model must be a string')
} else if (!request.model.includes(':')) {
errors.push('Model must be in format "provider:model_id"')
}
// Validate messages
if (!request.messages) {
errors.push('Messages array is required')
} else if (!Array.isArray(request.messages)) {
errors.push('Messages must be an array')
} else if (request.messages.length === 0) {
errors.push('Messages array cannot be empty')
} else {
// Validate each message
request.messages.forEach((message, index) => {
if (!message.role) {
errors.push(`Message ${index}: role is required`)
}
if (!message.content) {
errors.push(`Message ${index}: content is required`)
}
})
}
// Validate optional parameters
if (request.temperature !== undefined) {
if (typeof request.temperature !== 'number' || request.temperature < 0 || request.temperature > 2) {
errors.push('Temperature must be a number between 0 and 2')
}
}
if (request.max_tokens !== undefined) {
if (typeof request.max_tokens !== 'number' || request.max_tokens < 1) {
errors.push('max_tokens must be a positive number')
}
}
return {
isValid: errors.length === 0,
errors
}
}
async processCompletion(request: ChatCompletionCreateParams): Promise<OpenAI.Chat.Completions.ChatCompletion> {
try {
logger.info('Processing chat completion request:', {
model: request.model,
messageCount: request.messages.length,
stream: request.stream
})
// Validate request
const validation = this.validateRequest(request)
if (!validation.isValid) {
throw new Error(`Request validation failed: ${validation.errors.join(', ')}`)
}
// Get provider for the model
const provider = await getProviderByModel(request.model!)
if (!provider) {
throw new Error(`Provider not found for model: ${request.model}`)
}
// Validate provider
if (!validateProvider(provider)) {
throw new Error(`Provider validation failed for: ${provider.id}`)
}
// Extract model ID from the full model string
const modelId = getRealProviderModel(request.model)
// Create OpenAI client for the provider
const client = new OpenAI({
baseURL: provider.apiHost,
apiKey: provider.apiKey
})
// Prepare request with the actual model ID
const providerRequest = {
...request,
model: modelId,
stream: false
}
logger.debug('Sending request to provider:', {
provider: provider.id,
model: modelId,
apiHost: provider.apiHost
})
const response = (await client.chat.completions.create(providerRequest)) as OpenAI.Chat.Completions.ChatCompletion
logger.info('Successfully processed chat completion')
return response
} catch (error: any) {
logger.error('Error processing chat completion:', error)
throw error
}
}
async *processStreamingCompletion(
request: ChatCompletionCreateParams
): AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk> {
try {
logger.info('Processing streaming chat completion request:', {
model: request.model,
messageCount: request.messages.length
})
// Validate request
const validation = this.validateRequest(request)
if (!validation.isValid) {
throw new Error(`Request validation failed: ${validation.errors.join(', ')}`)
}
// Get provider for the model
const provider = await getProviderByModel(request.model!)
if (!provider) {
throw new Error(`Provider not found for model: ${request.model}`)
}
// Validate provider
if (!validateProvider(provider)) {
throw new Error(`Provider validation failed for: ${provider.id}`)
}
// Extract model ID from the full model string
const modelId = getRealProviderModel(request.model)
// Create OpenAI client for the provider
const client = new OpenAI({
baseURL: provider.apiHost,
apiKey: provider.apiKey
})
// Prepare streaming request
const streamingRequest = {
...request,
model: modelId,
stream: true as const
}
logger.debug('Sending streaming request to provider:', {
provider: provider.id,
model: modelId,
apiHost: provider.apiHost
})
const stream = await client.chat.completions.create(streamingRequest)
for await (const chunk of stream) {
yield chunk
}
logger.info('Successfully completed streaming chat completion')
} catch (error: any) {
logger.error('Error processing streaming chat completion:', error)
throw error
}
}
}
// Export singleton instance
export const chatCompletionService = new ChatCompletionService()

View File

@ -0,0 +1,251 @@
import mcpService from '@main/services/MCPService'
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp'
import {
isJSONRPCRequest,
JSONRPCMessage,
JSONRPCMessageSchema,
MessageExtraInfo
} from '@modelcontextprotocol/sdk/types'
import { MCPServer } from '@types'
import { randomUUID } from 'crypto'
import { EventEmitter } from 'events'
import { Request, Response } from 'express'
import { IncomingMessage, ServerResponse } from 'http'
import { loggerService } from '../../services/LoggerService'
import { reduxService } from '../../services/ReduxService'
import { getMcpServerById } from '../utils/mcp'
const logger = loggerService.withContext('MCPApiService')
const transports: Record<string, StreamableHTTPServerTransport> = {}
interface McpServerDTO {
id: MCPServer['id']
name: MCPServer['name']
type: MCPServer['type']
description: MCPServer['description']
url: string
}
interface McpServersResp {
servers: Record<string, McpServerDTO>
}
/**
* MCPApiService - API layer for MCP server management
*
* This service provides a REST API interface for MCP servers while integrating
* with the existing application architecture:
*
* 1. Uses ReduxService to access the renderer's Redux store directly
* 2. Syncs changes back to the renderer via Redux actions
* 3. Leverages existing MCPService for actual server connections
* 4. Provides session management for API clients
*/
class MCPApiService extends EventEmitter {
private transport: StreamableHTTPServerTransport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID()
})
constructor() {
super()
this.initMcpServer()
logger.silly('MCPApiService initialized')
}
private initMcpServer() {
this.transport.onmessage = this.onMessage
}
/**
* Get servers directly from Redux store
*/
private async getServersFromRedux(): Promise<MCPServer[]> {
try {
logger.silly('Getting servers from Redux store')
// Try to get from cache first (faster)
const cachedServers = reduxService.selectSync<MCPServer[]>('state.mcp.servers')
if (cachedServers && Array.isArray(cachedServers)) {
logger.silly(`Found ${cachedServers.length} servers in Redux cache`)
return cachedServers
}
// If cache is not available, get fresh data
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`)
return servers || []
} catch (error: any) {
logger.error('Failed to get servers from Redux:', error)
return []
}
}
// get all activated servers
async getAllServers(req: Request): Promise<McpServersResp> {
try {
const servers = await this.getServersFromRedux()
logger.silly(`Returning ${servers.length} servers`)
const resp: McpServersResp = {
servers: {}
}
for (const server of servers) {
if (server.isActive) {
resp.servers[server.id] = {
id: server.id,
name: server.name,
type: 'streamableHttp',
description: server.description,
url: `${req.protocol}://${req.host}/v1/mcps/${server.id}/mcp`
}
}
}
return resp
} catch (error: any) {
logger.error('Failed to get all servers:', error)
throw new Error('Failed to retrieve servers')
}
}
// get server by id
async getServerById(id: string): Promise<MCPServer | null> {
try {
logger.silly(`getServerById called with id: ${id}`)
const servers = await this.getServersFromRedux()
const server = servers.find((s) => s.id === id)
if (!server) {
logger.warn(`Server with id ${id} not found`)
return null
}
logger.silly(`Returning server with id ${id}`)
return server
} catch (error: any) {
logger.error(`Failed to get server with id ${id}:`, error)
throw new Error('Failed to retrieve server')
}
}
async getServerInfo(id: string): Promise<any> {
try {
logger.silly(`getServerInfo called with id: ${id}`)
const server = await this.getServerById(id)
if (!server) {
logger.warn(`Server with id ${id} not found`)
return null
}
logger.silly(`Returning server info for id ${id}`)
const client = await mcpService.initClient(server)
const tools = await client.listTools()
logger.info(`Server with id ${id} info:`, { tools: JSON.stringify(tools) })
// const [version, tools, prompts, resources] = await Promise.all([
// () => {
// try {
// return client.getServerVersion()
// } catch (error) {
// logger.error(`Failed to get server version for id ${id}:`, { error: error })
// return '1.0.0'
// }
// },
// (() => {
// try {
// return client.listTools()
// } catch (error) {
// logger.error(`Failed to list tools for id ${id}:`, { error: error })
// return []
// }
// })(),
// (() => {
// try {
// return client.listPrompts()
// } catch (error) {
// logger.error(`Failed to list prompts for id ${id}:`, { error: error })
// return []
// }
// })(),
// (() => {
// try {
// return client.listResources()
// } catch (error) {
// logger.error(`Failed to list resources for id ${id}:`, { error: error })
// return []
// }
// })()
// ])
return {
id: server.id,
name: server.name,
type: server.type,
description: server.description,
tools
}
} catch (error: any) {
logger.error(`Failed to get server info with id ${id}:`, error)
throw new Error('Failed to retrieve server info')
}
}
async handleRequest(req: Request, res: Response, server: MCPServer) {
const sessionId = req.headers['mcp-session-id'] as string | undefined
logger.silly(`Handling request for server with sessionId ${sessionId}`)
let transport: StreamableHTTPServerTransport
if (sessionId && transports[sessionId]) {
transport = transports[sessionId]
} else {
transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
onsessioninitialized: (sessionId) => {
transports[sessionId] = transport
}
})
transport.onclose = () => {
logger.info(`Transport for sessionId ${sessionId} closed`)
if (transport.sessionId) {
delete transports[transport.sessionId]
}
}
const mcpServer = await getMcpServerById(server.id)
if (mcpServer) {
await mcpServer.connect(transport)
}
}
const jsonpayload = req.body
const messages: JSONRPCMessage[] = []
if (Array.isArray(jsonpayload)) {
for (const payload of jsonpayload) {
const message = JSONRPCMessageSchema.parse(payload)
messages.push(message)
}
} else {
const message = JSONRPCMessageSchema.parse(jsonpayload)
messages.push(message)
}
for (const message of messages) {
if (isJSONRPCRequest(message)) {
if (!message.params) {
message.params = {}
}
if (!message.params._meta) {
message.params._meta = {}
}
message.params._meta.serverId = server.id
}
}
logger.info(`Request body`, { rawBody: req.body, messages: JSON.stringify(messages) })
await transport.handleRequest(req as IncomingMessage, res as ServerResponse, messages)
}
private onMessage(message: JSONRPCMessage, extra?: MessageExtraInfo) {
logger.info(`Received message: ${JSON.stringify(message)}`, extra)
// Handle message here
}
}
export const mcpApiService = new MCPApiService()

View File

@ -0,0 +1,231 @@
import { loggerService } from '@main/services/LoggerService'
import { reduxService } from '@main/services/ReduxService'
import { Model, Provider } from '@types'
const logger = loggerService.withContext('ApiServerUtils')
// OpenAI compatible model format
export interface OpenAICompatibleModel {
id: string
object: 'model'
created: number
owned_by: string
provider?: string
provider_model_id?: string
}
export async function getAvailableProviders(): Promise<Provider[]> {
try {
// Wait for store to be ready before accessing providers
const providers = await reduxService.select('state.llm.providers')
if (!providers || !Array.isArray(providers)) {
logger.warn('No providers found in Redux store, returning empty array')
return []
}
// Only support OpenAI type providers for API server
const openAIProviders = providers.filter((p: Provider) => p.enabled && p.type === 'openai')
logger.info(`Filtered to ${openAIProviders.length} OpenAI providers from ${providers.length} total providers`)
return openAIProviders
} catch (error: any) {
logger.error('Failed to get providers from Redux store:', error)
return []
}
}
export async function listAllAvailableModels(): Promise<Model[]> {
try {
const providers = await getAvailableProviders()
return providers.map((p: Provider) => p.models || []).flat()
} catch (error: any) {
logger.error('Failed to list available models:', error)
return []
}
}
export async function getProviderByModel(model: string): Promise<Provider | undefined> {
try {
if (!model || typeof model !== 'string') {
logger.warn(`Invalid model parameter: ${model}`)
return undefined
}
// Validate model format first
if (!model.includes(':')) {
logger.warn(
`Invalid model format, must contain ':' separator. Expected format "provider:model_id", got: ${model}`
)
return undefined
}
const providers = await getAvailableProviders()
const modelInfo = model.split(':')
if (modelInfo.length < 2 || modelInfo[0].length === 0 || modelInfo[1].length === 0) {
logger.warn(`Invalid model format, expected "provider:model_id" with non-empty parts, got: ${model}`)
return undefined
}
const providerId = modelInfo[0]
const provider = providers.find((p: Provider) => p.id === providerId)
if (!provider) {
logger.warn(
`Provider '${providerId}' not found or not enabled. Available providers: ${providers.map((p) => p.id).join(', ')}`
)
return undefined
}
logger.debug(`Found provider '${providerId}' for model: ${model}`)
return provider
} catch (error: any) {
logger.error('Failed to get provider by model:', error)
return undefined
}
}
export function getRealProviderModel(modelStr: string): string {
return modelStr.split(':').slice(1).join(':')
}
export interface ModelValidationError {
type: 'invalid_format' | 'provider_not_found' | 'model_not_available' | 'unsupported_provider_type'
message: string
code: string
}
export async function validateModelId(
model: string
): Promise<{ valid: boolean; error?: ModelValidationError; provider?: Provider; modelId?: string }> {
try {
if (!model || typeof model !== 'string') {
return {
valid: false,
error: {
type: 'invalid_format',
message: 'Model must be a non-empty string',
code: 'invalid_model_parameter'
}
}
}
if (!model.includes(':')) {
return {
valid: false,
error: {
type: 'invalid_format',
message: "Invalid model format. Expected format: 'provider:model_id' (e.g., 'my-openai:gpt-4')",
code: 'invalid_model_format'
}
}
}
const modelInfo = model.split(':')
if (modelInfo.length < 2 || modelInfo[0].length === 0 || modelInfo[1].length === 0) {
return {
valid: false,
error: {
type: 'invalid_format',
message: "Invalid model format. Both provider and model_id must be non-empty. Expected: 'provider:model_id'",
code: 'invalid_model_format'
}
}
}
const providerId = modelInfo[0]
const modelId = getRealProviderModel(model)
const provider = await getProviderByModel(model)
if (!provider) {
return {
valid: false,
error: {
type: 'provider_not_found',
message: `Provider '${providerId}' not found, not enabled, or not supported. Only OpenAI providers are currently supported.`,
code: 'provider_not_found'
}
}
}
// Check if model exists in provider
const modelExists = provider.models?.some((m) => m.id === modelId)
if (!modelExists) {
const availableModels = provider.models?.map((m) => m.id).join(', ') || 'none'
return {
valid: false,
error: {
type: 'model_not_available',
message: `Model '${modelId}' not available in provider '${providerId}'. Available models: ${availableModels}`,
code: 'model_not_available'
}
}
}
return {
valid: true,
provider,
modelId
}
} catch (error: any) {
logger.error('Error validating model ID:', error)
return {
valid: false,
error: {
type: 'invalid_format',
message: 'Failed to validate model ID',
code: 'validation_error'
}
}
}
}
export function transformModelToOpenAI(model: Model): OpenAICompatibleModel {
return {
id: `${model.provider}:${model.id}`,
object: 'model',
created: Math.floor(Date.now() / 1000),
owned_by: model.owned_by || model.provider,
provider: model.provider,
provider_model_id: model.id
}
}
export function validateProvider(provider: Provider): boolean {
try {
if (!provider) {
return false
}
// Check required fields
if (!provider.id || !provider.type || !provider.apiKey || !provider.apiHost) {
logger.warn('Provider missing required fields:', {
id: !!provider.id,
type: !!provider.type,
apiKey: !!provider.apiKey,
apiHost: !!provider.apiHost
})
return false
}
// Check if provider is enabled
if (!provider.enabled) {
logger.debug(`Provider is disabled: ${provider.id}`)
return false
}
// Only support OpenAI type providers
if (provider.type !== 'openai') {
logger.debug(
`Provider type '${provider.type}' not supported, only 'openai' type is currently supported: ${provider.id}`
)
return false
}
return true
} catch (error: any) {
logger.error('Error validating provider:', error)
return false
}
}

View File

@ -0,0 +1,76 @@
import mcpService from '@main/services/MCPService'
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
import { CallToolRequestSchema, ListToolsRequestSchema, ListToolsResult } from '@modelcontextprotocol/sdk/types.js'
import { MCPServer } from '@types'
import { loggerService } from '../../services/LoggerService'
import { reduxService } from '../../services/ReduxService'
const logger = loggerService.withContext('MCPApiService')
const cachedServers: Record<string, Server> = {}
async function handleListToolsRequest(request: any, extra: any): Promise<ListToolsResult> {
logger.debug('Handling list tools request', { request: request, extra: extra })
const serverId: string = request.params._meta.serverId
const serverConfig = await getMcpServerConfigById(serverId)
if (!serverConfig) {
throw new Error(`Server not found: ${serverId}`)
}
const client = await mcpService.initClient(serverConfig)
return client.listTools()
}
async function handleCallToolRequest(request: any, extra: any): Promise<any> {
logger.debug('Handling call tool request', { request: request, extra: extra })
const serverId: string = request.params._meta.serverId
const serverConfig = await getMcpServerConfigById(serverId)
if (!serverConfig) {
throw new Error(`Server not found: ${serverId}`)
}
const client = await mcpService.initClient(serverConfig)
return client.callTool(request.params)
}
async function getMcpServerConfigById(id: string): Promise<MCPServer | undefined> {
const servers = await getServersFromRedux()
return servers.find((s) => s.id === id || s.name === id)
}
/**
* Get servers directly from Redux store
*/
async function getServersFromRedux(): Promise<MCPServer[]> {
try {
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`)
return servers || []
} catch (error: any) {
logger.error('Failed to get servers from Redux:', error)
return []
}
}
export async function getMcpServerById(id: string): Promise<Server> {
const server = cachedServers[id]
if (!server) {
const servers = await getServersFromRedux()
const mcpServer = servers.find((s) => s.id === id || s.name === id)
if (!mcpServer) {
throw new Error(`Server not found: ${id}`)
}
const createMcpServer = (name: string, version: string): Server => {
const server = new Server({ name: name, version }, { capabilities: { tools: {} } })
server.setRequestHandler(ListToolsRequestSchema, handleListToolsRequest)
server.setRequestHandler(CallToolRequestSchema, handleCallToolRequest)
return server
}
const newServer = createMcpServer(mcpServer.name, '0.1.0')
cachedServers[id] = newServer
return newServer
}
logger.silly('getMcpServer ', { server: server })
return server
}

View File

@ -31,6 +31,7 @@ import { TrayService } from './services/TrayService'
import { windowService } from './services/WindowService'
import { dataRefactorMigrateService } from './data/migrate/dataRefactor/DataRefactorMigrateService'
import process from 'node:process'
import { apiServerService } from './services/ApiServerService'
const logger = loggerService.withContext('MainEntry')
@ -219,6 +220,17 @@ if (!app.requestSingleInstanceLock()) {
//start selection assistant service
initSelectionService()
// Start API server if enabled
try {
const config = await apiServerService.getCurrentConfig()
logger.info('API server config:', config)
if (config.enabled) {
await apiServerService.start()
}
} catch (error: any) {
logger.error('Failed to check/start API server:', error)
}
})
registerProtocolClient(app)
@ -264,6 +276,7 @@ if (!app.requestSingleInstanceLock()) {
// 简单的资源清理,不阻塞退出流程
try {
await mcpService.cleanup()
await apiServerService.stop()
} catch (error) {
logger.warn('Error cleaning up MCP service:', error as Error)
}

View File

@ -15,9 +15,12 @@ import { MIN_WINDOW_HEIGHT, MIN_WINDOW_WIDTH } from '@shared/config/constant'
import { UpgradeChannel } from '@shared/data/preferenceTypes'
import { IpcChannel } from '@shared/IpcChannel'
import { FileMetadata, Provider, Shortcut } from '@types'
import checkDiskSpace from 'check-disk-space'
import { BrowserWindow, dialog, ipcMain, ProxyConfig, session, shell, systemPreferences, webContents } from 'electron'
import fontList from 'font-list'
import { Notification } from 'src/renderer/src/types/notification'
import { apiServerService } from './services/ApiServerService'
import appService from './services/AppService'
import AppUpdater from './services/AppUpdater'
import BackupManager from './services/BackupManager'
@ -217,6 +220,17 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
return mainWindow.isFullScreen()
})
// Get System Fonts
ipcMain.handle(IpcChannel.App_GetSystemFonts, async () => {
try {
const fonts = await fontList.getFonts()
return fonts.map((font: string) => font.replace(/^"(.*)"$/, '$1')).filter((font: string) => font.length > 0)
} catch (error) {
logger.error('Failed to get system fonts:', error as Error)
return []
}
})
ipcMain.handle(IpcChannel.Config_Set, (_, key: string, value: any, isNotify: boolean = false) => {
configManager.set(key, value, isNotify)
})
@ -782,6 +796,23 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
addStreamMessage(spanId, modelName, context, msg)
)
ipcMain.handle(IpcChannel.App_GetDiskInfo, async (_, directoryPath: string) => {
try {
const diskSpace = await checkDiskSpace(directoryPath) // { free, size } in bytes
logger.debug('disk space', diskSpace)
const { free, size } = diskSpace
return {
free,
size
}
} catch (error) {
logger.error('check disk space error', error as Error)
return null
}
})
// API Server
apiServerService.registerIpcHandlers()
// Anthropic OAuth
ipcMain.handle(IpcChannel.Anthropic_StartOAuthFlow, () => anthropicService.startOAuthFlow())
ipcMain.handle(IpcChannel.Anthropic_CompleteOAuthWithCode, (_, code: string) =>

View File

@ -1,7 +1,7 @@
import { Embeddings, type EmbeddingsParams } from '@langchain/core/embeddings'
import { chunkArray } from '@langchain/core/utils/chunk_array'
import { getEnvironmentVariable } from '@langchain/core/utils/env'
import z from 'zod/v4'
import { z } from 'zod'
const jinaModelSchema = z.union([
z.literal('jina-clip-v2'),

View File

@ -3,7 +3,7 @@ import { loggerService } from '@logger'
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'
import { net } from 'electron'
import * as z from 'zod/v4'
import { z } from 'zod'
const logger = loggerService.withContext('DifyKnowledgeServer')

View File

@ -8,8 +8,8 @@ import TurndownService from 'turndown'
import { z } from 'zod'
export const RequestPayloadSchema = z.object({
url: z.string().url(),
headers: z.record(z.string()).optional()
url: z.url(),
headers: z.record(z.string(), z.string()).optional()
})
export type RequestPayload = z.infer<typeof RequestPayloadSchema>

View File

@ -8,7 +8,7 @@ import fs from 'fs/promises'
import { minimatch } from 'minimatch'
import os from 'os'
import path from 'path'
import * as z from 'zod/v4'
import { z } from 'zod'
const logger = loggerService.withContext('MCP:FileSystemServer')

View File

@ -0,0 +1,108 @@
import { IpcChannel } from '@shared/IpcChannel'
import { ApiServerConfig } from '@types'
import { ipcMain } from 'electron'
import { apiServer } from '../apiServer'
import { config } from '../apiServer/config'
import { loggerService } from './LoggerService'
const logger = loggerService.withContext('ApiServerService')
export class ApiServerService {
constructor() {
// Use the new clean implementation
}
async start(): Promise<void> {
try {
await apiServer.start()
logger.info('API Server started successfully')
} catch (error: any) {
logger.error('Failed to start API Server:', error)
throw error
}
}
async stop(): Promise<void> {
try {
await apiServer.stop()
logger.info('API Server stopped successfully')
} catch (error: any) {
logger.error('Failed to stop API Server:', error)
throw error
}
}
async restart(): Promise<void> {
try {
await apiServer.restart()
logger.info('API Server restarted successfully')
} catch (error: any) {
logger.error('Failed to restart API Server:', error)
throw error
}
}
isRunning(): boolean {
return apiServer.isRunning()
}
async getCurrentConfig(): Promise<ApiServerConfig> {
return config.get()
}
registerIpcHandlers(): void {
// API Server
ipcMain.handle(IpcChannel.ApiServer_Start, async () => {
try {
await this.start()
return { success: true }
} catch (error: any) {
return { success: false, error: error instanceof Error ? error.message : 'Unknown error' }
}
})
ipcMain.handle(IpcChannel.ApiServer_Stop, async () => {
try {
await this.stop()
return { success: true }
} catch (error: any) {
return { success: false, error: error instanceof Error ? error.message : 'Unknown error' }
}
})
ipcMain.handle(IpcChannel.ApiServer_Restart, async () => {
try {
await this.restart()
return { success: true }
} catch (error: any) {
return { success: false, error: error instanceof Error ? error.message : 'Unknown error' }
}
})
ipcMain.handle(IpcChannel.ApiServer_GetStatus, async () => {
try {
const config = await this.getCurrentConfig()
return {
running: this.isRunning(),
config
}
} catch (error: any) {
return {
running: this.isRunning(),
config: null
}
}
})
ipcMain.handle(IpcChannel.ApiServer_GetConfig, async () => {
try {
return this.getCurrentConfig()
} catch (error: any) {
return null
}
})
}
}
// Export singleton instance
export const apiServerService = new ApiServerService()

View File

@ -332,14 +332,15 @@ class CodeToolsService {
// macOS - Use osascript to launch terminal and execute command directly, without showing startup command
const envPrefix = buildEnvPrefix(false)
const command = envPrefix ? `${envPrefix} && ${baseCommand}` : baseCommand
// Combine directory change with the main command to ensure they execute in the same shell session
const fullCommand = `cd '${directory.replace(/'/g, "\\'")}' && clear && ${command}`
terminalCommand = 'osascript'
terminalArgs = [
'-e',
`tell application "Terminal"
set newTab to do script "cd '${directory.replace(/'/g, "\\'")}' && clear"
do script "${fullCommand.replace(/"/g, '\\"')}"
activate
do script "${command.replace(/"/g, '\\"')}" in newTab
end tell`
]
break

View File

@ -16,6 +16,7 @@ import {
type StreamableHTTPClientTransportOptions
} from '@modelcontextprotocol/sdk/client/streamableHttp'
import { InMemoryTransport } from '@modelcontextprotocol/sdk/inMemory'
import { McpError, type Tool as SDKTool } from '@modelcontextprotocol/sdk/types'
// Import notification schemas from MCP SDK
import {
CancelledNotificationSchema,
@ -29,6 +30,7 @@ import {
import { nanoid } from '@reduxjs/toolkit'
import { MCPProgressEvent } from '@shared/config/types'
import { IpcChannel } from '@shared/IpcChannel'
import { defaultAppHeaders } from '@shared/utils'
import {
BuiltinMCPServerNames,
type GetResourceResponse,
@ -94,7 +96,7 @@ function getServerLogger(server: MCPServer, extra?: Record<string, any>) {
baseUrl: server?.baseUrl,
type: server?.type || (server?.command ? 'stdio' : server?.baseUrl ? 'http' : 'inmemory')
}
return loggerService.withContext('MCPService', { ...base, ...(extra || {}) })
return loggerService.withContext('MCPService', { ...base, ...extra })
}
/**
@ -193,11 +195,18 @@ class McpService {
return existingClient
}
} catch (error: any) {
getServerLogger(server).error(`Error pinging server`, error as Error)
getServerLogger(server).error(`Error pinging server ${server.name}`, error as Error)
this.clients.delete(serverKey)
}
}
const prepareHeaders = () => {
return {
...defaultAppHeaders(),
...server.headers
}
}
// Create a promise for the initialization process
const initPromise = (async () => {
try {
@ -235,8 +244,11 @@ class McpService {
} else if (server.baseUrl) {
if (server.type === 'streamableHttp') {
const options: StreamableHTTPClientTransportOptions = {
fetch: async (url, init) => {
return net.fetch(typeof url === 'string' ? url : url.toString(), init)
},
requestInit: {
headers: server.headers || {}
headers: prepareHeaders()
},
authProvider
}
@ -249,25 +261,11 @@ class McpService {
const options: SSEClientTransportOptions = {
eventSourceInit: {
fetch: async (url, init) => {
const headers = { ...(server.headers || {}), ...(init?.headers || {}) }
// Get tokens from authProvider to make sure using the latest tokens
if (authProvider && typeof authProvider.tokens === 'function') {
try {
const tokens = await authProvider.tokens()
if (tokens && tokens.access_token) {
headers['Authorization'] = `Bearer ${tokens.access_token}`
}
} catch (error) {
getServerLogger(server).error('Failed to fetch tokens:', error as Error)
}
}
return net.fetch(typeof url === 'string' ? url : url.toString(), { ...init, headers })
return net.fetch(typeof url === 'string' ? url : url.toString(), init)
}
},
requestInit: {
headers: server.headers || {}
headers: prepareHeaders()
},
authProvider
}
@ -444,9 +442,9 @@ class McpService {
logger.debug(`Activated server: ${server.name}`)
return client
} catch (error: any) {
getServerLogger(server).error(`Error activating server`, error as Error)
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
} catch (error) {
getServerLogger(server).error(`Error activating server ${server.name}`, error as Error)
throw error
}
} finally {
// Clean up the pending promise when done
@ -614,12 +612,11 @@ class McpService {
}
private async listToolsImpl(server: MCPServer): Promise<MCPTool[]> {
getServerLogger(server).debug(`Listing tools`)
const client = await this.initClient(server)
try {
const { tools } = await client.listTools()
const serverTools: MCPTool[] = []
tools.map((tool: any) => {
tools.map((tool: SDKTool) => {
const serverTool: MCPTool = {
...tool,
id: buildFunctionCallToolName(server.name, tool.name),
@ -628,11 +625,12 @@ class McpService {
type: 'mcp'
}
serverTools.push(serverTool)
getServerLogger(server).debug(`Listing tools`, { tool: serverTool })
})
return serverTools
} catch (error: any) {
} catch (error: unknown) {
getServerLogger(server).error(`Failed to list tools`, error as Error)
return []
throw error
}
}
@ -739,9 +737,9 @@ class McpService {
serverId: server.id,
serverName: server.name
}))
} catch (error: any) {
} catch (error: unknown) {
// -32601 is the code for the method not found
if (error?.code !== -32601) {
if (error instanceof McpError && error.code !== -32601) {
getServerLogger(server).error(`Failed to list prompts`, error as Error)
}
return []

View File

@ -115,7 +115,7 @@ class KnowledgeService {
const framework = knowledgeFrameworkFactory.getFramework(base)
await framework.initialize(base)
}
public async reset(_: Electron.IpcMainInvokeEvent, { base }: { base: KnowledgeBaseParams }): Promise<void> {
public async reset(_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise<void> {
const framework = knowledgeFrameworkFactory.getFramework(base)
await framework.reset(base)
}

View File

@ -30,7 +30,7 @@ import {
KnowledgeBaseParams,
KnowledgeSearchResult
} from '@types'
import { uuidv4 } from 'zod/v4'
import { uuidv4 } from 'zod'
import { windowService } from '../WindowService'
import {
@ -103,6 +103,8 @@ export class LangChainFramework implements IKnowledgeFramework {
if (fs.existsSync(dbPath)) {
fs.rmSync(dbPath, { recursive: true })
}
// 立即重建空索引,避免随后加载时报错
await this.createDatabase(base)
}
async delete(id: string): Promise<void> {

View File

@ -3,6 +3,7 @@ import { Provider } from '@types'
import { BaseFileService } from './BaseFileService'
import { GeminiService } from './GeminiService'
import { MistralService } from './MistralService'
import { OpenaiService } from './OpenAIService'
export class FileServiceManager {
private static instance: FileServiceManager
@ -30,6 +31,9 @@ export class FileServiceManager {
case 'mistral':
service = new MistralService(provider)
break
case 'openai':
service = new OpenaiService(provider)
break
default:
throw new Error(`Unsupported service type: ${type}`)
}

View File

@ -0,0 +1,125 @@
import { loggerService } from '@logger'
import { fileStorage } from '@main/services/FileStorage'
import { FileListResponse, FileMetadata, FileUploadResponse, Provider } from '@types'
import * as fs from 'fs'
import OpenAI from 'openai'
import { CacheService } from '../CacheService'
import { BaseFileService } from './BaseFileService'
const logger = loggerService.withContext('OpenAIService')
export class OpenaiService extends BaseFileService {
private static readonly FILE_CACHE_DURATION = 7 * 24 * 60 * 60 * 1000
private static readonly generateUIFileIdCacheKey = (fileId: string) => `ui_file_id_${fileId}`
private readonly client: OpenAI
constructor(provider: Provider) {
super(provider)
this.client = new OpenAI({
apiKey: provider.apiKey,
baseURL: provider.apiHost
})
}
async uploadFile(file: FileMetadata): Promise<FileUploadResponse> {
let fileReadStream: fs.ReadStream | undefined
try {
fileReadStream = fs.createReadStream(fileStorage.getFilePathById(file))
// 还原文件原始名,以提高模型对文件的理解
const fileStreamWithMeta = Object.assign(fileReadStream, {
name: file.origin_name
})
const response = await this.client.files.create({
file: fileStreamWithMeta,
purpose: file.purpose || 'assistants'
})
if (!response.id) {
throw new Error('File id not found in response')
}
// 映射RemoteFileId到UIFileId上
CacheService.set<string>(
OpenaiService.generateUIFileIdCacheKey(file.id),
response.id,
OpenaiService.FILE_CACHE_DURATION
)
return {
fileId: response.id,
displayName: file.origin_name,
status: 'success',
originalFile: {
type: 'openai',
file: response
}
}
} catch (error) {
logger.error('Error uploading file:', error as Error)
return {
fileId: '',
displayName: file.origin_name,
status: 'failed'
}
} finally {
// 销毁文件流
if (fileReadStream) fileReadStream.destroy()
}
}
async listFiles(): Promise<FileListResponse> {
try {
const response = await this.client.files.list()
return {
files: response.data.map((file) => ({
id: file.id,
displayName: file.filename || '',
size: file.bytes,
status: 'success', // All listed files are processed,
originalFile: {
type: 'openai',
file
}
}))
}
} catch (error) {
logger.error('Error listing files:', error as Error)
return { files: [] }
}
}
async deleteFile(fileId: string): Promise<void> {
try {
const cachedRemoteFileId = CacheService.get<string>(OpenaiService.generateUIFileIdCacheKey(fileId))
await this.client.files.delete(cachedRemoteFileId || fileId)
logger.debug(`File ${fileId} deleted`)
} catch (error) {
logger.error('Error deleting file:', error as Error)
throw error
}
}
async retrieveFile(fileId: string): Promise<FileUploadResponse> {
try {
// 尝试反映射RemoteFileId
const cachedRemoteFileId = CacheService.get<string>(OpenaiService.generateUIFileIdCacheKey(fileId))
const response = await this.client.files.retrieve(cachedRemoteFileId || fileId)
return {
fileId: response.id,
displayName: response.filename,
status: 'success',
originalFile: {
type: 'openai',
file: response
}
}
} catch (error) {
logger.error('Error retrieving file:', error as Error)
return {
fileId: fileId,
displayName: '',
status: 'failed',
originalFile: undefined
}
}
}
}

View File

@ -398,11 +398,15 @@ export function validateFileName(fileName: string, platform = process.platform):
* @returns
*/
export function checkName(fileName: string): string {
const validation = validateFileName(fileName)
const baseName = path.basename(fileName)
const validation = validateFileName(baseName)
if (!validation.valid) {
throw new Error(`Invalid file name: ${fileName}. ${validation.error}`)
// 自动清理非法字符,而不是抛出错误
const sanitized = sanitizeFilename(baseName)
logger.warn(`File name contains invalid characters, auto-sanitized: "${baseName}" -> "${sanitized}"`)
return sanitized
}
return fileName
return baseName
}
/**

View File

@ -42,6 +42,8 @@ export function tracedInvoke(channel: string, spanContext: SpanContext | undefin
// Custom APIs for renderer
const api = {
getAppInfo: () => ipcRenderer.invoke(IpcChannel.App_Info),
getDiskInfo: (directoryPath: string): Promise<{ free: number; size: number } | null> =>
ipcRenderer.invoke(IpcChannel.App_GetDiskInfo, directoryPath),
reload: () => ipcRenderer.invoke(IpcChannel.App_Reload),
setProxy: (proxy: string | undefined, bypassRules?: string) =>
ipcRenderer.invoke(IpcChannel.App_Proxy, proxy, bypassRules),
@ -80,6 +82,7 @@ const api = {
ipcRenderer.invoke(IpcChannel.App_LogToMain, source, level, message, data),
setFullScreen: (value: boolean): Promise<void> => ipcRenderer.invoke(IpcChannel.App_SetFullScreen, value),
isFullScreen: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_IsFullScreen),
getSystemFonts: (): Promise<string[]> => ipcRenderer.invoke(IpcChannel.App_GetSystemFonts),
mac: {
isProcessTrusted: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_MacIsProcessTrusted),
requestProcessTrust: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_MacRequestProcessTrust)
@ -445,9 +448,10 @@ const api = {
isMaximized: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.Windows_IsMaximized),
onMaximizedChange: (callback: (isMaximized: boolean) => void): (() => void) => {
const channel = IpcChannel.Windows_MaximizedChanged
ipcRenderer.on(channel, (_, isMaximized: boolean) => callback(isMaximized))
const listener = (_: Electron.IpcRendererEvent, isMaximized: boolean) => callback(isMaximized)
ipcRenderer.on(channel, listener)
return () => {
ipcRenderer.removeAllListeners(channel)
ipcRenderer.removeListener(channel, listener)
}
}
},

View File

@ -1,5 +1,6 @@
import '@renderer/databases'
import { HeroUIProvider } from '@heroui/react'
import { preferenceService } from '@data/PreferenceService'
import { loggerService } from '@logger'
import store, { persistor } from '@renderer/store'
@ -7,6 +8,7 @@ import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { Provider } from 'react-redux'
import { PersistGate } from 'redux-persist/integration/react'
import { ToastPortal } from './components/ToastPortal'
import TopViewContainer from './components/TopView'
import AntdProvider from './context/AntdProvider'
import { CodeStyleProvider } from './context/CodeStyleProvider'
@ -35,21 +37,24 @@ function App(): React.ReactElement {
return (
<Provider store={store}>
<QueryClientProvider client={queryClient}>
<StyleSheetManager>
<ThemeProvider>
<AntdProvider>
<NotificationProvider>
<CodeStyleProvider>
<PersistGate loading={null} persistor={persistor}>
<TopViewContainer>
<Router />
</TopViewContainer>
</PersistGate>
</CodeStyleProvider>
</NotificationProvider>
</AntdProvider>
</ThemeProvider>
</StyleSheetManager>
<HeroUIProvider className="flex h-full w-full flex-1">
<StyleSheetManager>
<ThemeProvider>
<AntdProvider>
<NotificationProvider>
<CodeStyleProvider>
<PersistGate loading={null} persistor={persistor}>
<TopViewContainer>
<Router />
</TopViewContainer>
</PersistGate>
</CodeStyleProvider>
</NotificationProvider>
</AntdProvider>
</ThemeProvider>
</StyleSheetManager>
<ToastPortal />
</HeroUIProvider>
</QueryClientProvider>
</Provider>
)

View File

@ -4,8 +4,9 @@
*/
import { loggerService } from '@logger'
import { MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types'
import { AISDKWebSearchResult, MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import { convertLinks, flushLinkConverterBuffer } from '@renderer/utils/linkConverter'
import type { TextStreamPart, ToolSet } from 'ai'
import { ToolCallChunkHandler } from './handleToolCallChunk'
@ -29,13 +30,18 @@ export interface CherryStudioChunk {
export class AiSdkToChunkAdapter {
toolCallHandler: ToolCallChunkHandler
private accumulate: boolean | undefined
private isFirstChunk = true
private enableWebSearch: boolean = false
constructor(
private onChunk: (chunk: Chunk) => void,
mcpTools: MCPTool[] = [],
accumulate?: boolean
accumulate?: boolean,
enableWebSearch?: boolean
) {
this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools)
this.accumulate = accumulate
this.enableWebSearch = enableWebSearch || false
}
/**
@ -65,11 +71,24 @@ export class AiSdkToChunkAdapter {
webSearchResults: [],
reasoningId: ''
}
// Reset link converter state at the start of stream
this.isFirstChunk = true
try {
while (true) {
const { done, value } = await reader.read()
if (done) {
// Flush any remaining content from link converter buffer if web search is enabled
if (this.enableWebSearch) {
const remainingText = flushLinkConverterBuffer()
if (remainingText) {
this.onChunk({
type: ChunkType.TEXT_DELTA,
text: remainingText
})
}
}
break
}
@ -87,9 +106,9 @@ export class AiSdkToChunkAdapter {
*/
private convertAndEmitChunk(
chunk: TextStreamPart<any>,
final: { text: string; reasoningContent: string; webSearchResults: any[]; reasoningId: string }
final: { text: string; reasoningContent: string; webSearchResults: AISDKWebSearchResult[]; reasoningId: string }
) {
logger.info(`AI SDK chunk type: ${chunk.type}`, chunk)
logger.silly(`AI SDK chunk type: ${chunk.type}`, chunk)
switch (chunk.type) {
// === 文本相关事件 ===
case 'text-start':
@ -97,17 +116,44 @@ export class AiSdkToChunkAdapter {
type: ChunkType.TEXT_START
})
break
case 'text-delta':
if (this.accumulate) {
final.text += chunk.text || ''
case 'text-delta': {
const processedText = chunk.text || ''
let finalText: string
// Only apply link conversion if web search is enabled
if (this.enableWebSearch) {
const result = convertLinks(processedText, this.isFirstChunk)
if (this.isFirstChunk) {
this.isFirstChunk = false
}
// Handle buffered content
if (result.hasBufferedContent) {
finalText = result.text
} else {
finalText = result.text || processedText
}
} else {
final.text = chunk.text || ''
// Without web search, just use the original text
finalText = processedText
}
if (this.accumulate) {
final.text += finalText
} else {
final.text = finalText
}
// Only emit chunk if there's text to send
if (finalText) {
this.onChunk({
type: ChunkType.TEXT_DELTA,
text: this.accumulate ? final.text : finalText
})
}
this.onChunk({
type: ChunkType.TEXT_DELTA,
text: final.text || ''
})
break
}
case 'text-end':
this.onChunk({
type: ChunkType.TEXT_COMPLETE,
@ -152,12 +198,14 @@ export class AiSdkToChunkAdapter {
// this.toolCallHandler.handleToolCallCreated(chunk)
// break
case 'tool-call':
// 原始的工具调用(未被中间件处理)
this.toolCallHandler.handleToolCall(chunk)
break
case 'tool-error':
this.toolCallHandler.handleToolError(chunk)
break
case 'tool-result':
// 原始的工具调用结果(未被中间件处理)
this.toolCallHandler.handleToolResult(chunk)
break
@ -167,7 +215,6 @@ export class AiSdkToChunkAdapter {
// type: ChunkType.LLM_RESPONSE_CREATED
// })
// break
// TODO: 需要区分接口开始和步骤开始
// case 'start-step':
// this.onChunk({
// type: ChunkType.BLOCK_CREATED
@ -199,7 +246,7 @@ export class AiSdkToChunkAdapter {
[WebSearchSource.ANTHROPIC]: WebSearchSource.ANTHROPIC,
[WebSearchSource.OPENROUTER]: WebSearchSource.OPENROUTER,
[WebSearchSource.GEMINI]: WebSearchSource.GEMINI,
[WebSearchSource.PERPLEXITY]: WebSearchSource.PERPLEXITY,
// [WebSearchSource.PERPLEXITY]: WebSearchSource.PERPLEXITY,
[WebSearchSource.QWEN]: WebSearchSource.QWEN,
[WebSearchSource.HUNYUAN]: WebSearchSource.HUNYUAN,
[WebSearchSource.ZHIPU]: WebSearchSource.ZHIPU,
@ -267,18 +314,9 @@ export class AiSdkToChunkAdapter {
// === 源和文件相关事件 ===
case 'source':
if (chunk.sourceType === 'url') {
// if (final.webSearchResults.length === 0) {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { sourceType: _, ...rest } = chunk
final.webSearchResults.push(rest)
// }
// this.onChunk({
// type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
// llm_web_search: {
// source: WebSearchSource.AISDK,
// results: final.webSearchResults
// }
// })
}
break
case 'file':
@ -305,8 +343,6 @@ export class AiSdkToChunkAdapter {
break
default:
// 其他类型的 chunk 可以忽略或记录日志
// console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
}
}
}

View File

@ -8,34 +8,61 @@ import { loggerService } from '@logger'
import { processKnowledgeReferences } from '@renderer/services/KnowledgeService'
import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from 'ai'
// import type {
// AnthropicSearchOutput,
// WebSearchPluginConfig
// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
import type { ToolSet, TypedToolCall, TypedToolError, TypedToolResult } from 'ai'
const logger = loggerService.withContext('ToolCallChunkHandler')
export type ToolcallsMap = {
toolCallId: string
toolName: string
args: any
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
tool: BaseTool
}
/**
*
*/
export class ToolCallChunkHandler {
// private onChunk: (chunk: Chunk) => void
private activeToolCalls = new Map<
string,
{
toolCallId: string
toolName: string
args: any
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
tool: BaseTool
}
>()
private static globalActiveToolCalls = new Map<string, ToolcallsMap>()
private activeToolCalls = ToolCallChunkHandler.globalActiveToolCalls
constructor(
private onChunk: (chunk: Chunk) => void,
private mcpTools: MCPTool[]
) {}
/**
*
*/
private static addActiveToolCallImpl(toolCallId: string, map: ToolcallsMap): boolean {
if (!ToolCallChunkHandler.globalActiveToolCalls.has(toolCallId)) {
ToolCallChunkHandler.globalActiveToolCalls.set(toolCallId, map)
return true
}
return false
}
/**
*
*/
private addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
}
/**
*
*/
public static getActiveToolCalls() {
return ToolCallChunkHandler.globalActiveToolCalls
}
/**
* 访
*/
public static addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
}
// /**
// * 设置 onChunk 回调
// */
@ -43,103 +70,103 @@ export class ToolCallChunkHandler {
// this.onChunk = callback
// }
handleToolCallCreated(
chunk:
| {
type: 'tool-input-start'
id: string
toolName: string
providerMetadata?: ProviderMetadata
providerExecuted?: boolean
}
| {
type: 'tool-input-end'
id: string
providerMetadata?: ProviderMetadata
}
| {
type: 'tool-input-delta'
id: string
delta: string
providerMetadata?: ProviderMetadata
}
): void {
switch (chunk.type) {
case 'tool-input-start': {
// 能拿到说明是mcpTool
// if (this.activeToolCalls.get(chunk.id)) return
// handleToolCallCreated(
// chunk:
// | {
// type: 'tool-input-start'
// id: string
// toolName: string
// providerMetadata?: ProviderMetadata
// providerExecuted?: boolean
// }
// | {
// type: 'tool-input-end'
// id: string
// providerMetadata?: ProviderMetadata
// }
// | {
// type: 'tool-input-delta'
// id: string
// delta: string
// providerMetadata?: ProviderMetadata
// }
// ): void {
// switch (chunk.type) {
// case 'tool-input-start': {
// // 能拿到说明是mcpTool
// // if (this.activeToolCalls.get(chunk.id)) return
const tool: BaseTool | MCPTool = {
id: chunk.id,
name: chunk.toolName,
description: chunk.toolName,
type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
}
this.activeToolCalls.set(chunk.id, {
toolCallId: chunk.id,
toolName: chunk.toolName,
args: '',
tool
})
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: chunk.id,
tool: tool,
arguments: {},
status: 'pending',
toolCallId: chunk.id
}
this.onChunk({
type: ChunkType.MCP_TOOL_PENDING,
responses: [toolResponse]
})
break
}
case 'tool-input-delta': {
const toolCall = this.activeToolCalls.get(chunk.id)
if (!toolCall) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
return
}
toolCall.args += chunk.delta
break
}
case 'tool-input-end': {
const toolCall = this.activeToolCalls.get(chunk.id)
this.activeToolCalls.delete(chunk.id)
if (!toolCall) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
return
}
// const toolResponse: ToolCallResponse = {
// id: toolCall.toolCallId,
// tool: toolCall.tool,
// arguments: toolCall.args,
// status: 'pending',
// toolCallId: toolCall.toolCallId
// }
// logger.debug('toolResponse', toolResponse)
// this.onChunk({
// type: ChunkType.MCP_TOOL_PENDING,
// responses: [toolResponse]
// })
break
}
}
// if (!toolCall) {
// Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
// return
// }
// this.onChunk({
// type: ChunkType.MCP_TOOL_CREATED,
// tool_calls: [
// {
// id: chunk.id,
// name: chunk.toolName,
// status: 'pending'
// }
// ]
// })
}
// const tool: BaseTool | MCPTool = {
// id: chunk.id,
// name: chunk.toolName,
// description: chunk.toolName,
// type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
// }
// this.activeToolCalls.set(chunk.id, {
// toolCallId: chunk.id,
// toolName: chunk.toolName,
// args: '',
// tool
// })
// const toolResponse: MCPToolResponse | NormalToolResponse = {
// id: chunk.id,
// tool: tool,
// arguments: {},
// status: 'pending',
// toolCallId: chunk.id
// }
// this.onChunk({
// type: ChunkType.MCP_TOOL_PENDING,
// responses: [toolResponse]
// })
// break
// }
// case 'tool-input-delta': {
// const toolCall = this.activeToolCalls.get(chunk.id)
// if (!toolCall) {
// logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
// return
// }
// toolCall.args += chunk.delta
// break
// }
// case 'tool-input-end': {
// const toolCall = this.activeToolCalls.get(chunk.id)
// this.activeToolCalls.delete(chunk.id)
// if (!toolCall) {
// logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
// return
// }
// // const toolResponse: ToolCallResponse = {
// // id: toolCall.toolCallId,
// // tool: toolCall.tool,
// // arguments: toolCall.args,
// // status: 'pending',
// // toolCallId: toolCall.toolCallId
// // }
// // logger.debug('toolResponse', toolResponse)
// // this.onChunk({
// // type: ChunkType.MCP_TOOL_PENDING,
// // responses: [toolResponse]
// // })
// break
// }
// }
// // if (!toolCall) {
// // Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
// // return
// // }
// // this.onChunk({
// // type: ChunkType.MCP_TOOL_CREATED,
// // tool_calls: [
// // {
// // id: chunk.id,
// // name: chunk.toolName,
// // status: 'pending'
// // }
// // ]
// // })
// }
/**
*
@ -158,7 +185,6 @@ export class ToolCallChunkHandler {
let tool: BaseTool
let mcpTool: MCPTool | undefined
// 根据 providerExecuted 标志区分处理逻辑
if (providerExecuted) {
// 如果是 Provider 执行的工具(如 web_search
@ -196,27 +222,25 @@ export class ToolCallChunkHandler {
}
}
// 记录活跃的工具调用
this.activeToolCalls.set(toolCallId, {
this.addActiveToolCall(toolCallId, {
toolCallId,
toolName,
args,
tool
})
// 创建 MCPToolResponse 格式
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: toolCallId,
tool: tool,
arguments: args,
status: 'pending',
status: 'pending', // 统一使用 pending 状态
toolCallId: toolCallId
}
// 调用 onChunk
if (this.onChunk) {
this.onChunk({
type: ChunkType.MCP_TOOL_PENDING,
type: ChunkType.MCP_TOOL_PENDING, // 统一发送 pending 状态
responses: [toolResponse]
})
}
@ -257,7 +281,7 @@ export class ToolCallChunkHandler {
// 工具特定的后处理
switch (toolResponse.tool.name) {
case 'builtin_knowledge_search': {
processKnowledgeReferences(toolResponse.response?.knowledgeReferences, this.onChunk)
processKnowledgeReferences(toolResponse.response, this.onChunk)
break
}
// 未来可以在这里添加其他工具的后处理逻辑
@ -276,4 +300,33 @@ export class ToolCallChunkHandler {
})
}
}
handleToolError(
chunk: {
type: 'tool-error'
} & TypedToolError<ToolSet>
): void {
const { toolCallId, error, input } = chunk
const toolCallInfo = this.activeToolCalls.get(toolCallId)
if (!toolCallInfo) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
return
}
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: toolCallId,
tool: toolCallInfo.tool,
arguments: input,
status: 'error',
response: error,
toolCallId: toolCallId
}
this.activeToolCalls.delete(toolCallId)
if (this.onChunk) {
this.onChunk({
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [toolResponse]
})
}
}
}
export const addActiveToolCall = ToolCallChunkHandler.addActiveToolCall.bind(ToolCallChunkHandler)

View File

@ -265,15 +265,15 @@ export default class ModernAiProvider {
params: StreamTextParams,
config: ModernAiProviderConfig
): Promise<CompletionsResult> {
const modelId = this.model!.id
logger.info('Starting modernCompletions', {
modelId,
providerId: this.config!.providerId,
topicId: config.topicId,
hasOnChunk: !!config.onChunk,
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
toolCount: params.tools ? Object.keys(params.tools).length : 0
})
// const modelId = this.model!.id
// logger.info('Starting modernCompletions', {
// modelId,
// providerId: this.config!.providerId,
// topicId: config.topicId,
// hasOnChunk: !!config.onChunk,
// hasTools: !!params.tools && Object.keys(params.tools).length > 0,
// toolCount: params.tools ? Object.keys(params.tools).length : 0
// })
// 根据条件构建插件数组
const plugins = await buildPlugins(config)
@ -284,7 +284,7 @@ export default class ModernAiProvider {
// 创建带有中间件的执行器
if (config.onChunk) {
const accumulate = this.model!.supported_text_delta !== false // true and undefined
const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools, accumulate)
const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools, accumulate, config.enableWebSearch)
const streamResult = await executor.streamText({
...params,

View File

@ -45,6 +45,7 @@ import { isJSON, parseJSON } from '@renderer/utils'
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { defaultTimeout } from '@shared/config/constant'
import { defaultAppHeaders } from '@shared/utils'
import { REFERENCE_PROMPT } from '@shared/config/prompts'
import { isEmpty } from 'lodash'
@ -179,8 +180,7 @@ export abstract class BaseApiClient<
public defaultHeaders() {
return {
'HTTP-Referer': 'https://cherry-ai.com',
'X-Title': 'Cherry Studio',
...defaultAppHeaders(),
'X-Api-Key': this.apiKey
}
}

View File

@ -1,6 +1,6 @@
import { loggerService } from '@logger'
import { ChunkType } from '@renderer/types/chunk'
import { flushLinkConverterBuffer, smartLinkConverter } from '@renderer/utils/linkConverter'
import { convertLinks, flushLinkConverterBuffer } from '@renderer/utils/linkConverter'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
@ -28,8 +28,6 @@ export const WebSearchMiddleware: CompletionsMiddleware =
}
// 调用下游中间件
const result = await next(ctx, params)
const model = params.assistant?.model!
let isFirstChunk = true
// 响应后处理记录Web搜索事件
@ -42,15 +40,9 @@ export const WebSearchMiddleware: CompletionsMiddleware =
new TransformStream<GenericChunk, GenericChunk>({
transform(chunk: GenericChunk, controller) {
if (chunk.type === ChunkType.TEXT_DELTA) {
const providerType = model.provider || 'openai'
// 使用当前可用的Web搜索结果进行链接转换
const text = chunk.text
const result = smartLinkConverter(
text,
providerType,
isFirstChunk,
ctx._internal.webSearchState!.results
)
const result = convertLinks(text, isFirstChunk)
if (isFirstChunk) {
isFirstChunk = false
}

View File

@ -20,8 +20,10 @@ export interface AiSdkMiddlewareConfig {
isSupportedToolUse: boolean
// image generation endpoint
isImageGenerationEndpoint: boolean
// 是否开启内置搜索
enableWebSearch: boolean
enableGenerateImage: boolean
enableUrlContext: boolean
mcpTools?: MCPTool[]
uiMessages?: Message[]
}
@ -132,7 +134,6 @@ export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageMo
})
}
logger.info('builder.build()', builder.buildNamed())
return builder.build()
}

View File

@ -1,5 +1,5 @@
import { AiPlugin } from '@cherrystudio/ai-core'
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
import { createPromptToolUsePlugin, googleToolsPlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
import { preferenceService } from '@data/PreferenceService'
import { loggerService } from '@logger'
import type { Assistant } from '@renderer/types'
@ -35,7 +35,7 @@ export async function buildPlugins(
plugins.push(webSearchPlugin())
}
// 2. 支持工具调用时添加搜索插件
if (middlewareConfig.isSupportedToolUse) {
if (middlewareConfig.isSupportedToolUse || middlewareConfig.isPromptToolUse) {
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant, middlewareConfig.topicId || ''))
}
@ -45,12 +45,13 @@ export async function buildPlugins(
}
// 4. 启用Prompt工具调用时添加工具插件
if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
if (middlewareConfig.isPromptToolUse) {
plugins.push(
createPromptToolUsePlugin({
enabled: true,
createSystemMessage: (systemPrompt, params, context) => {
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
const modelId = typeof context.model === 'string' ? context.model : context.model.modelId
if (modelId.includes('o1-mini') || modelId.includes('o1-preview')) {
if (context.isRecursiveCall) {
return null
}
@ -69,10 +70,11 @@ export async function buildPlugins(
)
}
// if (!middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
// plugins.push(createNativeToolUsePlugin())
// }
logger.info(
if (middlewareConfig.enableUrlContext) {
plugins.push(googleToolsPlugin({ urlContext: true }))
}
logger.debug(
'Final plugin list:',
plugins.map((p) => p.name)
)

View File

@ -19,7 +19,8 @@ import {
SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY,
SEARCH_SUMMARY_PROMPT_WEB_ONLY
} from '@shared/config/prompts'
import type { ModelMessage } from 'ai'
import type { LanguageModel, ModelMessage } from 'ai'
import { generateText } from 'ai'
import { isEmpty } from 'lodash'
import { MemoryProcessor } from '../../services/MemoryProcessor'
@ -76,9 +77,7 @@ async function analyzeSearchIntent(
shouldKnowledgeSearch?: boolean
shouldMemorySearch?: boolean
lastAnswer?: ModelMessage
context: AiRequestContext & {
isAnalyzing?: boolean
}
context: AiRequestContext
topicId: string
}
): Promise<ExtractResults | undefined> {
@ -122,9 +121,7 @@ async function analyzeSearchIntent(
logger.error('Provider not found or missing API key')
return getFallbackResult()
}
// console.log('formattedPrompt', schema)
try {
context.isAnalyzing = true
logger.info('Starting intent analysis generateText call', {
modelId: model.id,
topicId: options.topicId,
@ -133,18 +130,16 @@ async function analyzeSearchIntent(
hasKnowledgeSearch: needKnowledgeExtract
})
const { text: result } = await context.executor
.generateText(model.id, {
prompt: formattedPrompt
})
.finally(() => {
context.isAnalyzing = false
logger.info('Intent analysis generateText call completed', {
modelId: model.id,
topicId: options.topicId,
requestId: context.requestId
})
const { text: result } = await generateText({
model: context.model as LanguageModel,
prompt: formattedPrompt
}).finally(() => {
logger.info('Intent analysis generateText call completed', {
modelId: model.id,
topicId: options.topicId,
requestId: context.requestId
})
})
const parsedResult = extractInfoFromXML(result)
logger.debug('Intent analysis result', { parsedResult })
@ -183,7 +178,6 @@ async function storeConversationMemory(
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
if (!globalMemoryEnabled || !assistant.enableMemory) {
// console.log('Memory storage is disabled')
return
}
@ -245,25 +239,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
// 存储意图分析结果
const intentAnalysisResults: { [requestId: string]: ExtractResults } = {}
const userMessages: { [requestId: string]: ModelMessage } = {}
let currentContext: AiRequestContext | null = null
return definePlugin({
name: 'search-orchestration',
enforce: 'pre', // 确保在其他插件之前执行
configureContext: (context: AiRequestContext) => {
if (currentContext) {
context.isAnalyzing = currentContext.isAnalyzing
}
currentContext = context
},
/**
* 🔍 Step 1: 意图识别阶段
*/
onRequestStart: async (context: AiRequestContext) => {
if (context.isAnalyzing) return
// 没开启任何搜索则不进行意图分析
if (!(assistant.webSearchProviderId || assistant.knowledge_bases?.length || assistant.enableMemory)) return
@ -284,7 +267,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
const shouldWebSearch = !!assistant.webSearchProviderId
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory
@ -315,7 +297,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
* 🔧 Step 2: 工具配置阶段
*/
transformParams: async (params: any, context: AiRequestContext) => {
if (context.isAnalyzing) return params
// logger.info('🔧 Configuring tools based on intent...', context.requestId)
try {
@ -409,7 +390,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
// context.isAnalyzing = false
// logger.info('context.isAnalyzing', context, result)
// logger.info('💾 Starting memory storage...', context.requestId)
if (context.isAnalyzing) return
try {
const messages = context.originalParams.messages

View File

@ -58,91 +58,6 @@ class AdapterTracer {
})
}
// startSpan(name: string, options?: any, context?: any): Span {
// // 如果提供了父 SpanContext 且未显式传入 context则使用父上下文
// const contextToUse = context ?? this.cachedParentContext ?? otelContext.active()
// const span = this.originalTracer.startSpan(name, options, contextToUse)
// // 标记父子关系,便于在转换阶段兜底重建层级
// try {
// if (this.parentSpanContext) {
// span.setAttribute('trace.parentSpanId', this.parentSpanContext.spanId)
// span.setAttribute('trace.parentTraceId', this.parentSpanContext.traceId)
// }
// if (this.topicId) {
// span.setAttribute('trace.topicId', this.topicId)
// }
// } catch (e) {
// logger.debug('Failed to set trace parent attributes', e as Error)
// }
// logger.info('AI SDK span created via AdapterTracer', {
// spanName: name,
// spanId: span.spanContext().spanId,
// traceId: span.spanContext().traceId,
// parentTraceId: this.parentSpanContext?.traceId,
// topicId: this.topicId,
// modelName: this.modelName,
// traceIdMatches: this.parentSpanContext ? span.spanContext().traceId === this.parentSpanContext.traceId : undefined
// })
// // 包装 span 的 end 方法,在结束时进行数据转换
// const originalEnd = span.end.bind(span)
// span.end = (endTime?: any) => {
// logger.info('AI SDK span.end() called - about to convert span', {
// spanName: name,
// spanId: span.spanContext().spanId,
// traceId: span.spanContext().traceId,
// topicId: this.topicId,
// modelName: this.modelName
// })
// // 调用原始 end 方法
// originalEnd(endTime)
// // 转换并保存 span 数据
// try {
// logger.info('Converting AI SDK span to SpanEntity', {
// spanName: name,
// spanId: span.spanContext().spanId,
// traceId: span.spanContext().traceId,
// topicId: this.topicId,
// modelName: this.modelName
// })
// logger.info('spanspanspanspanspanspan', span)
// const spanEntity = AiSdkSpanAdapter.convertToSpanEntity({
// span,
// topicId: this.topicId,
// modelName: this.modelName
// })
// // 保存转换后的数据
// window.api.trace.saveEntity(spanEntity)
// logger.info('AI SDK span converted and saved successfully', {
// spanName: name,
// spanId: span.spanContext().spanId,
// traceId: span.spanContext().traceId,
// topicId: this.topicId,
// modelName: this.modelName,
// hasUsage: !!spanEntity.usage,
// usage: spanEntity.usage
// })
// } catch (error) {
// logger.error('Failed to convert AI SDK span', error as Error, {
// spanName: name,
// spanId: span.spanContext().spanId,
// traceId: span.spanContext().traceId,
// topicId: this.topicId,
// modelName: this.modelName
// })
// }
// }
// return span
// }
startActiveSpan<F extends (span: Span) => any>(name: string, fn: F): ReturnType<F>
startActiveSpan<F extends (span: Span) => any>(name: string, options: any, fn: F): ReturnType<F>
startActiveSpan<F extends (span: Span) => any>(name: string, options: any, context: any, fn: F): ReturnType<F>

View File

@ -10,6 +10,7 @@ import { FileTypes } from '@renderer/types'
import { FileMessageBlock } from '@renderer/types/newMessage'
import { findFileBlocks } from '@renderer/utils/messageUtils/find'
import type { FilePart, TextPart } from 'ai'
import type OpenAI from 'openai'
import { getAiSdkProviderId } from '../provider/factory'
import { getFileSizeLimit, supportsImageInput, supportsLargeFileUpload, supportsPdfInput } from './modelCapabilities'
@ -112,6 +113,86 @@ export async function handleGeminiFileUpload(file: FileMetadata, model: Model):
return null
}
/**
* OpenAI大文件上传
*/
export async function handleOpenAILargeFileUpload(
file: FileMetadata,
model: Model
): Promise<(FilePart & { id?: string }) | null> {
const provider = getProviderByModel(model)
// 如果模型为qwen-long系列文档中要求purpose需要为'file-extract'
if (['qwen-long', 'qwen-doc'].some((modelName) => model.name.includes(modelName))) {
file = {
...file,
// 该类型并不在OpenAI定义中但符合sdk规范强制断言
purpose: 'file-extract' as OpenAI.FilePurpose
}
}
try {
// 检查文件是否已经上传过
const fileMetadata = await window.api.fileService.retrieve(provider, file.id)
if (fileMetadata.status === 'success' && fileMetadata.originalFile?.file) {
// 断言OpenAIFile对象
const remoteFile = fileMetadata.originalFile.file as OpenAI.Files.FileObject
// 判断用途是否一致
if (remoteFile.purpose !== file.purpose) {
logger.warn(`File ${file.origin_name} purpose mismatch: ${remoteFile.purpose} vs ${file.purpose}`)
throw new Error('File purpose mismatch')
}
return {
type: 'file',
filename: file.origin_name,
mediaType: '',
data: `fileid://${remoteFile.id}`
}
}
} catch (error) {
logger.error(`Failed to retrieve file ${file.origin_name}:`, error as Error)
return null
}
try {
// 如果文件未上传,执行上传
const uploadResult = await window.api.fileService.upload(provider, file)
if (uploadResult.originalFile?.file) {
// 断言OpenAIFile对象
const remoteFile = uploadResult.originalFile.file as OpenAI.Files.FileObject
logger.info(`File ${file.origin_name} uploaded.`)
return {
type: 'file',
filename: remoteFile.filename,
mediaType: '',
data: `fileid://${remoteFile.id}`
}
}
} catch (error) {
logger.error(`Failed to upload file ${file.origin_name}:`, error as Error)
}
return null
}
/**
*
*/
export async function handleLargeFileUpload(
file: FileMetadata,
model: Model
): Promise<(FilePart & { id?: string }) | null> {
const provider = getProviderByModel(model)
const aiSdkId = getAiSdkProviderId(provider)
if (['google', 'google-generative-ai', 'google-vertex'].includes(aiSdkId)) {
return await handleGeminiFileUpload(file, model)
}
if (provider.type === 'openai') {
return await handleOpenAILargeFileUpload(file, model)
}
return null
}
/**
* FilePart
*/
@ -127,7 +208,7 @@ export async function convertFileBlockToFilePart(fileBlock: FileMessageBlock, mo
// 如果支持大文件上传如Gemini File API尝试上传
if (supportsLargeFileUpload(model)) {
logger.info(`Large PDF file ${file.origin_name} (${file.size} bytes) attempting File API upload`)
const uploadResult = await handleGeminiFileUpload(file, model)
const uploadResult = await handleLargeFileUpload(file, model)
if (uploadResult) {
return uploadResult
}

View File

@ -13,7 +13,15 @@ import {
findThinkingBlocks,
getMainTextContent
} from '@renderer/utils/messageUtils/find'
import type { AssistantModelMessage, FilePart, ImagePart, ModelMessage, TextPart, UserModelMessage } from 'ai'
import type {
AssistantModelMessage,
FilePart,
ImagePart,
ModelMessage,
SystemModelMessage,
TextPart,
UserModelMessage
} from 'ai'
import { convertFileBlockToFilePart, convertFileBlockToTextPart } from './fileProcessor'
@ -27,7 +35,7 @@ export async function convertMessageToSdkParam(
message: Message,
isVisionModel = false,
model?: Model
): Promise<ModelMessage> {
): Promise<ModelMessage | ModelMessage[]> {
const content = getMainTextContent(message)
const fileBlocks = findFileBlocks(message)
const imageBlocks = findImageBlocks(message)
@ -48,7 +56,7 @@ async function convertMessageToUserModelMessage(
imageBlocks: ImageMessageBlock[],
isVisionModel = false,
model?: Model
): Promise<UserModelMessage> {
): Promise<UserModelMessage | (UserModelMessage | SystemModelMessage)[]> {
const parts: Array<TextPart | FilePart | ImagePart> = []
if (content) {
parts.push({ type: 'text', text: content })
@ -85,6 +93,19 @@ async function convertMessageToUserModelMessage(
if (model) {
const filePart = await convertFileBlockToFilePart(fileBlock, model)
if (filePart) {
// 判断filePart是否为string
if (typeof filePart.data === 'string' && filePart.data.startsWith('fileid://')) {
return [
{
role: 'system',
content: filePart.data
},
{
role: 'user',
content: parts.length > 0 ? parts : ''
}
]
}
parts.push(filePart)
logger.debug(`File ${file.origin_name} processed as native file format`)
processed = true
@ -159,7 +180,7 @@ export async function convertMessagesToSdkMessages(messages: Message[], model: M
for (const message of messages) {
const sdkMessage = await convertMessageToSdkParam(message, isVision, model)
sdkMessages.push(sdkMessage)
sdkMessages.push(...(Array.isArray(sdkMessage) ? sdkMessage : [sdkMessage]))
}
return sdkMessages

View File

@ -10,26 +10,61 @@ import { FileTypes } from '@renderer/types'
import { getAiSdkProviderId } from '../provider/factory'
// 工具函数:基于模型名和提供商判断是否支持某特性
function modelSupportValidator(
model: Model,
{
supportedModels = [],
unsupportedModels = [],
supportedProviders = [],
unsupportedProviders = []
}: {
supportedModels?: string[]
unsupportedModels?: string[]
supportedProviders?: string[]
unsupportedProviders?: string[]
}
): boolean {
const provider = getProviderByModel(model)
const aiSdkId = getAiSdkProviderId(provider)
// 黑名单:命中不支持的模型直接拒绝
if (unsupportedModels.some((name) => model.name.includes(name))) {
return false
}
// 黑名单:命中不支持的提供商直接拒绝,常用于某些提供商的同名模型并不具备原模型的某些特性
if (unsupportedProviders.includes(aiSdkId)) {
return false
}
// 白名单:命中支持的模型名
if (supportedModels.some((name) => model.name.includes(name))) {
return true
}
// 回退到提供商判断
return supportedProviders.includes(aiSdkId)
}
/**
* PDF输入
*/
export function supportsPdfInput(model: Model): boolean {
// 基于AI SDK文档这些提供商支持PDF输入
const supportedProviders = [
'openai',
'azure-openai',
'anthropic',
'google',
'google-generative-ai',
'google-vertex',
'bedrock',
'amazon-bedrock'
]
const provider = getProviderByModel(model)
const aiSdkId = getAiSdkProviderId(provider)
return supportedProviders.some((provider) => aiSdkId === provider)
// 基于AI SDK文档以下模型或提供商支持PDF输入
return modelSupportValidator(model, {
supportedModels: ['qwen-long', 'qwen-doc'],
supportedProviders: [
'openai',
'azure-openai',
'anthropic',
'google',
'google-generative-ai',
'google-vertex',
'bedrock',
'amazon-bedrock'
]
})
}
/**
@ -43,11 +78,11 @@ export function supportsImageInput(model: Model): boolean {
* Gemini File API
*/
export function supportsLargeFileUpload(model: Model): boolean {
const provider = getProviderByModel(model)
const aiSdkId = getAiSdkProviderId(provider)
// 目前主要是Gemini系列支持大文件上传
return ['google', 'google-generative-ai', 'google-vertex'].includes(aiSdkId)
// 基于AI SDK文档以下模型或提供商支持大文件上传
return modelSupportValidator(model, {
supportedModels: ['qwen-long', 'qwen-doc'],
supportedProviders: ['google', 'google-generative-ai', 'google-vertex']
})
}
/**
@ -67,6 +102,11 @@ export function getFileSizeLimit(model: Model, fileType: FileTypes): number {
return 20 * 1024 * 1024 // 20MB
}
// Dashscope如果模型支持大文件上传优先使用File API上传
if (aiSdkId === 'dashscope' && supportsLargeFileUpload(model)) {
return 0 // 使用较小的默认值
}
// 其他提供商没有明确限制,使用较大的默认值
// 这与Legacy架构中的实现一致让提供商自行处理文件大小
return Infinity

View File

@ -3,23 +3,28 @@
* AI SDK的流式和非流式参数
*/
import { vertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge'
import { vertex } from '@ai-sdk/google-vertex/edge'
import { loggerService } from '@logger'
import {
isGenerateImageModel,
isOpenRouterBuiltInWebSearchModel,
isReasoningModel,
isSupportedReasoningEffortModel,
isSupportedThinkingTokenClaudeModel,
isSupportedThinkingTokenModel,
isWebSearchModel
} from '@renderer/config/models'
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
import type { Assistant, MCPTool, Provider } from '@renderer/types'
import { type Assistant, type MCPTool, type Provider } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import type { ModelMessage } from 'ai'
import { stepCountIs } from 'ai'
import { getAiSdkProviderId } from '../provider/factory'
import { setupToolsConfig } from '../utils/mcp'
import { buildProviderOptions } from '../utils/options'
import { getAnthropicThinkingBudget } from '../utils/reasoning'
import { getTemperature, getTopP } from './modelParameters'
const logger = loggerService.withContext('parameterBuilder')
@ -54,8 +59,9 @@ export async function buildStreamTextParams(
const { mcpTools } = options
const model = assistant.model || getDefaultModel()
const aiSdkProviderId = getAiSdkProviderId(provider)
const { maxTokens } = getAssistantSettings(assistant)
let { maxTokens } = getAssistantSettings(assistant)
// 这三个变量透传出来,交给下面启用插件/中间件
// 也可以在外部构建好再传入buildStreamTextParams
@ -65,17 +71,20 @@ export async function buildStreamTextParams(
assistant.settings?.reasoning_effort !== undefined) ||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
// 判断是否使用内置搜索
// 条件:没有外部搜索提供商 && (用户开启了内置搜索 || 模型强制使用内置搜索)
const hasExternalSearch = !!options.webSearchProviderId
const enableWebSearch =
(assistant.enableWebSearch && isWebSearchModel(model)) ||
isOpenRouterBuiltInWebSearchModel(model) ||
model.id.includes('sonar') ||
false
!hasExternalSearch &&
((assistant.enableWebSearch && isWebSearchModel(model)) ||
isOpenRouterBuiltInWebSearchModel(model) ||
model.id.includes('sonar'))
const enableUrlContext = assistant.enableUrlContext || false
const enableGenerateImage = !!(isGenerateImageModel(model) && assistant.enableGenerateImage)
const tools = setupToolsConfig(mcpTools)
let tools = setupToolsConfig(mcpTools)
// if (webSearchProviderId) {
// tools['builtin_web_search'] = webSearchTool(webSearchProviderId)
@ -88,6 +97,36 @@ export async function buildStreamTextParams(
enableGenerateImage
})
// NOTE: ai-sdk会把maxToken和budgetToken加起来
if (
enableReasoning &&
maxTokens !== undefined &&
isSupportedThinkingTokenClaudeModel(model) &&
(provider.type === 'anthropic' || provider.type === 'aws-bedrock')
) {
maxTokens -= getAnthropicThinkingBudget(assistant, model)
}
// google-vertex | google-vertex-anthropic
if (enableWebSearch) {
if (!tools) {
tools = {}
}
if (aiSdkProviderId === 'google-vertex') {
tools.google_search = vertex.tools.googleSearch({})
} else if (aiSdkProviderId === 'google-vertex-anthropic') {
tools.web_search = vertexAnthropic.tools.webSearch_20250305({})
}
}
// google-vertex
if (enableUrlContext && aiSdkProviderId === 'google-vertex') {
if (!tools) {
tools = {}
}
tools.url_context = vertex.tools.urlContext({})
}
// 构建基础参数
const params: StreamTextParams = {
messages: sdkMessages,
@ -97,10 +136,12 @@ export async function buildStreamTextParams(
abortSignal: options.requestOptions?.signal,
headers: options.requestOptions?.headers,
providerOptions,
tools,
stopWhen: stepCountIs(10),
maxRetries: 0
}
if (tools) {
params.tools = tools
}
if (assistant.prompt) {
params.system = assistant.prompt
}

View File

@ -15,7 +15,7 @@ import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useV
import { getProviderByModel } from '@renderer/services/AssistantService'
import { loggerService } from '@renderer/services/LoggerService'
import store from '@renderer/store'
import type { Model, Provider } from '@renderer/types'
import { isSystemProvider, type Model, type Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api'
import { cloneDeep, isEmpty } from 'lodash'
@ -61,14 +61,16 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
// return createVertexProvider(provider)
// }
if (provider.id === 'aihubmix') {
return aihubmixProviderCreator(model, provider)
}
if (provider.id === 'newapi') {
return newApiResolverCreator(model, provider)
}
if (provider.id === 'vertexai') {
return vertexAnthropicProviderCreator(model, provider)
if (isSystemProvider(provider)) {
if (provider.id === 'aihubmix') {
return aihubmixProviderCreator(model, provider)
}
if (provider.id === 'new-api') {
return newApiResolverCreator(model, provider)
}
if (provider.id === 'vertexai') {
return vertexAnthropicProviderCreator(model, provider)
}
}
return provider
}

View File

@ -39,6 +39,14 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: true,
aliases: ['aws-bedrock']
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false,
aliases: ['perplexity']
}
] as const

View File

@ -23,8 +23,6 @@ export const knowledgeSearchTool = (
Pre-extracted search queries: "${extractedKeywords.question.join(', ')}"
Rewritten query: "${extractedKeywords.rewrite}"
This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.
Call this tool to execute the search. You can optionally provide additional context to refine the search.`,
inputSchema: z.object({
@ -35,99 +33,102 @@ Call this tool to execute the search. You can optionally provide additional cont
}),
execute: async ({ additionalContext }) => {
try {
// 获取助手的知识库配置
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
// try {
// 获取助手的知识库配置
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
// 检查是否有知识库
if (!hasKnowledgeBase) {
return {
summary: 'No knowledge base configured for this assistant.',
knowledgeReferences: [],
instructions: ''
// 检查是否有知识库
if (!hasKnowledgeBase) {
return []
}
let finalQueries = [...extractedKeywords.question]
let finalRewrite = extractedKeywords.rewrite
if (additionalContext?.trim()) {
// 如果大模型提供了额外上下文,使用更具体的描述
const cleanContext = additionalContext.trim()
if (cleanContext) {
finalQueries = [cleanContext]
finalRewrite = cleanContext
}
}
// 检查是否需要搜索
if (finalQueries[0] === 'not_needed') {
return []
}
// 构建搜索条件
let searchCriteria: { question: string[]; rewrite: string }
if (knowledgeRecognition === 'off') {
// 直接模式:使用用户消息内容
const directContent = userMessage || finalQueries[0] || 'search'
searchCriteria = {
question: [directContent],
rewrite: directContent
}
} else {
// 自动模式:使用意图识别的结果
searchCriteria = {
question: finalQueries,
rewrite: finalRewrite
}
}
// 构建 ExtractResults 对象
const extractResults: ExtractResults = {
websearch: undefined,
knowledge: searchCriteria
}
// 执行知识库搜索
const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds, topicId)
const knowledgeReferencesData = knowledgeReferences.map((ref: KnowledgeReference) => ({
id: ref.id,
content: ref.content,
sourceUrl: ref.sourceUrl,
type: ref.type,
file: ref.file,
metadata: ref.metadata
}))
// TODO 在工具函数中添加搜索缓存机制
// const searchCacheKey = `${topicId}-${JSON.stringify(finalQueries)}`
// 返回结果
return knowledgeReferencesData
},
toModelOutput: (results) => {
let summary = 'No search needed based on the query analysis.'
if (results.length > 0) {
summary = `Found ${results.length} relevant sources. Use [number] format to cite specific information.`
}
const referenceContent = `\`\`\`json\n${JSON.stringify(results, null, 2)}\n\`\`\``
const fullInstructions = REFERENCE_PROMPT.replace(
'{question}',
"Based on the knowledge references, please answer the user's question with proper citations."
).replace('{references}', referenceContent)
return {
type: 'content',
value: [
{
type: 'text',
text: 'This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.'
},
{
type: 'text',
text: summary
},
{
type: 'text',
text: fullInstructions
}
}
let finalQueries = [...extractedKeywords.question]
let finalRewrite = extractedKeywords.rewrite
if (additionalContext?.trim()) {
// 如果大模型提供了额外上下文,使用更具体的描述
const cleanContext = additionalContext.trim()
if (cleanContext) {
finalQueries = [cleanContext]
finalRewrite = cleanContext
}
}
// 检查是否需要搜索
if (finalQueries[0] === 'not_needed') {
return {
summary: 'No search needed based on the query analysis.',
knowledgeReferences: [],
instructions: ''
}
}
// 构建搜索条件
let searchCriteria: { question: string[]; rewrite: string }
if (knowledgeRecognition === 'off') {
// 直接模式:使用用户消息内容
const directContent = userMessage || finalQueries[0] || 'search'
searchCriteria = {
question: [directContent],
rewrite: directContent
}
} else {
// 自动模式:使用意图识别的结果
searchCriteria = {
question: finalQueries,
rewrite: finalRewrite
}
}
// 构建 ExtractResults 对象
const extractResults: ExtractResults = {
websearch: undefined,
knowledge: searchCriteria
}
// 执行知识库搜索
const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds, topicId)
const knowledgeReferencesData = knowledgeReferences.map((ref: KnowledgeReference) => ({
id: ref.id,
content: ref.content,
sourceUrl: ref.sourceUrl,
type: ref.type,
file: ref.file,
metadata: ref.metadata
}))
// const referenceContent = `\`\`\`json\n${JSON.stringify(knowledgeReferencesData, null, 2)}\n\`\`\``
// TODO 在工具函数中添加搜索缓存机制
// const searchCacheKey = `${topicId}-${JSON.stringify(finalQueries)}`
// 可以在插件层面管理已搜索的查询,避免重复搜索
const fullInstructions = REFERENCE_PROMPT.replace(
'{question}',
"Based on the knowledge references, please answer the user's question with proper citations."
).replace('{references}', 'knowledgeReferences:')
// 返回结果
return {
summary: `Found ${knowledgeReferencesData.length} relevant sources. Use [number] format to cite specific information.`,
knowledgeReferences: knowledgeReferencesData,
instructions: fullInstructions
}
} catch (error) {
// 返回空对象而不是抛出错误,避免中断对话流程
return {
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
knowledgeReferences: [],
instructions: ''
}
]
}
}
})

View File

@ -1,6 +1,5 @@
import store from '@renderer/store'
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
import type { Assistant } from '@renderer/types'
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
import { z } from 'zod'
@ -19,133 +18,29 @@ export const memorySearchTool = () => {
limit: z.number().min(1).max(20).default(5).describe('Maximum number of memories to return')
}),
execute: async ({ query, limit = 5 }) => {
// console.log('🧠 [memorySearchTool] Searching memories:', { query, limit })
try {
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
if (!globalMemoryEnabled) {
return []
}
const memoryConfig = selectMemoryConfig(store.getState())
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
// console.warn('Memory search skipped: embedding or LLM model not configured')
return []
}
const currentUserId = selectCurrentUserId(store.getState())
const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, 'default', currentUserId)
const memoryProcessor = new MemoryProcessor()
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
if (relevantMemories?.length > 0) {
// console.log('🧠 [memorySearchTool] Found memories:', relevantMemories.length)
return relevantMemories
}
return []
} catch (error) {
// console.error('🧠 [memorySearchTool] Error:', error)
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
if (!globalMemoryEnabled) {
return []
}
const memoryConfig = selectMemoryConfig(store.getState())
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
return []
}
const currentUserId = selectCurrentUserId(store.getState())
const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, 'default', currentUserId)
const memoryProcessor = new MemoryProcessor()
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
if (relevantMemories?.length > 0) {
return relevantMemories
}
return []
}
})
}
// 方案4: 为第二个工具也使用类型断言
type MessageRole = 'user' | 'assistant' | 'system'
type MessageType = {
content: string
role: MessageRole
}
type MemorySearchWithExtractionInput = {
userMessage: MessageType
lastAnswer?: MessageType
}
/**
* 🧠
*
*/
export const memorySearchToolWithExtraction = (assistant: Assistant) => {
return tool({
name: 'memory_search_with_extraction',
description: 'Search memories with automatic keyword extraction from conversation context',
inputSchema: z.object({
userMessage: z.object({
content: z.string().describe('The main content of the user message'),
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
}),
lastAnswer: z
.object({
content: z.string().describe('The main content of the last assistant response'),
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
})
.optional()
}) as z.ZodSchema<MemorySearchWithExtractionInput>,
execute: async ({ userMessage }) => {
// console.log('🧠 [memorySearchToolWithExtraction] Processing:', { userMessage, lastAnswer })
try {
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
if (!globalMemoryEnabled || !assistant.enableMemory) {
return {
extractedKeywords: 'Memory search disabled',
searchResults: []
}
}
const memoryConfig = selectMemoryConfig(store.getState())
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
// console.warn('Memory search skipped: embedding or LLM model not configured')
return {
extractedKeywords: 'Memory models not configured',
searchResults: []
}
}
// 🔍 使用用户消息内容作为搜索关键词
const content = userMessage.content
if (!content) {
return {
extractedKeywords: 'No content to search',
searchResults: []
}
}
const currentUserId = selectCurrentUserId(store.getState())
const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, assistant.id, currentUserId)
const memoryProcessor = new MemoryProcessor()
const relevantMemories = await memoryProcessor.searchRelevantMemories(
content,
processorConfig,
5 // Limit to top 5 most relevant memories
)
if (relevantMemories?.length > 0) {
// console.log('🧠 [memorySearchToolWithExtraction] Found memories:', relevantMemories.length)
return {
extractedKeywords: content,
searchResults: relevantMemories
}
}
return {
extractedKeywords: content,
searchResults: []
}
} catch (error) {
// console.error('🧠 [memorySearchToolWithExtraction] Error:', error)
return {
extractedKeywords: 'Search failed',
searchResults: []
}
}
}
})
}
export type MemorySearchToolInput = InferToolInput<ReturnType<typeof memorySearchTool>>
export type MemorySearchToolOutput = InferToolOutput<ReturnType<typeof memorySearchTool>>
export type MemorySearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof memorySearchToolWithExtraction>>

View File

@ -30,8 +30,6 @@ Relevant links: ${extractedKeywords.links.join(', ')}`
: ''
}
This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.
Call this tool to execute the search. You can optionally provide additional context to refine the search.`,
inputSchema: z.object({
@ -58,40 +56,27 @@ Call this tool to execute the search. You can optionally provide additional cont
}
// 检查是否需要搜索
if (finalQueries[0] === 'not_needed') {
return {
summary: 'No search needed based on the query analysis.',
searchResults,
sources: '',
instructions: ''
}
return searchResults
}
try {
// 构建 ExtractResults 结构用于 processWebsearch
const extractResults: ExtractResults = {
websearch: {
question: finalQueries,
links: extractedKeywords.links
}
}
searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
} catch (error) {
return {
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
sources: [],
instructions: ''
// 构建 ExtractResults 结构用于 processWebsearch
const extractResults: ExtractResults = {
websearch: {
question: finalQueries,
links: extractedKeywords.links
}
}
if (searchResults.results.length === 0) {
return {
summary: 'No search results found for the given query.',
sources: [],
instructions: ''
}
searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
return searchResults
},
toModelOutput: (results) => {
let summary = 'No search needed based on the query analysis.'
if (results.query && results.results.length > 0) {
summary = `Found ${results.results.length} relevant sources. Use [number] format to cite specific information.`
}
const results = searchResults.results
const citationData = results.map((result, index) => ({
const citationData = results.results.map((result, index) => ({
number: index + 1,
title: result.title,
content: result.content,
@ -99,18 +84,27 @@ Call this tool to execute the search. You can optionally provide additional cont
}))
// 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑
// const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
// 构建完整的引用指导文本
const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
const fullInstructions = REFERENCE_PROMPT.replace(
'{question}',
"Based on the search results, please answer the user's question with proper citations."
).replace('{references}', 'searchResults:')
).replace('{references}', referenceContent)
return {
summary: `Found ${citationData.length} relevant sources. Use [number] format to cite specific information.`,
searchResults,
instructions: fullInstructions
type: 'content',
value: [
{
type: 'text',
text: 'This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.'
},
{
type: 'text',
text: summary
},
{
type: 'text',
text: fullInstructions
}
]
}
}
})

View File

@ -1,7 +1,5 @@
import { loggerService } from '@logger'
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
import { MCPTool, MCPToolResponse } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import { callMCPTool, getMcpServerByTool, isToolAutoApproved } from '@renderer/utils/mcp-tools'
import { requestToolConfirmation } from '@renderer/utils/userConfirmation'
import { type Tool, type ToolSet } from 'ai'
@ -33,8 +31,36 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
tools[mcpTool.name] = tool({
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
execute: async (params, { toolCallId, experimental_context }) => {
const { onChunk } = experimental_context as { onChunk: (chunk: Chunk) => void }
execute: async (params, { toolCallId }) => {
// 检查是否启用自动批准
const server = getMcpServerByTool(mcpTool)
const isAutoApproveEnabled = isToolAutoApproved(mcpTool, server)
let confirmed = true
if (!isAutoApproveEnabled) {
// 请求用户确认
logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`)
confirmed = await requestToolConfirmation(toolCallId)
}
if (!confirmed) {
// 用户拒绝执行工具
logger.debug(`User cancelled tool execution: ${mcpTool.name}`)
return {
content: [
{
type: 'text',
text: `User declined to execute tool "${mcpTool.name}".`
}
],
isError: false
}
}
// 用户确认或自动批准,执行工具
logger.debug(`Executing tool: ${mcpTool.name}`)
// 创建适配的 MCPToolResponse 对象
const toolResponse: MCPToolResponse = {
id: toolCallId,
@ -44,53 +70,18 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
toolCallId
}
try {
// 检查是否启用自动批准
const server = getMcpServerByTool(mcpTool)
const isAutoApproveEnabled = isToolAutoApproved(mcpTool, server)
const result = await callMCPTool(toolResponse)
let confirmed = true
if (!isAutoApproveEnabled) {
// 请求用户确认
logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`)
confirmed = await requestToolConfirmation(toolResponse.id)
}
if (!confirmed) {
// 用户拒绝执行工具
logger.debug(`User cancelled tool execution: ${mcpTool.name}`)
return {
content: [
{
type: 'text',
text: `User declined to execute tool "${mcpTool.name}".`
}
],
isError: false
}
}
// 用户确认或自动批准,执行工具
toolResponse.status = 'invoking'
logger.debug(`Executing tool: ${mcpTool.name}`)
onChunk({
type: ChunkType.MCP_TOOL_IN_PROGRESS,
responses: [toolResponse]
})
const result = await callMCPTool(toolResponse)
// 返回结果AI SDK 会处理序列化
if (result.isError) {
throw new Error(result.content?.[0]?.text || 'Tool execution failed')
}
// 返回工具执行结果
return result
} catch (error) {
logger.error(`MCP Tool execution failed: ${mcpTool.name}`, { error })
throw error
// 返回结果AI SDK 会处理序列化
if (result.isError) {
// throw new Error(result.content?.[0]?.text || 'Tool execution failed')
return Promise.reject(result)
}
// 返回工具执行结果
return result
// } catch (error) {
// logger.error(`MCP Tool execution failed: ${mcpTool.name}`, { error })
// }
}
})
}

View File

@ -120,6 +120,9 @@ export function buildProviderOptions(
case 'google-vertex':
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
break
case 'google-vertex-anthropic':
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
break
default:
// 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = {
@ -137,10 +140,16 @@ export function buildProviderOptions(
...providerSpecificOptions,
...getCustomParameters(assistant)
}
// vertex需要映射到google或anthropic
const rawProviderKey =
{
'google-vertex': 'google',
'google-vertex-anthropic': 'anthropic'
}[rawProviderId] || rawProviderId
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
return {
[rawProviderId]: providerSpecificOptions
[rawProviderKey]: providerSpecificOptions
}
}

View File

@ -310,6 +310,26 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
return {}
}
export function getAnthropicThinkingBudget(assistant: Assistant, model: Model): number {
const { maxTokens, reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
if (maxTokens === undefined || reasoningEffort === undefined) {
return 0
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const budgetTokens = Math.max(
1024,
Math.floor(
Math.min(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
findTokenLimit(model.id)?.min!,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
)
)
)
return budgetTokens
}
/**
* Anthropic
* AnthropicAPIClient
@ -331,19 +351,7 @@ export function getAnthropicReasoningParams(assistant: Assistant, model: Model):
// Claude 推理参数
if (isSupportedThinkingTokenClaudeModel(model)) {
const { maxTokens } = getAssistantSettings(assistant)
const effortRatio = EFFORT_RATIO[reasoningEffort]
const budgetTokens = Math.max(
1024,
Math.floor(
Math.min(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
findTokenLimit(model.id)?.min!,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
)
)
)
const budgetTokens = getAnthropicThinkingBudget(assistant, model)
return {
thinking: {

Binary file not shown.

Before

Width:  |  Height:  |  Size: 299 KiB

View File

@ -0,0 +1,60 @@
.command-list-popover {
/* Base styles are handled inline for theme support */
/* Arrow styles based on placement */
}
.command-list-popover[data-placement^='bottom'] {
transform-origin: top center;
animation: slideDownAndFadeIn 0.15s ease-out;
}
.command-list-popover[data-placement^='top'] {
transform-origin: bottom center;
animation: slideUpAndFadeIn 0.15s ease-out;
}
.command-list-popover[data-placement*='start'] {
transform-origin: left center;
}
.command-list-popover[data-placement*='end'] {
transform-origin: right center;
}
@keyframes slideDownAndFadeIn {
0% {
opacity: 0;
transform: translateY(-8px) scale(0.95);
}
100% {
opacity: 1;
transform: translateY(0) scale(1);
}
}
@keyframes slideUpAndFadeIn {
0% {
opacity: 0;
transform: translateY(8px) scale(0.95);
}
100% {
opacity: 1;
transform: translateY(0) scale(1);
}
}
/* Ensure smooth scrolling in virtual list */
.command-list-popover .dynamic-virtual-list {
scroll-behavior: smooth;
}
/* Better focus indicators */
.command-list-popover [data-index] {
position: relative;
}
.command-list-popover [data-index]:focus-visible {
outline: 2px solid var(--color-primary, #1677ff);
outline-offset: -2px;
}

View File

@ -1,59 +0,0 @@
.command-list-popover {
// Base styles are handled inline for theme support
// Arrow styles based on placement
&[data-placement^='bottom'] {
transform-origin: top center;
animation: slideDownAndFadeIn 0.15s ease-out;
}
&[data-placement^='top'] {
transform-origin: bottom center;
animation: slideUpAndFadeIn 0.15s ease-out;
}
&[data-placement*='start'] {
transform-origin: left center;
}
&[data-placement*='end'] {
transform-origin: right center;
}
}
@keyframes slideDownAndFadeIn {
0% {
opacity: 0;
transform: translateY(-8px) scale(0.95);
}
100% {
opacity: 1;
transform: translateY(0) scale(1);
}
}
@keyframes slideUpAndFadeIn {
0% {
opacity: 0;
transform: translateY(8px) scale(0.95);
}
100% {
opacity: 1;
transform: translateY(0) scale(1);
}
}
// Ensure smooth scrolling in virtual list
.command-list-popover .dynamic-virtual-list {
scroll-behavior: smooth;
}
// Better focus indicators
.command-list-popover [data-index] {
position: relative;
&:focus-visible {
outline: 2px solid var(--color-primary, #1677ff);
outline-offset: -2px;
}
}

View File

@ -10,14 +10,14 @@
}
}
// 电磁波扩散效果
/* 电磁波扩散效果 */
.animation-pulse {
--pulse-color: 59, 130, 246;
--pulse-size: 8px;
animation: animation-pulse 1.5s infinite;
}
// Modal动画
/* Modal动画 */
@keyframes animation-move-down-in {
0% {
transform: translate3d(0, 100%, 0);
@ -54,7 +54,7 @@
animation-duration: 0.25s;
}
// 旋转动画
/* 旋转动画 */
@keyframes animation-rotate {
from {
transform: rotate(0deg);
@ -69,7 +69,7 @@
animation: animation-rotate 0.75s linear infinite;
}
// 定位高亮动画
/* 定位高亮动画 */
@keyframes animation-locate-highlight {
0% {
background-color: transparent;

View File

@ -0,0 +1,238 @@
@import './container.css';
/* Modal 关闭按钮不应该可拖拽,以确保点击正常 */
.ant-modal-close {
-webkit-app-region: no-drag;
}
/* 普通 Drawer 内容不应该可拖拽 */
.ant-drawer-content {
-webkit-app-region: no-drag;
}
/* minapp-drawer 有自己的拖拽规则 */
/* 下拉菜单和弹出框内容不应该可拖拽 */
.ant-dropdown,
.ant-dropdown-menu,
.ant-popover-content,
.ant-tooltip-content,
.ant-popconfirm {
-webkit-app-region: no-drag;
}
#inputbar {
resize: none;
}
.ant-image-preview-switch-left {
-webkit-app-region: no-drag;
}
.ant-btn:not(:disabled):focus-visible {
outline: none;
}
/* Align lucide icon in Button */
.ant-btn .ant-btn-icon {
display: inline-flex;
align-items: center;
justify-content: center;
}
.ant-tabs-tabpane:focus-visible {
outline: none;
}
.ant-tabs-tab-btn {
outline: none !important;
}
.ant-segmented-group {
gap: 4px;
}
.minapp-drawer .ant-drawer-content-wrapper {
box-shadow: none;
}
.minapp-drawer .ant-drawer-header {
position: absolute;
-webkit-app-region: drag;
min-height: calc(var(--navbar-height) + 0.5px);
margin-top: -0.5px;
border-bottom: none;
}
.minapp-drawer .ant-drawer-body {
padding: 0;
margin-top: var(--navbar-height);
overflow: hidden;
/* 手动展开 @extend #content-container 的内容 */
background-color: var(--color-background);
}
.minapp-drawer .minapp-mask {
background-color: transparent !important;
}
[navbar-position='left'] .minapp-drawer {
max-width: calc(100vw - var(--sidebar-width));
}
[navbar-position='left'] .minapp-drawer .ant-drawer-header {
width: calc(100vw - var(--sidebar-width));
}
[navbar-position='top'] .minapp-drawer {
max-width: 100vw;
}
[navbar-position='top'] .minapp-drawer .ant-drawer-header {
width: 100vw;
}
.ant-drawer-header {
/* 普通 drawer header 不应该可拖拽,除非被 minapp-drawer 覆盖 */
-webkit-app-region: no-drag;
}
.message-attachments .ant-upload-list-item:hover {
background-color: initial !important;
}
.ant-dropdown-menu .ant-dropdown-menu-sub {
max-height: 80vh;
width: max-content;
overflow-y: auto;
overflow-x: hidden;
border: 0.5px solid var(--color-border);
}
.ant-dropdown {
background-color: var(--ant-color-bg-elevated);
overflow: hidden;
border-radius: var(--ant-border-radius-lg);
user-select: none;
}
.ant-dropdown .ant-dropdown-menu {
max-height: 80vh;
overflow-y: auto;
border: 0.5px solid var(--color-border);
}
/* Align lucide icon in dropdown menu item extra */
.ant-dropdown .ant-dropdown-menu .ant-dropdown-menu-submenu-expand-icon,
.ant-dropdown .ant-dropdown-menu .ant-dropdown-menu-item-extra {
display: inline-flex;
align-items: center;
justify-content: center;
}
.ant-dropdown .ant-dropdown-arrow + .ant-dropdown-menu {
border: none;
}
.ant-select-dropdown {
border: 0.5px solid var(--color-border);
}
.ant-dropdown-menu-submenu {
background-color: var(--ant-color-bg-elevated);
overflow: hidden;
border-radius: var(--ant-border-radius-lg);
}
.ant-dropdown-menu-submenu .ant-dropdown-menu-submenu-title {
align-items: center;
}
.ant-popover .ant-popover-inner {
border: 0.5px solid var(--color-border);
}
.ant-popover .ant-popover-inner .ant-popover-inner-content {
max-height: 70vh;
overflow-y: auto;
}
.ant-popover .ant-popover-arrow + .ant-popover-content .ant-popover-inner {
border: none;
}
.ant-modal:not(.ant-modal-confirm) .ant-modal-confirm-body-has-title {
padding: 16px 0 0 0;
}
.ant-modal:not(.ant-modal-confirm) .ant-modal-content {
border-radius: 10px;
border: 0.5px solid var(--color-border);
padding: 0 0 8px 0;
}
.ant-modal:not(.ant-modal-confirm) .ant-modal-content .ant-modal-close {
margin-right: 2px;
}
.ant-modal:not(.ant-modal-confirm) .ant-modal-content .ant-modal-header {
padding: 16px 16px 0 16px;
border-radius: 10px;
}
.ant-modal:not(.ant-modal-confirm) .ant-modal-content .ant-modal-body {
/* 保持 body 在视口内,使用标准的最大高度 */
max-height: 80vh;
overflow-y: auto;
padding: 0 16px 0 16px;
}
.ant-modal:not(.ant-modal-confirm) .ant-modal-content .ant-modal-footer {
padding: 0 16px 8px 16px;
}
.ant-modal:not(.ant-modal-confirm) .ant-modal-content .ant-modal-confirm-btns {
margin-bottom: 8px;
}
.ant-modal.ant-modal-confirm.ant-modal-confirm-confirm .ant-modal-content {
padding: 16px;
}
.ant-collapse:not(.ant-collapse-ghost) {
border: 1px solid var(--color-border);
}
.ant-color-picker .ant-collapse:not(.ant-collapse-ghost) {
border: none;
}
.ant-collapse:not(.ant-collapse-ghost) .ant-collapse-content {
border-top: 0.5px solid var(--color-border) !important;
}
.ant-color-picker .ant-collapse:not(.ant-collapse-ghost) .ant-collapse-content {
border-top: none !important;
}
.ant-slider .ant-slider-handle::after {
box-shadow: 0 1px 4px 0px rgb(128 128 128 / 50%) !important;
}
.ant-splitter-bar .ant-splitter-bar-dragger::before {
background-color: var(--color-border) !important;
transition:
background-color 0.15s ease,
width 0.15s ease;
}
.ant-splitter-bar .ant-splitter-bar-dragger:hover::before {
width: 4px !important;
background-color: var(--color-primary) !important;
transition-delay: 0.15s;
}
.ant-splitter-bar .ant-splitter-bar-dragger-active::before {
width: 4px !important;
background-color: var(--color-primary) !important;
}

View File

@ -1,234 +0,0 @@
@use './container.scss';
/* Modal 关闭按钮不应该可拖拽,以确保点击正常 */
.ant-modal-close {
-webkit-app-region: no-drag;
}
/* 普通 Drawer 内容不应该可拖拽 */
.ant-drawer-content {
-webkit-app-region: no-drag;
}
/* minapp-drawer 有自己的拖拽规则 */
/* 下拉菜单和弹出框内容不应该可拖拽 */
.ant-dropdown,
.ant-dropdown-menu,
.ant-popover-content,
.ant-tooltip-content,
.ant-popconfirm {
-webkit-app-region: no-drag;
}
#inputbar {
resize: none;
}
.ant-image-preview-switch-left {
-webkit-app-region: no-drag;
}
.ant-btn:not(:disabled):focus-visible {
outline: none;
}
// Align lucide icon in Button
.ant-btn .ant-btn-icon {
display: inline-flex;
align-items: center;
justify-content: center;
}
.ant-tabs-tabpane:focus-visible {
outline: none;
}
.ant-tabs-tab-btn {
outline: none !important;
}
.ant-segmented-group {
gap: 4px;
}
.minapp-drawer {
[navbar-position='left'] & {
max-width: calc(100vw - var(--sidebar-width));
.ant-drawer-header {
width: calc(100vw - var(--sidebar-width));
}
}
[navbar-position='top'] & {
max-width: 100vw;
.ant-drawer-header {
width: 100vw;
}
}
.ant-drawer-content-wrapper {
box-shadow: none;
}
.ant-drawer-header {
position: absolute;
-webkit-app-region: drag;
min-height: calc(var(--navbar-height) + 0.5px);
margin-top: -0.5px;
border-bottom: none;
}
.ant-drawer-body {
padding: 0;
margin-top: var(--navbar-height);
overflow: hidden;
@extend #content-container;
}
.minapp-mask {
background-color: transparent !important;
}
}
.ant-drawer-header {
/* 普通 drawer header 不应该可拖拽,除非被 minapp-drawer 覆盖 */
-webkit-app-region: no-drag;
}
.message-attachments {
.ant-upload-list-item:hover {
background-color: initial !important;
}
}
.ant-dropdown-menu .ant-dropdown-menu-sub {
max-height: 80vh;
width: max-content;
overflow-y: auto;
overflow-x: hidden;
border: 0.5px solid var(--color-border);
}
.ant-dropdown {
background-color: var(--ant-color-bg-elevated);
overflow: hidden;
border-radius: var(--ant-border-radius-lg);
user-select: none;
.ant-dropdown-menu {
max-height: 80vh;
overflow-y: auto;
border: 0.5px solid var(--color-border);
// Align lucide icon in dropdown menu item extra
.ant-dropdown-menu-submenu-expand-icon,
.ant-dropdown-menu-item-extra {
display: inline-flex;
align-items: center;
justify-content: center;
}
}
.ant-dropdown-arrow + .ant-dropdown-menu {
border: none;
}
}
.ant-select-dropdown {
border: 0.5px solid var(--color-border);
}
.ant-dropdown-menu-submenu {
background-color: var(--ant-color-bg-elevated);
overflow: hidden;
border-radius: var(--ant-border-radius-lg);
.ant-dropdown-menu-submenu-title {
align-items: center;
}
}
.ant-popover {
.ant-popover-inner {
border: 0.5px solid var(--color-border);
.ant-popover-inner-content {
max-height: 70vh;
overflow-y: auto;
}
}
.ant-popover-arrow + .ant-popover-content {
.ant-popover-inner {
border: none;
}
}
}
.ant-modal:not(.ant-modal-confirm) {
.ant-modal-confirm-body-has-title {
padding: 16px 0 0 0;
}
.ant-modal-content {
border-radius: 10px;
border: 0.5px solid var(--color-border);
padding: 0 0 8px 0;
.ant-modal-close {
margin-right: 2px;
}
.ant-modal-header {
padding: 16px 16px 0 16px;
border-radius: 10px;
}
.ant-modal-body {
/* 保持 body 在视口内,使用标准的最大高度 */
max-height: 80vh;
overflow-y: auto;
padding: 0 16px 0 16px;
}
.ant-modal-footer {
padding: 0 16px 8px 16px;
}
.ant-modal-confirm-btns {
margin-bottom: 8px;
}
}
}
.ant-modal.ant-modal-confirm.ant-modal-confirm-confirm {
.ant-modal-content {
padding: 16px;
}
}
.ant-collapse:not(.ant-collapse-ghost) {
border: 1px solid var(--color-border);
.ant-color-picker & {
border: none;
}
.ant-collapse-content {
border-top: 0.5px solid var(--color-border) !important;
.ant-color-picker & {
border-top: none !important;
}
}
}
.ant-slider {
.ant-slider-handle::after {
box-shadow: 0 1px 4px 0px rgb(128 128 128 / 50%) !important;
}
}
.ant-splitter-bar {
.ant-splitter-bar-dragger {
&::before {
background-color: var(--color-border) !important;
transition:
background-color 0.15s ease,
width 0.15s ease;
}
&:hover {
&::before {
width: 4px !important;
background-color: var(--color-primary) !important;
transition-delay: 0.15s;
}
}
}
.ant-splitter-bar-dragger-active {
&::before {
width: 4px !important;
background-color: var(--color-primary) !important;
}
}
}

View File

@ -19,7 +19,7 @@
--color-background-soft: var(--color-black-soft);
--color-background-mute: var(--color-black-mute);
--color-background-opacity: rgba(34, 34, 34, 0.7);
--inner-glow-opacity: 0.3; // For the glassmorphism effect in the dropdown menu
--inner-glow-opacity: 0.3; /* For the glassmorphism effect in the dropdown menu */
--color-primary: #00b96b;
--color-primary-soft: #00b96b99;
@ -58,16 +58,6 @@
--navbar-background-mac: rgba(20, 20, 20, 0.55);
--navbar-background: #1f1f1f;
--navbar-height: 44px;
--sidebar-width: 50px;
--status-bar-height: 40px;
--input-bar-height: 100px;
--assistants-width: 275px;
--topic-list-width: 275px;
--settings-width: 250px;
--scrollbar-width: 5px;
--chat-background: transparent;
--chat-background-user: rgba(255, 255, 255, 0.08);
--chat-background-assistant: transparent;
@ -146,11 +136,6 @@
--chat-text-user: var(--color-text);
}
[navbar-position='left'] {
--navbar-height: 42px;
--list-item-border-radius: 20px;
}
[navbar-position='left'][theme-mode='light'] {
--color-list-item: #eee;
--color-list-item-hover: #f5f5f5;

View File

@ -0,0 +1,9 @@
#content-container {
background-color: var(--color-background);
}
[navbar-position='left'] #content-container {
border-top: 0.5px solid var(--color-border);
border-top-left-radius: 10px;
border-left: 0.5px solid var(--color-border);
}

View File

@ -1,11 +0,0 @@
#content-container {
background-color: var(--color-background);
}
[navbar-position='left'] {
#content-container {
border-top: 0.5px solid var(--color-border);
border-top-left-radius: 10px;
border-left: 0.5px solid var(--color-border);
}
}

View File

@ -0,0 +1,24 @@
:root {
--font-family:
var(--user-font-family), Ubuntu, -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui, Roboto, Oxygen,
Cantarell, 'Open Sans', 'Helvetica Neue', Arial, 'Noto Sans', sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji',
'Segoe UI Symbol', 'Noto Color Emoji';
--font-family-serif:
serif, -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui, Ubuntu, Roboto, Oxygen, Cantarell, 'Open Sans',
'Helvetica Neue', Arial, 'Noto Sans', 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
--code-font-family: var(--user-code-font-family), 'Cascadia Code', 'Fira Code', 'Consolas', Menlo, Courier, monospace;
}
/* Windows系统专用字体配置 */
body[os='windows'] {
--font-family:
var(--user-font-family), 'Twemoji Country Flags', Ubuntu, -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui,
Roboto, Oxygen, Cantarell, 'Open Sans', 'Helvetica Neue', Arial, 'Noto Sans', sans-serif, 'Apple Color Emoji',
'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
--code-font-family:
var(--user-code-font-family), 'Cascadia Code', 'Fira Code', 'Consolas', 'Sarasa Mono SC', 'Microsoft YaHei UI',
Courier, monospace;
}

View File

@ -1,23 +0,0 @@
:root {
--font-family:
Ubuntu, -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui, Roboto, Oxygen, Cantarell, 'Open Sans',
'Helvetica Neue', Arial, 'Noto Sans', sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol',
'Noto Color Emoji';
--font-family-serif:
serif, -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui, Ubuntu, Roboto, Oxygen, Cantarell, 'Open Sans',
'Helvetica Neue', Arial, 'Noto Sans', 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
--code-font-family: 'Cascadia Code', 'Fira Code', 'Consolas', Menlo, Courier, monospace;
}
// Windows系统专用字体配置
body[os='windows'] {
--font-family:
'Twemoji Country Flags', Ubuntu, -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui, Roboto, Oxygen,
Cantarell, 'Open Sans', 'Helvetica Neue', Arial, 'Noto Sans', sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji',
'Segoe UI Symbol', 'Noto Color Emoji';
--code-font-family:
'Cascadia Code', 'Fira Code', 'Consolas', 'Sarasa Mono SC', 'Microsoft YaHei UI', Courier, monospace;
}

View File

@ -1,11 +1,12 @@
@use './color.scss';
@use './font.scss';
@use './markdown.scss';
@use './ant.scss';
@use './scrollbar.scss';
@use './container.scss';
@use './animation.scss';
@use './richtext.scss';
@import './color.css';
@import './font.css';
@import './markdown.css';
@import './ant.css';
@import './scrollbar.css';
@import './container.css';
@import './animation.css';
@import './richtext.css';
@import './responsive.css';
@import '../fonts/icon-fonts/iconfont.css';
@import '../fonts/ubuntu/ubuntu.css';
@import '../fonts/country-flag-fonts/flag.css';
@ -14,7 +15,7 @@
*::before,
*::after {
box-sizing: border-box;
margin: 0;
/* margin: 0; */
font-weight: normal;
}
@ -34,11 +35,11 @@ body,
margin: 0;
}
#root {
/* #root {
display: flex;
flex-direction: row;
flex: 1;
}
} */
body {
display: flex;
@ -50,6 +51,7 @@ body {
font-family: var(--font-family);
text-rendering: optimizeLegibility;
transition: background-color 0.3s linear;
background-color: unset;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
@ -113,62 +115,58 @@ ul {
word-wrap: break-word;
}
.bubble:not(.multi-select-mode) {
.block-wrapper {
display: flow-root;
}
.bubble:not(.multi-select-mode) .block-wrapper {
display: flow-root;
}
.block-wrapper:last-child > *:last-child {
margin-bottom: 0;
}
.bubble:not(.multi-select-mode) .block-wrapper:last-child > *:last-child {
margin-bottom: 0;
}
.message-content-container > *:last-child {
margin-bottom: 0;
}
.bubble:not(.multi-select-mode) .message-content-container > *:last-child {
margin-bottom: 0;
}
.message-thought-container {
margin-top: 8px;
}
.bubble:not(.multi-select-mode) .message-thought-container {
margin-top: 8px;
}
.message-user {
.message-header {
flex-direction: row-reverse;
text-align: right;
.message-header-info-wrap {
flex-direction: row-reverse;
text-align: right;
}
}
.message-content-container {
border-radius: 10px;
padding: 10px 16px 10px 16px;
background-color: var(--chat-background-user);
align-self: self-end;
}
.MessageFooter {
margin-top: 2px;
align-self: self-end;
}
}
.bubble:not(.multi-select-mode) .message-user .message-header {
flex-direction: row-reverse;
text-align: right;
}
.message-assistant {
.message-content-container {
padding-left: 0;
}
.MessageFooter {
margin-left: 0;
}
}
.bubble:not(.multi-select-mode) .message-user .message-header .message-header-info-wrap {
flex-direction: row-reverse;
text-align: right;
}
code {
color: var(--color-text);
}
.markdown {
display: flow-root;
*:last-child {
margin-bottom: 0;
}
}
.bubble:not(.multi-select-mode) .message-user .message-content-container {
border-radius: 10px;
padding: 10px 16px 10px 16px;
background-color: var(--chat-background-user);
align-self: self-end;
}
.bubble:not(.multi-select-mode) .message-user .MessageFooter {
margin-top: 2px;
align-self: self-end;
}
.bubble:not(.multi-select-mode) .message-assistant .message-content-container {
padding-left: 0;
}
.bubble:not(.multi-select-mode) .message-assistant .MessageFooter {
margin-left: 0;
}
.bubble:not(.multi-select-mode) code {
color: var(--color-text);
}
.bubble:not(.multi-select-mode) .markdown {
display: flow-root;
}
.lucide:not(.lucide-custom) {
@ -184,8 +182,6 @@ ul {
background-color: var(--color-background-highlight-accent);
}
textarea {
&::-webkit-resizer {
display: none;
}
textarea::-webkit-resizer {
display: none;
}

View File

@ -0,0 +1,388 @@
.markdown {
color: var(--color-text);
line-height: 1.6;
user-select: text;
word-break: break-word;
}
.markdown h1:first-child,
.markdown h2:first-child,
.markdown h3:first-child,
.markdown h4:first-child,
.markdown h5:first-child,
.markdown h6:first-child {
margin-top: 0;
}
.markdown h1,
.markdown h2,
.markdown h3,
.markdown h4,
.markdown h5,
.markdown h6 {
margin: 1.5em 0 1em 0;
line-height: 1.3;
font-weight: bold;
}
.markdown h1 {
margin-top: 0;
font-size: 2em;
border-bottom: 0.5px solid var(--color-border);
padding-bottom: 0.3em;
}
.markdown h2 {
font-size: 1.5em;
border-bottom: 0.5px solid var(--color-border);
padding-bottom: 0.3em;
}
.markdown h3 {
font-size: 1.2em;
}
.markdown h4 {
font-size: 1em;
}
.markdown h5 {
font-size: 0.9em;
}
.markdown h6 {
font-size: 0.8em;
}
.markdown p {
margin: 1.3em 0;
white-space: pre-wrap;
line-height: 1.6;
}
.markdown p:last-child {
margin-bottom: 5px;
}
.markdown p:first-child {
margin-top: 0;
}
.markdown p:has(+ ul) {
margin-bottom: 0;
}
.markdown ul {
list-style: initial;
}
.markdown ul,
.markdown ol {
padding-left: 1.5em;
margin: 1em 0;
}
.markdown li {
margin-bottom: 0.5em;
}
.markdown li pre {
margin: 1.5em 0 !important;
}
.markdown li::marker {
color: var(--color-text-3);
}
.markdown li > ul,
.markdown li > ol {
margin: 0.5em 0;
}
.markdown hr {
border: none;
border-top: 0.5px solid var(--color-border);
margin: 20px 0;
}
.markdown span {
white-space: pre-wrap;
}
.markdown .katex span {
white-space: pre;
}
.markdown p code,
.markdown li code {
background: var(--color-background-mute);
padding: 3px 5px;
margin: 0 2px;
border-radius: 5px;
word-break: keep-all;
white-space: pre;
}
.markdown code {
font-family: var(--code-font-family);
}
.markdown pre {
border-radius: 8px;
overflow-x: auto;
font-family: var(--code-font-family);
background-color: var(--color-background-mute);
}
.markdown pre:has(.special-preview) {
background-color: transparent;
}
.markdown pre:not(pre pre) > code:not(pre pre > code) {
padding: 15px;
display: block;
}
.markdown pre pre {
margin: 0 !important;
}
.markdown pre pre code {
background: none;
padding: 0;
border-radius: 0;
}
.markdown pre + pre {
margin-top: 10px;
}
.markdown .markdown-alert,
.markdown blockquote {
margin: 1.5em 0;
padding: 1em 1.5em;
background-color: var(--color-background-soft);
border-left: 4px solid var(--color-primary);
border-radius: 0 8px 8px 0;
font-style: italic;
position: relative;
}
.markdown table {
--table-border-radius: 8px;
margin: 2em 0;
font-size: 0.9em;
width: 100%;
border-radius: var(--table-border-radius);
overflow: hidden;
border-collapse: separate;
border: 0.5px solid var(--color-border);
border-spacing: 0;
}
.markdown th,
.markdown td {
border-right: 0.5px solid var(--color-border);
border-bottom: 0.5px solid var(--color-border);
padding: 0.5em;
}
.markdown th:last-child,
.markdown td:last-child {
border-right: none;
}
.markdown tr:last-child td {
border-bottom: none;
}
.markdown th {
background-color: var(--color-background-mute);
font-weight: 600;
text-align: left;
}
.markdown tr:hover {
background-color: var(--color-background-soft);
}
.markdown img {
max-width: 100%;
height: auto;
margin: 1em 0;
}
.markdown a,
.markdown .link {
color: var(--color-link);
text-decoration: none;
cursor: pointer;
}
.markdown a:hover,
.markdown .link:hover {
text-decoration: underline;
}
.markdown strong {
font-weight: bold;
}
.markdown em {
font-style: italic;
}
.markdown del {
text-decoration: line-through;
}
.markdown sup,
.markdown sub {
font-size: 75%;
line-height: 0;
position: relative;
vertical-align: baseline;
}
.markdown sup {
top: -0.5em;
border-radius: 50%;
background-color: var(--color-reference);
color: var(--color-reference-text);
padding: 2px 5px;
zoom: 0.8;
}
.markdown sup > span.link {
color: var(--color-reference-text);
}
.markdown sub {
bottom: -0.25em;
}
.markdown .footnote-ref {
font-size: 0.8em;
vertical-align: super;
line-height: 0;
margin: 0 2px;
color: var(--color-primary);
text-decoration: none;
}
.markdown .footnote-ref:hover {
text-decoration: underline;
}
.footnotes {
margin-top: 1em;
margin-bottom: 1em;
padding-top: 1em;
background-color: var(--color-reference-background);
border-radius: 8px;
padding: 8px 12px;
}
.footnotes h4 {
margin-bottom: 5px;
font-size: 12px;
}
.footnotes a {
color: var(--color-link);
}
.footnotes ol {
padding-left: 1em;
margin: 0;
}
.footnotes ol li:last-child {
margin-bottom: 0;
}
.footnotes li {
font-size: 0.9em;
margin-bottom: 0.5em;
color: var(--color-text-light);
}
.footnotes li p {
display: inline;
margin: 0;
}
.footnotes .footnote-backref {
font-size: 0.8em;
vertical-align: super;
line-height: 0;
margin-left: 5px;
color: var(--color-primary);
text-decoration: none;
}
.footnotes .footnote-backref:hover {
text-decoration: underline;
}
emoji-picker {
--border-size: 0;
}
.block-wrapper + .block-wrapper {
margin-top: 1em;
}
.katex,
mjx-container {
display: inline-block;
overflow-x: auto;
overflow-y: hidden;
overflow-wrap: break-word;
vertical-align: middle;
max-width: 100%;
padding: 1px 2px;
margin-top: -2px;
}
/* Shiki 相关样式 */
.shiki {
font-family: var(--code-font-family);
/* 保持行高为初始值,在 shiki 代码块中处理 */
line-height: initial;
}
/* CodeMirror 相关样式 */
.cm-editor {
border-radius: inherit;
}
.cm-editor.cm-focused {
outline: none;
}
.cm-editor .cm-scroller {
font-family: var(--code-font-family);
border-radius: inherit;
}
.cm-editor .cm-scroller .cm-gutters {
line-height: 1.6;
border-right: none;
}
.cm-editor .cm-scroller .cm-content {
line-height: 1.6;
padding-left: 0.25em;
}
.cm-editor .cm-scroller .cm-lineWrapping * {
word-wrap: break-word;
white-space: pre-wrap;
}
.cm-editor .cm-announced {
position: absolute;
display: none;
}

View File

@ -1,379 +0,0 @@
.markdown {
color: var(--color-text);
line-height: 1.6;
user-select: text;
word-break: break-word;
h1:first-child,
h2:first-child,
h3:first-child,
h4:first-child,
h5:first-child,
h6:first-child {
margin-top: 0;
}
h1,
h2,
h3,
h4,
h5,
h6 {
margin: 1.5em 0 1em 0;
line-height: 1.3;
font-weight: bold;
}
h1 {
margin-top: 0;
font-size: 2em;
border-bottom: 0.5px solid var(--color-border);
padding-bottom: 0.3em;
}
h2 {
font-size: 1.5em;
border-bottom: 0.5px solid var(--color-border);
padding-bottom: 0.3em;
}
h3 {
font-size: 1.2em;
}
h4 {
font-size: 1em;
}
h5 {
font-size: 0.9em;
}
h6 {
font-size: 0.8em;
}
p {
margin: 1.3em 0;
white-space: pre-wrap;
line-height: 1.6;
&:last-child {
margin-bottom: 5px;
}
&:first-child {
margin-top: 0;
}
&:has(+ ul) {
margin-bottom: 0;
}
}
ul {
list-style: initial;
}
ul,
ol {
padding-left: 1.5em;
margin: 1em 0;
}
li {
margin-bottom: 0.5em;
pre {
margin: 1.5em 0 !important;
}
&::marker {
color: var(--color-text-3);
}
}
li > ul,
li > ol {
margin: 0.5em 0;
}
hr {
border: none;
border-top: 0.5px solid var(--color-border);
margin: 20px 0;
}
span {
white-space: pre-wrap;
}
.katex span {
white-space: pre;
}
p code,
li code {
background: var(--color-background-mute);
padding: 3px 5px;
margin: 0 2px;
border-radius: 5px;
word-break: keep-all;
white-space: pre;
}
code {
font-family: var(--code-font-family);
}
pre {
border-radius: 8px;
overflow-x: auto;
font-family: var(--code-font-family);
background-color: var(--color-background-mute);
&:has(.special-preview) {
background-color: transparent;
}
&:not(pre pre) {
> code:not(pre pre > code) {
padding: 15px;
display: block;
}
}
pre {
margin: 0 !important;
code {
background: none;
padding: 0;
border-radius: 0;
}
}
}
pre + pre {
margin-top: 10px;
}
.markdown-alert,
blockquote {
margin: 1.5em 0;
padding: 1em 1.5em;
background-color: var(--color-background-soft);
border-left: 4px solid var(--color-primary);
border-radius: 0 8px 8px 0;
font-style: italic;
position: relative;
}
table {
--table-border-radius: 8px;
margin: 2em 0;
font-size: 0.9em;
width: 100%;
border-radius: var(--table-border-radius);
overflow: hidden;
border-collapse: separate;
border: 0.5px solid var(--color-border);
border-spacing: 0;
}
th,
td {
border-right: 0.5px solid var(--color-border);
border-bottom: 0.5px solid var(--color-border);
padding: 0.5em;
&:last-child {
border-right: none;
}
}
tr:last-child td {
border-bottom: none;
}
th {
background-color: var(--color-background-mute);
font-weight: 600;
text-align: left;
}
tr:hover {
background-color: var(--color-background-soft);
}
img {
max-width: 100%;
height: auto;
margin: 1em 0;
}
a,
.link {
color: var(--color-link);
text-decoration: none;
cursor: pointer;
&:hover {
text-decoration: underline;
}
}
strong {
font-weight: bold;
}
em {
font-style: italic;
}
del {
text-decoration: line-through;
}
sup,
sub {
font-size: 75%;
line-height: 0;
position: relative;
vertical-align: baseline;
}
sup {
top: -0.5em;
border-radius: 50%;
background-color: var(--color-reference);
color: var(--color-reference-text);
padding: 2px 5px;
zoom: 0.8;
& > span.link {
color: var(--color-reference-text);
}
}
sub {
bottom: -0.25em;
}
.footnote-ref {
font-size: 0.8em;
vertical-align: super;
line-height: 0;
margin: 0 2px;
color: var(--color-primary);
text-decoration: none;
&:hover {
text-decoration: underline;
}
}
}
.footnotes {
margin-top: 1em;
margin-bottom: 1em;
padding-top: 1em;
background-color: var(--color-reference-background);
border-radius: 8px;
padding: 8px 12px;
h4 {
margin-bottom: 5px;
font-size: 12px;
}
a {
color: var(--color-link);
}
ol {
padding-left: 1em;
margin: 0;
li:last-child {
margin-bottom: 0;
}
}
li {
font-size: 0.9em;
margin-bottom: 0.5em;
color: var(--color-text-light);
p {
display: inline;
margin: 0;
}
}
.footnote-backref {
font-size: 0.8em;
vertical-align: super;
line-height: 0;
margin-left: 5px;
color: var(--color-primary);
text-decoration: none;
&:hover {
text-decoration: underline;
}
}
}
emoji-picker {
--border-size: 0;
}
.block-wrapper + .block-wrapper {
margin-top: 1em;
}
.katex,
mjx-container {
display: inline-block;
overflow-x: auto;
overflow-y: hidden;
overflow-wrap: break-word;
vertical-align: middle;
max-width: 100%;
padding: 1px 2px;
margin-top: -2px;
}
/* Shiki 相关样式 */
.shiki {
font-family: var(--code-font-family);
// 保持行高为初始值 shiki 代码块中处理
line-height: initial;
}
/* CodeMirror 相关样式 */
.cm-editor {
border-radius: inherit;
&.cm-focused {
outline: none;
}
.cm-scroller {
font-family: var(--code-font-family);
border-radius: inherit;
.cm-gutters {
line-height: 1.6;
border-right: none;
}
.cm-content {
line-height: 1.6;
padding-left: 0.25em;
}
.cm-lineWrapping * {
word-wrap: break-word;
white-space: pre-wrap;
}
}
.cm-announced {
position: absolute;
display: none;
}
}

Some files were not shown because too many files have changed in this diff Show More