mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-24 10:40:07 +08:00
Merge branch 'main' of github.com:CherryHQ/cherry-studio into wip/data-refactor
This commit is contained in:
commit
57fd73e51a
4
.github/workflows/auto-i18n.yml
vendored
4
.github/workflows/auto-i18n.yml
vendored
@ -1,7 +1,7 @@
|
||||
name: Auto I18N
|
||||
|
||||
env:
|
||||
API_KEY: ${{ secrets.TRANSLATE_API_KEY}}
|
||||
API_KEY: ${{ secrets.TRANSLATE_API_KEY }}
|
||||
MODEL: ${{ vars.MODEL || 'deepseek/deepseek-v3.1'}}
|
||||
BASE_URL: ${{ vars.BASE_URL || 'https://api.ppinfra.com/openai'}}
|
||||
|
||||
@ -35,7 +35,7 @@ jobs:
|
||||
# 在临时目录安装依赖
|
||||
mkdir -p /tmp/translation-deps
|
||||
cd /tmp/translation-deps
|
||||
echo '{"dependencies": {"openai": "^5.12.2", "cli-progress": "^3.12.0", "tsx": "^4.20.3", "prettier": "^3.5.3", "prettier-plugin-sort-json": "^4.1.1"}}' > package.json
|
||||
echo '{"dependencies": {"openai": "^5.12.2", "cli-progress": "^3.12.0", "tsx": "^4.20.3", "prettier": "^3.5.3", "prettier-plugin-sort-json": "^4.1.1", "prettier-plugin-tailwindcss": "^0.6.14"}}' > package.json
|
||||
npm install --no-package-lock
|
||||
|
||||
# 设置 NODE_PATH 让项目能找到这些依赖
|
||||
|
||||
16
.github/workflows/claude-code-review.yml
vendored
16
.github/workflows/claude-code-review.yml
vendored
@ -2,7 +2,7 @@ name: Claude Code Review
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
types: [opened]
|
||||
# Optional: Only run on specific file changes
|
||||
# paths:
|
||||
# - "src/**/*.ts"
|
||||
@ -12,12 +12,11 @@ on:
|
||||
|
||||
jobs:
|
||||
claude-review:
|
||||
# Optional: Filter by PR author
|
||||
# if: |
|
||||
# github.event.pull_request.user.login == 'external-contributor' ||
|
||||
# github.event.pull_request.user.login == 'new-developer' ||
|
||||
# github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR'
|
||||
|
||||
# Only trigger code review for PRs from the main repository due to upstream OIDC issues
|
||||
# https://github.com/anthropics/claude-code-action/issues/542
|
||||
if: |
|
||||
(github.event.pull_request.head.repo.full_name == github.repository) &&
|
||||
(github.event.pull_request.draft == false)
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@ -45,6 +44,9 @@ jobs:
|
||||
- Security concerns
|
||||
- Test coverage
|
||||
|
||||
PR number: ${{ github.event.number }}
|
||||
Repo: ${{ github.repository }}
|
||||
|
||||
Use the repository's CLAUDE.md for guidance on style and conventions. Be constructive and helpful in your feedback.
|
||||
|
||||
Use `gh pr comment` with your Bash tool to leave your review as a comment on the PR.
|
||||
|
||||
23
.github/workflows/claude-translator.yml
vendored
23
.github/workflows/claude-translator.yml
vendored
@ -1,6 +1,6 @@
|
||||
name: English Translator
|
||||
name: Claude Translator
|
||||
concurrency:
|
||||
group: translator-${{ github.event.issue.number }}
|
||||
group: translator-${{ github.event.comment.id || github.event.issue.number }}
|
||||
cancel-in-progress: false
|
||||
|
||||
on:
|
||||
@ -12,13 +12,15 @@ on:
|
||||
jobs:
|
||||
translate:
|
||||
if: |
|
||||
(github.event_name == 'issues' && github.event.issue.author_association == 'COLLABORATOR' && !contains(github.event.issue.body, 'This issue/comment was translated by Claude.')) ||
|
||||
(github.event_name == 'issue_comment' && github.event.comment.author_association == 'COLLABORATOR' && !contains(github.event.issue.body, 'This issue/comment was translated by Claude.'))
|
||||
(github.event_name == 'issues') ||
|
||||
(github.event_name == 'issue_comment' && github.event.sender.type != 'Bot') &&
|
||||
((github.event_name == 'issue_comment' && github.event.action == 'created' && !contains(github.event.comment.body, 'This issue was translated by Claude')) ||
|
||||
(github.event_name == 'issue_comment' && github.event.action == 'edited'))
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write # 编辑issues/comments
|
||||
pull-requests: read
|
||||
pull-requests: write
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
@ -28,11 +30,16 @@ jobs:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Run Claude for translation
|
||||
uses: anthropics/claude-code-action@v1
|
||||
uses: anthropics/claude-code-action@main
|
||||
id: claude
|
||||
with:
|
||||
# Warning: Permissions should have been controlled by workflow permission.
|
||||
# Now `contents: read` is safe for files, but we could make a fine-grained token to control it.
|
||||
# See: https://github.com/anthropics/claude-code-action/blob/main/docs/security.md
|
||||
github_token: ${{ secrets.TOKEN_GITHUB_WRITE }}
|
||||
allowed_non_write_users: '*'
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
claude_args: '--allowed-tools mcp__github_comment__update_claude_comment,Bash(gh issue:*),Bash(gh api:repos/*/issues:*)'
|
||||
claude_args: '--allowed-tools Bash(gh issue:*),Bash(gh api:repos/*/issues:*)'
|
||||
prompt: |
|
||||
你是一个多语言翻译助手。请完成以下任务:
|
||||
|
||||
@ -50,7 +57,7 @@ jobs:
|
||||
|
||||
---
|
||||
<details>
|
||||
<summary>**Original Content:**</summary>
|
||||
<summary>Original Content</summary>
|
||||
[原始内容]
|
||||
</details>
|
||||
|
||||
|
||||
@ -3,9 +3,11 @@
|
||||
"endOfLine": "lf",
|
||||
"jsonRecursiveSort": true,
|
||||
"jsonSortOrder": "{\"*\": \"lexical\"}",
|
||||
"plugins": ["prettier-plugin-sort-json"],
|
||||
"plugins": ["prettier-plugin-sort-json", "prettier-plugin-tailwindcss"],
|
||||
"printWidth": 120,
|
||||
"semi": false,
|
||||
"singleQuote": true,
|
||||
"tailwindFunctions": ["clsx"],
|
||||
"tailwindStylesheet": "./src/renderer/src/assets/styles/tailwind.css",
|
||||
"trailingComma": "none"
|
||||
}
|
||||
|
||||
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@ -28,6 +28,9 @@
|
||||
"source.organizeImports": "never"
|
||||
},
|
||||
"editor.formatOnSave": true,
|
||||
"files.associations": {
|
||||
"*.css": "tailwindcss"
|
||||
},
|
||||
"files.eol": "\n",
|
||||
"i18n-ally.displayLanguage": "zh-cn",
|
||||
"i18n-ally.enabledFrameworks": ["react-i18next", "i18next"],
|
||||
|
||||
@ -92,6 +92,12 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
- **Multi-language Support**: i18n with dynamic loading
|
||||
- **Theme System**: Light/dark themes with custom CSS variables
|
||||
|
||||
### UI Design
|
||||
|
||||
The project is in the process of migrating from antd & styled-components to HeroUI. Please use HeroUI to build UI components. The use of antd and styled-components is prohibited.
|
||||
|
||||
HeroUI Docs: https://www.heroui.com/docs/guide/introduction
|
||||
|
||||
### Database Architecture
|
||||
|
||||
- **Database**: SQLite (`cherrystudio.sqlite`) + libsql driver
|
||||
|
||||
@ -82,7 +82,7 @@ Cherry Studio is a desktop client that supports multiple LLM providers, availabl
|
||||
1. **Diverse LLM Provider Support**:
|
||||
|
||||
- ☁️ Major LLM Cloud Services: OpenAI, Gemini, Anthropic, and more
|
||||
- 🔗 AI Web Service Integration: Claude, Peplexity, Poe, and others
|
||||
- 🔗 AI Web Service Integration: Claude, Perplexity, Poe, and others
|
||||
- 💻 Local Model Support with Ollama, LM Studio
|
||||
|
||||
2. **AI Assistants & Conversations**:
|
||||
|
||||
21
components.json
Normal file
21
components.json
Normal file
@ -0,0 +1,21 @@
|
||||
{
|
||||
"$schema": "https://ui.shadcn.com/schema.json",
|
||||
"aliases": {
|
||||
"components": "@renderer/ui/third-party",
|
||||
"hooks": "@renderer/hooks",
|
||||
"lib": "@renderer/lib",
|
||||
"ui": "@renderer/ui",
|
||||
"utils": "@renderer/utils"
|
||||
},
|
||||
"iconLibrary": "lucide",
|
||||
"rsc": false,
|
||||
"style": "new-york",
|
||||
"tailwind": {
|
||||
"baseColor": "zinc",
|
||||
"config": "",
|
||||
"css": "src/renderer/src/assets/styles/tailwind.css",
|
||||
"cssVariables": true,
|
||||
"prefix": ""
|
||||
},
|
||||
"tsx": true
|
||||
}
|
||||
@ -13,7 +13,7 @@
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=fr">Français</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=de">Deutsch</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=es">Español</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=it">Itapano</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=it">Italiano</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=ru">Русский</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=pt">Português</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=nl">Nederlands</a></p>
|
||||
@ -89,7 +89,7 @@ https://docs.cherry-ai.com
|
||||
1. **多样化 LLM 服务支持**:
|
||||
|
||||
- ☁️ 支持主流 LLM 云服务:OpenAI、Gemini、Anthropic、硅基流动等
|
||||
- 🔗 集成流行 AI Web 服务:Claude、Peplexity、Poe、腾讯元宝、知乎直答等
|
||||
- 🔗 集成流行 AI Web 服务:Claude、Perplexity、Poe、腾讯元宝、知乎直答等
|
||||
- 💻 支持 Ollama、LM Studio 本地模型部署
|
||||
|
||||
2. **智能助手与对话**:
|
||||
|
||||
@ -113,6 +113,10 @@ linux:
|
||||
StartupWMClass: CherryStudio
|
||||
mimeTypes:
|
||||
- x-scheme-handler/cherrystudio
|
||||
rpm:
|
||||
# Workaround for electron build issue on rpm package:
|
||||
# https://github.com/electron/forge/issues/3594
|
||||
fpm: ['--rpm-rpmbuild-define=_build_id_links none']
|
||||
publish:
|
||||
provider: generic
|
||||
url: https://releases.cherry-ai.com
|
||||
|
||||
@ -74,6 +74,7 @@ export default defineConfig({
|
||||
},
|
||||
renderer: {
|
||||
plugins: [
|
||||
(async () => (await import('@tailwindcss/vite')).default())(),
|
||||
react({
|
||||
tsDecorators: true,
|
||||
plugins: [
|
||||
|
||||
@ -123,7 +123,10 @@ export default defineConfig([
|
||||
'.gitignore',
|
||||
'scripts/cloudflare-worker.js',
|
||||
'src/main/integration/nutstore/sso/lib/**',
|
||||
'src/main/integration/cherryin/index.js'
|
||||
'src/main/integration/cherryin/index.js',
|
||||
'src/main/integration/nutstore/sso/lib/**',
|
||||
'src/renderer/src/ui/**',
|
||||
'packages/**/dist'
|
||||
]
|
||||
}
|
||||
])
|
||||
|
||||
34
package.json
34
package.json
@ -69,6 +69,7 @@
|
||||
"format:check": "prettier --check .",
|
||||
"lint": "eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix && yarn typecheck && yarn check:i18n",
|
||||
"prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky",
|
||||
"claude": "dotenv -e .env -- claude",
|
||||
"migrations:generate": "drizzle-kit generate --config ./migrations/sqlite-drizzle.config.ts"
|
||||
},
|
||||
"dependencies": {
|
||||
@ -76,14 +77,18 @@
|
||||
"@libsql/win32-x64-msvc": "^0.4.7",
|
||||
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
|
||||
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
||||
"express": "^5.1.0",
|
||||
"faiss-node": "^0.5.1",
|
||||
"font-list": "^2.0.0",
|
||||
"graceful-fs": "^4.2.11",
|
||||
"jsdom": "26.1.0",
|
||||
"node-stream-zip": "^1.15.0",
|
||||
"officeparser": "^4.2.0",
|
||||
"os-proxy-config": "^1.1.2",
|
||||
"selection-hook": "^1.0.11",
|
||||
"selection-hook": "^1.0.12",
|
||||
"sharp": "^0.34.3",
|
||||
"swagger-jsdoc": "^6.2.8",
|
||||
"swagger-ui-express": "^5.0.1",
|
||||
"tesseract.js": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch",
|
||||
"turndown": "7.2.0"
|
||||
},
|
||||
@ -92,8 +97,9 @@
|
||||
"@agentic/searxng": "^7.3.3",
|
||||
"@agentic/tavily": "^7.3.3",
|
||||
"@ai-sdk/amazon-bedrock": "^3.0.0",
|
||||
"@ai-sdk/google-vertex": "^3.0.0",
|
||||
"@ai-sdk/google-vertex": "^3.0.25",
|
||||
"@ai-sdk/mistral": "^2.0.0",
|
||||
"@ai-sdk/perplexity": "^2.0.8",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
||||
@ -129,13 +135,14 @@
|
||||
"@eslint/js": "^9.22.0",
|
||||
"@google/genai": "patch:@google/genai@npm%3A1.0.1#~/.yarn/patches/@google-genai-npm-1.0.1-e26f0f9af7.patch",
|
||||
"@hello-pangea/dnd": "^18.0.1",
|
||||
"@heroui/react": "^2.8.3",
|
||||
"@kangfenmao/keyv-storage": "^0.1.0",
|
||||
"@langchain/community": "^0.3.50",
|
||||
"@langchain/core": "^0.3.68",
|
||||
"@langchain/ollama": "^0.2.1",
|
||||
"@langchain/openai": "^0.6.7",
|
||||
"@mistralai/mistralai": "^1.7.5",
|
||||
"@modelcontextprotocol/sdk": "^1.17.0",
|
||||
"@modelcontextprotocol/sdk": "^1.17.5",
|
||||
"@mozilla/readability": "^0.6.0",
|
||||
"@notionhq/client": "^2.2.15",
|
||||
"@openrouter/ai-sdk-provider": "^1.1.2",
|
||||
@ -149,6 +156,7 @@
|
||||
"@reduxjs/toolkit": "^2.2.5",
|
||||
"@shikijs/markdown-it": "^3.12.0",
|
||||
"@swc/plugin-styled-components": "^8.0.4",
|
||||
"@tailwindcss/vite": "^4.1.13",
|
||||
"@tanstack/react-query": "^5.85.5",
|
||||
"@tanstack/react-virtual": "^3.13.12",
|
||||
"@testing-library/dom": "^10.4.0",
|
||||
@ -174,6 +182,10 @@
|
||||
"@truto/turndown-plugin-gfm": "^1.0.2",
|
||||
"@tryfabric/martian": "^1.2.4",
|
||||
"@types/cli-progress": "^3",
|
||||
"@types/content-type": "^1.1.9",
|
||||
"@types/cors": "^2.8.19",
|
||||
"@types/diff": "^7",
|
||||
"@types/express": "^5",
|
||||
"@types/fs-extra": "^11",
|
||||
"@types/he": "^1",
|
||||
"@types/html-to-text": "^9",
|
||||
@ -187,6 +199,9 @@
|
||||
"@types/react-dom": "^19.0.4",
|
||||
"@types/react-infinite-scroll-component": "^5.0.0",
|
||||
"@types/react-transition-group": "^4.4.12",
|
||||
"@types/react-window": "^1",
|
||||
"@types/swagger-jsdoc": "^6",
|
||||
"@types/swagger-ui-express": "^4.1.8",
|
||||
"@types/tinycolor2": "^1",
|
||||
"@types/turndown": "^5.0.5",
|
||||
"@types/word-extractor": "^1",
|
||||
@ -201,16 +216,18 @@
|
||||
"@viz-js/lang-dot": "^1.0.5",
|
||||
"@viz-js/viz": "^3.14.0",
|
||||
"@xyflow/react": "^12.4.4",
|
||||
"ai": "^5.0.29",
|
||||
"ai": "^5.0.38",
|
||||
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
|
||||
"archiver": "^7.0.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
"axios": "^1.7.3",
|
||||
"browser-image-compression": "^2.0.2",
|
||||
"chardet": "^2.1.0",
|
||||
"check-disk-space": "3.4.0",
|
||||
"cheerio": "^1.1.2",
|
||||
"chokidar": "^4.0.3",
|
||||
"cli-progress": "^3.12.0",
|
||||
"clsx": "^2.1.1",
|
||||
"code-inspector-plugin": "^0.20.14",
|
||||
"color": "^5.0.0",
|
||||
"concurrently": "^9.2.1",
|
||||
@ -241,6 +258,7 @@
|
||||
"fast-diff": "^1.3.0",
|
||||
"fast-xml-parser": "^5.2.0",
|
||||
"fetch-socks": "1.3.2",
|
||||
"framer-motion": "^12.23.12",
|
||||
"franc-min": "^6.2.0",
|
||||
"fs-extra": "^11.2.0",
|
||||
"google-auth-library": "^9.15.1",
|
||||
@ -275,6 +293,7 @@
|
||||
"playwright": "^1.52.0",
|
||||
"prettier": "^3.5.3",
|
||||
"prettier-plugin-sort-json": "^4.1.1",
|
||||
"prettier-plugin-tailwindcss": "^0.6.14",
|
||||
"proxy-agent": "^6.5.0",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
@ -304,18 +323,19 @@
|
||||
"remark-math": "^6.0.0",
|
||||
"remove-markdown": "^0.6.2",
|
||||
"rollup-plugin-visualizer": "^5.12.0",
|
||||
"sass": "^1.88.0",
|
||||
"shiki": "^3.12.0",
|
||||
"strict-url-sanitise": "^0.0.1",
|
||||
"string-width": "^7.2.0",
|
||||
"striptags": "^3.2.0",
|
||||
"styled-components": "^6.1.11",
|
||||
"swr": "^2.3.6",
|
||||
"tailwindcss": "^4.1.13",
|
||||
"tar": "^7.4.3",
|
||||
"tiny-pinyin": "^1.3.2",
|
||||
"tokenx": "^1.1.0",
|
||||
"tsx": "^4.20.3",
|
||||
"turndown-plugin-gfm": "^1.0.2",
|
||||
"tw-animate-css": "^1.3.8",
|
||||
"typescript": "^5.6.2",
|
||||
"undici": "6.21.2",
|
||||
"unified": "^11.0.5",
|
||||
@ -331,7 +351,7 @@
|
||||
"yjs": "^13.6.27",
|
||||
"youtubei.js": "^15.0.1",
|
||||
"zipread": "^1.3.3",
|
||||
"zod": "^3.25.74"
|
||||
"zod": "^4.1.5"
|
||||
},
|
||||
"resolutions": {
|
||||
"@codemirror/language": "6.11.3",
|
||||
@ -360,7 +380,7 @@
|
||||
"prettier --write",
|
||||
"eslint --fix"
|
||||
],
|
||||
"*.{json,yml,yaml,css,scss,html}": [
|
||||
"*.{json,yml,yaml,css,html}": [
|
||||
"prettier --write"
|
||||
]
|
||||
}
|
||||
|
||||
@ -1,103 +0,0 @@
|
||||
/**
|
||||
* Hub Provider 使用示例
|
||||
*
|
||||
* 演示如何使用简化后的Hub Provider功能来路由到多个底层provider
|
||||
*/
|
||||
|
||||
import { createHubProvider, initializeProvider, providerRegistry } from '../src/index'
|
||||
|
||||
async function demonstrateHubProvider() {
|
||||
try {
|
||||
// 1. 初始化底层providers
|
||||
console.log('📦 初始化底层providers...')
|
||||
|
||||
initializeProvider('openai', {
|
||||
apiKey: process.env.OPENAI_API_KEY || 'sk-test-key'
|
||||
})
|
||||
|
||||
initializeProvider('anthropic', {
|
||||
apiKey: process.env.ANTHROPIC_API_KEY || 'sk-ant-test-key'
|
||||
})
|
||||
|
||||
// 2. 创建Hub Provider(自动包含所有已初始化的providers)
|
||||
console.log('🌐 创建Hub Provider...')
|
||||
|
||||
const aihubmixProvider = createHubProvider({
|
||||
hubId: 'aihubmix',
|
||||
debug: true
|
||||
})
|
||||
|
||||
// 3. 注册Hub Provider
|
||||
providerRegistry.registerProvider('aihubmix', aihubmixProvider)
|
||||
|
||||
console.log('✅ Hub Provider "aihubmix" 注册成功')
|
||||
|
||||
// 4. 使用Hub Provider访问不同的模型
|
||||
console.log('\n🚀 使用Hub模型...')
|
||||
|
||||
// 通过Hub路由到OpenAI
|
||||
const openaiModel = providerRegistry.languageModel('aihubmix:openai:gpt-4')
|
||||
console.log('✓ OpenAI模型已获取:', openaiModel.modelId)
|
||||
|
||||
// 通过Hub路由到Anthropic
|
||||
const anthropicModel = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
|
||||
console.log('✓ Anthropic模型已获取:', anthropicModel.modelId)
|
||||
|
||||
// 5. 演示错误处理
|
||||
console.log('\n❌ 演示错误处理...')
|
||||
|
||||
try {
|
||||
// 尝试访问未初始化的provider
|
||||
providerRegistry.languageModel('aihubmix:google:gemini-pro')
|
||||
} catch (error) {
|
||||
console.log('预期错误:', error.message)
|
||||
}
|
||||
|
||||
try {
|
||||
// 尝试使用错误的模型ID格式
|
||||
providerRegistry.languageModel('aihubmix:invalid-format')
|
||||
} catch (error) {
|
||||
console.log('预期错误:', error.message)
|
||||
}
|
||||
|
||||
// 6. 多个Hub Provider示例
|
||||
console.log('\n🔄 创建多个Hub Provider...')
|
||||
|
||||
const localHubProvider = createHubProvider({
|
||||
hubId: 'local-ai'
|
||||
})
|
||||
|
||||
providerRegistry.registerProvider('local-ai', localHubProvider)
|
||||
console.log('✅ Hub Provider "local-ai" 注册成功')
|
||||
|
||||
console.log('\n🎉 Hub Provider演示完成!')
|
||||
} catch (error) {
|
||||
console.error('💥 演示过程中发生错误:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 演示简化的使用方式
|
||||
function simplifiedUsageExample() {
|
||||
console.log('\n📝 简化使用示例:')
|
||||
console.log(`
|
||||
// 1. 初始化providers
|
||||
initializeProvider('openai', { apiKey: 'sk-xxx' })
|
||||
initializeProvider('anthropic', { apiKey: 'sk-ant-xxx' })
|
||||
|
||||
// 2. 创建并注册Hub Provider
|
||||
const hubProvider = createHubProvider({ hubId: 'aihubmix' })
|
||||
providerRegistry.registerProvider('aihubmix', hubProvider)
|
||||
|
||||
// 3. 直接使用
|
||||
const model1 = providerRegistry.languageModel('aihubmix:openai:gpt-4')
|
||||
const model2 = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
|
||||
`)
|
||||
}
|
||||
|
||||
// 运行演示
|
||||
if (require.main === module) {
|
||||
demonstrateHubProvider()
|
||||
simplifiedUsageExample()
|
||||
}
|
||||
|
||||
export { demonstrateHubProvider, simplifiedUsageExample }
|
||||
@ -1,167 +0,0 @@
|
||||
/**
|
||||
* Image Generation Example
|
||||
* 演示如何使用 aiCore 的文生图功能
|
||||
*/
|
||||
|
||||
import { createExecutor, generateImage } from '../src/index'
|
||||
|
||||
async function main() {
|
||||
// 方式1: 使用执行器实例
|
||||
console.log('📸 创建 OpenAI 图像生成执行器...')
|
||||
const executor = createExecutor('openai', {
|
||||
apiKey: process.env.OPENAI_API_KEY!
|
||||
})
|
||||
|
||||
try {
|
||||
console.log('🎨 使用执行器生成图像...')
|
||||
const result1 = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A futuristic cityscape at sunset with flying cars',
|
||||
size: '1024x1024',
|
||||
n: 1
|
||||
})
|
||||
|
||||
console.log('✅ 图像生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result1.images.length,
|
||||
mediaType: result1.image.mediaType,
|
||||
hasBase64: !!result1.image.base64,
|
||||
providerMetadata: result1.providerMetadata
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 执行器生成失败:', error)
|
||||
}
|
||||
|
||||
// 方式2: 使用直接调用 API
|
||||
try {
|
||||
console.log('🎨 使用直接 API 生成图像...')
|
||||
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
|
||||
prompt: 'A magical forest with glowing mushrooms and fairy lights',
|
||||
aspectRatio: '16:9',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: 'hd',
|
||||
style: 'vivid'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
console.log('✅ 直接 API 生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result2.images.length,
|
||||
mediaType: result2.image.mediaType,
|
||||
hasBase64: !!result2.image.base64
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 直接 API 生成失败:', error)
|
||||
}
|
||||
|
||||
// 方式3: 支持其他提供商 (Google Imagen)
|
||||
if (process.env.GOOGLE_API_KEY) {
|
||||
try {
|
||||
console.log('🎨 使用 Google Imagen 生成图像...')
|
||||
const googleExecutor = createExecutor('google', {
|
||||
apiKey: process.env.GOOGLE_API_KEY!
|
||||
})
|
||||
|
||||
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
|
||||
prompt: 'A serene mountain lake at dawn with mist rising from the water',
|
||||
aspectRatio: '1:1'
|
||||
})
|
||||
|
||||
console.log('✅ Google Imagen 生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result3.images.length,
|
||||
mediaType: result3.image.mediaType,
|
||||
hasBase64: !!result3.image.base64
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ Google Imagen 生成失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 方式4: 支持插件系统
|
||||
const pluginExample = async () => {
|
||||
console.log('🔌 演示插件系统...')
|
||||
|
||||
// 创建一个示例插件,用于修改提示词
|
||||
const promptEnhancerPlugin = {
|
||||
name: 'prompt-enhancer',
|
||||
transformParams: async (params: any) => {
|
||||
console.log('🔧 插件: 增强提示词...')
|
||||
return {
|
||||
...params,
|
||||
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
|
||||
}
|
||||
},
|
||||
transformResult: async (result: any) => {
|
||||
console.log('🔧 插件: 处理结果...')
|
||||
return {
|
||||
...result,
|
||||
enhanced: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const executorWithPlugin = createExecutor(
|
||||
'openai',
|
||||
{
|
||||
apiKey: process.env.OPENAI_API_KEY!
|
||||
},
|
||||
[promptEnhancerPlugin]
|
||||
)
|
||||
|
||||
try {
|
||||
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A cute robot playing in a garden'
|
||||
})
|
||||
|
||||
console.log('✅ 插件系统生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result4.images.length,
|
||||
enhanced: (result4 as any).enhanced,
|
||||
mediaType: result4.image.mediaType
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 插件系统生成失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
await pluginExample()
|
||||
}
|
||||
|
||||
// 错误处理演示
|
||||
async function errorHandlingExample() {
|
||||
console.log('⚠️ 演示错误处理...')
|
||||
|
||||
try {
|
||||
const executor = createExecutor('openai', {
|
||||
apiKey: 'invalid-key'
|
||||
})
|
||||
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'Test image'
|
||||
})
|
||||
} catch (error: any) {
|
||||
console.log('✅ 成功捕获错误:', error.constructor.name)
|
||||
console.log('📋 错误信息:', error.message)
|
||||
console.log('🏷️ 提供商ID:', error.providerId)
|
||||
console.log('🏷️ 模型ID:', error.modelId)
|
||||
}
|
||||
}
|
||||
|
||||
// 运行示例
|
||||
if (require.main === module) {
|
||||
main()
|
||||
.then(() => {
|
||||
console.log('🎉 所有示例完成!')
|
||||
return errorHandlingExample()
|
||||
})
|
||||
.then(() => {
|
||||
console.log('🎯 示例程序结束')
|
||||
process.exit(0)
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('💥 程序执行出错:', error)
|
||||
process.exit(1)
|
||||
})
|
||||
}
|
||||
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@cherrystudio/ai-core",
|
||||
"version": "1.0.0-alpha.11",
|
||||
"version": "1.0.0-alpha.14",
|
||||
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/index.mjs",
|
||||
@ -39,13 +39,13 @@
|
||||
"@ai-sdk/anthropic": "^2.0.5",
|
||||
"@ai-sdk/azure": "^2.0.16",
|
||||
"@ai-sdk/deepseek": "^1.0.9",
|
||||
"@ai-sdk/google": "^2.0.7",
|
||||
"@ai-sdk/openai": "^2.0.19",
|
||||
"@ai-sdk/google": "^2.0.13",
|
||||
"@ai-sdk/openai": "^2.0.26",
|
||||
"@ai-sdk/openai-compatible": "^1.0.9",
|
||||
"@ai-sdk/provider": "^2.0.0",
|
||||
"@ai-sdk/provider-utils": "^3.0.4",
|
||||
"@ai-sdk/xai": "^2.0.9",
|
||||
"zod": "^3.25.0"
|
||||
"zod": "^4.1.5"
|
||||
},
|
||||
"devDependencies": {
|
||||
"tsdown": "^0.12.9",
|
||||
|
||||
@ -84,7 +84,6 @@ export class ModelResolver {
|
||||
*/
|
||||
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
|
||||
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||
console.log('fullModelId', fullModelId)
|
||||
return globalRegistryManagement.languageModel(fullModelId as any)
|
||||
}
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
// copy from @ai-sdk/xai/xai-chat-options.ts
|
||||
// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件
|
||||
|
||||
import * as z from 'zod/v4'
|
||||
import { z } from 'zod'
|
||||
|
||||
const webSourceSchema = z.object({
|
||||
type: z.literal('web'),
|
||||
|
||||
@ -0,0 +1,39 @@
|
||||
import { google } from '@ai-sdk/google'
|
||||
|
||||
import { definePlugin } from '../../'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
|
||||
const toolNameMap = {
|
||||
googleSearch: 'google_search',
|
||||
urlContext: 'url_context',
|
||||
codeExecution: 'code_execution'
|
||||
} as const
|
||||
|
||||
type ToolConfigKey = keyof typeof toolNameMap
|
||||
type ToolConfig = { googleSearch?: boolean; urlContext?: boolean; codeExecution?: boolean }
|
||||
|
||||
export const googleToolsPlugin = (config?: ToolConfig) =>
|
||||
definePlugin({
|
||||
name: 'googleToolsPlugin',
|
||||
transformParams: <T>(params: T, context: AiRequestContext): T => {
|
||||
const { providerId } = context
|
||||
if (providerId === 'google' && config) {
|
||||
if (typeof params === 'object' && params !== null) {
|
||||
const typedParams = params as T & { tools?: Record<string, unknown> }
|
||||
|
||||
if (!typedParams.tools) {
|
||||
typedParams.tools = {}
|
||||
}
|
||||
|
||||
// 使用类型安全的方式遍历配置
|
||||
;(Object.keys(config) as ToolConfigKey[]).forEach((key) => {
|
||||
if (config[key] && key in toolNameMap && key in google.tools) {
|
||||
const toolName = toolNameMap[key]
|
||||
typedParams.tools![toolName] = google.tools[key]({})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
})
|
||||
@ -4,6 +4,7 @@
|
||||
*/
|
||||
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
|
||||
|
||||
export { googleToolsPlugin } from './googleToolsPlugin'
|
||||
export { createLoggingPlugin } from './logging'
|
||||
export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin'
|
||||
export type { PromptToolUseConfig, ToolUseRequestContext, ToolUseResult } from './toolUsePlugin/type'
|
||||
|
||||
@ -27,10 +27,20 @@ export class StreamEventManager {
|
||||
/**
|
||||
* 发送步骤完成事件
|
||||
*/
|
||||
sendStepFinishEvent(controller: StreamController, chunk: any): void {
|
||||
sendStepFinishEvent(
|
||||
controller: StreamController,
|
||||
chunk: any,
|
||||
context: AiRequestContext,
|
||||
finishReason: string = 'stop'
|
||||
): void {
|
||||
// 累加当前步骤的 usage
|
||||
if (chunk.usage && context.accumulatedUsage) {
|
||||
this.accumulateUsage(context.accumulatedUsage, chunk.usage)
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
type: 'finish-step',
|
||||
finishReason: 'stop',
|
||||
finishReason,
|
||||
response: chunk.response,
|
||||
usage: chunk.usage,
|
||||
providerMetadata: chunk.providerMetadata
|
||||
@ -43,28 +53,32 @@ export class StreamEventManager {
|
||||
async handleRecursiveCall(
|
||||
controller: StreamController,
|
||||
recursiveParams: any,
|
||||
context: AiRequestContext,
|
||||
stepId: string
|
||||
context: AiRequestContext
|
||||
): Promise<void> {
|
||||
try {
|
||||
console.log('[MCP Prompt] Starting recursive call after tool execution...')
|
||||
// try {
|
||||
// 重置工具执行状态,准备处理新的步骤
|
||||
context.hasExecutedToolsInCurrentStep = false
|
||||
|
||||
const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||
const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||
|
||||
if (recursiveResult && recursiveResult.fullStream) {
|
||||
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
|
||||
} else {
|
||||
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
|
||||
}
|
||||
} catch (error) {
|
||||
this.handleRecursiveCallError(controller, error, stepId)
|
||||
if (recursiveResult && recursiveResult.fullStream) {
|
||||
await this.pipeRecursiveStream(controller, recursiveResult.fullStream, context)
|
||||
} else {
|
||||
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
|
||||
}
|
||||
// } catch (error) {
|
||||
// this.handleRecursiveCallError(controller, error, stepId)
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
* 将递归流的数据传递到当前流
|
||||
*/
|
||||
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
|
||||
private async pipeRecursiveStream(
|
||||
controller: StreamController,
|
||||
recursiveStream: ReadableStream,
|
||||
context?: AiRequestContext
|
||||
): Promise<void> {
|
||||
const reader = recursiveStream.getReader()
|
||||
try {
|
||||
while (true) {
|
||||
@ -73,9 +87,16 @@ export class StreamEventManager {
|
||||
break
|
||||
}
|
||||
if (value.type === 'finish') {
|
||||
// 迭代的流不发finish
|
||||
// 迭代的流不发finish,但需要累加其 usage
|
||||
if (value.usage && context?.accumulatedUsage) {
|
||||
this.accumulateUsage(context.accumulatedUsage, value.usage)
|
||||
}
|
||||
break
|
||||
}
|
||||
// 对于 finish-step 类型,累加其 usage
|
||||
if (value.type === 'finish-step' && value.usage && context?.accumulatedUsage) {
|
||||
this.accumulateUsage(context.accumulatedUsage, value.usage)
|
||||
}
|
||||
// 将递归流的数据传递到当前流
|
||||
controller.enqueue(value)
|
||||
}
|
||||
@ -87,25 +108,25 @@ export class StreamEventManager {
|
||||
/**
|
||||
* 处理递归调用错误
|
||||
*/
|
||||
private handleRecursiveCallError(controller: StreamController, error: unknown, stepId: string): void {
|
||||
console.error('[MCP Prompt] Recursive call failed:', error)
|
||||
// private handleRecursiveCallError(controller: StreamController, error: unknown): void {
|
||||
// console.error('[MCP Prompt] Recursive call failed:', error)
|
||||
|
||||
// 使用 AI SDK 标准错误格式,但不中断流
|
||||
controller.enqueue({
|
||||
type: 'error',
|
||||
error: {
|
||||
message: error instanceof Error ? error.message : String(error),
|
||||
name: error instanceof Error ? error.name : 'RecursiveCallError'
|
||||
}
|
||||
})
|
||||
// // 使用 AI SDK 标准错误格式,但不中断流
|
||||
// controller.enqueue({
|
||||
// type: 'error',
|
||||
// error: {
|
||||
// message: error instanceof Error ? error.message : String(error),
|
||||
// name: error instanceof Error ? error.name : 'RecursiveCallError'
|
||||
// }
|
||||
// })
|
||||
|
||||
// 继续发送文本增量,保持流的连续性
|
||||
controller.enqueue({
|
||||
type: 'text-delta',
|
||||
id: stepId,
|
||||
text: '\n\n[工具执行后递归调用失败,继续对话...]'
|
||||
})
|
||||
}
|
||||
// // // 继续发送文本增量,保持流的连续性
|
||||
// // controller.enqueue({
|
||||
// // type: 'text-delta',
|
||||
// // id: stepId,
|
||||
// // text: '\n\n[工具执行后递归调用失败,继续对话...]'
|
||||
// // })
|
||||
// }
|
||||
|
||||
/**
|
||||
* 构建递归调用的参数
|
||||
@ -136,4 +157,18 @@ export class StreamEventManager {
|
||||
|
||||
return recursiveParams
|
||||
}
|
||||
|
||||
/**
|
||||
* 累加 usage 数据
|
||||
*/
|
||||
private accumulateUsage(target: any, source: any): void {
|
||||
if (!target || !source) return
|
||||
|
||||
// 累加各种 token 类型
|
||||
target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0)
|
||||
target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0)
|
||||
target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0)
|
||||
target.reasoningTokens = (target.reasoningTokens || 0) + (source.reasoningTokens || 0)
|
||||
target.cachedInputTokens = (target.cachedInputTokens || 0) + (source.cachedInputTokens || 0)
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
* 负责工具的执行、结果格式化和相关事件发送
|
||||
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||
*/
|
||||
import type { ToolSet } from 'ai'
|
||||
import type { ToolSet, TypedToolError } from 'ai'
|
||||
|
||||
import type { ToolUseResult } from './type'
|
||||
|
||||
@ -38,7 +38,6 @@ export class ToolExecutor {
|
||||
controller: StreamController
|
||||
): Promise<ExecutedResult[]> {
|
||||
const executedResults: ExecutedResult[] = []
|
||||
|
||||
for (const toolUse of toolUses) {
|
||||
try {
|
||||
const tool = tools[toolUse.toolName]
|
||||
@ -46,17 +45,12 @@ export class ToolExecutor {
|
||||
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
|
||||
}
|
||||
|
||||
// 发送工具调用开始事件
|
||||
this.sendToolStartEvents(controller, toolUse)
|
||||
|
||||
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
|
||||
|
||||
// 发送 tool-call 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-call',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
input: tool.inputSchema
|
||||
input: toolUse.arguments
|
||||
})
|
||||
|
||||
const result = await tool.execute(toolUse.arguments, {
|
||||
@ -111,45 +105,46 @@ export class ToolExecutor {
|
||||
/**
|
||||
* 发送工具调用开始相关事件
|
||||
*/
|
||||
private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
|
||||
// 发送 tool-input-start 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-input-start',
|
||||
id: toolUse.id,
|
||||
toolName: toolUse.toolName
|
||||
})
|
||||
}
|
||||
// private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
|
||||
// // 发送 tool-input-start 事件
|
||||
// controller.enqueue({
|
||||
// type: 'tool-input-start',
|
||||
// id: toolUse.id,
|
||||
// toolName: toolUse.toolName
|
||||
// })
|
||||
// }
|
||||
|
||||
/**
|
||||
* 处理工具执行错误
|
||||
*/
|
||||
private handleToolError(
|
||||
private handleToolError<T extends ToolSet>(
|
||||
toolUse: ToolUseResult,
|
||||
error: unknown,
|
||||
controller: StreamController
|
||||
// _tools: ToolSet
|
||||
): ExecutedResult {
|
||||
// 使用 AI SDK 标准错误格式
|
||||
// const toolError: TypedToolError<typeof _tools> = {
|
||||
// type: 'tool-error',
|
||||
// toolCallId: toolUse.id,
|
||||
// toolName: toolUse.toolName,
|
||||
// input: toolUse.arguments,
|
||||
// error: error instanceof Error ? error.message : String(error)
|
||||
// }
|
||||
const toolError: TypedToolError<T> = {
|
||||
type: 'tool-error',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
input: toolUse.arguments,
|
||||
error
|
||||
}
|
||||
|
||||
// controller.enqueue(toolError)
|
||||
controller.enqueue(toolError)
|
||||
|
||||
// 发送标准错误事件
|
||||
controller.enqueue({
|
||||
type: 'error',
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
})
|
||||
// controller.enqueue({
|
||||
// type: 'tool-error',
|
||||
// toolCallId: toolUse.id,
|
||||
// error: error instanceof Error ? error.message : String(error),
|
||||
// input: toolUse.arguments
|
||||
// })
|
||||
|
||||
return {
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
result: error instanceof Error ? error.message : String(error),
|
||||
result: error,
|
||||
isError: true
|
||||
}
|
||||
}
|
||||
|
||||
@ -8,9 +8,19 @@ import type { TextStreamPart, ToolSet } from 'ai'
|
||||
import { definePlugin } from '../../index'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import { StreamEventManager } from './StreamEventManager'
|
||||
import { type TagConfig, TagExtractor } from './tagExtraction'
|
||||
import { ToolExecutor } from './ToolExecutor'
|
||||
import { PromptToolUseConfig, ToolUseResult } from './type'
|
||||
|
||||
/**
|
||||
* 工具使用标签配置
|
||||
*/
|
||||
const TOOL_USE_TAG_CONFIG: TagConfig = {
|
||||
openingTag: '<tool_use>',
|
||||
closingTag: '</tool_use>',
|
||||
separator: '\n'
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认系统提示符模板(提取自 Cherry Studio)
|
||||
*/
|
||||
@ -146,8 +156,10 @@ Assistant: The population of Shanghai is 26 million, while Guangzhou has a popul
|
||||
/**
|
||||
* 构建可用工具部分(提取自 Cherry Studio)
|
||||
*/
|
||||
function buildAvailableTools(tools: ToolSet): string {
|
||||
function buildAvailableTools(tools: ToolSet): string | null {
|
||||
const availableTools = Object.keys(tools)
|
||||
if (availableTools.length === 0) return null
|
||||
const result = availableTools
|
||||
.map((toolName: string) => {
|
||||
const tool = tools[toolName]
|
||||
return `
|
||||
@ -162,7 +174,7 @@ function buildAvailableTools(tools: ToolSet): string {
|
||||
})
|
||||
.join('\n')
|
||||
return `<tools>
|
||||
${availableTools}
|
||||
${result}
|
||||
</tools>`
|
||||
}
|
||||
|
||||
@ -171,6 +183,7 @@ ${availableTools}
|
||||
*/
|
||||
function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): string {
|
||||
const availableTools = buildAvailableTools(tools)
|
||||
if (availableTools === null) return userSystemPrompt
|
||||
|
||||
const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES)
|
||||
.replace('{{ AVAILABLE_TOOLS }}', availableTools)
|
||||
@ -249,13 +262,11 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
}
|
||||
|
||||
context.mcpTools = params.tools
|
||||
console.log('tools stored in context', params.tools)
|
||||
|
||||
// 构建系统提示符
|
||||
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
|
||||
let systemMessage: string | null = systemPrompt
|
||||
console.log('config.context', context)
|
||||
if (config.createSystemMessage) {
|
||||
// 🎯 如果用户提供了自定义处理函数,使用它
|
||||
systemMessage = config.createSystemMessage(systemPrompt, params, context)
|
||||
@ -268,20 +279,40 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
tools: undefined
|
||||
}
|
||||
context.originalParams = transformedParams
|
||||
console.log('transformedParams', transformedParams)
|
||||
return transformedParams
|
||||
},
|
||||
transformStream: (_: any, context: AiRequestContext) => () => {
|
||||
let textBuffer = ''
|
||||
let stepId = ''
|
||||
// let stepId = ''
|
||||
|
||||
if (!context.mcpTools) {
|
||||
throw new Error('No tools available')
|
||||
}
|
||||
|
||||
// 创建工具执行器和流事件管理器
|
||||
// 从 context 中获取或初始化 usage 累加器
|
||||
if (!context.accumulatedUsage) {
|
||||
context.accumulatedUsage = {
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalTokens: 0,
|
||||
reasoningTokens: 0,
|
||||
cachedInputTokens: 0
|
||||
}
|
||||
}
|
||||
|
||||
// 创建工具执行器、流事件管理器和标签提取器
|
||||
const toolExecutor = new ToolExecutor()
|
||||
const streamEventManager = new StreamEventManager()
|
||||
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
|
||||
|
||||
// 在context中初始化工具执行状态,避免递归调用时状态丢失
|
||||
if (!context.hasExecutedToolsInCurrentStep) {
|
||||
context.hasExecutedToolsInCurrentStep = false
|
||||
}
|
||||
|
||||
// 用于hold text-start事件,直到确认有非工具标签内容
|
||||
let pendingTextStart: TextStreamPart<TOOLS> | null = null
|
||||
let hasStartedText = false
|
||||
|
||||
type TOOLS = NonNullable<typeof context.mcpTools>
|
||||
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
|
||||
@ -289,83 +320,106 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
chunk: TextStreamPart<TOOLS>,
|
||||
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
|
||||
) {
|
||||
// 收集文本内容
|
||||
if (chunk.type === 'text-delta') {
|
||||
textBuffer += chunk.text || ''
|
||||
stepId = chunk.id || ''
|
||||
controller.enqueue(chunk)
|
||||
// Hold住text-start事件,直到确认有非工具标签内容
|
||||
if ((chunk as any).type === 'text-start') {
|
||||
pendingTextStart = chunk
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'text-end' || chunk.type === 'finish-step') {
|
||||
const tools = context.mcpTools
|
||||
if (!tools || Object.keys(tools).length === 0) {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
// text-delta阶段:收集文本内容并过滤工具标签
|
||||
if (chunk.type === 'text-delta') {
|
||||
textBuffer += chunk.text || ''
|
||||
// stepId = chunk.id || ''
|
||||
|
||||
// 解析工具调用
|
||||
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools)
|
||||
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
||||
// 使用TagExtractor过滤工具标签,只传递非标签内容到UI层
|
||||
const extractionResults = tagExtractor.processText(chunk.text || '')
|
||||
|
||||
// 如果没有有效的工具调用,直接传递原始事件
|
||||
if (validToolUses.length === 0) {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'text-end') {
|
||||
controller.enqueue({
|
||||
type: 'text-end',
|
||||
id: stepId,
|
||||
providerMetadata: {
|
||||
text: {
|
||||
value: parsedContent
|
||||
}
|
||||
for (const result of extractionResults) {
|
||||
// 只传递非标签内容到UI层
|
||||
if (!result.isTagContent && result.content) {
|
||||
// 如果还没有发送text-start且有pending的text-start,先发送它
|
||||
if (!hasStartedText && pendingTextStart) {
|
||||
controller.enqueue(pendingTextStart)
|
||||
hasStartedText = true
|
||||
pendingTextStart = null
|
||||
}
|
||||
})
|
||||
return
|
||||
|
||||
const filteredChunk = {
|
||||
...chunk,
|
||||
text: result.content
|
||||
}
|
||||
controller.enqueue(filteredChunk)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'text-end') {
|
||||
// 只有当已经发送了text-start时才发送text-end
|
||||
if (hasStartedText) {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'finish-step') {
|
||||
// 统一在finish-step阶段检查并执行工具调用
|
||||
const tools = context.mcpTools
|
||||
if (tools && Object.keys(tools).length > 0 && !context.hasExecutedToolsInCurrentStep) {
|
||||
// 解析完整的textBuffer来检测工具调用
|
||||
const { results: parsedTools } = parseToolUse(textBuffer, tools)
|
||||
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
||||
|
||||
if (validToolUses.length > 0) {
|
||||
context.hasExecutedToolsInCurrentStep = true
|
||||
|
||||
// 执行工具调用(不需要手动发送 start-step,外部流已经处理)
|
||||
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
|
||||
|
||||
// 发送步骤完成事件,使用 tool-calls 作为 finishReason
|
||||
streamEventManager.sendStepFinishEvent(controller, chunk, context, 'tool-calls')
|
||||
|
||||
// 处理递归调用
|
||||
const toolResultsText = toolExecutor.formatToolResults(executedResults)
|
||||
const recursiveParams = streamEventManager.buildRecursiveParams(
|
||||
context,
|
||||
textBuffer,
|
||||
toolResultsText,
|
||||
tools
|
||||
)
|
||||
|
||||
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
finishReason: 'tool-calls'
|
||||
})
|
||||
|
||||
// 发送步骤开始事件
|
||||
streamEventManager.sendStepStartEvent(controller)
|
||||
|
||||
// 执行工具调用
|
||||
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
|
||||
|
||||
// 发送步骤完成事件
|
||||
streamEventManager.sendStepFinishEvent(controller, chunk)
|
||||
|
||||
// 处理递归调用
|
||||
if (validToolUses.length > 0) {
|
||||
const toolResultsText = toolExecutor.formatToolResults(executedResults)
|
||||
const recursiveParams = streamEventManager.buildRecursiveParams(
|
||||
context,
|
||||
textBuffer,
|
||||
toolResultsText,
|
||||
tools
|
||||
)
|
||||
|
||||
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context, stepId)
|
||||
}
|
||||
// 如果没有执行工具调用,直接传递原始finish-step事件
|
||||
controller.enqueue(chunk)
|
||||
|
||||
// 清理状态
|
||||
textBuffer = ''
|
||||
return
|
||||
}
|
||||
|
||||
// 对于其他类型的事件,直接传递
|
||||
controller.enqueue(chunk)
|
||||
// 处理 finish 类型,使用累加后的 totalUsage
|
||||
if (chunk.type === 'finish') {
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
totalUsage: context.accumulatedUsage
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 对于其他类型的事件,直接传递(不包括text-start,已在上面处理)
|
||||
if ((chunk as any).type !== 'text-start') {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
},
|
||||
|
||||
flush() {
|
||||
// 流结束时的清理工作
|
||||
console.log('[MCP Prompt] Stream ended, cleaning up...')
|
||||
// 清理pending状态
|
||||
pendingTextStart = null
|
||||
hasStartedText = false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -27,7 +27,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
|
||||
case 'openai': {
|
||||
if (config.openai) {
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
|
||||
params.tools.web_search = openai.tools.webSearch(config.openai)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
// 核心类型和接口
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
|
||||
import type { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import type { LanguageModel } from 'ai'
|
||||
|
||||
import type { ProviderId } from '../providers'
|
||||
import type { AiPlugin, AiRequestContext } from './types'
|
||||
|
||||
@ -9,16 +12,16 @@ export { PluginManager } from './manager'
|
||||
// 工具函数
|
||||
export function createContext<T extends ProviderId>(
|
||||
providerId: T,
|
||||
modelId: string,
|
||||
model: LanguageModel | ImageModelV2,
|
||||
originalParams: any
|
||||
): AiRequestContext {
|
||||
return {
|
||||
providerId,
|
||||
modelId,
|
||||
model,
|
||||
originalParams,
|
||||
metadata: {},
|
||||
startTime: Date.now(),
|
||||
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||
requestId: `${providerId}-${typeof model === 'string' ? model : model?.modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||
// 占位
|
||||
recursiveCall: () => Promise.resolve(null)
|
||||
}
|
||||
|
||||
@ -14,7 +14,7 @@ export type RecursiveCallFn = (newParams: any) => Promise<any>
|
||||
*/
|
||||
export interface AiRequestContext {
|
||||
providerId: ProviderId
|
||||
modelId: string
|
||||
model: LanguageModel | ImageModelV2
|
||||
originalParams: any
|
||||
metadata: Record<string, any>
|
||||
startTime: number
|
||||
|
||||
@ -10,8 +10,8 @@ import { createGoogleGenerativeAI } from '@ai-sdk/google'
|
||||
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
|
||||
import { createXai } from '@ai-sdk/xai'
|
||||
import { customProvider, type Provider } from 'ai'
|
||||
import * as z from 'zod'
|
||||
import { customProvider, Provider } from 'ai'
|
||||
import { z } from 'zod'
|
||||
|
||||
/**
|
||||
* 基础 Provider IDs
|
||||
@ -38,14 +38,12 @@ export const baseProviderIdSchema = z.enum(baseProviderIds)
|
||||
*/
|
||||
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
|
||||
|
||||
export const baseProviderSchema = z.object({
|
||||
id: baseProviderIdSchema,
|
||||
name: z.string(),
|
||||
creator: z.function().args(z.any()).returns(z.any()) as z.ZodType<(options: any) => Provider>,
|
||||
supportsImageGeneration: z.boolean()
|
||||
})
|
||||
|
||||
export type BaseProvider = z.infer<typeof baseProviderSchema>
|
||||
type BaseProvider = {
|
||||
id: BaseProviderId
|
||||
name: string
|
||||
creator: (options: any) => Provider
|
||||
supportsImageGeneration: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* 基础 Providers 定义
|
||||
@ -148,7 +146,12 @@ export const providerConfigSchema = z
|
||||
.object({
|
||||
id: customProviderIdSchema, // 只允许自定义ID
|
||||
name: z.string().min(1),
|
||||
creator: z.function().optional(),
|
||||
creator: z
|
||||
.function({
|
||||
input: z.any(),
|
||||
output: z.any()
|
||||
})
|
||||
.optional(),
|
||||
import: z.function().optional(),
|
||||
creatorFunctionName: z.string().optional(),
|
||||
supportsImageGeneration: z.boolean().default(false),
|
||||
|
||||
@ -4,12 +4,12 @@
|
||||
*/
|
||||
import { ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
import {
|
||||
experimental_generateImage as generateImage,
|
||||
generateObject,
|
||||
generateText,
|
||||
experimental_generateImage as _generateImage,
|
||||
generateObject as _generateObject,
|
||||
generateText as _generateText,
|
||||
LanguageModel,
|
||||
streamObject,
|
||||
streamText
|
||||
streamObject as _streamObject,
|
||||
streamText as _streamText
|
||||
} from 'ai'
|
||||
|
||||
import { globalModelResolver } from '../models'
|
||||
@ -18,7 +18,14 @@ import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||
import { type ProviderId } from '../providers'
|
||||
import { ImageGenerationError, ImageModelResolutionError } from './errors'
|
||||
import { PluginEngine } from './pluginEngine'
|
||||
import { type RuntimeConfig } from './types'
|
||||
import type {
|
||||
generateImageParams,
|
||||
generateObjectParams,
|
||||
generateTextParams,
|
||||
RuntimeConfig,
|
||||
streamObjectParams,
|
||||
streamTextParams
|
||||
} from './types'
|
||||
|
||||
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
public pluginEngine: PluginEngine<T>
|
||||
@ -75,12 +82,12 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
* 流式文本生成
|
||||
*/
|
||||
async streamText(
|
||||
params: Parameters<typeof streamText>[0],
|
||||
params: streamTextParams,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamText>> {
|
||||
const { model, ...restParams } = params
|
||||
): Promise<ReturnType<typeof _streamText>> {
|
||||
const { model } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
@ -94,19 +101,16 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
|
||||
return this.pluginEngine.executeStreamWithPlugins(
|
||||
'streamText',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams, streamTransforms) => {
|
||||
params,
|
||||
(resolvedModel, transformedParams, streamTransforms) => {
|
||||
const experimental_transform =
|
||||
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
|
||||
|
||||
const finalParams = {
|
||||
model: resolvedModel,
|
||||
return _streamText({
|
||||
...transformedParams,
|
||||
model: resolvedModel,
|
||||
experimental_transform
|
||||
} as Parameters<typeof streamText>[0]
|
||||
|
||||
return await streamText(finalParams)
|
||||
})
|
||||
}
|
||||
)
|
||||
}
|
||||
@ -117,12 +121,12 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
* 生成文本
|
||||
*/
|
||||
async generateText(
|
||||
params: Parameters<typeof generateText>[0],
|
||||
params: generateTextParams,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateText>> {
|
||||
const { model, ...restParams } = params
|
||||
): Promise<ReturnType<typeof _generateText>> {
|
||||
const { model } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
@ -134,12 +138,10 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
return this.pluginEngine.executeWithPlugins<Parameters<typeof _generateText>[0], ReturnType<typeof _generateText>>(
|
||||
'generateText',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
generateText({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateText>[0])
|
||||
params,
|
||||
(resolvedModel, transformedParams) => _generateText({ ...transformedParams, model: resolvedModel })
|
||||
)
|
||||
}
|
||||
|
||||
@ -147,12 +149,12 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
* 生成结构化对象
|
||||
*/
|
||||
async generateObject(
|
||||
params: Parameters<typeof generateObject>[0],
|
||||
params: generateObjectParams,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateObject>> {
|
||||
const { model, ...restParams } = params
|
||||
): Promise<ReturnType<typeof _generateObject>> {
|
||||
const { model } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
@ -164,25 +166,23 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
return this.pluginEngine.executeWithPlugins<generateObjectParams, ReturnType<typeof _generateObject>>(
|
||||
'generateObject',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
generateObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateObject>[0])
|
||||
params,
|
||||
async (resolvedModel, transformedParams) => _generateObject({ ...transformedParams, model: resolvedModel })
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式生成结构化对象
|
||||
*/
|
||||
async streamObject(
|
||||
params: Parameters<typeof streamObject>[0],
|
||||
streamObject(
|
||||
params: streamObjectParams,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamObject>> {
|
||||
const { model, ...restParams } = params
|
||||
): Promise<ReturnType<typeof _streamObject>> {
|
||||
const { model } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
@ -194,23 +194,17 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'streamObject',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
streamObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof streamObject>[0])
|
||||
return this.pluginEngine.executeStreamWithPlugins('streamObject', params, (resolvedModel, transformedParams) =>
|
||||
_streamObject({ ...transformedParams, model: resolvedModel })
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成图像
|
||||
*/
|
||||
async generateImage(
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'> & { model: string | ImageModelV2 }
|
||||
): Promise<ReturnType<typeof generateImage>> {
|
||||
generateImage(params: generateImageParams): Promise<ReturnType<typeof _generateImage>> {
|
||||
try {
|
||||
const { model, ...restParams } = params
|
||||
const { model } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
@ -219,13 +213,8 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return await this.pluginEngine.executeImageWithPlugins(
|
||||
'generateImage',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) => {
|
||||
return await generateImage({ model: resolvedModel, ...transformedParams })
|
||||
}
|
||||
return this.pluginEngine.executeImageWithPlugins('generateImage', params, (resolvedModel, transformedParams) =>
|
||||
_generateImage({ ...transformedParams, model: resolvedModel })
|
||||
)
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
/* eslint-disable @eslint-react/naming-convention/context-name */
|
||||
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import { LanguageModel } from 'ai'
|
||||
import { experimental_generateImage, generateObject, generateText, LanguageModel, streamObject, streamText } from 'ai'
|
||||
|
||||
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
||||
import { type ProviderId } from '../providers/types'
|
||||
@ -62,17 +62,19 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
* 执行带插件的操作(非流式)
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeWithPlugins<TParams, TResult>(
|
||||
async executeWithPlugins<
|
||||
TParams extends Parameters<typeof generateText | typeof generateObject>[0],
|
||||
TResult extends ReturnType<typeof generateText | typeof generateObject>
|
||||
>(
|
||||
methodName: string,
|
||||
model: LanguageModel,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
|
||||
executor: (model: LanguageModel, transformedParams: TParams) => TResult,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: LanguageModel | undefined
|
||||
let modelId: string
|
||||
|
||||
const { model } = params
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
@ -83,13 +85,13 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
}
|
||||
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
const context = _context ? _context : createContext(this.providerId, model, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeWithPlugins(methodName, model, newParams, executor, context)
|
||||
const result = await this.executeWithPlugins(methodName, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
@ -138,17 +140,19 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
* 执行带插件的图像生成操作
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeImageWithPlugins<TParams, TResult>(
|
||||
async executeImageWithPlugins<
|
||||
TParams extends Omit<Parameters<typeof experimental_generateImage>[0], 'model'> & { model: string | ImageModelV2 },
|
||||
TResult extends ReturnType<typeof experimental_generateImage>
|
||||
>(
|
||||
methodName: string,
|
||||
model: ImageModelV2 | string,
|
||||
params: TParams,
|
||||
executor: (model: ImageModelV2, transformedParams: TParams) => Promise<TResult>,
|
||||
executor: (model: ImageModelV2, transformedParams: TParams) => TResult,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: ImageModelV2 | undefined
|
||||
let modelId: string
|
||||
|
||||
const { model } = params
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
@ -159,13 +163,13 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
}
|
||||
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
const context = _context ? _context : createContext(this.providerId, model, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeImageWithPlugins(methodName, model, newParams, executor, context)
|
||||
const result = await this.executeImageWithPlugins(methodName, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
@ -214,17 +218,19 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
* 执行流式调用的通用逻辑(支持流转换器)
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeStreamWithPlugins<TParams, TResult>(
|
||||
async executeStreamWithPlugins<
|
||||
TParams extends Parameters<typeof streamText | typeof streamObject>[0],
|
||||
TResult extends ReturnType<typeof streamText | typeof streamObject>
|
||||
>(
|
||||
methodName: string,
|
||||
model: LanguageModel,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
|
||||
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => TResult,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: LanguageModel | undefined
|
||||
let modelId: string
|
||||
|
||||
const { model } = params
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
@ -235,13 +241,13 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
}
|
||||
|
||||
// 创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
const context = _context ? _context : createContext(this.providerId, model, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeStreamWithPlugins(methodName, model, newParams, executor, context)
|
||||
const result = await this.executeStreamWithPlugins(methodName, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
/**
|
||||
* Runtime 层类型定义
|
||||
*/
|
||||
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import { experimental_generateImage, generateObject, generateText, streamObject, streamText } from 'ai'
|
||||
|
||||
import { type ModelConfig } from '../models/types'
|
||||
import { type AiPlugin } from '../plugins'
|
||||
import { type ProviderId } from '../providers/types'
|
||||
@ -13,3 +16,11 @@ export interface RuntimeConfig<T extends ProviderId = ProviderId> {
|
||||
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
export type generateImageParams = Omit<Parameters<typeof experimental_generateImage>[0], 'model'> & {
|
||||
model: string | ImageModelV2
|
||||
}
|
||||
export type generateObjectParams = Parameters<typeof generateObject>[0]
|
||||
export type generateTextParams = Parameters<typeof generateText>[0]
|
||||
export type streamObjectParams = Parameters<typeof streamObject>[0]
|
||||
export type streamTextParams = Parameters<typeof streamText>[0]
|
||||
|
||||
@ -35,8 +35,10 @@ export enum IpcChannel {
|
||||
App_InstallBunBinary = 'app:install-bun-binary',
|
||||
App_LogToMain = 'app:log-to-main',
|
||||
App_SaveData = 'app:save-data',
|
||||
App_GetDiskInfo = 'app:get-disk-info',
|
||||
App_SetFullScreen = 'app:set-full-screen',
|
||||
App_IsFullScreen = 'app:is-full-screen',
|
||||
App_GetSystemFonts = 'app:get-system-fonts',
|
||||
|
||||
App_MacIsProcessTrusted = 'app:mac-is-process-trusted',
|
||||
App_MacRequestProcessTrust = 'app:mac-request-process-trust',
|
||||
@ -331,6 +333,13 @@ export enum IpcChannel {
|
||||
TRACE_CLEAN_LOCAL_DATA = 'trace:cleanLocalData',
|
||||
TRACE_ADD_STREAM_MESSAGE = 'trace:addStreamMessage',
|
||||
|
||||
// API Server
|
||||
ApiServer_Start = 'api-server:start',
|
||||
ApiServer_Stop = 'api-server:stop',
|
||||
ApiServer_Restart = 'api-server:restart',
|
||||
ApiServer_GetStatus = 'api-server:get-status',
|
||||
ApiServer_GetConfig = 'api-server:get-config',
|
||||
|
||||
// Anthropic OAuth
|
||||
Anthropic_StartOAuthFlow = 'anthropic:start-oauth-flow',
|
||||
Anthropic_CompleteOAuthWithCode = 'anthropic:complete-oauth-with-code',
|
||||
|
||||
6
packages/shared/utils.ts
Normal file
6
packages/shared/utils.ts
Normal file
@ -0,0 +1,6 @@
|
||||
export const defaultAppHeaders = () => {
|
||||
return {
|
||||
'HTTP-Referer': 'https://cherry-ai.com',
|
||||
'X-Title': 'Cherry Studio'
|
||||
}
|
||||
}
|
||||
@ -8,18 +8,18 @@
|
||||
</head>
|
||||
|
||||
<body class="bg-gray-50">
|
||||
<div class="max-w-4xl mx-auto px-4 py-8">
|
||||
<div class="mx-auto max-w-4xl px-4 py-8">
|
||||
<!-- 中文版本 -->
|
||||
<div class="mb-12">
|
||||
<h1 class="text-3xl font-bold mb-8 text-gray-900">许可协议</h1>
|
||||
<h1 class="mb-8 text-3xl font-bold text-gray-900">许可协议</h1>
|
||||
|
||||
<p class="mb-6 text-gray-700">
|
||||
本项目采用<strong>区分用户的双重许可 (User-Segmented Dual Licensing)</strong> 模式。
|
||||
</p>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">核心原则</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<h2 class="mb-4 text-xl font-semibold text-gray-900">核心原则</h2>
|
||||
<ul class="list-disc space-y-2 pl-6 text-gray-700">
|
||||
<li>
|
||||
<strong>个人用户 和 10人及以下企业/组织:</strong> 默认适用
|
||||
<strong>GNU Affero 通用公共许可证 v3.0 (AGPLv3)</strong>。
|
||||
@ -32,7 +32,7 @@
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">定义:"10人及以下"</h2>
|
||||
<h2 class="mb-4 text-xl font-semibold text-gray-900">定义:"10人及以下"</h2>
|
||||
<p class="text-gray-700">
|
||||
指在您的组织(包括公司、非营利组织、政府机构、教育机构等任何实体)中,能够访问、使用或以任何方式直接或间接受益于本软件(Cherry
|
||||
Studio)功能的个人总数不超过10人。这包括但不限于开发者、测试人员、运营人员、最终用户、通过集成系统间接使用者等。
|
||||
@ -40,10 +40,10 @@
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">
|
||||
<h2 class="mb-4 text-xl font-semibold text-gray-900">
|
||||
1. 开源许可证 (Open Source License): AGPLv3 - 适用于个人及10人及以下组织
|
||||
</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<ul class="list-disc space-y-2 pl-6 text-gray-700">
|
||||
<li>
|
||||
如果您是个人用户,或者您的组织满足上述"10人及以下"的定义,您可以在
|
||||
<strong>AGPLv3</strong> 的条款下自由使用、修改和分发 Cherry Studio。AGPLv3 的完整文本可以访问
|
||||
@ -62,10 +62,10 @@
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">
|
||||
<h2 class="mb-4 text-xl font-semibold text-gray-900">
|
||||
2. 商业许可证 (Commercial License) - 适用于超过10人的组织,或希望规避 AGPLv3 义务的用户
|
||||
</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<ul class="list-disc space-y-2 pl-6 text-gray-700">
|
||||
<li>
|
||||
<strong>强制要求:</strong>
|
||||
如果您的组织<strong>不</strong>满足上述"10人及以下"的定义(即有11人或更多人可以访问、使用或受益于本软件),您<strong>必须</strong>联系我们获取并签署一份商业许可证才能使用
|
||||
@ -80,7 +80,7 @@
|
||||
</li>
|
||||
<li>
|
||||
<strong>需要商业许可证的常见情况包括(但不限于):</strong>
|
||||
<ul class="list-disc pl-6 mt-2 space-y-1">
|
||||
<ul class="mt-2 list-disc space-y-1 pl-6">
|
||||
<li>您的组织规模超过10人。</li>
|
||||
<li>
|
||||
(无论组织规模)您希望分发修改过的 Cherry Studio 版本,但<strong>不希望</strong>根据 AGPLv3
|
||||
@ -104,8 +104,8 @@
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">3. 贡献 (Contributions)</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<h2 class="mb-4 text-xl font-semibold text-gray-900">3. 贡献 (Contributions)</h2>
|
||||
<ul class="list-disc space-y-2 pl-6 text-gray-700">
|
||||
<li>
|
||||
我们欢迎社区对 Cherry Studio 的贡献。所有向本项目提交的贡献都将被视为在
|
||||
<strong>AGPLv3</strong> 许可证下提供。
|
||||
@ -119,8 +119,8 @@
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">4. 其他条款 (Other Terms)</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<h2 class="mb-4 text-xl font-semibold text-gray-900">4. 其他条款 (Other Terms)</h2>
|
||||
<ul class="list-disc space-y-2 pl-6 text-gray-700">
|
||||
<li>关于商业许可证的具体条款和条件,以双方签署的正式商业许可协议为准。</li>
|
||||
<li>
|
||||
项目维护者保留根据需要更新本许可政策(包括用户规模定义和阈值)的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。
|
||||
@ -133,13 +133,13 @@
|
||||
|
||||
<!-- English Version -->
|
||||
<div>
|
||||
<h1 class="text-3xl font-bold mb-8 text-gray-900">Licensing</h1>
|
||||
<h1 class="mb-8 text-3xl font-bold text-gray-900">Licensing</h1>
|
||||
|
||||
<p class="mb-6 text-gray-700">This project employs a <strong>User-Segmented Dual Licensing</strong> model.</p>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">Core Principle</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<h2 class="mb-4 text-xl font-semibold text-gray-900">Core Principle</h2>
|
||||
<ul class="list-disc space-y-2 pl-6 text-gray-700">
|
||||
<li>
|
||||
<strong>Individual Users and Organizations with 10 or Fewer Individuals:</strong> Governed by default
|
||||
under the <strong>GNU Affero General Public License v3.0 (AGPLv3)</strong>.
|
||||
@ -152,7 +152,7 @@
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">Definition: "10 or Fewer Individuals"</h2>
|
||||
<h2 class="mb-4 text-xl font-semibold text-gray-900">Definition: "10 or Fewer Individuals"</h2>
|
||||
<p class="text-gray-700">
|
||||
Refers to any organization (including companies, non-profits, government agencies, educational institutions,
|
||||
etc.) where the total number of individuals who can access, use, or in any way directly or indirectly
|
||||
@ -162,10 +162,10 @@
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">
|
||||
<h2 class="mb-4 text-xl font-semibold text-gray-900">
|
||||
1. Open Source License: AGPLv3 - For Individuals and Organizations of 10 or Fewer
|
||||
</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<ul class="list-disc space-y-2 pl-6 text-gray-700">
|
||||
<li>
|
||||
If you are an individual user, or if your organization meets the "10 or Fewer Individuals" definition
|
||||
above, you are free to use, modify, and distribute Cherry Studio under the terms of the
|
||||
@ -186,11 +186,11 @@
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">
|
||||
<h2 class="mb-4 text-xl font-semibold text-gray-900">
|
||||
2. Commercial License - For Organizations with More Than 10 Individuals, or Users Needing to Avoid AGPLv3
|
||||
Obligations
|
||||
</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<ul class="list-disc space-y-2 pl-6 text-gray-700">
|
||||
<li>
|
||||
<strong>Mandatory Requirement:</strong> If your organization does <strong>not</strong> meet the "10 or
|
||||
Fewer Individuals" definition above (i.e., 11 or more individuals can access, use, or benefit from the
|
||||
@ -207,7 +207,7 @@
|
||||
</li>
|
||||
<li>
|
||||
<strong>Common scenarios requiring a Commercial License include (but are not limited to):</strong>
|
||||
<ul class="list-disc pl-6 mt-2 space-y-1">
|
||||
<ul class="mt-2 list-disc space-y-1 pl-6">
|
||||
<li>
|
||||
Your organization has more than 10 individuals who can access, use, or benefit from the software.
|
||||
</li>
|
||||
@ -236,8 +236,8 @@
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">3. Contributions</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<h2 class="mb-4 text-xl font-semibold text-gray-900">3. Contributions</h2>
|
||||
<ul class="list-disc space-y-2 pl-6 text-gray-700">
|
||||
<li>
|
||||
We welcome community contributions to Cherry Studio. All contributions submitted to this project are
|
||||
considered to be offered under the <strong>AGPLv3</strong> license.
|
||||
@ -255,8 +255,8 @@
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">4. Other Terms</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<h2 class="mb-4 text-xl font-semibold text-gray-900">4. Other Terms</h2>
|
||||
<ul class="list-disc space-y-2 pl-6 text-gray-700">
|
||||
<li>
|
||||
The specific terms and conditions of the Commercial License are governed by the formal commercial license
|
||||
agreement signed by both parties.
|
||||
|
||||
@ -12,18 +12,18 @@
|
||||
|
||||
<body id="app">
|
||||
<div :class="isDark ? 'dark-bg' : 'bg'" class="min-h-screen">
|
||||
<div class="max-w-3xl mx-auto py-12 px-4">
|
||||
<h1 class="text-3xl font-bold mb-8" :class="isDark ? 'text-white' : 'text-gray-900'">Release Timeline</h1>
|
||||
<div class="mx-auto max-w-3xl px-4 py-12">
|
||||
<h1 class="mb-8 text-3xl font-bold" :class="isDark ? 'text-white' : 'text-gray-900'">Release Timeline</h1>
|
||||
|
||||
<!-- Loading状态 -->
|
||||
<div v-if="loading" class="text-center py-8">
|
||||
<div v-if="loading" class="py-8 text-center">
|
||||
<div
|
||||
class="inline-block animate-spin rounded-full h-8 w-8 border-4"
|
||||
class="inline-block h-8 w-8 animate-spin rounded-full border-4"
|
||||
:class="isDark ? 'border-gray-700 border-t-blue-500' : 'border-gray-300 border-t-blue-500'"></div>
|
||||
</div>
|
||||
|
||||
<!-- Error 状态 -->
|
||||
<div v-else-if="error" class="text-red-500 text-center py-8">{{ error }}</div>
|
||||
<div v-else-if="error" class="py-8 text-center text-red-500">{{ error }}</div>
|
||||
|
||||
<!-- Release 列表 -->
|
||||
<div v-else class="space-y-8">
|
||||
@ -32,21 +32,21 @@
|
||||
:key="release.id"
|
||||
class="relative pl-8"
|
||||
:class="isDark ? 'border-l-2 border-gray-700' : 'border-l-2 border-gray-200'">
|
||||
<div class="absolute -left-2 top-0 w-4 h-4 rounded-full bg-green-500"></div>
|
||||
<div class="absolute top-0 -left-2 h-4 w-4 rounded-full bg-green-500"></div>
|
||||
<div
|
||||
class="rounded-lg shadow-sm p-6 transition-shadow"
|
||||
class="rounded-lg p-6 shadow-sm transition-shadow"
|
||||
:class="isDark ? 'bg-black hover:shadow-md hover:shadow-black' : 'bg-white hover:shadow-md'">
|
||||
<div class="flex items-start justify-between mb-4">
|
||||
<div class="mb-4 flex items-start justify-between">
|
||||
<div>
|
||||
<h2 class="text-xl font-semibold" :class="isDark ? 'text-white' : 'text-gray-900'">
|
||||
{{ release.name || release.tag_name }}
|
||||
</h2>
|
||||
<p class="text-sm mt-1" :class="isDark ? 'text-gray-400' : 'text-gray-500'">
|
||||
<p class="mt-1 text-sm" :class="isDark ? 'text-gray-400' : 'text-gray-500'">
|
||||
{{ formatDate(release.published_at) }}
|
||||
</p>
|
||||
</div>
|
||||
<span
|
||||
class="inline-flex items-center px-3 py-1 rounded-full text-sm font-medium"
|
||||
class="inline-flex items-center rounded-full px-3 py-1 text-sm font-medium"
|
||||
:class="isDark ? 'bg-green-900 text-green-200' : 'bg-green-100 text-green-800'">
|
||||
{{ release.tag_name }}
|
||||
</span>
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { exec } from 'child_process'
|
||||
import * as fs from 'fs/promises'
|
||||
import linguistLanguages from 'linguist-languages'
|
||||
import * as linguistLanguages from 'linguist-languages'
|
||||
import * as path from 'path'
|
||||
import { promisify } from 'util'
|
||||
|
||||
|
||||
128
src/main/apiServer/app.ts
Normal file
128
src/main/apiServer/app.ts
Normal file
@ -0,0 +1,128 @@
|
||||
import { loggerService } from '@main/services/LoggerService'
|
||||
import cors from 'cors'
|
||||
import express from 'express'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
import { authMiddleware } from './middleware/auth'
|
||||
import { errorHandler } from './middleware/error'
|
||||
import { setupOpenAPIDocumentation } from './middleware/openapi'
|
||||
import { chatRoutes } from './routes/chat'
|
||||
import { mcpRoutes } from './routes/mcp'
|
||||
import { modelsRoutes } from './routes/models'
|
||||
|
||||
const logger = loggerService.withContext('ApiServer')
|
||||
|
||||
const app = express()
|
||||
|
||||
// Global middleware
|
||||
app.use((req, res, next) => {
|
||||
const start = Date.now()
|
||||
res.on('finish', () => {
|
||||
const duration = Date.now() - start
|
||||
logger.info(`${req.method} ${req.path} - ${res.statusCode} - ${duration}ms`)
|
||||
})
|
||||
next()
|
||||
})
|
||||
|
||||
app.use((_req, res, next) => {
|
||||
res.setHeader('X-Request-ID', uuidv4())
|
||||
next()
|
||||
})
|
||||
|
||||
app.use(
|
||||
cors({
|
||||
origin: '*',
|
||||
allowedHeaders: ['Content-Type', 'Authorization'],
|
||||
methods: ['GET', 'POST', 'PUT', 'DELETE', 'OPTIONS']
|
||||
})
|
||||
)
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /health:
|
||||
* get:
|
||||
* summary: Health check endpoint
|
||||
* description: Check server status (no authentication required)
|
||||
* tags: [Health]
|
||||
* security: []
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Server is healthy
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* status:
|
||||
* type: string
|
||||
* example: ok
|
||||
* timestamp:
|
||||
* type: string
|
||||
* format: date-time
|
||||
* version:
|
||||
* type: string
|
||||
* example: 1.0.0
|
||||
*/
|
||||
app.get('/health', (_req, res) => {
|
||||
res.json({
|
||||
status: 'ok',
|
||||
timestamp: new Date().toISOString(),
|
||||
version: process.env.npm_package_version || '1.0.0'
|
||||
})
|
||||
})
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /:
|
||||
* get:
|
||||
* summary: API information
|
||||
* description: Get basic API information and available endpoints
|
||||
* tags: [General]
|
||||
* security: []
|
||||
* responses:
|
||||
* 200:
|
||||
* description: API information
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* name:
|
||||
* type: string
|
||||
* example: Cherry Studio API
|
||||
* version:
|
||||
* type: string
|
||||
* example: 1.0.0
|
||||
* endpoints:
|
||||
* type: object
|
||||
*/
|
||||
app.get('/', (_req, res) => {
|
||||
res.json({
|
||||
name: 'Cherry Studio API',
|
||||
version: '1.0.0',
|
||||
endpoints: {
|
||||
health: 'GET /health',
|
||||
models: 'GET /v1/models',
|
||||
chat: 'POST /v1/chat/completions',
|
||||
mcp: 'GET /v1/mcps'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// API v1 routes with auth
|
||||
const apiRouter = express.Router()
|
||||
apiRouter.use(authMiddleware)
|
||||
apiRouter.use(express.json())
|
||||
// Mount routes
|
||||
apiRouter.use('/chat', chatRoutes)
|
||||
apiRouter.use('/mcps', mcpRoutes)
|
||||
apiRouter.use('/models', modelsRoutes)
|
||||
app.use('/v1', apiRouter)
|
||||
|
||||
// Setup OpenAPI documentation
|
||||
setupOpenAPIDocumentation(app)
|
||||
|
||||
// Error handling (must be last)
|
||||
app.use(errorHandler)
|
||||
|
||||
export { app }
|
||||
65
src/main/apiServer/config.ts
Normal file
65
src/main/apiServer/config.ts
Normal file
@ -0,0 +1,65 @@
|
||||
import { ApiServerConfig } from '@types'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
import { loggerService } from '../services/LoggerService'
|
||||
import { reduxService } from '../services/ReduxService'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerConfig')
|
||||
|
||||
const defaultHost = 'localhost'
|
||||
const defaultPort = 23333
|
||||
|
||||
class ConfigManager {
|
||||
private _config: ApiServerConfig | null = null
|
||||
|
||||
private generateApiKey(): string {
|
||||
return `cs-sk-${uuidv4()}`
|
||||
}
|
||||
|
||||
async load(): Promise<ApiServerConfig> {
|
||||
try {
|
||||
const settings = await reduxService.select('state.settings')
|
||||
const serverSettings = settings?.apiServer
|
||||
let apiKey = serverSettings?.apiKey
|
||||
if (!apiKey || apiKey.trim() === '') {
|
||||
apiKey = this.generateApiKey()
|
||||
await reduxService.dispatch({
|
||||
type: 'settings/setApiServerApiKey',
|
||||
payload: apiKey
|
||||
})
|
||||
}
|
||||
this._config = {
|
||||
enabled: serverSettings?.enabled ?? false,
|
||||
port: serverSettings?.port ?? defaultPort,
|
||||
host: defaultHost,
|
||||
apiKey: apiKey
|
||||
}
|
||||
return this._config
|
||||
} catch (error: any) {
|
||||
logger.warn('Failed to load config from Redux, using defaults:', error)
|
||||
this._config = {
|
||||
enabled: false,
|
||||
port: defaultPort,
|
||||
host: defaultHost,
|
||||
apiKey: this.generateApiKey()
|
||||
}
|
||||
return this._config
|
||||
}
|
||||
}
|
||||
|
||||
async get(): Promise<ApiServerConfig> {
|
||||
if (!this._config) {
|
||||
await this.load()
|
||||
}
|
||||
if (!this._config) {
|
||||
throw new Error('Failed to load API server configuration')
|
||||
}
|
||||
return this._config
|
||||
}
|
||||
|
||||
async reload(): Promise<ApiServerConfig> {
|
||||
return await this.load()
|
||||
}
|
||||
}
|
||||
|
||||
export const config = new ConfigManager()
|
||||
2
src/main/apiServer/index.ts
Normal file
2
src/main/apiServer/index.ts
Normal file
@ -0,0 +1,2 @@
|
||||
export { config } from './config'
|
||||
export { apiServer } from './server'
|
||||
62
src/main/apiServer/middleware/auth.ts
Normal file
62
src/main/apiServer/middleware/auth.ts
Normal file
@ -0,0 +1,62 @@
|
||||
import crypto from 'crypto'
|
||||
import { NextFunction, Request, Response } from 'express'
|
||||
|
||||
import { config } from '../config'
|
||||
|
||||
export const authMiddleware = async (req: Request, res: Response, next: NextFunction) => {
|
||||
const auth = req.header('Authorization') || ''
|
||||
const xApiKey = req.header('x-api-key') || ''
|
||||
|
||||
// Fast rejection if neither credential header provided
|
||||
if (!auth && !xApiKey) {
|
||||
return res.status(401).json({ error: 'Unauthorized: missing credentials' })
|
||||
}
|
||||
|
||||
let token: string | undefined
|
||||
|
||||
// Prefer Bearer if well‑formed
|
||||
if (auth) {
|
||||
const trimmed = auth.trim()
|
||||
const bearerPrefix = /^Bearer\s+/i
|
||||
if (bearerPrefix.test(trimmed)) {
|
||||
const candidate = trimmed.replace(bearerPrefix, '').trim()
|
||||
if (!candidate) {
|
||||
return res.status(401).json({ error: 'Unauthorized: empty bearer token' })
|
||||
}
|
||||
token = candidate
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to x-api-key if token still not resolved
|
||||
if (!token && xApiKey) {
|
||||
if (!xApiKey.trim()) {
|
||||
return res.status(401).json({ error: 'Unauthorized: empty x-api-key' })
|
||||
}
|
||||
token = xApiKey.trim()
|
||||
}
|
||||
|
||||
if (!token) {
|
||||
// At this point we had at least one header, but none yielded a usable token
|
||||
return res.status(401).json({ error: 'Unauthorized: invalid credentials format' })
|
||||
}
|
||||
|
||||
const { apiKey } = await config.get()
|
||||
|
||||
if (!apiKey) {
|
||||
// If server not configured, treat as forbidden (or could be 500). Choose 403 to avoid leaking config state.
|
||||
return res.status(403).json({ error: 'Forbidden' })
|
||||
}
|
||||
|
||||
// Timing-safe compare when lengths match, else immediate forbidden
|
||||
if (token.length !== apiKey.length) {
|
||||
return res.status(403).json({ error: 'Forbidden' })
|
||||
}
|
||||
|
||||
const tokenBuf = Buffer.from(token)
|
||||
const keyBuf = Buffer.from(apiKey)
|
||||
if (!crypto.timingSafeEqual(tokenBuf, keyBuf)) {
|
||||
return res.status(403).json({ error: 'Forbidden' })
|
||||
}
|
||||
|
||||
return next()
|
||||
}
|
||||
21
src/main/apiServer/middleware/error.ts
Normal file
21
src/main/apiServer/middleware/error.ts
Normal file
@ -0,0 +1,21 @@
|
||||
import { NextFunction, Request, Response } from 'express'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerErrorHandler')
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
export const errorHandler = (err: Error, _req: Request, res: Response, _next: NextFunction) => {
|
||||
logger.error('API Server Error:', err)
|
||||
|
||||
// Don't expose internal errors in production
|
||||
const isDev = process.env.NODE_ENV === 'development'
|
||||
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: isDev ? err.message : 'Internal server error',
|
||||
type: 'server_error',
|
||||
...(isDev && { stack: err.stack })
|
||||
}
|
||||
})
|
||||
}
|
||||
206
src/main/apiServer/middleware/openapi.ts
Normal file
206
src/main/apiServer/middleware/openapi.ts
Normal file
@ -0,0 +1,206 @@
|
||||
import { Express } from 'express'
|
||||
import swaggerJSDoc from 'swagger-jsdoc'
|
||||
import swaggerUi from 'swagger-ui-express'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
|
||||
const logger = loggerService.withContext('OpenAPIMiddleware')
|
||||
|
||||
const swaggerOptions: swaggerJSDoc.Options = {
|
||||
definition: {
|
||||
openapi: '3.0.0',
|
||||
info: {
|
||||
title: 'Cherry Studio API',
|
||||
version: '1.0.0',
|
||||
description: 'OpenAI-compatible API for Cherry Studio with additional Cherry-specific endpoints',
|
||||
contact: {
|
||||
name: 'Cherry Studio',
|
||||
url: 'https://github.com/CherryHQ/cherry-studio'
|
||||
}
|
||||
},
|
||||
servers: [
|
||||
{
|
||||
url: 'http://localhost:23333',
|
||||
description: 'Local development server'
|
||||
}
|
||||
],
|
||||
components: {
|
||||
securitySchemes: {
|
||||
BearerAuth: {
|
||||
type: 'http',
|
||||
scheme: 'bearer',
|
||||
bearerFormat: 'JWT',
|
||||
description: 'Use the API key from Cherry Studio settings'
|
||||
}
|
||||
},
|
||||
schemas: {
|
||||
Error: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
error: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
message: { type: 'string' },
|
||||
type: { type: 'string' },
|
||||
code: { type: 'string' }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
ChatMessage: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
role: {
|
||||
type: 'string',
|
||||
enum: ['system', 'user', 'assistant', 'tool']
|
||||
},
|
||||
content: {
|
||||
oneOf: [
|
||||
{ type: 'string' },
|
||||
{
|
||||
type: 'array',
|
||||
items: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
type: { type: 'string' },
|
||||
text: { type: 'string' },
|
||||
image_url: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
url: { type: 'string' }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
name: { type: 'string' },
|
||||
tool_calls: {
|
||||
type: 'array',
|
||||
items: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
id: { type: 'string' },
|
||||
type: { type: 'string' },
|
||||
function: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
name: { type: 'string' },
|
||||
arguments: { type: 'string' }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
ChatCompletionRequest: {
|
||||
type: 'object',
|
||||
required: ['model', 'messages'],
|
||||
properties: {
|
||||
model: {
|
||||
type: 'string',
|
||||
description: 'The model to use for completion, in format provider:model-id'
|
||||
},
|
||||
messages: {
|
||||
type: 'array',
|
||||
items: { $ref: '#/components/schemas/ChatMessage' }
|
||||
},
|
||||
temperature: {
|
||||
type: 'number',
|
||||
minimum: 0,
|
||||
maximum: 2,
|
||||
default: 1
|
||||
},
|
||||
max_tokens: {
|
||||
type: 'integer',
|
||||
minimum: 1
|
||||
},
|
||||
stream: {
|
||||
type: 'boolean',
|
||||
default: false
|
||||
},
|
||||
tools: {
|
||||
type: 'array',
|
||||
items: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
type: { type: 'string' },
|
||||
function: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
name: { type: 'string' },
|
||||
description: { type: 'string' },
|
||||
parameters: { type: 'object' }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Model: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
id: { type: 'string' },
|
||||
object: { type: 'string', enum: ['model'] },
|
||||
created: { type: 'integer' },
|
||||
owned_by: { type: 'string' }
|
||||
}
|
||||
},
|
||||
MCPServer: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
id: { type: 'string' },
|
||||
name: { type: 'string' },
|
||||
command: { type: 'string' },
|
||||
args: {
|
||||
type: 'array',
|
||||
items: { type: 'string' }
|
||||
},
|
||||
env: { type: 'object' },
|
||||
disabled: { type: 'boolean' }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
security: [
|
||||
{
|
||||
BearerAuth: []
|
||||
}
|
||||
]
|
||||
},
|
||||
apis: ['./src/main/apiServer/routes/*.ts', './src/main/apiServer/app.ts']
|
||||
}
|
||||
|
||||
export function setupOpenAPIDocumentation(app: Express) {
|
||||
try {
|
||||
const specs = swaggerJSDoc(swaggerOptions)
|
||||
|
||||
// Serve OpenAPI JSON
|
||||
app.get('/api-docs.json', (_req, res) => {
|
||||
res.setHeader('Content-Type', 'application/json')
|
||||
res.send(specs)
|
||||
})
|
||||
|
||||
// Serve Swagger UI
|
||||
app.use(
|
||||
'/api-docs',
|
||||
swaggerUi.serve,
|
||||
swaggerUi.setup(specs, {
|
||||
customCss: `
|
||||
.swagger-ui .topbar { display: none; }
|
||||
.swagger-ui .info .title { color: #1890ff; }
|
||||
`,
|
||||
customSiteTitle: 'Cherry Studio API Documentation'
|
||||
})
|
||||
)
|
||||
|
||||
logger.info('OpenAPI documentation setup complete')
|
||||
logger.info('Documentation available at /api-docs')
|
||||
logger.info('OpenAPI spec available at /api-docs.json')
|
||||
} catch (error) {
|
||||
logger.error('Failed to setup OpenAPI documentation:', error as Error)
|
||||
}
|
||||
}
|
||||
225
src/main/apiServer/routes/chat.ts
Normal file
225
src/main/apiServer/routes/chat.ts
Normal file
@ -0,0 +1,225 @@
|
||||
import express, { Request, Response } from 'express'
|
||||
import OpenAI from 'openai'
|
||||
import { ChatCompletionCreateParams } from 'openai/resources'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { chatCompletionService } from '../services/chat-completion'
|
||||
import { validateModelId } from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerChatRoutes')
|
||||
|
||||
const router = express.Router()
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/chat/completions:
|
||||
* post:
|
||||
* summary: Create chat completion
|
||||
* description: Create a chat completion response, compatible with OpenAI API
|
||||
* tags: [Chat]
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ChatCompletionRequest'
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Chat completion response
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* id:
|
||||
* type: string
|
||||
* object:
|
||||
* type: string
|
||||
* example: chat.completion
|
||||
* created:
|
||||
* type: integer
|
||||
* model:
|
||||
* type: string
|
||||
* choices:
|
||||
* type: array
|
||||
* items:
|
||||
* type: object
|
||||
* properties:
|
||||
* index:
|
||||
* type: integer
|
||||
* message:
|
||||
* $ref: '#/components/schemas/ChatMessage'
|
||||
* finish_reason:
|
||||
* type: string
|
||||
* usage:
|
||||
* type: object
|
||||
* properties:
|
||||
* prompt_tokens:
|
||||
* type: integer
|
||||
* completion_tokens:
|
||||
* type: integer
|
||||
* total_tokens:
|
||||
* type: integer
|
||||
* text/plain:
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Server-sent events stream (when stream=true)
|
||||
* 400:
|
||||
* description: Bad request
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
* 401:
|
||||
* description: Unauthorized
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
* 429:
|
||||
* description: Rate limit exceeded
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
* 500:
|
||||
* description: Internal server error
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
*/
|
||||
router.post('/completions', async (req: Request, res: Response) => {
|
||||
try {
|
||||
const request: ChatCompletionCreateParams = req.body
|
||||
|
||||
if (!request) {
|
||||
return res.status(400).json({
|
||||
error: {
|
||||
message: 'Request body is required',
|
||||
type: 'invalid_request_error',
|
||||
code: 'missing_body'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('Chat completion request:', {
|
||||
model: request.model,
|
||||
messageCount: request.messages?.length || 0,
|
||||
stream: request.stream,
|
||||
temperature: request.temperature
|
||||
})
|
||||
|
||||
// Validate request
|
||||
const validation = chatCompletionService.validateRequest(request)
|
||||
if (!validation.isValid) {
|
||||
return res.status(400).json({
|
||||
error: {
|
||||
message: validation.errors.join('; '),
|
||||
type: 'invalid_request_error',
|
||||
code: 'validation_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Validate model ID and get provider
|
||||
const modelValidation = await validateModelId(request.model)
|
||||
if (!modelValidation.valid) {
|
||||
const error = modelValidation.error!
|
||||
logger.warn(`Model validation failed for '${request.model}':`, error)
|
||||
return res.status(400).json({
|
||||
error: {
|
||||
message: error.message,
|
||||
type: 'invalid_request_error',
|
||||
code: error.code
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const provider = modelValidation.provider!
|
||||
const modelId = modelValidation.modelId!
|
||||
|
||||
logger.info('Model validation successful:', {
|
||||
provider: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId: modelId,
|
||||
fullModelId: request.model
|
||||
})
|
||||
|
||||
// Create OpenAI client
|
||||
const client = new OpenAI({
|
||||
baseURL: provider.apiHost,
|
||||
apiKey: provider.apiKey
|
||||
})
|
||||
request.model = modelId
|
||||
|
||||
// Handle streaming
|
||||
if (request.stream) {
|
||||
const streamResponse = await client.chat.completions.create(request)
|
||||
|
||||
res.setHeader('Content-Type', 'text/plain; charset=utf-8')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
|
||||
try {
|
||||
for await (const chunk of streamResponse as any) {
|
||||
res.write(`data: ${JSON.stringify(chunk)}\n\n`)
|
||||
}
|
||||
res.write('data: [DONE]\n\n')
|
||||
res.end()
|
||||
} catch (streamError: any) {
|
||||
logger.error('Stream error:', streamError)
|
||||
res.write(
|
||||
`data: ${JSON.stringify({
|
||||
error: {
|
||||
message: 'Stream processing error',
|
||||
type: 'server_error',
|
||||
code: 'stream_error'
|
||||
}
|
||||
})}\n\n`
|
||||
)
|
||||
res.end()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Handle non-streaming
|
||||
const response = await client.chat.completions.create(request)
|
||||
return res.json(response)
|
||||
} catch (error: any) {
|
||||
logger.error('Chat completion error:', error)
|
||||
|
||||
let statusCode = 500
|
||||
let errorType = 'server_error'
|
||||
let errorCode = 'internal_error'
|
||||
let errorMessage = 'Internal server error'
|
||||
|
||||
if (error instanceof Error) {
|
||||
errorMessage = error.message
|
||||
|
||||
if (error.message.includes('API key') || error.message.includes('authentication')) {
|
||||
statusCode = 401
|
||||
errorType = 'authentication_error'
|
||||
errorCode = 'invalid_api_key'
|
||||
} else if (error.message.includes('rate limit') || error.message.includes('quota')) {
|
||||
statusCode = 429
|
||||
errorType = 'rate_limit_error'
|
||||
errorCode = 'rate_limit_exceeded'
|
||||
} else if (error.message.includes('timeout') || error.message.includes('connection')) {
|
||||
statusCode = 502
|
||||
errorType = 'server_error'
|
||||
errorCode = 'upstream_error'
|
||||
}
|
||||
}
|
||||
|
||||
return res.status(statusCode).json({
|
||||
error: {
|
||||
message: errorMessage,
|
||||
type: errorType,
|
||||
code: errorCode
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
export { router as chatRoutes }
|
||||
153
src/main/apiServer/routes/mcp.ts
Normal file
153
src/main/apiServer/routes/mcp.ts
Normal file
@ -0,0 +1,153 @@
|
||||
import express, { Request, Response } from 'express'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { mcpApiService } from '../services/mcp'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerMCPRoutes')
|
||||
|
||||
const router = express.Router()
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/mcps:
|
||||
* get:
|
||||
* summary: List MCP servers
|
||||
* description: Get a list of all configured Model Context Protocol servers
|
||||
* tags: [MCP]
|
||||
* responses:
|
||||
* 200:
|
||||
* description: List of MCP servers
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* success:
|
||||
* type: boolean
|
||||
* data:
|
||||
* type: array
|
||||
* items:
|
||||
* $ref: '#/components/schemas/MCPServer'
|
||||
* 503:
|
||||
* description: Service unavailable
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* success:
|
||||
* type: boolean
|
||||
* example: false
|
||||
* error:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
*/
|
||||
router.get('/', async (req: Request, res: Response) => {
|
||||
try {
|
||||
logger.info('Get all MCP servers request received')
|
||||
const servers = await mcpApiService.getAllServers(req)
|
||||
return res.json({
|
||||
success: true,
|
||||
data: servers
|
||||
})
|
||||
} catch (error: any) {
|
||||
logger.error('Error fetching MCP servers:', error)
|
||||
return res.status(503).json({
|
||||
success: false,
|
||||
error: {
|
||||
message: `Failed to retrieve MCP servers: ${error.message}`,
|
||||
type: 'service_unavailable',
|
||||
code: 'servers_unavailable'
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/mcps/{server_id}:
|
||||
* get:
|
||||
* summary: Get MCP server info
|
||||
* description: Get detailed information about a specific MCP server
|
||||
* tags: [MCP]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: server_id
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: MCP server ID
|
||||
* responses:
|
||||
* 200:
|
||||
* description: MCP server information
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* success:
|
||||
* type: boolean
|
||||
* data:
|
||||
* $ref: '#/components/schemas/MCPServer'
|
||||
* 404:
|
||||
* description: MCP server not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* success:
|
||||
* type: boolean
|
||||
* example: false
|
||||
* error:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
*/
|
||||
router.get('/:server_id', async (req: Request, res: Response) => {
|
||||
try {
|
||||
logger.info('Get MCP server info request received')
|
||||
const server = await mcpApiService.getServerInfo(req.params.server_id)
|
||||
if (!server) {
|
||||
logger.warn('MCP server not found')
|
||||
return res.status(404).json({
|
||||
success: false,
|
||||
error: {
|
||||
message: 'MCP server not found',
|
||||
type: 'not_found',
|
||||
code: 'server_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
return res.json({
|
||||
success: true,
|
||||
data: server
|
||||
})
|
||||
} catch (error: any) {
|
||||
logger.error('Error fetching MCP server info:', error)
|
||||
return res.status(503).json({
|
||||
success: false,
|
||||
error: {
|
||||
message: `Failed to retrieve MCP server info: ${error.message}`,
|
||||
type: 'service_unavailable',
|
||||
code: 'server_info_unavailable'
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Connect to MCP server
|
||||
router.all('/:server_id/mcp', async (req: Request, res: Response) => {
|
||||
const server = await mcpApiService.getServerById(req.params.server_id)
|
||||
if (!server) {
|
||||
logger.warn('MCP server not found')
|
||||
return res.status(404).json({
|
||||
success: false,
|
||||
error: {
|
||||
message: 'MCP server not found',
|
||||
type: 'not_found',
|
||||
code: 'server_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
return await mcpApiService.handleRequest(req, res, server)
|
||||
})
|
||||
|
||||
export { router as mcpRoutes }
|
||||
73
src/main/apiServer/routes/models.ts
Normal file
73
src/main/apiServer/routes/models.ts
Normal file
@ -0,0 +1,73 @@
|
||||
import express, { Request, Response } from 'express'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { chatCompletionService } from '../services/chat-completion'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerModelsRoutes')
|
||||
|
||||
const router = express.Router()
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/models:
|
||||
* get:
|
||||
* summary: List available models
|
||||
* description: Returns a list of available AI models from all configured providers
|
||||
* tags: [Models]
|
||||
* responses:
|
||||
* 200:
|
||||
* description: List of available models
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* object:
|
||||
* type: string
|
||||
* example: list
|
||||
* data:
|
||||
* type: array
|
||||
* items:
|
||||
* $ref: '#/components/schemas/Model'
|
||||
* 503:
|
||||
* description: Service unavailable
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
*/
|
||||
router.get('/', async (_req: Request, res: Response) => {
|
||||
try {
|
||||
logger.info('Models list request received')
|
||||
|
||||
const models = await chatCompletionService.getModels()
|
||||
|
||||
if (models.length === 0) {
|
||||
logger.warn(
|
||||
'No models available from providers. This may be because no OpenAI providers are configured or enabled.'
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(`Returning ${models.length} models (OpenAI providers only)`)
|
||||
logger.debug(
|
||||
'Model IDs:',
|
||||
models.map((m) => m.id)
|
||||
)
|
||||
|
||||
return res.json({
|
||||
object: 'list',
|
||||
data: models
|
||||
})
|
||||
} catch (error: any) {
|
||||
logger.error('Error fetching models:', error)
|
||||
return res.status(503).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve models from available providers',
|
||||
type: 'service_unavailable',
|
||||
code: 'models_unavailable'
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
export { router as modelsRoutes }
|
||||
65
src/main/apiServer/server.ts
Normal file
65
src/main/apiServer/server.ts
Normal file
@ -0,0 +1,65 @@
|
||||
import { createServer } from 'node:http'
|
||||
|
||||
import { loggerService } from '../services/LoggerService'
|
||||
import { app } from './app'
|
||||
import { config } from './config'
|
||||
|
||||
const logger = loggerService.withContext('ApiServer')
|
||||
|
||||
export class ApiServer {
|
||||
private server: ReturnType<typeof createServer> | null = null
|
||||
|
||||
async start(): Promise<void> {
|
||||
if (this.server) {
|
||||
logger.warn('Server already running')
|
||||
return
|
||||
}
|
||||
|
||||
// Load config
|
||||
const { port, host, apiKey } = await config.load()
|
||||
|
||||
// Create server with Express app
|
||||
this.server = createServer(app)
|
||||
|
||||
// Start server
|
||||
return new Promise((resolve, reject) => {
|
||||
this.server!.listen(port, host, () => {
|
||||
logger.info(`API Server started at http://${host}:${port}`)
|
||||
logger.info(`API Key: ${apiKey}`)
|
||||
resolve()
|
||||
})
|
||||
|
||||
this.server!.on('error', reject)
|
||||
})
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {
|
||||
if (!this.server) return
|
||||
|
||||
return new Promise((resolve) => {
|
||||
this.server!.close(() => {
|
||||
logger.info('API Server stopped')
|
||||
this.server = null
|
||||
resolve()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
async restart(): Promise<void> {
|
||||
await this.stop()
|
||||
await config.reload()
|
||||
await this.start()
|
||||
}
|
||||
|
||||
isRunning(): boolean {
|
||||
const hasServer = this.server !== null
|
||||
const isListening = this.server?.listening || false
|
||||
const result = hasServer && isListening
|
||||
|
||||
logger.debug('isRunning check:', { hasServer, isListening, result })
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
export const apiServer = new ApiServer()
|
||||
239
src/main/apiServer/services/chat-completion.ts
Normal file
239
src/main/apiServer/services/chat-completion.ts
Normal file
@ -0,0 +1,239 @@
|
||||
import OpenAI from 'openai'
|
||||
import { ChatCompletionCreateParams } from 'openai/resources'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import {
|
||||
getProviderByModel,
|
||||
getRealProviderModel,
|
||||
listAllAvailableModels,
|
||||
OpenAICompatibleModel,
|
||||
transformModelToOpenAI,
|
||||
validateProvider
|
||||
} from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ChatCompletionService')
|
||||
|
||||
export interface ModelData extends OpenAICompatibleModel {
|
||||
provider_id: string
|
||||
model_id: string
|
||||
name: string
|
||||
}
|
||||
|
||||
export interface ValidationResult {
|
||||
isValid: boolean
|
||||
errors: string[]
|
||||
}
|
||||
|
||||
export class ChatCompletionService {
|
||||
async getModels(): Promise<ModelData[]> {
|
||||
try {
|
||||
logger.info('Getting available models from providers')
|
||||
|
||||
const models = await listAllAvailableModels()
|
||||
|
||||
// Use Map to deduplicate models by their full ID (provider:model_id)
|
||||
const uniqueModels = new Map<string, ModelData>()
|
||||
|
||||
for (const model of models) {
|
||||
const openAIModel = transformModelToOpenAI(model)
|
||||
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
||||
|
||||
// Only add if not already present (first occurrence wins)
|
||||
if (!uniqueModels.has(fullModelId)) {
|
||||
uniqueModels.set(fullModelId, {
|
||||
...openAIModel,
|
||||
provider_id: model.provider,
|
||||
model_id: model.id,
|
||||
name: model.name
|
||||
})
|
||||
} else {
|
||||
logger.debug(`Skipping duplicate model: ${fullModelId}`)
|
||||
}
|
||||
}
|
||||
|
||||
const modelData = Array.from(uniqueModels.values())
|
||||
|
||||
logger.info(`Successfully retrieved ${modelData.length} unique models from ${models.length} total models`)
|
||||
|
||||
if (models.length > modelData.length) {
|
||||
logger.debug(`Filtered out ${models.length - modelData.length} duplicate models`)
|
||||
}
|
||||
|
||||
return modelData
|
||||
} catch (error: any) {
|
||||
logger.error('Error getting models:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
validateRequest(request: ChatCompletionCreateParams): ValidationResult {
|
||||
const errors: string[] = []
|
||||
|
||||
// Validate model
|
||||
if (!request.model) {
|
||||
errors.push('Model is required')
|
||||
} else if (typeof request.model !== 'string') {
|
||||
errors.push('Model must be a string')
|
||||
} else if (!request.model.includes(':')) {
|
||||
errors.push('Model must be in format "provider:model_id"')
|
||||
}
|
||||
|
||||
// Validate messages
|
||||
if (!request.messages) {
|
||||
errors.push('Messages array is required')
|
||||
} else if (!Array.isArray(request.messages)) {
|
||||
errors.push('Messages must be an array')
|
||||
} else if (request.messages.length === 0) {
|
||||
errors.push('Messages array cannot be empty')
|
||||
} else {
|
||||
// Validate each message
|
||||
request.messages.forEach((message, index) => {
|
||||
if (!message.role) {
|
||||
errors.push(`Message ${index}: role is required`)
|
||||
}
|
||||
if (!message.content) {
|
||||
errors.push(`Message ${index}: content is required`)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Validate optional parameters
|
||||
if (request.temperature !== undefined) {
|
||||
if (typeof request.temperature !== 'number' || request.temperature < 0 || request.temperature > 2) {
|
||||
errors.push('Temperature must be a number between 0 and 2')
|
||||
}
|
||||
}
|
||||
|
||||
if (request.max_tokens !== undefined) {
|
||||
if (typeof request.max_tokens !== 'number' || request.max_tokens < 1) {
|
||||
errors.push('max_tokens must be a positive number')
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: errors.length === 0,
|
||||
errors
|
||||
}
|
||||
}
|
||||
|
||||
async processCompletion(request: ChatCompletionCreateParams): Promise<OpenAI.Chat.Completions.ChatCompletion> {
|
||||
try {
|
||||
logger.info('Processing chat completion request:', {
|
||||
model: request.model,
|
||||
messageCount: request.messages.length,
|
||||
stream: request.stream
|
||||
})
|
||||
|
||||
// Validate request
|
||||
const validation = this.validateRequest(request)
|
||||
if (!validation.isValid) {
|
||||
throw new Error(`Request validation failed: ${validation.errors.join(', ')}`)
|
||||
}
|
||||
|
||||
// Get provider for the model
|
||||
const provider = await getProviderByModel(request.model!)
|
||||
if (!provider) {
|
||||
throw new Error(`Provider not found for model: ${request.model}`)
|
||||
}
|
||||
|
||||
// Validate provider
|
||||
if (!validateProvider(provider)) {
|
||||
throw new Error(`Provider validation failed for: ${provider.id}`)
|
||||
}
|
||||
|
||||
// Extract model ID from the full model string
|
||||
const modelId = getRealProviderModel(request.model)
|
||||
|
||||
// Create OpenAI client for the provider
|
||||
const client = new OpenAI({
|
||||
baseURL: provider.apiHost,
|
||||
apiKey: provider.apiKey
|
||||
})
|
||||
|
||||
// Prepare request with the actual model ID
|
||||
const providerRequest = {
|
||||
...request,
|
||||
model: modelId,
|
||||
stream: false
|
||||
}
|
||||
|
||||
logger.debug('Sending request to provider:', {
|
||||
provider: provider.id,
|
||||
model: modelId,
|
||||
apiHost: provider.apiHost
|
||||
})
|
||||
|
||||
const response = (await client.chat.completions.create(providerRequest)) as OpenAI.Chat.Completions.ChatCompletion
|
||||
|
||||
logger.info('Successfully processed chat completion')
|
||||
return response
|
||||
} catch (error: any) {
|
||||
logger.error('Error processing chat completion:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async *processStreamingCompletion(
|
||||
request: ChatCompletionCreateParams
|
||||
): AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk> {
|
||||
try {
|
||||
logger.info('Processing streaming chat completion request:', {
|
||||
model: request.model,
|
||||
messageCount: request.messages.length
|
||||
})
|
||||
|
||||
// Validate request
|
||||
const validation = this.validateRequest(request)
|
||||
if (!validation.isValid) {
|
||||
throw new Error(`Request validation failed: ${validation.errors.join(', ')}`)
|
||||
}
|
||||
|
||||
// Get provider for the model
|
||||
const provider = await getProviderByModel(request.model!)
|
||||
if (!provider) {
|
||||
throw new Error(`Provider not found for model: ${request.model}`)
|
||||
}
|
||||
|
||||
// Validate provider
|
||||
if (!validateProvider(provider)) {
|
||||
throw new Error(`Provider validation failed for: ${provider.id}`)
|
||||
}
|
||||
|
||||
// Extract model ID from the full model string
|
||||
const modelId = getRealProviderModel(request.model)
|
||||
|
||||
// Create OpenAI client for the provider
|
||||
const client = new OpenAI({
|
||||
baseURL: provider.apiHost,
|
||||
apiKey: provider.apiKey
|
||||
})
|
||||
|
||||
// Prepare streaming request
|
||||
const streamingRequest = {
|
||||
...request,
|
||||
model: modelId,
|
||||
stream: true as const
|
||||
}
|
||||
|
||||
logger.debug('Sending streaming request to provider:', {
|
||||
provider: provider.id,
|
||||
model: modelId,
|
||||
apiHost: provider.apiHost
|
||||
})
|
||||
|
||||
const stream = await client.chat.completions.create(streamingRequest)
|
||||
|
||||
for await (const chunk of stream) {
|
||||
yield chunk
|
||||
}
|
||||
|
||||
logger.info('Successfully completed streaming chat completion')
|
||||
} catch (error: any) {
|
||||
logger.error('Error processing streaming chat completion:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton instance
|
||||
export const chatCompletionService = new ChatCompletionService()
|
||||
251
src/main/apiServer/services/mcp.ts
Normal file
251
src/main/apiServer/services/mcp.ts
Normal file
@ -0,0 +1,251 @@
|
||||
import mcpService from '@main/services/MCPService'
|
||||
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp'
|
||||
import {
|
||||
isJSONRPCRequest,
|
||||
JSONRPCMessage,
|
||||
JSONRPCMessageSchema,
|
||||
MessageExtraInfo
|
||||
} from '@modelcontextprotocol/sdk/types'
|
||||
import { MCPServer } from '@types'
|
||||
import { randomUUID } from 'crypto'
|
||||
import { EventEmitter } from 'events'
|
||||
import { Request, Response } from 'express'
|
||||
import { IncomingMessage, ServerResponse } from 'http'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { reduxService } from '../../services/ReduxService'
|
||||
import { getMcpServerById } from '../utils/mcp'
|
||||
|
||||
const logger = loggerService.withContext('MCPApiService')
|
||||
const transports: Record<string, StreamableHTTPServerTransport> = {}
|
||||
|
||||
interface McpServerDTO {
|
||||
id: MCPServer['id']
|
||||
name: MCPServer['name']
|
||||
type: MCPServer['type']
|
||||
description: MCPServer['description']
|
||||
url: string
|
||||
}
|
||||
|
||||
interface McpServersResp {
|
||||
servers: Record<string, McpServerDTO>
|
||||
}
|
||||
|
||||
/**
|
||||
* MCPApiService - API layer for MCP server management
|
||||
*
|
||||
* This service provides a REST API interface for MCP servers while integrating
|
||||
* with the existing application architecture:
|
||||
*
|
||||
* 1. Uses ReduxService to access the renderer's Redux store directly
|
||||
* 2. Syncs changes back to the renderer via Redux actions
|
||||
* 3. Leverages existing MCPService for actual server connections
|
||||
* 4. Provides session management for API clients
|
||||
*/
|
||||
class MCPApiService extends EventEmitter {
|
||||
private transport: StreamableHTTPServerTransport = new StreamableHTTPServerTransport({
|
||||
sessionIdGenerator: () => randomUUID()
|
||||
})
|
||||
|
||||
constructor() {
|
||||
super()
|
||||
this.initMcpServer()
|
||||
logger.silly('MCPApiService initialized')
|
||||
}
|
||||
|
||||
private initMcpServer() {
|
||||
this.transport.onmessage = this.onMessage
|
||||
}
|
||||
|
||||
/**
|
||||
* Get servers directly from Redux store
|
||||
*/
|
||||
private async getServersFromRedux(): Promise<MCPServer[]> {
|
||||
try {
|
||||
logger.silly('Getting servers from Redux store')
|
||||
|
||||
// Try to get from cache first (faster)
|
||||
const cachedServers = reduxService.selectSync<MCPServer[]>('state.mcp.servers')
|
||||
if (cachedServers && Array.isArray(cachedServers)) {
|
||||
logger.silly(`Found ${cachedServers.length} servers in Redux cache`)
|
||||
return cachedServers
|
||||
}
|
||||
|
||||
// If cache is not available, get fresh data
|
||||
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
|
||||
logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`)
|
||||
return servers || []
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get servers from Redux:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
// get all activated servers
|
||||
async getAllServers(req: Request): Promise<McpServersResp> {
|
||||
try {
|
||||
const servers = await this.getServersFromRedux()
|
||||
logger.silly(`Returning ${servers.length} servers`)
|
||||
const resp: McpServersResp = {
|
||||
servers: {}
|
||||
}
|
||||
for (const server of servers) {
|
||||
if (server.isActive) {
|
||||
resp.servers[server.id] = {
|
||||
id: server.id,
|
||||
name: server.name,
|
||||
type: 'streamableHttp',
|
||||
description: server.description,
|
||||
url: `${req.protocol}://${req.host}/v1/mcps/${server.id}/mcp`
|
||||
}
|
||||
}
|
||||
}
|
||||
return resp
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get all servers:', error)
|
||||
throw new Error('Failed to retrieve servers')
|
||||
}
|
||||
}
|
||||
|
||||
// get server by id
|
||||
async getServerById(id: string): Promise<MCPServer | null> {
|
||||
try {
|
||||
logger.silly(`getServerById called with id: ${id}`)
|
||||
const servers = await this.getServersFromRedux()
|
||||
const server = servers.find((s) => s.id === id)
|
||||
if (!server) {
|
||||
logger.warn(`Server with id ${id} not found`)
|
||||
return null
|
||||
}
|
||||
logger.silly(`Returning server with id ${id}`)
|
||||
return server
|
||||
} catch (error: any) {
|
||||
logger.error(`Failed to get server with id ${id}:`, error)
|
||||
throw new Error('Failed to retrieve server')
|
||||
}
|
||||
}
|
||||
|
||||
async getServerInfo(id: string): Promise<any> {
|
||||
try {
|
||||
logger.silly(`getServerInfo called with id: ${id}`)
|
||||
const server = await this.getServerById(id)
|
||||
if (!server) {
|
||||
logger.warn(`Server with id ${id} not found`)
|
||||
return null
|
||||
}
|
||||
logger.silly(`Returning server info for id ${id}`)
|
||||
|
||||
const client = await mcpService.initClient(server)
|
||||
const tools = await client.listTools()
|
||||
|
||||
logger.info(`Server with id ${id} info:`, { tools: JSON.stringify(tools) })
|
||||
|
||||
// const [version, tools, prompts, resources] = await Promise.all([
|
||||
// () => {
|
||||
// try {
|
||||
// return client.getServerVersion()
|
||||
// } catch (error) {
|
||||
// logger.error(`Failed to get server version for id ${id}:`, { error: error })
|
||||
// return '1.0.0'
|
||||
// }
|
||||
// },
|
||||
// (() => {
|
||||
// try {
|
||||
// return client.listTools()
|
||||
// } catch (error) {
|
||||
// logger.error(`Failed to list tools for id ${id}:`, { error: error })
|
||||
// return []
|
||||
// }
|
||||
// })(),
|
||||
// (() => {
|
||||
// try {
|
||||
// return client.listPrompts()
|
||||
// } catch (error) {
|
||||
// logger.error(`Failed to list prompts for id ${id}:`, { error: error })
|
||||
// return []
|
||||
// }
|
||||
// })(),
|
||||
// (() => {
|
||||
// try {
|
||||
// return client.listResources()
|
||||
// } catch (error) {
|
||||
// logger.error(`Failed to list resources for id ${id}:`, { error: error })
|
||||
// return []
|
||||
// }
|
||||
// })()
|
||||
// ])
|
||||
|
||||
return {
|
||||
id: server.id,
|
||||
name: server.name,
|
||||
type: server.type,
|
||||
description: server.description,
|
||||
tools
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error(`Failed to get server info with id ${id}:`, error)
|
||||
throw new Error('Failed to retrieve server info')
|
||||
}
|
||||
}
|
||||
|
||||
async handleRequest(req: Request, res: Response, server: MCPServer) {
|
||||
const sessionId = req.headers['mcp-session-id'] as string | undefined
|
||||
logger.silly(`Handling request for server with sessionId ${sessionId}`)
|
||||
let transport: StreamableHTTPServerTransport
|
||||
if (sessionId && transports[sessionId]) {
|
||||
transport = transports[sessionId]
|
||||
} else {
|
||||
transport = new StreamableHTTPServerTransport({
|
||||
sessionIdGenerator: () => randomUUID(),
|
||||
onsessioninitialized: (sessionId) => {
|
||||
transports[sessionId] = transport
|
||||
}
|
||||
})
|
||||
|
||||
transport.onclose = () => {
|
||||
logger.info(`Transport for sessionId ${sessionId} closed`)
|
||||
if (transport.sessionId) {
|
||||
delete transports[transport.sessionId]
|
||||
}
|
||||
}
|
||||
const mcpServer = await getMcpServerById(server.id)
|
||||
if (mcpServer) {
|
||||
await mcpServer.connect(transport)
|
||||
}
|
||||
}
|
||||
const jsonpayload = req.body
|
||||
const messages: JSONRPCMessage[] = []
|
||||
|
||||
if (Array.isArray(jsonpayload)) {
|
||||
for (const payload of jsonpayload) {
|
||||
const message = JSONRPCMessageSchema.parse(payload)
|
||||
messages.push(message)
|
||||
}
|
||||
} else {
|
||||
const message = JSONRPCMessageSchema.parse(jsonpayload)
|
||||
messages.push(message)
|
||||
}
|
||||
|
||||
for (const message of messages) {
|
||||
if (isJSONRPCRequest(message)) {
|
||||
if (!message.params) {
|
||||
message.params = {}
|
||||
}
|
||||
if (!message.params._meta) {
|
||||
message.params._meta = {}
|
||||
}
|
||||
message.params._meta.serverId = server.id
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`Request body`, { rawBody: req.body, messages: JSON.stringify(messages) })
|
||||
await transport.handleRequest(req as IncomingMessage, res as ServerResponse, messages)
|
||||
}
|
||||
|
||||
private onMessage(message: JSONRPCMessage, extra?: MessageExtraInfo) {
|
||||
logger.info(`Received message: ${JSON.stringify(message)}`, extra)
|
||||
// Handle message here
|
||||
}
|
||||
}
|
||||
|
||||
export const mcpApiService = new MCPApiService()
|
||||
231
src/main/apiServer/utils/index.ts
Normal file
231
src/main/apiServer/utils/index.ts
Normal file
@ -0,0 +1,231 @@
|
||||
import { loggerService } from '@main/services/LoggerService'
|
||||
import { reduxService } from '@main/services/ReduxService'
|
||||
import { Model, Provider } from '@types'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerUtils')
|
||||
|
||||
// OpenAI compatible model format
|
||||
export interface OpenAICompatibleModel {
|
||||
id: string
|
||||
object: 'model'
|
||||
created: number
|
||||
owned_by: string
|
||||
provider?: string
|
||||
provider_model_id?: string
|
||||
}
|
||||
|
||||
export async function getAvailableProviders(): Promise<Provider[]> {
|
||||
try {
|
||||
// Wait for store to be ready before accessing providers
|
||||
const providers = await reduxService.select('state.llm.providers')
|
||||
if (!providers || !Array.isArray(providers)) {
|
||||
logger.warn('No providers found in Redux store, returning empty array')
|
||||
return []
|
||||
}
|
||||
|
||||
// Only support OpenAI type providers for API server
|
||||
const openAIProviders = providers.filter((p: Provider) => p.enabled && p.type === 'openai')
|
||||
|
||||
logger.info(`Filtered to ${openAIProviders.length} OpenAI providers from ${providers.length} total providers`)
|
||||
|
||||
return openAIProviders
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get providers from Redux store:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
export async function listAllAvailableModels(): Promise<Model[]> {
|
||||
try {
|
||||
const providers = await getAvailableProviders()
|
||||
return providers.map((p: Provider) => p.models || []).flat()
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to list available models:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
export async function getProviderByModel(model: string): Promise<Provider | undefined> {
|
||||
try {
|
||||
if (!model || typeof model !== 'string') {
|
||||
logger.warn(`Invalid model parameter: ${model}`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
// Validate model format first
|
||||
if (!model.includes(':')) {
|
||||
logger.warn(
|
||||
`Invalid model format, must contain ':' separator. Expected format "provider:model_id", got: ${model}`
|
||||
)
|
||||
return undefined
|
||||
}
|
||||
|
||||
const providers = await getAvailableProviders()
|
||||
const modelInfo = model.split(':')
|
||||
|
||||
if (modelInfo.length < 2 || modelInfo[0].length === 0 || modelInfo[1].length === 0) {
|
||||
logger.warn(`Invalid model format, expected "provider:model_id" with non-empty parts, got: ${model}`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
const providerId = modelInfo[0]
|
||||
const provider = providers.find((p: Provider) => p.id === providerId)
|
||||
|
||||
if (!provider) {
|
||||
logger.warn(
|
||||
`Provider '${providerId}' not found or not enabled. Available providers: ${providers.map((p) => p.id).join(', ')}`
|
||||
)
|
||||
return undefined
|
||||
}
|
||||
|
||||
logger.debug(`Found provider '${providerId}' for model: ${model}`)
|
||||
return provider
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get provider by model:', error)
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
export function getRealProviderModel(modelStr: string): string {
|
||||
return modelStr.split(':').slice(1).join(':')
|
||||
}
|
||||
|
||||
export interface ModelValidationError {
|
||||
type: 'invalid_format' | 'provider_not_found' | 'model_not_available' | 'unsupported_provider_type'
|
||||
message: string
|
||||
code: string
|
||||
}
|
||||
|
||||
export async function validateModelId(
|
||||
model: string
|
||||
): Promise<{ valid: boolean; error?: ModelValidationError; provider?: Provider; modelId?: string }> {
|
||||
try {
|
||||
if (!model || typeof model !== 'string') {
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
type: 'invalid_format',
|
||||
message: 'Model must be a non-empty string',
|
||||
code: 'invalid_model_parameter'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!model.includes(':')) {
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
type: 'invalid_format',
|
||||
message: "Invalid model format. Expected format: 'provider:model_id' (e.g., 'my-openai:gpt-4')",
|
||||
code: 'invalid_model_format'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const modelInfo = model.split(':')
|
||||
if (modelInfo.length < 2 || modelInfo[0].length === 0 || modelInfo[1].length === 0) {
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
type: 'invalid_format',
|
||||
message: "Invalid model format. Both provider and model_id must be non-empty. Expected: 'provider:model_id'",
|
||||
code: 'invalid_model_format'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const providerId = modelInfo[0]
|
||||
const modelId = getRealProviderModel(model)
|
||||
const provider = await getProviderByModel(model)
|
||||
|
||||
if (!provider) {
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
type: 'provider_not_found',
|
||||
message: `Provider '${providerId}' not found, not enabled, or not supported. Only OpenAI providers are currently supported.`,
|
||||
code: 'provider_not_found'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if model exists in provider
|
||||
const modelExists = provider.models?.some((m) => m.id === modelId)
|
||||
if (!modelExists) {
|
||||
const availableModels = provider.models?.map((m) => m.id).join(', ') || 'none'
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
type: 'model_not_available',
|
||||
message: `Model '${modelId}' not available in provider '${providerId}'. Available models: ${availableModels}`,
|
||||
code: 'model_not_available'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
valid: true,
|
||||
provider,
|
||||
modelId
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error('Error validating model ID:', error)
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
type: 'invalid_format',
|
||||
message: 'Failed to validate model ID',
|
||||
code: 'validation_error'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function transformModelToOpenAI(model: Model): OpenAICompatibleModel {
|
||||
return {
|
||||
id: `${model.provider}:${model.id}`,
|
||||
object: 'model',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
owned_by: model.owned_by || model.provider,
|
||||
provider: model.provider,
|
||||
provider_model_id: model.id
|
||||
}
|
||||
}
|
||||
|
||||
export function validateProvider(provider: Provider): boolean {
|
||||
try {
|
||||
if (!provider) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check required fields
|
||||
if (!provider.id || !provider.type || !provider.apiKey || !provider.apiHost) {
|
||||
logger.warn('Provider missing required fields:', {
|
||||
id: !!provider.id,
|
||||
type: !!provider.type,
|
||||
apiKey: !!provider.apiKey,
|
||||
apiHost: !!provider.apiHost
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if provider is enabled
|
||||
if (!provider.enabled) {
|
||||
logger.debug(`Provider is disabled: ${provider.id}`)
|
||||
return false
|
||||
}
|
||||
|
||||
// Only support OpenAI type providers
|
||||
if (provider.type !== 'openai') {
|
||||
logger.debug(
|
||||
`Provider type '${provider.type}' not supported, only 'openai' type is currently supported: ${provider.id}`
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (error: any) {
|
||||
logger.error('Error validating provider:', error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
76
src/main/apiServer/utils/mcp.ts
Normal file
76
src/main/apiServer/utils/mcp.ts
Normal file
@ -0,0 +1,76 @@
|
||||
import mcpService from '@main/services/MCPService'
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||
import { CallToolRequestSchema, ListToolsRequestSchema, ListToolsResult } from '@modelcontextprotocol/sdk/types.js'
|
||||
import { MCPServer } from '@types'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { reduxService } from '../../services/ReduxService'
|
||||
|
||||
const logger = loggerService.withContext('MCPApiService')
|
||||
|
||||
const cachedServers: Record<string, Server> = {}
|
||||
|
||||
async function handleListToolsRequest(request: any, extra: any): Promise<ListToolsResult> {
|
||||
logger.debug('Handling list tools request', { request: request, extra: extra })
|
||||
const serverId: string = request.params._meta.serverId
|
||||
const serverConfig = await getMcpServerConfigById(serverId)
|
||||
if (!serverConfig) {
|
||||
throw new Error(`Server not found: ${serverId}`)
|
||||
}
|
||||
const client = await mcpService.initClient(serverConfig)
|
||||
return client.listTools()
|
||||
}
|
||||
|
||||
async function handleCallToolRequest(request: any, extra: any): Promise<any> {
|
||||
logger.debug('Handling call tool request', { request: request, extra: extra })
|
||||
const serverId: string = request.params._meta.serverId
|
||||
const serverConfig = await getMcpServerConfigById(serverId)
|
||||
if (!serverConfig) {
|
||||
throw new Error(`Server not found: ${serverId}`)
|
||||
}
|
||||
const client = await mcpService.initClient(serverConfig)
|
||||
return client.callTool(request.params)
|
||||
}
|
||||
|
||||
async function getMcpServerConfigById(id: string): Promise<MCPServer | undefined> {
|
||||
const servers = await getServersFromRedux()
|
||||
return servers.find((s) => s.id === id || s.name === id)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get servers directly from Redux store
|
||||
*/
|
||||
async function getServersFromRedux(): Promise<MCPServer[]> {
|
||||
try {
|
||||
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
|
||||
logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`)
|
||||
return servers || []
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get servers from Redux:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
export async function getMcpServerById(id: string): Promise<Server> {
|
||||
const server = cachedServers[id]
|
||||
if (!server) {
|
||||
const servers = await getServersFromRedux()
|
||||
const mcpServer = servers.find((s) => s.id === id || s.name === id)
|
||||
if (!mcpServer) {
|
||||
throw new Error(`Server not found: ${id}`)
|
||||
}
|
||||
|
||||
const createMcpServer = (name: string, version: string): Server => {
|
||||
const server = new Server({ name: name, version }, { capabilities: { tools: {} } })
|
||||
server.setRequestHandler(ListToolsRequestSchema, handleListToolsRequest)
|
||||
server.setRequestHandler(CallToolRequestSchema, handleCallToolRequest)
|
||||
return server
|
||||
}
|
||||
|
||||
const newServer = createMcpServer(mcpServer.name, '0.1.0')
|
||||
cachedServers[id] = newServer
|
||||
return newServer
|
||||
}
|
||||
logger.silly('getMcpServer ', { server: server })
|
||||
return server
|
||||
}
|
||||
@ -31,6 +31,7 @@ import { TrayService } from './services/TrayService'
|
||||
import { windowService } from './services/WindowService'
|
||||
import { dataRefactorMigrateService } from './data/migrate/dataRefactor/DataRefactorMigrateService'
|
||||
import process from 'node:process'
|
||||
import { apiServerService } from './services/ApiServerService'
|
||||
|
||||
const logger = loggerService.withContext('MainEntry')
|
||||
|
||||
@ -219,6 +220,17 @@ if (!app.requestSingleInstanceLock()) {
|
||||
|
||||
//start selection assistant service
|
||||
initSelectionService()
|
||||
|
||||
// Start API server if enabled
|
||||
try {
|
||||
const config = await apiServerService.getCurrentConfig()
|
||||
logger.info('API server config:', config)
|
||||
if (config.enabled) {
|
||||
await apiServerService.start()
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to check/start API server:', error)
|
||||
}
|
||||
})
|
||||
|
||||
registerProtocolClient(app)
|
||||
@ -264,6 +276,7 @@ if (!app.requestSingleInstanceLock()) {
|
||||
// 简单的资源清理,不阻塞退出流程
|
||||
try {
|
||||
await mcpService.cleanup()
|
||||
await apiServerService.stop()
|
||||
} catch (error) {
|
||||
logger.warn('Error cleaning up MCP service:', error as Error)
|
||||
}
|
||||
|
||||
@ -15,9 +15,12 @@ import { MIN_WINDOW_HEIGHT, MIN_WINDOW_WIDTH } from '@shared/config/constant'
|
||||
import { UpgradeChannel } from '@shared/data/preferenceTypes'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { FileMetadata, Provider, Shortcut } from '@types'
|
||||
import checkDiskSpace from 'check-disk-space'
|
||||
import { BrowserWindow, dialog, ipcMain, ProxyConfig, session, shell, systemPreferences, webContents } from 'electron'
|
||||
import fontList from 'font-list'
|
||||
import { Notification } from 'src/renderer/src/types/notification'
|
||||
|
||||
import { apiServerService } from './services/ApiServerService'
|
||||
import appService from './services/AppService'
|
||||
import AppUpdater from './services/AppUpdater'
|
||||
import BackupManager from './services/BackupManager'
|
||||
@ -217,6 +220,17 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
return mainWindow.isFullScreen()
|
||||
})
|
||||
|
||||
// Get System Fonts
|
||||
ipcMain.handle(IpcChannel.App_GetSystemFonts, async () => {
|
||||
try {
|
||||
const fonts = await fontList.getFonts()
|
||||
return fonts.map((font: string) => font.replace(/^"(.*)"$/, '$1')).filter((font: string) => font.length > 0)
|
||||
} catch (error) {
|
||||
logger.error('Failed to get system fonts:', error as Error)
|
||||
return []
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.Config_Set, (_, key: string, value: any, isNotify: boolean = false) => {
|
||||
configManager.set(key, value, isNotify)
|
||||
})
|
||||
@ -782,6 +796,23 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
addStreamMessage(spanId, modelName, context, msg)
|
||||
)
|
||||
|
||||
ipcMain.handle(IpcChannel.App_GetDiskInfo, async (_, directoryPath: string) => {
|
||||
try {
|
||||
const diskSpace = await checkDiskSpace(directoryPath) // { free, size } in bytes
|
||||
logger.debug('disk space', diskSpace)
|
||||
const { free, size } = diskSpace
|
||||
return {
|
||||
free,
|
||||
size
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('check disk space error', error as Error)
|
||||
return null
|
||||
}
|
||||
})
|
||||
// API Server
|
||||
apiServerService.registerIpcHandlers()
|
||||
|
||||
// Anthropic OAuth
|
||||
ipcMain.handle(IpcChannel.Anthropic_StartOAuthFlow, () => anthropicService.startOAuthFlow())
|
||||
ipcMain.handle(IpcChannel.Anthropic_CompleteOAuthWithCode, (_, code: string) =>
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import { Embeddings, type EmbeddingsParams } from '@langchain/core/embeddings'
|
||||
import { chunkArray } from '@langchain/core/utils/chunk_array'
|
||||
import { getEnvironmentVariable } from '@langchain/core/utils/env'
|
||||
import z from 'zod/v4'
|
||||
import { z } from 'zod'
|
||||
|
||||
const jinaModelSchema = z.union([
|
||||
z.literal('jina-clip-v2'),
|
||||
|
||||
@ -3,7 +3,7 @@ import { loggerService } from '@logger'
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'
|
||||
import { net } from 'electron'
|
||||
import * as z from 'zod/v4'
|
||||
import { z } from 'zod'
|
||||
|
||||
const logger = loggerService.withContext('DifyKnowledgeServer')
|
||||
|
||||
|
||||
@ -8,8 +8,8 @@ import TurndownService from 'turndown'
|
||||
import { z } from 'zod'
|
||||
|
||||
export const RequestPayloadSchema = z.object({
|
||||
url: z.string().url(),
|
||||
headers: z.record(z.string()).optional()
|
||||
url: z.url(),
|
||||
headers: z.record(z.string(), z.string()).optional()
|
||||
})
|
||||
|
||||
export type RequestPayload = z.infer<typeof RequestPayloadSchema>
|
||||
|
||||
@ -8,7 +8,7 @@ import fs from 'fs/promises'
|
||||
import { minimatch } from 'minimatch'
|
||||
import os from 'os'
|
||||
import path from 'path'
|
||||
import * as z from 'zod/v4'
|
||||
import { z } from 'zod'
|
||||
|
||||
const logger = loggerService.withContext('MCP:FileSystemServer')
|
||||
|
||||
|
||||
108
src/main/services/ApiServerService.ts
Normal file
108
src/main/services/ApiServerService.ts
Normal file
@ -0,0 +1,108 @@
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { ApiServerConfig } from '@types'
|
||||
import { ipcMain } from 'electron'
|
||||
|
||||
import { apiServer } from '../apiServer'
|
||||
import { config } from '../apiServer/config'
|
||||
import { loggerService } from './LoggerService'
|
||||
const logger = loggerService.withContext('ApiServerService')
|
||||
|
||||
export class ApiServerService {
|
||||
constructor() {
|
||||
// Use the new clean implementation
|
||||
}
|
||||
|
||||
async start(): Promise<void> {
|
||||
try {
|
||||
await apiServer.start()
|
||||
logger.info('API Server started successfully')
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to start API Server:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {
|
||||
try {
|
||||
await apiServer.stop()
|
||||
logger.info('API Server stopped successfully')
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to stop API Server:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async restart(): Promise<void> {
|
||||
try {
|
||||
await apiServer.restart()
|
||||
logger.info('API Server restarted successfully')
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to restart API Server:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
isRunning(): boolean {
|
||||
return apiServer.isRunning()
|
||||
}
|
||||
|
||||
async getCurrentConfig(): Promise<ApiServerConfig> {
|
||||
return config.get()
|
||||
}
|
||||
|
||||
registerIpcHandlers(): void {
|
||||
// API Server
|
||||
ipcMain.handle(IpcChannel.ApiServer_Start, async () => {
|
||||
try {
|
||||
await this.start()
|
||||
return { success: true }
|
||||
} catch (error: any) {
|
||||
return { success: false, error: error instanceof Error ? error.message : 'Unknown error' }
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.ApiServer_Stop, async () => {
|
||||
try {
|
||||
await this.stop()
|
||||
return { success: true }
|
||||
} catch (error: any) {
|
||||
return { success: false, error: error instanceof Error ? error.message : 'Unknown error' }
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.ApiServer_Restart, async () => {
|
||||
try {
|
||||
await this.restart()
|
||||
return { success: true }
|
||||
} catch (error: any) {
|
||||
return { success: false, error: error instanceof Error ? error.message : 'Unknown error' }
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.ApiServer_GetStatus, async () => {
|
||||
try {
|
||||
const config = await this.getCurrentConfig()
|
||||
return {
|
||||
running: this.isRunning(),
|
||||
config
|
||||
}
|
||||
} catch (error: any) {
|
||||
return {
|
||||
running: this.isRunning(),
|
||||
config: null
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.ApiServer_GetConfig, async () => {
|
||||
try {
|
||||
return this.getCurrentConfig()
|
||||
} catch (error: any) {
|
||||
return null
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton instance
|
||||
export const apiServerService = new ApiServerService()
|
||||
@ -332,14 +332,15 @@ class CodeToolsService {
|
||||
// macOS - Use osascript to launch terminal and execute command directly, without showing startup command
|
||||
const envPrefix = buildEnvPrefix(false)
|
||||
const command = envPrefix ? `${envPrefix} && ${baseCommand}` : baseCommand
|
||||
// Combine directory change with the main command to ensure they execute in the same shell session
|
||||
const fullCommand = `cd '${directory.replace(/'/g, "\\'")}' && clear && ${command}`
|
||||
|
||||
terminalCommand = 'osascript'
|
||||
terminalArgs = [
|
||||
'-e',
|
||||
`tell application "Terminal"
|
||||
set newTab to do script "cd '${directory.replace(/'/g, "\\'")}' && clear"
|
||||
do script "${fullCommand.replace(/"/g, '\\"')}"
|
||||
activate
|
||||
do script "${command.replace(/"/g, '\\"')}" in newTab
|
||||
end tell`
|
||||
]
|
||||
break
|
||||
|
||||
@ -16,6 +16,7 @@ import {
|
||||
type StreamableHTTPClientTransportOptions
|
||||
} from '@modelcontextprotocol/sdk/client/streamableHttp'
|
||||
import { InMemoryTransport } from '@modelcontextprotocol/sdk/inMemory'
|
||||
import { McpError, type Tool as SDKTool } from '@modelcontextprotocol/sdk/types'
|
||||
// Import notification schemas from MCP SDK
|
||||
import {
|
||||
CancelledNotificationSchema,
|
||||
@ -29,6 +30,7 @@ import {
|
||||
import { nanoid } from '@reduxjs/toolkit'
|
||||
import { MCPProgressEvent } from '@shared/config/types'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import {
|
||||
BuiltinMCPServerNames,
|
||||
type GetResourceResponse,
|
||||
@ -94,7 +96,7 @@ function getServerLogger(server: MCPServer, extra?: Record<string, any>) {
|
||||
baseUrl: server?.baseUrl,
|
||||
type: server?.type || (server?.command ? 'stdio' : server?.baseUrl ? 'http' : 'inmemory')
|
||||
}
|
||||
return loggerService.withContext('MCPService', { ...base, ...(extra || {}) })
|
||||
return loggerService.withContext('MCPService', { ...base, ...extra })
|
||||
}
|
||||
|
||||
/**
|
||||
@ -193,11 +195,18 @@ class McpService {
|
||||
return existingClient
|
||||
}
|
||||
} catch (error: any) {
|
||||
getServerLogger(server).error(`Error pinging server`, error as Error)
|
||||
getServerLogger(server).error(`Error pinging server ${server.name}`, error as Error)
|
||||
this.clients.delete(serverKey)
|
||||
}
|
||||
}
|
||||
|
||||
const prepareHeaders = () => {
|
||||
return {
|
||||
...defaultAppHeaders(),
|
||||
...server.headers
|
||||
}
|
||||
}
|
||||
|
||||
// Create a promise for the initialization process
|
||||
const initPromise = (async () => {
|
||||
try {
|
||||
@ -235,8 +244,11 @@ class McpService {
|
||||
} else if (server.baseUrl) {
|
||||
if (server.type === 'streamableHttp') {
|
||||
const options: StreamableHTTPClientTransportOptions = {
|
||||
fetch: async (url, init) => {
|
||||
return net.fetch(typeof url === 'string' ? url : url.toString(), init)
|
||||
},
|
||||
requestInit: {
|
||||
headers: server.headers || {}
|
||||
headers: prepareHeaders()
|
||||
},
|
||||
authProvider
|
||||
}
|
||||
@ -249,25 +261,11 @@ class McpService {
|
||||
const options: SSEClientTransportOptions = {
|
||||
eventSourceInit: {
|
||||
fetch: async (url, init) => {
|
||||
const headers = { ...(server.headers || {}), ...(init?.headers || {}) }
|
||||
|
||||
// Get tokens from authProvider to make sure using the latest tokens
|
||||
if (authProvider && typeof authProvider.tokens === 'function') {
|
||||
try {
|
||||
const tokens = await authProvider.tokens()
|
||||
if (tokens && tokens.access_token) {
|
||||
headers['Authorization'] = `Bearer ${tokens.access_token}`
|
||||
}
|
||||
} catch (error) {
|
||||
getServerLogger(server).error('Failed to fetch tokens:', error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
return net.fetch(typeof url === 'string' ? url : url.toString(), { ...init, headers })
|
||||
return net.fetch(typeof url === 'string' ? url : url.toString(), init)
|
||||
}
|
||||
},
|
||||
requestInit: {
|
||||
headers: server.headers || {}
|
||||
headers: prepareHeaders()
|
||||
},
|
||||
authProvider
|
||||
}
|
||||
@ -444,9 +442,9 @@ class McpService {
|
||||
|
||||
logger.debug(`Activated server: ${server.name}`)
|
||||
return client
|
||||
} catch (error: any) {
|
||||
getServerLogger(server).error(`Error activating server`, error as Error)
|
||||
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
|
||||
} catch (error) {
|
||||
getServerLogger(server).error(`Error activating server ${server.name}`, error as Error)
|
||||
throw error
|
||||
}
|
||||
} finally {
|
||||
// Clean up the pending promise when done
|
||||
@ -614,12 +612,11 @@ class McpService {
|
||||
}
|
||||
|
||||
private async listToolsImpl(server: MCPServer): Promise<MCPTool[]> {
|
||||
getServerLogger(server).debug(`Listing tools`)
|
||||
const client = await this.initClient(server)
|
||||
try {
|
||||
const { tools } = await client.listTools()
|
||||
const serverTools: MCPTool[] = []
|
||||
tools.map((tool: any) => {
|
||||
tools.map((tool: SDKTool) => {
|
||||
const serverTool: MCPTool = {
|
||||
...tool,
|
||||
id: buildFunctionCallToolName(server.name, tool.name),
|
||||
@ -628,11 +625,12 @@ class McpService {
|
||||
type: 'mcp'
|
||||
}
|
||||
serverTools.push(serverTool)
|
||||
getServerLogger(server).debug(`Listing tools`, { tool: serverTool })
|
||||
})
|
||||
return serverTools
|
||||
} catch (error: any) {
|
||||
} catch (error: unknown) {
|
||||
getServerLogger(server).error(`Failed to list tools`, error as Error)
|
||||
return []
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@ -739,9 +737,9 @@ class McpService {
|
||||
serverId: server.id,
|
||||
serverName: server.name
|
||||
}))
|
||||
} catch (error: any) {
|
||||
} catch (error: unknown) {
|
||||
// -32601 is the code for the method not found
|
||||
if (error?.code !== -32601) {
|
||||
if (error instanceof McpError && error.code !== -32601) {
|
||||
getServerLogger(server).error(`Failed to list prompts`, error as Error)
|
||||
}
|
||||
return []
|
||||
|
||||
@ -115,7 +115,7 @@ class KnowledgeService {
|
||||
const framework = knowledgeFrameworkFactory.getFramework(base)
|
||||
await framework.initialize(base)
|
||||
}
|
||||
public async reset(_: Electron.IpcMainInvokeEvent, { base }: { base: KnowledgeBaseParams }): Promise<void> {
|
||||
public async reset(_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise<void> {
|
||||
const framework = knowledgeFrameworkFactory.getFramework(base)
|
||||
await framework.reset(base)
|
||||
}
|
||||
|
||||
@ -30,7 +30,7 @@ import {
|
||||
KnowledgeBaseParams,
|
||||
KnowledgeSearchResult
|
||||
} from '@types'
|
||||
import { uuidv4 } from 'zod/v4'
|
||||
import { uuidv4 } from 'zod'
|
||||
|
||||
import { windowService } from '../WindowService'
|
||||
import {
|
||||
@ -103,6 +103,8 @@ export class LangChainFramework implements IKnowledgeFramework {
|
||||
if (fs.existsSync(dbPath)) {
|
||||
fs.rmSync(dbPath, { recursive: true })
|
||||
}
|
||||
// 立即重建空索引,避免随后加载时报错
|
||||
await this.createDatabase(base)
|
||||
}
|
||||
|
||||
async delete(id: string): Promise<void> {
|
||||
|
||||
@ -3,6 +3,7 @@ import { Provider } from '@types'
|
||||
import { BaseFileService } from './BaseFileService'
|
||||
import { GeminiService } from './GeminiService'
|
||||
import { MistralService } from './MistralService'
|
||||
import { OpenaiService } from './OpenAIService'
|
||||
|
||||
export class FileServiceManager {
|
||||
private static instance: FileServiceManager
|
||||
@ -30,6 +31,9 @@ export class FileServiceManager {
|
||||
case 'mistral':
|
||||
service = new MistralService(provider)
|
||||
break
|
||||
case 'openai':
|
||||
service = new OpenaiService(provider)
|
||||
break
|
||||
default:
|
||||
throw new Error(`Unsupported service type: ${type}`)
|
||||
}
|
||||
|
||||
125
src/main/services/remotefile/OpenAIService.ts
Normal file
125
src/main/services/remotefile/OpenAIService.ts
Normal file
@ -0,0 +1,125 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { fileStorage } from '@main/services/FileStorage'
|
||||
import { FileListResponse, FileMetadata, FileUploadResponse, Provider } from '@types'
|
||||
import * as fs from 'fs'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CacheService } from '../CacheService'
|
||||
import { BaseFileService } from './BaseFileService'
|
||||
|
||||
const logger = loggerService.withContext('OpenAIService')
|
||||
|
||||
export class OpenaiService extends BaseFileService {
|
||||
private static readonly FILE_CACHE_DURATION = 7 * 24 * 60 * 60 * 1000
|
||||
private static readonly generateUIFileIdCacheKey = (fileId: string) => `ui_file_id_${fileId}`
|
||||
private readonly client: OpenAI
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.client = new OpenAI({
|
||||
apiKey: provider.apiKey,
|
||||
baseURL: provider.apiHost
|
||||
})
|
||||
}
|
||||
|
||||
async uploadFile(file: FileMetadata): Promise<FileUploadResponse> {
|
||||
let fileReadStream: fs.ReadStream | undefined
|
||||
try {
|
||||
fileReadStream = fs.createReadStream(fileStorage.getFilePathById(file))
|
||||
// 还原文件原始名,以提高模型对文件的理解
|
||||
const fileStreamWithMeta = Object.assign(fileReadStream, {
|
||||
name: file.origin_name
|
||||
})
|
||||
const response = await this.client.files.create({
|
||||
file: fileStreamWithMeta,
|
||||
purpose: file.purpose || 'assistants'
|
||||
})
|
||||
if (!response.id) {
|
||||
throw new Error('File id not found in response')
|
||||
}
|
||||
// 映射RemoteFileId到UIFileId上
|
||||
CacheService.set<string>(
|
||||
OpenaiService.generateUIFileIdCacheKey(file.id),
|
||||
response.id,
|
||||
OpenaiService.FILE_CACHE_DURATION
|
||||
)
|
||||
return {
|
||||
fileId: response.id,
|
||||
displayName: file.origin_name,
|
||||
status: 'success',
|
||||
originalFile: {
|
||||
type: 'openai',
|
||||
file: response
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error uploading file:', error as Error)
|
||||
return {
|
||||
fileId: '',
|
||||
displayName: file.origin_name,
|
||||
status: 'failed'
|
||||
}
|
||||
} finally {
|
||||
// 销毁文件流
|
||||
if (fileReadStream) fileReadStream.destroy()
|
||||
}
|
||||
}
|
||||
|
||||
async listFiles(): Promise<FileListResponse> {
|
||||
try {
|
||||
const response = await this.client.files.list()
|
||||
return {
|
||||
files: response.data.map((file) => ({
|
||||
id: file.id,
|
||||
displayName: file.filename || '',
|
||||
size: file.bytes,
|
||||
status: 'success', // All listed files are processed,
|
||||
originalFile: {
|
||||
type: 'openai',
|
||||
file
|
||||
}
|
||||
}))
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error listing files:', error as Error)
|
||||
return { files: [] }
|
||||
}
|
||||
}
|
||||
|
||||
async deleteFile(fileId: string): Promise<void> {
|
||||
try {
|
||||
const cachedRemoteFileId = CacheService.get<string>(OpenaiService.generateUIFileIdCacheKey(fileId))
|
||||
await this.client.files.delete(cachedRemoteFileId || fileId)
|
||||
logger.debug(`File ${fileId} deleted`)
|
||||
} catch (error) {
|
||||
logger.error('Error deleting file:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async retrieveFile(fileId: string): Promise<FileUploadResponse> {
|
||||
try {
|
||||
// 尝试反映射RemoteFileId
|
||||
const cachedRemoteFileId = CacheService.get<string>(OpenaiService.generateUIFileIdCacheKey(fileId))
|
||||
const response = await this.client.files.retrieve(cachedRemoteFileId || fileId)
|
||||
|
||||
return {
|
||||
fileId: response.id,
|
||||
displayName: response.filename,
|
||||
status: 'success',
|
||||
originalFile: {
|
||||
type: 'openai',
|
||||
file: response
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error retrieving file:', error as Error)
|
||||
return {
|
||||
fileId: fileId,
|
||||
displayName: '',
|
||||
status: 'failed',
|
||||
originalFile: undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -398,11 +398,15 @@ export function validateFileName(fileName: string, platform = process.platform):
|
||||
* @returns 合法的文件名
|
||||
*/
|
||||
export function checkName(fileName: string): string {
|
||||
const validation = validateFileName(fileName)
|
||||
const baseName = path.basename(fileName)
|
||||
const validation = validateFileName(baseName)
|
||||
if (!validation.valid) {
|
||||
throw new Error(`Invalid file name: ${fileName}. ${validation.error}`)
|
||||
// 自动清理非法字符,而不是抛出错误
|
||||
const sanitized = sanitizeFilename(baseName)
|
||||
logger.warn(`File name contains invalid characters, auto-sanitized: "${baseName}" -> "${sanitized}"`)
|
||||
return sanitized
|
||||
}
|
||||
return fileName
|
||||
return baseName
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -42,6 +42,8 @@ export function tracedInvoke(channel: string, spanContext: SpanContext | undefin
|
||||
// Custom APIs for renderer
|
||||
const api = {
|
||||
getAppInfo: () => ipcRenderer.invoke(IpcChannel.App_Info),
|
||||
getDiskInfo: (directoryPath: string): Promise<{ free: number; size: number } | null> =>
|
||||
ipcRenderer.invoke(IpcChannel.App_GetDiskInfo, directoryPath),
|
||||
reload: () => ipcRenderer.invoke(IpcChannel.App_Reload),
|
||||
setProxy: (proxy: string | undefined, bypassRules?: string) =>
|
||||
ipcRenderer.invoke(IpcChannel.App_Proxy, proxy, bypassRules),
|
||||
@ -80,6 +82,7 @@ const api = {
|
||||
ipcRenderer.invoke(IpcChannel.App_LogToMain, source, level, message, data),
|
||||
setFullScreen: (value: boolean): Promise<void> => ipcRenderer.invoke(IpcChannel.App_SetFullScreen, value),
|
||||
isFullScreen: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_IsFullScreen),
|
||||
getSystemFonts: (): Promise<string[]> => ipcRenderer.invoke(IpcChannel.App_GetSystemFonts),
|
||||
mac: {
|
||||
isProcessTrusted: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_MacIsProcessTrusted),
|
||||
requestProcessTrust: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_MacRequestProcessTrust)
|
||||
@ -445,9 +448,10 @@ const api = {
|
||||
isMaximized: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.Windows_IsMaximized),
|
||||
onMaximizedChange: (callback: (isMaximized: boolean) => void): (() => void) => {
|
||||
const channel = IpcChannel.Windows_MaximizedChanged
|
||||
ipcRenderer.on(channel, (_, isMaximized: boolean) => callback(isMaximized))
|
||||
const listener = (_: Electron.IpcRendererEvent, isMaximized: boolean) => callback(isMaximized)
|
||||
ipcRenderer.on(channel, listener)
|
||||
return () => {
|
||||
ipcRenderer.removeAllListeners(channel)
|
||||
ipcRenderer.removeListener(channel, listener)
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import '@renderer/databases'
|
||||
|
||||
import { HeroUIProvider } from '@heroui/react'
|
||||
import { preferenceService } from '@data/PreferenceService'
|
||||
import { loggerService } from '@logger'
|
||||
import store, { persistor } from '@renderer/store'
|
||||
@ -7,6 +8,7 @@ import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { Provider } from 'react-redux'
|
||||
import { PersistGate } from 'redux-persist/integration/react'
|
||||
|
||||
import { ToastPortal } from './components/ToastPortal'
|
||||
import TopViewContainer from './components/TopView'
|
||||
import AntdProvider from './context/AntdProvider'
|
||||
import { CodeStyleProvider } from './context/CodeStyleProvider'
|
||||
@ -35,21 +37,24 @@ function App(): React.ReactElement {
|
||||
return (
|
||||
<Provider store={store}>
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<StyleSheetManager>
|
||||
<ThemeProvider>
|
||||
<AntdProvider>
|
||||
<NotificationProvider>
|
||||
<CodeStyleProvider>
|
||||
<PersistGate loading={null} persistor={persistor}>
|
||||
<TopViewContainer>
|
||||
<Router />
|
||||
</TopViewContainer>
|
||||
</PersistGate>
|
||||
</CodeStyleProvider>
|
||||
</NotificationProvider>
|
||||
</AntdProvider>
|
||||
</ThemeProvider>
|
||||
</StyleSheetManager>
|
||||
<HeroUIProvider className="flex h-full w-full flex-1">
|
||||
<StyleSheetManager>
|
||||
<ThemeProvider>
|
||||
<AntdProvider>
|
||||
<NotificationProvider>
|
||||
<CodeStyleProvider>
|
||||
<PersistGate loading={null} persistor={persistor}>
|
||||
<TopViewContainer>
|
||||
<Router />
|
||||
</TopViewContainer>
|
||||
</PersistGate>
|
||||
</CodeStyleProvider>
|
||||
</NotificationProvider>
|
||||
</AntdProvider>
|
||||
</ThemeProvider>
|
||||
</StyleSheetManager>
|
||||
<ToastPortal />
|
||||
</HeroUIProvider>
|
||||
</QueryClientProvider>
|
||||
</Provider>
|
||||
)
|
||||
|
||||
@ -4,8 +4,9 @@
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import { MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types'
|
||||
import { AISDKWebSearchResult, MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { convertLinks, flushLinkConverterBuffer } from '@renderer/utils/linkConverter'
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { ToolCallChunkHandler } from './handleToolCallChunk'
|
||||
@ -29,13 +30,18 @@ export interface CherryStudioChunk {
|
||||
export class AiSdkToChunkAdapter {
|
||||
toolCallHandler: ToolCallChunkHandler
|
||||
private accumulate: boolean | undefined
|
||||
private isFirstChunk = true
|
||||
private enableWebSearch: boolean = false
|
||||
|
||||
constructor(
|
||||
private onChunk: (chunk: Chunk) => void,
|
||||
mcpTools: MCPTool[] = [],
|
||||
accumulate?: boolean
|
||||
accumulate?: boolean,
|
||||
enableWebSearch?: boolean
|
||||
) {
|
||||
this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools)
|
||||
this.accumulate = accumulate
|
||||
this.enableWebSearch = enableWebSearch || false
|
||||
}
|
||||
|
||||
/**
|
||||
@ -65,11 +71,24 @@ export class AiSdkToChunkAdapter {
|
||||
webSearchResults: [],
|
||||
reasoningId: ''
|
||||
}
|
||||
// Reset link converter state at the start of stream
|
||||
this.isFirstChunk = true
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
|
||||
if (done) {
|
||||
// Flush any remaining content from link converter buffer if web search is enabled
|
||||
if (this.enableWebSearch) {
|
||||
const remainingText = flushLinkConverterBuffer()
|
||||
if (remainingText) {
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: remainingText
|
||||
})
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
@ -87,9 +106,9 @@ export class AiSdkToChunkAdapter {
|
||||
*/
|
||||
private convertAndEmitChunk(
|
||||
chunk: TextStreamPart<any>,
|
||||
final: { text: string; reasoningContent: string; webSearchResults: any[]; reasoningId: string }
|
||||
final: { text: string; reasoningContent: string; webSearchResults: AISDKWebSearchResult[]; reasoningId: string }
|
||||
) {
|
||||
logger.info(`AI SDK chunk type: ${chunk.type}`, chunk)
|
||||
logger.silly(`AI SDK chunk type: ${chunk.type}`, chunk)
|
||||
switch (chunk.type) {
|
||||
// === 文本相关事件 ===
|
||||
case 'text-start':
|
||||
@ -97,17 +116,44 @@ export class AiSdkToChunkAdapter {
|
||||
type: ChunkType.TEXT_START
|
||||
})
|
||||
break
|
||||
case 'text-delta':
|
||||
if (this.accumulate) {
|
||||
final.text += chunk.text || ''
|
||||
case 'text-delta': {
|
||||
const processedText = chunk.text || ''
|
||||
let finalText: string
|
||||
|
||||
// Only apply link conversion if web search is enabled
|
||||
if (this.enableWebSearch) {
|
||||
const result = convertLinks(processedText, this.isFirstChunk)
|
||||
|
||||
if (this.isFirstChunk) {
|
||||
this.isFirstChunk = false
|
||||
}
|
||||
|
||||
// Handle buffered content
|
||||
if (result.hasBufferedContent) {
|
||||
finalText = result.text
|
||||
} else {
|
||||
finalText = result.text || processedText
|
||||
}
|
||||
} else {
|
||||
final.text = chunk.text || ''
|
||||
// Without web search, just use the original text
|
||||
finalText = processedText
|
||||
}
|
||||
|
||||
if (this.accumulate) {
|
||||
final.text += finalText
|
||||
} else {
|
||||
final.text = finalText
|
||||
}
|
||||
|
||||
// Only emit chunk if there's text to send
|
||||
if (finalText) {
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: this.accumulate ? final.text : finalText
|
||||
})
|
||||
}
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: final.text || ''
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'text-end':
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
@ -152,12 +198,14 @@ export class AiSdkToChunkAdapter {
|
||||
// this.toolCallHandler.handleToolCallCreated(chunk)
|
||||
// break
|
||||
case 'tool-call':
|
||||
// 原始的工具调用(未被中间件处理)
|
||||
this.toolCallHandler.handleToolCall(chunk)
|
||||
break
|
||||
|
||||
case 'tool-error':
|
||||
this.toolCallHandler.handleToolError(chunk)
|
||||
break
|
||||
|
||||
case 'tool-result':
|
||||
// 原始的工具调用结果(未被中间件处理)
|
||||
this.toolCallHandler.handleToolResult(chunk)
|
||||
break
|
||||
|
||||
@ -167,7 +215,6 @@ export class AiSdkToChunkAdapter {
|
||||
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||
// })
|
||||
// break
|
||||
// TODO: 需要区分接口开始和步骤开始
|
||||
// case 'start-step':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.BLOCK_CREATED
|
||||
@ -199,7 +246,7 @@ export class AiSdkToChunkAdapter {
|
||||
[WebSearchSource.ANTHROPIC]: WebSearchSource.ANTHROPIC,
|
||||
[WebSearchSource.OPENROUTER]: WebSearchSource.OPENROUTER,
|
||||
[WebSearchSource.GEMINI]: WebSearchSource.GEMINI,
|
||||
[WebSearchSource.PERPLEXITY]: WebSearchSource.PERPLEXITY,
|
||||
// [WebSearchSource.PERPLEXITY]: WebSearchSource.PERPLEXITY,
|
||||
[WebSearchSource.QWEN]: WebSearchSource.QWEN,
|
||||
[WebSearchSource.HUNYUAN]: WebSearchSource.HUNYUAN,
|
||||
[WebSearchSource.ZHIPU]: WebSearchSource.ZHIPU,
|
||||
@ -267,18 +314,9 @@ export class AiSdkToChunkAdapter {
|
||||
// === 源和文件相关事件 ===
|
||||
case 'source':
|
||||
if (chunk.sourceType === 'url') {
|
||||
// if (final.webSearchResults.length === 0) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const { sourceType: _, ...rest } = chunk
|
||||
final.webSearchResults.push(rest)
|
||||
// }
|
||||
// this.onChunk({
|
||||
// type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
// llm_web_search: {
|
||||
// source: WebSearchSource.AISDK,
|
||||
// results: final.webSearchResults
|
||||
// }
|
||||
// })
|
||||
}
|
||||
break
|
||||
case 'file':
|
||||
@ -305,8 +343,6 @@ export class AiSdkToChunkAdapter {
|
||||
break
|
||||
|
||||
default:
|
||||
// 其他类型的 chunk 可以忽略或记录日志
|
||||
// console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -8,34 +8,61 @@ import { loggerService } from '@logger'
|
||||
import { processKnowledgeReferences } from '@renderer/services/KnowledgeService'
|
||||
import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from 'ai'
|
||||
// import type {
|
||||
// AnthropicSearchOutput,
|
||||
// WebSearchPluginConfig
|
||||
// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
|
||||
import type { ToolSet, TypedToolCall, TypedToolError, TypedToolResult } from 'ai'
|
||||
|
||||
const logger = loggerService.withContext('ToolCallChunkHandler')
|
||||
|
||||
export type ToolcallsMap = {
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
args: any
|
||||
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
|
||||
tool: BaseTool
|
||||
}
|
||||
/**
|
||||
* 工具调用处理器类
|
||||
*/
|
||||
export class ToolCallChunkHandler {
|
||||
// private onChunk: (chunk: Chunk) => void
|
||||
private activeToolCalls = new Map<
|
||||
string,
|
||||
{
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
args: any
|
||||
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
|
||||
tool: BaseTool
|
||||
}
|
||||
>()
|
||||
private static globalActiveToolCalls = new Map<string, ToolcallsMap>()
|
||||
|
||||
private activeToolCalls = ToolCallChunkHandler.globalActiveToolCalls
|
||||
constructor(
|
||||
private onChunk: (chunk: Chunk) => void,
|
||||
private mcpTools: MCPTool[]
|
||||
) {}
|
||||
|
||||
/**
|
||||
* 内部静态方法:添加活跃工具调用的核心逻辑
|
||||
*/
|
||||
private static addActiveToolCallImpl(toolCallId: string, map: ToolcallsMap): boolean {
|
||||
if (!ToolCallChunkHandler.globalActiveToolCalls.has(toolCallId)) {
|
||||
ToolCallChunkHandler.globalActiveToolCalls.set(toolCallId, map)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* 实例方法:添加活跃工具调用
|
||||
*/
|
||||
private addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
|
||||
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取全局活跃的工具调用
|
||||
*/
|
||||
public static getActiveToolCalls() {
|
||||
return ToolCallChunkHandler.globalActiveToolCalls
|
||||
}
|
||||
|
||||
/**
|
||||
* 静态方法:添加活跃工具调用(外部访问)
|
||||
*/
|
||||
public static addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
|
||||
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
|
||||
}
|
||||
|
||||
// /**
|
||||
// * 设置 onChunk 回调
|
||||
// */
|
||||
@ -43,103 +70,103 @@ export class ToolCallChunkHandler {
|
||||
// this.onChunk = callback
|
||||
// }
|
||||
|
||||
handleToolCallCreated(
|
||||
chunk:
|
||||
| {
|
||||
type: 'tool-input-start'
|
||||
id: string
|
||||
toolName: string
|
||||
providerMetadata?: ProviderMetadata
|
||||
providerExecuted?: boolean
|
||||
}
|
||||
| {
|
||||
type: 'tool-input-end'
|
||||
id: string
|
||||
providerMetadata?: ProviderMetadata
|
||||
}
|
||||
| {
|
||||
type: 'tool-input-delta'
|
||||
id: string
|
||||
delta: string
|
||||
providerMetadata?: ProviderMetadata
|
||||
}
|
||||
): void {
|
||||
switch (chunk.type) {
|
||||
case 'tool-input-start': {
|
||||
// 能拿到说明是mcpTool
|
||||
// if (this.activeToolCalls.get(chunk.id)) return
|
||||
// handleToolCallCreated(
|
||||
// chunk:
|
||||
// | {
|
||||
// type: 'tool-input-start'
|
||||
// id: string
|
||||
// toolName: string
|
||||
// providerMetadata?: ProviderMetadata
|
||||
// providerExecuted?: boolean
|
||||
// }
|
||||
// | {
|
||||
// type: 'tool-input-end'
|
||||
// id: string
|
||||
// providerMetadata?: ProviderMetadata
|
||||
// }
|
||||
// | {
|
||||
// type: 'tool-input-delta'
|
||||
// id: string
|
||||
// delta: string
|
||||
// providerMetadata?: ProviderMetadata
|
||||
// }
|
||||
// ): void {
|
||||
// switch (chunk.type) {
|
||||
// case 'tool-input-start': {
|
||||
// // 能拿到说明是mcpTool
|
||||
// // if (this.activeToolCalls.get(chunk.id)) return
|
||||
|
||||
const tool: BaseTool | MCPTool = {
|
||||
id: chunk.id,
|
||||
name: chunk.toolName,
|
||||
description: chunk.toolName,
|
||||
type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
|
||||
}
|
||||
this.activeToolCalls.set(chunk.id, {
|
||||
toolCallId: chunk.id,
|
||||
toolName: chunk.toolName,
|
||||
args: '',
|
||||
tool
|
||||
})
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: chunk.id,
|
||||
tool: tool,
|
||||
arguments: {},
|
||||
status: 'pending',
|
||||
toolCallId: chunk.id
|
||||
}
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_PENDING,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'tool-input-delta': {
|
||||
const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
if (!toolCall) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
return
|
||||
}
|
||||
toolCall.args += chunk.delta
|
||||
break
|
||||
}
|
||||
case 'tool-input-end': {
|
||||
const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
this.activeToolCalls.delete(chunk.id)
|
||||
if (!toolCall) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
return
|
||||
}
|
||||
// const toolResponse: ToolCallResponse = {
|
||||
// id: toolCall.toolCallId,
|
||||
// tool: toolCall.tool,
|
||||
// arguments: toolCall.args,
|
||||
// status: 'pending',
|
||||
// toolCallId: toolCall.toolCallId
|
||||
// }
|
||||
// logger.debug('toolResponse', toolResponse)
|
||||
// this.onChunk({
|
||||
// type: ChunkType.MCP_TOOL_PENDING,
|
||||
// responses: [toolResponse]
|
||||
// })
|
||||
break
|
||||
}
|
||||
}
|
||||
// if (!toolCall) {
|
||||
// Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
// return
|
||||
// }
|
||||
// this.onChunk({
|
||||
// type: ChunkType.MCP_TOOL_CREATED,
|
||||
// tool_calls: [
|
||||
// {
|
||||
// id: chunk.id,
|
||||
// name: chunk.toolName,
|
||||
// status: 'pending'
|
||||
// }
|
||||
// ]
|
||||
// })
|
||||
}
|
||||
// const tool: BaseTool | MCPTool = {
|
||||
// id: chunk.id,
|
||||
// name: chunk.toolName,
|
||||
// description: chunk.toolName,
|
||||
// type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
|
||||
// }
|
||||
// this.activeToolCalls.set(chunk.id, {
|
||||
// toolCallId: chunk.id,
|
||||
// toolName: chunk.toolName,
|
||||
// args: '',
|
||||
// tool
|
||||
// })
|
||||
// const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
// id: chunk.id,
|
||||
// tool: tool,
|
||||
// arguments: {},
|
||||
// status: 'pending',
|
||||
// toolCallId: chunk.id
|
||||
// }
|
||||
// this.onChunk({
|
||||
// type: ChunkType.MCP_TOOL_PENDING,
|
||||
// responses: [toolResponse]
|
||||
// })
|
||||
// break
|
||||
// }
|
||||
// case 'tool-input-delta': {
|
||||
// const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
// if (!toolCall) {
|
||||
// logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
// return
|
||||
// }
|
||||
// toolCall.args += chunk.delta
|
||||
// break
|
||||
// }
|
||||
// case 'tool-input-end': {
|
||||
// const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
// this.activeToolCalls.delete(chunk.id)
|
||||
// if (!toolCall) {
|
||||
// logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
// return
|
||||
// }
|
||||
// // const toolResponse: ToolCallResponse = {
|
||||
// // id: toolCall.toolCallId,
|
||||
// // tool: toolCall.tool,
|
||||
// // arguments: toolCall.args,
|
||||
// // status: 'pending',
|
||||
// // toolCallId: toolCall.toolCallId
|
||||
// // }
|
||||
// // logger.debug('toolResponse', toolResponse)
|
||||
// // this.onChunk({
|
||||
// // type: ChunkType.MCP_TOOL_PENDING,
|
||||
// // responses: [toolResponse]
|
||||
// // })
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// // if (!toolCall) {
|
||||
// // Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
// // return
|
||||
// // }
|
||||
// // this.onChunk({
|
||||
// // type: ChunkType.MCP_TOOL_CREATED,
|
||||
// // tool_calls: [
|
||||
// // {
|
||||
// // id: chunk.id,
|
||||
// // name: chunk.toolName,
|
||||
// // status: 'pending'
|
||||
// // }
|
||||
// // ]
|
||||
// // })
|
||||
// }
|
||||
|
||||
/**
|
||||
* 处理工具调用事件
|
||||
@ -158,7 +185,6 @@ export class ToolCallChunkHandler {
|
||||
|
||||
let tool: BaseTool
|
||||
let mcpTool: MCPTool | undefined
|
||||
|
||||
// 根据 providerExecuted 标志区分处理逻辑
|
||||
if (providerExecuted) {
|
||||
// 如果是 Provider 执行的工具(如 web_search)
|
||||
@ -196,27 +222,25 @@ export class ToolCallChunkHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// 记录活跃的工具调用
|
||||
this.activeToolCalls.set(toolCallId, {
|
||||
this.addActiveToolCall(toolCallId, {
|
||||
toolCallId,
|
||||
toolName,
|
||||
args,
|
||||
tool
|
||||
})
|
||||
|
||||
// 创建 MCPToolResponse 格式
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: tool,
|
||||
arguments: args,
|
||||
status: 'pending',
|
||||
status: 'pending', // 统一使用 pending 状态
|
||||
toolCallId: toolCallId
|
||||
}
|
||||
|
||||
// 调用 onChunk
|
||||
if (this.onChunk) {
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_PENDING,
|
||||
type: ChunkType.MCP_TOOL_PENDING, // 统一发送 pending 状态
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
@ -257,7 +281,7 @@ export class ToolCallChunkHandler {
|
||||
// 工具特定的后处理
|
||||
switch (toolResponse.tool.name) {
|
||||
case 'builtin_knowledge_search': {
|
||||
processKnowledgeReferences(toolResponse.response?.knowledgeReferences, this.onChunk)
|
||||
processKnowledgeReferences(toolResponse.response, this.onChunk)
|
||||
break
|
||||
}
|
||||
// 未来可以在这里添加其他工具的后处理逻辑
|
||||
@ -276,4 +300,33 @@ export class ToolCallChunkHandler {
|
||||
})
|
||||
}
|
||||
}
|
||||
handleToolError(
|
||||
chunk: {
|
||||
type: 'tool-error'
|
||||
} & TypedToolError<ToolSet>
|
||||
): void {
|
||||
const { toolCallId, error, input } = chunk
|
||||
const toolCallInfo = this.activeToolCalls.get(toolCallId)
|
||||
if (!toolCallInfo) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
|
||||
return
|
||||
}
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: toolCallInfo.tool,
|
||||
arguments: input,
|
||||
status: 'error',
|
||||
response: error,
|
||||
toolCallId: toolCallId
|
||||
}
|
||||
this.activeToolCalls.delete(toolCallId)
|
||||
if (this.onChunk) {
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const addActiveToolCall = ToolCallChunkHandler.addActiveToolCall.bind(ToolCallChunkHandler)
|
||||
|
||||
@ -265,15 +265,15 @@ export default class ModernAiProvider {
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
const modelId = this.model!.id
|
||||
logger.info('Starting modernCompletions', {
|
||||
modelId,
|
||||
providerId: this.config!.providerId,
|
||||
topicId: config.topicId,
|
||||
hasOnChunk: !!config.onChunk,
|
||||
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||
toolCount: params.tools ? Object.keys(params.tools).length : 0
|
||||
})
|
||||
// const modelId = this.model!.id
|
||||
// logger.info('Starting modernCompletions', {
|
||||
// modelId,
|
||||
// providerId: this.config!.providerId,
|
||||
// topicId: config.topicId,
|
||||
// hasOnChunk: !!config.onChunk,
|
||||
// hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||
// toolCount: params.tools ? Object.keys(params.tools).length : 0
|
||||
// })
|
||||
|
||||
// 根据条件构建插件数组
|
||||
const plugins = await buildPlugins(config)
|
||||
@ -284,7 +284,7 @@ export default class ModernAiProvider {
|
||||
// 创建带有中间件的执行器
|
||||
if (config.onChunk) {
|
||||
const accumulate = this.model!.supported_text_delta !== false // true and undefined
|
||||
const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools, accumulate)
|
||||
const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools, accumulate, config.enableWebSearch)
|
||||
|
||||
const streamResult = await executor.streamText({
|
||||
...params,
|
||||
|
||||
@ -45,6 +45,7 @@ import { isJSON, parseJSON } from '@renderer/utils'
|
||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import { REFERENCE_PROMPT } from '@shared/config/prompts'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
@ -179,8 +180,7 @@ export abstract class BaseApiClient<
|
||||
|
||||
public defaultHeaders() {
|
||||
return {
|
||||
'HTTP-Referer': 'https://cherry-ai.com',
|
||||
'X-Title': 'Cherry Studio',
|
||||
...defaultAppHeaders(),
|
||||
'X-Api-Key': this.apiKey
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { flushLinkConverterBuffer, smartLinkConverter } from '@renderer/utils/linkConverter'
|
||||
import { convertLinks, flushLinkConverterBuffer } from '@renderer/utils/linkConverter'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
@ -28,8 +28,6 @@ export const WebSearchMiddleware: CompletionsMiddleware =
|
||||
}
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
const model = params.assistant?.model!
|
||||
let isFirstChunk = true
|
||||
|
||||
// 响应后处理:记录Web搜索事件
|
||||
@ -42,15 +40,9 @@ export const WebSearchMiddleware: CompletionsMiddleware =
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
const providerType = model.provider || 'openai'
|
||||
// 使用当前可用的Web搜索结果进行链接转换
|
||||
const text = chunk.text
|
||||
const result = smartLinkConverter(
|
||||
text,
|
||||
providerType,
|
||||
isFirstChunk,
|
||||
ctx._internal.webSearchState!.results
|
||||
)
|
||||
const result = convertLinks(text, isFirstChunk)
|
||||
if (isFirstChunk) {
|
||||
isFirstChunk = false
|
||||
}
|
||||
|
||||
@ -20,8 +20,10 @@ export interface AiSdkMiddlewareConfig {
|
||||
isSupportedToolUse: boolean
|
||||
// image generation endpoint
|
||||
isImageGenerationEndpoint: boolean
|
||||
// 是否开启内置搜索
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
enableUrlContext: boolean
|
||||
mcpTools?: MCPTool[]
|
||||
uiMessages?: Message[]
|
||||
}
|
||||
@ -132,7 +134,6 @@ export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageMo
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('builder.build()', builder.buildNamed())
|
||||
return builder.build()
|
||||
}
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { AiPlugin } from '@cherrystudio/ai-core'
|
||||
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import { createPromptToolUsePlugin, googleToolsPlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import { preferenceService } from '@data/PreferenceService'
|
||||
import { loggerService } from '@logger'
|
||||
import type { Assistant } from '@renderer/types'
|
||||
@ -35,7 +35,7 @@ export async function buildPlugins(
|
||||
plugins.push(webSearchPlugin())
|
||||
}
|
||||
// 2. 支持工具调用时添加搜索插件
|
||||
if (middlewareConfig.isSupportedToolUse) {
|
||||
if (middlewareConfig.isSupportedToolUse || middlewareConfig.isPromptToolUse) {
|
||||
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant, middlewareConfig.topicId || ''))
|
||||
}
|
||||
|
||||
@ -45,12 +45,13 @@ export async function buildPlugins(
|
||||
}
|
||||
|
||||
// 4. 启用Prompt工具调用时添加工具插件
|
||||
if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
||||
if (middlewareConfig.isPromptToolUse) {
|
||||
plugins.push(
|
||||
createPromptToolUsePlugin({
|
||||
enabled: true,
|
||||
createSystemMessage: (systemPrompt, params, context) => {
|
||||
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
|
||||
const modelId = typeof context.model === 'string' ? context.model : context.model.modelId
|
||||
if (modelId.includes('o1-mini') || modelId.includes('o1-preview')) {
|
||||
if (context.isRecursiveCall) {
|
||||
return null
|
||||
}
|
||||
@ -69,10 +70,11 @@ export async function buildPlugins(
|
||||
)
|
||||
}
|
||||
|
||||
// if (!middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
||||
// plugins.push(createNativeToolUsePlugin())
|
||||
// }
|
||||
logger.info(
|
||||
if (middlewareConfig.enableUrlContext) {
|
||||
plugins.push(googleToolsPlugin({ urlContext: true }))
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
'Final plugin list:',
|
||||
plugins.map((p) => p.name)
|
||||
)
|
||||
|
||||
@ -19,7 +19,8 @@ import {
|
||||
SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY,
|
||||
SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
} from '@shared/config/prompts'
|
||||
import type { ModelMessage } from 'ai'
|
||||
import type { LanguageModel, ModelMessage } from 'ai'
|
||||
import { generateText } from 'ai'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||
@ -76,9 +77,7 @@ async function analyzeSearchIntent(
|
||||
shouldKnowledgeSearch?: boolean
|
||||
shouldMemorySearch?: boolean
|
||||
lastAnswer?: ModelMessage
|
||||
context: AiRequestContext & {
|
||||
isAnalyzing?: boolean
|
||||
}
|
||||
context: AiRequestContext
|
||||
topicId: string
|
||||
}
|
||||
): Promise<ExtractResults | undefined> {
|
||||
@ -122,9 +121,7 @@ async function analyzeSearchIntent(
|
||||
logger.error('Provider not found or missing API key')
|
||||
return getFallbackResult()
|
||||
}
|
||||
// console.log('formattedPrompt', schema)
|
||||
try {
|
||||
context.isAnalyzing = true
|
||||
logger.info('Starting intent analysis generateText call', {
|
||||
modelId: model.id,
|
||||
topicId: options.topicId,
|
||||
@ -133,18 +130,16 @@ async function analyzeSearchIntent(
|
||||
hasKnowledgeSearch: needKnowledgeExtract
|
||||
})
|
||||
|
||||
const { text: result } = await context.executor
|
||||
.generateText(model.id, {
|
||||
prompt: formattedPrompt
|
||||
})
|
||||
.finally(() => {
|
||||
context.isAnalyzing = false
|
||||
logger.info('Intent analysis generateText call completed', {
|
||||
modelId: model.id,
|
||||
topicId: options.topicId,
|
||||
requestId: context.requestId
|
||||
})
|
||||
const { text: result } = await generateText({
|
||||
model: context.model as LanguageModel,
|
||||
prompt: formattedPrompt
|
||||
}).finally(() => {
|
||||
logger.info('Intent analysis generateText call completed', {
|
||||
modelId: model.id,
|
||||
topicId: options.topicId,
|
||||
requestId: context.requestId
|
||||
})
|
||||
})
|
||||
const parsedResult = extractInfoFromXML(result)
|
||||
logger.debug('Intent analysis result', { parsedResult })
|
||||
|
||||
@ -183,7 +178,6 @@ async function storeConversationMemory(
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
|
||||
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
||||
// console.log('Memory storage is disabled')
|
||||
return
|
||||
}
|
||||
|
||||
@ -245,25 +239,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
||||
// 存储意图分析结果
|
||||
const intentAnalysisResults: { [requestId: string]: ExtractResults } = {}
|
||||
const userMessages: { [requestId: string]: ModelMessage } = {}
|
||||
let currentContext: AiRequestContext | null = null
|
||||
|
||||
return definePlugin({
|
||||
name: 'search-orchestration',
|
||||
enforce: 'pre', // 确保在其他插件之前执行
|
||||
|
||||
configureContext: (context: AiRequestContext) => {
|
||||
if (currentContext) {
|
||||
context.isAnalyzing = currentContext.isAnalyzing
|
||||
}
|
||||
currentContext = context
|
||||
},
|
||||
|
||||
/**
|
||||
* 🔍 Step 1: 意图识别阶段
|
||||
*/
|
||||
onRequestStart: async (context: AiRequestContext) => {
|
||||
if (context.isAnalyzing) return
|
||||
|
||||
// 没开启任何搜索则不进行意图分析
|
||||
if (!(assistant.webSearchProviderId || assistant.knowledge_bases?.length || assistant.enableMemory)) return
|
||||
|
||||
@ -284,7 +267,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
|
||||
const shouldWebSearch = !!assistant.webSearchProviderId
|
||||
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
|
||||
const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory
|
||||
@ -315,7 +297,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
||||
* 🔧 Step 2: 工具配置阶段
|
||||
*/
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
if (context.isAnalyzing) return params
|
||||
// logger.info('🔧 Configuring tools based on intent...', context.requestId)
|
||||
|
||||
try {
|
||||
@ -409,7 +390,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
||||
// context.isAnalyzing = false
|
||||
// logger.info('context.isAnalyzing', context, result)
|
||||
// logger.info('💾 Starting memory storage...', context.requestId)
|
||||
if (context.isAnalyzing) return
|
||||
try {
|
||||
const messages = context.originalParams.messages
|
||||
|
||||
|
||||
@ -58,91 +58,6 @@ class AdapterTracer {
|
||||
})
|
||||
}
|
||||
|
||||
// startSpan(name: string, options?: any, context?: any): Span {
|
||||
// // 如果提供了父 SpanContext 且未显式传入 context,则使用父上下文
|
||||
// const contextToUse = context ?? this.cachedParentContext ?? otelContext.active()
|
||||
|
||||
// const span = this.originalTracer.startSpan(name, options, contextToUse)
|
||||
|
||||
// // 标记父子关系,便于在转换阶段兜底重建层级
|
||||
// try {
|
||||
// if (this.parentSpanContext) {
|
||||
// span.setAttribute('trace.parentSpanId', this.parentSpanContext.spanId)
|
||||
// span.setAttribute('trace.parentTraceId', this.parentSpanContext.traceId)
|
||||
// }
|
||||
// if (this.topicId) {
|
||||
// span.setAttribute('trace.topicId', this.topicId)
|
||||
// }
|
||||
// } catch (e) {
|
||||
// logger.debug('Failed to set trace parent attributes', e as Error)
|
||||
// }
|
||||
|
||||
// logger.info('AI SDK span created via AdapterTracer', {
|
||||
// spanName: name,
|
||||
// spanId: span.spanContext().spanId,
|
||||
// traceId: span.spanContext().traceId,
|
||||
// parentTraceId: this.parentSpanContext?.traceId,
|
||||
// topicId: this.topicId,
|
||||
// modelName: this.modelName,
|
||||
// traceIdMatches: this.parentSpanContext ? span.spanContext().traceId === this.parentSpanContext.traceId : undefined
|
||||
// })
|
||||
|
||||
// // 包装 span 的 end 方法,在结束时进行数据转换
|
||||
// const originalEnd = span.end.bind(span)
|
||||
// span.end = (endTime?: any) => {
|
||||
// logger.info('AI SDK span.end() called - about to convert span', {
|
||||
// spanName: name,
|
||||
// spanId: span.spanContext().spanId,
|
||||
// traceId: span.spanContext().traceId,
|
||||
// topicId: this.topicId,
|
||||
// modelName: this.modelName
|
||||
// })
|
||||
|
||||
// // 调用原始 end 方法
|
||||
// originalEnd(endTime)
|
||||
|
||||
// // 转换并保存 span 数据
|
||||
// try {
|
||||
// logger.info('Converting AI SDK span to SpanEntity', {
|
||||
// spanName: name,
|
||||
// spanId: span.spanContext().spanId,
|
||||
// traceId: span.spanContext().traceId,
|
||||
// topicId: this.topicId,
|
||||
// modelName: this.modelName
|
||||
// })
|
||||
// logger.info('spanspanspanspanspanspan', span)
|
||||
// const spanEntity = AiSdkSpanAdapter.convertToSpanEntity({
|
||||
// span,
|
||||
// topicId: this.topicId,
|
||||
// modelName: this.modelName
|
||||
// })
|
||||
|
||||
// // 保存转换后的数据
|
||||
// window.api.trace.saveEntity(spanEntity)
|
||||
|
||||
// logger.info('AI SDK span converted and saved successfully', {
|
||||
// spanName: name,
|
||||
// spanId: span.spanContext().spanId,
|
||||
// traceId: span.spanContext().traceId,
|
||||
// topicId: this.topicId,
|
||||
// modelName: this.modelName,
|
||||
// hasUsage: !!spanEntity.usage,
|
||||
// usage: spanEntity.usage
|
||||
// })
|
||||
// } catch (error) {
|
||||
// logger.error('Failed to convert AI SDK span', error as Error, {
|
||||
// spanName: name,
|
||||
// spanId: span.spanContext().spanId,
|
||||
// traceId: span.spanContext().traceId,
|
||||
// topicId: this.topicId,
|
||||
// modelName: this.modelName
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
// return span
|
||||
// }
|
||||
|
||||
startActiveSpan<F extends (span: Span) => any>(name: string, fn: F): ReturnType<F>
|
||||
startActiveSpan<F extends (span: Span) => any>(name: string, options: any, fn: F): ReturnType<F>
|
||||
startActiveSpan<F extends (span: Span) => any>(name: string, options: any, context: any, fn: F): ReturnType<F>
|
||||
|
||||
@ -10,6 +10,7 @@ import { FileTypes } from '@renderer/types'
|
||||
import { FileMessageBlock } from '@renderer/types/newMessage'
|
||||
import { findFileBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import type { FilePart, TextPart } from 'ai'
|
||||
import type OpenAI from 'openai'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import { getFileSizeLimit, supportsImageInput, supportsLargeFileUpload, supportsPdfInput } from './modelCapabilities'
|
||||
@ -112,6 +113,86 @@ export async function handleGeminiFileUpload(file: FileMetadata, model: Model):
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理OpenAI大文件上传
|
||||
*/
|
||||
export async function handleOpenAILargeFileUpload(
|
||||
file: FileMetadata,
|
||||
model: Model
|
||||
): Promise<(FilePart & { id?: string }) | null> {
|
||||
const provider = getProviderByModel(model)
|
||||
// 如果模型为qwen-long系列,文档中要求purpose需要为'file-extract'
|
||||
if (['qwen-long', 'qwen-doc'].some((modelName) => model.name.includes(modelName))) {
|
||||
file = {
|
||||
...file,
|
||||
// 该类型并不在OpenAI定义中,但符合sdk规范,强制断言
|
||||
purpose: 'file-extract' as OpenAI.FilePurpose
|
||||
}
|
||||
}
|
||||
try {
|
||||
// 检查文件是否已经上传过
|
||||
const fileMetadata = await window.api.fileService.retrieve(provider, file.id)
|
||||
if (fileMetadata.status === 'success' && fileMetadata.originalFile?.file) {
|
||||
// 断言OpenAIFile对象
|
||||
const remoteFile = fileMetadata.originalFile.file as OpenAI.Files.FileObject
|
||||
// 判断用途是否一致
|
||||
if (remoteFile.purpose !== file.purpose) {
|
||||
logger.warn(`File ${file.origin_name} purpose mismatch: ${remoteFile.purpose} vs ${file.purpose}`)
|
||||
throw new Error('File purpose mismatch')
|
||||
}
|
||||
return {
|
||||
type: 'file',
|
||||
filename: file.origin_name,
|
||||
mediaType: '',
|
||||
data: `fileid://${remoteFile.id}`
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to retrieve file ${file.origin_name}:`, error as Error)
|
||||
return null
|
||||
}
|
||||
try {
|
||||
// 如果文件未上传,执行上传
|
||||
const uploadResult = await window.api.fileService.upload(provider, file)
|
||||
if (uploadResult.originalFile?.file) {
|
||||
// 断言OpenAIFile对象
|
||||
const remoteFile = uploadResult.originalFile.file as OpenAI.Files.FileObject
|
||||
logger.info(`File ${file.origin_name} uploaded.`)
|
||||
return {
|
||||
type: 'file',
|
||||
filename: remoteFile.filename,
|
||||
mediaType: '',
|
||||
data: `fileid://${remoteFile.id}`
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to upload file ${file.origin_name}:`, error as Error)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 大文件上传路由函数
|
||||
*/
|
||||
export async function handleLargeFileUpload(
|
||||
file: FileMetadata,
|
||||
model: Model
|
||||
): Promise<(FilePart & { id?: string }) | null> {
|
||||
const provider = getProviderByModel(model)
|
||||
const aiSdkId = getAiSdkProviderId(provider)
|
||||
|
||||
if (['google', 'google-generative-ai', 'google-vertex'].includes(aiSdkId)) {
|
||||
return await handleGeminiFileUpload(file, model)
|
||||
}
|
||||
|
||||
if (provider.type === 'openai') {
|
||||
return await handleOpenAILargeFileUpload(file, model)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 将文件块转换为FilePart(用于原生文件支持)
|
||||
*/
|
||||
@ -127,7 +208,7 @@ export async function convertFileBlockToFilePart(fileBlock: FileMessageBlock, mo
|
||||
// 如果支持大文件上传(如Gemini File API),尝试上传
|
||||
if (supportsLargeFileUpload(model)) {
|
||||
logger.info(`Large PDF file ${file.origin_name} (${file.size} bytes) attempting File API upload`)
|
||||
const uploadResult = await handleGeminiFileUpload(file, model)
|
||||
const uploadResult = await handleLargeFileUpload(file, model)
|
||||
if (uploadResult) {
|
||||
return uploadResult
|
||||
}
|
||||
|
||||
@ -13,7 +13,15 @@ import {
|
||||
findThinkingBlocks,
|
||||
getMainTextContent
|
||||
} from '@renderer/utils/messageUtils/find'
|
||||
import type { AssistantModelMessage, FilePart, ImagePart, ModelMessage, TextPart, UserModelMessage } from 'ai'
|
||||
import type {
|
||||
AssistantModelMessage,
|
||||
FilePart,
|
||||
ImagePart,
|
||||
ModelMessage,
|
||||
SystemModelMessage,
|
||||
TextPart,
|
||||
UserModelMessage
|
||||
} from 'ai'
|
||||
|
||||
import { convertFileBlockToFilePart, convertFileBlockToTextPart } from './fileProcessor'
|
||||
|
||||
@ -27,7 +35,7 @@ export async function convertMessageToSdkParam(
|
||||
message: Message,
|
||||
isVisionModel = false,
|
||||
model?: Model
|
||||
): Promise<ModelMessage> {
|
||||
): Promise<ModelMessage | ModelMessage[]> {
|
||||
const content = getMainTextContent(message)
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
@ -48,7 +56,7 @@ async function convertMessageToUserModelMessage(
|
||||
imageBlocks: ImageMessageBlock[],
|
||||
isVisionModel = false,
|
||||
model?: Model
|
||||
): Promise<UserModelMessage> {
|
||||
): Promise<UserModelMessage | (UserModelMessage | SystemModelMessage)[]> {
|
||||
const parts: Array<TextPart | FilePart | ImagePart> = []
|
||||
if (content) {
|
||||
parts.push({ type: 'text', text: content })
|
||||
@ -85,6 +93,19 @@ async function convertMessageToUserModelMessage(
|
||||
if (model) {
|
||||
const filePart = await convertFileBlockToFilePart(fileBlock, model)
|
||||
if (filePart) {
|
||||
// 判断filePart是否为string
|
||||
if (typeof filePart.data === 'string' && filePart.data.startsWith('fileid://')) {
|
||||
return [
|
||||
{
|
||||
role: 'system',
|
||||
content: filePart.data
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: parts.length > 0 ? parts : ''
|
||||
}
|
||||
]
|
||||
}
|
||||
parts.push(filePart)
|
||||
logger.debug(`File ${file.origin_name} processed as native file format`)
|
||||
processed = true
|
||||
@ -159,7 +180,7 @@ export async function convertMessagesToSdkMessages(messages: Message[], model: M
|
||||
|
||||
for (const message of messages) {
|
||||
const sdkMessage = await convertMessageToSdkParam(message, isVision, model)
|
||||
sdkMessages.push(sdkMessage)
|
||||
sdkMessages.push(...(Array.isArray(sdkMessage) ? sdkMessage : [sdkMessage]))
|
||||
}
|
||||
|
||||
return sdkMessages
|
||||
|
||||
@ -10,26 +10,61 @@ import { FileTypes } from '@renderer/types'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
|
||||
// 工具函数:基于模型名和提供商判断是否支持某特性
|
||||
function modelSupportValidator(
|
||||
model: Model,
|
||||
{
|
||||
supportedModels = [],
|
||||
unsupportedModels = [],
|
||||
supportedProviders = [],
|
||||
unsupportedProviders = []
|
||||
}: {
|
||||
supportedModels?: string[]
|
||||
unsupportedModels?: string[]
|
||||
supportedProviders?: string[]
|
||||
unsupportedProviders?: string[]
|
||||
}
|
||||
): boolean {
|
||||
const provider = getProviderByModel(model)
|
||||
const aiSdkId = getAiSdkProviderId(provider)
|
||||
|
||||
// 黑名单:命中不支持的模型直接拒绝
|
||||
if (unsupportedModels.some((name) => model.name.includes(name))) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 黑名单:命中不支持的提供商直接拒绝,常用于某些提供商的同名模型并不具备原模型的某些特性
|
||||
if (unsupportedProviders.includes(aiSdkId)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 白名单:命中支持的模型名
|
||||
if (supportedModels.some((name) => model.name.includes(name))) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 回退到提供商判断
|
||||
return supportedProviders.includes(aiSdkId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查模型是否支持原生PDF输入
|
||||
*/
|
||||
export function supportsPdfInput(model: Model): boolean {
|
||||
// 基于AI SDK文档,这些提供商支持PDF输入
|
||||
const supportedProviders = [
|
||||
'openai',
|
||||
'azure-openai',
|
||||
'anthropic',
|
||||
'google',
|
||||
'google-generative-ai',
|
||||
'google-vertex',
|
||||
'bedrock',
|
||||
'amazon-bedrock'
|
||||
]
|
||||
|
||||
const provider = getProviderByModel(model)
|
||||
const aiSdkId = getAiSdkProviderId(provider)
|
||||
|
||||
return supportedProviders.some((provider) => aiSdkId === provider)
|
||||
// 基于AI SDK文档,以下模型或提供商支持PDF输入
|
||||
return modelSupportValidator(model, {
|
||||
supportedModels: ['qwen-long', 'qwen-doc'],
|
||||
supportedProviders: [
|
||||
'openai',
|
||||
'azure-openai',
|
||||
'anthropic',
|
||||
'google',
|
||||
'google-generative-ai',
|
||||
'google-vertex',
|
||||
'bedrock',
|
||||
'amazon-bedrock'
|
||||
]
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
@ -43,11 +78,11 @@ export function supportsImageInput(model: Model): boolean {
|
||||
* 检查提供商是否支持大文件上传(如Gemini File API)
|
||||
*/
|
||||
export function supportsLargeFileUpload(model: Model): boolean {
|
||||
const provider = getProviderByModel(model)
|
||||
const aiSdkId = getAiSdkProviderId(provider)
|
||||
|
||||
// 目前主要是Gemini系列支持大文件上传
|
||||
return ['google', 'google-generative-ai', 'google-vertex'].includes(aiSdkId)
|
||||
// 基于AI SDK文档,以下模型或提供商支持大文件上传
|
||||
return modelSupportValidator(model, {
|
||||
supportedModels: ['qwen-long', 'qwen-doc'],
|
||||
supportedProviders: ['google', 'google-generative-ai', 'google-vertex']
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
@ -67,6 +102,11 @@ export function getFileSizeLimit(model: Model, fileType: FileTypes): number {
|
||||
return 20 * 1024 * 1024 // 20MB
|
||||
}
|
||||
|
||||
// Dashscope如果模型支持大文件上传优先使用File API上传
|
||||
if (aiSdkId === 'dashscope' && supportsLargeFileUpload(model)) {
|
||||
return 0 // 使用较小的默认值
|
||||
}
|
||||
|
||||
// 其他提供商没有明确限制,使用较大的默认值
|
||||
// 这与Legacy架构中的实现一致,让提供商自行处理文件大小
|
||||
return Infinity
|
||||
|
||||
@ -3,23 +3,28 @@
|
||||
* 构建AI SDK的流式和非流式参数
|
||||
*/
|
||||
|
||||
import { vertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge'
|
||||
import { vertex } from '@ai-sdk/google-vertex/edge'
|
||||
import { loggerService } from '@logger'
|
||||
import {
|
||||
isGenerateImageModel,
|
||||
isOpenRouterBuiltInWebSearchModel,
|
||||
isReasoningModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedThinkingTokenClaudeModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import type { Assistant, MCPTool, Provider } from '@renderer/types'
|
||||
import { type Assistant, type MCPTool, type Provider } from '@renderer/types'
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import type { ModelMessage } from 'ai'
|
||||
import { stepCountIs } from 'ai'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import { setupToolsConfig } from '../utils/mcp'
|
||||
import { buildProviderOptions } from '../utils/options'
|
||||
import { getAnthropicThinkingBudget } from '../utils/reasoning'
|
||||
import { getTemperature, getTopP } from './modelParameters'
|
||||
|
||||
const logger = loggerService.withContext('parameterBuilder')
|
||||
@ -54,8 +59,9 @@ export async function buildStreamTextParams(
|
||||
const { mcpTools } = options
|
||||
|
||||
const model = assistant.model || getDefaultModel()
|
||||
const aiSdkProviderId = getAiSdkProviderId(provider)
|
||||
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
let { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
// 这三个变量透传出来,交给下面启用插件/中间件
|
||||
// 也可以在外部构建好再传入buildStreamTextParams
|
||||
@ -65,17 +71,20 @@ export async function buildStreamTextParams(
|
||||
assistant.settings?.reasoning_effort !== undefined) ||
|
||||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
|
||||
|
||||
// 判断是否使用内置搜索
|
||||
// 条件:没有外部搜索提供商 && (用户开启了内置搜索 || 模型强制使用内置搜索)
|
||||
const hasExternalSearch = !!options.webSearchProviderId
|
||||
const enableWebSearch =
|
||||
(assistant.enableWebSearch && isWebSearchModel(model)) ||
|
||||
isOpenRouterBuiltInWebSearchModel(model) ||
|
||||
model.id.includes('sonar') ||
|
||||
false
|
||||
!hasExternalSearch &&
|
||||
((assistant.enableWebSearch && isWebSearchModel(model)) ||
|
||||
isOpenRouterBuiltInWebSearchModel(model) ||
|
||||
model.id.includes('sonar'))
|
||||
|
||||
const enableUrlContext = assistant.enableUrlContext || false
|
||||
|
||||
const enableGenerateImage = !!(isGenerateImageModel(model) && assistant.enableGenerateImage)
|
||||
|
||||
const tools = setupToolsConfig(mcpTools)
|
||||
let tools = setupToolsConfig(mcpTools)
|
||||
|
||||
// if (webSearchProviderId) {
|
||||
// tools['builtin_web_search'] = webSearchTool(webSearchProviderId)
|
||||
@ -88,6 +97,36 @@ export async function buildStreamTextParams(
|
||||
enableGenerateImage
|
||||
})
|
||||
|
||||
// NOTE: ai-sdk会把maxToken和budgetToken加起来
|
||||
if (
|
||||
enableReasoning &&
|
||||
maxTokens !== undefined &&
|
||||
isSupportedThinkingTokenClaudeModel(model) &&
|
||||
(provider.type === 'anthropic' || provider.type === 'aws-bedrock')
|
||||
) {
|
||||
maxTokens -= getAnthropicThinkingBudget(assistant, model)
|
||||
}
|
||||
|
||||
// google-vertex | google-vertex-anthropic
|
||||
if (enableWebSearch) {
|
||||
if (!tools) {
|
||||
tools = {}
|
||||
}
|
||||
if (aiSdkProviderId === 'google-vertex') {
|
||||
tools.google_search = vertex.tools.googleSearch({})
|
||||
} else if (aiSdkProviderId === 'google-vertex-anthropic') {
|
||||
tools.web_search = vertexAnthropic.tools.webSearch_20250305({})
|
||||
}
|
||||
}
|
||||
|
||||
// google-vertex
|
||||
if (enableUrlContext && aiSdkProviderId === 'google-vertex') {
|
||||
if (!tools) {
|
||||
tools = {}
|
||||
}
|
||||
tools.url_context = vertex.tools.urlContext({})
|
||||
}
|
||||
|
||||
// 构建基础参数
|
||||
const params: StreamTextParams = {
|
||||
messages: sdkMessages,
|
||||
@ -97,10 +136,12 @@ export async function buildStreamTextParams(
|
||||
abortSignal: options.requestOptions?.signal,
|
||||
headers: options.requestOptions?.headers,
|
||||
providerOptions,
|
||||
tools,
|
||||
stopWhen: stepCountIs(10),
|
||||
maxRetries: 0
|
||||
}
|
||||
if (tools) {
|
||||
params.tools = tools
|
||||
}
|
||||
if (assistant.prompt) {
|
||||
params.system = assistant.prompt
|
||||
}
|
||||
|
||||
@ -15,7 +15,7 @@ import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useV
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { loggerService } from '@renderer/services/LoggerService'
|
||||
import store from '@renderer/store'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import { isSystemProvider, type Model, type Provider } from '@renderer/types'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import { cloneDeep, isEmpty } from 'lodash'
|
||||
|
||||
@ -61,14 +61,16 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
|
||||
// return createVertexProvider(provider)
|
||||
// }
|
||||
|
||||
if (provider.id === 'aihubmix') {
|
||||
return aihubmixProviderCreator(model, provider)
|
||||
}
|
||||
if (provider.id === 'newapi') {
|
||||
return newApiResolverCreator(model, provider)
|
||||
}
|
||||
if (provider.id === 'vertexai') {
|
||||
return vertexAnthropicProviderCreator(model, provider)
|
||||
if (isSystemProvider(provider)) {
|
||||
if (provider.id === 'aihubmix') {
|
||||
return aihubmixProviderCreator(model, provider)
|
||||
}
|
||||
if (provider.id === 'new-api') {
|
||||
return newApiResolverCreator(model, provider)
|
||||
}
|
||||
if (provider.id === 'vertexai') {
|
||||
return vertexAnthropicProviderCreator(model, provider)
|
||||
}
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
@ -39,6 +39,14 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
|
||||
creatorFunctionName: 'createAmazonBedrock',
|
||||
supportsImageGeneration: true,
|
||||
aliases: ['aws-bedrock']
|
||||
},
|
||||
{
|
||||
id: 'perplexity',
|
||||
name: 'Perplexity',
|
||||
import: () => import('@ai-sdk/perplexity'),
|
||||
creatorFunctionName: 'createPerplexity',
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['perplexity']
|
||||
}
|
||||
] as const
|
||||
|
||||
|
||||
@ -23,8 +23,6 @@ export const knowledgeSearchTool = (
|
||||
Pre-extracted search queries: "${extractedKeywords.question.join(', ')}"
|
||||
Rewritten query: "${extractedKeywords.rewrite}"
|
||||
|
||||
This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.
|
||||
|
||||
Call this tool to execute the search. You can optionally provide additional context to refine the search.`,
|
||||
|
||||
inputSchema: z.object({
|
||||
@ -35,99 +33,102 @@ Call this tool to execute the search. You can optionally provide additional cont
|
||||
}),
|
||||
|
||||
execute: async ({ additionalContext }) => {
|
||||
try {
|
||||
// 获取助手的知识库配置
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
// try {
|
||||
// 获取助手的知识库配置
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
|
||||
// 检查是否有知识库
|
||||
if (!hasKnowledgeBase) {
|
||||
return {
|
||||
summary: 'No knowledge base configured for this assistant.',
|
||||
knowledgeReferences: [],
|
||||
instructions: ''
|
||||
// 检查是否有知识库
|
||||
if (!hasKnowledgeBase) {
|
||||
return []
|
||||
}
|
||||
|
||||
let finalQueries = [...extractedKeywords.question]
|
||||
let finalRewrite = extractedKeywords.rewrite
|
||||
|
||||
if (additionalContext?.trim()) {
|
||||
// 如果大模型提供了额外上下文,使用更具体的描述
|
||||
const cleanContext = additionalContext.trim()
|
||||
if (cleanContext) {
|
||||
finalQueries = [cleanContext]
|
||||
finalRewrite = cleanContext
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否需要搜索
|
||||
if (finalQueries[0] === 'not_needed') {
|
||||
return []
|
||||
}
|
||||
|
||||
// 构建搜索条件
|
||||
let searchCriteria: { question: string[]; rewrite: string }
|
||||
|
||||
if (knowledgeRecognition === 'off') {
|
||||
// 直接模式:使用用户消息内容
|
||||
const directContent = userMessage || finalQueries[0] || 'search'
|
||||
searchCriteria = {
|
||||
question: [directContent],
|
||||
rewrite: directContent
|
||||
}
|
||||
} else {
|
||||
// 自动模式:使用意图识别的结果
|
||||
searchCriteria = {
|
||||
question: finalQueries,
|
||||
rewrite: finalRewrite
|
||||
}
|
||||
}
|
||||
|
||||
// 构建 ExtractResults 对象
|
||||
const extractResults: ExtractResults = {
|
||||
websearch: undefined,
|
||||
knowledge: searchCriteria
|
||||
}
|
||||
|
||||
// 执行知识库搜索
|
||||
const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds, topicId)
|
||||
const knowledgeReferencesData = knowledgeReferences.map((ref: KnowledgeReference) => ({
|
||||
id: ref.id,
|
||||
content: ref.content,
|
||||
sourceUrl: ref.sourceUrl,
|
||||
type: ref.type,
|
||||
file: ref.file,
|
||||
metadata: ref.metadata
|
||||
}))
|
||||
|
||||
// TODO 在工具函数中添加搜索缓存机制
|
||||
// const searchCacheKey = `${topicId}-${JSON.stringify(finalQueries)}`
|
||||
|
||||
// 返回结果
|
||||
return knowledgeReferencesData
|
||||
},
|
||||
toModelOutput: (results) => {
|
||||
let summary = 'No search needed based on the query analysis.'
|
||||
if (results.length > 0) {
|
||||
summary = `Found ${results.length} relevant sources. Use [number] format to cite specific information.`
|
||||
}
|
||||
const referenceContent = `\`\`\`json\n${JSON.stringify(results, null, 2)}\n\`\`\``
|
||||
const fullInstructions = REFERENCE_PROMPT.replace(
|
||||
'{question}',
|
||||
"Based on the knowledge references, please answer the user's question with proper citations."
|
||||
).replace('{references}', referenceContent)
|
||||
|
||||
return {
|
||||
type: 'content',
|
||||
value: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.'
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: summary
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: fullInstructions
|
||||
}
|
||||
}
|
||||
|
||||
let finalQueries = [...extractedKeywords.question]
|
||||
let finalRewrite = extractedKeywords.rewrite
|
||||
|
||||
if (additionalContext?.trim()) {
|
||||
// 如果大模型提供了额外上下文,使用更具体的描述
|
||||
const cleanContext = additionalContext.trim()
|
||||
if (cleanContext) {
|
||||
finalQueries = [cleanContext]
|
||||
finalRewrite = cleanContext
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否需要搜索
|
||||
if (finalQueries[0] === 'not_needed') {
|
||||
return {
|
||||
summary: 'No search needed based on the query analysis.',
|
||||
knowledgeReferences: [],
|
||||
instructions: ''
|
||||
}
|
||||
}
|
||||
|
||||
// 构建搜索条件
|
||||
let searchCriteria: { question: string[]; rewrite: string }
|
||||
|
||||
if (knowledgeRecognition === 'off') {
|
||||
// 直接模式:使用用户消息内容
|
||||
const directContent = userMessage || finalQueries[0] || 'search'
|
||||
searchCriteria = {
|
||||
question: [directContent],
|
||||
rewrite: directContent
|
||||
}
|
||||
} else {
|
||||
// 自动模式:使用意图识别的结果
|
||||
searchCriteria = {
|
||||
question: finalQueries,
|
||||
rewrite: finalRewrite
|
||||
}
|
||||
}
|
||||
|
||||
// 构建 ExtractResults 对象
|
||||
const extractResults: ExtractResults = {
|
||||
websearch: undefined,
|
||||
knowledge: searchCriteria
|
||||
}
|
||||
|
||||
// 执行知识库搜索
|
||||
const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds, topicId)
|
||||
const knowledgeReferencesData = knowledgeReferences.map((ref: KnowledgeReference) => ({
|
||||
id: ref.id,
|
||||
content: ref.content,
|
||||
sourceUrl: ref.sourceUrl,
|
||||
type: ref.type,
|
||||
file: ref.file,
|
||||
metadata: ref.metadata
|
||||
}))
|
||||
|
||||
// const referenceContent = `\`\`\`json\n${JSON.stringify(knowledgeReferencesData, null, 2)}\n\`\`\``
|
||||
// TODO 在工具函数中添加搜索缓存机制
|
||||
// const searchCacheKey = `${topicId}-${JSON.stringify(finalQueries)}`
|
||||
// 可以在插件层面管理已搜索的查询,避免重复搜索
|
||||
const fullInstructions = REFERENCE_PROMPT.replace(
|
||||
'{question}',
|
||||
"Based on the knowledge references, please answer the user's question with proper citations."
|
||||
).replace('{references}', 'knowledgeReferences:')
|
||||
|
||||
// 返回结果
|
||||
return {
|
||||
summary: `Found ${knowledgeReferencesData.length} relevant sources. Use [number] format to cite specific information.`,
|
||||
knowledgeReferences: knowledgeReferencesData,
|
||||
instructions: fullInstructions
|
||||
}
|
||||
} catch (error) {
|
||||
// 返回空对象而不是抛出错误,避免中断对话流程
|
||||
return {
|
||||
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
knowledgeReferences: [],
|
||||
instructions: ''
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import store from '@renderer/store'
|
||||
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
||||
import type { Assistant } from '@renderer/types'
|
||||
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
|
||||
import { z } from 'zod'
|
||||
|
||||
@ -19,133 +18,29 @@ export const memorySearchTool = () => {
|
||||
limit: z.number().min(1).max(20).default(5).describe('Maximum number of memories to return')
|
||||
}),
|
||||
execute: async ({ query, limit = 5 }) => {
|
||||
// console.log('🧠 [memorySearchTool] Searching memories:', { query, limit })
|
||||
|
||||
try {
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
if (!globalMemoryEnabled) {
|
||||
return []
|
||||
}
|
||||
|
||||
const memoryConfig = selectMemoryConfig(store.getState())
|
||||
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
||||
// console.warn('Memory search skipped: embedding or LLM model not configured')
|
||||
return []
|
||||
}
|
||||
|
||||
const currentUserId = selectCurrentUserId(store.getState())
|
||||
const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, 'default', currentUserId)
|
||||
|
||||
const memoryProcessor = new MemoryProcessor()
|
||||
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
|
||||
|
||||
if (relevantMemories?.length > 0) {
|
||||
// console.log('🧠 [memorySearchTool] Found memories:', relevantMemories.length)
|
||||
return relevantMemories
|
||||
}
|
||||
return []
|
||||
} catch (error) {
|
||||
// console.error('🧠 [memorySearchTool] Error:', error)
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
if (!globalMemoryEnabled) {
|
||||
return []
|
||||
}
|
||||
|
||||
const memoryConfig = selectMemoryConfig(store.getState())
|
||||
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
||||
return []
|
||||
}
|
||||
|
||||
const currentUserId = selectCurrentUserId(store.getState())
|
||||
const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, 'default', currentUserId)
|
||||
|
||||
const memoryProcessor = new MemoryProcessor()
|
||||
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
|
||||
|
||||
if (relevantMemories?.length > 0) {
|
||||
return relevantMemories
|
||||
}
|
||||
return []
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 方案4: 为第二个工具也使用类型断言
|
||||
type MessageRole = 'user' | 'assistant' | 'system'
|
||||
type MessageType = {
|
||||
content: string
|
||||
role: MessageRole
|
||||
}
|
||||
type MemorySearchWithExtractionInput = {
|
||||
userMessage: MessageType
|
||||
lastAnswer?: MessageType
|
||||
}
|
||||
|
||||
/**
|
||||
* 🧠 智能记忆搜索工具(带上下文提取)
|
||||
* 从用户消息和对话历史中自动提取关键词进行记忆搜索
|
||||
*/
|
||||
export const memorySearchToolWithExtraction = (assistant: Assistant) => {
|
||||
return tool({
|
||||
name: 'memory_search_with_extraction',
|
||||
description: 'Search memories with automatic keyword extraction from conversation context',
|
||||
inputSchema: z.object({
|
||||
userMessage: z.object({
|
||||
content: z.string().describe('The main content of the user message'),
|
||||
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||
}),
|
||||
lastAnswer: z
|
||||
.object({
|
||||
content: z.string().describe('The main content of the last assistant response'),
|
||||
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||
})
|
||||
.optional()
|
||||
}) as z.ZodSchema<MemorySearchWithExtractionInput>,
|
||||
execute: async ({ userMessage }) => {
|
||||
// console.log('🧠 [memorySearchToolWithExtraction] Processing:', { userMessage, lastAnswer })
|
||||
|
||||
try {
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
||||
return {
|
||||
extractedKeywords: 'Memory search disabled',
|
||||
searchResults: []
|
||||
}
|
||||
}
|
||||
|
||||
const memoryConfig = selectMemoryConfig(store.getState())
|
||||
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
||||
// console.warn('Memory search skipped: embedding or LLM model not configured')
|
||||
return {
|
||||
extractedKeywords: 'Memory models not configured',
|
||||
searchResults: []
|
||||
}
|
||||
}
|
||||
|
||||
// 🔍 使用用户消息内容作为搜索关键词
|
||||
const content = userMessage.content
|
||||
|
||||
if (!content) {
|
||||
return {
|
||||
extractedKeywords: 'No content to search',
|
||||
searchResults: []
|
||||
}
|
||||
}
|
||||
|
||||
const currentUserId = selectCurrentUserId(store.getState())
|
||||
const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, assistant.id, currentUserId)
|
||||
|
||||
const memoryProcessor = new MemoryProcessor()
|
||||
const relevantMemories = await memoryProcessor.searchRelevantMemories(
|
||||
content,
|
||||
processorConfig,
|
||||
5 // Limit to top 5 most relevant memories
|
||||
)
|
||||
|
||||
if (relevantMemories?.length > 0) {
|
||||
// console.log('🧠 [memorySearchToolWithExtraction] Found memories:', relevantMemories.length)
|
||||
return {
|
||||
extractedKeywords: content,
|
||||
searchResults: relevantMemories
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
extractedKeywords: content,
|
||||
searchResults: []
|
||||
}
|
||||
} catch (error) {
|
||||
// console.error('🧠 [memorySearchToolWithExtraction] Error:', error)
|
||||
return {
|
||||
extractedKeywords: 'Search failed',
|
||||
searchResults: []
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
export type MemorySearchToolInput = InferToolInput<ReturnType<typeof memorySearchTool>>
|
||||
export type MemorySearchToolOutput = InferToolOutput<ReturnType<typeof memorySearchTool>>
|
||||
export type MemorySearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof memorySearchToolWithExtraction>>
|
||||
|
||||
@ -30,8 +30,6 @@ Relevant links: ${extractedKeywords.links.join(', ')}`
|
||||
: ''
|
||||
}
|
||||
|
||||
This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.
|
||||
|
||||
Call this tool to execute the search. You can optionally provide additional context to refine the search.`,
|
||||
|
||||
inputSchema: z.object({
|
||||
@ -58,40 +56,27 @@ Call this tool to execute the search. You can optionally provide additional cont
|
||||
}
|
||||
// 检查是否需要搜索
|
||||
if (finalQueries[0] === 'not_needed') {
|
||||
return {
|
||||
summary: 'No search needed based on the query analysis.',
|
||||
searchResults,
|
||||
sources: '',
|
||||
instructions: ''
|
||||
}
|
||||
return searchResults
|
||||
}
|
||||
|
||||
try {
|
||||
// 构建 ExtractResults 结构用于 processWebsearch
|
||||
const extractResults: ExtractResults = {
|
||||
websearch: {
|
||||
question: finalQueries,
|
||||
links: extractedKeywords.links
|
||||
}
|
||||
}
|
||||
searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
||||
} catch (error) {
|
||||
return {
|
||||
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
sources: [],
|
||||
instructions: ''
|
||||
// 构建 ExtractResults 结构用于 processWebsearch
|
||||
const extractResults: ExtractResults = {
|
||||
websearch: {
|
||||
question: finalQueries,
|
||||
links: extractedKeywords.links
|
||||
}
|
||||
}
|
||||
if (searchResults.results.length === 0) {
|
||||
return {
|
||||
summary: 'No search results found for the given query.',
|
||||
sources: [],
|
||||
instructions: ''
|
||||
}
|
||||
searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
||||
|
||||
return searchResults
|
||||
},
|
||||
toModelOutput: (results) => {
|
||||
let summary = 'No search needed based on the query analysis.'
|
||||
if (results.query && results.results.length > 0) {
|
||||
summary = `Found ${results.results.length} relevant sources. Use [number] format to cite specific information.`
|
||||
}
|
||||
|
||||
const results = searchResults.results
|
||||
const citationData = results.map((result, index) => ({
|
||||
const citationData = results.results.map((result, index) => ({
|
||||
number: index + 1,
|
||||
title: result.title,
|
||||
content: result.content,
|
||||
@ -99,18 +84,27 @@ Call this tool to execute the search. You can optionally provide additional cont
|
||||
}))
|
||||
|
||||
// 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑
|
||||
// const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
|
||||
|
||||
// 构建完整的引用指导文本
|
||||
const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
|
||||
const fullInstructions = REFERENCE_PROMPT.replace(
|
||||
'{question}',
|
||||
"Based on the search results, please answer the user's question with proper citations."
|
||||
).replace('{references}', 'searchResults:')
|
||||
|
||||
).replace('{references}', referenceContent)
|
||||
return {
|
||||
summary: `Found ${citationData.length} relevant sources. Use [number] format to cite specific information.`,
|
||||
searchResults,
|
||||
instructions: fullInstructions
|
||||
type: 'content',
|
||||
value: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.'
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: summary
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: fullInstructions
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
import { loggerService } from '@logger'
|
||||
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
||||
import { MCPTool, MCPToolResponse } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { callMCPTool, getMcpServerByTool, isToolAutoApproved } from '@renderer/utils/mcp-tools'
|
||||
import { requestToolConfirmation } from '@renderer/utils/userConfirmation'
|
||||
import { type Tool, type ToolSet } from 'ai'
|
||||
@ -33,8 +31,36 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
|
||||
tools[mcpTool.name] = tool({
|
||||
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
||||
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||
execute: async (params, { toolCallId, experimental_context }) => {
|
||||
const { onChunk } = experimental_context as { onChunk: (chunk: Chunk) => void }
|
||||
execute: async (params, { toolCallId }) => {
|
||||
// 检查是否启用自动批准
|
||||
const server = getMcpServerByTool(mcpTool)
|
||||
const isAutoApproveEnabled = isToolAutoApproved(mcpTool, server)
|
||||
|
||||
let confirmed = true
|
||||
|
||||
if (!isAutoApproveEnabled) {
|
||||
// 请求用户确认
|
||||
logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`)
|
||||
confirmed = await requestToolConfirmation(toolCallId)
|
||||
}
|
||||
|
||||
if (!confirmed) {
|
||||
// 用户拒绝执行工具
|
||||
logger.debug(`User cancelled tool execution: ${mcpTool.name}`)
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `User declined to execute tool "${mcpTool.name}".`
|
||||
}
|
||||
],
|
||||
isError: false
|
||||
}
|
||||
}
|
||||
|
||||
// 用户确认或自动批准,执行工具
|
||||
logger.debug(`Executing tool: ${mcpTool.name}`)
|
||||
|
||||
// 创建适配的 MCPToolResponse 对象
|
||||
const toolResponse: MCPToolResponse = {
|
||||
id: toolCallId,
|
||||
@ -44,53 +70,18 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
|
||||
toolCallId
|
||||
}
|
||||
|
||||
try {
|
||||
// 检查是否启用自动批准
|
||||
const server = getMcpServerByTool(mcpTool)
|
||||
const isAutoApproveEnabled = isToolAutoApproved(mcpTool, server)
|
||||
const result = await callMCPTool(toolResponse)
|
||||
|
||||
let confirmed = true
|
||||
if (!isAutoApproveEnabled) {
|
||||
// 请求用户确认
|
||||
logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`)
|
||||
confirmed = await requestToolConfirmation(toolResponse.id)
|
||||
}
|
||||
|
||||
if (!confirmed) {
|
||||
// 用户拒绝执行工具
|
||||
logger.debug(`User cancelled tool execution: ${mcpTool.name}`)
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `User declined to execute tool "${mcpTool.name}".`
|
||||
}
|
||||
],
|
||||
isError: false
|
||||
}
|
||||
}
|
||||
|
||||
// 用户确认或自动批准,执行工具
|
||||
toolResponse.status = 'invoking'
|
||||
logger.debug(`Executing tool: ${mcpTool.name}`)
|
||||
|
||||
onChunk({
|
||||
type: ChunkType.MCP_TOOL_IN_PROGRESS,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
|
||||
const result = await callMCPTool(toolResponse)
|
||||
|
||||
// 返回结果,AI SDK 会处理序列化
|
||||
if (result.isError) {
|
||||
throw new Error(result.content?.[0]?.text || 'Tool execution failed')
|
||||
}
|
||||
// 返回工具执行结果
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.error(`MCP Tool execution failed: ${mcpTool.name}`, { error })
|
||||
throw error
|
||||
// 返回结果,AI SDK 会处理序列化
|
||||
if (result.isError) {
|
||||
// throw new Error(result.content?.[0]?.text || 'Tool execution failed')
|
||||
return Promise.reject(result)
|
||||
}
|
||||
// 返回工具执行结果
|
||||
return result
|
||||
// } catch (error) {
|
||||
// logger.error(`MCP Tool execution failed: ${mcpTool.name}`, { error })
|
||||
// }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -120,6 +120,9 @@ export function buildProviderOptions(
|
||||
case 'google-vertex':
|
||||
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
case 'google-vertex-anthropic':
|
||||
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
default:
|
||||
// 对于其他 provider,使用通用的构建逻辑
|
||||
providerSpecificOptions = {
|
||||
@ -137,10 +140,16 @@ export function buildProviderOptions(
|
||||
...providerSpecificOptions,
|
||||
...getCustomParameters(assistant)
|
||||
}
|
||||
// vertex需要映射到google或anthropic
|
||||
const rawProviderKey =
|
||||
{
|
||||
'google-vertex': 'google',
|
||||
'google-vertex-anthropic': 'anthropic'
|
||||
}[rawProviderId] || rawProviderId
|
||||
|
||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
|
||||
return {
|
||||
[rawProviderId]: providerSpecificOptions
|
||||
[rawProviderKey]: providerSpecificOptions
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -310,6 +310,26 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
|
||||
return {}
|
||||
}
|
||||
|
||||
export function getAnthropicThinkingBudget(assistant: Assistant, model: Model): number {
|
||||
const { maxTokens, reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
|
||||
if (maxTokens === undefined || reasoningEffort === undefined) {
|
||||
return 0
|
||||
}
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
const budgetTokens = Math.max(
|
||||
1024,
|
||||
Math.floor(
|
||||
Math.min(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||
findTokenLimit(model.id)?.min!,
|
||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||
)
|
||||
)
|
||||
)
|
||||
return budgetTokens
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Anthropic 推理参数
|
||||
* 从 AnthropicAPIClient 中提取的逻辑
|
||||
@ -331,19 +351,7 @@ export function getAnthropicReasoningParams(assistant: Assistant, model: Model):
|
||||
|
||||
// Claude 推理参数
|
||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
const budgetTokens = Math.max(
|
||||
1024,
|
||||
Math.floor(
|
||||
Math.min(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||
findTokenLimit(model.id)?.min!,
|
||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||
)
|
||||
)
|
||||
)
|
||||
const budgetTokens = getAnthropicThinkingBudget(assistant, model)
|
||||
|
||||
return {
|
||||
thinking: {
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 299 KiB |
60
src/renderer/src/assets/styles/CommandListPopover.css
Normal file
60
src/renderer/src/assets/styles/CommandListPopover.css
Normal file
@ -0,0 +1,60 @@
|
||||
.command-list-popover {
|
||||
/* Base styles are handled inline for theme support */
|
||||
|
||||
/* Arrow styles based on placement */
|
||||
}
|
||||
|
||||
.command-list-popover[data-placement^='bottom'] {
|
||||
transform-origin: top center;
|
||||
animation: slideDownAndFadeIn 0.15s ease-out;
|
||||
}
|
||||
|
||||
.command-list-popover[data-placement^='top'] {
|
||||
transform-origin: bottom center;
|
||||
animation: slideUpAndFadeIn 0.15s ease-out;
|
||||
}
|
||||
|
||||
.command-list-popover[data-placement*='start'] {
|
||||
transform-origin: left center;
|
||||
}
|
||||
|
||||
.command-list-popover[data-placement*='end'] {
|
||||
transform-origin: right center;
|
||||
}
|
||||
|
||||
@keyframes slideDownAndFadeIn {
|
||||
0% {
|
||||
opacity: 0;
|
||||
transform: translateY(-8px) scale(0.95);
|
||||
}
|
||||
100% {
|
||||
opacity: 1;
|
||||
transform: translateY(0) scale(1);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes slideUpAndFadeIn {
|
||||
0% {
|
||||
opacity: 0;
|
||||
transform: translateY(8px) scale(0.95);
|
||||
}
|
||||
100% {
|
||||
opacity: 1;
|
||||
transform: translateY(0) scale(1);
|
||||
}
|
||||
}
|
||||
|
||||
/* Ensure smooth scrolling in virtual list */
|
||||
.command-list-popover .dynamic-virtual-list {
|
||||
scroll-behavior: smooth;
|
||||
}
|
||||
|
||||
/* Better focus indicators */
|
||||
.command-list-popover [data-index] {
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.command-list-popover [data-index]:focus-visible {
|
||||
outline: 2px solid var(--color-primary, #1677ff);
|
||||
outline-offset: -2px;
|
||||
}
|
||||
@ -1,59 +0,0 @@
|
||||
.command-list-popover {
|
||||
// Base styles are handled inline for theme support
|
||||
|
||||
// Arrow styles based on placement
|
||||
&[data-placement^='bottom'] {
|
||||
transform-origin: top center;
|
||||
animation: slideDownAndFadeIn 0.15s ease-out;
|
||||
}
|
||||
|
||||
&[data-placement^='top'] {
|
||||
transform-origin: bottom center;
|
||||
animation: slideUpAndFadeIn 0.15s ease-out;
|
||||
}
|
||||
|
||||
&[data-placement*='start'] {
|
||||
transform-origin: left center;
|
||||
}
|
||||
|
||||
&[data-placement*='end'] {
|
||||
transform-origin: right center;
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes slideDownAndFadeIn {
|
||||
0% {
|
||||
opacity: 0;
|
||||
transform: translateY(-8px) scale(0.95);
|
||||
}
|
||||
100% {
|
||||
opacity: 1;
|
||||
transform: translateY(0) scale(1);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes slideUpAndFadeIn {
|
||||
0% {
|
||||
opacity: 0;
|
||||
transform: translateY(8px) scale(0.95);
|
||||
}
|
||||
100% {
|
||||
opacity: 1;
|
||||
transform: translateY(0) scale(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure smooth scrolling in virtual list
|
||||
.command-list-popover .dynamic-virtual-list {
|
||||
scroll-behavior: smooth;
|
||||
}
|
||||
|
||||
// Better focus indicators
|
||||
.command-list-popover [data-index] {
|
||||
position: relative;
|
||||
|
||||
&:focus-visible {
|
||||
outline: 2px solid var(--color-primary, #1677ff);
|
||||
outline-offset: -2px;
|
||||
}
|
||||
}
|
||||
@ -10,14 +10,14 @@
|
||||
}
|
||||
}
|
||||
|
||||
// 电磁波扩散效果
|
||||
/* 电磁波扩散效果 */
|
||||
.animation-pulse {
|
||||
--pulse-color: 59, 130, 246;
|
||||
--pulse-size: 8px;
|
||||
animation: animation-pulse 1.5s infinite;
|
||||
}
|
||||
|
||||
// Modal动画
|
||||
/* Modal动画 */
|
||||
@keyframes animation-move-down-in {
|
||||
0% {
|
||||
transform: translate3d(0, 100%, 0);
|
||||
@ -54,7 +54,7 @@
|
||||
animation-duration: 0.25s;
|
||||
}
|
||||
|
||||
// 旋转动画
|
||||
/* 旋转动画 */
|
||||
@keyframes animation-rotate {
|
||||
from {
|
||||
transform: rotate(0deg);
|
||||
@ -69,7 +69,7 @@
|
||||
animation: animation-rotate 0.75s linear infinite;
|
||||
}
|
||||
|
||||
// 定位高亮动画
|
||||
/* 定位高亮动画 */
|
||||
@keyframes animation-locate-highlight {
|
||||
0% {
|
||||
background-color: transparent;
|
||||
238
src/renderer/src/assets/styles/ant.css
Normal file
238
src/renderer/src/assets/styles/ant.css
Normal file
@ -0,0 +1,238 @@
|
||||
@import './container.css';
|
||||
|
||||
/* Modal 关闭按钮不应该可拖拽,以确保点击正常 */
|
||||
.ant-modal-close {
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
/* 普通 Drawer 内容不应该可拖拽 */
|
||||
.ant-drawer-content {
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
/* minapp-drawer 有自己的拖拽规则 */
|
||||
|
||||
/* 下拉菜单和弹出框内容不应该可拖拽 */
|
||||
.ant-dropdown,
|
||||
.ant-dropdown-menu,
|
||||
.ant-popover-content,
|
||||
.ant-tooltip-content,
|
||||
.ant-popconfirm {
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
#inputbar {
|
||||
resize: none;
|
||||
}
|
||||
|
||||
.ant-image-preview-switch-left {
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
.ant-btn:not(:disabled):focus-visible {
|
||||
outline: none;
|
||||
}
|
||||
|
||||
/* Align lucide icon in Button */
|
||||
.ant-btn .ant-btn-icon {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.ant-tabs-tabpane:focus-visible {
|
||||
outline: none;
|
||||
}
|
||||
|
||||
.ant-tabs-tab-btn {
|
||||
outline: none !important;
|
||||
}
|
||||
|
||||
.ant-segmented-group {
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.minapp-drawer .ant-drawer-content-wrapper {
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
.minapp-drawer .ant-drawer-header {
|
||||
position: absolute;
|
||||
-webkit-app-region: drag;
|
||||
min-height: calc(var(--navbar-height) + 0.5px);
|
||||
margin-top: -0.5px;
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
.minapp-drawer .ant-drawer-body {
|
||||
padding: 0;
|
||||
margin-top: var(--navbar-height);
|
||||
overflow: hidden;
|
||||
/* 手动展开 @extend #content-container 的内容 */
|
||||
background-color: var(--color-background);
|
||||
}
|
||||
|
||||
.minapp-drawer .minapp-mask {
|
||||
background-color: transparent !important;
|
||||
}
|
||||
|
||||
[navbar-position='left'] .minapp-drawer {
|
||||
max-width: calc(100vw - var(--sidebar-width));
|
||||
}
|
||||
|
||||
[navbar-position='left'] .minapp-drawer .ant-drawer-header {
|
||||
width: calc(100vw - var(--sidebar-width));
|
||||
}
|
||||
|
||||
[navbar-position='top'] .minapp-drawer {
|
||||
max-width: 100vw;
|
||||
}
|
||||
|
||||
[navbar-position='top'] .minapp-drawer .ant-drawer-header {
|
||||
width: 100vw;
|
||||
}
|
||||
|
||||
.ant-drawer-header {
|
||||
/* 普通 drawer header 不应该可拖拽,除非被 minapp-drawer 覆盖 */
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
.message-attachments .ant-upload-list-item:hover {
|
||||
background-color: initial !important;
|
||||
}
|
||||
|
||||
.ant-dropdown-menu .ant-dropdown-menu-sub {
|
||||
max-height: 80vh;
|
||||
width: max-content;
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden;
|
||||
border: 0.5px solid var(--color-border);
|
||||
}
|
||||
|
||||
.ant-dropdown {
|
||||
background-color: var(--ant-color-bg-elevated);
|
||||
overflow: hidden;
|
||||
border-radius: var(--ant-border-radius-lg);
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.ant-dropdown .ant-dropdown-menu {
|
||||
max-height: 80vh;
|
||||
overflow-y: auto;
|
||||
border: 0.5px solid var(--color-border);
|
||||
}
|
||||
|
||||
/* Align lucide icon in dropdown menu item extra */
|
||||
.ant-dropdown .ant-dropdown-menu .ant-dropdown-menu-submenu-expand-icon,
|
||||
.ant-dropdown .ant-dropdown-menu .ant-dropdown-menu-item-extra {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.ant-dropdown .ant-dropdown-arrow + .ant-dropdown-menu {
|
||||
border: none;
|
||||
}
|
||||
|
||||
.ant-select-dropdown {
|
||||
border: 0.5px solid var(--color-border);
|
||||
}
|
||||
|
||||
.ant-dropdown-menu-submenu {
|
||||
background-color: var(--ant-color-bg-elevated);
|
||||
overflow: hidden;
|
||||
border-radius: var(--ant-border-radius-lg);
|
||||
}
|
||||
|
||||
.ant-dropdown-menu-submenu .ant-dropdown-menu-submenu-title {
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.ant-popover .ant-popover-inner {
|
||||
border: 0.5px solid var(--color-border);
|
||||
}
|
||||
|
||||
.ant-popover .ant-popover-inner .ant-popover-inner-content {
|
||||
max-height: 70vh;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.ant-popover .ant-popover-arrow + .ant-popover-content .ant-popover-inner {
|
||||
border: none;
|
||||
}
|
||||
|
||||
.ant-modal:not(.ant-modal-confirm) .ant-modal-confirm-body-has-title {
|
||||
padding: 16px 0 0 0;
|
||||
}
|
||||
|
||||
.ant-modal:not(.ant-modal-confirm) .ant-modal-content {
|
||||
border-radius: 10px;
|
||||
border: 0.5px solid var(--color-border);
|
||||
padding: 0 0 8px 0;
|
||||
}
|
||||
|
||||
.ant-modal:not(.ant-modal-confirm) .ant-modal-content .ant-modal-close {
|
||||
margin-right: 2px;
|
||||
}
|
||||
|
||||
.ant-modal:not(.ant-modal-confirm) .ant-modal-content .ant-modal-header {
|
||||
padding: 16px 16px 0 16px;
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.ant-modal:not(.ant-modal-confirm) .ant-modal-content .ant-modal-body {
|
||||
/* 保持 body 在视口内,使用标准的最大高度 */
|
||||
max-height: 80vh;
|
||||
overflow-y: auto;
|
||||
padding: 0 16px 0 16px;
|
||||
}
|
||||
|
||||
.ant-modal:not(.ant-modal-confirm) .ant-modal-content .ant-modal-footer {
|
||||
padding: 0 16px 8px 16px;
|
||||
}
|
||||
|
||||
.ant-modal:not(.ant-modal-confirm) .ant-modal-content .ant-modal-confirm-btns {
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.ant-modal.ant-modal-confirm.ant-modal-confirm-confirm .ant-modal-content {
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.ant-collapse:not(.ant-collapse-ghost) {
|
||||
border: 1px solid var(--color-border);
|
||||
}
|
||||
|
||||
.ant-color-picker .ant-collapse:not(.ant-collapse-ghost) {
|
||||
border: none;
|
||||
}
|
||||
|
||||
.ant-collapse:not(.ant-collapse-ghost) .ant-collapse-content {
|
||||
border-top: 0.5px solid var(--color-border) !important;
|
||||
}
|
||||
|
||||
.ant-color-picker .ant-collapse:not(.ant-collapse-ghost) .ant-collapse-content {
|
||||
border-top: none !important;
|
||||
}
|
||||
|
||||
.ant-slider .ant-slider-handle::after {
|
||||
box-shadow: 0 1px 4px 0px rgb(128 128 128 / 50%) !important;
|
||||
}
|
||||
|
||||
.ant-splitter-bar .ant-splitter-bar-dragger::before {
|
||||
background-color: var(--color-border) !important;
|
||||
transition:
|
||||
background-color 0.15s ease,
|
||||
width 0.15s ease;
|
||||
}
|
||||
|
||||
.ant-splitter-bar .ant-splitter-bar-dragger:hover::before {
|
||||
width: 4px !important;
|
||||
background-color: var(--color-primary) !important;
|
||||
transition-delay: 0.15s;
|
||||
}
|
||||
|
||||
.ant-splitter-bar .ant-splitter-bar-dragger-active::before {
|
||||
width: 4px !important;
|
||||
background-color: var(--color-primary) !important;
|
||||
}
|
||||
@ -1,234 +0,0 @@
|
||||
@use './container.scss';
|
||||
|
||||
/* Modal 关闭按钮不应该可拖拽,以确保点击正常 */
|
||||
.ant-modal-close {
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
/* 普通 Drawer 内容不应该可拖拽 */
|
||||
.ant-drawer-content {
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
/* minapp-drawer 有自己的拖拽规则 */
|
||||
|
||||
/* 下拉菜单和弹出框内容不应该可拖拽 */
|
||||
.ant-dropdown,
|
||||
.ant-dropdown-menu,
|
||||
.ant-popover-content,
|
||||
.ant-tooltip-content,
|
||||
.ant-popconfirm {
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
#inputbar {
|
||||
resize: none;
|
||||
}
|
||||
|
||||
.ant-image-preview-switch-left {
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
.ant-btn:not(:disabled):focus-visible {
|
||||
outline: none;
|
||||
}
|
||||
|
||||
// Align lucide icon in Button
|
||||
.ant-btn .ant-btn-icon {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.ant-tabs-tabpane:focus-visible {
|
||||
outline: none;
|
||||
}
|
||||
|
||||
.ant-tabs-tab-btn {
|
||||
outline: none !important;
|
||||
}
|
||||
|
||||
.ant-segmented-group {
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.minapp-drawer {
|
||||
[navbar-position='left'] & {
|
||||
max-width: calc(100vw - var(--sidebar-width));
|
||||
.ant-drawer-header {
|
||||
width: calc(100vw - var(--sidebar-width));
|
||||
}
|
||||
}
|
||||
[navbar-position='top'] & {
|
||||
max-width: 100vw;
|
||||
.ant-drawer-header {
|
||||
width: 100vw;
|
||||
}
|
||||
}
|
||||
.ant-drawer-content-wrapper {
|
||||
box-shadow: none;
|
||||
}
|
||||
.ant-drawer-header {
|
||||
position: absolute;
|
||||
-webkit-app-region: drag;
|
||||
min-height: calc(var(--navbar-height) + 0.5px);
|
||||
margin-top: -0.5px;
|
||||
border-bottom: none;
|
||||
}
|
||||
.ant-drawer-body {
|
||||
padding: 0;
|
||||
margin-top: var(--navbar-height);
|
||||
overflow: hidden;
|
||||
@extend #content-container;
|
||||
}
|
||||
.minapp-mask {
|
||||
background-color: transparent !important;
|
||||
}
|
||||
}
|
||||
|
||||
.ant-drawer-header {
|
||||
/* 普通 drawer header 不应该可拖拽,除非被 minapp-drawer 覆盖 */
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
.message-attachments {
|
||||
.ant-upload-list-item:hover {
|
||||
background-color: initial !important;
|
||||
}
|
||||
}
|
||||
|
||||
.ant-dropdown-menu .ant-dropdown-menu-sub {
|
||||
max-height: 80vh;
|
||||
width: max-content;
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden;
|
||||
border: 0.5px solid var(--color-border);
|
||||
}
|
||||
.ant-dropdown {
|
||||
background-color: var(--ant-color-bg-elevated);
|
||||
overflow: hidden;
|
||||
border-radius: var(--ant-border-radius-lg);
|
||||
user-select: none;
|
||||
.ant-dropdown-menu {
|
||||
max-height: 80vh;
|
||||
overflow-y: auto;
|
||||
border: 0.5px solid var(--color-border);
|
||||
|
||||
// Align lucide icon in dropdown menu item extra
|
||||
.ant-dropdown-menu-submenu-expand-icon,
|
||||
.ant-dropdown-menu-item-extra {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
}
|
||||
.ant-dropdown-arrow + .ant-dropdown-menu {
|
||||
border: none;
|
||||
}
|
||||
}
|
||||
.ant-select-dropdown {
|
||||
border: 0.5px solid var(--color-border);
|
||||
}
|
||||
.ant-dropdown-menu-submenu {
|
||||
background-color: var(--ant-color-bg-elevated);
|
||||
overflow: hidden;
|
||||
border-radius: var(--ant-border-radius-lg);
|
||||
|
||||
.ant-dropdown-menu-submenu-title {
|
||||
align-items: center;
|
||||
}
|
||||
}
|
||||
|
||||
.ant-popover {
|
||||
.ant-popover-inner {
|
||||
border: 0.5px solid var(--color-border);
|
||||
.ant-popover-inner-content {
|
||||
max-height: 70vh;
|
||||
overflow-y: auto;
|
||||
}
|
||||
}
|
||||
.ant-popover-arrow + .ant-popover-content {
|
||||
.ant-popover-inner {
|
||||
border: none;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.ant-modal:not(.ant-modal-confirm) {
|
||||
.ant-modal-confirm-body-has-title {
|
||||
padding: 16px 0 0 0;
|
||||
}
|
||||
.ant-modal-content {
|
||||
border-radius: 10px;
|
||||
border: 0.5px solid var(--color-border);
|
||||
padding: 0 0 8px 0;
|
||||
.ant-modal-close {
|
||||
margin-right: 2px;
|
||||
}
|
||||
.ant-modal-header {
|
||||
padding: 16px 16px 0 16px;
|
||||
border-radius: 10px;
|
||||
}
|
||||
.ant-modal-body {
|
||||
/* 保持 body 在视口内,使用标准的最大高度 */
|
||||
max-height: 80vh;
|
||||
overflow-y: auto;
|
||||
padding: 0 16px 0 16px;
|
||||
}
|
||||
.ant-modal-footer {
|
||||
padding: 0 16px 8px 16px;
|
||||
}
|
||||
.ant-modal-confirm-btns {
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
}
|
||||
}
|
||||
.ant-modal.ant-modal-confirm.ant-modal-confirm-confirm {
|
||||
.ant-modal-content {
|
||||
padding: 16px;
|
||||
}
|
||||
}
|
||||
|
||||
.ant-collapse:not(.ant-collapse-ghost) {
|
||||
border: 1px solid var(--color-border);
|
||||
.ant-color-picker & {
|
||||
border: none;
|
||||
}
|
||||
.ant-collapse-content {
|
||||
border-top: 0.5px solid var(--color-border) !important;
|
||||
.ant-color-picker & {
|
||||
border-top: none !important;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.ant-slider {
|
||||
.ant-slider-handle::after {
|
||||
box-shadow: 0 1px 4px 0px rgb(128 128 128 / 50%) !important;
|
||||
}
|
||||
}
|
||||
|
||||
.ant-splitter-bar {
|
||||
.ant-splitter-bar-dragger {
|
||||
&::before {
|
||||
background-color: var(--color-border) !important;
|
||||
transition:
|
||||
background-color 0.15s ease,
|
||||
width 0.15s ease;
|
||||
}
|
||||
&:hover {
|
||||
&::before {
|
||||
width: 4px !important;
|
||||
background-color: var(--color-primary) !important;
|
||||
transition-delay: 0.15s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.ant-splitter-bar-dragger-active {
|
||||
&::before {
|
||||
width: 4px !important;
|
||||
background-color: var(--color-primary) !important;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -19,7 +19,7 @@
|
||||
--color-background-soft: var(--color-black-soft);
|
||||
--color-background-mute: var(--color-black-mute);
|
||||
--color-background-opacity: rgba(34, 34, 34, 0.7);
|
||||
--inner-glow-opacity: 0.3; // For the glassmorphism effect in the dropdown menu
|
||||
--inner-glow-opacity: 0.3; /* For the glassmorphism effect in the dropdown menu */
|
||||
|
||||
--color-primary: #00b96b;
|
||||
--color-primary-soft: #00b96b99;
|
||||
@ -58,16 +58,6 @@
|
||||
--navbar-background-mac: rgba(20, 20, 20, 0.55);
|
||||
--navbar-background: #1f1f1f;
|
||||
|
||||
--navbar-height: 44px;
|
||||
--sidebar-width: 50px;
|
||||
--status-bar-height: 40px;
|
||||
--input-bar-height: 100px;
|
||||
|
||||
--assistants-width: 275px;
|
||||
--topic-list-width: 275px;
|
||||
--settings-width: 250px;
|
||||
--scrollbar-width: 5px;
|
||||
|
||||
--chat-background: transparent;
|
||||
--chat-background-user: rgba(255, 255, 255, 0.08);
|
||||
--chat-background-assistant: transparent;
|
||||
@ -146,11 +136,6 @@
|
||||
--chat-text-user: var(--color-text);
|
||||
}
|
||||
|
||||
[navbar-position='left'] {
|
||||
--navbar-height: 42px;
|
||||
--list-item-border-radius: 20px;
|
||||
}
|
||||
|
||||
[navbar-position='left'][theme-mode='light'] {
|
||||
--color-list-item: #eee;
|
||||
--color-list-item-hover: #f5f5f5;
|
||||
9
src/renderer/src/assets/styles/container.css
Normal file
9
src/renderer/src/assets/styles/container.css
Normal file
@ -0,0 +1,9 @@
|
||||
#content-container {
|
||||
background-color: var(--color-background);
|
||||
}
|
||||
|
||||
[navbar-position='left'] #content-container {
|
||||
border-top: 0.5px solid var(--color-border);
|
||||
border-top-left-radius: 10px;
|
||||
border-left: 0.5px solid var(--color-border);
|
||||
}
|
||||
@ -1,11 +0,0 @@
|
||||
#content-container {
|
||||
background-color: var(--color-background);
|
||||
}
|
||||
|
||||
[navbar-position='left'] {
|
||||
#content-container {
|
||||
border-top: 0.5px solid var(--color-border);
|
||||
border-top-left-radius: 10px;
|
||||
border-left: 0.5px solid var(--color-border);
|
||||
}
|
||||
}
|
||||
24
src/renderer/src/assets/styles/font.css
Normal file
24
src/renderer/src/assets/styles/font.css
Normal file
@ -0,0 +1,24 @@
|
||||
:root {
|
||||
--font-family:
|
||||
var(--user-font-family), Ubuntu, -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui, Roboto, Oxygen,
|
||||
Cantarell, 'Open Sans', 'Helvetica Neue', Arial, 'Noto Sans', sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji',
|
||||
'Segoe UI Symbol', 'Noto Color Emoji';
|
||||
|
||||
--font-family-serif:
|
||||
serif, -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui, Ubuntu, Roboto, Oxygen, Cantarell, 'Open Sans',
|
||||
'Helvetica Neue', Arial, 'Noto Sans', 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
||||
|
||||
--code-font-family: var(--user-code-font-family), 'Cascadia Code', 'Fira Code', 'Consolas', Menlo, Courier, monospace;
|
||||
}
|
||||
|
||||
/* Windows系统专用字体配置 */
|
||||
body[os='windows'] {
|
||||
--font-family:
|
||||
var(--user-font-family), 'Twemoji Country Flags', Ubuntu, -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui,
|
||||
Roboto, Oxygen, Cantarell, 'Open Sans', 'Helvetica Neue', Arial, 'Noto Sans', sans-serif, 'Apple Color Emoji',
|
||||
'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
||||
|
||||
--code-font-family:
|
||||
var(--user-code-font-family), 'Cascadia Code', 'Fira Code', 'Consolas', 'Sarasa Mono SC', 'Microsoft YaHei UI',
|
||||
Courier, monospace;
|
||||
}
|
||||
@ -1,23 +0,0 @@
|
||||
:root {
|
||||
--font-family:
|
||||
Ubuntu, -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui, Roboto, Oxygen, Cantarell, 'Open Sans',
|
||||
'Helvetica Neue', Arial, 'Noto Sans', sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol',
|
||||
'Noto Color Emoji';
|
||||
|
||||
--font-family-serif:
|
||||
serif, -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui, Ubuntu, Roboto, Oxygen, Cantarell, 'Open Sans',
|
||||
'Helvetica Neue', Arial, 'Noto Sans', 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
||||
|
||||
--code-font-family: 'Cascadia Code', 'Fira Code', 'Consolas', Menlo, Courier, monospace;
|
||||
}
|
||||
|
||||
// Windows系统专用字体配置
|
||||
body[os='windows'] {
|
||||
--font-family:
|
||||
'Twemoji Country Flags', Ubuntu, -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui, Roboto, Oxygen,
|
||||
Cantarell, 'Open Sans', 'Helvetica Neue', Arial, 'Noto Sans', sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji',
|
||||
'Segoe UI Symbol', 'Noto Color Emoji';
|
||||
|
||||
--code-font-family:
|
||||
'Cascadia Code', 'Fira Code', 'Consolas', 'Sarasa Mono SC', 'Microsoft YaHei UI', Courier, monospace;
|
||||
}
|
||||
@ -1,11 +1,12 @@
|
||||
@use './color.scss';
|
||||
@use './font.scss';
|
||||
@use './markdown.scss';
|
||||
@use './ant.scss';
|
||||
@use './scrollbar.scss';
|
||||
@use './container.scss';
|
||||
@use './animation.scss';
|
||||
@use './richtext.scss';
|
||||
@import './color.css';
|
||||
@import './font.css';
|
||||
@import './markdown.css';
|
||||
@import './ant.css';
|
||||
@import './scrollbar.css';
|
||||
@import './container.css';
|
||||
@import './animation.css';
|
||||
@import './richtext.css';
|
||||
@import './responsive.css';
|
||||
@import '../fonts/icon-fonts/iconfont.css';
|
||||
@import '../fonts/ubuntu/ubuntu.css';
|
||||
@import '../fonts/country-flag-fonts/flag.css';
|
||||
@ -14,7 +15,7 @@
|
||||
*::before,
|
||||
*::after {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
/* margin: 0; */
|
||||
font-weight: normal;
|
||||
}
|
||||
|
||||
@ -34,11 +35,11 @@ body,
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
#root {
|
||||
/* #root {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
flex: 1;
|
||||
}
|
||||
} */
|
||||
|
||||
body {
|
||||
display: flex;
|
||||
@ -50,6 +51,7 @@ body {
|
||||
font-family: var(--font-family);
|
||||
text-rendering: optimizeLegibility;
|
||||
transition: background-color 0.3s linear;
|
||||
background-color: unset;
|
||||
|
||||
-webkit-font-smoothing: antialiased;
|
||||
-moz-osx-font-smoothing: grayscale;
|
||||
@ -113,62 +115,58 @@ ul {
|
||||
word-wrap: break-word;
|
||||
}
|
||||
|
||||
.bubble:not(.multi-select-mode) {
|
||||
.block-wrapper {
|
||||
display: flow-root;
|
||||
}
|
||||
.bubble:not(.multi-select-mode) .block-wrapper {
|
||||
display: flow-root;
|
||||
}
|
||||
|
||||
.block-wrapper:last-child > *:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
.bubble:not(.multi-select-mode) .block-wrapper:last-child > *:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.message-content-container > *:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
.bubble:not(.multi-select-mode) .message-content-container > *:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.message-thought-container {
|
||||
margin-top: 8px;
|
||||
}
|
||||
.bubble:not(.multi-select-mode) .message-thought-container {
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
.message-user {
|
||||
.message-header {
|
||||
flex-direction: row-reverse;
|
||||
text-align: right;
|
||||
.message-header-info-wrap {
|
||||
flex-direction: row-reverse;
|
||||
text-align: right;
|
||||
}
|
||||
}
|
||||
.message-content-container {
|
||||
border-radius: 10px;
|
||||
padding: 10px 16px 10px 16px;
|
||||
background-color: var(--chat-background-user);
|
||||
align-self: self-end;
|
||||
}
|
||||
.MessageFooter {
|
||||
margin-top: 2px;
|
||||
align-self: self-end;
|
||||
}
|
||||
}
|
||||
.bubble:not(.multi-select-mode) .message-user .message-header {
|
||||
flex-direction: row-reverse;
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
.message-assistant {
|
||||
.message-content-container {
|
||||
padding-left: 0;
|
||||
}
|
||||
.MessageFooter {
|
||||
margin-left: 0;
|
||||
}
|
||||
}
|
||||
.bubble:not(.multi-select-mode) .message-user .message-header .message-header-info-wrap {
|
||||
flex-direction: row-reverse;
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
code {
|
||||
color: var(--color-text);
|
||||
}
|
||||
.markdown {
|
||||
display: flow-root;
|
||||
*:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
}
|
||||
.bubble:not(.multi-select-mode) .message-user .message-content-container {
|
||||
border-radius: 10px;
|
||||
padding: 10px 16px 10px 16px;
|
||||
background-color: var(--chat-background-user);
|
||||
align-self: self-end;
|
||||
}
|
||||
|
||||
.bubble:not(.multi-select-mode) .message-user .MessageFooter {
|
||||
margin-top: 2px;
|
||||
align-self: self-end;
|
||||
}
|
||||
|
||||
.bubble:not(.multi-select-mode) .message-assistant .message-content-container {
|
||||
padding-left: 0;
|
||||
}
|
||||
|
||||
.bubble:not(.multi-select-mode) .message-assistant .MessageFooter {
|
||||
margin-left: 0;
|
||||
}
|
||||
|
||||
.bubble:not(.multi-select-mode) code {
|
||||
color: var(--color-text);
|
||||
}
|
||||
|
||||
.bubble:not(.multi-select-mode) .markdown {
|
||||
display: flow-root;
|
||||
}
|
||||
|
||||
.lucide:not(.lucide-custom) {
|
||||
@ -184,8 +182,6 @@ ul {
|
||||
background-color: var(--color-background-highlight-accent);
|
||||
}
|
||||
|
||||
textarea {
|
||||
&::-webkit-resizer {
|
||||
display: none;
|
||||
}
|
||||
textarea::-webkit-resizer {
|
||||
display: none;
|
||||
}
|
||||
388
src/renderer/src/assets/styles/markdown.css
Normal file
388
src/renderer/src/assets/styles/markdown.css
Normal file
@ -0,0 +1,388 @@
|
||||
.markdown {
|
||||
color: var(--color-text);
|
||||
line-height: 1.6;
|
||||
user-select: text;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.markdown h1:first-child,
|
||||
.markdown h2:first-child,
|
||||
.markdown h3:first-child,
|
||||
.markdown h4:first-child,
|
||||
.markdown h5:first-child,
|
||||
.markdown h6:first-child {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.markdown h1,
|
||||
.markdown h2,
|
||||
.markdown h3,
|
||||
.markdown h4,
|
||||
.markdown h5,
|
||||
.markdown h6 {
|
||||
margin: 1.5em 0 1em 0;
|
||||
line-height: 1.3;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.markdown h1 {
|
||||
margin-top: 0;
|
||||
font-size: 2em;
|
||||
border-bottom: 0.5px solid var(--color-border);
|
||||
padding-bottom: 0.3em;
|
||||
}
|
||||
|
||||
.markdown h2 {
|
||||
font-size: 1.5em;
|
||||
border-bottom: 0.5px solid var(--color-border);
|
||||
padding-bottom: 0.3em;
|
||||
}
|
||||
|
||||
.markdown h3 {
|
||||
font-size: 1.2em;
|
||||
}
|
||||
|
||||
.markdown h4 {
|
||||
font-size: 1em;
|
||||
}
|
||||
|
||||
.markdown h5 {
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
.markdown h6 {
|
||||
font-size: 0.8em;
|
||||
}
|
||||
|
||||
.markdown p {
|
||||
margin: 1.3em 0;
|
||||
white-space: pre-wrap;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.markdown p:last-child {
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
|
||||
.markdown p:first-child {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.markdown p:has(+ ul) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.markdown ul {
|
||||
list-style: initial;
|
||||
}
|
||||
|
||||
.markdown ul,
|
||||
.markdown ol {
|
||||
padding-left: 1.5em;
|
||||
margin: 1em 0;
|
||||
}
|
||||
|
||||
.markdown li {
|
||||
margin-bottom: 0.5em;
|
||||
}
|
||||
|
||||
.markdown li pre {
|
||||
margin: 1.5em 0 !important;
|
||||
}
|
||||
|
||||
.markdown li::marker {
|
||||
color: var(--color-text-3);
|
||||
}
|
||||
|
||||
.markdown li > ul,
|
||||
.markdown li > ol {
|
||||
margin: 0.5em 0;
|
||||
}
|
||||
|
||||
.markdown hr {
|
||||
border: none;
|
||||
border-top: 0.5px solid var(--color-border);
|
||||
margin: 20px 0;
|
||||
}
|
||||
|
||||
.markdown span {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.markdown .katex span {
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
.markdown p code,
|
||||
.markdown li code {
|
||||
background: var(--color-background-mute);
|
||||
padding: 3px 5px;
|
||||
margin: 0 2px;
|
||||
border-radius: 5px;
|
||||
word-break: keep-all;
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
.markdown code {
|
||||
font-family: var(--code-font-family);
|
||||
}
|
||||
|
||||
.markdown pre {
|
||||
border-radius: 8px;
|
||||
overflow-x: auto;
|
||||
font-family: var(--code-font-family);
|
||||
background-color: var(--color-background-mute);
|
||||
}
|
||||
|
||||
.markdown pre:has(.special-preview) {
|
||||
background-color: transparent;
|
||||
}
|
||||
|
||||
.markdown pre:not(pre pre) > code:not(pre pre > code) {
|
||||
padding: 15px;
|
||||
display: block;
|
||||
}
|
||||
|
||||
.markdown pre pre {
|
||||
margin: 0 !important;
|
||||
}
|
||||
|
||||
.markdown pre pre code {
|
||||
background: none;
|
||||
padding: 0;
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
.markdown pre + pre {
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.markdown .markdown-alert,
|
||||
.markdown blockquote {
|
||||
margin: 1.5em 0;
|
||||
padding: 1em 1.5em;
|
||||
background-color: var(--color-background-soft);
|
||||
border-left: 4px solid var(--color-primary);
|
||||
border-radius: 0 8px 8px 0;
|
||||
font-style: italic;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.markdown table {
|
||||
--table-border-radius: 8px;
|
||||
margin: 2em 0;
|
||||
font-size: 0.9em;
|
||||
width: 100%;
|
||||
border-radius: var(--table-border-radius);
|
||||
overflow: hidden;
|
||||
border-collapse: separate;
|
||||
border: 0.5px solid var(--color-border);
|
||||
border-spacing: 0;
|
||||
}
|
||||
|
||||
.markdown th,
|
||||
.markdown td {
|
||||
border-right: 0.5px solid var(--color-border);
|
||||
border-bottom: 0.5px solid var(--color-border);
|
||||
padding: 0.5em;
|
||||
}
|
||||
|
||||
.markdown th:last-child,
|
||||
.markdown td:last-child {
|
||||
border-right: none;
|
||||
}
|
||||
|
||||
.markdown tr:last-child td {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
.markdown th {
|
||||
background-color: var(--color-background-mute);
|
||||
font-weight: 600;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.markdown tr:hover {
|
||||
background-color: var(--color-background-soft);
|
||||
}
|
||||
|
||||
.markdown img {
|
||||
max-width: 100%;
|
||||
height: auto;
|
||||
margin: 1em 0;
|
||||
}
|
||||
|
||||
.markdown a,
|
||||
.markdown .link {
|
||||
color: var(--color-link);
|
||||
text-decoration: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.markdown a:hover,
|
||||
.markdown .link:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.markdown strong {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.markdown em {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.markdown del {
|
||||
text-decoration: line-through;
|
||||
}
|
||||
|
||||
.markdown sup,
|
||||
.markdown sub {
|
||||
font-size: 75%;
|
||||
line-height: 0;
|
||||
position: relative;
|
||||
vertical-align: baseline;
|
||||
}
|
||||
|
||||
.markdown sup {
|
||||
top: -0.5em;
|
||||
border-radius: 50%;
|
||||
background-color: var(--color-reference);
|
||||
color: var(--color-reference-text);
|
||||
padding: 2px 5px;
|
||||
zoom: 0.8;
|
||||
}
|
||||
|
||||
.markdown sup > span.link {
|
||||
color: var(--color-reference-text);
|
||||
}
|
||||
|
||||
.markdown sub {
|
||||
bottom: -0.25em;
|
||||
}
|
||||
|
||||
.markdown .footnote-ref {
|
||||
font-size: 0.8em;
|
||||
vertical-align: super;
|
||||
line-height: 0;
|
||||
margin: 0 2px;
|
||||
color: var(--color-primary);
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.markdown .footnote-ref:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.footnotes {
|
||||
margin-top: 1em;
|
||||
margin-bottom: 1em;
|
||||
padding-top: 1em;
|
||||
|
||||
background-color: var(--color-reference-background);
|
||||
border-radius: 8px;
|
||||
padding: 8px 12px;
|
||||
}
|
||||
|
||||
.footnotes h4 {
|
||||
margin-bottom: 5px;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.footnotes a {
|
||||
color: var(--color-link);
|
||||
}
|
||||
|
||||
.footnotes ol {
|
||||
padding-left: 1em;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.footnotes ol li:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.footnotes li {
|
||||
font-size: 0.9em;
|
||||
margin-bottom: 0.5em;
|
||||
color: var(--color-text-light);
|
||||
}
|
||||
|
||||
.footnotes li p {
|
||||
display: inline;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.footnotes .footnote-backref {
|
||||
font-size: 0.8em;
|
||||
vertical-align: super;
|
||||
line-height: 0;
|
||||
margin-left: 5px;
|
||||
color: var(--color-primary);
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.footnotes .footnote-backref:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
emoji-picker {
|
||||
--border-size: 0;
|
||||
}
|
||||
|
||||
.block-wrapper + .block-wrapper {
|
||||
margin-top: 1em;
|
||||
}
|
||||
|
||||
.katex,
|
||||
mjx-container {
|
||||
display: inline-block;
|
||||
overflow-x: auto;
|
||||
overflow-y: hidden;
|
||||
overflow-wrap: break-word;
|
||||
vertical-align: middle;
|
||||
max-width: 100%;
|
||||
padding: 1px 2px;
|
||||
margin-top: -2px;
|
||||
}
|
||||
|
||||
/* Shiki 相关样式 */
|
||||
.shiki {
|
||||
font-family: var(--code-font-family);
|
||||
/* 保持行高为初始值,在 shiki 代码块中处理 */
|
||||
line-height: initial;
|
||||
}
|
||||
|
||||
/* CodeMirror 相关样式 */
|
||||
.cm-editor {
|
||||
border-radius: inherit;
|
||||
}
|
||||
|
||||
.cm-editor.cm-focused {
|
||||
outline: none;
|
||||
}
|
||||
|
||||
.cm-editor .cm-scroller {
|
||||
font-family: var(--code-font-family);
|
||||
border-radius: inherit;
|
||||
}
|
||||
|
||||
.cm-editor .cm-scroller .cm-gutters {
|
||||
line-height: 1.6;
|
||||
border-right: none;
|
||||
}
|
||||
|
||||
.cm-editor .cm-scroller .cm-content {
|
||||
line-height: 1.6;
|
||||
padding-left: 0.25em;
|
||||
}
|
||||
|
||||
.cm-editor .cm-scroller .cm-lineWrapping * {
|
||||
word-wrap: break-word;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.cm-editor .cm-announced {
|
||||
position: absolute;
|
||||
display: none;
|
||||
}
|
||||
@ -1,379 +0,0 @@
|
||||
.markdown {
|
||||
color: var(--color-text);
|
||||
line-height: 1.6;
|
||||
user-select: text;
|
||||
word-break: break-word;
|
||||
|
||||
h1:first-child,
|
||||
h2:first-child,
|
||||
h3:first-child,
|
||||
h4:first-child,
|
||||
h5:first-child,
|
||||
h6:first-child {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
h1,
|
||||
h2,
|
||||
h3,
|
||||
h4,
|
||||
h5,
|
||||
h6 {
|
||||
margin: 1.5em 0 1em 0;
|
||||
line-height: 1.3;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
h1 {
|
||||
margin-top: 0;
|
||||
font-size: 2em;
|
||||
border-bottom: 0.5px solid var(--color-border);
|
||||
padding-bottom: 0.3em;
|
||||
}
|
||||
|
||||
h2 {
|
||||
font-size: 1.5em;
|
||||
border-bottom: 0.5px solid var(--color-border);
|
||||
padding-bottom: 0.3em;
|
||||
}
|
||||
|
||||
h3 {
|
||||
font-size: 1.2em;
|
||||
}
|
||||
|
||||
h4 {
|
||||
font-size: 1em;
|
||||
}
|
||||
|
||||
h5 {
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
h6 {
|
||||
font-size: 0.8em;
|
||||
}
|
||||
|
||||
p {
|
||||
margin: 1.3em 0;
|
||||
white-space: pre-wrap;
|
||||
line-height: 1.6;
|
||||
|
||||
&:last-child {
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
|
||||
&:first-child {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
&:has(+ ul) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
}
|
||||
|
||||
ul {
|
||||
list-style: initial;
|
||||
}
|
||||
|
||||
ul,
|
||||
ol {
|
||||
padding-left: 1.5em;
|
||||
margin: 1em 0;
|
||||
}
|
||||
|
||||
li {
|
||||
margin-bottom: 0.5em;
|
||||
pre {
|
||||
margin: 1.5em 0 !important;
|
||||
}
|
||||
&::marker {
|
||||
color: var(--color-text-3);
|
||||
}
|
||||
}
|
||||
|
||||
li > ul,
|
||||
li > ol {
|
||||
margin: 0.5em 0;
|
||||
}
|
||||
|
||||
hr {
|
||||
border: none;
|
||||
border-top: 0.5px solid var(--color-border);
|
||||
margin: 20px 0;
|
||||
}
|
||||
|
||||
span {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.katex span {
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
p code,
|
||||
li code {
|
||||
background: var(--color-background-mute);
|
||||
padding: 3px 5px;
|
||||
margin: 0 2px;
|
||||
border-radius: 5px;
|
||||
word-break: keep-all;
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
code {
|
||||
font-family: var(--code-font-family);
|
||||
}
|
||||
|
||||
pre {
|
||||
border-radius: 8px;
|
||||
overflow-x: auto;
|
||||
font-family: var(--code-font-family);
|
||||
background-color: var(--color-background-mute);
|
||||
&:has(.special-preview) {
|
||||
background-color: transparent;
|
||||
}
|
||||
&:not(pre pre) {
|
||||
> code:not(pre pre > code) {
|
||||
padding: 15px;
|
||||
display: block;
|
||||
}
|
||||
}
|
||||
pre {
|
||||
margin: 0 !important;
|
||||
code {
|
||||
background: none;
|
||||
padding: 0;
|
||||
border-radius: 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pre + pre {
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.markdown-alert,
|
||||
blockquote {
|
||||
margin: 1.5em 0;
|
||||
padding: 1em 1.5em;
|
||||
background-color: var(--color-background-soft);
|
||||
border-left: 4px solid var(--color-primary);
|
||||
border-radius: 0 8px 8px 0;
|
||||
font-style: italic;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
table {
|
||||
--table-border-radius: 8px;
|
||||
margin: 2em 0;
|
||||
font-size: 0.9em;
|
||||
width: 100%;
|
||||
border-radius: var(--table-border-radius);
|
||||
overflow: hidden;
|
||||
border-collapse: separate;
|
||||
border: 0.5px solid var(--color-border);
|
||||
border-spacing: 0;
|
||||
}
|
||||
|
||||
th,
|
||||
td {
|
||||
border-right: 0.5px solid var(--color-border);
|
||||
border-bottom: 0.5px solid var(--color-border);
|
||||
padding: 0.5em;
|
||||
&:last-child {
|
||||
border-right: none;
|
||||
}
|
||||
}
|
||||
|
||||
tr:last-child td {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
th {
|
||||
background-color: var(--color-background-mute);
|
||||
font-weight: 600;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
tr:hover {
|
||||
background-color: var(--color-background-soft);
|
||||
}
|
||||
|
||||
img {
|
||||
max-width: 100%;
|
||||
height: auto;
|
||||
margin: 1em 0;
|
||||
}
|
||||
|
||||
a,
|
||||
.link {
|
||||
color: var(--color-link);
|
||||
text-decoration: none;
|
||||
cursor: pointer;
|
||||
|
||||
&:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
}
|
||||
|
||||
strong {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
em {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
del {
|
||||
text-decoration: line-through;
|
||||
}
|
||||
|
||||
sup,
|
||||
sub {
|
||||
font-size: 75%;
|
||||
line-height: 0;
|
||||
position: relative;
|
||||
vertical-align: baseline;
|
||||
}
|
||||
|
||||
sup {
|
||||
top: -0.5em;
|
||||
border-radius: 50%;
|
||||
background-color: var(--color-reference);
|
||||
color: var(--color-reference-text);
|
||||
padding: 2px 5px;
|
||||
zoom: 0.8;
|
||||
& > span.link {
|
||||
color: var(--color-reference-text);
|
||||
}
|
||||
}
|
||||
|
||||
sub {
|
||||
bottom: -0.25em;
|
||||
}
|
||||
|
||||
.footnote-ref {
|
||||
font-size: 0.8em;
|
||||
vertical-align: super;
|
||||
line-height: 0;
|
||||
margin: 0 2px;
|
||||
color: var(--color-primary);
|
||||
text-decoration: none;
|
||||
|
||||
&:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.footnotes {
|
||||
margin-top: 1em;
|
||||
margin-bottom: 1em;
|
||||
padding-top: 1em;
|
||||
|
||||
background-color: var(--color-reference-background);
|
||||
border-radius: 8px;
|
||||
padding: 8px 12px;
|
||||
|
||||
h4 {
|
||||
margin-bottom: 5px;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
a {
|
||||
color: var(--color-link);
|
||||
}
|
||||
|
||||
ol {
|
||||
padding-left: 1em;
|
||||
margin: 0;
|
||||
li:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
}
|
||||
|
||||
li {
|
||||
font-size: 0.9em;
|
||||
margin-bottom: 0.5em;
|
||||
color: var(--color-text-light);
|
||||
|
||||
p {
|
||||
display: inline;
|
||||
margin: 0;
|
||||
}
|
||||
}
|
||||
|
||||
.footnote-backref {
|
||||
font-size: 0.8em;
|
||||
vertical-align: super;
|
||||
line-height: 0;
|
||||
margin-left: 5px;
|
||||
color: var(--color-primary);
|
||||
text-decoration: none;
|
||||
|
||||
&:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
emoji-picker {
|
||||
--border-size: 0;
|
||||
}
|
||||
|
||||
.block-wrapper + .block-wrapper {
|
||||
margin-top: 1em;
|
||||
}
|
||||
|
||||
.katex,
|
||||
mjx-container {
|
||||
display: inline-block;
|
||||
overflow-x: auto;
|
||||
overflow-y: hidden;
|
||||
overflow-wrap: break-word;
|
||||
vertical-align: middle;
|
||||
max-width: 100%;
|
||||
padding: 1px 2px;
|
||||
margin-top: -2px;
|
||||
}
|
||||
|
||||
/* Shiki 相关样式 */
|
||||
.shiki {
|
||||
font-family: var(--code-font-family);
|
||||
// 保持行高为初始值,在 shiki 代码块中处理
|
||||
line-height: initial;
|
||||
}
|
||||
|
||||
/* CodeMirror 相关样式 */
|
||||
.cm-editor {
|
||||
border-radius: inherit;
|
||||
|
||||
&.cm-focused {
|
||||
outline: none;
|
||||
}
|
||||
|
||||
.cm-scroller {
|
||||
font-family: var(--code-font-family);
|
||||
border-radius: inherit;
|
||||
|
||||
.cm-gutters {
|
||||
line-height: 1.6;
|
||||
border-right: none;
|
||||
}
|
||||
|
||||
.cm-content {
|
||||
line-height: 1.6;
|
||||
padding-left: 0.25em;
|
||||
}
|
||||
|
||||
.cm-lineWrapping * {
|
||||
word-wrap: break-word;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
}
|
||||
|
||||
.cm-announced {
|
||||
position: absolute;
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user