Merge branch 'main' of github.com:CherryHQ/cherry-studio into wip/data-refactor

This commit is contained in:
fullex 2025-09-06 18:20:53 +08:00
commit 2931e558b3
387 changed files with 33330 additions and 12709 deletions

View File

@ -1,94 +0,0 @@
name: 🐛 错误报告 (中文)
description: 创建一个报告以帮助我们改进
title: '[错误]: '
labels: ['BUG']
body:
- type: markdown
attributes:
value: |
感谢您花时间填写此错误报告!
在提交此问题之前,请确保您已经了解了[常见问题](https://docs.cherry-ai.com/question-contact/questions)和[知识科普](https://docs.cherry-ai.com/question-contact/knowledge)
- type: checkboxes
id: checklist
attributes:
label: 提交前检查
description: |
在提交 Issue 前请确保您已经完成了以下所有步骤
options:
- label: 我理解 Issue 是用于反馈和解决问题的,而非吐槽评论区,将尽可能提供更多信息帮助问题解决。
required: true
- label: 我的问题不是 [常见问题](https://github.com/CherryHQ/cherry-studio/issues/3881) 中的内容。
required: true
- label: 我已经查看了 **置顶 Issue** 并搜索了现有的 [开放Issue](https://github.com/CherryHQ/cherry-studio/issues)和[已关闭Issue](https://github.com/CherryHQ/cherry-studio/issues?q=is%3Aissue%20state%3Aclosed%20),没有找到类似的问题。
required: true
- label: 我填写了简短且清晰明确的标题,以便开发者在翻阅 Issue 列表时能快速确定大致问题。而不是“一个建议”、“卡住了”等。
required: true
- label: 我确认我正在使用最新版本的 Cherry Studio。
required: true
- type: dropdown
id: platform
attributes:
label: 平台
description: 您正在使用哪个平台?
options:
- Windows
- macOS
- Linux
validations:
required: true
- type: input
id: version
attributes:
label: 版本
description: 您正在运行的 Cherry Studio 版本是什么?
placeholder: 例如 v1.0.0
validations:
required: true
- type: textarea
id: description
attributes:
label: 错误描述
description: 描述问题时请尽可能详细。请尽可能提供截图或屏幕录制,以帮助我们更好地理解问题。
placeholder: 告诉我们发生了什么...(记得附上截图/录屏,如果适用)
validations:
required: true
- type: textarea
id: reproduction
attributes:
label: 重现步骤
description: 提供详细的重现步骤,以便于我们的开发人员可以准确地重现问题。请尽可能为每个步骤提供截图或屏幕录制。
placeholder: |
1. 转到 '...'
2. 点击 '....'
3. 向下滚动到 '....'
4. 看到错误
记得尽可能为每个步骤附上截图/录屏!
validations:
required: true
- type: textarea
id: expected
attributes:
label: 预期行为
description: 清晰简洁地描述您期望发生的事情
validations:
required: true
- type: textarea
id: logs
attributes:
label: 相关日志输出
description: 请复制并粘贴任何相关的日志输出
render: shell
- type: textarea
id: additional
attributes:
label: 附加信息
description: 任何能让我们对你所遇到的问题有更多了解的东西

View File

@ -1,76 +0,0 @@
name: 💡 功能建议 (中文)
description: 为项目提出新的想法
title: '[功能]: '
labels: ['feature']
body:
- type: markdown
attributes:
value: |
感谢您花时间提出新的功能建议!
在提交此问题之前,请确保您已经了解了[项目规划](https://docs.cherry-ai.com/cherrystudio/planning)和[功能介绍](https://docs.cherry-ai.com/cherrystudio/preview)
- type: checkboxes
id: checklist
attributes:
label: 提交前检查
description: |
在提交 Issue 前请确保您已经完成了以下所有步骤
options:
- label: 我理解 Issue 是用于反馈和解决问题的,而非吐槽评论区,将尽可能提供更多信息帮助问题解决。
required: true
- label: 我已经查看了置顶 Issue 并搜索了现有的 [开放Issue](https://github.com/CherryHQ/cherry-studio/issues)和[已关闭Issue](https://github.com/CherryHQ/cherry-studio/issues?q=is%3Aissue%20state%3Aclosed%20),没有找到类似的建议。
required: true
- label: 我填写了简短且清晰明确的标题,以便开发者在翻阅 Issue 列表时能快速确定大致问题。而不是“一个建议”、“卡住了”等。
required: true
- label: 最新的 Cherry Studio 版本没有实现我所提出的功能。
required: true
- type: dropdown
id: platform
attributes:
label: 平台
description: 您正在使用哪个平台?
options:
- Windows
- macOS
- Linux
validations:
required: true
- type: input
id: version
attributes:
label: 版本
description: 您正在运行的 Cherry Studio 版本是什么?
placeholder: 例如 v1.0.0
validations:
required: true
- type: textarea
id: problem
attributes:
label: 您的功能建议是否与某个问题/issue相关?
description: 请简明扼要地描述您遇到的问题
placeholder: 我总是感到沮丧,因为...
validations:
required: true
- type: textarea
id: solution
attributes:
label: 请描述您希望实现的解决方案
description: 请简明扼要地描述您希望发生的情况
validations:
required: true
- type: textarea
id: alternatives
attributes:
label: 请描述您考虑过的其他方案
description: 请简明扼要地描述您考虑过的任何其他解决方案或功能
- type: textarea
id: additional
attributes:
label: 其他补充信息
description: 在此添加任何其他与功能建议相关的上下文或截图

View File

@ -1,77 +0,0 @@
name: ❓ 提问 & 讨论 (中文)
description: 寻求帮助、讨论问题、提出疑问等...
title: '[讨论]: '
labels: ['discussion', 'help wanted']
body:
- type: markdown
attributes:
value: |
感谢您的提问!请尽可能详细地描述您的问题,这样我们才能更好地帮助您。
- type: checkboxes
id: checklist
attributes:
label: Issue 检查清单
description: |
在提交 Issue 前请确保您已经完成了以下所有步骤
options:
- label: 我理解 Issue 是用于反馈和解决问题的,而非吐槽评论区,将尽可能提供更多信息帮助问题解决。
required: true
- label: 我确认自己需要的是提出问题并且讨论问题,而不是 Bug 反馈或需求建议。
required: true
- type: dropdown
id: platform
attributes:
label: 平台
description: 您正在使用哪个平台?
options:
- Windows
- macOS
- Linux
validations:
required: true
- type: input
id: version
attributes:
label: 版本
description: 您正在运行的 Cherry Studio 版本是什么?
placeholder: 例如 v1.0.0
validations:
required: true
- type: textarea
id: question
attributes:
label: 您的问题
description: 请详细描述您的问题
placeholder: 请尽可能清楚地说明您的问题...
validations:
required: true
- type: textarea
id: context
attributes:
label: 相关背景
description: 请提供一些背景信息,帮助我们更好地理解您的问题
placeholder: 例如:使用场景、已尝试的解决方案等
- type: textarea
id: additional
attributes:
label: 补充信息
description: 任何其他相关的信息、截图或代码示例
render: shell
- type: dropdown
id: priority
attributes:
label: 优先级
description: 这个问题对您来说有多紧急?
options:
- 低 (有空再看)
- 中 (希望尽快得到答复)
- 高 (阻碍工作进行)
validations:
required: true

View File

@ -1,76 +0,0 @@
name: 🤔 其他问题 (中文)
description: 提交不属于错误报告或功能需求的问题
title: '[其他]: '
body:
- type: markdown
attributes:
value: |
感谢您花时间提出问题!
在提交此问题之前,请确保您已经了解了[常见问题](https://docs.cherry-ai.com/question-contact/questions)和[知识科普](https://docs.cherry-ai.com/question-contact/knowledge)
- type: checkboxes
id: checklist
attributes:
label: 提交前检查
description: |
在提交 Issue 前请确保您已经完成了以下所有步骤
options:
- label: 我理解 Issue 是用于反馈和解决问题的,而非吐槽评论区,将尽可能提供更多信息帮助问题解决。
required: true
- label: 我已经查看了置顶 Issue 并搜索了现有的 [开放Issue](https://github.com/CherryHQ/cherry-studio/issues)和[已关闭Issue](https://github.com/CherryHQ/cherry-studio/issues?q=is%3Aissue%20state%3Aclosed%20),没有找到类似的问题。
required: true
- label: 我填写了简短且清晰明确的标题,以便开发者在翻阅 Issue 列表时能快速确定大致问题。而不是"一个问题"、"求助"等。
required: true
- label: 我的问题不属于错误报告或功能需求类别。
required: true
- type: dropdown
id: platform
attributes:
label: 平台
description: 您正在使用哪个平台?
options:
- Windows
- macOS
- Linux
validations:
required: true
- type: input
id: version
attributes:
label: 版本
description: 您正在运行的 Cherry Studio 版本是什么?
placeholder: 例如 v1.0.0
validations:
required: true
- type: textarea
id: question
attributes:
label: 问题描述
description: 请详细描述您的问题或疑问
placeholder: 我想了解有关...的更多信息
validations:
required: true
- type: textarea
id: context
attributes:
label: 相关背景
description: 请提供与您的问题相关的任何背景信息或上下文
placeholder: 我尝试实现...时遇到了疑问
validations:
required: true
- type: textarea
id: attempts
attributes:
label: 您已尝试的方法
description: 请描述您为解决问题已经尝试过的方法(如果有)
- type: textarea
id: additional
attributes:
label: 附加信息
description: 任何能让我们对您的问题有更多了解的信息,包括截图或相关链接

66
.github/workflows/auto-i18n.yml vendored Normal file
View File

@ -0,0 +1,66 @@
name: Auto I18N
env:
API_KEY: ${{ secrets.TRANSLATE_API_KEY}}
MODEL: ${{ vars.MODEL || 'deepseek/deepseek-v3.1'}}
BASE_URL: ${{ vars.BASE_URL || 'https://api.ppinfra.com/openai'}}
on:
pull_request:
types: [opened, synchronize, reopened]
workflow_dispatch:
jobs:
auto-i18n:
runs-on: ubuntu-latest
if: github.event.pull_request.head.repo.full_name == 'CherryHQ/cherry-studio'
name: Auto I18N
permissions:
contents: write
pull-requests: write
steps:
- name: 🐈‍⬛ Checkout
uses: actions/checkout@v5
with:
ref: ${{ github.event.pull_request.head.ref }}
- name: 📦 Setting Node.js
uses: actions/setup-node@v4
with:
node-version: 20
- name: 📦 Install dependencies in isolated directory
run: |
# 在临时目录安装依赖
mkdir -p /tmp/translation-deps
cd /tmp/translation-deps
echo '{"dependencies": {"openai": "^5.12.2", "cli-progress": "^3.12.0", "tsx": "^4.20.3", "prettier": "^3.5.3", "prettier-plugin-sort-json": "^4.1.1"}}' > package.json
npm install --no-package-lock
# 设置 NODE_PATH 让项目能找到这些依赖
echo "NODE_PATH=/tmp/translation-deps/node_modules" >> $GITHUB_ENV
- name: 🏃‍♀️ Translate
run: npx tsx scripts/auto-translate-i18n.ts
- name: 🔍 Format
run: cd /tmp/translation-deps && npx prettier --write --config /home/runner/work/cherry-studio/cherry-studio/.prettierrc /home/runner/work/cherry-studio/cherry-studio/src/renderer/src/i18n/
- name: 🔄 Commit changes
run: |
git config --local user.email "action@github.com"
git config --local user.name "GitHub Action"
git add .
git reset -- package.json yarn.lock # 不提交 package.json 和 yarn.lock 的更改
if git diff --cached --quiet; then
echo "No changes to commit"
else
git commit -m "fix(i18n): Auto update translations for PR #${{ github.event.pull_request.number }}"
fi
- name: 🚀 Push changes
uses: ad-m/github-push-action@master
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
branch: ${{ github.event.pull_request.head.ref }}

View File

@ -0,0 +1,54 @@
name: Claude Code Review
on:
pull_request:
types: [opened, synchronize]
# Optional: Only run on specific file changes
# paths:
# - "src/**/*.ts"
# - "src/**/*.tsx"
# - "src/**/*.js"
# - "src/**/*.jsx"
jobs:
claude-review:
# Optional: Filter by PR author
# if: |
# github.event.pull_request.user.login == 'external-contributor' ||
# github.event.pull_request.user.login == 'new-developer' ||
# github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR'
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
issues: read
id-token: write
actions: read
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Run Claude Code Review
id: claude-review
uses: anthropics/claude-code-action@v1
with:
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
prompt: |
Please review this pull request and provide feedback on:
- Code quality and best practices
- Potential bugs or issues
- Performance considerations
- Security concerns
- Test coverage
Use the repository's CLAUDE.md for guidance on style and conventions. Be constructive and helpful in your feedback.
Use `gh pr comment` with your Bash tool to leave your review as a comment on the PR.
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
# or https://docs.anthropic.com/en/docs/claude-code/sdk#command-line for available options
claude_args: '--allowed-tools "Bash(gh issue view:*),Bash(gh search:*),Bash(gh issue list:*),Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*)"'

69
.github/workflows/claude-translator.yml vendored Normal file
View File

@ -0,0 +1,69 @@
name: English Translator
concurrency:
group: translator-${{ github.event.issue.number }}
cancel-in-progress: false
on:
issues:
types: [opened]
issue_comment:
types: [created, edited]
jobs:
translate:
if: |
(github.event_name == 'issues' && github.event.issue.author_association == 'COLLABORATOR' && !contains(github.event.issue.body, 'This issue/comment was translated by Claude.')) ||
(github.event_name == 'issue_comment' && github.event.comment.author_association == 'COLLABORATOR' && !contains(github.event.issue.body, 'This issue/comment was translated by Claude.'))
runs-on: ubuntu-latest
permissions:
contents: read
issues: write # 编辑issues/comments
pull-requests: read
id-token: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Run Claude for translation
uses: anthropics/claude-code-action@v1
id: claude
with:
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
claude_args: '--allowed-tools mcp__github_comment__update_claude_comment,Bash(gh issue:*),Bash(gh api:repos/*/issues:*)'
prompt: |
你是一个多语言翻译助手。请完成以下任务:
1. 获取当前issue/comment的完整信息
2. 智能检测内容。
1. 如果是已经遵循格式要求翻译过的issue/comment检查翻译内容和原始内容是否匹配。若不匹配则重新翻译一次令其匹配并遵循格式要求若匹配则跳过任务。
2. 如果是未翻译过的issue/comment检查其内容语言。若不是英文则翻译成英文若已经是英文则跳过任务。
3. 格式要求:
- 标题:英文翻译(如果非英文)
- 内容格式:
> [!NOTE]
> This issue/comment was translated by Claude.
[英文翻译内容]
---
<details>
<summary>**Original Content:**</summary>
[原始内容]
</details>
4. 使用gh工具更新
- 根据环境信息中的Event类型选择正确的命令
- 如果Event是'issues'gh issue edit [ISSUE_NUMBER] --title "[英文标题]" --body "[翻译内容 + 原始内容]"
- 如果Event是'issue_comment'gh api -X PATCH /repos/[REPO]/issues/comments/[COMMENT_ID] -f body="[翻译内容 + 原始内容]"
环境信息:
- Event: ${{ github.event_name }}
- Issue Number: ${{ github.event.issue.number }}
- Repository: ${{ github.repository }}
- Comment ID: ${{ github.event.comment.id || 'N/A' }} (only available for comment events)
使用以下命令获取完整信息:
gh issue view ${{ github.event.issue.number }} --json title,body,comments

60
.github/workflows/claude.yml vendored Normal file
View File

@ -0,0 +1,60 @@
name: Claude Code
on:
issue_comment:
types: [created]
pull_request_review_comment:
types: [created]
issues:
types: [opened]
pull_request_review:
types: [submitted]
jobs:
claude:
if: |
(github.event_name == 'issue_comment'
&& contains(github.event.comment.body, '@claude')
&& contains(fromJSON('["COLLABORATOR","MEMBER","OWNER"]'), github.event.comment.author_association))
||
(github.event_name == 'pull_request_review_comment'
&& contains(github.event.comment.body, '@claude')
&& contains(fromJSON('["COLLABORATOR","MEMBER","OWNER"]'), github.event.comment.author_association))
||
(github.event_name == 'pull_request_review'
&& contains(github.event.review.body, '@claude')
&& contains(fromJSON('["COLLABORATOR","MEMBER","OWNER"]'), github.event.review.author_association))
||
(github.event_name == 'issues'
&& (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude'))
&& contains(fromJSON('["COLLABORATOR","MEMBER","OWNER"]'), github.event.issue.author_association))
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: read
issues: read
id-token: write
actions: read # Required for Claude to read CI results on PRs
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Run Claude Code
id: claude
uses: anthropics/claude-code-action@v1
with:
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
# This is an optional setting that allows Claude to read CI results on PRs
additional_permissions: |
actions: read
# Optional: Give a custom prompt to Claude. If this is not specified, Claude will perform the instructions specified in the comment that tagged it.
# prompt: 'Update the pull request description to include a summary of changes.'
# Optional: Add claude_args to customize behavior and configuration
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
# or https://docs.anthropic.com/en/docs/claude-code/sdk#command-line for available options
# claude_args: '--model claude-opus-4-1-20250805 --allowed-tools Bash(gh pr:*)'

View File

@ -45,6 +45,9 @@ jobs:
- name: Install Dependencies
run: yarn install
- name: Format Check
run: yarn format:check
- name: Lint Check
run: yarn test:lint

View File

@ -7,4 +7,6 @@ tsconfig.*.json
CHANGELOG*.md
agents.json
src/renderer/src/integration/nutstore/sso/lib
src/main/integration/cherryin/index.js
AGENT.md
src/main/integration/
.yarn/releases/

45
LICENSE
View File

@ -1,48 +1,3 @@
**许可协议 (Licensing)**
本项目采用**区分用户的双重许可 (User-Segmented Dual Licensing)** 模式。
**核心原则:**
* **个人用户 和 10人及以下企业/组织:** 默认适用 **GNU Affero 通用公共许可证 v3.0 (AGPLv3)**。
* **超过10人的企业/组织:** **必须** 获取 **商业许可证 (Commercial License)**。
定义“10人及以下”
指在您的组织包括公司、非营利组织、政府机构、教育机构等任何实体能够访问、使用或以任何方式直接或间接受益于本软件Cherry Studio功能的个人总数不超过10人。这包括但不限于开发者、测试人员、运营人员、最终用户、通过集成系统间接使用者等。
---
**1. 开源许可证 (Open Source License): AGPLv3 - 适用于个人及10人及以下组织**
* 如果您是个人用户或者您的组织满足上述“10人及以下”的定义您可以在 **AGPLv3** 的条款下自由使用、修改和分发 Cherry Studio。AGPLv3 的完整文本可以访问 [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html) 获取。
* **核心义务:** AGPLv3 的一个关键要求是,如果您修改了 Cherry Studio 并通过网络提供服务,或者分发了修改后的版本,您必须以 AGPLv3 许可证向接收者提供相应的**完整源代码**。即使您符合“10人及以下”的标准如果您希望避免此源代码公开义务您也需要考虑获取商业许可证见下文
* 使用前请务必仔细阅读并理解 AGPLv3 的所有条款。
**2. 商业许可证 (Commercial License) - 适用于超过10人的组织或希望规避 AGPLv3 义务的用户**
* **强制要求:** 如果您的组织**不**满足上述“10人及以下”的定义即有11人或更多人可以访问、使用或受益于本软件您**必须**联系我们获取并签署一份商业许可证才能使用 Cherry Studio。
* **自愿选择:** 即使您的组织满足“10人及以下”的条件但如果您的使用场景**无法满足 AGPLv3 的条款要求**(特别是关于**源代码公开**的义务),或者您需要 AGPLv3 **未提供**的特定商业条款(如保证、赔偿、无 Copyleft 限制等),您也**必须**联系我们获取并签署一份商业许可证。
* **需要商业许可证的常见情况包括(但不限于):**
* 您的组织规模超过10人。
* (无论组织规模)您希望分发修改过的 Cherry Studio 版本,但**不希望**根据 AGPLv3 公开您修改部分的源代码。
* (无论组织规模)您希望基于修改过的 Cherry Studio 提供网络服务SaaS但**不希望**根据 AGPLv3 向服务使用者提供修改后的源代码。
* (无论组织规模)您的公司政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件,或要求闭源分发及保密。
* 商业许可证将为您提供豁免 AGPLv3 义务(如源代码公开)的权利,并可能包含额外的商业保障条款。
* **获取商业许可:** 请通过邮箱 **bd@cherry-ai.com** 联系 Cherry Studio 开发团队洽谈商业授权事宜。
**3. 贡献 (Contributions)**
* 我们欢迎社区对 Cherry Studio 的贡献。所有向本项目提交的贡献都将被视为在 **AGPLv3** 许可证下提供。
* 通过向本项目提交贡献(例如通过 Pull Request即表示您同意您的代码以 AGPLv3 许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。
* 您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 Cherry Studio 版本中。
**4. 其他条款 (Other Terms)**
* 关于商业许可证的具体条款和条件,以双方签署的正式商业许可协议为准。
* 项目维护者保留根据需要更新本许可政策(包括用户规模定义和阈值)的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。
---
**Licensing**
This project employs a **User-Segmented Dual Licensing** model.

View File

@ -8,9 +8,9 @@
| 字段名 | 类型 | 是否主键 | 索引 | 说明 |
| ---------- | ------ | -------- | ---- | ------------------------------------------------------------------------ |
| `id` | string | ✅ 是 | ✅ | 唯一标识符,主键 |
| `langCode` | string | ❌ 否 | ✅ | 语言代码(如:`zh-cn`, `en-us`, `ja-jp` 等,均为小写),支持普通索引查询 |
| `value` | string | ❌ 否 | ❌ | 语言的名称,用户输入 |
| `emoji` | string | ❌ 否 | ❌ | 语言的emoji用户输入 |
| `id` | string | ✅ 是 | ✅ | 唯一标识符,主键 |
| `langCode` | string | ❌ 否 | ✅ | 语言代码(如:`zh-cn`, `en-us`, `ja-jp` 等,均为小写),支持普通索引查询 |
| `value` | string | ❌ 否 | ❌ | 语言的名称,用户输入 |
| `emoji` | string | ❌ 否 | ❌ | 语言的emoji用户输入 |
> `langCode` 虽非主键,但在业务层应当避免重复插入相同语言代码。

View File

@ -124,24 +124,25 @@ afterSign: scripts/notarize.js
artifactBuildCompleted: scripts/artifact-build-completed.js
releaseInfo:
releaseNotes: |
✨ 重要更新:
- 新增笔记模块,支持富文本编辑和管理
- 内置 GLM-4.5-Flash 免费模型(由智谱开放平台提供)
- 内置 Qwen3-8B 免费模型(由硅基流动提供)
- 新增 Nano BananaGemini 2.5 Flash Image模型支持
- 新增系统 OCR 功能 (macOS & Windows)
- 新增图片 OCR 识别和翻译功能
- 模型切换支持通过标签筛选
- 翻译功能增强:历史搜索和收藏
✨ 新功能:
- 重构知识库模块,提升文档处理能力和搜索性能
- 新增 PaddleOCR 支持,增强文档识别能力
- 支持自定义窗口控制按钮样式
- 新增 AI SDK 包,扩展 AI 能力集成
- 支持标签页拖拽重排序功能
- 增强笔记编辑器的同步和日志功能
🔧 性能优化:
- 优化历史页面搜索性能
- 优化拖拽列表组件交互
- 升级 Electron 到 37.4.0
- 优化 MCP 服务的日志记录和错误处理
- 改进 WebView 服务的 User-Agent 处理
- 优化迷你应用的标题栏样式和状态栏适配
- 重构依赖管理,清理和优化 package.json
🐛 修复问题:
- 修复知识库加密 PDF 文档处理
- 修复导航栏在左侧时笔记侧边栏按钮缺失
- 修复多个模型兼容性问题
- 修复 MCP 相关问题
- 其他稳定性改进
🐛 问题修复:
- 修复输入栏无限状态更新循环问题
- 修复窗口控制提示框的鼠标悬停延迟
- 修复翻译输入框粘贴多内容源的处理
- 修复导航服务初始化时序问题
- 修复 MCP 通过 JSON 添加时的参数转换
- 修复模型作用域服务器同步时的 URL 格式
- 标准化工具提示图标样式

View File

@ -99,6 +99,9 @@ export default defineConfig({
'@data': resolve('src/renderer/src/data'),
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
'@mcp-trace/trace-web': resolve('packages/mcp-trace/trace-web'),
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
'@cherrystudio/ai-core/built-in/plugins': resolve('packages/aiCore/src/core/plugins/built-in'),
'@cherrystudio/ai-core': resolve('packages/aiCore/src'),
'@cherrystudio/extension-table-plus': resolve('packages/extension-table-plus/src')
}
},

View File

@ -1,6 +1,6 @@
{
"name": "CherryStudio",
"version": "1.5.9",
"version": "1.6.0-beta.7",
"private": true,
"description": "A powerful AI assistant for producer.",
"main": "./out/main/index.js",
@ -66,6 +66,7 @@
"test:lint": "eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts",
"test:scripts": "vitest scripts",
"format": "prettier --write .",
"format:check": "prettier --check .",
"lint": "eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix && yarn typecheck && yarn check:i18n",
"prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky",
"migrations:generate": "drizzle-kit generate --config ./migrations/sqlite-drizzle.config.ts"
@ -75,6 +76,7 @@
"@libsql/win32-x64-msvc": "^0.4.7",
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
"@strongtz/win32-arm64-msvc": "^0.4.7",
"faiss-node": "^0.5.1",
"graceful-fs": "^4.2.11",
"jsdom": "26.1.0",
"node-stream-zip": "^1.15.0",
@ -89,12 +91,16 @@
"@agentic/exa": "^7.3.3",
"@agentic/searxng": "^7.3.3",
"@agentic/tavily": "^7.3.3",
"@ai-sdk/amazon-bedrock": "^3.0.0",
"@ai-sdk/google-vertex": "^3.0.0",
"@ai-sdk/mistral": "^2.0.0",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
"@aws-sdk/client-bedrock": "^3.840.0",
"@aws-sdk/client-bedrock-runtime": "^3.840.0",
"@aws-sdk/client-s3": "^3.840.0",
"@cherrystudio/ai-core": "workspace:*",
"@cherrystudio/embedjs": "^0.1.31",
"@cherrystudio/embedjs-libsql": "^0.1.31",
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
@ -124,12 +130,15 @@
"@google/genai": "patch:@google/genai@npm%3A1.0.1#~/.yarn/patches/@google-genai-npm-1.0.1-e26f0f9af7.patch",
"@hello-pangea/dnd": "^18.0.1",
"@kangfenmao/keyv-storage": "^0.1.0",
"@langchain/community": "^0.3.36",
"@langchain/community": "^0.3.50",
"@langchain/core": "^0.3.68",
"@langchain/ollama": "^0.2.1",
"@langchain/openai": "^0.6.7",
"@mistralai/mistralai": "^1.7.5",
"@modelcontextprotocol/sdk": "^1.17.0",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
"@openrouter/ai-sdk-provider": "^1.1.2",
"@opentelemetry/api": "^1.9.0",
"@opentelemetry/core": "2.0.0",
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
@ -139,7 +148,7 @@
"@playwright/test": "^1.52.0",
"@reduxjs/toolkit": "^2.2.5",
"@shikijs/markdown-it": "^3.12.0",
"@swc/plugin-styled-components": "^7.1.5",
"@swc/plugin-styled-components": "^8.0.4",
"@tanstack/react-query": "^5.85.5",
"@tanstack/react-virtual": "^3.13.12",
"@testing-library/dom": "^10.4.0",
@ -167,9 +176,11 @@
"@types/cli-progress": "^3",
"@types/fs-extra": "^11",
"@types/he": "^1",
"@types/html-to-text": "^9",
"@types/lodash": "^4.17.5",
"@types/markdown-it": "^14",
"@types/md5": "^2.3.5",
"@types/mime-types": "^3",
"@types/node": "^22.17.1",
"@types/pako": "^1.0.2",
"@types/react": "^19.0.12",
@ -190,12 +201,14 @@
"@viz-js/lang-dot": "^1.0.5",
"@viz-js/viz": "^3.14.0",
"@xyflow/react": "^12.4.4",
"ai": "^5.0.29",
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
"archiver": "^7.0.1",
"async-mutex": "^0.5.0",
"axios": "^1.7.3",
"browser-image-compression": "^2.0.2",
"chardet": "^2.1.0",
"cheerio": "^1.1.2",
"chokidar": "^4.0.3",
"cli-progress": "^3.12.0",
"code-inspector-plugin": "^0.20.14",
@ -234,6 +247,7 @@
"he": "^1.2.0",
"html-tags": "^5.1.0",
"html-to-image": "^1.11.13",
"html-to-text": "^9.0.5",
"htmlparser2": "^10.0.0",
"husky": "^9.1.7",
"i18next": "^23.11.5",
@ -250,12 +264,14 @@
"markdown-it": "^14.1.0",
"mermaid": "^11.10.1",
"mime": "^4.0.4",
"mime-types": "^3.0.1",
"motion": "^12.10.5",
"notion-helper": "^1.3.22",
"npx-scope-finder": "^1.2.0",
"openai": "patch:openai@npm%3A5.12.2#~/.yarn/patches/openai-npm-5.12.2-30b075401c.patch",
"p-queue": "^8.1.0",
"pdf-lib": "^1.17.1",
"pdf-parse": "^1.1.1",
"playwright": "^1.52.0",
"prettier": "^3.5.3",
"prettier-plugin-sort-json": "^4.1.1",
@ -268,6 +284,7 @@
"react-infinite-scroll-component": "^6.1.0",
"react-json-view": "^1.21.3",
"react-markdown": "^10.1.0",
"react-player": "^3.3.1",
"react-redux": "^9.1.2",
"react-router": "6",
"react-router-dom": "6",
@ -309,7 +326,9 @@
"winston-daily-rotate-file": "^5.0.0",
"word-extractor": "^1.0.4",
"y-protocols": "^1.0.6",
"yaml": "^2.8.1",
"yjs": "^13.6.27",
"youtubei.js": "^15.0.1",
"zipread": "^1.3.3",
"zod": "^3.25.74"
},
@ -340,7 +359,7 @@
"prettier --write",
"eslint --fix"
],
"*.{json,md,yml,yaml,css,scss,html}": [
"*.{json,yml,yaml,css,scss,html}": [
"prettier --write"
]
}

View File

@ -0,0 +1,514 @@
# AI Core 基于 Vercel AI SDK 的技术架构
## 1. 架构设计理念
### 1.1 设计目标
- **简化分层**`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
- **统一接口**:使用 Vercel AI SDK 统一不同 AI Provider 的接口差异
- **动态导入**:通过动态导入实现按需加载,减少打包体积
- **最小包装**:直接使用 AI SDK 的类型和接口,避免重复定义
- **插件系统**:基于钩子的通用插件架构,支持请求全生命周期扩展
- **类型安全**:利用 TypeScript 和 AI SDK 的类型系统确保类型安全
- **轻量级**:专注核心功能,保持包的轻量和高效
- **包级独立**:作为独立包管理,便于复用和维护
- **Agent就绪**:为将来集成 OpenAI Agents SDK 预留扩展空间
### 1.2 核心优势
- **标准化**AI SDK 提供统一的模型接口,减少适配工作
- **简化设计**函数式API避免过度抽象
- **更好的开发体验**:完整的 TypeScript 支持和丰富的生态系统
- **性能优化**AI SDK 内置优化和最佳实践
- **模块化设计**:独立包结构,支持跨项目复用
- **可扩展插件**:通用的流转换和参数处理插件系统
- **面向未来**:为 OpenAI Agents SDK 集成做好准备
## 2. 整体架构图
```mermaid
graph TD
subgraph "用户应用 (如 Cherry Studio)"
UI["用户界面"]
Components["应用组件"]
end
subgraph "packages/aiCore (AI Core 包)"
subgraph "Runtime Layer (运行时层)"
RuntimeExecutor["RuntimeExecutor (运行时执行器)"]
PluginEngine["PluginEngine (插件引擎)"]
RuntimeAPI["Runtime API (便捷函数)"]
end
subgraph "Models Layer (模型层)"
ModelFactory["createModel() (模型工厂)"]
ProviderCreator["ProviderCreator (提供商创建器)"]
end
subgraph "Core Systems (核心系统)"
subgraph "Plugins (插件)"
PluginManager["PluginManager (插件管理)"]
BuiltInPlugins["Built-in Plugins (内置插件)"]
StreamTransforms["Stream Transforms (流转换)"]
end
subgraph "Middleware (中间件)"
MiddlewareWrapper["wrapModelWithMiddlewares() (中间件包装)"]
end
subgraph "Providers (提供商)"
Registry["Provider Registry (注册表)"]
Factory["Provider Factory (工厂)"]
end
end
end
subgraph "Vercel AI SDK"
AICore["ai (核心库)"]
OpenAI["@ai-sdk/openai"]
Anthropic["@ai-sdk/anthropic"]
Google["@ai-sdk/google"]
XAI["@ai-sdk/xai"]
Others["其他 19+ Providers"]
end
subgraph "Future: OpenAI Agents SDK"
AgentSDK["@openai/agents (未来集成)"]
AgentExtensions["Agent Extensions (预留)"]
end
UI --> RuntimeAPI
Components --> RuntimeExecutor
RuntimeAPI --> RuntimeExecutor
RuntimeExecutor --> PluginEngine
RuntimeExecutor --> ModelFactory
PluginEngine --> PluginManager
ModelFactory --> ProviderCreator
ModelFactory --> MiddlewareWrapper
ProviderCreator --> Registry
Registry --> Factory
Factory --> OpenAI
Factory --> Anthropic
Factory --> Google
Factory --> XAI
Factory --> Others
RuntimeExecutor --> AICore
AICore --> streamText
AICore --> generateText
AICore --> streamObject
AICore --> generateObject
PluginManager --> StreamTransforms
PluginManager --> BuiltInPlugins
%% 未来集成路径
RuntimeExecutor -.-> AgentSDK
AgentSDK -.-> AgentExtensions
```
## 3. 包结构设计
### 3.1 新架构文件结构
```
packages/aiCore/
├── src/
│ ├── core/ # 核心层 - 内部实现
│ │ ├── models/ # 模型层 - 模型创建和配置
│ │ │ ├── factory.ts # 模型工厂函数 ✅
│ │ │ ├── ModelCreator.ts # 模型创建器 ✅
│ │ │ ├── ConfigManager.ts # 配置管理器 ✅
│ │ │ ├── types.ts # 模型类型定义 ✅
│ │ │ └── index.ts # 模型层导出 ✅
│ │ ├── runtime/ # 运行时层 - 执行和用户API
│ │ │ ├── executor.ts # 运行时执行器 ✅
│ │ │ ├── pluginEngine.ts # 插件引擎 ✅
│ │ │ ├── types.ts # 运行时类型定义 ✅
│ │ │ └── index.ts # 运行时导出 ✅
│ │ ├── middleware/ # 中间件系统
│ │ │ ├── wrapper.ts # 模型包装器 ✅
│ │ │ ├── manager.ts # 中间件管理器 ✅
│ │ │ ├── types.ts # 中间件类型 ✅
│ │ │ └── index.ts # 中间件导出 ✅
│ │ ├── plugins/ # 插件系统
│ │ │ ├── types.ts # 插件类型定义 ✅
│ │ │ ├── manager.ts # 插件管理器 ✅
│ │ │ ├── built-in/ # 内置插件 ✅
│ │ │ │ ├── logging.ts # 日志插件 ✅
│ │ │ │ ├── webSearchPlugin/ # 网络搜索插件 ✅
│ │ │ │ ├── toolUsePlugin/ # 工具使用插件 ✅
│ │ │ │ └── index.ts # 内置插件导出 ✅
│ │ │ ├── README.md # 插件文档 ✅
│ │ │ └── index.ts # 插件导出 ✅
│ │ ├── providers/ # 提供商管理
│ │ │ ├── registry.ts # 提供商注册表 ✅
│ │ │ ├── factory.ts # 提供商工厂 ✅
│ │ │ ├── creator.ts # 提供商创建器 ✅
│ │ │ ├── types.ts # 提供商类型 ✅
│ │ │ ├── utils.ts # 工具函数 ✅
│ │ │ └── index.ts # 提供商导出 ✅
│ │ ├── options/ # 配置选项
│ │ │ ├── factory.ts # 选项工厂 ✅
│ │ │ ├── types.ts # 选项类型 ✅
│ │ │ ├── xai.ts # xAI 选项 ✅
│ │ │ ├── openrouter.ts # OpenRouter 选项 ✅
│ │ │ ├── examples.ts # 示例配置 ✅
│ │ │ └── index.ts # 选项导出 ✅
│ │ └── index.ts # 核心层导出 ✅
│ ├── types.ts # 全局类型定义 ✅
│ └── index.ts # 包主入口文件 ✅
├── package.json # 包配置文件 ✅
├── tsconfig.json # TypeScript 配置 ✅
├── README.md # 包说明文档 ✅
└── AI_SDK_ARCHITECTURE.md # 本文档 ✅
```
## 4. 架构分层详解
### 4.1 Models Layer (模型层)
**职责**:统一的模型创建和配置管理
**核心文件**
- `factory.ts`: 模型工厂函数 (`createModel`, `createModels`)
- `ProviderCreator.ts`: 底层提供商创建和模型实例化
- `types.ts`: 模型配置类型定义
**设计特点**
- 函数式设计,避免不必要的类抽象
- 统一的模型配置接口
- 自动处理中间件应用
- 支持批量模型创建
**核心API**
```typescript
// 模型配置接口
export interface ModelConfig {
providerId: ProviderId
modelId: string
options: ProviderSettingsMap[ProviderId]
middlewares?: LanguageModelV1Middleware[]
}
// 核心模型创建函数
export async function createModel(config: ModelConfig): Promise<LanguageModel>
export async function createModels(configs: ModelConfig[]): Promise<LanguageModel[]>
```
### 4.2 Runtime Layer (运行时层)
**职责**运行时执行器和用户面向的API接口
**核心组件**
- `executor.ts`: 运行时执行器类
- `plugin-engine.ts`: 插件引擎原PluginEnabledAiClient
- `index.ts`: 便捷函数和工厂方法
**设计特点**
- 提供三种使用方式:类实例、静态工厂、函数式调用
- 自动集成模型创建和插件处理
- 完整的类型安全支持
- 为 OpenAI Agents SDK 预留扩展接口
**核心API**
```typescript
// 运行时执行器
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
static create<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T],
plugins?: AiPlugin[]
): RuntimeExecutor<T>
async streamText(modelId: string, params: StreamTextParams): Promise<StreamTextResult>
async generateText(modelId: string, params: GenerateTextParams): Promise<GenerateTextResult>
async streamObject(modelId: string, params: StreamObjectParams): Promise<StreamObjectResult>
async generateObject(modelId: string, params: GenerateObjectParams): Promise<GenerateObjectResult>
}
// 便捷函数式API
export async function streamText<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T],
modelId: string,
params: StreamTextParams,
plugins?: AiPlugin[]
): Promise<StreamTextResult>
```
### 4.3 Plugin System (插件系统)
**职责**:可扩展的插件架构
**核心组件**
- `PluginManager`: 插件生命周期管理
- `built-in/`: 内置插件集合
- 流转换收集和应用
**设计特点**
- 借鉴 Rollup 的钩子分类设计
- 支持流转换 (`experimental_transform`)
- 内置常用插件(日志、计数等)
- 完整的生命周期钩子
**插件接口**
```typescript
export interface AiPlugin {
name: string
enforce?: 'pre' | 'post'
// 【First】首个钩子 - 只执行第一个返回值的插件
resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise<string | null>
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
// 【Sequential】串行钩子 - 链式执行,支持数据转换
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
// 【Stream】流处理
transformStream?: () => TransformStream
}
```
### 4.4 Middleware System (中间件系统)
**职责**AI SDK原生中间件支持
**核心组件**
- `ModelWrapper.ts`: 模型包装函数
**设计哲学**
- 直接使用AI SDK的 `wrapLanguageModel`
- 与插件系统分离,职责明确
- 函数式设计,简化使用
```typescript
export function wrapModelWithMiddlewares(model: LanguageModel, middlewares: LanguageModelV1Middleware[]): LanguageModel
```
### 4.5 Provider System (提供商系统)
**职责**AI Provider注册表和动态导入
**核心组件**
- `registry.ts`: 19+ Provider配置和类型
- `factory.ts`: Provider配置工厂
**支持的Providers**
- OpenAI, Anthropic, Google, XAI
- Azure OpenAI, Amazon Bedrock, Google Vertex
- Groq, Together.ai, Fireworks, DeepSeek
- 等19+ AI SDK官方支持的providers
## 5. 使用方式
### 5.1 函数式调用 (推荐 - 简单场景)
```typescript
import { streamText, generateText } from '@cherrystudio/ai-core/runtime'
// 直接函数调用
const stream = await streamText(
'anthropic',
{ apiKey: 'your-api-key' },
'claude-3',
{ messages: [{ role: 'user', content: 'Hello!' }] },
[loggingPlugin]
)
```
### 5.2 执行器实例 (推荐 - 复杂场景)
```typescript
import { createExecutor } from '@cherrystudio/ai-core/runtime'
// 创建可复用的执行器
const executor = createExecutor('openai', { apiKey: 'your-api-key' }, [plugin1, plugin2])
// 多次使用
const stream = await executor.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello!' }]
})
const result = await executor.generateText('gpt-4', {
messages: [{ role: 'user', content: 'How are you?' }]
})
```
### 5.3 静态工厂方法
```typescript
import { RuntimeExecutor } from '@cherrystudio/ai-core/runtime'
// 静态创建
const executor = RuntimeExecutor.create('anthropic', { apiKey: 'your-api-key' })
await executor.streamText('claude-3', { messages: [...] })
```
### 5.4 直接模型创建 (高级用法)
```typescript
import { createModel } from '@cherrystudio/ai-core/models'
import { streamText } from 'ai'
// 直接创建模型使用
const model = await createModel({
providerId: 'openai',
modelId: 'gpt-4',
options: { apiKey: 'your-api-key' },
middlewares: [middleware1, middleware2]
})
// 直接使用 AI SDK
const result = await streamText({ model, messages: [...] })
```
## 6. 为 OpenAI Agents SDK 预留的设计
### 6.1 架构兼容性
当前架构完全兼容 OpenAI Agents SDK 的集成需求:
```typescript
// 当前的模型创建
const model = await createModel({
providerId: 'anthropic',
modelId: 'claude-3',
options: { apiKey: 'xxx' }
})
// 将来可以直接用于 OpenAI Agents SDK
import { Agent, run } from '@openai/agents'
const agent = new Agent({
model, // ✅ 直接兼容 LanguageModel 接口
name: 'Assistant',
instructions: '...',
tools: [tool1, tool2]
})
const result = await run(agent, 'user input')
```
### 6.2 预留的扩展点
1. **runtime/agents/** 目录预留
2. **AgentExecutor** 类预留
3. **Agent工具转换插件** 预留
4. **多Agent编排** 预留
### 6.3 未来架构扩展
```
packages/aiCore/src/core/
├── runtime/
│ ├── agents/ # 🚀 未来添加
│ │ ├── AgentExecutor.ts
│ │ ├── WorkflowManager.ts
│ │ └── ConversationManager.ts
│ ├── executor.ts
│ └── index.ts
```
## 7. 架构优势
### 7.1 简化设计
- **移除过度抽象**删除了orchestration层和creation层的复杂包装
- **函数式优先**models层使用函数而非类
- **直接明了**runtime层直接提供用户API
### 7.2 职责清晰
- **Models**: 专注模型创建和配置
- **Runtime**: 专注执行和用户API
- **Plugins**: 专注扩展功能
- **Providers**: 专注AI Provider管理
### 7.3 类型安全
- 完整的 TypeScript 支持
- AI SDK 类型的直接复用
- 避免类型重复定义
### 7.4 灵活使用
- 三种使用模式满足不同需求
- 从简单函数调用到复杂执行器
- 支持直接AI SDK使用
### 7.5 面向未来
- 为 OpenAI Agents SDK 集成做好准备
- 清晰的扩展点和架构边界
- 模块化设计便于功能添加
## 8. 技术决策记录
### 8.1 为什么选择简化的两层架构?
- **职责分离**models专注创建runtime专注执行
- **模块化**:每层都有清晰的边界和职责
- **扩展性**为Agent功能预留了清晰的扩展空间
### 8.2 为什么选择函数式设计?
- **简洁性**:避免不必要的类设计
- **性能**:减少对象创建开销
- **易用性**:函数调用更直观
### 8.3 为什么分离插件和中间件?
- **职责明确**: 插件处理应用特定需求
- **原生支持**: 中间件使用AI SDK原生功能
- **灵活性**: 两套系统可以独立演进
## 9. 总结
AI Core架构实现了
### 9.1 核心特点
- ✅ **简化架构**: 2层核心架构职责清晰
- ✅ **函数式设计**: models层完全函数化
- ✅ **类型安全**: 统一的类型定义和AI SDK类型复用
- ✅ **插件扩展**: 强大的插件系统
- ✅ **多种使用方式**: 满足不同复杂度需求
- ✅ **Agent就绪**: 为OpenAI Agents SDK集成做好准备
### 9.2 核心价值
- **统一接口**: 一套API支持19+ AI providers
- **灵活使用**: 函数式、实例式、静态工厂式
- **强类型**: 完整的TypeScript支持
- **可扩展**: 插件和中间件双重扩展能力
- **高性能**: 最小化包装直接使用AI SDK
- **面向未来**: Agent SDK集成架构就绪
### 9.3 未来发展
这个架构提供了:
- **优秀的开发体验**: 简洁的API和清晰的使用模式
- **强大的扩展能力**: 为Agent功能预留了完整的架构空间
- **良好的维护性**: 职责分离明确,代码易于维护
- **广泛的适用性**: 既适合简单调用也适合复杂应用

433
packages/aiCore/README.md Normal file
View File

@ -0,0 +1,433 @@
# @cherrystudio/ai-core
Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口包,为 AI 应用提供强大的抽象层和插件化架构。
## ✨ 核心亮点
### 🏗️ 优雅的架构设计
- **简化分层**`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
- **函数式优先**:避免过度抽象,提供简洁直观的 API
- **类型安全**:完整的 TypeScript 支持,直接复用 AI SDK 类型系统
- **最小包装**:直接使用 AI SDK 的接口,避免重复定义和性能损耗
### 🔌 强大的插件系统
- **生命周期钩子**:支持请求全生命周期的扩展点
- **流转换支持**:基于 AI SDK 的 `experimental_transform` 实现流处理
- **插件分类**First、Sequential、Parallel 三种钩子类型,满足不同场景
- **内置插件**webSearch、logging、toolUse 等开箱即用的功能
### 🌐 统一多 Provider 接口
- **扩展注册**:支持自定义 Provider 注册,无限扩展能力
- **配置统一**:统一的配置接口,简化多 Provider 管理
### 🚀 多种使用方式
- **函数式调用**:适合简单场景的直接函数调用
- **执行器实例**:适合复杂场景的可复用执行器
- **静态工厂**:便捷的静态创建方法
- **原生兼容**:完全兼容 AI SDK 原生 Provider Registry
### 🔮 面向未来
- **Agent 就绪**:为 OpenAI Agents SDK 集成预留架构空间
- **模块化设计**:独立包结构,支持跨项目复用
- **渐进式迁移**:可以逐步从现有 AI SDK 代码迁移
## 特性
- 🚀 统一的 AI Provider 接口
- 🔄 动态导入支持
- 🛠️ TypeScript 支持
- 📦 强大的插件系统
- 🌍 内置webSearch(Openai,Google,Anthropic,xAI)
- 🎯 多种使用模式(函数式/实例式/静态工厂)
- 🔌 可扩展的 Provider 注册系统
- 🧩 完整的中间件支持
- 📊 插件统计和调试功能
## 支持的 Providers
基于 [AI SDK 官方支持的 providers](https://ai-sdk.dev/providers/ai-sdk-providers)
**核心 Providers内置支持:**
- OpenAI
- Anthropic
- Google Generative AI
- OpenAI-Compatible
- xAI (Grok)
- Azure OpenAI
- DeepSeek
**扩展 Providers通过注册API支持:**
- Google Vertex AI
- ...
- 自定义 Provider
## 安装
```bash
npm install @cherrystudio/ai-core ai
```
### React Native
如果你在 React Native 项目中使用此包,需要在 `metro.config.js` 中添加以下配置:
```javascript
// metro.config.js
const { getDefaultConfig } = require('expo/metro-config')
const config = getDefaultConfig(__dirname)
// 添加对 @cherrystudio/ai-core 的支持
config.resolver.resolverMainFields = ['react-native', 'browser', 'main']
config.resolver.platforms = ['ios', 'android', 'native', 'web']
module.exports = config
```
还需要安装你要使用的 AI SDK provider:
```bash
npm install @ai-sdk/openai @ai-sdk/anthropic @ai-sdk/google
```
## 使用示例
### 基础用法
```typescript
import { AiCore } from '@cherrystudio/ai-core'
// 创建 OpenAI executor
const executor = AiCore.create('openai', {
apiKey: 'your-api-key'
})
// 流式生成
const result = await executor.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello!' }]
})
// 非流式生成
const response = await executor.generateText('gpt-4', {
messages: [{ role: 'user', content: 'Hello!' }]
})
```
### 便捷函数
```typescript
import { createOpenAIExecutor } from '@cherrystudio/ai-core'
// 快速创建 OpenAI executor
const executor = createOpenAIExecutor({
apiKey: 'your-api-key'
})
// 使用 executor
const result = await executor.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello!' }]
})
```
### 多 Provider 支持
```typescript
import { AiCore } from '@cherrystudio/ai-core'
// 支持多种 AI providers
const openaiExecutor = AiCore.create('openai', { apiKey: 'openai-key' })
const anthropicExecutor = AiCore.create('anthropic', { apiKey: 'anthropic-key' })
const googleExecutor = AiCore.create('google', { apiKey: 'google-key' })
const xaiExecutor = AiCore.create('xai', { apiKey: 'xai-key' })
```
### 扩展 Provider 注册
对于非内置的 providers可以通过注册 API 扩展支持:
```typescript
import { registerProvider, AiCore } from '@cherrystudio/ai-core'
// 方式一:导入并注册第三方 provider
import { createGroq } from '@ai-sdk/groq'
registerProvider({
id: 'groq',
name: 'Groq',
creator: createGroq,
supportsImageGeneration: false
})
// 现在可以使用 Groq
const groqExecutor = AiCore.create('groq', { apiKey: 'groq-key' })
// 方式二:动态导入方式注册
registerProvider({
id: 'mistral',
name: 'Mistral AI',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral'
})
const mistralExecutor = AiCore.create('mistral', { apiKey: 'mistral-key' })
```
## 🔌 插件系统
AI Core 提供了强大的插件系统,支持请求全生命周期的扩展。
### 内置插件
#### webSearchPlugin - 网络搜索插件
为不同 AI Provider 提供统一的网络搜索能力:
```typescript
import { webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
webSearchPlugin({
openai: {
/* OpenAI 搜索配置 */
},
anthropic: { maxUses: 5 },
google: {
/* Google 搜索配置 */
},
xai: {
mode: 'on',
returnCitations: true,
maxSearchResults: 5,
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
}
})
])
```
#### loggingPlugin - 日志插件
提供详细的请求日志记录:
```typescript
import { createLoggingPlugin } from '@cherrystudio/ai-core/built-in/plugins'
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
createLoggingPlugin({
logLevel: 'info',
includeParams: true,
includeResult: false
})
])
```
#### promptToolUsePlugin - 提示工具使用插件
为不支持原生 Function Call 的模型提供 prompt 方式的工具调用:
```typescript
import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/built-in/plugins'
// 对于不支持 function call 的模型
const executor = AiCore.create(
'providerId',
{
apiKey: 'your-key',
baseURL: 'https://your-model-endpoint'
},
[
createPromptToolUsePlugin({
enabled: true,
// 可选:自定义系统提示符构建
buildSystemPrompt: (userPrompt, tools) => {
return `${userPrompt}\n\nAvailable tools: ${Object.keys(tools).join(', ')}`
}
})
]
)
```
### 自定义插件
创建自定义插件非常简单:
```typescript
import { definePlugin } from '@cherrystudio/ai-core'
const customPlugin = definePlugin({
name: 'custom-plugin',
enforce: 'pre', // 'pre' | 'post' | undefined
// 在请求开始时记录日志
onRequestStart: async (context) => {
console.log(`Starting request for model: ${context.modelId}`)
},
// 转换请求参数
transformParams: async (params, context) => {
// 添加自定义系统消息
if (params.messages) {
params.messages.unshift({
role: 'system',
content: 'You are a helpful assistant.'
})
}
return params
},
// 处理响应结果
transformResult: async (result, context) => {
// 添加元数据
if (result.text) {
result.metadata = {
processedAt: new Date().toISOString(),
modelId: context.modelId
}
}
return result
}
})
// 使用自定义插件
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [customPlugin])
```
### 使用 AI SDK 原生 Provider 注册表
> https://ai-sdk.dev/docs/reference/ai-sdk-core/provider-registry
除了使用内建的 provider 管理,你还可以使用 AI SDK 原生的 `createProviderRegistry` 来构建自己的 provider 注册表。
#### 基本用法示例
```typescript
import { createClient } from '@cherrystudio/ai-core'
import { createProviderRegistry } from 'ai'
import { createOpenAI } from '@ai-sdk/openai'
import { anthropic } from '@ai-sdk/anthropic'
// 1. 创建 AI SDK 原生注册表
export const registry = createProviderRegistry({
// register provider with prefix and default setup:
anthropic,
// register provider with prefix and custom setup:
openai: createOpenAI({
apiKey: process.env.OPENAI_API_KEY
})
})
// 2. 创建client,'openai'可以传空或者传providerId(内建的provider)
const client = PluginEnabledAiClient.create('openai', {
apiKey: process.env.OPENAI_API_KEY
})
// 3. 方式1使用内建逻辑传统方式
const result1 = await client.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello with built-in logic!' }]
})
// 4. 方式2使用自定义注册表灵活方式
const result2 = await client.streamText({
model: registry.languageModel('openai:gpt-4'),
messages: [{ role: 'user', content: 'Hello with custom registry!' }]
})
// 5. 支持的重载方法
await client.generateObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ name: z.string() }),
messages: [{ role: 'user', content: 'Generate a user' }]
})
await client.streamObject({
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
schema: z.object({ items: z.array(z.string()) }),
messages: [{ role: 'user', content: 'Generate a list' }]
})
```
#### 与插件系统配合使用
更强大的是,你还可以将自定义注册表与 Cherry Studio 的插件系统结合使用:
```typescript
import { PluginEnabledAiClient } from '@cherrystudio/ai-core'
import { createProviderRegistry } from 'ai'
import { createOpenAI } from '@ai-sdk/openai'
import { anthropic } from '@ai-sdk/anthropic'
// 1. 创建带插件的客户端
const client = PluginEnabledAiClient.create(
'openai',
{
apiKey: process.env.OPENAI_API_KEY
},
[LoggingPlugin, RetryPlugin]
)
// 2. 创建自定义注册表
const registry = createProviderRegistry({
openai: createOpenAI({ apiKey: process.env.OPENAI_API_KEY }),
anthropic: anthropic({ apiKey: process.env.ANTHROPIC_API_KEY })
})
// 3. 方式1使用内建逻辑 + 完整插件系统
await client.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello with plugins!' }]
})
// 4. 方式2使用自定义注册表 + 有限插件支持
await client.streamText({
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
messages: [{ role: 'user', content: 'Hello from Claude!' }]
})
// 5. 支持的方法
await client.generateObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ name: z.string() }),
messages: [{ role: 'user', content: 'Generate a user' }]
})
await client.streamObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ items: z.array(z.string()) }),
messages: [{ role: 'user', content: 'Generate a list' }]
})
```
#### 混合使用的优势
- **灵活性**:可以根据需要选择使用内建逻辑或自定义注册表
- **兼容性**:完全兼容 AI SDK 的 `createProviderRegistry` API
- **渐进式**:可以逐步迁移现有代码,无需一次性重构
- **插件支持**:自定义注册表仍可享受插件系统的部分功能
- **最佳实践**:结合两种方式的优点,既有动态加载的性能优势,又有统一注册表的便利性
## 📚 相关资源
- [Vercel AI SDK 文档](https://ai-sdk.dev/)
- [Cherry Studio 项目](https://github.com/CherryHQ/cherry-studio)
- [AI SDK Providers](https://ai-sdk.dev/providers/ai-sdk-providers)
## 未来版本
- 🔮 多 Agent 编排
- 🔮 可视化插件配置
- 🔮 实时监控和分析
- 🔮 云端插件同步
## 📄 License
MIT License - 详见 [LICENSE](https://github.com/CherryHQ/cherry-studio/blob/main/LICENSE) 文件
---
**Cherry Studio AI Core** - 让 AI 开发更简单、更强大、更灵活 🚀

View File

@ -0,0 +1,103 @@
/**
* Hub Provider 使
*
* 使Hub Provider功能来路由到多个底层provider
*/
import { createHubProvider, initializeProvider, providerRegistry } from '../src/index'
async function demonstrateHubProvider() {
try {
// 1. 初始化底层providers
console.log('📦 初始化底层providers...')
initializeProvider('openai', {
apiKey: process.env.OPENAI_API_KEY || 'sk-test-key'
})
initializeProvider('anthropic', {
apiKey: process.env.ANTHROPIC_API_KEY || 'sk-ant-test-key'
})
// 2. 创建Hub Provider自动包含所有已初始化的providers
console.log('🌐 创建Hub Provider...')
const aihubmixProvider = createHubProvider({
hubId: 'aihubmix',
debug: true
})
// 3. 注册Hub Provider
providerRegistry.registerProvider('aihubmix', aihubmixProvider)
console.log('✅ Hub Provider "aihubmix" 注册成功')
// 4. 使用Hub Provider访问不同的模型
console.log('\n🚀 使用Hub模型...')
// 通过Hub路由到OpenAI
const openaiModel = providerRegistry.languageModel('aihubmix:openai:gpt-4')
console.log('✓ OpenAI模型已获取:', openaiModel.modelId)
// 通过Hub路由到Anthropic
const anthropicModel = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
console.log('✓ Anthropic模型已获取:', anthropicModel.modelId)
// 5. 演示错误处理
console.log('\n❌ 演示错误处理...')
try {
// 尝试访问未初始化的provider
providerRegistry.languageModel('aihubmix:google:gemini-pro')
} catch (error) {
console.log('预期错误:', error.message)
}
try {
// 尝试使用错误的模型ID格式
providerRegistry.languageModel('aihubmix:invalid-format')
} catch (error) {
console.log('预期错误:', error.message)
}
// 6. 多个Hub Provider示例
console.log('\n🔄 创建多个Hub Provider...')
const localHubProvider = createHubProvider({
hubId: 'local-ai'
})
providerRegistry.registerProvider('local-ai', localHubProvider)
console.log('✅ Hub Provider "local-ai" 注册成功')
console.log('\n🎉 Hub Provider演示完成')
} catch (error) {
console.error('💥 演示过程中发生错误:', error)
}
}
// 演示简化的使用方式
function simplifiedUsageExample() {
console.log('\n📝 简化使用示例:')
console.log(`
// 1. 初始化providers
initializeProvider('openai', { apiKey: 'sk-xxx' })
initializeProvider('anthropic', { apiKey: 'sk-ant-xxx' })
// 2. 创建并注册Hub Provider
const hubProvider = createHubProvider({ hubId: 'aihubmix' })
providerRegistry.registerProvider('aihubmix', hubProvider)
// 3. 直接使用
const model1 = providerRegistry.languageModel('aihubmix:openai:gpt-4')
const model2 = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
`)
}
// 运行演示
if (require.main === module) {
demonstrateHubProvider()
simplifiedUsageExample()
}
export { demonstrateHubProvider, simplifiedUsageExample }

View File

@ -0,0 +1,167 @@
/**
* Image Generation Example
* 使 aiCore
*/
import { createExecutor, generateImage } from '../src/index'
async function main() {
// 方式1: 使用执行器实例
console.log('📸 创建 OpenAI 图像生成执行器...')
const executor = createExecutor('openai', {
apiKey: process.env.OPENAI_API_KEY!
})
try {
console.log('🎨 使用执行器生成图像...')
const result1 = await executor.generateImage('dall-e-3', {
prompt: 'A futuristic cityscape at sunset with flying cars',
size: '1024x1024',
n: 1
})
console.log('✅ 图像生成成功!')
console.log('📊 结果:', {
imagesCount: result1.images.length,
mediaType: result1.image.mediaType,
hasBase64: !!result1.image.base64,
providerMetadata: result1.providerMetadata
})
} catch (error) {
console.error('❌ 执行器生成失败:', error)
}
// 方式2: 使用直接调用 API
try {
console.log('🎨 使用直接 API 生成图像...')
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
prompt: 'A magical forest with glowing mushrooms and fairy lights',
aspectRatio: '16:9',
providerOptions: {
openai: {
quality: 'hd',
style: 'vivid'
}
}
})
console.log('✅ 直接 API 生成成功!')
console.log('📊 结果:', {
imagesCount: result2.images.length,
mediaType: result2.image.mediaType,
hasBase64: !!result2.image.base64
})
} catch (error) {
console.error('❌ 直接 API 生成失败:', error)
}
// 方式3: 支持其他提供商 (Google Imagen)
if (process.env.GOOGLE_API_KEY) {
try {
console.log('🎨 使用 Google Imagen 生成图像...')
const googleExecutor = createExecutor('google', {
apiKey: process.env.GOOGLE_API_KEY!
})
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
prompt: 'A serene mountain lake at dawn with mist rising from the water',
aspectRatio: '1:1'
})
console.log('✅ Google Imagen 生成成功!')
console.log('📊 结果:', {
imagesCount: result3.images.length,
mediaType: result3.image.mediaType,
hasBase64: !!result3.image.base64
})
} catch (error) {
console.error('❌ Google Imagen 生成失败:', error)
}
}
// 方式4: 支持插件系统
const pluginExample = async () => {
console.log('🔌 演示插件系统...')
// 创建一个示例插件,用于修改提示词
const promptEnhancerPlugin = {
name: 'prompt-enhancer',
transformParams: async (params: any) => {
console.log('🔧 插件: 增强提示词...')
return {
...params,
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
}
},
transformResult: async (result: any) => {
console.log('🔧 插件: 处理结果...')
return {
...result,
enhanced: true
}
}
}
const executorWithPlugin = createExecutor(
'openai',
{
apiKey: process.env.OPENAI_API_KEY!
},
[promptEnhancerPlugin]
)
try {
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A cute robot playing in a garden'
})
console.log('✅ 插件系统生成成功!')
console.log('📊 结果:', {
imagesCount: result4.images.length,
enhanced: (result4 as any).enhanced,
mediaType: result4.image.mediaType
})
} catch (error) {
console.error('❌ 插件系统生成失败:', error)
}
}
await pluginExample()
}
// 错误处理演示
async function errorHandlingExample() {
console.log('⚠️ 演示错误处理...')
try {
const executor = createExecutor('openai', {
apiKey: 'invalid-key'
})
await executor.generateImage('dall-e-3', {
prompt: 'Test image'
})
} catch (error: any) {
console.log('✅ 成功捕获错误:', error.constructor.name)
console.log('📋 错误信息:', error.message)
console.log('🏷️ 提供商ID:', error.providerId)
console.log('🏷️ 模型ID:', error.modelId)
}
}
// 运行示例
if (require.main === module) {
main()
.then(() => {
console.log('🎉 所有示例完成!')
return errorHandlingExample()
})
.then(() => {
console.log('🎯 示例程序结束')
process.exit(0)
})
.catch((error) => {
console.error('💥 程序执行出错:', error)
process.exit(1)
})
}

View File

@ -0,0 +1,85 @@
{
"name": "@cherrystudio/ai-core",
"version": "1.0.0-alpha.11",
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
"main": "dist/index.js",
"module": "dist/index.mjs",
"types": "dist/index.d.ts",
"react-native": "dist/index.js",
"scripts": {
"build": "tsdown",
"dev": "tsc -w",
"clean": "rm -rf dist",
"test": "vitest run",
"test:watch": "vitest"
},
"keywords": [
"ai",
"sdk",
"openai",
"anthropic",
"google",
"cherry-studio",
"vercel-ai-sdk"
],
"author": "Cherry Studio",
"license": "MIT",
"repository": {
"type": "git",
"url": "git+https://github.com/CherryHQ/cherry-studio.git"
},
"bugs": {
"url": "https://github.com/CherryHQ/cherry-studio/issues"
},
"homepage": "https://github.com/CherryHQ/cherry-studio#readme",
"peerDependencies": {
"ai": "^5.0.26"
},
"dependencies": {
"@ai-sdk/anthropic": "^2.0.5",
"@ai-sdk/azure": "^2.0.16",
"@ai-sdk/deepseek": "^1.0.9",
"@ai-sdk/google": "^2.0.7",
"@ai-sdk/openai": "^2.0.19",
"@ai-sdk/openai-compatible": "^1.0.9",
"@ai-sdk/provider": "^2.0.0",
"@ai-sdk/provider-utils": "^3.0.4",
"@ai-sdk/xai": "^2.0.9",
"zod": "^3.25.0"
},
"devDependencies": {
"tsdown": "^0.12.9",
"typescript": "^5.0.0",
"vitest": "^3.2.4"
},
"sideEffects": false,
"engines": {
"node": ">=18.0.0"
},
"files": [
"dist"
],
"exports": {
".": {
"types": "./dist/index.d.ts",
"react-native": "./dist/index.js",
"import": "./dist/index.mjs",
"require": "./dist/index.js",
"default": "./dist/index.js"
},
"./built-in/plugins": {
"types": "./dist/built-in/plugins/index.d.ts",
"react-native": "./dist/built-in/plugins/index.js",
"import": "./dist/built-in/plugins/index.mjs",
"require": "./dist/built-in/plugins/index.js",
"default": "./dist/built-in/plugins/index.js"
},
"./provider": {
"types": "./dist/provider/index.d.ts",
"react-native": "./dist/provider/index.js",
"import": "./dist/provider/index.mjs",
"require": "./dist/provider/index.js",
"default": "./dist/provider/index.js"
}
}
}

View File

@ -0,0 +1,2 @@
// 模拟 Vite SSR helper避免 Node 环境找不到时报错
;(globalThis as any).__vite_ssr_exportName__ = (name: string, value: any) => value

View File

@ -0,0 +1,3 @@
# @cherryStudio-aiCore
Core

View File

@ -0,0 +1,17 @@
/**
* Core
* 使
*/
// 中间件系统
export type { NamedMiddleware } from './middleware'
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
// 创建管理
export { globalModelResolver, ModelResolver } from './models'
export type { ModelConfig as ModelConfigType } from './models/types'
// 执行管理
export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type'
export { createExecutor, createOpenAICompatibleExecutor } from './runtime'
export type { RuntimeConfig } from './runtime/types'

View File

@ -0,0 +1,8 @@
/**
* Middleware
*
*/
export { createMiddlewares } from './manager'
export type { NamedMiddleware } from './types'
export { wrapModelWithMiddlewares } from './wrapper'

View File

@ -0,0 +1,16 @@
/**
*
* AI SDK
*/
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
/**
*
*
*/
export function createMiddlewares(userMiddlewares: LanguageModelV2Middleware[] = []): LanguageModelV2Middleware[] {
// 未来可以在这里添加默认的中间件
const defaultMiddlewares: LanguageModelV2Middleware[] = []
return [...defaultMiddlewares, ...userMiddlewares]
}

View File

@ -0,0 +1,12 @@
/**
*
*/
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
/**
*
*/
export interface NamedMiddleware {
name: string
middleware: LanguageModelV2Middleware
}

View File

@ -0,0 +1,23 @@
/**
*
* LanguageModel上
*/
import { LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
import { wrapLanguageModel } from 'ai'
/**
* 使
*/
export function wrapModelWithMiddlewares(
model: LanguageModelV2,
middlewares: LanguageModelV2Middleware[]
): LanguageModelV2 {
if (middlewares.length === 0) {
return model
}
return wrapLanguageModel({
model,
middleware: middlewares
})
}

View File

@ -0,0 +1,125 @@
/**
* - models模块的核心
* modelId解析为AI SDK的LanguageModel实例
*
* ModelCreator
*/
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
import { wrapModelWithMiddlewares } from '../middleware/wrapper'
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../providers/RegistryManagement'
export class ModelResolver {
/**
* modelId为语言模型
*
* @param modelId ID 'gpt-4' 'anthropic>claude-3'
* @param fallbackProviderId modelId为传统格式时使用的providerId
* @param providerOptions provider配置选项OpenAI模式选择等
* @param middlewares
*/
async resolveLanguageModel(
modelId: string,
fallbackProviderId: string,
providerOptions?: any,
middlewares?: LanguageModelV2Middleware[]
): Promise<LanguageModelV2> {
let finalProviderId = fallbackProviderId
let model: LanguageModelV2
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
if ((fallbackProviderId === 'openai' || fallbackProviderId === 'azure') && providerOptions?.mode === 'chat') {
finalProviderId = `${fallbackProviderId}-chat`
}
// 检查是否是命名空间格式
if (modelId.includes(DEFAULT_SEPARATOR)) {
model = this.resolveNamespacedModel(modelId)
} else {
// 传统格式:使用处理后的 providerId + modelId
model = this.resolveTraditionalModel(finalProviderId, modelId)
}
// 🎯 应用中间件(如果有)
if (middlewares && middlewares.length > 0) {
model = wrapModelWithMiddlewares(model, middlewares)
}
return model
}
/**
*
*/
async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise<EmbeddingModelV2<string>> {
if (modelId.includes(DEFAULT_SEPARATOR)) {
return this.resolveNamespacedEmbeddingModel(modelId)
}
return this.resolveTraditionalEmbeddingModel(fallbackProviderId, modelId)
}
/**
*
*/
async resolveImageModel(modelId: string, fallbackProviderId: string): Promise<ImageModelV2> {
if (modelId.includes(DEFAULT_SEPARATOR)) {
return this.resolveNamespacedImageModel(modelId)
}
return this.resolveTraditionalImageModel(fallbackProviderId, modelId)
}
/**
*
* aihubmix:anthropic:claude-3 -> globalRegistryManagement.languageModel('aihubmix:anthropic:claude-3')
*/
private resolveNamespacedModel(modelId: string): LanguageModelV2 {
return globalRegistryManagement.languageModel(modelId as any)
}
/**
*
* providerId: 'openai', modelId: 'gpt-4' -> globalRegistryManagement.languageModel('openai:gpt-4')
*/
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
console.log('fullModelId', fullModelId)
return globalRegistryManagement.languageModel(fullModelId as any)
}
/**
*
*/
private resolveNamespacedEmbeddingModel(modelId: string): EmbeddingModelV2<string> {
return globalRegistryManagement.textEmbeddingModel(modelId as any)
}
/**
*
*/
private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV2<string> {
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
return globalRegistryManagement.textEmbeddingModel(fullModelId as any)
}
/**
*
*/
private resolveNamespacedImageModel(modelId: string): ImageModelV2 {
return globalRegistryManagement.imageModel(modelId as any)
}
/**
*
*/
private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV2 {
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
return globalRegistryManagement.imageModel(fullModelId as any)
}
}
/**
*
*/
export const globalModelResolver = new ModelResolver()

View File

@ -0,0 +1,9 @@
/**
* Models -
*/
// 核心模型解析器
export { globalModelResolver, ModelResolver } from './ModelResolver'
// 保留的类型定义(可能被其他地方使用)
export type { ModelConfig as ModelConfigType } from './types'

View File

@ -0,0 +1,15 @@
/**
* Creation
*/
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
import type { ProviderId, ProviderSettingsMap } from '../providers/types'
export interface ModelConfig<T extends ProviderId = ProviderId> {
providerId: T
modelId: string
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
middlewares?: LanguageModelV2Middleware[]
// 额外模型参数
extraModelConfig?: Record<string, any>
}

View File

@ -0,0 +1,87 @@
import { streamText } from 'ai'
import {
createAnthropicOptions,
createGenericProviderOptions,
createGoogleOptions,
createOpenAIOptions,
mergeProviderOptions
} from './factory'
// 示例1: 使用已知供应商的严格类型约束
export function exampleOpenAIWithOptions() {
const openaiOptions = createOpenAIOptions({
reasoningEffort: 'medium'
})
// 这里会有类型检查确保选项符合OpenAI的设置
return streamText({
model: {} as any, // 实际使用时替换为真实模型
prompt: 'Hello',
providerOptions: openaiOptions
})
}
// 示例2: 使用Anthropic供应商选项
export function exampleAnthropicWithOptions() {
const anthropicOptions = createAnthropicOptions({
thinking: {
type: 'enabled',
budgetTokens: 1000
}
})
return streamText({
model: {} as any,
prompt: 'Hello',
providerOptions: anthropicOptions
})
}
// 示例3: 使用Google供应商选项
export function exampleGoogleWithOptions() {
const googleOptions = createGoogleOptions({
thinkingConfig: {
includeThoughts: true,
thinkingBudget: 1000
}
})
return streamText({
model: {} as any,
prompt: 'Hello',
providerOptions: googleOptions
})
}
// 示例4: 使用未知供应商(通用类型)
export function exampleUnknownProviderWithOptions() {
const customProviderOptions = createGenericProviderOptions('custom-provider', {
temperature: 0.7,
customSetting: 'value',
anotherOption: true
})
return streamText({
model: {} as any,
prompt: 'Hello',
providerOptions: customProviderOptions
})
}
// 示例5: 合并多个供应商选项
export function exampleMergedOptions() {
const openaiOptions = createOpenAIOptions({})
const customOptions = createGenericProviderOptions('custom', {
customParam: 'value'
})
const mergedOptions = mergeProviderOptions(openaiOptions, customOptions)
return streamText({
model: {} as any,
prompt: 'Hello',
providerOptions: mergedOptions
})
}

View File

@ -0,0 +1,71 @@
import { ExtractProviderOptions, ProviderOptionsMap, TypedProviderOptions } from './types'
/**
*
* @param provider
* @param options
* @returns provider options
*/
export function createProviderOptions<T extends keyof ProviderOptionsMap>(
provider: T,
options: ExtractProviderOptions<T>
): Record<T, ExtractProviderOptions<T>> {
return { [provider]: options } as Record<T, ExtractProviderOptions<T>>
}
/**
*
* @param provider
* @param options
* @returns provider options
*/
export function createGenericProviderOptions<T extends string>(
provider: T,
options: Record<string, any>
): Record<T, Record<string, any>> {
return { [provider]: options } as Record<T, Record<string, any>>
}
/**
* options
* @param optionsMap
* @returns TypedProviderOptions
*/
export function mergeProviderOptions(...optionsMap: Partial<TypedProviderOptions>[]): TypedProviderOptions {
return Object.assign({}, ...optionsMap)
}
/**
* OpenAI供应商选项的便捷函数
*/
export function createOpenAIOptions(options: ExtractProviderOptions<'openai'>) {
return createProviderOptions('openai', options)
}
/**
* Anthropic供应商选项的便捷函数
*/
export function createAnthropicOptions(options: ExtractProviderOptions<'anthropic'>) {
return createProviderOptions('anthropic', options)
}
/**
* Google供应商选项的便捷函数
*/
export function createGoogleOptions(options: ExtractProviderOptions<'google'>) {
return createProviderOptions('google', options)
}
/**
* OpenRouter供应商选项的便捷函数
*/
export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'>) {
return createProviderOptions('openrouter', options)
}
/**
* XAI供应商选项的便捷函数
*/
export function createXaiOptions(options: ExtractProviderOptions<'xai'>) {
return createProviderOptions('xai', options)
}

View File

@ -0,0 +1,2 @@
export * from './factory'
export * from './types'

View File

@ -0,0 +1,38 @@
export type OpenRouterProviderOptions = {
models?: string[]
/**
* https://openrouter.ai/docs/use-cases/reasoning-tokens
* One of `max_tokens` or `effort` is required.
* If `exclude` is true, reasoning will be removed from the response. Default is false.
*/
reasoning?: {
exclude?: boolean
} & (
| {
max_tokens: number
}
| {
effort: 'high' | 'medium' | 'low'
}
)
/**
* A unique identifier representing your end-user, which can
* help OpenRouter to monitor and detect abuse.
*/
user?: string
extraBody?: Record<string, unknown>
/**
* Enable usage accounting to get detailed token usage information.
* https://openrouter.ai/docs/use-cases/usage-accounting
*/
usage?: {
/**
* When true, includes token usage information in the response.
*/
include: boolean
}
}

View File

@ -0,0 +1,33 @@
import { type AnthropicProviderOptions } from '@ai-sdk/anthropic'
import { type GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import { type OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import { type SharedV2ProviderMetadata } from '@ai-sdk/provider'
import { type OpenRouterProviderOptions } from './openrouter'
import { type XaiProviderOptions } from './xai'
export type ProviderOptions<T extends keyof SharedV2ProviderMetadata> = SharedV2ProviderMetadata[T]
/**
* map中没有
*/
export type ProviderOptionsMap = {
openai: OpenAIResponsesProviderOptions
anthropic: AnthropicProviderOptions
google: GoogleGenerativeAIProviderOptions
openrouter: OpenRouterProviderOptions
xai: XaiProviderOptions
}
// 工具类型用于从ProviderOptionsMap中提取特定供应商的选项类型
export type ExtractProviderOptions<T extends keyof ProviderOptionsMap> = ProviderOptionsMap[T]
/**
* ProviderOptions
* 使Record<string, JSONValue>
*/
export type TypedProviderOptions = {
[K in keyof ProviderOptionsMap]?: ProviderOptionsMap[K]
} & {
[K in string]?: Record<string, any>
} & SharedV2ProviderMetadata

View File

@ -0,0 +1,86 @@
// copy from @ai-sdk/xai/xai-chat-options.ts
// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件
import * as z from 'zod/v4'
const webSourceSchema = z.object({
type: z.literal('web'),
country: z.string().length(2).optional(),
excludedWebsites: z.array(z.string()).max(5).optional(),
allowedWebsites: z.array(z.string()).max(5).optional(),
safeSearch: z.boolean().optional()
})
const xSourceSchema = z.object({
type: z.literal('x'),
xHandles: z.array(z.string()).optional()
})
const newsSourceSchema = z.object({
type: z.literal('news'),
country: z.string().length(2).optional(),
excludedWebsites: z.array(z.string()).max(5).optional(),
safeSearch: z.boolean().optional()
})
const rssSourceSchema = z.object({
type: z.literal('rss'),
links: z.array(z.url()).max(1) // currently only supports one RSS link
})
const searchSourceSchema = z.discriminatedUnion('type', [
webSourceSchema,
xSourceSchema,
newsSourceSchema,
rssSourceSchema
])
export const xaiProviderOptions = z.object({
/**
* reasoning effort for reasoning models
* only supported by grok-3-mini and grok-3-mini-fast models
*/
reasoningEffort: z.enum(['low', 'high']).optional(),
searchParameters: z
.object({
/**
* search mode preference
* - "off": disables search completely
* - "auto": model decides whether to search (default)
* - "on": always enables search
*/
mode: z.enum(['off', 'auto', 'on']),
/**
* whether to return citations in the response
* defaults to true
*/
returnCitations: z.boolean().optional(),
/**
* start date for search data (ISO8601 format: YYYY-MM-DD)
*/
fromDate: z.string().optional(),
/**
* end date for search data (ISO8601 format: YYYY-MM-DD)
*/
toDate: z.string().optional(),
/**
* maximum number of search results to consider
* defaults to 20
*/
maxSearchResults: z.number().min(1).max(50).optional(),
/**
* data sources to search from
* defaults to ["web", "x"] if not specified
*/
sources: z.array(searchSourceSchema).optional()
})
.optional()
})
export type XaiProviderOptions = z.infer<typeof xaiProviderOptions>

View File

@ -0,0 +1,257 @@
# AI Core 插件系统
支持四种钩子类型:**First**、**Sequential**、**Parallel** 和 **Stream**
## 🎯 设计理念
- **语义清晰**:不同钩子有不同的执行语义
- **类型安全**TypeScript 完整支持
- **性能优化**First 短路、Parallel 并发、Sequential 链式
- **易于扩展**`enforce` 排序 + 功能分类
## 📋 钩子类型
### 🥇 First 钩子 - 首个有效结果
```typescript
// 只执行第一个返回值的插件,用于解析和查找
resolveModel?: (modelId: string, context: AiRequestContext) => string | null
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null
```
### 🔄 Sequential 钩子 - 链式数据转换
```typescript
// 按顺序链式执行,每个插件可以修改数据
transformParams?: (params: any, context: AiRequestContext) => any
transformResult?: (result: any, context: AiRequestContext) => any
```
### ⚡ Parallel 钩子 - 并行副作用
```typescript
// 并发执行,用于日志、监控等副作用
onRequestStart?: (context: AiRequestContext) => void
onRequestEnd?: (context: AiRequestContext, result: any) => void
onError?: (error: Error, context: AiRequestContext) => void
```
### 🌊 Stream 钩子 - 流处理
```typescript
// 直接使用 AI SDK 的 TransformStream
transformStream?: () => (options) => TransformStream<TextStreamPart, TextStreamPart>
```
## 🚀 快速开始
### 基础用法
```typescript
import { PluginManager, createContext, definePlugin } from '@cherrystudio/ai-core/middleware'
// 创建插件管理器
const pluginManager = new PluginManager()
// 添加插件
pluginManager.use({
name: 'my-plugin',
async transformParams(params, context) {
return { ...params, temperature: 0.7 }
}
})
// 使用插件
const context = createContext('openai', 'gpt-4', { messages: [] })
const transformedParams = await pluginManager.executeSequential(
'transformParams',
{ messages: [{ role: 'user', content: 'Hello' }] },
context
)
```
### 完整示例
```typescript
import {
PluginManager,
ModelAliasPlugin,
LoggingPlugin,
ParamsValidationPlugin,
createContext
} from '@cherrystudio/ai-core/middleware'
// 创建插件管理器
const manager = new PluginManager([
ModelAliasPlugin, // 模型别名解析
ParamsValidationPlugin, // 参数验证
LoggingPlugin // 日志记录
])
// AI 请求流程
async function aiRequest(providerId: string, modelId: string, params: any) {
const context = createContext(providerId, modelId, params)
try {
// 1. 【并行】触发请求开始事件
await manager.executeParallel('onRequestStart', context)
// 2. 【首个】解析模型别名
const resolvedModel = await manager.executeFirst('resolveModel', modelId, context)
context.modelId = resolvedModel || modelId
// 3. 【串行】转换请求参数
const transformedParams = await manager.executeSequential('transformParams', params, context)
// 4. 【流处理】收集流转换器AI SDK 原生支持数组)
const streamTransforms = manager.collectStreamTransforms()
// 5. 调用 AI SDK这里省略具体实现
const result = await callAiSdk(transformedParams, streamTransforms)
// 6. 【串行】转换响应结果
const transformedResult = await manager.executeSequential('transformResult', result, context)
// 7. 【并行】触发请求完成事件
await manager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 8. 【并行】触发错误事件
await manager.executeParallel('onError', context, undefined, error)
throw error
}
}
```
## 🔧 自定义插件
### 模型别名插件
```typescript
const ModelAliasPlugin = definePlugin({
name: 'model-alias',
enforce: 'pre', // 最先执行
async resolveModel(modelId) {
const aliases = {
gpt4: 'gpt-4-turbo-preview',
claude: 'claude-3-sonnet-20240229'
}
return aliases[modelId] || null
}
})
```
### 参数验证插件
```typescript
const ValidationPlugin = definePlugin({
name: 'validation',
async transformParams(params) {
if (!params.messages) {
throw new Error('messages is required')
}
return {
...params,
temperature: params.temperature ?? 0.7,
max_tokens: params.max_tokens ?? 4096
}
}
})
```
### 监控插件
```typescript
const MonitoringPlugin = definePlugin({
name: 'monitoring',
enforce: 'post', // 最后执行
async onRequestEnd(context, result) {
const duration = Date.now() - context.startTime
console.log(`请求耗时: ${duration}ms`)
}
})
```
### 内容过滤插件
```typescript
const FilterPlugin = definePlugin({
name: 'content-filter',
transformStream() {
return () =>
new TransformStream({
transform(chunk, controller) {
if (chunk.type === 'text-delta') {
const filtered = chunk.textDelta.replace(/敏感词/g, '***')
controller.enqueue({ ...chunk, textDelta: filtered })
} else {
controller.enqueue(chunk)
}
}
})
}
})
```
## 📊 执行顺序
### 插件排序
```
enforce: 'pre' → normal → enforce: 'post'
```
### 钩子执行流程
```mermaid
graph TD
A[请求开始] --> B[onRequestStart 并行执行]
B --> C[resolveModel 首个有效]
C --> D[loadTemplate 首个有效]
D --> E[transformParams 串行执行]
E --> F[collectStreamTransforms]
F --> G[AI SDK 调用]
G --> H[transformResult 串行执行]
H --> I[onRequestEnd 并行执行]
G --> J[异常处理]
J --> K[onError 并行执行]
```
## 💡 最佳实践
1. **功能单一**:每个插件专注一个功能
2. **幂等性**:插件应该是幂等的,重复执行不会产生副作用
3. **错误处理**:插件内部处理异常,不要让异常向上传播
4. **性能优化**使用合适的钩子类型First vs Sequential vs Parallel
5. **命名规范**:使用语义化的插件名称
## 🔍 调试工具
```typescript
// 查看插件统计信息
const stats = manager.getStats()
console.log('插件统计:', stats)
// 查看所有插件
const plugins = manager.getPlugins()
console.log(
'已注册插件:',
plugins.map((p) => p.name)
)
```
## ⚡ 性能优势
- **First 钩子**:一旦找到结果立即停止,避免无效计算
- **Parallel 钩子**:真正并发执行,不阻塞主流程
- **Sequential 钩子**:保证数据转换的顺序性
- **Stream 钩子**:直接集成 AI SDK零开销
这个设计兼顾了简洁性和强大功能,为 AI Core 提供了灵活而高效的扩展机制。

View File

@ -0,0 +1,10 @@
/**
*
* 'built-in:'
*/
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
export { createLoggingPlugin } from './logging'
export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin'
export type { PromptToolUseConfig, ToolUseRequestContext, ToolUseResult } from './toolUsePlugin/type'
export { webSearchPlugin } from './webSearchPlugin'

View File

@ -0,0 +1,86 @@
/**
*
* AI调用的关键信息
*/
import { definePlugin } from '../index'
import type { AiRequestContext } from '../types'
export interface LoggingConfig {
// 日志级别
level?: 'debug' | 'info' | 'warn' | 'error'
// 是否记录参数
logParams?: boolean
// 是否记录结果
logResult?: boolean
// 是否记录性能数据
logPerformance?: boolean
// 自定义日志函数
logger?: (level: string, message: string, data?: any) => void
}
/**
*
*/
export function createLoggingPlugin(config: LoggingConfig = {}) {
const { level = 'info', logParams = true, logResult = false, logPerformance = true, logger = console.log } = config
const startTimes = new Map<string, number>()
return definePlugin({
name: 'built-in:logging',
onRequestStart: (context: AiRequestContext) => {
const requestId = context.requestId
startTimes.set(requestId, Date.now())
logger(level, `🚀 AI Request Started`, {
requestId,
providerId: context.providerId,
modelId: context.modelId,
originalParams: logParams ? context.originalParams : '[hidden]'
})
},
onRequestEnd: (context: AiRequestContext, result: any) => {
const requestId = context.requestId
const startTime = startTimes.get(requestId)
const duration = startTime ? Date.now() - startTime : undefined
startTimes.delete(requestId)
const logData: any = {
requestId,
providerId: context.providerId,
modelId: context.modelId
}
if (logPerformance && duration) {
logData.duration = `${duration}ms`
}
if (logResult) {
logData.result = result
}
logger(level, `✅ AI Request Completed`, logData)
},
onError: (error: Error, context: AiRequestContext) => {
const requestId = context.requestId
const startTime = startTimes.get(requestId)
const duration = startTime ? Date.now() - startTime : undefined
startTimes.delete(requestId)
logger('error', `❌ AI Request Failed`, {
requestId,
providerId: context.providerId,
modelId: context.modelId,
duration: duration ? `${duration}ms` : undefined,
error: {
name: error.name,
message: error.message,
stack: error.stack
}
})
}
})
}

View File

@ -0,0 +1,139 @@
/**
*
*
* AI SDK
* promptToolUsePlugin.ts
*/
import type { ModelMessage } from 'ai'
import type { AiRequestContext } from '../../types'
import type { StreamController } from './ToolExecutor'
/**
*
*/
export class StreamEventManager {
/**
*
*/
sendStepStartEvent(controller: StreamController): void {
controller.enqueue({
type: 'start-step',
request: {},
warnings: []
})
}
/**
*
*/
sendStepFinishEvent(controller: StreamController, chunk: any): void {
controller.enqueue({
type: 'finish-step',
finishReason: 'stop',
response: chunk.response,
usage: chunk.usage,
providerMetadata: chunk.providerMetadata
})
}
/**
*
*/
async handleRecursiveCall(
controller: StreamController,
recursiveParams: any,
context: AiRequestContext,
stepId: string
): Promise<void> {
try {
console.log('[MCP Prompt] Starting recursive call after tool execution...')
const recursiveResult = await context.recursiveCall(recursiveParams)
if (recursiveResult && recursiveResult.fullStream) {
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
} else {
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
}
} catch (error) {
this.handleRecursiveCallError(controller, error, stepId)
}
}
/**
*
*/
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
const reader = recursiveStream.getReader()
try {
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
if (value.type === 'finish') {
// 迭代的流不发finish
break
}
// 将递归流的数据传递到当前流
controller.enqueue(value)
}
} finally {
reader.releaseLock()
}
}
/**
*
*/
private handleRecursiveCallError(controller: StreamController, error: unknown, stepId: string): void {
console.error('[MCP Prompt] Recursive call failed:', error)
// 使用 AI SDK 标准错误格式,但不中断流
controller.enqueue({
type: 'error',
error: {
message: error instanceof Error ? error.message : String(error),
name: error instanceof Error ? error.name : 'RecursiveCallError'
}
})
// 继续发送文本增量,保持流的连续性
controller.enqueue({
type: 'text-delta',
id: stepId,
text: '\n\n[工具执行后递归调用失败,继续对话...]'
})
}
/**
*
*/
buildRecursiveParams(context: AiRequestContext, textBuffer: string, toolResultsText: string, tools: any): any {
// 构建新的对话消息
const newMessages: ModelMessage[] = [
...(context.originalParams.messages || []),
{
role: 'assistant',
content: textBuffer
},
{
role: 'user',
content: toolResultsText
}
]
// 递归调用,继续对话,重新传递 tools
const recursiveParams = {
...context.originalParams,
messages: newMessages,
tools: tools
}
// 更新上下文中的消息
context.originalParams.messages = newMessages
return recursiveParams
}
}

View File

@ -0,0 +1,156 @@
/**
*
*
*
* promptToolUsePlugin.ts
*/
import type { ToolSet } from 'ai'
import type { ToolUseResult } from './type'
/**
*
*/
export interface ExecutedResult {
toolCallId: string
toolName: string
result: any
isError?: boolean
}
/**
* AI SDK
*/
export interface StreamController {
enqueue(chunk: any): void
}
/**
*
*/
export class ToolExecutor {
/**
*
*/
async executeTools(
toolUses: ToolUseResult[],
tools: ToolSet,
controller: StreamController
): Promise<ExecutedResult[]> {
const executedResults: ExecutedResult[] = []
for (const toolUse of toolUses) {
try {
const tool = tools[toolUse.toolName]
if (!tool || typeof tool.execute !== 'function') {
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
}
// 发送工具调用开始事件
this.sendToolStartEvents(controller, toolUse)
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
// 发送 tool-call 事件
controller.enqueue({
type: 'tool-call',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: tool.inputSchema
})
const result = await tool.execute(toolUse.arguments, {
toolCallId: toolUse.id,
messages: [],
abortSignal: new AbortController().signal
})
// 发送 tool-result 事件
controller.enqueue({
type: 'tool-result',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: toolUse.arguments,
output: result
})
executedResults.push({
toolCallId: toolUse.id,
toolName: toolUse.toolName,
result,
isError: false
})
} catch (error) {
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
// 处理错误情况
const errorResult = this.handleToolError(toolUse, error, controller)
executedResults.push(errorResult)
}
}
return executedResults
}
/**
* Cherry Studio
*/
formatToolResults(executedResults: ExecutedResult[]): string {
return executedResults
.map((tr) => {
if (!tr.isError) {
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <result>${JSON.stringify(tr.result)}</result>\n</tool_use_result>`
} else {
const error = tr.result || 'Unknown error'
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <error>${error}</error>\n</tool_use_result>`
}
})
.join('\n\n')
}
/**
*
*/
private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
// 发送 tool-input-start 事件
controller.enqueue({
type: 'tool-input-start',
id: toolUse.id,
toolName: toolUse.toolName
})
}
/**
*
*/
private handleToolError(
toolUse: ToolUseResult,
error: unknown,
controller: StreamController
// _tools: ToolSet
): ExecutedResult {
// 使用 AI SDK 标准错误格式
// const toolError: TypedToolError<typeof _tools> = {
// type: 'tool-error',
// toolCallId: toolUse.id,
// toolName: toolUse.toolName,
// input: toolUse.arguments,
// error: error instanceof Error ? error.message : String(error)
// }
// controller.enqueue(toolError)
// 发送标准错误事件
controller.enqueue({
type: 'error',
error: error instanceof Error ? error.message : String(error)
})
return {
toolCallId: toolUse.id,
toolName: toolUse.toolName,
result: error instanceof Error ? error.message : String(error),
isError: true
}
}
}

View File

@ -0,0 +1,373 @@
/**
* MCP Prompt
* Function Call prompt
*
*/
import type { TextStreamPart, ToolSet } from 'ai'
import { definePlugin } from '../../index'
import type { AiRequestContext } from '../../types'
import { StreamEventManager } from './StreamEventManager'
import { ToolExecutor } from './ToolExecutor'
import { PromptToolUseConfig, ToolUseResult } from './type'
/**
* Cherry Studio
*/
const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \\
You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
## Tool Use Formatting
Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure:
<tool_use>
<name>{tool_name}</name>
<arguments>{json_arguments}</arguments>
</tool_use>
The tool name should be the exact name of the tool you are using, and the arguments should be a JSON object containing the parameters required by that tool. For example:
<tool_use>
<name>python_interpreter</name>
<arguments>{"code": "5 + 3 + 1294.678"}</arguments>
</tool_use>
The user will respond with the result of the tool use, which should be formatted as follows:
<tool_use_result>
<name>{tool_name}</name>
<result>{result}</result>
</tool_use_result>
The result should be a string, which can represent a file or any other output type. You can use this result as input for the next action.
For example, if the result of the tool use is an image file, you can use it in the next action like this:
<tool_use>
<name>image_transformer</name>
<arguments>{"image": "image_1.jpg"}</arguments>
</tool_use>
Always adhere to this format for the tool use to ensure proper parsing and execution.
## Tool Use Examples
{{ TOOL_USE_EXAMPLES }}
## Tool Use Available Tools
Above example were using notional tools that might not exist for you. You only have access to these tools:
{{ AVAILABLE_TOOLS }}
## Tool Use Rules
Here are the rules you should always follow to solve your task:
1. Always use the right arguments for the tools. Never use variable names as the action arguments, use the value instead.
2. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself.
3. If no tool call is needed, just answer the question directly.
4. Never re-do a tool call that you previously did with the exact same parameters.
5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format.
# User Instructions
{{ USER_SYSTEM_PROMPT }}
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.`
/**
* 使 Cherry Studio
*/
const DEFAULT_TOOL_USE_EXAMPLES = `
Here are a few examples using notional tools:
---
User: Generate an image of the oldest person in this document.
A: I can use the document_qa tool to find out who the oldest person is in the document.
<tool_use>
<name>document_qa</name>
<arguments>{"document": "document.pdf", "question": "Who is the oldest person mentioned?"}</arguments>
</tool_use>
User: <tool_use_result>
<name>document_qa</name>
<result>John Doe, a 55 year old lumberjack living in Newfoundland.</result>
</tool_use_result>
A: I can use the image_generator tool to create a portrait of John Doe.
<tool_use>
<name>image_generator</name>
<arguments>{"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}</arguments>
</tool_use>
User: <tool_use_result>
<name>image_generator</name>
<result>image.png</result>
</tool_use_result>
A: the image is generated as image.png
---
User: "What is the result of the following operation: 5 + 3 + 1294.678?"
A: I can use the python_interpreter tool to calculate the result of the operation.
<tool_use>
<name>python_interpreter</name>
<arguments>{"code": "5 + 3 + 1294.678"}</arguments>
</tool_use>
User: <tool_use_result>
<name>python_interpreter</name>
<result>1302.678</result>
</tool_use_result>
A: The result of the operation is 1302.678.
---
User: "Which city has the highest population , Guangzhou or Shanghai?"
A: I can use the search tool to find the population of Guangzhou.
<tool_use>
<name>search</name>
<arguments>{"query": "Population Guangzhou"}</arguments>
</tool_use>
User: <tool_use_result>
<name>search</name>
<result>Guangzhou has a population of 15 million inhabitants as of 2021.</result>
</tool_use_result>
A: I can use the search tool to find the population of Shanghai.
<tool_use>
<name>search</name>
<arguments>{"query": "Population Shanghai"}</arguments>
</tool_use>
User: <tool_use_result>
<name>search</name>
<result>26 million (2019)</result>
</tool_use_result>
Assistant: The population of Shanghai is 26 million, while Guangzhou has a population of 15 million. Therefore, Shanghai has the highest population.`
/**
* Cherry Studio
*/
function buildAvailableTools(tools: ToolSet): string {
const availableTools = Object.keys(tools)
.map((toolName: string) => {
const tool = tools[toolName]
return `
<tool>
<name>${toolName}</name>
<description>${tool.description || ''}</description>
<arguments>
${tool.inputSchema ? JSON.stringify(tool.inputSchema) : ''}
</arguments>
</tool>
`
})
.join('\n')
return `<tools>
${availableTools}
</tools>`
}
/**
* Cherry Studio
*/
function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): string {
const availableTools = buildAvailableTools(tools)
const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES)
.replace('{{ AVAILABLE_TOOLS }}', availableTools)
.replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt || '')
return fullPrompt
}
/**
* Cherry Studio
* XML
*/
function defaultParseToolUse(content: string, tools: ToolSet): { results: ToolUseResult[]; content: string } {
if (!content || !tools || Object.keys(tools).length === 0) {
return { results: [], content: content }
}
// 支持两种格式:
// 1. 完整的 <tool_use></tool_use> 标签包围的内容
// 2. 只有内部内容(从 TagExtractor 提取出来的)
let contentToProcess = content
// 如果内容不包含 <tool_use> 标签,说明是从 TagExtractor 提取的内部内容,需要包装
if (!content.includes('<tool_use>')) {
contentToProcess = `<tool_use>\n${content}\n</tool_use>`
}
const toolUsePattern =
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
const results: ToolUseResult[] = []
let match
let idx = 0
// Find all tool use blocks
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
const fullMatch = match[0]
const toolName = match[2].trim()
const toolArgs = match[4].trim()
// Try to parse the arguments as JSON
let parsedArgs
try {
parsedArgs = JSON.parse(toolArgs)
} catch (error) {
// If parsing fails, use the string as is
parsedArgs = toolArgs
}
// Find the corresponding tool
const tool = tools[toolName]
if (!tool) {
console.warn(`Tool "${toolName}" not found in available tools`)
continue
}
// Add to results array
results.push({
id: `${toolName}-${idx++}`, // Unique ID for each tool use
toolName: toolName,
arguments: parsedArgs,
status: 'pending'
})
contentToProcess = contentToProcess.replace(fullMatch, '')
}
return { results, content: contentToProcess }
}
export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config
return definePlugin({
name: 'built-in:prompt-tool-use',
transformParams: (params: any, context: AiRequestContext) => {
if (!enabled || !params.tools || typeof params.tools !== 'object') {
return params
}
context.mcpTools = params.tools
console.log('tools stored in context', params.tools)
// 构建系统提示符
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
let systemMessage: string | null = systemPrompt
console.log('config.context', context)
if (config.createSystemMessage) {
// 🎯 如果用户提供了自定义处理函数,使用它
systemMessage = config.createSystemMessage(systemPrompt, params, context)
}
// 移除 tools改为 prompt 模式
const transformedParams = {
...params,
...(systemMessage ? { system: systemMessage } : {}),
tools: undefined
}
context.originalParams = transformedParams
console.log('transformedParams', transformedParams)
return transformedParams
},
transformStream: (_: any, context: AiRequestContext) => () => {
let textBuffer = ''
let stepId = ''
if (!context.mcpTools) {
throw new Error('No tools available')
}
// 创建工具执行器和流事件管理器
const toolExecutor = new ToolExecutor()
const streamEventManager = new StreamEventManager()
type TOOLS = NonNullable<typeof context.mcpTools>
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
async transform(
chunk: TextStreamPart<TOOLS>,
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
) {
// 收集文本内容
if (chunk.type === 'text-delta') {
textBuffer += chunk.text || ''
stepId = chunk.id || ''
controller.enqueue(chunk)
return
}
if (chunk.type === 'text-end' || chunk.type === 'finish-step') {
const tools = context.mcpTools
if (!tools || Object.keys(tools).length === 0) {
controller.enqueue(chunk)
return
}
// 解析工具调用
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools)
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
// 如果没有有效的工具调用,直接传递原始事件
if (validToolUses.length === 0) {
controller.enqueue(chunk)
return
}
if (chunk.type === 'text-end') {
controller.enqueue({
type: 'text-end',
id: stepId,
providerMetadata: {
text: {
value: parsedContent
}
}
})
return
}
controller.enqueue({
...chunk,
finishReason: 'tool-calls'
})
// 发送步骤开始事件
streamEventManager.sendStepStartEvent(controller)
// 执行工具调用
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
// 发送步骤完成事件
streamEventManager.sendStepFinishEvent(controller, chunk)
// 处理递归调用
if (validToolUses.length > 0) {
const toolResultsText = toolExecutor.formatToolResults(executedResults)
const recursiveParams = streamEventManager.buildRecursiveParams(
context,
textBuffer,
toolResultsText,
tools
)
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context, stepId)
}
// 清理状态
textBuffer = ''
return
}
// 对于其他类型的事件,直接传递
controller.enqueue(chunk)
},
flush() {
// 流结束时的清理工作
console.log('[MCP Prompt] Stream ended, cleaning up...')
}
})
}
})
}

View File

@ -0,0 +1,196 @@
// Copied from https://github.com/vercel/ai/blob/main/packages/ai/core/util/get-potential-start-index.ts
/**
* Returns the index of the start of the searchedText in the text, or null if it
* is not found.
*/
export function getPotentialStartIndex(text: string, searchedText: string): number | null {
// Return null immediately if searchedText is empty.
if (searchedText.length === 0) {
return null
}
// Check if the searchedText exists as a direct substring of text.
const directIndex = text.indexOf(searchedText)
if (directIndex !== -1) {
return directIndex
}
// Otherwise, look for the largest suffix of "text" that matches
// a prefix of "searchedText". We go from the end of text inward.
for (let i = text.length - 1; i >= 0; i--) {
const suffix = text.substring(i)
if (searchedText.startsWith(suffix)) {
return i
}
}
return null
}
export interface TagConfig {
openingTag: string
closingTag: string
separator?: string
}
export interface TagExtractionState {
textBuffer: string
isInsideTag: boolean
isFirstTag: boolean
isFirstText: boolean
afterSwitch: boolean
accumulatedTagContent: string
hasTagContent: boolean
}
export interface TagExtractionResult {
content: string
isTagContent: boolean
complete: boolean
tagContentExtracted?: string
}
/**
*
* <think>...</think>, <tool_use>...</tool_use>
*/
export class TagExtractor {
private config: TagConfig
private state: TagExtractionState
constructor(config: TagConfig) {
this.config = config
this.state = {
textBuffer: '',
isInsideTag: false,
isFirstTag: true,
isFirstText: true,
afterSwitch: false,
accumulatedTagContent: '',
hasTagContent: false
}
}
/**
*
*/
processText(newText: string): TagExtractionResult[] {
this.state.textBuffer += newText
const results: TagExtractionResult[] = []
// 处理标签提取逻辑
while (true) {
const nextTag = this.state.isInsideTag ? this.config.closingTag : this.config.openingTag
const startIndex = getPotentialStartIndex(this.state.textBuffer, nextTag)
if (startIndex == null) {
const content = this.state.textBuffer
if (content.length > 0) {
results.push({
content: this.addPrefix(content),
isTagContent: this.state.isInsideTag,
complete: false
})
if (this.state.isInsideTag) {
this.state.accumulatedTagContent += this.addPrefix(content)
this.state.hasTagContent = true
}
}
this.state.textBuffer = ''
break
}
// 处理标签前的内容
const contentBeforeTag = this.state.textBuffer.slice(0, startIndex)
if (contentBeforeTag.length > 0) {
results.push({
content: this.addPrefix(contentBeforeTag),
isTagContent: this.state.isInsideTag,
complete: false
})
if (this.state.isInsideTag) {
this.state.accumulatedTagContent += this.addPrefix(contentBeforeTag)
this.state.hasTagContent = true
}
}
const foundFullMatch = startIndex + nextTag.length <= this.state.textBuffer.length
if (foundFullMatch) {
// 如果找到完整的标签
this.state.textBuffer = this.state.textBuffer.slice(startIndex + nextTag.length)
// 如果刚刚结束一个标签内容,生成完整的标签内容结果
if (this.state.isInsideTag && this.state.hasTagContent) {
results.push({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: this.state.accumulatedTagContent
})
this.state.accumulatedTagContent = ''
this.state.hasTagContent = false
}
this.state.isInsideTag = !this.state.isInsideTag
this.state.afterSwitch = true
if (this.state.isInsideTag) {
this.state.isFirstTag = false
} else {
this.state.isFirstText = false
}
} else {
this.state.textBuffer = this.state.textBuffer.slice(startIndex)
break
}
}
return results
}
/**
*
*/
finalize(): TagExtractionResult | null {
if (this.state.hasTagContent && this.state.accumulatedTagContent) {
const result = {
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: this.state.accumulatedTagContent
}
this.state.accumulatedTagContent = ''
this.state.hasTagContent = false
return result
}
return null
}
private addPrefix(text: string): string {
const needsPrefix =
this.state.afterSwitch && (this.state.isInsideTag ? !this.state.isFirstTag : !this.state.isFirstText)
const prefix = needsPrefix && this.config.separator ? this.config.separator : ''
this.state.afterSwitch = false
return prefix + text
}
/**
*
*/
reset(): void {
this.state = {
textBuffer: '',
isInsideTag: false,
isFirstTag: true,
isFirstText: true,
afterSwitch: false,
accumulatedTagContent: '',
hasTagContent: false
}
}
}

View File

@ -0,0 +1,33 @@
import { ToolSet } from 'ai'
import { AiRequestContext } from '../..'
/**
*
* AI响应中解析出的工具使用意图
*/
export interface ToolUseResult {
id: string
toolName: string
arguments: any
status: 'pending' | 'invoking' | 'done' | 'error'
}
export interface BaseToolUsePluginConfig {
enabled?: boolean
}
export interface PromptToolUseConfig extends BaseToolUsePluginConfig {
// 自定义系统提示符构建函数(可选,有默认实现)
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
// 自定义工具解析函数(可选,有默认实现)
parseToolUse?: (content: string, tools: ToolSet) => { results: ToolUseResult[]; content: string }
createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null
}
/**
* AI MCP
*/
export interface ToolUseRequestContext extends AiRequestContext {
mcpTools: ToolSet
}

View File

@ -0,0 +1,67 @@
import { anthropic } from '@ai-sdk/anthropic'
import { google } from '@ai-sdk/google'
import { openai } from '@ai-sdk/openai'
import { ProviderOptionsMap } from '../../../options/types'
/**
* AI SDK
*/
type OpenAISearchConfig = Parameters<typeof openai.tools.webSearchPreview>[0]
type AnthropicSearchConfig = Parameters<typeof anthropic.tools.webSearch_20250305>[0]
type GoogleSearchConfig = Parameters<typeof google.tools.googleSearch>[0]
/**
*
*
* ProviderOptions 便
*/
export interface WebSearchPluginConfig {
openai?: OpenAISearchConfig
anthropic?: AnthropicSearchConfig
xai?: ProviderOptionsMap['xai']['searchParameters']
google?: GoogleSearchConfig
'google-vertex'?: GoogleSearchConfig
}
/**
*
*/
export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
google: {},
'google-vertex': {},
openai: {},
xai: {
mode: 'on',
returnCitations: true,
maxSearchResults: 5,
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
},
anthropic: {
maxUses: 5
}
}
export type WebSearchToolOutputSchema = {
// Anthropic 工具 - 手动定义
anthropicWebSearch: Array<{
url: string
title: string
pageAge: string | null
encryptedContent: string
type: string
}>
// OpenAI 工具 - 基于实际输出
openaiWebSearch: {
status: 'completed' | 'failed'
}
// Google 工具
googleSearch: {
webSearchQueries?: string[]
groundingChunks?: Array<{
web?: { uri: string; title: string }
}>
}
}

View File

@ -0,0 +1,69 @@
/**
* Web Search Plugin
* AI Provider
*/
import { anthropic } from '@ai-sdk/anthropic'
import { google } from '@ai-sdk/google'
import { openai } from '@ai-sdk/openai'
import { createXaiOptions, mergeProviderOptions } from '../../../options'
import { definePlugin } from '../../'
import type { AiRequestContext } from '../../types'
import { DEFAULT_WEB_SEARCH_CONFIG, WebSearchPluginConfig } from './helper'
/**
*
*
* @param config -
*/
export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEARCH_CONFIG) =>
definePlugin({
name: 'webSearch',
enforce: 'pre',
transformParams: async (params: any, context: AiRequestContext) => {
const { providerId } = context
switch (providerId) {
case 'openai': {
if (config.openai) {
if (!params.tools) params.tools = {}
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
}
break
}
case 'anthropic': {
if (config.anthropic) {
if (!params.tools) params.tools = {}
params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic)
}
break
}
case 'google': {
// case 'google-vertex':
if (!params.tools) params.tools = {}
params.tools.web_search = google.tools.googleSearch(config.google || {})
break
}
case 'xai': {
if (config.xai) {
const searchOptions = createXaiOptions({
searchParameters: { ...config.xai, mode: 'on' }
})
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
}
break
}
}
return params
}
})
// 导出类型定义供开发者使用
export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper'
// 默认导出
export default webSearchPlugin

View File

@ -0,0 +1,32 @@
// 核心类型和接口
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
import type { ProviderId } from '../providers'
import type { AiPlugin, AiRequestContext } from './types'
// 插件管理器
export { PluginManager } from './manager'
// 工具函数
export function createContext<T extends ProviderId>(
providerId: T,
modelId: string,
originalParams: any
): AiRequestContext {
return {
providerId,
modelId,
originalParams,
metadata: {},
startTime: Date.now(),
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
// 占位
recursiveCall: () => Promise.resolve(null)
}
}
// 插件构建器 - 便于创建插件
export function definePlugin(plugin: AiPlugin): AiPlugin
export function definePlugin<T extends (...args: any[]) => AiPlugin>(pluginFactory: T): T
export function definePlugin(plugin: AiPlugin | ((...args: any[]) => AiPlugin)) {
return plugin
}

View File

@ -0,0 +1,184 @@
import { AiPlugin, AiRequestContext } from './types'
/**
*
*/
export class PluginManager {
private plugins: AiPlugin[] = []
constructor(plugins: AiPlugin[] = []) {
this.plugins = this.sortPlugins(plugins)
}
/**
*
*/
use(plugin: AiPlugin): this {
this.plugins = this.sortPlugins([...this.plugins, plugin])
return this
}
/**
*
*/
remove(pluginName: string): this {
this.plugins = this.plugins.filter((p) => p.name !== pluginName)
return this
}
/**
* pre -> normal -> post
*/
private sortPlugins(plugins: AiPlugin[]): AiPlugin[] {
const pre: AiPlugin[] = []
const normal: AiPlugin[] = []
const post: AiPlugin[] = []
plugins.forEach((plugin) => {
if (plugin.enforce === 'pre') {
pre.push(plugin)
} else if (plugin.enforce === 'post') {
post.push(plugin)
} else {
normal.push(plugin)
}
})
return [...pre, ...normal, ...post]
}
/**
* First -
*/
async executeFirst<T>(
hookName: 'resolveModel' | 'loadTemplate',
arg: any,
context: AiRequestContext
): Promise<T | null> {
for (const plugin of this.plugins) {
const hook = plugin[hookName]
if (hook) {
const result = await hook(arg, context)
if (result !== null && result !== undefined) {
return result as T
}
}
}
return null
}
/**
* Sequential -
*/
async executeSequential<T>(
hookName: 'transformParams' | 'transformResult',
initialValue: T,
context: AiRequestContext
): Promise<T> {
let result = initialValue
for (const plugin of this.plugins) {
const hook = plugin[hookName]
if (hook) {
result = await hook<T>(result, context)
}
}
return result
}
/**
* ConfigureContext -
*/
async executeConfigureContext(context: AiRequestContext): Promise<void> {
for (const plugin of this.plugins) {
const hook = plugin.configureContext
if (hook) {
await hook(context)
}
}
}
/**
* Parallel -
*/
async executeParallel(
hookName: 'onRequestStart' | 'onRequestEnd' | 'onError',
context: AiRequestContext,
result?: any,
error?: Error
): Promise<void> {
const promises = this.plugins
.map((plugin) => {
const hook = plugin[hookName]
if (!hook) return null
if (hookName === 'onError' && error) {
return (hook as any)(error, context)
} else if (hookName === 'onRequestEnd' && result !== undefined) {
return (hook as any)(context, result)
} else if (hookName === 'onRequestStart') {
return (hook as any)(context)
}
return null
})
.filter(Boolean)
// 使用 Promise.all 而不是 allSettled让插件错误能够抛出
await Promise.all(promises)
}
/**
* AI SDK
*/
collectStreamTransforms(params: any, context: AiRequestContext) {
return this.plugins
.filter((plugin) => plugin.transformStream)
.map((plugin) => plugin.transformStream?.(params, context))
}
/**
*
*/
getPlugins(): AiPlugin[] {
return [...this.plugins]
}
/**
*
*/
getStats() {
const stats = {
total: this.plugins.length,
pre: 0,
normal: 0,
post: 0,
hooks: {
resolveModel: 0,
loadTemplate: 0,
transformParams: 0,
transformResult: 0,
onRequestStart: 0,
onRequestEnd: 0,
onError: 0,
transformStream: 0
}
}
this.plugins.forEach((plugin) => {
// 统计 enforce 类型
if (plugin.enforce === 'pre') stats.pre++
else if (plugin.enforce === 'post') stats.post++
else stats.normal++
// 统计钩子数量
Object.keys(stats.hooks).forEach((hookName) => {
if (plugin[hookName as keyof AiPlugin]) {
stats.hooks[hookName as keyof typeof stats.hooks]++
}
})
})
return stats
}
}

View File

@ -0,0 +1,79 @@
import type { ImageModelV2 } from '@ai-sdk/provider'
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
import { type ProviderId } from '../providers/types'
/**
*
* 使 any
*/
export type RecursiveCallFn = (newParams: any) => Promise<any>
/**
* AI
*/
export interface AiRequestContext {
providerId: ProviderId
modelId: string
originalParams: any
metadata: Record<string, any>
startTime: number
requestId: string
recursiveCall: RecursiveCallFn
isRecursiveCall?: boolean
mcpTools?: ToolSet
[key: string]: any
}
/**
*
*/
export interface AiPlugin {
name: string
enforce?: 'pre' | 'post'
// 【First】首个钩子 - 只执行第一个返回值的插件
resolveModel?: (
modelId: string,
context: AiRequestContext
) => Promise<LanguageModel | ImageModelV2 | null> | LanguageModel | ImageModelV2 | null
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
// 【Sequential】串行钩子 - 链式执行,支持数据转换
configureContext?: (context: AiRequestContext) => void | Promise<void>
transformParams?: <T>(params: T, context: AiRequestContext) => T | Promise<T>
transformResult?: <T>(result: T, context: AiRequestContext) => T | Promise<T>
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
// 【Stream】流处理 - 直接使用 AI SDK
transformStream?: (
params: any,
context: AiRequestContext
) => <TOOLS extends ToolSet>(options?: {
tools: TOOLS
stopStream: () => void
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
// AI SDK 原生中间件
// aiSdkMiddlewares?: LanguageModelV1Middleware[]
}
/**
*
*/
export interface PluginManagerConfig {
plugins: AiPlugin[]
context: Partial<AiRequestContext>
}
/**
*
*/
export interface HookResult<T = any> {
value: T
stop?: boolean
}

View File

@ -0,0 +1,101 @@
/**
* Hub Provider - provider
*
* 支持格式: hubId:providerId:modelId
* 例如: aihubmix:anthropic:claude-3.5-sonnet
*/
import { ProviderV2 } from '@ai-sdk/provider'
import { customProvider } from 'ai'
import { globalRegistryManagement } from './RegistryManagement'
import type { AiSdkMethodName, AiSdkModelReturn, AiSdkModelType } from './types'
export interface HubProviderConfig {
/** Hub的唯一标识符 */
hubId: string
/** 是否启用调试日志 */
debug?: boolean
}
export class HubProviderError extends Error {
constructor(
message: string,
public readonly hubId: string,
public readonly providerId?: string,
public readonly originalError?: Error
) {
super(message)
this.name = 'HubProviderError'
}
}
/**
* Hub模型ID
*/
function parseHubModelId(modelId: string): { provider: string; actualModelId: string } {
const parts = modelId.split(':')
if (parts.length !== 2) {
throw new HubProviderError(`Invalid hub model ID format. Expected "provider:modelId", got: ${modelId}`, 'unknown')
}
return {
provider: parts[0],
actualModelId: parts[1]
}
}
/**
* Hub Provider
*/
export function createHubProvider(config: HubProviderConfig): ProviderV2 {
const { hubId } = config
function getTargetProvider(providerId: string): ProviderV2 {
// 从全局注册表获取provider实例
try {
const provider = globalRegistryManagement.getProvider(providerId)
if (!provider) {
throw new HubProviderError(
`Provider "${providerId}" is not initialized. Please call initializeProvider("${providerId}", options) first.`,
hubId,
providerId
)
}
return provider
} catch (error) {
throw new HubProviderError(
`Failed to get provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
hubId,
providerId,
error instanceof Error ? error : undefined
)
}
}
function resolveModel<T extends AiSdkModelType>(
modelId: string,
modelType: T,
methodName: AiSdkMethodName<T>
): AiSdkModelReturn<T> {
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
const fn = targetProvider[methodName] as (id: string) => AiSdkModelReturn<T>
if (!fn) {
throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider)
}
return fn(actualModelId)
}
return customProvider({
fallbackProvider: {
languageModel: (modelId: string) => resolveModel(modelId, 'text', 'languageModel'),
textEmbeddingModel: (modelId: string) => resolveModel(modelId, 'embedding', 'textEmbeddingModel'),
imageModel: (modelId: string) => resolveModel(modelId, 'image', 'imageModel'),
transcriptionModel: (modelId: string) => resolveModel(modelId, 'transcription', 'transcriptionModel'),
speechModel: (modelId: string) => resolveModel(modelId, 'speech', 'speechModel')
}
})
}

View File

@ -0,0 +1,221 @@
/**
* Provider
* provider
* AI SDK createProviderRegistry
*/
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, ProviderV2 } from '@ai-sdk/provider'
import { createProviderRegistry, type ProviderRegistryProvider } from 'ai'
type PROVIDERS = Record<string, ProviderV2>
export const DEFAULT_SEPARATOR = '|'
// export type MODEL_ID = `${string}${typeof DEFAULT_SEPARATOR}${string}`
export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARATOR> {
private providers: PROVIDERS = {}
private aliases: Set<string> = new Set() // 记录哪些key是别名
private separator: SEPARATOR
private registry: ProviderRegistryProvider<PROVIDERS, SEPARATOR> | null = null
constructor(options: { separator: SEPARATOR } = { separator: DEFAULT_SEPARATOR as SEPARATOR }) {
this.separator = options.separator
}
/**
* provider
*/
registerProvider(id: string, provider: ProviderV2, aliases?: string[]): this {
// 注册主provider
this.providers[id] = provider
// 注册别名都指向同一个provider实例
if (aliases) {
aliases.forEach((alias) => {
this.providers[alias] = provider // 直接存储引用
this.aliases.add(alias) // 标记为别名
})
}
this.rebuildRegistry()
return this
}
/**
* provider实例
*/
getProvider(id: string): ProviderV2 | undefined {
return this.providers[id]
}
/**
* providers
*/
registerProviders(providers: Record<string, ProviderV2>): this {
Object.assign(this.providers, providers)
this.rebuildRegistry()
return this
}
/**
* provider
*/
unregisterProvider(id: string): this {
const provider = this.providers[id]
if (!provider) return this
// 如果移除的是真实ID需要清理所有指向它的别名
if (!this.aliases.has(id)) {
// 找到所有指向此provider的别名并删除
const aliasesToRemove: string[] = []
this.aliases.forEach((alias) => {
if (this.providers[alias] === provider) {
aliasesToRemove.push(alias)
}
})
aliasesToRemove.forEach((alias) => {
delete this.providers[alias]
this.aliases.delete(alias)
})
} else {
// 如果移除的是别名,只删除别名记录
this.aliases.delete(id)
}
delete this.providers[id]
this.rebuildRegistry()
return this
}
/**
* registry -
*/
private rebuildRegistry(): void {
if (Object.keys(this.providers).length === 0) {
this.registry = null
return
}
this.registry = createProviderRegistry<PROVIDERS, SEPARATOR>(this.providers, {
separator: this.separator
})
}
/**
* - AI SDK
*/
languageModel(id: `${string}${SEPARATOR}${string}`): LanguageModelV2 {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.languageModel(id)
}
/**
* - AI SDK
*/
textEmbeddingModel(id: `${string}${SEPARATOR}${string}`): EmbeddingModelV2<string> {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.textEmbeddingModel(id)
}
/**
* - AI SDK
*/
imageModel(id: `${string}${SEPARATOR}${string}`): ImageModelV2 {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.imageModel(id)
}
/**
* - AI SDK
*/
transcriptionModel(id: `${string}${SEPARATOR}${string}`): any {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.transcriptionModel(id)
}
/**
* - AI SDK
*/
speechModel(id: `${string}${SEPARATOR}${string}`): any {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.speechModel(id)
}
/**
* provider
*/
getRegisteredProviders(): string[] {
return Object.keys(this.providers)
}
/**
* providers
*/
hasProviders(): boolean {
return Object.keys(this.providers).length > 0
}
/**
* providers
*/
clear(): this {
this.providers = {}
this.aliases.clear()
this.registry = null
return this
}
/**
* Provider IDgetAiSdkProviderId使用
* Provider ID
* ID
*/
resolveProviderId(id: string): string {
if (!this.aliases.has(id)) return id // 不是别名,直接返回
// 是别名找到真实ID
const targetProvider = this.providers[id]
for (const [realId, provider] of Object.entries(this.providers)) {
if (provider === targetProvider && !this.aliases.has(realId)) {
return realId
}
}
return id
}
/**
*
*/
isAlias(id: string): boolean {
return this.aliases.has(id)
}
/**
*
*/
getAllAliases(): Record<string, string> {
const result: Record<string, string> = {}
this.aliases.forEach((alias) => {
result[alias] = this.resolveProviderId(alias)
})
return result
}
}
/**
*
* 使 | 因为 : 会和 :free suffix冲突
*/
export const globalRegistryManagement = new RegistryManagement()

View File

@ -0,0 +1,632 @@
/**
* AiProviderRegistry
*/
import { beforeEach, describe, expect, it, vi } from 'vitest'
// 模拟 AI SDK
vi.mock('@ai-sdk/openai', () => ({
createOpenAI: vi.fn(() => ({ name: 'openai-mock' }))
}))
vi.mock('@ai-sdk/anthropic', () => ({
createAnthropic: vi.fn(() => ({ name: 'anthropic-mock' }))
}))
vi.mock('@ai-sdk/azure', () => ({
createAzure: vi.fn(() => ({ name: 'azure-mock' }))
}))
vi.mock('@ai-sdk/deepseek', () => ({
createDeepSeek: vi.fn(() => ({ name: 'deepseek-mock' }))
}))
vi.mock('@ai-sdk/google', () => ({
createGoogleGenerativeAI: vi.fn(() => ({ name: 'google-mock' }))
}))
vi.mock('@ai-sdk/openai-compatible', () => ({
createOpenAICompatible: vi.fn(() => ({ name: 'openai-compatible-mock' }))
}))
vi.mock('@ai-sdk/xai', () => ({
createXai: vi.fn(() => ({ name: 'xai-mock' }))
}))
import {
cleanup,
clearAllProviders,
createAndRegisterProvider,
createProvider,
getAllProviderConfigAliases,
getAllProviderConfigs,
getInitializedProviders,
getLanguageModel,
getProviderConfig,
getProviderConfigByAlias,
getSupportedProviders,
hasInitializedProviders,
hasProviderConfig,
hasProviderConfigByAlias,
isProviderConfigAlias,
ProviderInitializationError,
providerRegistry,
registerMultipleProviderConfigs,
registerProvider,
registerProviderConfig,
resolveProviderConfigId
} from '../registry'
import type { ProviderConfig } from '../schemas'
describe('Provider Registry 功能测试', () => {
beforeEach(() => {
// 清理状态
cleanup()
})
describe('基础功能', () => {
it('能够获取支持的 providers 列表', () => {
const providers = getSupportedProviders()
expect(Array.isArray(providers)).toBe(true)
expect(providers.length).toBeGreaterThan(0)
// 检查返回的数据结构
providers.forEach((provider) => {
expect(provider).toHaveProperty('id')
expect(provider).toHaveProperty('name')
expect(typeof provider.id).toBe('string')
expect(typeof provider.name).toBe('string')
})
// 包含基础 providers
const providerIds = providers.map((p) => p.id)
expect(providerIds).toContain('openai')
expect(providerIds).toContain('anthropic')
expect(providerIds).toContain('google')
})
it('能够获取已初始化的 providers', () => {
// 初始状态下没有已初始化的 providers
expect(getInitializedProviders()).toEqual([])
expect(hasInitializedProviders()).toBe(false)
})
it('能够访问全局注册管理器', () => {
expect(providerRegistry).toBeDefined()
expect(typeof providerRegistry.clear).toBe('function')
expect(typeof providerRegistry.getRegisteredProviders).toBe('function')
expect(typeof providerRegistry.hasProviders).toBe('function')
})
it('能够获取语言模型', () => {
// 在没有注册 provider 的情况下,这个函数应该会抛出错误
expect(() => getLanguageModel('non-existent')).toThrow('No providers registered')
})
})
describe('Provider 配置注册', () => {
it('能够注册自定义 provider 配置', () => {
const config: ProviderConfig = {
id: 'custom-provider',
name: 'Custom Provider',
creator: vi.fn(() => ({ name: 'custom' })),
supportsImageGeneration: false
}
const success = registerProviderConfig(config)
expect(success).toBe(true)
expect(hasProviderConfig('custom-provider')).toBe(true)
expect(getProviderConfig('custom-provider')).toEqual(config)
})
it('能够注册带别名的 provider 配置', () => {
const config: ProviderConfig = {
id: 'custom-provider-with-aliases',
name: 'Custom Provider with Aliases',
creator: vi.fn(() => ({ name: 'custom-aliased' })),
supportsImageGeneration: false,
aliases: ['alias-1', 'alias-2']
}
const success = registerProviderConfig(config)
expect(success).toBe(true)
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
expect(getProviderConfigByAlias('alias-1')).toEqual(config)
expect(resolveProviderConfigId('alias-1')).toBe('custom-provider-with-aliases')
})
it('拒绝无效的配置', () => {
// 缺少必要字段
const invalidConfig = {
id: 'invalid-provider'
// 缺少 name, creator 等
}
const success = registerProviderConfig(invalidConfig as any)
expect(success).toBe(false)
})
it('能够批量注册 provider 配置', () => {
const configs: ProviderConfig[] = [
{
id: 'provider-1',
name: 'Provider 1',
creator: vi.fn(() => ({ name: 'provider-1' })),
supportsImageGeneration: false
},
{
id: 'provider-2',
name: 'Provider 2',
creator: vi.fn(() => ({ name: 'provider-2' })),
supportsImageGeneration: true
},
{
id: '', // 无效配置
name: 'Invalid Provider',
creator: vi.fn(() => ({ name: 'invalid' })),
supportsImageGeneration: false
} as any
]
const successCount = registerMultipleProviderConfigs(configs)
expect(successCount).toBe(2) // 只有前两个成功
expect(hasProviderConfig('provider-1')).toBe(true)
expect(hasProviderConfig('provider-2')).toBe(true)
expect(hasProviderConfig('')).toBe(false)
})
it('能够获取所有配置和别名信息', () => {
// 注册一些配置
registerProviderConfig({
id: 'test-provider',
name: 'Test Provider',
creator: vi.fn(),
supportsImageGeneration: false,
aliases: ['test-alias']
})
const allConfigs = getAllProviderConfigs()
expect(Array.isArray(allConfigs)).toBe(true)
expect(allConfigs.some((config) => config.id === 'test-provider')).toBe(true)
const aliases = getAllProviderConfigAliases()
expect(aliases['test-alias']).toBe('test-provider')
expect(isProviderConfigAlias('test-alias')).toBe(true)
})
})
describe('Provider 创建和注册', () => {
it('能够创建 provider 实例', async () => {
const config: ProviderConfig = {
id: 'test-create-provider',
name: 'Test Create Provider',
creator: vi.fn(() => ({ name: 'test-created' })),
supportsImageGeneration: false
}
// 先注册配置
registerProviderConfig(config)
// 创建 provider 实例
const provider = await createProvider('test-create-provider', { apiKey: 'test' })
expect(provider).toBeDefined()
expect(config.creator).toHaveBeenCalledWith({ apiKey: 'test' })
})
it('能够注册 provider 到全局管理器', () => {
const mockProvider = { name: 'mock-provider' }
const config: ProviderConfig = {
id: 'test-register-provider',
name: 'Test Register Provider',
creator: vi.fn(() => mockProvider),
supportsImageGeneration: false
}
// 先注册配置
registerProviderConfig(config)
// 注册 provider 到全局管理器
const success = registerProvider('test-register-provider', mockProvider)
expect(success).toBe(true)
// 验证注册成功
const registeredProviders = getInitializedProviders()
expect(registeredProviders).toContain('test-register-provider')
expect(hasInitializedProviders()).toBe(true)
})
it('能够一步完成创建和注册', async () => {
const config: ProviderConfig = {
id: 'test-create-and-register',
name: 'Test Create and Register',
creator: vi.fn(() => ({ name: 'test-both' })),
supportsImageGeneration: false
}
// 先注册配置
registerProviderConfig(config)
// 一步完成创建和注册
const success = await createAndRegisterProvider('test-create-and-register', { apiKey: 'test' })
expect(success).toBe(true)
// 验证注册成功
const registeredProviders = getInitializedProviders()
expect(registeredProviders).toContain('test-create-and-register')
})
})
describe('Registry 管理', () => {
it('能够清理所有配置和注册的 providers', () => {
// 注册一些配置
registerProviderConfig({
id: 'temp-provider',
name: 'Temp Provider',
creator: vi.fn(() => ({ name: 'temp' })),
supportsImageGeneration: false
})
expect(hasProviderConfig('temp-provider')).toBe(true)
// 清理
cleanup()
expect(hasProviderConfig('temp-provider')).toBe(false)
// 但基础配置应该重新加载
expect(hasProviderConfig('openai')).toBe(true) // 基础 providers 会重新初始化
})
it('能够单独清理已注册的 providers', () => {
// 清理所有 providers
clearAllProviders()
expect(getInitializedProviders()).toEqual([])
expect(hasInitializedProviders()).toBe(false)
})
it('ProviderInitializationError 错误类工作正常', () => {
const error = new ProviderInitializationError('Test error', 'test-provider')
expect(error.message).toBe('Test error')
expect(error.providerId).toBe('test-provider')
expect(error.name).toBe('ProviderInitializationError')
})
})
describe('错误处理', () => {
it('优雅处理空配置', () => {
const success = registerProviderConfig(null as any)
expect(success).toBe(false)
})
it('优雅处理未定义配置', () => {
const success = registerProviderConfig(undefined as any)
expect(success).toBe(false)
})
it('处理空字符串 ID', () => {
const config = {
id: '',
name: 'Empty ID Provider',
creator: vi.fn(() => ({ name: 'empty' })),
supportsImageGeneration: false
}
const success = registerProviderConfig(config)
expect(success).toBe(false)
})
it('处理创建不存在配置的 provider', async () => {
await expect(createProvider('non-existent-provider', {})).rejects.toThrow(
'ProviderConfig not found for id: non-existent-provider'
)
})
it('处理注册不存在配置的 provider', () => {
const mockProvider = { name: 'mock' }
const success = registerProvider('non-existent-provider', mockProvider)
expect(success).toBe(false)
})
it('处理获取不存在配置的情况', () => {
expect(getProviderConfig('non-existent')).toBeUndefined()
expect(getProviderConfigByAlias('non-existent-alias')).toBeUndefined()
expect(hasProviderConfig('non-existent')).toBe(false)
expect(hasProviderConfigByAlias('non-existent-alias')).toBe(false)
})
it('处理批量注册时的部分失败', () => {
const mixedConfigs: ProviderConfig[] = [
{
id: 'valid-provider-1',
name: 'Valid Provider 1',
creator: vi.fn(() => ({ name: 'valid-1' })),
supportsImageGeneration: false
},
{
id: '', // 无效配置
name: 'Invalid Provider',
creator: vi.fn(() => ({ name: 'invalid' })),
supportsImageGeneration: false
} as any,
{
id: 'valid-provider-2',
name: 'Valid Provider 2',
creator: vi.fn(() => ({ name: 'valid-2' })),
supportsImageGeneration: true
}
]
const successCount = registerMultipleProviderConfigs(mixedConfigs)
expect(successCount).toBe(2) // 只有两个有效配置成功
expect(hasProviderConfig('valid-provider-1')).toBe(true)
expect(hasProviderConfig('valid-provider-2')).toBe(true)
expect(hasProviderConfig('')).toBe(false)
})
it('处理动态导入失败的情况', async () => {
const config: ProviderConfig = {
id: 'import-test-provider',
name: 'Import Test Provider',
import: vi.fn().mockRejectedValue(new Error('Import failed')),
creatorFunctionName: 'createTest',
supportsImageGeneration: false
}
registerProviderConfig(config)
await expect(createProvider('import-test-provider', {})).rejects.toThrow('Import failed')
})
})
describe('集成测试', () => {
it('正确处理复杂的配置、创建、注册和清理场景', async () => {
// 初始状态验证
const initialConfigs = getAllProviderConfigs()
expect(initialConfigs.length).toBeGreaterThan(0) // 有基础配置
expect(getInitializedProviders()).toEqual([])
// 注册多个带别名的 provider 配置
const configs: ProviderConfig[] = [
{
id: 'integration-provider-1',
name: 'Integration Provider 1',
creator: vi.fn(() => ({ name: 'integration-1' })),
supportsImageGeneration: false,
aliases: ['alias-1', 'short-name-1']
},
{
id: 'integration-provider-2',
name: 'Integration Provider 2',
creator: vi.fn(() => ({ name: 'integration-2' })),
supportsImageGeneration: true,
aliases: ['alias-2', 'short-name-2']
}
]
const successCount = registerMultipleProviderConfigs(configs)
expect(successCount).toBe(2)
// 验证配置注册成功
expect(hasProviderConfig('integration-provider-1')).toBe(true)
expect(hasProviderConfig('integration-provider-2')).toBe(true)
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
// 验证别名映射
const aliases = getAllProviderConfigAliases()
expect(aliases['alias-1']).toBe('integration-provider-1')
expect(aliases['alias-2']).toBe('integration-provider-2')
// 创建和注册 providers
const success1 = await createAndRegisterProvider('integration-provider-1', { apiKey: 'test1' })
const success2 = await createAndRegisterProvider('integration-provider-2', { apiKey: 'test2' })
expect(success1).toBe(true)
expect(success2).toBe(true)
// 验证注册成功
const registeredProviders = getInitializedProviders()
expect(registeredProviders).toContain('integration-provider-1')
expect(registeredProviders).toContain('integration-provider-2')
expect(hasInitializedProviders()).toBe(true)
// 清理
cleanup()
// 验证清理后的状态
expect(getInitializedProviders()).toEqual([])
expect(hasProviderConfig('integration-provider-1')).toBe(false)
expect(hasProviderConfig('integration-provider-2')).toBe(false)
expect(getAllProviderConfigAliases()).toEqual({})
// 基础配置应该重新加载
expect(hasProviderConfig('openai')).toBe(true)
})
it('正确处理动态导入配置的 provider', async () => {
const mockModule = { createCustomProvider: vi.fn(() => ({ name: 'custom-dynamic' })) }
const dynamicImportConfig: ProviderConfig = {
id: 'dynamic-import-provider',
name: 'Dynamic Import Provider',
import: vi.fn().mockResolvedValue(mockModule),
creatorFunctionName: 'createCustomProvider',
supportsImageGeneration: false
}
// 注册配置
const configSuccess = registerProviderConfig(dynamicImportConfig)
expect(configSuccess).toBe(true)
// 创建和注册 provider
const registerSuccess = await createAndRegisterProvider('dynamic-import-provider', { apiKey: 'test' })
expect(registerSuccess).toBe(true)
// 验证导入函数被调用
expect(dynamicImportConfig.import).toHaveBeenCalled()
expect(mockModule.createCustomProvider).toHaveBeenCalledWith({ apiKey: 'test' })
// 验证注册成功
expect(getInitializedProviders()).toContain('dynamic-import-provider')
})
it('正确处理大量配置的注册和管理', () => {
const largeConfigList: ProviderConfig[] = []
// 生成50个配置
for (let i = 0; i < 50; i++) {
largeConfigList.push({
id: `bulk-provider-${i}`,
name: `Bulk Provider ${i}`,
creator: vi.fn(() => ({ name: `bulk-${i}` })),
supportsImageGeneration: i % 2 === 0, // 偶数支持图像生成
aliases: [`alias-${i}`, `short-${i}`]
})
}
const successCount = registerMultipleProviderConfigs(largeConfigList)
expect(successCount).toBe(50)
// 验证所有配置都被正确注册
const allConfigs = getAllProviderConfigs()
expect(allConfigs.filter((config) => config.id.startsWith('bulk-provider-')).length).toBe(50)
// 验证别名数量
const aliases = getAllProviderConfigAliases()
const bulkAliases = Object.keys(aliases).filter(
(alias) => alias.startsWith('alias-') || alias.startsWith('short-')
)
expect(bulkAliases.length).toBe(100) // 每个 provider 有2个别名
// 随机验证几个配置
expect(hasProviderConfig('bulk-provider-0')).toBe(true)
expect(hasProviderConfig('bulk-provider-25')).toBe(true)
expect(hasProviderConfig('bulk-provider-49')).toBe(true)
// 验证别名工作正常
expect(resolveProviderConfigId('alias-25')).toBe('bulk-provider-25')
expect(isProviderConfigAlias('short-30')).toBe(true)
// 清理能正确处理大量数据
cleanup()
const cleanupAliases = getAllProviderConfigAliases()
expect(
Object.keys(cleanupAliases).filter((alias) => alias.startsWith('alias-') || alias.startsWith('short-'))
).toEqual([])
})
})
describe('边界测试', () => {
it('处理包含特殊字符的 provider IDs', () => {
const specialCharsConfigs: ProviderConfig[] = [
{
id: 'provider-with-dashes',
name: 'Provider With Dashes',
creator: vi.fn(() => ({ name: 'dashes' })),
supportsImageGeneration: false
},
{
id: 'provider_with_underscores',
name: 'Provider With Underscores',
creator: vi.fn(() => ({ name: 'underscores' })),
supportsImageGeneration: false
},
{
id: 'provider.with.dots',
name: 'Provider With Dots',
creator: vi.fn(() => ({ name: 'dots' })),
supportsImageGeneration: false
}
]
const successCount = registerMultipleProviderConfigs(specialCharsConfigs)
expect(successCount).toBeGreaterThan(0) // 至少有一些成功
// 验证支持的特殊字符格式
if (hasProviderConfig('provider-with-dashes')) {
expect(getProviderConfig('provider-with-dashes')).toBeDefined()
}
if (hasProviderConfig('provider_with_underscores')) {
expect(getProviderConfig('provider_with_underscores')).toBeDefined()
}
})
it('处理空的批量注册', () => {
const successCount = registerMultipleProviderConfigs([])
expect(successCount).toBe(0)
// 确保没有额外的配置被添加
const configsBefore = getAllProviderConfigs().length
expect(configsBefore).toBeGreaterThan(0) // 应该有基础配置
})
it('处理重复的配置注册', () => {
const config: ProviderConfig = {
id: 'duplicate-test-provider',
name: 'Duplicate Test Provider',
creator: vi.fn(() => ({ name: 'duplicate' })),
supportsImageGeneration: false
}
// 第一次注册成功
expect(registerProviderConfig(config)).toBe(true)
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
// 重复注册相同的配置(允许覆盖)
const updatedConfig: ProviderConfig = {
...config,
name: 'Updated Duplicate Test Provider'
}
expect(registerProviderConfig(updatedConfig)).toBe(true)
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
// 验证配置被更新
const retrievedConfig = getProviderConfig('duplicate-test-provider')
expect(retrievedConfig?.name).toBe('Updated Duplicate Test Provider')
})
it('处理极长的 ID 和名称', () => {
const longId = 'very-long-provider-id-' + 'x'.repeat(100)
const longName = 'Very Long Provider Name ' + 'Y'.repeat(100)
const config: ProviderConfig = {
id: longId,
name: longName,
creator: vi.fn(() => ({ name: 'long-test' })),
supportsImageGeneration: false
}
const success = registerProviderConfig(config)
expect(success).toBe(true)
expect(hasProviderConfig(longId)).toBe(true)
const retrievedConfig = getProviderConfig(longId)
expect(retrievedConfig?.name).toBe(longName)
})
it('处理大量别名的配置', () => {
const manyAliases = Array.from({ length: 50 }, (_, i) => `alias-${i}`)
const config: ProviderConfig = {
id: 'provider-with-many-aliases',
name: 'Provider With Many Aliases',
creator: vi.fn(() => ({ name: 'many-aliases' })),
supportsImageGeneration: false,
aliases: manyAliases
}
const success = registerProviderConfig(config)
expect(success).toBe(true)
// 验证所有别名都能正确解析
manyAliases.forEach((alias) => {
expect(hasProviderConfigByAlias(alias)).toBe(true)
expect(resolveProviderConfigId(alias)).toBe('provider-with-many-aliases')
expect(isProviderConfigAlias(alias)).toBe(true)
})
})
})
})

View File

@ -0,0 +1,264 @@
import { describe, expect, it, vi } from 'vitest'
import {
type BaseProviderId,
baseProviderIds,
baseProviderIdSchema,
baseProviders,
type CustomProviderId,
customProviderIdSchema,
providerConfigSchema,
type ProviderId,
providerIdSchema
} from '../schemas'
describe('Provider Schemas', () => {
describe('baseProviders', () => {
it('包含所有预期的基础 providers', () => {
expect(baseProviders).toBeDefined()
expect(Array.isArray(baseProviders)).toBe(true)
expect(baseProviders.length).toBeGreaterThan(0)
const expectedIds = [
'openai',
'openai-responses',
'openai-compatible',
'anthropic',
'google',
'xai',
'azure',
'deepseek'
]
const actualIds = baseProviders.map((p) => p.id)
expectedIds.forEach((id) => {
expect(actualIds).toContain(id)
})
})
it('每个基础 provider 有必要的属性', () => {
baseProviders.forEach((provider) => {
expect(provider).toHaveProperty('id')
expect(provider).toHaveProperty('name')
expect(provider).toHaveProperty('creator')
expect(provider).toHaveProperty('supportsImageGeneration')
expect(typeof provider.id).toBe('string')
expect(typeof provider.name).toBe('string')
expect(typeof provider.creator).toBe('function')
expect(typeof provider.supportsImageGeneration).toBe('boolean')
})
})
it('provider ID 是唯一的', () => {
const ids = baseProviders.map((p) => p.id)
const uniqueIds = [...new Set(ids)]
expect(ids).toEqual(uniqueIds)
})
})
describe('baseProviderIds', () => {
it('正确提取所有基础 provider IDs', () => {
expect(baseProviderIds).toBeDefined()
expect(Array.isArray(baseProviderIds)).toBe(true)
expect(baseProviderIds.length).toBe(baseProviders.length)
baseProviders.forEach((provider) => {
expect(baseProviderIds).toContain(provider.id)
})
})
})
describe('baseProviderIdSchema', () => {
it('验证有效的基础 provider IDs', () => {
baseProviderIds.forEach((id) => {
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
})
})
it('拒绝无效的基础 provider IDs', () => {
const invalidIds = ['invalid', 'not-exists', '']
invalidIds.forEach((id) => {
expect(baseProviderIdSchema.safeParse(id).success).toBe(false)
})
})
})
describe('customProviderIdSchema', () => {
it('接受有效的自定义 provider IDs', () => {
const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2']
validIds.forEach((id) => {
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
})
})
it('拒绝与基础 provider IDs 冲突的 IDs', () => {
baseProviderIds.forEach((id) => {
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
})
})
it('拒绝空字符串', () => {
expect(customProviderIdSchema.safeParse('').success).toBe(false)
})
})
describe('providerIdSchema', () => {
it('接受基础 provider IDs', () => {
baseProviderIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
})
it('接受有效的自定义 provider IDs', () => {
const validCustomIds = ['custom-provider', 'my-ai-service']
validCustomIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
})
it('拒绝无效的 IDs', () => {
const invalidIds = ['', undefined, null, 123]
invalidIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(false)
})
})
})
describe('providerConfigSchema', () => {
it('验证带有 creator 的有效配置', () => {
const validConfig = {
id: 'custom-provider',
name: 'Custom Provider',
creator: vi.fn(),
supportsImageGeneration: true
}
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
})
it('验证带有 import 配置的有效配置', () => {
const validConfig = {
id: 'custom-provider',
name: 'Custom Provider',
import: vi.fn(),
creatorFunctionName: 'createCustom',
supportsImageGeneration: false
}
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
})
it('拒绝既没有 creator 也没有 import 配置的配置', () => {
const invalidConfig = {
id: 'invalid',
name: 'Invalid Provider',
supportsImageGeneration: false
}
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
})
it('为 supportsImageGeneration 设置默认值', () => {
const config = {
id: 'test',
name: 'Test',
creator: vi.fn()
}
const result = providerConfigSchema.safeParse(config)
expect(result.success).toBe(true)
if (result.success) {
expect(result.data.supportsImageGeneration).toBe(false)
}
})
it('拒绝使用基础 provider ID 的配置', () => {
const invalidConfig = {
id: 'openai', // 基础 provider ID
name: 'Should Fail',
creator: vi.fn()
}
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
})
it('拒绝缺少必需字段的配置', () => {
const invalidConfigs = [
{ name: 'Missing ID', creator: vi.fn() },
{ id: 'missing-name', creator: vi.fn() },
{ id: '', name: 'Empty ID', creator: vi.fn() },
{ id: 'valid-custom', name: '', creator: vi.fn() }
]
invalidConfigs.forEach((config) => {
expect(providerConfigSchema.safeParse(config).success).toBe(false)
})
})
})
describe('Schema 验证功能', () => {
it('baseProviderIdSchema 正确验证基础 provider IDs', () => {
baseProviderIds.forEach((id) => {
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
})
expect(baseProviderIdSchema.safeParse('invalid-id').success).toBe(false)
})
it('customProviderIdSchema 正确验证自定义 provider IDs', () => {
const customIds = ['custom-provider', 'my-service', 'company-llm']
customIds.forEach((id) => {
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
})
// 拒绝基础 provider IDs
baseProviderIds.forEach((id) => {
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
})
})
it('providerIdSchema 接受基础和自定义 provider IDs', () => {
// 基础 IDs
baseProviderIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
// 自定义 IDs
const customIds = ['custom-provider', 'my-service']
customIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
})
it('providerConfigSchema 验证完整的 provider 配置', () => {
const validConfig = {
id: 'custom-provider',
name: 'Custom Provider',
creator: vi.fn(),
supportsImageGeneration: true
}
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
const invalidConfig = {
id: 'openai', // 不允许基础 provider ID
name: 'OpenAI',
creator: vi.fn()
}
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
})
})
describe('类型推导', () => {
it('BaseProviderId 类型正确', () => {
const id: BaseProviderId = 'openai'
expect(baseProviderIds).toContain(id)
})
it('CustomProviderId 类型是字符串', () => {
const id: CustomProviderId = 'custom-provider'
expect(typeof id).toBe('string')
})
it('ProviderId 类型支持基础和自定义 IDs', () => {
const baseId: ProviderId = 'openai'
const customId: ProviderId = 'custom-provider'
expect(typeof baseId).toBe('string')
expect(typeof customId).toBe('string')
})
})
})

View File

@ -0,0 +1,291 @@
/**
* AI Provider
* Provider
*/
import type { ProviderId, ProviderSettingsMap } from './types'
/**
* Provider
*/
export interface BaseProviderConfig {
apiKey?: string
baseURL?: string
timeout?: number
headers?: Record<string, string>
fetch?: typeof globalThis.fetch
}
/**
* AI SDK Provider
*/
type CompleteProviderConfig<T extends ProviderId> = BaseProviderConfig & Partial<ProviderSettingsMap[T]>
type ConfigHandler<T extends ProviderId> = (
builder: ProviderConfigBuilder<T>,
provider: CompleteProviderConfig<T>
) => void
const configHandlers: {
[K in ProviderId]?: ConfigHandler<K>
} = {
azure: (builder, provider) => {
const azureBuilder = builder as ProviderConfigBuilder<'azure'>
const azureProvider = provider as CompleteProviderConfig<'azure'>
azureBuilder.withAzureConfig({
apiVersion: azureProvider.apiVersion,
resourceName: azureProvider.resourceName
})
}
}
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
private config: CompleteProviderConfig<T> = {} as CompleteProviderConfig<T>
constructor(private providerId: T) {}
/**
* API Key
*/
withApiKey(apiKey: string): this
withApiKey(apiKey: string, options: T extends 'openai' ? { organization?: string; project?: string } : never): this
withApiKey(apiKey: string, options?: any): this {
this.config.apiKey = apiKey
// 类型安全的 OpenAI 特定配置
if (this.providerId === 'openai' && options) {
const openaiConfig = this.config as CompleteProviderConfig<'openai'>
if (options.organization) {
openaiConfig.organization = options.organization
}
if (options.project) {
openaiConfig.project = options.project
}
}
return this
}
/**
* URL
*/
withBaseURL(baseURL: string) {
this.config.baseURL = baseURL
return this
}
/**
*
*/
withRequestConfig(options: { headers?: Record<string, string>; fetch?: typeof fetch }): this {
if (options.headers) {
this.config.headers = { ...this.config.headers, ...options.headers }
}
if (options.fetch) {
this.config.fetch = options.fetch
}
return this
}
/**
* Azure OpenAI
*/
withAzureConfig(options: { apiVersion?: string; resourceName?: string }): T extends 'azure' ? this : never
withAzureConfig(options: any): any {
if (this.providerId === 'azure') {
const azureConfig = this.config as CompleteProviderConfig<'azure'>
if (options.apiVersion) {
azureConfig.apiVersion = options.apiVersion
}
if (options.resourceName) {
azureConfig.resourceName = options.resourceName
}
}
return this
}
/**
*
*/
withCustomParams(params: Record<string, any>) {
Object.assign(this.config, params)
return this
}
/**
*
*/
build(): ProviderSettingsMap[T] {
return this.config as ProviderSettingsMap[T]
}
}
/**
* Provider
* 便
*/
export class ProviderConfigFactory {
/**
*
*/
static builder<T extends ProviderId>(providerId: T): ProviderConfigBuilder<T> {
return new ProviderConfigBuilder(providerId)
}
/**
* Provider对象创建配置 - 使
*/
static fromProvider<T extends ProviderId>(
providerId: T,
provider: CompleteProviderConfig<T>,
options?: {
headers?: Record<string, string>
[key: string]: any
}
): ProviderSettingsMap[T] {
const builder = new ProviderConfigBuilder<T>(providerId)
// 设置基本配置
if (provider.apiKey) {
builder.withApiKey(provider.apiKey)
}
if (provider.baseURL) {
builder.withBaseURL(provider.baseURL)
}
// 设置请求配置
if (options?.headers) {
builder.withRequestConfig({
headers: options.headers
})
}
// 使用配置处理器模式 - 更加优雅和可扩展
const handler = configHandlers[providerId]
if (handler) {
handler(builder, provider)
}
// 添加其他自定义参数
if (options) {
const customOptions = { ...options }
delete customOptions.headers // 已经处理过了
if (Object.keys(customOptions).length > 0) {
builder.withCustomParams(customOptions)
}
}
return builder.build()
}
/**
* OpenAI
*/
static createOpenAI(
apiKey: string,
options?: {
baseURL?: string
organization?: string
project?: string
}
) {
const builder = this.builder('openai')
// 使用类型安全的重载
if (options?.organization || options?.project) {
builder.withApiKey(apiKey, {
organization: options.organization,
project: options.project
})
} else {
builder.withApiKey(apiKey)
}
return builder.withBaseURL(options?.baseURL || 'https://api.openai.com').build()
}
/**
* Anthropic
*/
static createAnthropic(
apiKey: string,
options?: {
baseURL?: string
}
) {
return this.builder('anthropic')
.withApiKey(apiKey)
.withBaseURL(options?.baseURL || 'https://api.anthropic.com')
.build()
}
/**
* Azure OpenAI
*/
static createAzureOpenAI(
apiKey: string,
options: {
baseURL: string
apiVersion?: string
resourceName?: string
}
) {
return this.builder('azure')
.withApiKey(apiKey)
.withBaseURL(options.baseURL)
.withAzureConfig({
apiVersion: options.apiVersion,
resourceName: options.resourceName
})
.build()
}
/**
* Google
*/
static createGoogle(
apiKey: string,
options?: {
baseURL?: string
projectId?: string
location?: string
}
) {
return this.builder('google')
.withApiKey(apiKey)
.withBaseURL(options?.baseURL || 'https://generativelanguage.googleapis.com')
.build()
}
/**
* Vertex AI
*/
static createVertexAI() {
// credentials: {
// clientEmail: string
// privateKey: string
// },
// options?: {
// project?: string
// location?: string
// }
// return this.builder('google-vertex')
// .withGoogleCredentials(credentials)
// .withGoogleVertexConfig({
// project: options?.project,
// location: options?.location
// })
// .build()
}
static createOpenAICompatible(baseURL: string, apiKey: string) {
return this.builder('openai-compatible').withBaseURL(baseURL).withApiKey(apiKey).build()
}
}
/**
* 便
*/
export const createProviderConfig = ProviderConfigFactory.fromProvider
export const providerConfigBuilder = ProviderConfigFactory.builder

View File

@ -0,0 +1,83 @@
/**
* Providers - Provider包
*/
// ==================== 核心管理器 ====================
// Provider 注册表管理器
export { globalRegistryManagement, RegistryManagement } from './RegistryManagement'
// Provider 核心功能
export {
// 状态管理
cleanup,
clearAllProviders,
createAndRegisterProvider,
createProvider,
getAllProviderConfigAliases,
getAllProviderConfigs,
getImageModel,
// 工具函数
getInitializedProviders,
getLanguageModel,
getProviderConfig,
getProviderConfigByAlias,
getSupportedProviders,
getTextEmbeddingModel,
hasInitializedProviders,
// 工具函数
hasProviderConfig,
// 别名支持
hasProviderConfigByAlias,
isProviderConfigAlias,
// 错误类型
ProviderInitializationError,
// 全局访问
providerRegistry,
registerMultipleProviderConfigs,
registerProvider,
// 统一Provider系统
registerProviderConfig,
resolveProviderConfigId
} from './registry'
// ==================== 基础数据和类型 ====================
// 基础Provider数据源
export { baseProviderIds, baseProviders } from './schemas'
// 类型定义和Schema
export type {
BaseProviderId,
CustomProviderId,
DynamicProviderRegistration,
ProviderConfig,
ProviderId
} from './schemas' // 从 schemas 导出的类型
export { baseProviderIdSchema, customProviderIdSchema, providerConfigSchema, providerIdSchema } from './schemas' // Schema 导出
export type {
DynamicProviderRegistry,
ExtensibleProviderSettingsMap,
ProviderError,
ProviderSettingsMap,
ProviderTypeRegistrar
} from './types'
// ==================== 工具函数 ====================
// Provider配置工厂
export {
type BaseProviderConfig,
createProviderConfig,
ProviderConfigBuilder,
providerConfigBuilder,
ProviderConfigFactory
} from './factory'
// 工具函数
export { formatPrivateKey } from './utils'
// ==================== 扩展功能 ====================
// Hub Provider 功能
export { createHubProvider, type HubProviderConfig, HubProviderError } from './HubProvider'

View File

@ -0,0 +1,320 @@
/**
* Provider
* providers
* ModelCreator
*/
import { customProvider } from 'ai'
import { globalRegistryManagement } from './RegistryManagement'
import { baseProviders, type ProviderConfig } from './schemas'
/**
* Provider
*/
class ProviderInitializationError extends Error {
constructor(
message: string,
public providerId?: string,
public cause?: Error
) {
super(message)
this.name = 'ProviderInitializationError'
}
}
// ==================== 全局管理器导出 ====================
export { globalRegistryManagement as providerRegistry }
// ==================== 便捷访问方法 ====================
export const getLanguageModel = (id: string) => globalRegistryManagement.languageModel(id as any)
export const getTextEmbeddingModel = (id: string) => globalRegistryManagement.textEmbeddingModel(id as any)
export const getImageModel = (id: string) => globalRegistryManagement.imageModel(id as any)
// ==================== 工具函数 ====================
/**
* Providers
*/
export function getSupportedProviders(): Array<{
id: string
name: string
}> {
return baseProviders.map((provider) => ({
id: provider.id,
name: provider.name
}))
}
/**
* providers
*/
export function getInitializedProviders(): string[] {
return globalRegistryManagement.getRegisteredProviders()
}
/**
* providers
*/
export function hasInitializedProviders(): boolean {
return globalRegistryManagement.hasProviders()
}
// ==================== 统一Provider配置系统 ====================
// 全局Provider配置存储
const providerConfigs = new Map<string, ProviderConfig>()
// 全局ProviderConfig别名映射 - 借鉴RegistryManagement模式
const providerConfigAliases = new Map<string, string>() // alias -> realId
/**
* - baseProviders转换为统一格式
*/
function initializeBuiltInConfigs(): void {
baseProviders.forEach((provider) => {
const config: ProviderConfig = {
id: provider.id,
name: provider.name,
creator: provider.creator as any, // 类型转换以兼容多种creator签名
supportsImageGeneration: provider.supportsImageGeneration || false
}
providerConfigs.set(provider.id, config)
})
}
// 启动时自动注册内置配置
initializeBuiltInConfigs()
/**
* 步骤1: 注册Provider配置 -
*/
export function registerProviderConfig(config: ProviderConfig): boolean {
try {
// 验证配置
if (!config || !config.id || !config.name) {
return false
}
// 检查是否与已有配置冲突(包括内置配置)
if (providerConfigs.has(config.id)) {
console.warn(`ProviderConfig "${config.id}" already exists, will override`)
}
// 存储配置(内置和用户配置统一处理)
providerConfigs.set(config.id, config)
// 处理别名
if (config.aliases && config.aliases.length > 0) {
config.aliases.forEach((alias) => {
if (providerConfigAliases.has(alias)) {
console.warn(`ProviderConfig alias "${alias}" already exists, will override`)
}
providerConfigAliases.set(alias, config.id)
})
}
return true
} catch (error) {
console.error(`Failed to register ProviderConfig:`, error)
return false
}
}
/**
* 步骤2: 创建Provider -
*/
export async function createProvider(providerId: string, options: any): Promise<any> {
// 支持通过别名查找配置
const config = getProviderConfigByAlias(providerId)
if (!config) {
throw new Error(`ProviderConfig not found for id: ${providerId}`)
}
try {
let creator: (options: any) => any
if (config.creator) {
// 方式1: 直接执行 creator
creator = config.creator
} else if (config.import && config.creatorFunctionName) {
// 方式2: 动态导入并执行
const module = await config.import()
creator = (module as any)[config.creatorFunctionName]
if (!creator || typeof creator !== 'function') {
throw new Error(`Creator function "${config.creatorFunctionName}" not found in imported module`)
}
} else {
throw new Error('No valid creator method provided in ProviderConfig')
}
// 使用真实配置创建provider实例
return creator(options)
} catch (error) {
console.error(`Failed to create provider "${providerId}":`, error)
throw error
}
}
/**
* 步骤3: 注册Provider到全局管理器
*/
export function registerProvider(providerId: string, provider: any): boolean {
try {
const config = providerConfigs.get(providerId)
if (!config) {
console.error(`ProviderConfig not found for id: ${providerId}`)
return false
}
// 获取aliases配置
const aliases = config.aliases
// 处理特殊provider逻辑
if (providerId === 'openai') {
// 注册默认 openai
globalRegistryManagement.registerProvider(providerId, provider, aliases)
// 创建并注册 openai-chat 变体
const openaiChatProvider = customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.chat(modelId)
}
})
globalRegistryManagement.registerProvider(`${providerId}-chat`, openaiChatProvider)
} else if (providerId === 'azure') {
globalRegistryManagement.registerProvider(`${providerId}-chat`, provider, aliases)
// 跟上面相反,creator产出的默认会调用chat
const azureResponsesProvider = customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.responses(modelId)
}
})
globalRegistryManagement.registerProvider(providerId, azureResponsesProvider)
} else {
// 其他provider直接注册
globalRegistryManagement.registerProvider(providerId, provider, aliases)
}
return true
} catch (error) {
console.error(`Failed to register provider "${providerId}" to global registry:`, error)
return false
}
}
/**
* 便捷函数: 一次性完成创建+
*/
export async function createAndRegisterProvider(providerId: string, options: any): Promise<boolean> {
try {
// 步骤2: 创建provider
const provider = await createProvider(providerId, options)
// 步骤3: 注册到全局管理器
return registerProvider(providerId, provider)
} catch (error) {
console.error(`Failed to create and register provider "${providerId}":`, error)
return false
}
}
/**
* Provider配置
*/
export function registerMultipleProviderConfigs(configs: ProviderConfig[]): number {
let successCount = 0
configs.forEach((config) => {
if (registerProviderConfig(config)) {
successCount++
}
})
return successCount
}
/**
* Provider配置
*/
export function hasProviderConfig(providerId: string): boolean {
return providerConfigs.has(providerId)
}
/**
* ID检查是否有对应的Provider配置
*/
export function hasProviderConfigByAlias(aliasOrId: string): boolean {
const realId = resolveProviderConfigId(aliasOrId)
return providerConfigs.has(realId)
}
/**
* Provider配置
*/
export function getAllProviderConfigs(): ProviderConfig[] {
return Array.from(providerConfigs.values())
}
/**
* ID获取Provider配置
*/
export function getProviderConfig(providerId: string): ProviderConfig | undefined {
return providerConfigs.get(providerId)
}
/**
* ID获取Provider配置
*/
export function getProviderConfigByAlias(aliasOrId: string): ProviderConfig | undefined {
// 先检查是否为别名如果是则解析为真实ID
const realId = providerConfigAliases.get(aliasOrId) || aliasOrId
return providerConfigs.get(realId)
}
/**
* ProviderConfig ID
*/
export function resolveProviderConfigId(aliasOrId: string): string {
return providerConfigAliases.get(aliasOrId) || aliasOrId
}
/**
* ProviderConfig别名
*/
export function isProviderConfigAlias(id: string): boolean {
return providerConfigAliases.has(id)
}
/**
* ProviderConfig别名映射关系
*/
export function getAllProviderConfigAliases(): Record<string, string> {
const result: Record<string, string> = {}
providerConfigAliases.forEach((realId, alias) => {
result[alias] = realId
})
return result
}
/**
* Provider配置和已注册的providers
*/
export function cleanup(): void {
providerConfigs.clear()
providerConfigAliases.clear() // 清理别名映射
globalRegistryManagement.clear()
// 重新初始化内置配置
initializeBuiltInConfigs()
}
export function clearAllProviders(): void {
globalRegistryManagement.clear()
}
// ==================== 导出错误类型 ====================
export { ProviderInitializationError }

View File

@ -0,0 +1,178 @@
/**
* Provider Config
*/
import { createAnthropic } from '@ai-sdk/anthropic'
import { createAzure } from '@ai-sdk/azure'
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { createDeepSeek } from '@ai-sdk/deepseek'
import { createGoogleGenerativeAI } from '@ai-sdk/google'
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
import { createXai } from '@ai-sdk/xai'
import { customProvider, type Provider } from 'ai'
import * as z from 'zod'
/**
* Provider IDs
*/
export const baseProviderIds = [
'openai',
'openai-chat',
'openai-compatible',
'anthropic',
'google',
'xai',
'azure',
'azure-responses',
'deepseek'
] as const
/**
* Provider ID Schema
*/
export const baseProviderIdSchema = z.enum(baseProviderIds)
/**
* Provider ID
*/
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
export const baseProviderSchema = z.object({
id: baseProviderIdSchema,
name: z.string(),
creator: z.function().args(z.any()).returns(z.any()) as z.ZodType<(options: any) => Provider>,
supportsImageGeneration: z.boolean()
})
export type BaseProvider = z.infer<typeof baseProviderSchema>
/**
* Providers
*
*/
export const baseProviders = [
{
id: 'openai',
name: 'OpenAI',
creator: createOpenAI,
supportsImageGeneration: true
},
{
id: 'openai-chat',
name: 'OpenAI Chat',
creator: (options: OpenAIProviderSettings) => {
const provider = createOpenAI(options)
return customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.chat(modelId)
}
})
},
supportsImageGeneration: true
},
{
id: 'openai-compatible',
name: 'OpenAI Compatible',
creator: createOpenAICompatible,
supportsImageGeneration: true
},
{
id: 'anthropic',
name: 'Anthropic',
creator: createAnthropic,
supportsImageGeneration: false
},
{
id: 'google',
name: 'Google Generative AI',
creator: createGoogleGenerativeAI,
supportsImageGeneration: true
},
{
id: 'xai',
name: 'xAI (Grok)',
creator: createXai,
supportsImageGeneration: true
},
{
id: 'azure',
name: 'Azure OpenAI',
creator: createAzure,
supportsImageGeneration: true
},
{
id: 'azure-responses',
name: 'Azure OpenAI Responses',
creator: (options: AzureOpenAIProviderSettings) => {
const provider = createAzure(options)
return customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.responses(modelId)
}
})
},
supportsImageGeneration: true
},
{
id: 'deepseek',
name: 'DeepSeek',
creator: createDeepSeek,
supportsImageGeneration: false
}
] as const satisfies BaseProvider[]
/**
* Provider ID Schema
* provider IDs
*/
export const customProviderIdSchema = z
.string()
.min(1)
.refine((id) => !baseProviderIds.includes(id as any), {
message: 'Custom provider ID cannot conflict with base provider IDs'
})
/**
* Provider ID Schema -
*/
export const providerIdSchema = z.union([baseProviderIdSchema, customProviderIdSchema])
/**
* Provider Schema
* Provider的配置验证
*/
export const providerConfigSchema = z
.object({
id: customProviderIdSchema, // 只允许自定义ID
name: z.string().min(1),
creator: z.function().optional(),
import: z.function().optional(),
creatorFunctionName: z.string().optional(),
supportsImageGeneration: z.boolean().default(false),
imageCreator: z.function().optional(),
validateOptions: z.function().optional(),
aliases: z.array(z.string()).optional()
})
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
message: 'Must provide either creator function or import configuration'
})
/**
* Provider ID - zod schema
*/
export type ProviderId = z.infer<typeof providerIdSchema>
export type CustomProviderId = z.infer<typeof customProviderIdSchema>
/**
* Provider
*/
export type ProviderConfig = z.infer<typeof providerConfigSchema>
/**
*
* @deprecated 使 ProviderConfig
*/
export type DynamicProviderRegistration = ProviderConfig

View File

@ -0,0 +1,96 @@
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
import {
EmbeddingModelV2 as EmbeddingModel,
ImageModelV2 as ImageModel,
LanguageModelV2 as LanguageModel,
ProviderV2,
SpeechModelV2 as SpeechModel,
TranscriptionModelV2 as TranscriptionModel
} from '@ai-sdk/provider'
import { type XaiProviderSettings } from '@ai-sdk/xai'
// 导入基于 Zod 的 ProviderId 类型
import { type ProviderId as ZodProviderId } from './schemas'
export interface ExtensibleProviderSettingsMap {
// 基础的静态providers
openai: OpenAIProviderSettings
'openai-responses': OpenAIProviderSettings
'openai-compatible': OpenAICompatibleProviderSettings
anthropic: AnthropicProviderSettings
google: GoogleGenerativeAIProviderSettings
xai: XaiProviderSettings
azure: AzureOpenAIProviderSettings
deepseek: DeepSeekProviderSettings
}
// 动态扩展的provider类型注册表
export interface DynamicProviderRegistry {
[key: string]: any
}
// 合并基础和动态provider类型
export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry
// 错误类型
export class ProviderError extends Error {
constructor(
message: string,
public providerId: string,
public code?: string,
public cause?: Error
) {
super(message)
this.name = 'ProviderError'
}
}
// 动态ProviderId类型 - 基于 Zod Schema支持运行时扩展和验证
export type ProviderId = ZodProviderId
export interface ProviderTypeRegistrar {
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
getProviderSettings<T extends string>(providerId: T): any
}
// 重新导出所有类型供外部使用
export type {
AnthropicProviderSettings,
AzureOpenAIProviderSettings,
DeepSeekProviderSettings,
GoogleGenerativeAIProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
XaiProviderSettings
}
export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel<string> | TranscriptionModel | SpeechModel
export type AiSdkModelType = 'text' | 'image' | 'embedding' | 'transcription' | 'speech'
export const METHOD_MAP = {
text: 'languageModel',
image: 'imageModel',
embedding: 'textEmbeddingModel',
transcription: 'transcriptionModel',
speech: 'speechModel'
} as const satisfies Record<AiSdkModelType, keyof ProviderV2>
export type AiSdkModelMethodMap = Record<AiSdkModelType, keyof ProviderV2>
export type AiSdkModelReturnMap = {
text: LanguageModel
image: ImageModel
embedding: EmbeddingModel<string>
transcription: TranscriptionModel
speech: SpeechModel
}
export type AiSdkMethodName<T extends AiSdkModelType> = (typeof METHOD_MAP)[T]
export type AiSdkModelReturn<T extends AiSdkModelType> = AiSdkModelReturnMap[T]

View File

@ -0,0 +1,86 @@
/**
* PEM头部和尾部
*/
export function formatPrivateKey(privateKey: string): string {
if (!privateKey || typeof privateKey !== 'string') {
throw new Error('Private key must be a non-empty string')
}
// 先处理 JSON 字符串中的转义换行符
const key = privateKey.replace(/\\n/g, '\n')
// 检查是否已经是正确格式的 PEM 私钥
const hasBeginMarker = key.includes('-----BEGIN PRIVATE KEY-----')
const hasEndMarker = key.includes('-----END PRIVATE KEY-----')
if (hasBeginMarker && hasEndMarker) {
// 已经是 PEM 格式,但可能格式不规范,重新格式化
return normalizePemFormat(key)
}
// 如果没有完整的 PEM 头尾,尝试重新构建
return reconstructPemKey(key)
}
/**
* PEM
*/
function normalizePemFormat(pemKey: string): string {
// 分离头部、内容和尾部
const lines = pemKey
.split('\n')
.map((line) => line.trim())
.filter((line) => line.length > 0)
let keyContent = ''
let foundBegin = false
let foundEnd = false
for (const line of lines) {
if (line === '-----BEGIN PRIVATE KEY-----') {
foundBegin = true
continue
}
if (line === '-----END PRIVATE KEY-----') {
foundEnd = true
break
}
if (foundBegin && !foundEnd) {
keyContent += line
}
}
if (!foundBegin || !foundEnd || !keyContent) {
throw new Error('Invalid PEM format: missing BEGIN/END markers or key content')
}
// 重新格式化为 64 字符一行
const formattedContent = keyContent.match(/.{1,64}/g)?.join('\n') || keyContent
return `-----BEGIN PRIVATE KEY-----\n${formattedContent}\n-----END PRIVATE KEY-----`
}
/**
* PEM
*/
function reconstructPemKey(key: string): string {
// 移除所有空白字符和可能存在的不完整头尾
let cleanKey = key.replace(/\s+/g, '')
cleanKey = cleanKey.replace(/-----BEGIN[^-]*-----/g, '')
cleanKey = cleanKey.replace(/-----END[^-]*-----/g, '')
// 确保私钥内容不为空
if (!cleanKey) {
throw new Error('Private key content is empty after cleaning')
}
// 验证是否是有效的 Base64 字符
if (!/^[A-Za-z0-9+/=]+$/.test(cleanKey)) {
throw new Error('Private key contains invalid characters (not valid Base64)')
}
// 格式化为 64 字符一行
const formattedKey = cleanKey.match(/.{1,64}/g)?.join('\n') || cleanKey
return `-----BEGIN PRIVATE KEY-----\n${formattedKey}\n-----END PRIVATE KEY-----`
}

View File

@ -0,0 +1,523 @@
import { ImageModelV2 } from '@ai-sdk/provider'
import { experimental_generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { type AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { ImageGenerationError, ImageModelResolutionError } from '../errors'
import { RuntimeExecutor } from '../executor'
// Mock dependencies
vi.mock('ai', () => ({
experimental_generateImage: vi.fn(),
NoImageGeneratedError: class NoImageGeneratedError extends Error {
static isInstance = vi.fn()
constructor() {
super('No image generated')
this.name = 'NoImageGeneratedError'
}
}
}))
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
imageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
describe('RuntimeExecutor.generateImage', () => {
let executor: RuntimeExecutor<'openai'>
let mockImageModel: ImageModelV2
let mockGenerateImageResult: any
beforeEach(() => {
// Reset all mocks
vi.clearAllMocks()
// Create executor instance
executor = RuntimeExecutor.create('openai', {
apiKey: 'test-key'
})
// Mock image model
mockImageModel = {
modelId: 'dall-e-3',
provider: 'openai'
} as ImageModelV2
// Mock generateImage result
mockGenerateImageResult = {
image: {
base64: 'base64-encoded-image-data',
uint8Array: new Uint8Array([1, 2, 3]),
mediaType: 'image/png'
},
images: [
{
base64: 'base64-encoded-image-data',
uint8Array: new Uint8Array([1, 2, 3]),
mediaType: 'image/png'
}
],
warnings: [],
providerMetadata: {
openai: {
images: [{ revisedPrompt: 'A detailed prompt' }]
}
},
responses: []
}
// Setup mocks to avoid "No providers registered" error
vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel)
vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult)
})
describe('Basic functionality', () => {
it('should generate a single image with minimal parameters', async () => {
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape at sunset' })
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3')
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A futuristic cityscape at sunset'
})
expect(result).toEqual(mockGenerateImageResult)
})
it('should generate image with pre-created model', async () => {
const result = await executor.generateImage({
model: mockImageModel,
prompt: 'A beautiful landscape'
})
// Note: globalRegistryManagement.imageModel may still be called due to resolveImageModel logic
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A beautiful landscape'
})
expect(result).toEqual(mockGenerateImageResult)
})
it('should support multiple images generation', async () => {
await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape', n: 3 })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A futuristic cityscape',
n: 3
})
})
it('should support size specification', async () => {
await executor.generateImage({ model: 'dall-e-3', prompt: 'A beautiful sunset', size: '1024x1024' })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A beautiful sunset',
size: '1024x1024'
})
})
it('should support aspect ratio specification', async () => {
await executor.generateImage({ model: 'dall-e-3', prompt: 'A mountain landscape', aspectRatio: '16:9' })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A mountain landscape',
aspectRatio: '16:9'
})
})
it('should support seed for consistent output', async () => {
await executor.generateImage({ model: 'dall-e-3', prompt: 'A cat in space', seed: 1234567890 })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A cat in space',
seed: 1234567890
})
})
it('should support abort signal', async () => {
const abortController = new AbortController()
await executor.generateImage({ model: 'dall-e-3', prompt: 'A cityscape', abortSignal: abortController.signal })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A cityscape',
abortSignal: abortController.signal
})
})
it('should support provider-specific options', async () => {
await executor.generateImage({
model: 'dall-e-3',
prompt: 'A space station',
providerOptions: {
openai: {
quality: 'hd',
style: 'vivid'
}
}
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A space station',
providerOptions: {
openai: {
quality: 'hd',
style: 'vivid'
}
}
})
})
it('should support custom headers', async () => {
await executor.generateImage({
model: 'dall-e-3',
prompt: 'A robot',
headers: {
'X-Custom-Header': 'test-value'
}
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A robot',
headers: {
'X-Custom-Header': 'test-value'
}
})
})
})
describe('Plugin integration', () => {
it('should execute plugins in correct order', async () => {
const pluginCallOrder: string[] = []
const testPlugin: AiPlugin = {
name: 'test-plugin',
onRequestStart: vi.fn(async () => {
pluginCallOrder.push('onRequestStart')
}),
transformParams: vi.fn(async (params) => {
pluginCallOrder.push('transformParams')
return { ...params, size: '512x512' }
}),
transformResult: vi.fn(async (result) => {
pluginCallOrder.push('transformResult')
return { ...result, processed: true }
}),
onRequestEnd: vi.fn(async () => {
pluginCallOrder.push('onRequestEnd')
})
}
const executorWithPlugin = RuntimeExecutor.create(
'openai',
{
apiKey: 'test-key'
},
[testPlugin]
)
const result = await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
expect(testPlugin.transformParams).toHaveBeenCalledWith(
{ prompt: 'A test image' },
expect.objectContaining({
providerId: 'openai',
modelId: 'dall-e-3'
})
)
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A test image',
size: '512x512' // Should be transformed by plugin
})
expect(result).toEqual({
...mockGenerateImageResult,
processed: true // Should be transformed by plugin
})
})
it('should handle model resolution through plugins', async () => {
const customImageModel = {
modelId: 'custom-model',
provider: 'openai'
} as ImageModelV2
const modelResolutionPlugin: AiPlugin = {
name: 'model-resolver',
resolveModel: vi.fn(async () => customImageModel)
}
const executorWithPlugin = RuntimeExecutor.create(
'openai',
{
apiKey: 'test-key'
},
[modelResolutionPlugin]
)
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
'dall-e-3',
expect.objectContaining({
providerId: 'openai',
modelId: 'dall-e-3'
})
)
expect(aiGenerateImage).toHaveBeenCalledWith({
model: customImageModel,
prompt: 'A test image'
})
})
it('should support recursive calls from plugins', async () => {
const recursivePlugin: AiPlugin = {
name: 'recursive-plugin',
transformParams: vi.fn(async (params, context) => {
if (!context.isRecursiveCall && params.prompt === 'original') {
// Make a recursive call with modified prompt
await context.recursiveCall({
model: 'dall-e-3',
prompt: 'modified'
})
}
return params
})
}
const executorWithPlugin = RuntimeExecutor.create(
'openai',
{
apiKey: 'test-key'
},
[recursivePlugin]
)
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'original' })
expect(recursivePlugin.transformParams).toHaveBeenCalledTimes(2)
expect(aiGenerateImage).toHaveBeenCalledTimes(2)
})
})
describe('Error handling', () => {
it('should handle model creation errors', async () => {
const modelError = new Error('Failed to get image model')
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
throw modelError
})
await expect(executor.generateImage({ model: 'invalid-model', prompt: 'A test image' })).rejects.toThrow(
ImageGenerationError
)
})
it('should handle ImageModelResolutionError correctly', async () => {
const resolutionError = new ImageModelResolutionError('invalid-model', 'openai', new Error('Model not found'))
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
throw resolutionError
})
const thrownError = await executor
.generateImage({ model: 'invalid-model', prompt: 'A test image' })
.catch((error) => error)
expect(thrownError).toBeInstanceOf(ImageGenerationError)
expect(thrownError.message).toContain('Failed to generate image:')
expect(thrownError.providerId).toBe('openai')
expect(thrownError.modelId).toBe('invalid-model')
expect(thrownError.cause).toBeInstanceOf(ImageModelResolutionError)
expect(thrownError.cause.message).toContain('Failed to resolve image model: invalid-model')
})
it('should handle ImageModelResolutionError without provider', async () => {
const resolutionError = new ImageModelResolutionError('unknown-model')
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
throw resolutionError
})
await expect(executor.generateImage({ model: 'unknown-model', prompt: 'A test image' })).rejects.toThrow(
ImageGenerationError
)
})
it('should handle image generation API errors', async () => {
const apiError = new Error('API request failed')
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
'Failed to generate image:'
)
})
it('should handle NoImageGeneratedError', async () => {
const noImageError = new NoImageGeneratedError({
cause: new Error('No image generated'),
responses: []
})
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
'Failed to generate image:'
)
})
it('should execute onError plugin hook on failure', async () => {
const error = new Error('Generation failed')
vi.mocked(aiGenerateImage).mockRejectedValue(error)
const errorPlugin: AiPlugin = {
name: 'error-handler',
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create(
'openai',
{
apiKey: 'test-key'
},
[errorPlugin]
)
await expect(executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
'Failed to generate image:'
)
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
modelId: 'dall-e-3'
})
)
})
it('should handle abort signal timeout', async () => {
const abortError = new Error('Operation was aborted')
abortError.name = 'AbortError'
vi.mocked(aiGenerateImage).mockRejectedValue(abortError)
const abortController = new AbortController()
setTimeout(() => abortController.abort(), 10)
await expect(
executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', abortSignal: abortController.signal })
).rejects.toThrow('Failed to generate image:')
})
})
describe('Multiple providers support', () => {
it('should work with different providers', async () => {
const googleExecutor = RuntimeExecutor.create('google', {
apiKey: 'google-key'
})
await googleExecutor.generateImage({ model: 'imagen-3.0-generate-002', prompt: 'A landscape' })
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google|imagen-3.0-generate-002')
})
it('should support xAI Grok image models', async () => {
const xaiExecutor = RuntimeExecutor.create('xai', {
apiKey: 'xai-key'
})
await xaiExecutor.generateImage({ model: 'grok-2-image', prompt: 'A futuristic robot' })
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai|grok-2-image')
})
})
describe('Advanced features', () => {
it('should support batch image generation with maxImagesPerCall', async () => {
await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', n: 10, maxImagesPerCall: 5 })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A test image',
n: 10,
maxImagesPerCall: 5
})
})
it('should support retries with maxRetries', async () => {
await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', maxRetries: 3 })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A test image',
maxRetries: 3
})
})
it('should handle warnings from the model', async () => {
const resultWithWarnings = {
...mockGenerateImageResult,
warnings: [
{
type: 'unsupported-setting',
message: 'Size parameter not supported for this model'
}
]
}
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithWarnings)
const result = await executor.generateImage({
model: 'dall-e-3',
prompt: 'A test image',
size: '2048x2048' // Unsupported size
})
expect(result.warnings).toHaveLength(1)
expect(result.warnings[0].type).toBe('unsupported-setting')
})
it('should provide access to provider metadata', async () => {
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
expect(result.providerMetadata).toBeDefined()
expect(result.providerMetadata.openai).toBeDefined()
})
it('should provide response metadata', async () => {
const resultWithMetadata = {
...mockGenerateImageResult,
responses: [
{
timestamp: new Date(),
modelId: 'dall-e-3',
headers: { 'x-request-id': 'test-123' }
}
]
}
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithMetadata)
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
expect(result.responses).toHaveLength(1)
expect(result.responses[0].modelId).toBe('dall-e-3')
expect(result.responses[0].headers).toEqual({ 'x-request-id': 'test-123' })
})
})
})

View File

@ -0,0 +1,38 @@
/**
* Error classes for runtime operations
*/
/**
* Error thrown when image generation fails
*/
export class ImageGenerationError extends Error {
constructor(
message: string,
public providerId?: string,
public modelId?: string,
public cause?: Error
) {
super(message)
this.name = 'ImageGenerationError'
// Maintain proper stack trace (for V8 engines)
if (Error.captureStackTrace) {
Error.captureStackTrace(this, ImageGenerationError)
}
}
}
/**
* Error thrown when model resolution fails during image generation
*/
export class ImageModelResolutionError extends ImageGenerationError {
constructor(modelId: string, providerId?: string, cause?: Error) {
super(
`Failed to resolve image model: ${modelId}${providerId ? ` for provider: ${providerId}` : ''}`,
providerId,
modelId,
cause
)
this.name = 'ImageModelResolutionError'
}
}

View File

@ -0,0 +1,321 @@
/**
*
* AI调用处理
*/
import { ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
import {
experimental_generateImage as generateImage,
generateObject,
generateText,
LanguageModel,
streamObject,
streamText
} from 'ai'
import { globalModelResolver } from '../models'
import { type ModelConfig } from '../models/types'
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
import { type ProviderId } from '../providers'
import { ImageGenerationError, ImageModelResolutionError } from './errors'
import { PluginEngine } from './pluginEngine'
import { type RuntimeConfig } from './types'
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
public pluginEngine: PluginEngine<T>
// private options: ProviderSettingsMap[T]
private config: RuntimeConfig<T>
constructor(config: RuntimeConfig<T>) {
// if (!isProviderSupported(config.providerId)) {
// throw new Error(`Unsupported provider: ${config.providerId}`)
// }
// 存储options供后续使用
// this.options = config.options
this.config = config
// 创建插件客户端
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
}
private createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
return definePlugin({
name: '_internal_resolveModel',
enforce: 'post',
resolveModel: async (modelId: string) => {
// 注意extraModelConfig 暂时不支持,已在新架构中移除
return await this.resolveModel(modelId, middlewares)
}
})
}
private createResolveImageModelPlugin() {
return definePlugin({
name: '_internal_resolveImageModel',
enforce: 'post',
resolveModel: async (modelId: string) => {
return await this.resolveImageModel(modelId)
}
})
}
private createConfigureContextPlugin() {
return definePlugin({
name: '_internal_configureContext',
configureContext: async (context: AiRequestContext) => {
context.executor = this
}
})
}
// === 高阶重载:直接使用模型 ===
/**
*
*/
async streamText(
params: Parameters<typeof streamText>[0],
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof streamText>> {
const { model, ...restParams } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
} else {
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
}
return this.pluginEngine.executeStreamWithPlugins(
'streamText',
model,
restParams,
async (resolvedModel, transformedParams, streamTransforms) => {
const experimental_transform =
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
const finalParams = {
model: resolvedModel,
...transformedParams,
experimental_transform
} as Parameters<typeof streamText>[0]
return await streamText(finalParams)
}
)
}
// === 其他方法的重载 ===
/**
*
*/
async generateText(
params: Parameters<typeof generateText>[0],
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateText>> {
const { model, ...restParams } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
} else {
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
}
return this.pluginEngine.executeWithPlugins(
'generateText',
model,
restParams,
async (resolvedModel, transformedParams) =>
generateText({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateText>[0])
)
}
/**
*
*/
async generateObject(
params: Parameters<typeof generateObject>[0],
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateObject>> {
const { model, ...restParams } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
} else {
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
}
return this.pluginEngine.executeWithPlugins(
'generateObject',
model,
restParams,
async (resolvedModel, transformedParams) =>
generateObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateObject>[0])
)
}
/**
*
*/
async streamObject(
params: Parameters<typeof streamObject>[0],
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof streamObject>> {
const { model, ...restParams } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
} else {
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
}
return this.pluginEngine.executeWithPlugins(
'streamObject',
model,
restParams,
async (resolvedModel, transformedParams) =>
streamObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof streamObject>[0])
)
}
/**
*
*/
async generateImage(
params: Omit<Parameters<typeof generateImage>[0], 'model'> & { model: string | ImageModelV2 }
): Promise<ReturnType<typeof generateImage>> {
try {
const { model, ...restParams } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()])
} else {
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
}
return await this.pluginEngine.executeImageWithPlugins(
'generateImage',
model,
restParams,
async (resolvedModel, transformedParams) => {
return await generateImage({ model: resolvedModel, ...transformedParams })
}
)
} catch (error) {
if (error instanceof Error) {
const modelId = typeof params.model === 'string' ? params.model : params.model.modelId
throw new ImageGenerationError(
`Failed to generate image: ${error.message}`,
this.config.providerId,
modelId,
error
)
}
throw error
}
}
// === 辅助方法 ===
/**
*
*/
private async resolveModel(
modelOrId: LanguageModel,
middlewares?: LanguageModelV2Middleware[]
): Promise<LanguageModelV2> {
if (typeof modelOrId === 'string') {
// 🎯 字符串modelId使用新的ModelResolver解析传递完整参数
return await globalModelResolver.resolveLanguageModel(
modelOrId, // 支持 'gpt-4' 和 'aihubmix:anthropic:claude-3.5-sonnet'
this.config.providerId, // fallback provider
this.config.providerSettings, // provider options
middlewares // 中间件数组
)
} else {
// 已经是模型,直接返回
return modelOrId
}
}
/**
*
*/
private async resolveImageModel(modelOrId: ImageModelV2 | string): Promise<ImageModelV2> {
try {
if (typeof modelOrId === 'string') {
// 字符串modelId使用新的ModelResolver解析
return await globalModelResolver.resolveImageModel(
modelOrId, // 支持 'dall-e-3' 和 'aihubmix:openai:dall-e-3'
this.config.providerId // fallback provider
)
} else {
// 已经是模型,直接返回
return modelOrId
}
} catch (error) {
throw new ImageModelResolutionError(
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
this.config.providerId,
error instanceof Error ? error : undefined
)
}
}
// === 静态工厂方法 ===
/**
* - provider的类型安全
*/
static create<T extends ProviderId>(
providerId: T,
options: ModelConfig<T>['providerSettings'],
plugins?: AiPlugin[]
): RuntimeExecutor<T> {
return new RuntimeExecutor({
providerId,
providerSettings: options,
plugins
})
}
/**
* OpenAI Compatible执行器
*/
static createOpenAICompatible(
options: ModelConfig<'openai-compatible'>['providerSettings'],
plugins: AiPlugin[] = []
): RuntimeExecutor<'openai-compatible'> {
return new RuntimeExecutor({
providerId: 'openai-compatible',
providerSettings: options,
plugins
})
}
}

View File

@ -0,0 +1,117 @@
/**
* Runtime
* AI调用处理
*/
// 主要的运行时执行器
export { RuntimeExecutor } from './executor'
// 导出类型
export type { RuntimeConfig } from './types'
// === 便捷工厂函数 ===
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
import { type AiPlugin } from '../plugins'
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
import { RuntimeExecutor } from './executor'
/**
* - provider
*/
export function createExecutor<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
plugins?: AiPlugin[]
): RuntimeExecutor<T> {
return RuntimeExecutor.create(providerId, options, plugins)
}
/**
* OpenAI Compatible执行器
*/
export function createOpenAICompatibleExecutor(
options: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' },
plugins: AiPlugin[] = []
): RuntimeExecutor<'openai-compatible'> {
return RuntimeExecutor.createOpenAICompatible(options, plugins)
}
// === 直接调用API无需创建executor实例===
/**
* - middlewares
*/
export async function streamText<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
params: Parameters<RuntimeExecutor<T>['streamText']>[0],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.streamText(params, { middlewares })
}
/**
* - middlewares
*/
export async function generateText<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
params: Parameters<RuntimeExecutor<T>['generateText']>[0],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.generateText(params, { middlewares })
}
/**
* - middlewares
*/
export async function generateObject<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
params: Parameters<RuntimeExecutor<T>['generateObject']>[0],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.generateObject(params, { middlewares })
}
/**
* - middlewares
*/
export async function streamObject<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
params: Parameters<RuntimeExecutor<T>['streamObject']>[0],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.streamObject(params, { middlewares })
}
/**
* - middlewares
*/
export async function generateImage<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
params: Parameters<RuntimeExecutor<T>['generateImage']>[0],
plugins?: AiPlugin[]
): Promise<ReturnType<RuntimeExecutor<T>['generateImage']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.generateImage(params)
}
// === Agent 功能预留 ===
// 未来将在 ../agents/ 文件夹中添加:
// - AgentExecutor.ts
// - WorkflowManager.ts
// - ConversationManager.ts
// 并在此处导出相关API

View File

@ -0,0 +1,290 @@
/* eslint-disable @eslint-react/naming-convention/context-name */
import { ImageModelV2 } from '@ai-sdk/provider'
import { LanguageModel } from 'ai'
import { type AiPlugin, createContext, PluginManager } from '../plugins'
import { type ProviderId } from '../providers/types'
/**
* AI
* API
*/
export class PluginEngine<T extends ProviderId = ProviderId> {
private pluginManager: PluginManager
constructor(
private readonly providerId: T,
// private readonly options: ProviderSettingsMap[T],
plugins: AiPlugin[] = []
) {
this.pluginManager = new PluginManager(plugins)
}
/**
*
*/
use(plugin: AiPlugin): this {
this.pluginManager.use(plugin)
return this
}
/**
*
*/
usePlugins(plugins: AiPlugin[]): this {
plugins.forEach((plugin) => this.use(plugin))
return this
}
/**
*
*/
removePlugin(pluginName: string): this {
this.pluginManager.remove(pluginName)
return this
}
/**
*
*/
getPluginStats() {
return this.pluginManager.getStats()
}
/**
*
*/
getPlugins() {
return this.pluginManager.getPlugins()
}
/**
*
* AiExecutor使用
*/
async executeWithPlugins<TParams, TResult>(
methodName: string,
model: LanguageModel,
params: TParams,
executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 统一处理模型解析
let resolvedModel: LanguageModel | undefined
let modelId: string
if (typeof model === 'string') {
// 字符串:需要通过插件解析
modelId = model
} else {
// 模型对象:直接使用
resolvedModel = model
modelId = model.modelId
}
// 使用正确的createContext创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
context.isRecursiveCall = true
const result = await this.executeWithPlugins(methodName, model, newParams, executor, context)
context.isRecursiveCall = false
return result
}
try {
// 0. 配置上下文
await this.pluginManager.executeConfigureContext(context)
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型(如果是字符串)
if (typeof model === 'string') {
const resolved = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
if (!resolved) {
throw new Error(`Failed to resolve model: ${modelId}`)
}
resolvedModel = resolved
}
if (!resolvedModel) {
throw new Error(`Model resolution failed: no model available`)
}
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 执行具体的 API 调用
const result = await executor(resolvedModel, transformedParams)
// 5. 转换结果(对于非流式调用)
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
// 6. 触发完成事件
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 7. 触发错误事件
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
throw error
}
}
/**
*
* AiExecutor使用
*/
async executeImageWithPlugins<TParams, TResult>(
methodName: string,
model: ImageModelV2 | string,
params: TParams,
executor: (model: ImageModelV2, transformedParams: TParams) => Promise<TResult>,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 统一处理模型解析
let resolvedModel: ImageModelV2 | undefined
let modelId: string
if (typeof model === 'string') {
// 字符串:需要通过插件解析
modelId = model
} else {
// 模型对象:直接使用
resolvedModel = model
modelId = model.modelId
}
// 使用正确的createContext创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
context.isRecursiveCall = true
const result = await this.executeImageWithPlugins(methodName, model, newParams, executor, context)
context.isRecursiveCall = false
return result
}
try {
// 0. 配置上下文
await this.pluginManager.executeConfigureContext(context)
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型(如果是字符串)
if (typeof model === 'string') {
const resolved = await this.pluginManager.executeFirst<ImageModelV2>('resolveModel', modelId, context)
if (!resolved) {
throw new Error(`Failed to resolve image model: ${modelId}`)
}
resolvedModel = resolved
}
if (!resolvedModel) {
throw new Error(`Image model resolution failed: no model available`)
}
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 执行具体的 API 调用
const result = await executor(resolvedModel, transformedParams)
// 5. 转换结果
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
// 6. 触发完成事件
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 7. 触发错误事件
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
throw error
}
}
/**
*
* AiExecutor使用
*/
async executeStreamWithPlugins<TParams, TResult>(
methodName: string,
model: LanguageModel,
params: TParams,
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 统一处理模型解析
let resolvedModel: LanguageModel | undefined
let modelId: string
if (typeof model === 'string') {
// 字符串:需要通过插件解析
modelId = model
} else {
// 模型对象:直接使用
resolvedModel = model
modelId = model.modelId
}
// 创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
context.isRecursiveCall = true
const result = await this.executeStreamWithPlugins(methodName, model, newParams, executor, context)
context.isRecursiveCall = false
return result
}
try {
// 0. 配置上下文
await this.pluginManager.executeConfigureContext(context)
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型(如果是字符串)
if (typeof model === 'string') {
const resolved = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
if (!resolved) {
throw new Error(`Failed to resolve model: ${modelId}`)
}
resolvedModel = resolved
}
if (!resolvedModel) {
throw new Error(`Model resolution failed: no model available`)
}
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 收集流转换器
const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context)
// 5. 执行流式 API 调用
const result = await executor(resolvedModel, transformedParams, streamTransforms)
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
// 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件)
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 7. 触发错误事件
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
throw error
}
}
}

View File

@ -0,0 +1,15 @@
/**
* Runtime
*/
import { type ModelConfig } from '../models/types'
import { type AiPlugin } from '../plugins'
import { type ProviderId } from '../providers/types'
/**
*
*/
export interface RuntimeConfig<T extends ProviderId = ProviderId> {
providerId: T
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
plugins?: AiPlugin[]
}

View File

@ -0,0 +1,46 @@
/**
* Cherry Studio AI Core Package
* Vercel AI SDK AI Provider
*/
// 导入内部使用的类和函数
// ==================== 主要用户接口 ====================
export {
createExecutor,
createOpenAICompatibleExecutor,
generateImage,
generateObject,
generateText,
streamText
} from './core/runtime'
// ==================== 高级API ====================
export { globalModelResolver as modelResolver } from './core/models'
// ==================== 插件系统 ====================
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins'
export { createContext, definePlugin, PluginManager } from './core/plugins'
// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
export { PluginEngine } from './core/runtime/pluginEngine'
// ==================== AI SDK 常用类型导出 ====================
// 直接导出 AI SDK 的常用类型,方便使用
export type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
export type { ToolCall } from '@ai-sdk/provider-utils'
export type { ReasoningPart } from '@ai-sdk/provider-utils'
// ==================== 选项 ====================
export {
createAnthropicOptions,
createGoogleOptions,
createOpenAIOptions,
type ExtractProviderOptions,
mergeProviderOptions,
type ProviderOptionsMap,
type TypedProviderOptions
} from './core/options'
// ==================== 包信息 ====================
export const AI_CORE_VERSION = '1.0.0'
export const AI_CORE_NAME = '@cherrystudio/ai-core'

View File

@ -0,0 +1,2 @@
// 重新导出插件类型
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'

View File

@ -0,0 +1,26 @@
{
"compilerOptions": {
"target": "ES2020",
"module": "ESNext",
"moduleResolution": "bundler",
"declaration": true,
"outDir": "./dist",
"rootDir": "./src",
"strict": true,
"esModuleInterop": true,
"skipLibCheck": true,
"forceConsistentCasingInFileNames": true,
"resolveJsonModule": true,
"allowSyntheticDefaultImports": true,
"noEmitOnError": false,
"experimentalDecorators": true,
"emitDecoratorMetadata": true
},
"include": [
"src/**/*"
],
"exclude": [
"node_modules",
"dist"
]
}

View File

@ -0,0 +1,14 @@
import { defineConfig } from 'tsdown'
export default defineConfig({
entry: {
index: 'src/index.ts',
'built-in/plugins/index': 'src/core/plugins/built-in/index.ts',
'provider/index': 'src/core/providers/index.ts'
},
outDir: 'dist',
format: ['esm', 'cjs'],
clean: true,
dts: true,
tsconfig: 'tsconfig.json'
})

View File

@ -0,0 +1,15 @@
import { defineConfig } from 'vitest/config'
export default defineConfig({
test: {
globals: true
},
resolve: {
alias: {
'@': './src'
}
},
esbuild: {
target: 'node18'
}
})

View File

@ -83,7 +83,7 @@ export enum IpcChannel {
Mcp_UploadDxt = 'mcp:upload-dxt',
Mcp_AbortTool = 'mcp:abort-tool',
Mcp_GetServerVersion = 'mcp:get-server-version',
Mcp_Progress = 'mcp:progress',
// Python
Python_Execute = 'python:execute',
@ -123,6 +123,12 @@ export enum IpcChannel {
Windows_SetMinimumSize = 'window:set-minimum-size',
Windows_Resize = 'window:resize',
Windows_GetSize = 'window:get-size',
Windows_Minimize = 'window:minimize',
Windows_Maximize = 'window:maximize',
Windows_Unmaximize = 'window:unmaximize',
Windows_Close = 'window:close',
Windows_IsMaximized = 'window:is-maximized',
Windows_MaximizedChanged = 'window:maximized-changed',
KnowledgeBase_Create = 'knowledge-base:create',
KnowledgeBase_Reset = 'knowledge-base:reset',
@ -322,6 +328,14 @@ export enum IpcChannel {
TRACE_CLEAN_LOCAL_DATA = 'trace:cleanLocalData',
TRACE_ADD_STREAM_MESSAGE = 'trace:addStreamMessage',
// Anthropic OAuth
Anthropic_StartOAuthFlow = 'anthropic:start-oauth-flow',
Anthropic_CompleteOAuthWithCode = 'anthropic:complete-oauth-with-code',
Anthropic_CancelOAuthFlow = 'anthropic:cancel-oauth-flow',
Anthropic_GetAccessToken = 'anthropic:get-access-token',
Anthropic_HasCredentials = 'anthropic:has-credentials',
Anthropic_ClearCredentials = 'anthropic:clear-credentials',
// CodeTools
CodeTools_Run = 'code-tools:run',

View File

@ -7,7 +7,7 @@ export type LoaderReturn = {
loaderType: string
status?: ProcessingStatus
message?: string
messageSource?: 'preprocess' | 'embedding'
messageSource?: 'preprocess' | 'embedding' | 'validation'
}
export type FileChangeEventType = 'add' | 'change' | 'unlink' | 'addDir' | 'unlinkDir'
@ -17,3 +17,8 @@ export type FileChangeEvent = {
filePath: string
watchPath: string
}
export type MCPProgressEvent = {
callId: string
progress: number // 0-1 range
}

View File

@ -1,199 +1,274 @@
<!DOCTYPE html>
<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>许可协议 | License Agreement</title>
<script src="https://cdn.tailwindcss.com"></script>
</head>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>许可协议 | License Agreement</title>
<script src="https://cdn.tailwindcss.com"></script>
</head>
<body class="bg-gray-50">
<div class="max-w-4xl mx-auto px-4 py-8">
<!-- 中文版本 -->
<div class="mb-12">
<h1 class="text-3xl font-bold mb-8 text-gray-900">许可协议</h1>
<body class="bg-gray-50">
<div class="max-w-4xl mx-auto px-4 py-8">
<!-- 中文版本 -->
<div class="mb-12">
<h1 class="text-3xl font-bold mb-8 text-gray-900">许可协议</h1>
<p class="mb-6 text-gray-700">本项目采用<strong>区分用户的双重许可 (User-Segmented Dual Licensing)</strong> 模式。</p>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">核心原则</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li><strong>个人用户 和 10人及以下企业/组织:</strong> 默认适用 <strong>GNU Affero 通用公共许可证 v3.0 (AGPLv3)</strong></li>
<li><strong>超过10人的企业/组织:</strong> <strong>必须</strong> 获取 <strong>商业许可证 (Commercial License)</strong></li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">定义:"10人及以下"</h2>
<p class="text-gray-700">
指在您的组织包括公司、非营利组织、政府机构、教育机构等任何实体能够访问、使用或以任何方式直接或间接受益于本软件Cherry
Studio功能的个人总数不超过10人。这包括但不限于开发者、测试人员、运营人员、最终用户、通过集成系统间接使用者等。
<p class="mb-6 text-gray-700">
本项目采用<strong>区分用户的双重许可 (User-Segmented Dual Licensing)</strong> 模式。
</p>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">1. 开源许可证 (Open Source License): AGPLv3 - 适用于个人及10人及以下组织
</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>如果您是个人用户,或者您的组织满足上述"10人及以下"的定义,您可以在 <strong>AGPLv3</strong> 的条款下自由使用、修改和分发 Cherry Studio。AGPLv3 的完整文本可以访问
<a href="https://www.gnu.org/licenses/agpl-3.0.html"
class="text-blue-600 hover:underline">https://www.gnu.org/licenses/agpl-3.0.html</a> 获取。
</li>
<li><strong>核心义务:</strong> AGPLv3 的一个关键要求是,如果您修改了 Cherry Studio 并通过网络提供服务,或者分发了修改后的版本,您必须以 AGPLv3
许可证向接收者提供相应的<strong>完整源代码</strong>即使您符合"10人及以下"的标准,如果您希望避免此源代码公开义务,您也需要考虑获取商业许可证(见下文)。</li>
<li>使用前请务必仔细阅读并理解 AGPLv3 的所有条款。</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">核心原则</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>
<strong>个人用户 和 10人及以下企业/组织:</strong> 默认适用
<strong>GNU Affero 通用公共许可证 v3.0 (AGPLv3)</strong>
</li>
<li>
<strong>超过10人的企业/组织:</strong> <strong>必须</strong> 获取
<strong>商业许可证 (Commercial License)</strong>
</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">2. 商业许可证 (Commercial License) - 适用于超过10人的组织或希望规避 AGPLv3
义务的用户</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li><strong>强制要求:</strong>
如果您的组织<strong></strong>满足上述"10人及以下"的定义即有11人或更多人可以访问、使用或受益于本软件<strong>必须</strong>联系我们获取并签署一份商业许可证才能使用
Cherry Studio。</li>
<li><strong>自愿选择:</strong> 即使您的组织满足"10人及以下"的条件,但如果您的使用场景<strong>无法满足 AGPLv3
的条款要求</strong>(特别是关于<strong>源代码公开</strong>的义务),或者您需要 AGPLv3 <strong>未提供</strong>的特定商业条款(如保证、赔偿、无 Copyleft
限制等),您也<strong>必须</strong>联系我们获取并签署一份商业许可证。</li>
<li><strong>需要商业许可证的常见情况包括(但不限于):</strong>
<ul class="list-disc pl-6 mt-2 space-y-1">
<li>您的组织规模超过10人。</li>
<li>(无论组织规模)您希望分发修改过的 Cherry Studio 版本,但<strong>不希望</strong>根据 AGPLv3 公开您修改部分的源代码。</li>
<li>(无论组织规模)您希望基于修改过的 Cherry Studio 提供网络服务SaaS<strong>不希望</strong>根据 AGPLv3 向服务使用者提供修改后的源代码。</li>
<li>(无论组织规模)您的公司政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件,或要求闭源分发及保密。</li>
</ul>
</li>
<li><strong>获取商业许可:</strong> 请通过邮箱 <a href="mailto:bd@cherry-ai.com"
class="text-blue-600 hover:underline">bd@cherry-ai.com</a> 联系 Cherry Studio 开发团队洽谈商业授权事宜。</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">定义:"10人及以下"</h2>
<p class="text-gray-700">
指在您的组织包括公司、非营利组织、政府机构、教育机构等任何实体能够访问、使用或以任何方式直接或间接受益于本软件Cherry
Studio功能的个人总数不超过10人。这包括但不限于开发者、测试人员、运营人员、最终用户、通过集成系统间接使用者等。
</p>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">3. 贡献 (Contributions)</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>我们欢迎社区对 Cherry Studio 的贡献。所有向本项目提交的贡献都将被视为在 <strong>AGPLv3</strong> 许可证下提供。</li>
<li>通过向本项目提交贡献(例如通过 Pull Request即表示您同意您的代码以 AGPLv3 许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。</li>
<li>您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 Cherry Studio 版本中。</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">
1. 开源许可证 (Open Source License): AGPLv3 - 适用于个人及10人及以下组织
</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>
如果您是个人用户,或者您的组织满足上述"10人及以下"的定义,您可以在
<strong>AGPLv3</strong> 的条款下自由使用、修改和分发 Cherry Studio。AGPLv3 的完整文本可以访问
<a href="https://www.gnu.org/licenses/agpl-3.0.html" class="text-blue-600 hover:underline"
>https://www.gnu.org/licenses/agpl-3.0.html</a
>
获取。
</li>
<li>
<strong>核心义务:</strong> AGPLv3 的一个关键要求是,如果您修改了 Cherry Studio
并通过网络提供服务,或者分发了修改后的版本,您必须以 AGPLv3
许可证向接收者提供相应的<strong>完整源代码</strong>。即使您符合"10人及以下"的标准,如果您希望避免此源代码公开义务,您也需要考虑获取商业许可证(见下文)。
</li>
<li>使用前请务必仔细阅读并理解 AGPLv3 的所有条款。</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">4. 其他条款 (Other Terms)</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>关于商业许可证的具体条款和条件,以双方签署的正式商业许可协议为准。</li>
<li>项目维护者保留根据需要更新本许可政策(包括用户规模定义和阈值)的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">
2. 商业许可证 (Commercial License) - 适用于超过10人的组织或希望规避 AGPLv3 义务的用户
</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>
<strong>强制要求:</strong>
如果您的组织<strong></strong>满足上述"10人及以下"的定义即有11人或更多人可以访问、使用或受益于本软件<strong>必须</strong>联系我们获取并签署一份商业许可证才能使用
Cherry Studio。
</li>
<li>
<strong>自愿选择:</strong> 即使您的组织满足"10人及以下"的条件,但如果您的使用场景<strong
>无法满足 AGPLv3 的条款要求</strong
>(特别是关于<strong>源代码公开</strong>的义务),或者您需要 AGPLv3
<strong>未提供</strong>的特定商业条款(如保证、赔偿、无 Copyleft
限制等),您也<strong>必须</strong>联系我们获取并签署一份商业许可证。
</li>
<li>
<strong>需要商业许可证的常见情况包括(但不限于):</strong>
<ul class="list-disc pl-6 mt-2 space-y-1">
<li>您的组织规模超过10人。</li>
<li>
(无论组织规模)您希望分发修改过的 Cherry Studio 版本,但<strong>不希望</strong>根据 AGPLv3
公开您修改部分的源代码。
</li>
<li>
(无论组织规模)您希望基于修改过的 Cherry Studio 提供网络服务SaaS<strong>不希望</strong>根据
AGPLv3 向服务使用者提供修改后的源代码。
</li>
<li>
(无论组织规模)您的公司政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件,或要求闭源分发及保密。
</li>
</ul>
</li>
<li>
<strong>获取商业许可:</strong> 请通过邮箱
<a href="mailto:bd@cherry-ai.com" class="text-blue-600 hover:underline">bd@cherry-ai.com</a> 联系 Cherry
Studio 开发团队洽谈商业授权事宜。
</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">3. 贡献 (Contributions)</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>
我们欢迎社区对 Cherry Studio 的贡献。所有向本项目提交的贡献都将被视为在
<strong>AGPLv3</strong> 许可证下提供。
</li>
<li>
通过向本项目提交贡献(例如通过 Pull Request即表示您同意您的代码以 AGPLv3
许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。
</li>
<li>您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 Cherry Studio 版本中。</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">4. 其他条款 (Other Terms)</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>关于商业许可证的具体条款和条件,以双方签署的正式商业许可协议为准。</li>
<li>
项目维护者保留根据需要更新本许可政策(包括用户规模定义和阈值)的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。
</li>
</ul>
</section>
</div>
<hr class="my-12 border-gray-300" />
<!-- English Version -->
<div>
<h1 class="text-3xl font-bold mb-8 text-gray-900">Licensing</h1>
<p class="mb-6 text-gray-700">This project employs a <strong>User-Segmented Dual Licensing</strong> model.</p>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">Core Principle</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>
<strong>Individual Users and Organizations with 10 or Fewer Individuals:</strong> Governed by default
under the <strong>GNU Affero General Public License v3.0 (AGPLv3)</strong>.
</li>
<li>
<strong>Organizations with More Than 10 Individuals:</strong> <strong>Must</strong> obtain a
<strong>Commercial License</strong>.
</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">Definition: "10 or Fewer Individuals"</h2>
<p class="text-gray-700">
Refers to any organization (including companies, non-profits, government agencies, educational institutions,
etc.) where the total number of individuals who can access, use, or in any way directly or indirectly
benefit from the functionality of this software (Cherry Studio) does not exceed 10. This includes, but is
not limited to, developers, testers, operations staff, end-users, and indirect users via integrated systems.
</p>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">
1. Open Source License: AGPLv3 - For Individuals and Organizations of 10 or Fewer
</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>
If you are an individual user, or if your organization meets the "10 or Fewer Individuals" definition
above, you are free to use, modify, and distribute Cherry Studio under the terms of the
<strong>AGPLv3</strong>. The full text of the AGPLv3 can be found at
<a href="https://www.gnu.org/licenses/agpl-3.0.html" class="text-blue-600 hover:underline"
>https://www.gnu.org/licenses/agpl-3.0.html</a
>.
</li>
<li>
<strong>Core Obligation:</strong> A key requirement of the AGPLv3 is that if you modify Cherry Studio and
make it available over a network, or distribute the modified version, you must provide the
<strong>complete corresponding source code</strong> under the AGPLv3 license to the recipients. Even if
you qualify under the "10 or Fewer Individuals" rule, if you wish to avoid this source code disclosure
obligation, you will need to obtain a Commercial License (see below).
</li>
<li>Please read and understand the full terms of the AGPLv3 carefully before use.</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">
2. Commercial License - For Organizations with More Than 10 Individuals, or Users Needing to Avoid AGPLv3
Obligations
</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>
<strong>Mandatory Requirement:</strong> If your organization does <strong>not</strong> meet the "10 or
Fewer Individuals" definition above (i.e., 11 or more individuals can access, use, or benefit from the
software), you <strong>must</strong> contact us to obtain and execute a Commercial License to use Cherry
Studio.
</li>
<li>
<strong>Voluntary Option:</strong> Even if your organization meets the "10 or Fewer Individuals"
condition, if your intended use case
<strong>cannot comply with the terms of the AGPLv3</strong> (particularly the obligations regarding
<strong>source code disclosure</strong>), or if you require specific commercial terms
<strong>not offered</strong> by the AGPLv3 (such as warranties, indemnities, or freedom from copyleft
restrictions), you also <strong>must</strong> contact us to obtain and execute a Commercial License.
</li>
<li>
<strong>Common scenarios requiring a Commercial License include (but are not limited to):</strong>
<ul class="list-disc pl-6 mt-2 space-y-1">
<li>
Your organization has more than 10 individuals who can access, use, or benefit from the software.
</li>
<li>
(Regardless of organization size) You wish to distribute a modified version of Cherry Studio but
<strong>do not want</strong> to disclose the source code of your modifications under AGPLv3.
</li>
<li>
(Regardless of organization size) You wish to provide a network service (SaaS) based on a modified
version of Cherry Studio but <strong>do not want</strong> to provide the modified source code to users
of the service under AGPLv3.
</li>
<li>
(Regardless of organization size) Your corporate policies, client contracts, or project requirements
prohibit the use of AGPLv3-licensed software or mandate closed-source distribution and
confidentiality.
</li>
</ul>
</li>
<li>
<strong>Obtaining a Commercial License:</strong> Please contact the Cherry Studio development team via
email at <a href="mailto:bd@cherry-ai.com" class="text-blue-600 hover:underline">bd@cherry-ai.com</a> to
discuss commercial licensing options.
</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">3. Contributions</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>
We welcome community contributions to Cherry Studio. All contributions submitted to this project are
considered to be offered under the <strong>AGPLv3</strong> license.
</li>
<li>
By submitting a contribution to this project (e.g., via a Pull Request), you agree to license your code
under the AGPLv3 to the project and all its downstream users (regardless of whether those users ultimately
operate under AGPLv3 or a Commercial License).
</li>
<li>
You also understand and agree that your contribution may be included in distributions of Cherry Studio
offered under our commercial license.
</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">4. Other Terms</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>
The specific terms and conditions of the Commercial License are governed by the formal commercial license
agreement signed by both parties.
</li>
<li>
The project maintainers reserve the right to update this licensing policy (including the definition and
threshold for user count) as needed. Updates will be communicated through official project channels (e.g.,
code repository, official website).
</li>
</ul>
</section>
</div>
</div>
<hr class="my-12 border-gray-300">
<!-- English Version -->
<div>
<h1 class="text-3xl font-bold mb-8 text-gray-900">Licensing</h1>
<p class="mb-6 text-gray-700">This project employs a <strong>User-Segmented Dual Licensing</strong> model.</p>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">Core Principle</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li><strong>Individual Users and Organizations with 10 or Fewer Individuals:</strong> Governed by default
under the <strong>GNU Affero General Public License v3.0 (AGPLv3)</strong>.</li>
<li><strong>Organizations with More Than 10 Individuals:</strong> <strong>Must</strong> obtain a
<strong>Commercial License</strong>.
</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">Definition: "10 or Fewer Individuals"</h2>
<p class="text-gray-700">
Refers to any organization (including companies, non-profits, government agencies, educational institutions,
etc.) where the total number of individuals who can access, use, or in any way directly or indirectly benefit
from the functionality of this software (Cherry Studio) does not exceed 10. This includes, but is not limited
to, developers, testers, operations staff, end-users, and indirect users via integrated systems.
</p>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">1. Open Source License: AGPLv3 - For Individuals and
Organizations of 10 or Fewer</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>If you are an individual user, or if your organization meets the "10 or Fewer Individuals" definition
above, you are free to use, modify, and distribute Cherry Studio under the terms of the
<strong>AGPLv3</strong>. The full text of the AGPLv3 can be found at <a
href="https://www.gnu.org/licenses/agpl-3.0.html"
class="text-blue-600 hover:underline">https://www.gnu.org/licenses/agpl-3.0.html</a>.
</li>
<li><strong>Core Obligation:</strong> A key requirement of the AGPLv3 is that if you modify Cherry Studio and
make it available over a network, or distribute the modified version, you must provide the <strong>complete
corresponding source code</strong> under the AGPLv3 license to the recipients. Even if you qualify under
the "10 or Fewer Individuals" rule, if you wish to avoid this source code disclosure obligation, you will
need to obtain a Commercial License (see below).</li>
<li>Please read and understand the full terms of the AGPLv3 carefully before use.</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">2. Commercial License - For Organizations with More Than 10
Individuals, or Users Needing to Avoid AGPLv3 Obligations</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li><strong>Mandatory Requirement:</strong> If your organization does <strong>not</strong> meet the "10 or
Fewer Individuals" definition above (i.e., 11 or more individuals can access, use, or benefit from the
software), you <strong>must</strong> contact us to obtain and execute a Commercial License to use Cherry
Studio.</li>
<li><strong>Voluntary Option:</strong> Even if your organization meets the "10 or Fewer Individuals"
condition, if your intended use case <strong>cannot comply with the terms of the AGPLv3</strong>
(particularly the obligations regarding <strong>source code disclosure</strong>), or if you require specific
commercial terms <strong>not offered</strong> by the AGPLv3 (such as warranties, indemnities, or freedom
from copyleft restrictions), you also <strong>must</strong> contact us to obtain and execute a Commercial
License.</li>
<li><strong>Common scenarios requiring a Commercial License include (but are not limited to):</strong>
<ul class="list-disc pl-6 mt-2 space-y-1">
<li>Your organization has more than 10 individuals who can access, use, or benefit from the software.</li>
<li>(Regardless of organization size) You wish to distribute a modified version of Cherry Studio but
<strong>do not want</strong> to disclose the source code of your modifications under AGPLv3.
</li>
<li>(Regardless of organization size) You wish to provide a network service (SaaS) based on a modified
version of Cherry Studio but <strong>do not want</strong> to provide the modified source code to users
of the service under AGPLv3.</li>
<li>(Regardless of organization size) Your corporate policies, client contracts, or project requirements
prohibit the use of AGPLv3-licensed software or mandate closed-source distribution and confidentiality.
</li>
</ul>
</li>
<li><strong>Obtaining a Commercial License:</strong> Please contact the Cherry Studio development team via
email at <a href="mailto:bd@cherry-ai.com" class="text-blue-600 hover:underline">bd@cherry-ai.com</a> to
discuss commercial licensing options.</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">3. Contributions</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>We welcome community contributions to Cherry Studio. All contributions submitted to this project are
considered to be offered under the <strong>AGPLv3</strong> license.</li>
<li>By submitting a contribution to this project (e.g., via a Pull Request), you agree to license your code
under the AGPLv3 to the project and all its downstream users (regardless of whether those users ultimately
operate under AGPLv3 or a Commercial License).</li>
<li>You also understand and agree that your contribution may be included in distributions of Cherry Studio
offered under our commercial license.</li>
</ul>
</section>
<section class="mb-8">
<h2 class="text-xl font-semibold mb-4 text-gray-900">4. Other Terms</h2>
<ul class="list-disc pl-6 space-y-2 text-gray-700">
<li>The specific terms and conditions of the Commercial License are governed by the formal commercial license
agreement signed by both parties.</li>
<li>The project maintainers reserve the right to update this licensing policy (including the definition and
threshold for user count) as needed. Updates will be communicated through official project channels (e.g.,
code repository, official website).</li>
</ul>
</section>
</div>
</div>
</body>
</html>
</body>
</html>

File diff suppressed because one or more lines are too long

View File

@ -7,6 +7,7 @@ import { preferenceService } from '@data/PreferenceService'
import { loggerService } from '@logger'
import { isLinux, isMac, isPortable, isWin } from '@main/constant'
import { generateSignature } from '@main/integration/cherryin'
import anthropicService from '@main/services/AnthropicService'
import { getBinaryPath, isBinaryExists, runInstallScript } from '@main/utils/process'
import { handleZoomFactor } from '@main/utils/zoom'
import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
@ -27,7 +28,7 @@ import DxtService from './services/DxtService'
import { ExportService } from './services/ExportService'
import { fileStorage as fileManager } from './services/FileStorage'
import FileService from './services/FileSystemService'
import KnowledgeService from './services/KnowledgeService'
import KnowledgeService from './services/knowledge/KnowledgeService'
import mcpService from './services/MCPService'
import MemoryService from './services/memory/MemoryService'
import { openTraceWindow, setTraceWindowTitle } from './services/NodeTraceService'
@ -524,7 +525,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
}
})
// knowledge base
ipcMain.handle(IpcChannel.KnowledgeBase_Create, KnowledgeService.create.bind(KnowledgeService))
ipcMain.handle(IpcChannel.KnowledgeBase_Reset, KnowledgeService.reset.bind(KnowledgeService))
ipcMain.handle(IpcChannel.KnowledgeBase_Delete, KnowledgeService.delete.bind(KnowledgeService))
@ -588,6 +588,41 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
return [width, height]
})
// Window Controls
ipcMain.handle(IpcChannel.Windows_Minimize, () => {
checkMainWindow()
mainWindow.minimize()
})
ipcMain.handle(IpcChannel.Windows_Maximize, () => {
checkMainWindow()
mainWindow.maximize()
})
ipcMain.handle(IpcChannel.Windows_Unmaximize, () => {
checkMainWindow()
mainWindow.unmaximize()
})
ipcMain.handle(IpcChannel.Windows_Close, () => {
checkMainWindow()
mainWindow.close()
})
ipcMain.handle(IpcChannel.Windows_IsMaximized, () => {
checkMainWindow()
return mainWindow.isMaximized()
})
// Send maximized state changes to renderer
mainWindow.on('maximize', () => {
mainWindow.webContents.send(IpcChannel.Windows_MaximizedChanged, true)
})
mainWindow.on('unmaximize', () => {
mainWindow.webContents.send(IpcChannel.Windows_MaximizedChanged, false)
})
// VertexAI
ipcMain.handle(IpcChannel.VertexAI_GetAuthHeaders, async (_, params) => {
return vertexAIService.getAuthHeaders(params)
@ -747,6 +782,16 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
addStreamMessage(spanId, modelName, context, msg)
)
// Anthropic OAuth
ipcMain.handle(IpcChannel.Anthropic_StartOAuthFlow, () => anthropicService.startOAuthFlow())
ipcMain.handle(IpcChannel.Anthropic_CompleteOAuthWithCode, (_, code: string) =>
anthropicService.completeOAuthWithCode(code)
)
ipcMain.handle(IpcChannel.Anthropic_CancelOAuthFlow, () => anthropicService.cancelOAuthFlow())
ipcMain.handle(IpcChannel.Anthropic_GetAccessToken, () => anthropicService.getValidAccessToken())
ipcMain.handle(IpcChannel.Anthropic_HasCredentials, () => anthropicService.hasCredentials())
ipcMain.handle(IpcChannel.Anthropic_ClearCredentials, () => anthropicService.clearCredentials())
// CodeTools
ipcMain.handle(IpcChannel.CodeTools_Run, codeToolsService.run)

View File

@ -0,0 +1,63 @@
import { VoyageEmbeddings } from '@langchain/community/embeddings/voyage'
import type { Embeddings } from '@langchain/core/embeddings'
import { OllamaEmbeddings } from '@langchain/ollama'
import { AzureOpenAIEmbeddings, OpenAIEmbeddings } from '@langchain/openai'
import { ApiClient, SystemProviderIds } from '@types'
import { isJinaEmbeddingsModel, JinaEmbeddings } from './JinaEmbeddings'
export default class EmbeddingsFactory {
static create({ embedApiClient, dimensions }: { embedApiClient: ApiClient; dimensions?: number }): Embeddings {
const batchSize = 10
const { model, provider, apiKey, apiVersion, baseURL } = embedApiClient
if (provider === SystemProviderIds.ollama) {
let baseUrl = baseURL
if (baseURL.includes('v1/')) {
baseUrl = baseURL.replace('v1/', '')
}
const headers = apiKey
? {
Authorization: `Bearer ${apiKey}`
}
: undefined
return new OllamaEmbeddings({
model: model,
baseUrl,
...headers
})
} else if (provider === SystemProviderIds.voyageai) {
return new VoyageEmbeddings({
modelName: model,
apiKey,
outputDimension: dimensions,
batchSize
})
}
if (isJinaEmbeddingsModel(model)) {
return new JinaEmbeddings({
model,
apiKey,
batchSize,
dimensions,
baseUrl: baseURL
})
}
if (apiVersion !== undefined) {
return new AzureOpenAIEmbeddings({
azureOpenAIApiKey: apiKey,
azureOpenAIApiVersion: apiVersion,
azureOpenAIApiDeploymentName: model,
azureOpenAIEndpoint: baseURL,
dimensions,
batchSize
})
}
return new OpenAIEmbeddings({
model,
apiKey,
dimensions,
batchSize,
configuration: { baseURL }
})
}
}

View File

@ -0,0 +1,199 @@
import { Embeddings, type EmbeddingsParams } from '@langchain/core/embeddings'
import { chunkArray } from '@langchain/core/utils/chunk_array'
import { getEnvironmentVariable } from '@langchain/core/utils/env'
import z from 'zod/v4'
const jinaModelSchema = z.union([
z.literal('jina-clip-v2'),
z.literal('jina-embeddings-v3'),
z.literal('jina-colbert-v2'),
z.literal('jina-clip-v1'),
z.literal('jina-colbert-v1-en'),
z.literal('jina-embeddings-v2-base-es'),
z.literal('jina-embeddings-v2-base-code'),
z.literal('jina-embeddings-v2-base-de'),
z.literal('jina-embeddings-v2-base-zh'),
z.literal('jina-embeddings-v2-base-en')
])
type JinaModel = z.infer<typeof jinaModelSchema>
export const isJinaEmbeddingsModel = (model: string): model is JinaModel => {
return jinaModelSchema.safeParse(model).success
}
interface JinaEmbeddingsParams extends EmbeddingsParams {
/** Model name to use */
model: JinaModel
baseUrl?: string
/**
* Timeout to use when making requests to Jina.
*/
timeout?: number
/**
* The maximum number of documents to embed in a single request.
*/
batchSize?: number
/**
* Whether to strip new lines from the input text.
*/
stripNewLines?: boolean
/**
* The dimensions of the embedding.
*/
dimensions?: number
/**
* Scales the embedding so its Euclidean (L2) norm becomes 1, preserving direction. Useful when downstream involves dot-product, classification, visualization..
*/
normalized?: boolean
}
type JinaMultiModelInput =
| {
text: string
image?: never
}
| {
image: string
text?: never
}
type JinaEmbeddingsInput = string | JinaMultiModelInput
interface EmbeddingCreateParams {
model: JinaEmbeddingsParams['model']
/**
* input can be strings or JinaMultiModelInputs,if you want embed image,you should use JinaMultiModelInputs
*/
input: JinaEmbeddingsInput[]
dimensions: number
task?: 'retrieval.query' | 'retrieval.passage'
}
interface EmbeddingResponse {
model: string
object: string
usage: {
total_tokens: number
prompt_tokens: number
}
data: {
object: string
index: number
embedding: number[]
}[]
}
interface EmbeddingErrorResponse {
detail: string
}
export class JinaEmbeddings extends Embeddings implements JinaEmbeddingsParams {
model: JinaEmbeddingsParams['model'] = 'jina-clip-v2'
batchSize = 24
baseUrl = 'https://api.jina.ai/v1/embeddings'
stripNewLines = true
dimensions = 1024
apiKey: string
constructor(
fields?: Partial<JinaEmbeddingsParams> & {
apiKey?: string
}
) {
const fieldsWithDefaults = { maxConcurrency: 2, ...fields }
super(fieldsWithDefaults)
const apiKey =
fieldsWithDefaults?.apiKey || getEnvironmentVariable('JINA_API_KEY') || getEnvironmentVariable('JINA_AUTH_TOKEN')
if (!apiKey) throw new Error('Jina API key not found')
this.apiKey = apiKey
this.baseUrl = fieldsWithDefaults?.baseUrl ? `${fieldsWithDefaults?.baseUrl}embeddings` : this.baseUrl
this.model = fieldsWithDefaults?.model ?? this.model
this.dimensions = fieldsWithDefaults?.dimensions ?? this.dimensions
this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize
this.stripNewLines = fieldsWithDefaults?.stripNewLines ?? this.stripNewLines
}
private doStripNewLines(input: JinaEmbeddingsInput[]) {
if (this.stripNewLines) {
return input.map((i) => {
if (typeof i === 'string') {
return i.replace(/\n/g, ' ')
}
if (i.text) {
return { text: i.text.replace(/\n/g, ' ') }
}
return i
})
}
return input
}
async embedDocuments(input: JinaEmbeddingsInput[]): Promise<number[][]> {
const batches = chunkArray(this.doStripNewLines(input), this.batchSize)
const batchRequests = batches.map((batch) => {
const params = this.getParams(batch)
return this.embeddingWithRetry(params)
})
const batchResponses = await Promise.all(batchRequests)
const embeddings: number[][] = []
for (let i = 0; i < batchResponses.length; i += 1) {
const batch = batches[i]
const batchResponse = batchResponses[i] || []
for (let j = 0; j < batch.length; j += 1) {
embeddings.push(batchResponse[j])
}
}
return embeddings
}
async embedQuery(input: JinaEmbeddingsInput): Promise<number[]> {
const params = this.getParams(this.doStripNewLines([input]), true)
const embeddings = (await this.embeddingWithRetry(params)) || [[]]
return embeddings[0]
}
private getParams(input: JinaEmbeddingsInput[], query?: boolean): EmbeddingCreateParams {
return {
model: this.model,
input,
dimensions: this.dimensions,
task: query ? 'retrieval.query' : this.model === 'jina-clip-v2' ? undefined : 'retrieval.passage'
}
}
private async embeddingWithRetry(body: EmbeddingCreateParams) {
const response = await fetch(this.baseUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.apiKey}`
},
body: JSON.stringify(body)
})
const embeddingData: EmbeddingResponse | EmbeddingErrorResponse = await response.json()
if ('detail' in embeddingData && embeddingData.detail) {
throw new Error(`${embeddingData.detail}`)
}
return (embeddingData as EmbeddingResponse).data.map(({ embedding }) => embedding)
}
}

View File

@ -0,0 +1,25 @@
import type { Embeddings as BaseEmbeddings } from '@langchain/core/embeddings'
import { TraceMethod } from '@mcp-trace/trace-core'
import { ApiClient } from '@types'
import EmbeddingsFactory from './EmbeddingsFactory'
export default class TextEmbeddings {
private sdk: BaseEmbeddings
constructor({ embedApiClient, dimensions }: { embedApiClient: ApiClient; dimensions?: number }) {
this.sdk = EmbeddingsFactory.create({
embedApiClient,
dimensions
})
}
@TraceMethod({ spanName: 'embedDocuments', tag: 'Embeddings' })
public async embedDocuments(texts: string[]): Promise<number[][]> {
return this.sdk.embedDocuments(texts)
}
@TraceMethod({ spanName: 'embedQuery', tag: 'Embeddings' })
public async embedQuery(text: string): Promise<number[]> {
return this.sdk.embedQuery(text)
}
}

View File

@ -0,0 +1,97 @@
import { BaseDocumentLoader } from '@langchain/core/document_loaders/base'
import { Document } from '@langchain/core/documents'
import { readTextFileWithAutoEncoding } from '@main/utils/file'
import MarkdownIt from 'markdown-it'
export class MarkdownLoader extends BaseDocumentLoader {
private path: string
private md: MarkdownIt
constructor(path: string) {
super()
this.path = path
this.md = new MarkdownIt()
}
public async load(): Promise<Document[]> {
const content = await readTextFileWithAutoEncoding(this.path)
return this.parseMarkdown(content)
}
private parseMarkdown(content: string): Document[] {
const tokens = this.md.parse(content, {})
const documents: Document[] = []
let currentSection: {
heading?: string
level?: number
content: string
startLine?: number
} = { content: '' }
let i = 0
while (i < tokens.length) {
const token = tokens[i]
if (token.type === 'heading_open') {
// Save previous section if it has content
if (currentSection.content.trim()) {
documents.push(
new Document({
pageContent: currentSection.content.trim(),
metadata: {
source: this.path,
heading: currentSection.heading || 'Introduction',
level: currentSection.level || 0,
startLine: currentSection.startLine || 0
}
})
)
}
// Start new section
const level = parseInt(token.tag.slice(1)) // Extract number from h1, h2, etc.
const headingContent = tokens[i + 1]?.content || ''
currentSection = {
heading: headingContent,
level: level,
content: '',
startLine: token.map?.[0] || 0
}
// Skip heading_open, inline, heading_close tokens
i += 3
continue
}
// Add token content to current section
if (token.content) {
currentSection.content += token.content
}
// Add newlines for block tokens
if (token.block && token.type !== 'heading_close') {
currentSection.content += '\n'
}
i++
}
// Add the last section
if (currentSection.content.trim()) {
documents.push(
new Document({
pageContent: currentSection.content.trim(),
metadata: {
source: this.path,
heading: currentSection.heading || 'Introduction',
level: currentSection.level || 0,
startLine: currentSection.startLine || 0
}
})
)
}
return documents
}
}

View File

@ -0,0 +1,50 @@
import { BaseDocumentLoader } from '@langchain/core/document_loaders/base'
import { Document } from '@langchain/core/documents'
export class NoteLoader extends BaseDocumentLoader {
private text: string
private sourceUrl?: string
constructor(
public _text: string,
public _sourceUrl?: string
) {
super()
this.text = _text
this.sourceUrl = _sourceUrl
}
/**
* A protected method that takes a `raw` string as a parameter and returns
* a promise that resolves to an array containing the raw text as a single
* element.
* @param raw The raw text to be parsed.
* @returns A promise that resolves to an array containing the raw text as a single element.
*/
protected async parse(raw: string): Promise<string[]> {
return [raw]
}
public async load(): Promise<Document[]> {
const metadata = { source: this.sourceUrl || 'note' }
const parsed = await this.parse(this.text)
parsed.forEach((pageContent, i) => {
if (typeof pageContent !== 'string') {
throw new Error(`Expected string, at position ${i} got ${typeof pageContent}`)
}
})
return parsed.map(
(pageContent, i) =>
new Document({
pageContent,
metadata:
parsed.length === 1
? metadata
: {
...metadata,
line: i + 1
}
})
)
}
}

View File

@ -0,0 +1,170 @@
import { BaseDocumentLoader } from '@langchain/core/document_loaders/base'
import { Document } from '@langchain/core/documents'
import { Innertube } from 'youtubei.js'
// ... (接口定义 YoutubeConfig 和 VideoMetadata 保持不变)
/**
* Configuration options for the YoutubeLoader class. Includes properties
* such as the videoId, language, and addVideoInfo.
*/
interface YoutubeConfig {
videoId: string
language?: string
addVideoInfo?: boolean
// 新增一个选项,用于控制输出格式
transcriptFormat?: 'text' | 'srt'
}
/**
* Metadata of a YouTube video. Includes properties such as the source
* (videoId), description, title, view_count, author, and category.
*/
interface VideoMetadata {
source: string
description?: string
title?: string
view_count?: number
author?: string
category?: string
}
/**
* A document loader for loading data from YouTube videos. It uses the
* youtubei.js library to fetch the transcript and video metadata.
* @example
* ```typescript
* const loader = new YoutubeLoader({
* videoId: "VIDEO_ID",
* language: "en",
* addVideoInfo: true,
* transcriptFormat: "srt" // 获取 SRT 格式
* });
* const docs = await loader.load();
* console.log(docs[0].pageContent);
* ```
*/
export class YoutubeLoader extends BaseDocumentLoader {
private videoId: string
private language?: string
private addVideoInfo: boolean
// 新增格式化选项的私有属性
private transcriptFormat: 'text' | 'srt'
constructor(config: YoutubeConfig) {
super()
this.videoId = config.videoId
this.language = config?.language
this.addVideoInfo = config?.addVideoInfo ?? false
// 初始化格式化选项,默认为 'text' 以保持向后兼容
this.transcriptFormat = config?.transcriptFormat ?? 'text'
}
/**
* Extracts the videoId from a YouTube video URL.
* @param url The URL of the YouTube video.
* @returns The videoId of the YouTube video.
*/
private static getVideoID(url: string): string {
const match = url.match(/.*(?:youtu.be\/|v\/|u\/\w\/|embed\/|watch\?v=)([^#&?]*).*/)
if (match !== null && match[1].length === 11) {
return match[1]
} else {
throw new Error('Failed to get youtube video id from the url')
}
}
/**
* Creates a new instance of the YoutubeLoader class from a YouTube video
* URL.
* @param url The URL of the YouTube video.
* @param config Optional configuration options for the YoutubeLoader instance, excluding the videoId.
* @returns A new instance of the YoutubeLoader class.
*/
static createFromUrl(url: string, config?: Omit<YoutubeConfig, 'videoId'>): YoutubeLoader {
const videoId = YoutubeLoader.getVideoID(url)
return new YoutubeLoader({ ...config, videoId })
}
/**
* [] SRT (HH:MM:SS,ms)
* @param ms
* @returns
*/
private static formatTimestamp(ms: number): string {
const totalSeconds = Math.floor(ms / 1000)
const hours = Math.floor(totalSeconds / 3600)
.toString()
.padStart(2, '0')
const minutes = Math.floor((totalSeconds % 3600) / 60)
.toString()
.padStart(2, '0')
const seconds = (totalSeconds % 60).toString().padStart(2, '0')
const milliseconds = (ms % 1000).toString().padStart(3, '0')
return `${hours}:${minutes}:${seconds},${milliseconds}`
}
/**
* Loads the transcript and video metadata from the specified YouTube
* video. It can return the transcript as plain text or in SRT format.
* @returns An array of Documents representing the retrieved data.
*/
async load(): Promise<Document[]> {
const metadata: VideoMetadata = {
source: this.videoId
}
try {
const youtube = await Innertube.create({
lang: this.language,
retrieve_player: false
})
const info = await youtube.getInfo(this.videoId)
const transcriptData = await info.getTranscript()
if (!transcriptData.transcript.content?.body?.initial_segments) {
throw new Error('Transcript segments not found in the response.')
}
const segments = transcriptData.transcript.content.body.initial_segments
let pageContent: string
// 根据 transcriptFormat 选项决定如何格式化字幕
if (this.transcriptFormat === 'srt') {
// [修改] 将字幕片段格式化为 SRT 格式
pageContent = segments
.map((segment, index) => {
const srtIndex = index + 1
const startTime = YoutubeLoader.formatTimestamp(Number(segment.start_ms))
const endTime = YoutubeLoader.formatTimestamp(Number(segment.end_ms))
const text = segment.snippet?.text || '' // 使用 segment.snippet.text
return `${srtIndex}\n${startTime} --> ${endTime}\n${text}`
})
.join('\n\n') // 每个 SRT 块之间用两个换行符分隔
} else {
// [原始逻辑] 拼接为纯文本
pageContent = segments.map((segment) => segment.snippet?.text || '').join(' ')
}
if (this.addVideoInfo) {
const basicInfo = info.basic_info
metadata.description = basicInfo.short_description
metadata.title = basicInfo.title
metadata.view_count = basicInfo.view_count
metadata.author = basicInfo.author
}
const document = new Document({
pageContent,
metadata
})
return [document]
} catch (e: unknown) {
throw new Error(`Failed to get YouTube video transcription: ${(e as Error).message}`)
}
}
}

View File

@ -0,0 +1,235 @@
import { DocxLoader } from '@langchain/community/document_loaders/fs/docx'
import { EPubLoader } from '@langchain/community/document_loaders/fs/epub'
import { PDFLoader } from '@langchain/community/document_loaders/fs/pdf'
import { PPTXLoader } from '@langchain/community/document_loaders/fs/pptx'
import { CheerioWebBaseLoader } from '@langchain/community/document_loaders/web/cheerio'
import { SitemapLoader } from '@langchain/community/document_loaders/web/sitemap'
import { FaissStore } from '@langchain/community/vectorstores/faiss'
import { Document } from '@langchain/core/documents'
import { loggerService } from '@logger'
import { UrlSource } from '@main/utils/knowledge'
import { LoaderReturn } from '@shared/config/types'
import { FileMetadata, FileTypes, KnowledgeBaseParams } from '@types'
import { randomUUID } from 'crypto'
import { JSONLoader } from 'langchain/document_loaders/fs/json'
import { TextLoader } from 'langchain/document_loaders/fs/text'
import { SplitterFactory } from '../splitter'
import { MarkdownLoader } from './MarkdownLoader'
import { NoteLoader } from './NoteLoader'
import { YoutubeLoader } from './YoutubeLoader'
const logger = loggerService.withContext('KnowledgeService File Loader')
type LoaderInstance =
| TextLoader
| PDFLoader
| PPTXLoader
| DocxLoader
| JSONLoader
| EPubLoader
| CheerioWebBaseLoader
| YoutubeLoader
| SitemapLoader
| NoteLoader
| MarkdownLoader
/**
* metadata
*/
function formatDocument(docs: Document[], type: string): Document[] {
return docs.map((doc) => ({
...doc,
metadata: {
...doc.metadata,
type: type
}
}))
}
/**
*
*/
async function processDocuments(
base: KnowledgeBaseParams,
vectorStore: FaissStore,
docs: Document[],
loaderType: string,
splitterType?: string
): Promise<LoaderReturn> {
const formattedDocs = formatDocument(docs, loaderType)
const splitter = SplitterFactory.create({
chunkSize: base.chunkSize,
chunkOverlap: base.chunkOverlap,
...(splitterType && { type: splitterType })
})
const splitterResults = await splitter.splitDocuments(formattedDocs)
const ids = splitterResults.map(() => randomUUID())
await vectorStore.addDocuments(splitterResults, { ids })
return {
entriesAdded: splitterResults.length,
uniqueId: ids[0] || '',
uniqueIds: ids,
loaderType
}
}
/**
*
*/
async function executeLoader(
base: KnowledgeBaseParams,
vectorStore: FaissStore,
loaderInstance: LoaderInstance,
loaderType: string,
identifier: string,
splitterType?: string
): Promise<LoaderReturn> {
const emptyResult: LoaderReturn = {
entriesAdded: 0,
uniqueId: '',
uniqueIds: [],
loaderType
}
try {
const docs = await loaderInstance.load()
return await processDocuments(base, vectorStore, docs, loaderType, splitterType)
} catch (error) {
logger.error(`Error loading or processing ${identifier} with loader ${loaderType}: ${error}`)
return emptyResult
}
}
/**
*
*/
const FILE_LOADER_MAP: Record<string, { loader: new (path: string) => LoaderInstance; type: string }> = {
'.pdf': { loader: PDFLoader, type: 'pdf' },
'.txt': { loader: TextLoader, type: 'text' },
'.pptx': { loader: PPTXLoader, type: 'pptx' },
'.docx': { loader: DocxLoader, type: 'docx' },
'.doc': { loader: DocxLoader, type: 'doc' },
'.json': { loader: JSONLoader, type: 'json' },
'.epub': { loader: EPubLoader, type: 'epub' },
'.md': { loader: MarkdownLoader, type: 'markdown' }
}
export async function addFileLoader(
base: KnowledgeBaseParams,
vectorStore: FaissStore,
file: FileMetadata
): Promise<LoaderReturn> {
const fileExt = file.ext.toLowerCase()
const loaderConfig = FILE_LOADER_MAP[fileExt]
if (!loaderConfig) {
// 默认使用文本加载器
const loaderInstance = new TextLoader(file.path)
const type = fileExt.replace('.', '') || 'unknown'
return executeLoader(base, vectorStore, loaderInstance, type, file.path)
}
const loaderInstance = new loaderConfig.loader(file.path)
return executeLoader(base, vectorStore, loaderInstance, loaderConfig.type, file.path)
}
export async function addWebLoader(
base: KnowledgeBaseParams,
vectorStore: FaissStore,
url: string,
source: UrlSource
): Promise<LoaderReturn> {
let loaderInstance: CheerioWebBaseLoader | YoutubeLoader | undefined
let splitterType: string | undefined
switch (source) {
case 'normal':
loaderInstance = new CheerioWebBaseLoader(url)
break
case 'youtube':
loaderInstance = YoutubeLoader.createFromUrl(url, {
addVideoInfo: true,
transcriptFormat: 'srt'
})
splitterType = 'srt'
break
}
if (!loaderInstance) {
return {
entriesAdded: 0,
uniqueId: '',
uniqueIds: [],
loaderType: source
}
}
return executeLoader(base, vectorStore, loaderInstance, source, url, splitterType)
}
export async function addSitemapLoader(
base: KnowledgeBaseParams,
vectorStore: FaissStore,
url: string
): Promise<LoaderReturn> {
const loaderInstance = new SitemapLoader(url)
return executeLoader(base, vectorStore, loaderInstance, 'sitemap', url)
}
export async function addNoteLoader(
base: KnowledgeBaseParams,
vectorStore: FaissStore,
content: string,
sourceUrl: string
): Promise<LoaderReturn> {
const loaderInstance = new NoteLoader(content, sourceUrl)
return executeLoader(base, vectorStore, loaderInstance, 'note', sourceUrl)
}
export async function addVideoLoader(
base: KnowledgeBaseParams,
vectorStore: FaissStore,
files: FileMetadata[]
): Promise<LoaderReturn> {
const srtFile = files.find((f) => f.type === FileTypes.TEXT)
const videoFile = files.find((f) => f.type === FileTypes.VIDEO)
const emptyResult: LoaderReturn = {
entriesAdded: 0,
uniqueId: '',
uniqueIds: [],
loaderType: 'video'
}
if (!srtFile || !videoFile) {
return emptyResult
}
try {
const loaderInstance = new TextLoader(srtFile.path)
const originalDocs = await loaderInstance.load()
const docsWithVideoMeta = originalDocs.map(
(doc) =>
new Document({
...doc,
metadata: {
...doc.metadata,
video: {
path: videoFile.path,
name: videoFile.origin_name
}
}
})
)
return await processDocuments(base, vectorStore, docsWithVideoMeta, 'video', 'srt')
} catch (error) {
logger.error(`Error loading or processing file ${srtFile.path} with loader video: ${error}`)
return emptyResult
}
}

View File

@ -0,0 +1,55 @@
import { BM25Retriever } from '@langchain/community/retrievers/bm25'
import { FaissStore } from '@langchain/community/vectorstores/faiss'
import { BaseRetriever } from '@langchain/core/retrievers'
import { loggerService } from '@main/services/LoggerService'
import { type KnowledgeBaseParams } from '@types'
import { type Document } from 'langchain/document'
import { EnsembleRetriever } from 'langchain/retrievers/ensemble'
const logger = loggerService.withContext('RetrieverFactory')
export class RetrieverFactory {
/**
* LangChain (Retriever)
* @param base
* @param vectorStore
* @param documents BM25Retriever
* @returns BaseRetriever
*/
public createRetriever(base: KnowledgeBaseParams, vectorStore: FaissStore, documents: Document[]): BaseRetriever {
const retrieverType = base.retriever?.mode ?? 'hybrid'
const retrieverWeight = base.retriever?.weight ?? 0.5
const searchK = base.documentCount ?? 5
logger.info(`Creating retriever of type: ${retrieverType} with k=${searchK}`)
switch (retrieverType) {
case 'bm25':
if (documents.length === 0) {
throw new Error('BM25Retriever requires documents, but none were provided or found.')
}
logger.info('Create BM25 Retriever')
return BM25Retriever.fromDocuments(documents, { k: searchK })
case 'hybrid': {
if (documents.length === 0) {
logger.warn('No documents provided for BM25 part of hybrid search. Falling back to vector search only.')
return vectorStore.asRetriever(searchK)
}
const vectorstoreRetriever = vectorStore.asRetriever(searchK)
const bm25Retriever = BM25Retriever.fromDocuments(documents, { k: searchK })
logger.info('Create Hybrid Retriever')
return new EnsembleRetriever({
retrievers: [bm25Retriever, vectorstoreRetriever],
weights: [retrieverWeight, 1 - retrieverWeight]
})
}
case 'vector':
default:
logger.info('Create Vector Retriever')
return vectorStore.asRetriever(searchK)
}
}
}

View File

@ -0,0 +1,133 @@
import { Document } from '@langchain/core/documents'
import { TextSplitter, TextSplitterParams } from 'langchain/text_splitter'
// 定义一个接口来表示解析后的单个字幕片段
interface SrtSegment {
text: string
startTime: number // in seconds
endTime: number // in seconds
}
// 辅助函数:将 SRT 时间戳字符串 (HH:MM:SS,ms) 转换为秒
function srtTimeToSeconds(time: string): number {
const parts = time.split(':')
const secondsAndMs = parts[2].split(',')
const hours = parseInt(parts[0], 10)
const minutes = parseInt(parts[1], 10)
const seconds = parseInt(secondsAndMs[0], 10)
const milliseconds = parseInt(secondsAndMs[1], 10)
return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000
}
export class SrtSplitter extends TextSplitter {
constructor(fields?: Partial<TextSplitterParams>) {
// 传入 chunkSize 和 chunkOverlap
super(fields)
}
splitText(): Promise<string[]> {
throw new Error('Method not implemented.')
}
// 核心方法:重写 splitDocuments 来实现自定义逻辑
async splitDocuments(documents: Document[]): Promise<Document[]> {
const allChunks: Document[] = []
for (const doc of documents) {
// 1. 解析 SRT 内容
const segments = this.parseSrt(doc.pageContent)
if (segments.length === 0) continue
// 2. 将字幕片段组合成块
const chunks = this.mergeSegmentsIntoChunks(segments, doc.metadata)
allChunks.push(...chunks)
}
return allChunks
}
// 辅助方法:解析整个 SRT 字符串
private parseSrt(srt: string): SrtSegment[] {
const segments: SrtSegment[] = []
const blocks = srt.trim().split(/\n\n/)
for (const block of blocks) {
const lines = block.split('\n')
if (lines.length < 3) continue
const timeMatch = lines[1].match(/(\d{2}:\d{2}:\d{2},\d{3}) --> (\d{2}:\d{2}:\d{2},\d{3})/)
if (!timeMatch) continue
const startTime = srtTimeToSeconds(timeMatch[1])
const endTime = srtTimeToSeconds(timeMatch[2])
const text = lines.slice(2).join(' ').trim()
segments.push({ text, startTime, endTime })
}
return segments
}
// 辅助方法:将解析后的片段合并成每 5 段一个块
private mergeSegmentsIntoChunks(segments: SrtSegment[], baseMetadata: Record<string, any>): Document[] {
const chunks: Document[] = []
let currentChunkText = ''
let currentChunkStartTime = 0
let currentChunkEndTime = 0
let segmentCount = 0
for (const segment of segments) {
if (segmentCount === 0) {
currentChunkStartTime = segment.startTime
}
currentChunkText += (currentChunkText ? ' ' : '') + segment.text
currentChunkEndTime = segment.endTime
segmentCount++
// 当累积到 5 段时,创建一个新的 Document
if (segmentCount === 5) {
const metadata: Record<string, any> = {
...baseMetadata,
startTime: currentChunkStartTime,
endTime: currentChunkEndTime
}
if (baseMetadata.source_url) {
metadata.source_url_with_timestamp = `${baseMetadata.source_url}?t=${Math.floor(currentChunkStartTime)}s`
}
chunks.push(
new Document({
pageContent: currentChunkText,
metadata
})
)
// 重置计数器和临时变量
currentChunkText = ''
currentChunkStartTime = 0
currentChunkEndTime = 0
segmentCount = 0
}
}
// 如果还有剩余的片段,创建最后一个 Document
if (segmentCount > 0) {
const metadata: Record<string, any> = {
...baseMetadata,
startTime: currentChunkStartTime,
endTime: currentChunkEndTime
}
if (baseMetadata.source_url) {
metadata.source_url_with_timestamp = `${baseMetadata.source_url}?t=${Math.floor(currentChunkStartTime)}s`
}
chunks.push(
new Document({
pageContent: currentChunkText,
metadata
})
)
}
return chunks
}
}

View File

@ -0,0 +1,31 @@
import { RecursiveCharacterTextSplitter, TextSplitter } from '@langchain/textsplitters'
import { SrtSplitter } from './SrtSplitter'
export type SplitterConfig = {
chunkSize?: number
chunkOverlap?: number
type?: 'recursive' | 'srt' | string
}
export class SplitterFactory {
/**
* Creates a TextSplitter instance based on the provided configuration.
* @param config - The configuration object specifying the splitter type and its parameters.
* @returns An instance of a TextSplitter, or null if no splitting is required.
*/
public static create(config: SplitterConfig): TextSplitter {
switch (config.type) {
case 'srt':
return new SrtSplitter({
chunkSize: config.chunkSize,
chunkOverlap: config.chunkOverlap
})
case 'recursive':
default:
return new RecursiveCharacterTextSplitter({
chunkSize: config.chunkSize,
chunkOverlap: config.chunkOverlap
})
}
}
}

View File

@ -0,0 +1,63 @@
import PreprocessProvider from '@main/knowledge/preprocess/PreprocessProvider'
import { loggerService } from '@main/services/LoggerService'
import { windowService } from '@main/services/WindowService'
import type { FileMetadata, KnowledgeBaseParams, KnowledgeItem } from '@types'
const logger = loggerService.withContext('PreprocessingService')
class PreprocessingService {
public async preprocessFile(
file: FileMetadata,
base: KnowledgeBaseParams,
item: KnowledgeItem,
userId: string
): Promise<FileMetadata> {
let fileToProcess: FileMetadata = file
// Check if preprocessing is configured and applicable (e.g., for PDFs)
if (base.preprocessProvider && file.ext.toLowerCase() === '.pdf') {
try {
const provider = new PreprocessProvider(base.preprocessProvider.provider, userId)
// Check if file has already been preprocessed
const alreadyProcessed = await provider.checkIfAlreadyProcessed(file)
if (alreadyProcessed) {
logger.debug(`File already preprocessed, using cached result: ${file.path}`)
return alreadyProcessed
}
// Execute preprocessing
logger.debug(`Starting preprocess for scanned PDF: ${file.path}`)
const { processedFile, quota } = await provider.parseFile(item.id, file)
fileToProcess = processedFile
// Notify the UI
const mainWindow = windowService.getMainWindow()
mainWindow?.webContents.send('file-preprocess-finished', {
itemId: item.id,
quota: quota
})
} catch (err) {
logger.error(`Preprocessing failed: ${err}`)
// If preprocessing fails, re-throw the error to be handled by the caller
throw new Error(`Preprocessing failed: ${err}`)
}
}
return fileToProcess
}
public async checkQuota(base: KnowledgeBaseParams, userId: string): Promise<number> {
try {
if (base.preprocessProvider && base.preprocessProvider.type === 'preprocess') {
const provider = new PreprocessProvider(base.preprocessProvider.provider, userId)
return await provider.checkQuota()
}
throw new Error('No preprocess provider configured')
} catch (err) {
logger.error(`Failed to check quota: ${err}`)
throw new Error(`Failed to check quota: ${err}`)
}
}
}
export const preprocessingService = new PreprocessingService()

View File

@ -1,101 +1,46 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
import { DEFAULT_DOCUMENT_COUNT, DEFAULT_RELEVANT_SCORE } from '@main/utils/knowledge'
import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types'
import { MultiModalDocument, RerankStrategy } from './strategies/RerankStrategy'
import { StrategyFactory } from './strategies/StrategyFactory'
export default abstract class BaseReranker {
protected base: KnowledgeBaseParams
protected strategy: RerankStrategy
constructor(base: KnowledgeBaseParams) {
if (!base.rerankApiClient) {
throw new Error('Rerank model is required')
}
this.base = base
this.strategy = StrategyFactory.createStrategy(base.rerankApiClient.provider)
}
abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]>
/**
* Get Rerank Request Url
*/
protected getRerankUrl() {
if (this.base.rerankApiClient?.provider === 'bailian') {
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
}
let baseURL = this.base.rerankApiClient?.baseURL
if (baseURL && baseURL.endsWith('/')) {
// `/` 结尾强制使用rerankBaseURL
return `${baseURL}rerank`
}
if (baseURL && !baseURL.endsWith('/v1')) {
baseURL = `${baseURL}/v1`
}
return `${baseURL}/rerank`
abstract rerank(query: string, searchResults: KnowledgeSearchResult[]): Promise<KnowledgeSearchResult[]>
protected getRerankUrl(): string {
return this.strategy.buildUrl(this.base.rerankApiClient?.baseURL)
}
/**
* Get Rerank Request Body
*/
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
const provider = this.base.rerankApiClient?.provider
const documents = searchResults.map((doc) => doc.pageContent)
const topN = this.base.documentCount
if (provider === 'voyageai') {
return {
model: this.base.rerankApiClient?.model,
query,
documents,
top_k: topN
}
} else if (provider === 'bailian') {
return {
model: this.base.rerankApiClient?.model,
input: {
query,
documents
},
parameters: {
top_n: topN
}
}
} else if (provider?.includes('tei')) {
return {
query,
texts: documents,
return_text: true
}
} else {
return {
model: this.base.rerankApiClient?.model,
query,
documents,
top_n: topN
}
}
protected getRerankRequestBody(query: string, searchResults: KnowledgeSearchResult[]) {
const documents = this.buildDocuments(searchResults)
const topN = this.base.documentCount ?? DEFAULT_DOCUMENT_COUNT
const model = this.base.rerankApiClient?.model
return this.strategy.buildRequestBody(query, documents, topN, model)
}
private buildDocuments(searchResults: KnowledgeSearchResult[]): MultiModalDocument[] {
return searchResults.map((doc) => {
const document: MultiModalDocument = {}
/**
* Extract Rerank Result
*/
// 检查是否是图片类型,添加图片内容
if (doc.metadata?.type === 'image') {
document.image = doc.pageContent
} else {
document.text = doc.pageContent
}
return document
})
}
protected extractRerankResult(data: any) {
const provider = this.base.rerankApiClient?.provider
if (provider === 'bailian') {
return data.output.results
} else if (provider === 'voyageai') {
return data.data
} else if (provider?.includes('tei')) {
return data.map((item: any) => {
return {
index: item.index,
relevance_score: item.score
}
})
} else {
return data.results
}
return this.strategy.extractResults(data)
}
/**
@ -105,35 +50,30 @@ export default abstract class BaseReranker {
* @protected
*/
protected getRerankResult(
searchResults: ExtractChunkData[],
rerankResults: Array<{
index: number
relevance_score: number
}>
searchResults: KnowledgeSearchResult[],
rerankResults: Array<{ index: number; relevance_score: number }>
) {
const resultMap = new Map(rerankResults.map((result) => [result.index, result.relevance_score || 0]))
const resultMap = new Map(
rerankResults.map((result) => [result.index, result.relevance_score || DEFAULT_RELEVANT_SCORE])
)
return searchResults
.map((doc: ExtractChunkData, index: number) => {
const returenResults = searchResults
.map((doc: KnowledgeSearchResult, index: number) => {
const score = resultMap.get(index)
if (score === undefined) return undefined
return {
...doc,
score
}
return { ...doc, score }
})
.filter((doc): doc is ExtractChunkData => doc !== undefined)
.filter((doc): doc is KnowledgeSearchResult => doc !== undefined)
.sort((a, b) => b.score - a.score)
}
return returenResults
}
public defaultHeaders() {
return {
Authorization: `Bearer ${this.base.rerankApiClient?.apiKey}`,
'Content-Type': 'application/json'
}
}
protected formatErrorMessage(url: string, error: any, requestBody: any) {
const errorDetails = {
url: url,

View File

@ -1,19 +1,14 @@
import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types'
import { net } from 'electron'
import BaseReranker from './BaseReranker'
export default class GeneralReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
public rerank = async (query: string, searchResults: KnowledgeSearchResult[]): Promise<KnowledgeSearchResult[]> => {
const url = this.getRerankUrl()
const requestBody = this.getRerankRequestBody(query, searchResults)
try {
const response = await net.fetch(url, {
method: 'POST',

View File

@ -1,5 +1,4 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types'
import GeneralReranker from './GeneralReranker'
@ -8,7 +7,7 @@ export default class Reranker {
constructor(base: KnowledgeBaseParams) {
this.sdk = new GeneralReranker(base)
}
public async rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> {
public async rerank(query: string, searchResults: KnowledgeSearchResult[]): Promise<KnowledgeSearchResult[]> {
return this.sdk.rerank(query, searchResults)
}
}

View File

@ -0,0 +1,18 @@
import { MultiModalDocument, RerankStrategy } from './RerankStrategy'
export class BailianStrategy implements RerankStrategy {
buildUrl(): string {
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
}
buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string) {
const textDocuments = documents.filter((d) => d.text).map((d) => d.text!)
return {
model,
input: { query, documents: textDocuments },
parameters: { top_n: topN }
}
}
extractResults(data: any) {
return data.output.results
}
}

View File

@ -0,0 +1,25 @@
import { MultiModalDocument, RerankStrategy } from './RerankStrategy'
export class DefaultStrategy implements RerankStrategy {
buildUrl(baseURL?: string): string {
if (baseURL && baseURL.endsWith('/')) {
return `${baseURL}rerank`
}
if (baseURL && !baseURL.endsWith('/v1')) {
baseURL = `${baseURL}/v1`
}
return `${baseURL}/rerank`
}
buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string) {
const textDocuments = documents.filter((d) => d.text).map((d) => d.text!)
return {
model,
query,
documents: textDocuments,
top_n: topN
}
}
extractResults(data: any) {
return data.results
}
}

View File

@ -0,0 +1,33 @@
import { MultiModalDocument, RerankStrategy } from './RerankStrategy'
export class JinaStrategy implements RerankStrategy {
buildUrl(baseURL?: string): string {
if (baseURL && baseURL.endsWith('/')) {
return `${baseURL}rerank`
}
if (baseURL && !baseURL.endsWith('/v1')) {
baseURL = `${baseURL}/v1`
}
return `${baseURL}/rerank`
}
buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string) {
if (model === 'jina-reranker-m0') {
return {
model,
query,
documents,
top_n: topN
}
}
const textDocuments = documents.filter((d) => d.text).map((d) => d.text!)
return {
model,
query,
documents: textDocuments,
top_n: topN
}
}
extractResults(data: any) {
return data.results
}
}

Some files were not shown because too many files have changed in this diff Show More