mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-31 16:49:07 +08:00
Merge branch 'main' of github.com:CherryHQ/cherry-studio into wip/data-refactor
This commit is contained in:
commit
2931e558b3
94
.github/ISSUE_TEMPLATE/#0_bug_report.yml
vendored
94
.github/ISSUE_TEMPLATE/#0_bug_report.yml
vendored
@ -1,94 +0,0 @@
|
||||
name: 🐛 错误报告 (中文)
|
||||
description: 创建一个报告以帮助我们改进
|
||||
title: '[错误]: '
|
||||
labels: ['BUG']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
感谢您花时间填写此错误报告!
|
||||
在提交此问题之前,请确保您已经了解了[常见问题](https://docs.cherry-ai.com/question-contact/questions)和[知识科普](https://docs.cherry-ai.com/question-contact/knowledge)
|
||||
|
||||
- type: checkboxes
|
||||
id: checklist
|
||||
attributes:
|
||||
label: 提交前检查
|
||||
description: |
|
||||
在提交 Issue 前请确保您已经完成了以下所有步骤
|
||||
options:
|
||||
- label: 我理解 Issue 是用于反馈和解决问题的,而非吐槽评论区,将尽可能提供更多信息帮助问题解决。
|
||||
required: true
|
||||
- label: 我的问题不是 [常见问题](https://github.com/CherryHQ/cherry-studio/issues/3881) 中的内容。
|
||||
required: true
|
||||
- label: 我已经查看了 **置顶 Issue** 并搜索了现有的 [开放Issue](https://github.com/CherryHQ/cherry-studio/issues)和[已关闭Issue](https://github.com/CherryHQ/cherry-studio/issues?q=is%3Aissue%20state%3Aclosed%20),没有找到类似的问题。
|
||||
required: true
|
||||
- label: 我填写了简短且清晰明确的标题,以便开发者在翻阅 Issue 列表时能快速确定大致问题。而不是“一个建议”、“卡住了”等。
|
||||
required: true
|
||||
- label: 我确认我正在使用最新版本的 Cherry Studio。
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: platform
|
||||
attributes:
|
||||
label: 平台
|
||||
description: 您正在使用哪个平台?
|
||||
options:
|
||||
- Windows
|
||||
- macOS
|
||||
- Linux
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: 版本
|
||||
description: 您正在运行的 Cherry Studio 版本是什么?
|
||||
placeholder: 例如 v1.0.0
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: 错误描述
|
||||
description: 描述问题时请尽可能详细。请尽可能提供截图或屏幕录制,以帮助我们更好地理解问题。
|
||||
placeholder: 告诉我们发生了什么...(记得附上截图/录屏,如果适用)
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
attributes:
|
||||
label: 重现步骤
|
||||
description: 提供详细的重现步骤,以便于我们的开发人员可以准确地重现问题。请尽可能为每个步骤提供截图或屏幕录制。
|
||||
placeholder: |
|
||||
1. 转到 '...'
|
||||
2. 点击 '....'
|
||||
3. 向下滚动到 '....'
|
||||
4. 看到错误
|
||||
|
||||
记得尽可能为每个步骤附上截图/录屏!
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: expected
|
||||
attributes:
|
||||
label: 预期行为
|
||||
description: 清晰简洁地描述您期望发生的事情
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: 相关日志输出
|
||||
description: 请复制并粘贴任何相关的日志输出
|
||||
render: shell
|
||||
|
||||
- type: textarea
|
||||
id: additional
|
||||
attributes:
|
||||
label: 附加信息
|
||||
description: 任何能让我们对你所遇到的问题有更多了解的东西
|
||||
76
.github/ISSUE_TEMPLATE/#1_feature_request.yml
vendored
76
.github/ISSUE_TEMPLATE/#1_feature_request.yml
vendored
@ -1,76 +0,0 @@
|
||||
name: 💡 功能建议 (中文)
|
||||
description: 为项目提出新的想法
|
||||
title: '[功能]: '
|
||||
labels: ['feature']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
感谢您花时间提出新的功能建议!
|
||||
在提交此问题之前,请确保您已经了解了[项目规划](https://docs.cherry-ai.com/cherrystudio/planning)和[功能介绍](https://docs.cherry-ai.com/cherrystudio/preview)
|
||||
|
||||
- type: checkboxes
|
||||
id: checklist
|
||||
attributes:
|
||||
label: 提交前检查
|
||||
description: |
|
||||
在提交 Issue 前请确保您已经完成了以下所有步骤
|
||||
options:
|
||||
- label: 我理解 Issue 是用于反馈和解决问题的,而非吐槽评论区,将尽可能提供更多信息帮助问题解决。
|
||||
required: true
|
||||
- label: 我已经查看了置顶 Issue 并搜索了现有的 [开放Issue](https://github.com/CherryHQ/cherry-studio/issues)和[已关闭Issue](https://github.com/CherryHQ/cherry-studio/issues?q=is%3Aissue%20state%3Aclosed%20),没有找到类似的建议。
|
||||
required: true
|
||||
- label: 我填写了简短且清晰明确的标题,以便开发者在翻阅 Issue 列表时能快速确定大致问题。而不是“一个建议”、“卡住了”等。
|
||||
required: true
|
||||
- label: 最新的 Cherry Studio 版本没有实现我所提出的功能。
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: platform
|
||||
attributes:
|
||||
label: 平台
|
||||
description: 您正在使用哪个平台?
|
||||
options:
|
||||
- Windows
|
||||
- macOS
|
||||
- Linux
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: 版本
|
||||
description: 您正在运行的 Cherry Studio 版本是什么?
|
||||
placeholder: 例如 v1.0.0
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: problem
|
||||
attributes:
|
||||
label: 您的功能建议是否与某个问题/issue相关?
|
||||
description: 请简明扼要地描述您遇到的问题
|
||||
placeholder: 我总是感到沮丧,因为...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: solution
|
||||
attributes:
|
||||
label: 请描述您希望实现的解决方案
|
||||
description: 请简明扼要地描述您希望发生的情况
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: alternatives
|
||||
attributes:
|
||||
label: 请描述您考虑过的其他方案
|
||||
description: 请简明扼要地描述您考虑过的任何其他解决方案或功能
|
||||
|
||||
- type: textarea
|
||||
id: additional
|
||||
attributes:
|
||||
label: 其他补充信息
|
||||
description: 在此添加任何其他与功能建议相关的上下文或截图
|
||||
77
.github/ISSUE_TEMPLATE/#2_question.yml
vendored
77
.github/ISSUE_TEMPLATE/#2_question.yml
vendored
@ -1,77 +0,0 @@
|
||||
name: ❓ 提问 & 讨论 (中文)
|
||||
description: 寻求帮助、讨论问题、提出疑问等...
|
||||
title: '[讨论]: '
|
||||
labels: ['discussion', 'help wanted']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
感谢您的提问!请尽可能详细地描述您的问题,这样我们才能更好地帮助您。
|
||||
|
||||
- type: checkboxes
|
||||
id: checklist
|
||||
attributes:
|
||||
label: Issue 检查清单
|
||||
description: |
|
||||
在提交 Issue 前请确保您已经完成了以下所有步骤
|
||||
options:
|
||||
- label: 我理解 Issue 是用于反馈和解决问题的,而非吐槽评论区,将尽可能提供更多信息帮助问题解决。
|
||||
required: true
|
||||
- label: 我确认自己需要的是提出问题并且讨论问题,而不是 Bug 反馈或需求建议。
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: platform
|
||||
attributes:
|
||||
label: 平台
|
||||
description: 您正在使用哪个平台?
|
||||
options:
|
||||
- Windows
|
||||
- macOS
|
||||
- Linux
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: 版本
|
||||
description: 您正在运行的 Cherry Studio 版本是什么?
|
||||
placeholder: 例如 v1.0.0
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: question
|
||||
attributes:
|
||||
label: 您的问题
|
||||
description: 请详细描述您的问题
|
||||
placeholder: 请尽可能清楚地说明您的问题...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: context
|
||||
attributes:
|
||||
label: 相关背景
|
||||
description: 请提供一些背景信息,帮助我们更好地理解您的问题
|
||||
placeholder: 例如:使用场景、已尝试的解决方案等
|
||||
|
||||
- type: textarea
|
||||
id: additional
|
||||
attributes:
|
||||
label: 补充信息
|
||||
description: 任何其他相关的信息、截图或代码示例
|
||||
render: shell
|
||||
|
||||
- type: dropdown
|
||||
id: priority
|
||||
attributes:
|
||||
label: 优先级
|
||||
description: 这个问题对您来说有多紧急?
|
||||
options:
|
||||
- 低 (有空再看)
|
||||
- 中 (希望尽快得到答复)
|
||||
- 高 (阻碍工作进行)
|
||||
validations:
|
||||
required: true
|
||||
76
.github/ISSUE_TEMPLATE/#3_others.yml
vendored
76
.github/ISSUE_TEMPLATE/#3_others.yml
vendored
@ -1,76 +0,0 @@
|
||||
name: 🤔 其他问题 (中文)
|
||||
description: 提交不属于错误报告或功能需求的问题
|
||||
title: '[其他]: '
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
感谢您花时间提出问题!
|
||||
在提交此问题之前,请确保您已经了解了[常见问题](https://docs.cherry-ai.com/question-contact/questions)和[知识科普](https://docs.cherry-ai.com/question-contact/knowledge)
|
||||
|
||||
- type: checkboxes
|
||||
id: checklist
|
||||
attributes:
|
||||
label: 提交前检查
|
||||
description: |
|
||||
在提交 Issue 前请确保您已经完成了以下所有步骤
|
||||
options:
|
||||
- label: 我理解 Issue 是用于反馈和解决问题的,而非吐槽评论区,将尽可能提供更多信息帮助问题解决。
|
||||
required: true
|
||||
- label: 我已经查看了置顶 Issue 并搜索了现有的 [开放Issue](https://github.com/CherryHQ/cherry-studio/issues)和[已关闭Issue](https://github.com/CherryHQ/cherry-studio/issues?q=is%3Aissue%20state%3Aclosed%20),没有找到类似的问题。
|
||||
required: true
|
||||
- label: 我填写了简短且清晰明确的标题,以便开发者在翻阅 Issue 列表时能快速确定大致问题。而不是"一个问题"、"求助"等。
|
||||
required: true
|
||||
- label: 我的问题不属于错误报告或功能需求类别。
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: platform
|
||||
attributes:
|
||||
label: 平台
|
||||
description: 您正在使用哪个平台?
|
||||
options:
|
||||
- Windows
|
||||
- macOS
|
||||
- Linux
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: 版本
|
||||
description: 您正在运行的 Cherry Studio 版本是什么?
|
||||
placeholder: 例如 v1.0.0
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: question
|
||||
attributes:
|
||||
label: 问题描述
|
||||
description: 请详细描述您的问题或疑问
|
||||
placeholder: 我想了解有关...的更多信息
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: context
|
||||
attributes:
|
||||
label: 相关背景
|
||||
description: 请提供与您的问题相关的任何背景信息或上下文
|
||||
placeholder: 我尝试实现...时遇到了疑问
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: attempts
|
||||
attributes:
|
||||
label: 您已尝试的方法
|
||||
description: 请描述您为解决问题已经尝试过的方法(如果有)
|
||||
|
||||
- type: textarea
|
||||
id: additional
|
||||
attributes:
|
||||
label: 附加信息
|
||||
description: 任何能让我们对您的问题有更多了解的信息,包括截图或相关链接
|
||||
66
.github/workflows/auto-i18n.yml
vendored
Normal file
66
.github/workflows/auto-i18n.yml
vendored
Normal file
@ -0,0 +1,66 @@
|
||||
name: Auto I18N
|
||||
|
||||
env:
|
||||
API_KEY: ${{ secrets.TRANSLATE_API_KEY}}
|
||||
MODEL: ${{ vars.MODEL || 'deepseek/deepseek-v3.1'}}
|
||||
BASE_URL: ${{ vars.BASE_URL || 'https://api.ppinfra.com/openai'}}
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
auto-i18n:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.head.repo.full_name == 'CherryHQ/cherry-studio'
|
||||
name: Auto I18N
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- name: 🐈⬛ Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
|
||||
- name: 📦 Setting Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
|
||||
- name: 📦 Install dependencies in isolated directory
|
||||
run: |
|
||||
# 在临时目录安装依赖
|
||||
mkdir -p /tmp/translation-deps
|
||||
cd /tmp/translation-deps
|
||||
echo '{"dependencies": {"openai": "^5.12.2", "cli-progress": "^3.12.0", "tsx": "^4.20.3", "prettier": "^3.5.3", "prettier-plugin-sort-json": "^4.1.1"}}' > package.json
|
||||
npm install --no-package-lock
|
||||
|
||||
# 设置 NODE_PATH 让项目能找到这些依赖
|
||||
echo "NODE_PATH=/tmp/translation-deps/node_modules" >> $GITHUB_ENV
|
||||
|
||||
- name: 🏃♀️ Translate
|
||||
run: npx tsx scripts/auto-translate-i18n.ts
|
||||
|
||||
- name: 🔍 Format
|
||||
run: cd /tmp/translation-deps && npx prettier --write --config /home/runner/work/cherry-studio/cherry-studio/.prettierrc /home/runner/work/cherry-studio/cherry-studio/src/renderer/src/i18n/
|
||||
|
||||
- name: 🔄 Commit changes
|
||||
run: |
|
||||
git config --local user.email "action@github.com"
|
||||
git config --local user.name "GitHub Action"
|
||||
git add .
|
||||
git reset -- package.json yarn.lock # 不提交 package.json 和 yarn.lock 的更改
|
||||
if git diff --cached --quiet; then
|
||||
echo "No changes to commit"
|
||||
else
|
||||
git commit -m "fix(i18n): Auto update translations for PR #${{ github.event.pull_request.number }}"
|
||||
fi
|
||||
|
||||
- name: 🚀 Push changes
|
||||
uses: ad-m/github-push-action@master
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
branch: ${{ github.event.pull_request.head.ref }}
|
||||
54
.github/workflows/claude-code-review.yml
vendored
Normal file
54
.github/workflows/claude-code-review.yml
vendored
Normal file
@ -0,0 +1,54 @@
|
||||
name: Claude Code Review
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
# Optional: Only run on specific file changes
|
||||
# paths:
|
||||
# - "src/**/*.ts"
|
||||
# - "src/**/*.tsx"
|
||||
# - "src/**/*.js"
|
||||
# - "src/**/*.jsx"
|
||||
|
||||
jobs:
|
||||
claude-review:
|
||||
# Optional: Filter by PR author
|
||||
# if: |
|
||||
# github.event.pull_request.user.login == 'external-contributor' ||
|
||||
# github.event.pull_request.user.login == 'new-developer' ||
|
||||
# github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR'
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Run Claude Code Review
|
||||
id: claude-review
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
prompt: |
|
||||
Please review this pull request and provide feedback on:
|
||||
- Code quality and best practices
|
||||
- Potential bugs or issues
|
||||
- Performance considerations
|
||||
- Security concerns
|
||||
- Test coverage
|
||||
|
||||
Use the repository's CLAUDE.md for guidance on style and conventions. Be constructive and helpful in your feedback.
|
||||
|
||||
Use `gh pr comment` with your Bash tool to leave your review as a comment on the PR.
|
||||
|
||||
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
|
||||
# or https://docs.anthropic.com/en/docs/claude-code/sdk#command-line for available options
|
||||
claude_args: '--allowed-tools "Bash(gh issue view:*),Bash(gh search:*),Bash(gh issue list:*),Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*)"'
|
||||
69
.github/workflows/claude-translator.yml
vendored
Normal file
69
.github/workflows/claude-translator.yml
vendored
Normal file
@ -0,0 +1,69 @@
|
||||
name: English Translator
|
||||
concurrency:
|
||||
group: translator-${{ github.event.issue.number }}
|
||||
cancel-in-progress: false
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened]
|
||||
issue_comment:
|
||||
types: [created, edited]
|
||||
|
||||
jobs:
|
||||
translate:
|
||||
if: |
|
||||
(github.event_name == 'issues' && github.event.issue.author_association == 'COLLABORATOR' && !contains(github.event.issue.body, 'This issue/comment was translated by Claude.')) ||
|
||||
(github.event_name == 'issue_comment' && github.event.comment.author_association == 'COLLABORATOR' && !contains(github.event.issue.body, 'This issue/comment was translated by Claude.'))
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write # 编辑issues/comments
|
||||
pull-requests: read
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Run Claude for translation
|
||||
uses: anthropics/claude-code-action@v1
|
||||
id: claude
|
||||
with:
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
claude_args: '--allowed-tools mcp__github_comment__update_claude_comment,Bash(gh issue:*),Bash(gh api:repos/*/issues:*)'
|
||||
prompt: |
|
||||
你是一个多语言翻译助手。请完成以下任务:
|
||||
|
||||
1. 获取当前issue/comment的完整信息
|
||||
2. 智能检测内容。
|
||||
1. 如果是已经遵循格式要求翻译过的issue/comment,检查翻译内容和原始内容是否匹配。若不匹配,则重新翻译一次令其匹配,并遵循格式要求;若匹配,则跳过任务。
|
||||
2. 如果是未翻译过的issue/comment,检查其内容语言。若不是英文,则翻译成英文;若已经是英文,则跳过任务。
|
||||
3. 格式要求:
|
||||
- 标题:英文翻译(如果非英文)
|
||||
- 内容格式:
|
||||
> [!NOTE]
|
||||
> This issue/comment was translated by Claude.
|
||||
|
||||
[英文翻译内容]
|
||||
|
||||
---
|
||||
<details>
|
||||
<summary>**Original Content:**</summary>
|
||||
[原始内容]
|
||||
</details>
|
||||
|
||||
4. 使用gh工具更新:
|
||||
- 根据环境信息中的Event类型选择正确的命令:
|
||||
- 如果Event是'issues':gh issue edit [ISSUE_NUMBER] --title "[英文标题]" --body "[翻译内容 + 原始内容]"
|
||||
- 如果Event是'issue_comment':gh api -X PATCH /repos/[REPO]/issues/comments/[COMMENT_ID] -f body="[翻译内容 + 原始内容]"
|
||||
|
||||
环境信息:
|
||||
- Event: ${{ github.event_name }}
|
||||
- Issue Number: ${{ github.event.issue.number }}
|
||||
- Repository: ${{ github.repository }}
|
||||
- Comment ID: ${{ github.event.comment.id || 'N/A' }} (only available for comment events)
|
||||
|
||||
使用以下命令获取完整信息:
|
||||
gh issue view ${{ github.event.issue.number }} --json title,body,comments
|
||||
60
.github/workflows/claude.yml
vendored
Normal file
60
.github/workflows/claude.yml
vendored
Normal file
@ -0,0 +1,60 @@
|
||||
name: Claude Code
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
pull_request_review_comment:
|
||||
types: [created]
|
||||
issues:
|
||||
types: [opened]
|
||||
pull_request_review:
|
||||
types: [submitted]
|
||||
|
||||
jobs:
|
||||
claude:
|
||||
if: |
|
||||
(github.event_name == 'issue_comment'
|
||||
&& contains(github.event.comment.body, '@claude')
|
||||
&& contains(fromJSON('["COLLABORATOR","MEMBER","OWNER"]'), github.event.comment.author_association))
|
||||
||
|
||||
(github.event_name == 'pull_request_review_comment'
|
||||
&& contains(github.event.comment.body, '@claude')
|
||||
&& contains(fromJSON('["COLLABORATOR","MEMBER","OWNER"]'), github.event.comment.author_association))
|
||||
||
|
||||
(github.event_name == 'pull_request_review'
|
||||
&& contains(github.event.review.body, '@claude')
|
||||
&& contains(fromJSON('["COLLABORATOR","MEMBER","OWNER"]'), github.event.review.author_association))
|
||||
||
|
||||
(github.event_name == 'issues'
|
||||
&& (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude'))
|
||||
&& contains(fromJSON('["COLLABORATOR","MEMBER","OWNER"]'), github.event.issue.author_association))
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for Claude to read CI results on PRs
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
|
||||
# This is an optional setting that allows Claude to read CI results on PRs
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
|
||||
# Optional: Give a custom prompt to Claude. If this is not specified, Claude will perform the instructions specified in the comment that tagged it.
|
||||
# prompt: 'Update the pull request description to include a summary of changes.'
|
||||
|
||||
# Optional: Add claude_args to customize behavior and configuration
|
||||
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
|
||||
# or https://docs.anthropic.com/en/docs/claude-code/sdk#command-line for available options
|
||||
# claude_args: '--model claude-opus-4-1-20250805 --allowed-tools Bash(gh pr:*)'
|
||||
3
.github/workflows/pr-ci.yml
vendored
3
.github/workflows/pr-ci.yml
vendored
@ -45,6 +45,9 @@ jobs:
|
||||
- name: Install Dependencies
|
||||
run: yarn install
|
||||
|
||||
- name: Format Check
|
||||
run: yarn format:check
|
||||
|
||||
- name: Lint Check
|
||||
run: yarn test:lint
|
||||
|
||||
|
||||
@ -7,4 +7,6 @@ tsconfig.*.json
|
||||
CHANGELOG*.md
|
||||
agents.json
|
||||
src/renderer/src/integration/nutstore/sso/lib
|
||||
src/main/integration/cherryin/index.js
|
||||
AGENT.md
|
||||
src/main/integration/
|
||||
.yarn/releases/
|
||||
|
||||
45
LICENSE
45
LICENSE
@ -1,48 +1,3 @@
|
||||
**许可协议 (Licensing)**
|
||||
|
||||
本项目采用**区分用户的双重许可 (User-Segmented Dual Licensing)** 模式。
|
||||
|
||||
**核心原则:**
|
||||
|
||||
* **个人用户 和 10人及以下企业/组织:** 默认适用 **GNU Affero 通用公共许可证 v3.0 (AGPLv3)**。
|
||||
* **超过10人的企业/组织:** **必须** 获取 **商业许可证 (Commercial License)**。
|
||||
|
||||
定义:“10人及以下”
|
||||
指在您的组织(包括公司、非营利组织、政府机构、教育机构等任何实体)中,能够访问、使用或以任何方式直接或间接受益于本软件(Cherry Studio)功能的个人总数不超过10人。这包括但不限于开发者、测试人员、运营人员、最终用户、通过集成系统间接使用者等。
|
||||
|
||||
---
|
||||
|
||||
**1. 开源许可证 (Open Source License): AGPLv3 - 适用于个人及10人及以下组织**
|
||||
|
||||
* 如果您是个人用户,或者您的组织满足上述“10人及以下”的定义,您可以在 **AGPLv3** 的条款下自由使用、修改和分发 Cherry Studio。AGPLv3 的完整文本可以访问 [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html) 获取。
|
||||
* **核心义务:** AGPLv3 的一个关键要求是,如果您修改了 Cherry Studio 并通过网络提供服务,或者分发了修改后的版本,您必须以 AGPLv3 许可证向接收者提供相应的**完整源代码**。即使您符合“10人及以下”的标准,如果您希望避免此源代码公开义务,您也需要考虑获取商业许可证(见下文)。
|
||||
* 使用前请务必仔细阅读并理解 AGPLv3 的所有条款。
|
||||
|
||||
**2. 商业许可证 (Commercial License) - 适用于超过10人的组织,或希望规避 AGPLv3 义务的用户**
|
||||
|
||||
* **强制要求:** 如果您的组织**不**满足上述“10人及以下”的定义(即有11人或更多人可以访问、使用或受益于本软件),您**必须**联系我们获取并签署一份商业许可证才能使用 Cherry Studio。
|
||||
* **自愿选择:** 即使您的组织满足“10人及以下”的条件,但如果您的使用场景**无法满足 AGPLv3 的条款要求**(特别是关于**源代码公开**的义务),或者您需要 AGPLv3 **未提供**的特定商业条款(如保证、赔偿、无 Copyleft 限制等),您也**必须**联系我们获取并签署一份商业许可证。
|
||||
* **需要商业许可证的常见情况包括(但不限于):**
|
||||
* 您的组织规模超过10人。
|
||||
* (无论组织规模)您希望分发修改过的 Cherry Studio 版本,但**不希望**根据 AGPLv3 公开您修改部分的源代码。
|
||||
* (无论组织规模)您希望基于修改过的 Cherry Studio 提供网络服务(SaaS),但**不希望**根据 AGPLv3 向服务使用者提供修改后的源代码。
|
||||
* (无论组织规模)您的公司政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件,或要求闭源分发及保密。
|
||||
* 商业许可证将为您提供豁免 AGPLv3 义务(如源代码公开)的权利,并可能包含额外的商业保障条款。
|
||||
* **获取商业许可:** 请通过邮箱 **bd@cherry-ai.com** 联系 Cherry Studio 开发团队洽谈商业授权事宜。
|
||||
|
||||
**3. 贡献 (Contributions)**
|
||||
|
||||
* 我们欢迎社区对 Cherry Studio 的贡献。所有向本项目提交的贡献都将被视为在 **AGPLv3** 许可证下提供。
|
||||
* 通过向本项目提交贡献(例如通过 Pull Request),即表示您同意您的代码以 AGPLv3 许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。
|
||||
* 您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 Cherry Studio 版本中。
|
||||
|
||||
**4. 其他条款 (Other Terms)**
|
||||
|
||||
* 关于商业许可证的具体条款和条件,以双方签署的正式商业许可协议为准。
|
||||
* 项目维护者保留根据需要更新本许可政策(包括用户规模定义和阈值)的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。
|
||||
|
||||
---
|
||||
|
||||
**Licensing**
|
||||
|
||||
This project employs a **User-Segmented Dual Licensing** model.
|
||||
|
||||
@ -8,9 +8,9 @@
|
||||
|
||||
| 字段名 | 类型 | 是否主键 | 索引 | 说明 |
|
||||
| ---------- | ------ | -------- | ---- | ------------------------------------------------------------------------ |
|
||||
| `id` | string | ✅ 是 | ✅ | 唯一标识符,主键 |
|
||||
| `langCode` | string | ❌ 否 | ✅ | 语言代码(如:`zh-cn`, `en-us`, `ja-jp` 等,均为小写),支持普通索引查询 |
|
||||
| `value` | string | ❌ 否 | ❌ | 语言的名称,用户输入 |
|
||||
| `emoji` | string | ❌ 否 | ❌ | 语言的emoji,用户输入 |
|
||||
| `id` | string | ✅ 是 | ✅ | 唯一标识符,主键 |
|
||||
| `langCode` | string | ❌ 否 | ✅ | 语言代码(如:`zh-cn`, `en-us`, `ja-jp` 等,均为小写),支持普通索引查询 |
|
||||
| `value` | string | ❌ 否 | ❌ | 语言的名称,用户输入 |
|
||||
| `emoji` | string | ❌ 否 | ❌ | 语言的emoji,用户输入 |
|
||||
|
||||
> `langCode` 虽非主键,但在业务层应当避免重复插入相同语言代码。
|
||||
|
||||
@ -124,24 +124,25 @@ afterSign: scripts/notarize.js
|
||||
artifactBuildCompleted: scripts/artifact-build-completed.js
|
||||
releaseInfo:
|
||||
releaseNotes: |
|
||||
✨ 重要更新:
|
||||
- 新增笔记模块,支持富文本编辑和管理
|
||||
- 内置 GLM-4.5-Flash 免费模型(由智谱开放平台提供)
|
||||
- 内置 Qwen3-8B 免费模型(由硅基流动提供)
|
||||
- 新增 Nano Banana(Gemini 2.5 Flash Image)模型支持
|
||||
- 新增系统 OCR 功能 (macOS & Windows)
|
||||
- 新增图片 OCR 识别和翻译功能
|
||||
- 模型切换支持通过标签筛选
|
||||
- 翻译功能增强:历史搜索和收藏
|
||||
✨ 新功能:
|
||||
- 重构知识库模块,提升文档处理能力和搜索性能
|
||||
- 新增 PaddleOCR 支持,增强文档识别能力
|
||||
- 支持自定义窗口控制按钮样式
|
||||
- 新增 AI SDK 包,扩展 AI 能力集成
|
||||
- 支持标签页拖拽重排序功能
|
||||
- 增强笔记编辑器的同步和日志功能
|
||||
|
||||
🔧 性能优化:
|
||||
- 优化历史页面搜索性能
|
||||
- 优化拖拽列表组件交互
|
||||
- 升级 Electron 到 37.4.0
|
||||
- 优化 MCP 服务的日志记录和错误处理
|
||||
- 改进 WebView 服务的 User-Agent 处理
|
||||
- 优化迷你应用的标题栏样式和状态栏适配
|
||||
- 重构依赖管理,清理和优化 package.json
|
||||
|
||||
🐛 修复问题:
|
||||
- 修复知识库加密 PDF 文档处理
|
||||
- 修复导航栏在左侧时笔记侧边栏按钮缺失
|
||||
- 修复多个模型兼容性问题
|
||||
- 修复 MCP 相关问题
|
||||
- 其他稳定性改进
|
||||
🐛 问题修复:
|
||||
- 修复输入栏无限状态更新循环问题
|
||||
- 修复窗口控制提示框的鼠标悬停延迟
|
||||
- 修复翻译输入框粘贴多内容源的处理
|
||||
- 修复导航服务初始化时序问题
|
||||
- 修复 MCP 通过 JSON 添加时的参数转换
|
||||
- 修复模型作用域服务器同步时的 URL 格式
|
||||
- 标准化工具提示图标样式
|
||||
|
||||
@ -99,6 +99,9 @@ export default defineConfig({
|
||||
'@data': resolve('src/renderer/src/data'),
|
||||
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
|
||||
'@mcp-trace/trace-web': resolve('packages/mcp-trace/trace-web'),
|
||||
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
|
||||
'@cherrystudio/ai-core/built-in/plugins': resolve('packages/aiCore/src/core/plugins/built-in'),
|
||||
'@cherrystudio/ai-core': resolve('packages/aiCore/src'),
|
||||
'@cherrystudio/extension-table-plus': resolve('packages/extension-table-plus/src')
|
||||
}
|
||||
},
|
||||
|
||||
27
package.json
27
package.json
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "CherryStudio",
|
||||
"version": "1.5.9",
|
||||
"version": "1.6.0-beta.7",
|
||||
"private": true,
|
||||
"description": "A powerful AI assistant for producer.",
|
||||
"main": "./out/main/index.js",
|
||||
@ -66,6 +66,7 @@
|
||||
"test:lint": "eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts",
|
||||
"test:scripts": "vitest scripts",
|
||||
"format": "prettier --write .",
|
||||
"format:check": "prettier --check .",
|
||||
"lint": "eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix && yarn typecheck && yarn check:i18n",
|
||||
"prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky",
|
||||
"migrations:generate": "drizzle-kit generate --config ./migrations/sqlite-drizzle.config.ts"
|
||||
@ -75,6 +76,7 @@
|
||||
"@libsql/win32-x64-msvc": "^0.4.7",
|
||||
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
|
||||
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
||||
"faiss-node": "^0.5.1",
|
||||
"graceful-fs": "^4.2.11",
|
||||
"jsdom": "26.1.0",
|
||||
"node-stream-zip": "^1.15.0",
|
||||
@ -89,12 +91,16 @@
|
||||
"@agentic/exa": "^7.3.3",
|
||||
"@agentic/searxng": "^7.3.3",
|
||||
"@agentic/tavily": "^7.3.3",
|
||||
"@ai-sdk/amazon-bedrock": "^3.0.0",
|
||||
"@ai-sdk/google-vertex": "^3.0.0",
|
||||
"@ai-sdk/mistral": "^2.0.0",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
||||
"@aws-sdk/client-bedrock": "^3.840.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.840.0",
|
||||
"@aws-sdk/client-s3": "^3.840.0",
|
||||
"@cherrystudio/ai-core": "workspace:*",
|
||||
"@cherrystudio/embedjs": "^0.1.31",
|
||||
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
||||
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
||||
@ -124,12 +130,15 @@
|
||||
"@google/genai": "patch:@google/genai@npm%3A1.0.1#~/.yarn/patches/@google-genai-npm-1.0.1-e26f0f9af7.patch",
|
||||
"@hello-pangea/dnd": "^18.0.1",
|
||||
"@kangfenmao/keyv-storage": "^0.1.0",
|
||||
"@langchain/community": "^0.3.36",
|
||||
"@langchain/community": "^0.3.50",
|
||||
"@langchain/core": "^0.3.68",
|
||||
"@langchain/ollama": "^0.2.1",
|
||||
"@langchain/openai": "^0.6.7",
|
||||
"@mistralai/mistralai": "^1.7.5",
|
||||
"@modelcontextprotocol/sdk": "^1.17.0",
|
||||
"@mozilla/readability": "^0.6.0",
|
||||
"@notionhq/client": "^2.2.15",
|
||||
"@openrouter/ai-sdk-provider": "^1.1.2",
|
||||
"@opentelemetry/api": "^1.9.0",
|
||||
"@opentelemetry/core": "2.0.0",
|
||||
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
|
||||
@ -139,7 +148,7 @@
|
||||
"@playwright/test": "^1.52.0",
|
||||
"@reduxjs/toolkit": "^2.2.5",
|
||||
"@shikijs/markdown-it": "^3.12.0",
|
||||
"@swc/plugin-styled-components": "^7.1.5",
|
||||
"@swc/plugin-styled-components": "^8.0.4",
|
||||
"@tanstack/react-query": "^5.85.5",
|
||||
"@tanstack/react-virtual": "^3.13.12",
|
||||
"@testing-library/dom": "^10.4.0",
|
||||
@ -167,9 +176,11 @@
|
||||
"@types/cli-progress": "^3",
|
||||
"@types/fs-extra": "^11",
|
||||
"@types/he": "^1",
|
||||
"@types/html-to-text": "^9",
|
||||
"@types/lodash": "^4.17.5",
|
||||
"@types/markdown-it": "^14",
|
||||
"@types/md5": "^2.3.5",
|
||||
"@types/mime-types": "^3",
|
||||
"@types/node": "^22.17.1",
|
||||
"@types/pako": "^1.0.2",
|
||||
"@types/react": "^19.0.12",
|
||||
@ -190,12 +201,14 @@
|
||||
"@viz-js/lang-dot": "^1.0.5",
|
||||
"@viz-js/viz": "^3.14.0",
|
||||
"@xyflow/react": "^12.4.4",
|
||||
"ai": "^5.0.29",
|
||||
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
|
||||
"archiver": "^7.0.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
"axios": "^1.7.3",
|
||||
"browser-image-compression": "^2.0.2",
|
||||
"chardet": "^2.1.0",
|
||||
"cheerio": "^1.1.2",
|
||||
"chokidar": "^4.0.3",
|
||||
"cli-progress": "^3.12.0",
|
||||
"code-inspector-plugin": "^0.20.14",
|
||||
@ -234,6 +247,7 @@
|
||||
"he": "^1.2.0",
|
||||
"html-tags": "^5.1.0",
|
||||
"html-to-image": "^1.11.13",
|
||||
"html-to-text": "^9.0.5",
|
||||
"htmlparser2": "^10.0.0",
|
||||
"husky": "^9.1.7",
|
||||
"i18next": "^23.11.5",
|
||||
@ -250,12 +264,14 @@
|
||||
"markdown-it": "^14.1.0",
|
||||
"mermaid": "^11.10.1",
|
||||
"mime": "^4.0.4",
|
||||
"mime-types": "^3.0.1",
|
||||
"motion": "^12.10.5",
|
||||
"notion-helper": "^1.3.22",
|
||||
"npx-scope-finder": "^1.2.0",
|
||||
"openai": "patch:openai@npm%3A5.12.2#~/.yarn/patches/openai-npm-5.12.2-30b075401c.patch",
|
||||
"p-queue": "^8.1.0",
|
||||
"pdf-lib": "^1.17.1",
|
||||
"pdf-parse": "^1.1.1",
|
||||
"playwright": "^1.52.0",
|
||||
"prettier": "^3.5.3",
|
||||
"prettier-plugin-sort-json": "^4.1.1",
|
||||
@ -268,6 +284,7 @@
|
||||
"react-infinite-scroll-component": "^6.1.0",
|
||||
"react-json-view": "^1.21.3",
|
||||
"react-markdown": "^10.1.0",
|
||||
"react-player": "^3.3.1",
|
||||
"react-redux": "^9.1.2",
|
||||
"react-router": "6",
|
||||
"react-router-dom": "6",
|
||||
@ -309,7 +326,9 @@
|
||||
"winston-daily-rotate-file": "^5.0.0",
|
||||
"word-extractor": "^1.0.4",
|
||||
"y-protocols": "^1.0.6",
|
||||
"yaml": "^2.8.1",
|
||||
"yjs": "^13.6.27",
|
||||
"youtubei.js": "^15.0.1",
|
||||
"zipread": "^1.3.3",
|
||||
"zod": "^3.25.74"
|
||||
},
|
||||
@ -340,7 +359,7 @@
|
||||
"prettier --write",
|
||||
"eslint --fix"
|
||||
],
|
||||
"*.{json,md,yml,yaml,css,scss,html}": [
|
||||
"*.{json,yml,yaml,css,scss,html}": [
|
||||
"prettier --write"
|
||||
]
|
||||
}
|
||||
|
||||
514
packages/aiCore/AI_SDK_ARCHITECTURE.md
Normal file
514
packages/aiCore/AI_SDK_ARCHITECTURE.md
Normal file
@ -0,0 +1,514 @@
|
||||
# AI Core 基于 Vercel AI SDK 的技术架构
|
||||
|
||||
## 1. 架构设计理念
|
||||
|
||||
### 1.1 设计目标
|
||||
|
||||
- **简化分层**:`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
|
||||
- **统一接口**:使用 Vercel AI SDK 统一不同 AI Provider 的接口差异
|
||||
- **动态导入**:通过动态导入实现按需加载,减少打包体积
|
||||
- **最小包装**:直接使用 AI SDK 的类型和接口,避免重复定义
|
||||
- **插件系统**:基于钩子的通用插件架构,支持请求全生命周期扩展
|
||||
- **类型安全**:利用 TypeScript 和 AI SDK 的类型系统确保类型安全
|
||||
- **轻量级**:专注核心功能,保持包的轻量和高效
|
||||
- **包级独立**:作为独立包管理,便于复用和维护
|
||||
- **Agent就绪**:为将来集成 OpenAI Agents SDK 预留扩展空间
|
||||
|
||||
### 1.2 核心优势
|
||||
|
||||
- **标准化**:AI SDK 提供统一的模型接口,减少适配工作
|
||||
- **简化设计**:函数式API,避免过度抽象
|
||||
- **更好的开发体验**:完整的 TypeScript 支持和丰富的生态系统
|
||||
- **性能优化**:AI SDK 内置优化和最佳实践
|
||||
- **模块化设计**:独立包结构,支持跨项目复用
|
||||
- **可扩展插件**:通用的流转换和参数处理插件系统
|
||||
- **面向未来**:为 OpenAI Agents SDK 集成做好准备
|
||||
|
||||
## 2. 整体架构图
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph "用户应用 (如 Cherry Studio)"
|
||||
UI["用户界面"]
|
||||
Components["应用组件"]
|
||||
end
|
||||
|
||||
subgraph "packages/aiCore (AI Core 包)"
|
||||
subgraph "Runtime Layer (运行时层)"
|
||||
RuntimeExecutor["RuntimeExecutor (运行时执行器)"]
|
||||
PluginEngine["PluginEngine (插件引擎)"]
|
||||
RuntimeAPI["Runtime API (便捷函数)"]
|
||||
end
|
||||
|
||||
subgraph "Models Layer (模型层)"
|
||||
ModelFactory["createModel() (模型工厂)"]
|
||||
ProviderCreator["ProviderCreator (提供商创建器)"]
|
||||
end
|
||||
|
||||
subgraph "Core Systems (核心系统)"
|
||||
subgraph "Plugins (插件)"
|
||||
PluginManager["PluginManager (插件管理)"]
|
||||
BuiltInPlugins["Built-in Plugins (内置插件)"]
|
||||
StreamTransforms["Stream Transforms (流转换)"]
|
||||
end
|
||||
|
||||
subgraph "Middleware (中间件)"
|
||||
MiddlewareWrapper["wrapModelWithMiddlewares() (中间件包装)"]
|
||||
end
|
||||
|
||||
subgraph "Providers (提供商)"
|
||||
Registry["Provider Registry (注册表)"]
|
||||
Factory["Provider Factory (工厂)"]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
subgraph "Vercel AI SDK"
|
||||
AICore["ai (核心库)"]
|
||||
OpenAI["@ai-sdk/openai"]
|
||||
Anthropic["@ai-sdk/anthropic"]
|
||||
Google["@ai-sdk/google"]
|
||||
XAI["@ai-sdk/xai"]
|
||||
Others["其他 19+ Providers"]
|
||||
end
|
||||
|
||||
subgraph "Future: OpenAI Agents SDK"
|
||||
AgentSDK["@openai/agents (未来集成)"]
|
||||
AgentExtensions["Agent Extensions (预留)"]
|
||||
end
|
||||
|
||||
UI --> RuntimeAPI
|
||||
Components --> RuntimeExecutor
|
||||
RuntimeAPI --> RuntimeExecutor
|
||||
RuntimeExecutor --> PluginEngine
|
||||
RuntimeExecutor --> ModelFactory
|
||||
PluginEngine --> PluginManager
|
||||
ModelFactory --> ProviderCreator
|
||||
ModelFactory --> MiddlewareWrapper
|
||||
ProviderCreator --> Registry
|
||||
Registry --> Factory
|
||||
Factory --> OpenAI
|
||||
Factory --> Anthropic
|
||||
Factory --> Google
|
||||
Factory --> XAI
|
||||
Factory --> Others
|
||||
|
||||
RuntimeExecutor --> AICore
|
||||
AICore --> streamText
|
||||
AICore --> generateText
|
||||
AICore --> streamObject
|
||||
AICore --> generateObject
|
||||
|
||||
PluginManager --> StreamTransforms
|
||||
PluginManager --> BuiltInPlugins
|
||||
|
||||
%% 未来集成路径
|
||||
RuntimeExecutor -.-> AgentSDK
|
||||
AgentSDK -.-> AgentExtensions
|
||||
```
|
||||
|
||||
## 3. 包结构设计
|
||||
|
||||
### 3.1 新架构文件结构
|
||||
|
||||
```
|
||||
packages/aiCore/
|
||||
├── src/
|
||||
│ ├── core/ # 核心层 - 内部实现
|
||||
│ │ ├── models/ # 模型层 - 模型创建和配置
|
||||
│ │ │ ├── factory.ts # 模型工厂函数 ✅
|
||||
│ │ │ ├── ModelCreator.ts # 模型创建器 ✅
|
||||
│ │ │ ├── ConfigManager.ts # 配置管理器 ✅
|
||||
│ │ │ ├── types.ts # 模型类型定义 ✅
|
||||
│ │ │ └── index.ts # 模型层导出 ✅
|
||||
│ │ ├── runtime/ # 运行时层 - 执行和用户API
|
||||
│ │ │ ├── executor.ts # 运行时执行器 ✅
|
||||
│ │ │ ├── pluginEngine.ts # 插件引擎 ✅
|
||||
│ │ │ ├── types.ts # 运行时类型定义 ✅
|
||||
│ │ │ └── index.ts # 运行时导出 ✅
|
||||
│ │ ├── middleware/ # 中间件系统
|
||||
│ │ │ ├── wrapper.ts # 模型包装器 ✅
|
||||
│ │ │ ├── manager.ts # 中间件管理器 ✅
|
||||
│ │ │ ├── types.ts # 中间件类型 ✅
|
||||
│ │ │ └── index.ts # 中间件导出 ✅
|
||||
│ │ ├── plugins/ # 插件系统
|
||||
│ │ │ ├── types.ts # 插件类型定义 ✅
|
||||
│ │ │ ├── manager.ts # 插件管理器 ✅
|
||||
│ │ │ ├── built-in/ # 内置插件 ✅
|
||||
│ │ │ │ ├── logging.ts # 日志插件 ✅
|
||||
│ │ │ │ ├── webSearchPlugin/ # 网络搜索插件 ✅
|
||||
│ │ │ │ ├── toolUsePlugin/ # 工具使用插件 ✅
|
||||
│ │ │ │ └── index.ts # 内置插件导出 ✅
|
||||
│ │ │ ├── README.md # 插件文档 ✅
|
||||
│ │ │ └── index.ts # 插件导出 ✅
|
||||
│ │ ├── providers/ # 提供商管理
|
||||
│ │ │ ├── registry.ts # 提供商注册表 ✅
|
||||
│ │ │ ├── factory.ts # 提供商工厂 ✅
|
||||
│ │ │ ├── creator.ts # 提供商创建器 ✅
|
||||
│ │ │ ├── types.ts # 提供商类型 ✅
|
||||
│ │ │ ├── utils.ts # 工具函数 ✅
|
||||
│ │ │ └── index.ts # 提供商导出 ✅
|
||||
│ │ ├── options/ # 配置选项
|
||||
│ │ │ ├── factory.ts # 选项工厂 ✅
|
||||
│ │ │ ├── types.ts # 选项类型 ✅
|
||||
│ │ │ ├── xai.ts # xAI 选项 ✅
|
||||
│ │ │ ├── openrouter.ts # OpenRouter 选项 ✅
|
||||
│ │ │ ├── examples.ts # 示例配置 ✅
|
||||
│ │ │ └── index.ts # 选项导出 ✅
|
||||
│ │ └── index.ts # 核心层导出 ✅
|
||||
│ ├── types.ts # 全局类型定义 ✅
|
||||
│ └── index.ts # 包主入口文件 ✅
|
||||
├── package.json # 包配置文件 ✅
|
||||
├── tsconfig.json # TypeScript 配置 ✅
|
||||
├── README.md # 包说明文档 ✅
|
||||
└── AI_SDK_ARCHITECTURE.md # 本文档 ✅
|
||||
```
|
||||
|
||||
## 4. 架构分层详解
|
||||
|
||||
### 4.1 Models Layer (模型层)
|
||||
|
||||
**职责**:统一的模型创建和配置管理
|
||||
|
||||
**核心文件**:
|
||||
|
||||
- `factory.ts`: 模型工厂函数 (`createModel`, `createModels`)
|
||||
- `ProviderCreator.ts`: 底层提供商创建和模型实例化
|
||||
- `types.ts`: 模型配置类型定义
|
||||
|
||||
**设计特点**:
|
||||
|
||||
- 函数式设计,避免不必要的类抽象
|
||||
- 统一的模型配置接口
|
||||
- 自动处理中间件应用
|
||||
- 支持批量模型创建
|
||||
|
||||
**核心API**:
|
||||
|
||||
```typescript
|
||||
// 模型配置接口
|
||||
export interface ModelConfig {
|
||||
providerId: ProviderId
|
||||
modelId: string
|
||||
options: ProviderSettingsMap[ProviderId]
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
}
|
||||
|
||||
// 核心模型创建函数
|
||||
export async function createModel(config: ModelConfig): Promise<LanguageModel>
|
||||
export async function createModels(configs: ModelConfig[]): Promise<LanguageModel[]>
|
||||
```
|
||||
|
||||
### 4.2 Runtime Layer (运行时层)
|
||||
|
||||
**职责**:运行时执行器和用户面向的API接口
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `executor.ts`: 运行时执行器类
|
||||
- `plugin-engine.ts`: 插件引擎(原PluginEnabledAiClient)
|
||||
- `index.ts`: 便捷函数和工厂方法
|
||||
|
||||
**设计特点**:
|
||||
|
||||
- 提供三种使用方式:类实例、静态工厂、函数式调用
|
||||
- 自动集成模型创建和插件处理
|
||||
- 完整的类型安全支持
|
||||
- 为 OpenAI Agents SDK 预留扩展接口
|
||||
|
||||
**核心API**:
|
||||
|
||||
```typescript
|
||||
// 运行时执行器
|
||||
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
static create<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T],
|
||||
plugins?: AiPlugin[]
|
||||
): RuntimeExecutor<T>
|
||||
|
||||
async streamText(modelId: string, params: StreamTextParams): Promise<StreamTextResult>
|
||||
async generateText(modelId: string, params: GenerateTextParams): Promise<GenerateTextResult>
|
||||
async streamObject(modelId: string, params: StreamObjectParams): Promise<StreamObjectResult>
|
||||
async generateObject(modelId: string, params: GenerateObjectParams): Promise<GenerateObjectResult>
|
||||
}
|
||||
|
||||
// 便捷函数式API
|
||||
export async function streamText<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T],
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
plugins?: AiPlugin[]
|
||||
): Promise<StreamTextResult>
|
||||
```
|
||||
|
||||
### 4.3 Plugin System (插件系统)
|
||||
|
||||
**职责**:可扩展的插件架构
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `PluginManager`: 插件生命周期管理
|
||||
- `built-in/`: 内置插件集合
|
||||
- 流转换收集和应用
|
||||
|
||||
**设计特点**:
|
||||
|
||||
- 借鉴 Rollup 的钩子分类设计
|
||||
- 支持流转换 (`experimental_transform`)
|
||||
- 内置常用插件(日志、计数等)
|
||||
- 完整的生命周期钩子
|
||||
|
||||
**插件接口**:
|
||||
|
||||
```typescript
|
||||
export interface AiPlugin {
|
||||
name: string
|
||||
enforce?: 'pre' | 'post'
|
||||
|
||||
// 【First】首个钩子 - 只执行第一个返回值的插件
|
||||
resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise<string | null>
|
||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
||||
|
||||
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
|
||||
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
|
||||
|
||||
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
|
||||
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
|
||||
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
|
||||
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
|
||||
|
||||
// 【Stream】流处理
|
||||
transformStream?: () => TransformStream
|
||||
}
|
||||
```
|
||||
|
||||
### 4.4 Middleware System (中间件系统)
|
||||
|
||||
**职责**:AI SDK原生中间件支持
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `ModelWrapper.ts`: 模型包装函数
|
||||
|
||||
**设计哲学**:
|
||||
|
||||
- 直接使用AI SDK的 `wrapLanguageModel`
|
||||
- 与插件系统分离,职责明确
|
||||
- 函数式设计,简化使用
|
||||
|
||||
```typescript
|
||||
export function wrapModelWithMiddlewares(model: LanguageModel, middlewares: LanguageModelV1Middleware[]): LanguageModel
|
||||
```
|
||||
|
||||
### 4.5 Provider System (提供商系统)
|
||||
|
||||
**职责**:AI Provider注册表和动态导入
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `registry.ts`: 19+ Provider配置和类型
|
||||
- `factory.ts`: Provider配置工厂
|
||||
|
||||
**支持的Providers**:
|
||||
|
||||
- OpenAI, Anthropic, Google, XAI
|
||||
- Azure OpenAI, Amazon Bedrock, Google Vertex
|
||||
- Groq, Together.ai, Fireworks, DeepSeek
|
||||
- 等19+ AI SDK官方支持的providers
|
||||
|
||||
## 5. 使用方式
|
||||
|
||||
### 5.1 函数式调用 (推荐 - 简单场景)
|
||||
|
||||
```typescript
|
||||
import { streamText, generateText } from '@cherrystudio/ai-core/runtime'
|
||||
|
||||
// 直接函数调用
|
||||
const stream = await streamText(
|
||||
'anthropic',
|
||||
{ apiKey: 'your-api-key' },
|
||||
'claude-3',
|
||||
{ messages: [{ role: 'user', content: 'Hello!' }] },
|
||||
[loggingPlugin]
|
||||
)
|
||||
```
|
||||
|
||||
### 5.2 执行器实例 (推荐 - 复杂场景)
|
||||
|
||||
```typescript
|
||||
import { createExecutor } from '@cherrystudio/ai-core/runtime'
|
||||
|
||||
// 创建可复用的执行器
|
||||
const executor = createExecutor('openai', { apiKey: 'your-api-key' }, [plugin1, plugin2])
|
||||
|
||||
// 多次使用
|
||||
const stream = await executor.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello!' }]
|
||||
})
|
||||
|
||||
const result = await executor.generateText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'How are you?' }]
|
||||
})
|
||||
```
|
||||
|
||||
### 5.3 静态工厂方法
|
||||
|
||||
```typescript
|
||||
import { RuntimeExecutor } from '@cherrystudio/ai-core/runtime'
|
||||
|
||||
// 静态创建
|
||||
const executor = RuntimeExecutor.create('anthropic', { apiKey: 'your-api-key' })
|
||||
await executor.streamText('claude-3', { messages: [...] })
|
||||
```
|
||||
|
||||
### 5.4 直接模型创建 (高级用法)
|
||||
|
||||
```typescript
|
||||
import { createModel } from '@cherrystudio/ai-core/models'
|
||||
import { streamText } from 'ai'
|
||||
|
||||
// 直接创建模型使用
|
||||
const model = await createModel({
|
||||
providerId: 'openai',
|
||||
modelId: 'gpt-4',
|
||||
options: { apiKey: 'your-api-key' },
|
||||
middlewares: [middleware1, middleware2]
|
||||
})
|
||||
|
||||
// 直接使用 AI SDK
|
||||
const result = await streamText({ model, messages: [...] })
|
||||
```
|
||||
|
||||
## 6. 为 OpenAI Agents SDK 预留的设计
|
||||
|
||||
### 6.1 架构兼容性
|
||||
|
||||
当前架构完全兼容 OpenAI Agents SDK 的集成需求:
|
||||
|
||||
```typescript
|
||||
// 当前的模型创建
|
||||
const model = await createModel({
|
||||
providerId: 'anthropic',
|
||||
modelId: 'claude-3',
|
||||
options: { apiKey: 'xxx' }
|
||||
})
|
||||
|
||||
// 将来可以直接用于 OpenAI Agents SDK
|
||||
import { Agent, run } from '@openai/agents'
|
||||
|
||||
const agent = new Agent({
|
||||
model, // ✅ 直接兼容 LanguageModel 接口
|
||||
name: 'Assistant',
|
||||
instructions: '...',
|
||||
tools: [tool1, tool2]
|
||||
})
|
||||
|
||||
const result = await run(agent, 'user input')
|
||||
```
|
||||
|
||||
### 6.2 预留的扩展点
|
||||
|
||||
1. **runtime/agents/** 目录预留
|
||||
2. **AgentExecutor** 类预留
|
||||
3. **Agent工具转换插件** 预留
|
||||
4. **多Agent编排** 预留
|
||||
|
||||
### 6.3 未来架构扩展
|
||||
|
||||
```
|
||||
packages/aiCore/src/core/
|
||||
├── runtime/
|
||||
│ ├── agents/ # 🚀 未来添加
|
||||
│ │ ├── AgentExecutor.ts
|
||||
│ │ ├── WorkflowManager.ts
|
||||
│ │ └── ConversationManager.ts
|
||||
│ ├── executor.ts
|
||||
│ └── index.ts
|
||||
```
|
||||
|
||||
## 7. 架构优势
|
||||
|
||||
### 7.1 简化设计
|
||||
|
||||
- **移除过度抽象**:删除了orchestration层和creation层的复杂包装
|
||||
- **函数式优先**:models层使用函数而非类
|
||||
- **直接明了**:runtime层直接提供用户API
|
||||
|
||||
### 7.2 职责清晰
|
||||
|
||||
- **Models**: 专注模型创建和配置
|
||||
- **Runtime**: 专注执行和用户API
|
||||
- **Plugins**: 专注扩展功能
|
||||
- **Providers**: 专注AI Provider管理
|
||||
|
||||
### 7.3 类型安全
|
||||
|
||||
- 完整的 TypeScript 支持
|
||||
- AI SDK 类型的直接复用
|
||||
- 避免类型重复定义
|
||||
|
||||
### 7.4 灵活使用
|
||||
|
||||
- 三种使用模式满足不同需求
|
||||
- 从简单函数调用到复杂执行器
|
||||
- 支持直接AI SDK使用
|
||||
|
||||
### 7.5 面向未来
|
||||
|
||||
- 为 OpenAI Agents SDK 集成做好准备
|
||||
- 清晰的扩展点和架构边界
|
||||
- 模块化设计便于功能添加
|
||||
|
||||
## 8. 技术决策记录
|
||||
|
||||
### 8.1 为什么选择简化的两层架构?
|
||||
|
||||
- **职责分离**:models专注创建,runtime专注执行
|
||||
- **模块化**:每层都有清晰的边界和职责
|
||||
- **扩展性**:为Agent功能预留了清晰的扩展空间
|
||||
|
||||
### 8.2 为什么选择函数式设计?
|
||||
|
||||
- **简洁性**:避免不必要的类设计
|
||||
- **性能**:减少对象创建开销
|
||||
- **易用性**:函数调用更直观
|
||||
|
||||
### 8.3 为什么分离插件和中间件?
|
||||
|
||||
- **职责明确**: 插件处理应用特定需求
|
||||
- **原生支持**: 中间件使用AI SDK原生功能
|
||||
- **灵活性**: 两套系统可以独立演进
|
||||
|
||||
## 9. 总结
|
||||
|
||||
AI Core架构实现了:
|
||||
|
||||
### 9.1 核心特点
|
||||
|
||||
- ✅ **简化架构**: 2层核心架构,职责清晰
|
||||
- ✅ **函数式设计**: models层完全函数化
|
||||
- ✅ **类型安全**: 统一的类型定义和AI SDK类型复用
|
||||
- ✅ **插件扩展**: 强大的插件系统
|
||||
- ✅ **多种使用方式**: 满足不同复杂度需求
|
||||
- ✅ **Agent就绪**: 为OpenAI Agents SDK集成做好准备
|
||||
|
||||
### 9.2 核心价值
|
||||
|
||||
- **统一接口**: 一套API支持19+ AI providers
|
||||
- **灵活使用**: 函数式、实例式、静态工厂式
|
||||
- **强类型**: 完整的TypeScript支持
|
||||
- **可扩展**: 插件和中间件双重扩展能力
|
||||
- **高性能**: 最小化包装,直接使用AI SDK
|
||||
- **面向未来**: Agent SDK集成架构就绪
|
||||
|
||||
### 9.3 未来发展
|
||||
|
||||
这个架构提供了:
|
||||
|
||||
- **优秀的开发体验**: 简洁的API和清晰的使用模式
|
||||
- **强大的扩展能力**: 为Agent功能预留了完整的架构空间
|
||||
- **良好的维护性**: 职责分离明确,代码易于维护
|
||||
- **广泛的适用性**: 既适合简单调用也适合复杂应用
|
||||
433
packages/aiCore/README.md
Normal file
433
packages/aiCore/README.md
Normal file
@ -0,0 +1,433 @@
|
||||
# @cherrystudio/ai-core
|
||||
|
||||
Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口包,为 AI 应用提供强大的抽象层和插件化架构。
|
||||
|
||||
## ✨ 核心亮点
|
||||
|
||||
### 🏗️ 优雅的架构设计
|
||||
|
||||
- **简化分层**:`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
|
||||
- **函数式优先**:避免过度抽象,提供简洁直观的 API
|
||||
- **类型安全**:完整的 TypeScript 支持,直接复用 AI SDK 类型系统
|
||||
- **最小包装**:直接使用 AI SDK 的接口,避免重复定义和性能损耗
|
||||
|
||||
### 🔌 强大的插件系统
|
||||
|
||||
- **生命周期钩子**:支持请求全生命周期的扩展点
|
||||
- **流转换支持**:基于 AI SDK 的 `experimental_transform` 实现流处理
|
||||
- **插件分类**:First、Sequential、Parallel 三种钩子类型,满足不同场景
|
||||
- **内置插件**:webSearch、logging、toolUse 等开箱即用的功能
|
||||
|
||||
### 🌐 统一多 Provider 接口
|
||||
|
||||
- **扩展注册**:支持自定义 Provider 注册,无限扩展能力
|
||||
- **配置统一**:统一的配置接口,简化多 Provider 管理
|
||||
|
||||
### 🚀 多种使用方式
|
||||
|
||||
- **函数式调用**:适合简单场景的直接函数调用
|
||||
- **执行器实例**:适合复杂场景的可复用执行器
|
||||
- **静态工厂**:便捷的静态创建方法
|
||||
- **原生兼容**:完全兼容 AI SDK 原生 Provider Registry
|
||||
|
||||
### 🔮 面向未来
|
||||
|
||||
- **Agent 就绪**:为 OpenAI Agents SDK 集成预留架构空间
|
||||
- **模块化设计**:独立包结构,支持跨项目复用
|
||||
- **渐进式迁移**:可以逐步从现有 AI SDK 代码迁移
|
||||
|
||||
## 特性
|
||||
|
||||
- 🚀 统一的 AI Provider 接口
|
||||
- 🔄 动态导入支持
|
||||
- 🛠️ TypeScript 支持
|
||||
- 📦 强大的插件系统
|
||||
- 🌍 内置webSearch(Openai,Google,Anthropic,xAI)
|
||||
- 🎯 多种使用模式(函数式/实例式/静态工厂)
|
||||
- 🔌 可扩展的 Provider 注册系统
|
||||
- 🧩 完整的中间件支持
|
||||
- 📊 插件统计和调试功能
|
||||
|
||||
## 支持的 Providers
|
||||
|
||||
基于 [AI SDK 官方支持的 providers](https://ai-sdk.dev/providers/ai-sdk-providers):
|
||||
|
||||
**核心 Providers(内置支持):**
|
||||
|
||||
- OpenAI
|
||||
- Anthropic
|
||||
- Google Generative AI
|
||||
- OpenAI-Compatible
|
||||
- xAI (Grok)
|
||||
- Azure OpenAI
|
||||
- DeepSeek
|
||||
|
||||
**扩展 Providers(通过注册API支持):**
|
||||
|
||||
- Google Vertex AI
|
||||
- ...
|
||||
- 自定义 Provider
|
||||
|
||||
## 安装
|
||||
|
||||
```bash
|
||||
npm install @cherrystudio/ai-core ai
|
||||
```
|
||||
|
||||
### React Native
|
||||
|
||||
如果你在 React Native 项目中使用此包,需要在 `metro.config.js` 中添加以下配置:
|
||||
|
||||
```javascript
|
||||
// metro.config.js
|
||||
const { getDefaultConfig } = require('expo/metro-config')
|
||||
|
||||
const config = getDefaultConfig(__dirname)
|
||||
|
||||
// 添加对 @cherrystudio/ai-core 的支持
|
||||
config.resolver.resolverMainFields = ['react-native', 'browser', 'main']
|
||||
config.resolver.platforms = ['ios', 'android', 'native', 'web']
|
||||
|
||||
module.exports = config
|
||||
```
|
||||
|
||||
还需要安装你要使用的 AI SDK provider:
|
||||
|
||||
```bash
|
||||
npm install @ai-sdk/openai @ai-sdk/anthropic @ai-sdk/google
|
||||
```
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 基础用法
|
||||
|
||||
```typescript
|
||||
import { AiCore } from '@cherrystudio/ai-core'
|
||||
|
||||
// 创建 OpenAI executor
|
||||
const executor = AiCore.create('openai', {
|
||||
apiKey: 'your-api-key'
|
||||
})
|
||||
|
||||
// 流式生成
|
||||
const result = await executor.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello!' }]
|
||||
})
|
||||
|
||||
// 非流式生成
|
||||
const response = await executor.generateText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello!' }]
|
||||
})
|
||||
```
|
||||
|
||||
### 便捷函数
|
||||
|
||||
```typescript
|
||||
import { createOpenAIExecutor } from '@cherrystudio/ai-core'
|
||||
|
||||
// 快速创建 OpenAI executor
|
||||
const executor = createOpenAIExecutor({
|
||||
apiKey: 'your-api-key'
|
||||
})
|
||||
|
||||
// 使用 executor
|
||||
const result = await executor.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello!' }]
|
||||
})
|
||||
```
|
||||
|
||||
### 多 Provider 支持
|
||||
|
||||
```typescript
|
||||
import { AiCore } from '@cherrystudio/ai-core'
|
||||
|
||||
// 支持多种 AI providers
|
||||
const openaiExecutor = AiCore.create('openai', { apiKey: 'openai-key' })
|
||||
const anthropicExecutor = AiCore.create('anthropic', { apiKey: 'anthropic-key' })
|
||||
const googleExecutor = AiCore.create('google', { apiKey: 'google-key' })
|
||||
const xaiExecutor = AiCore.create('xai', { apiKey: 'xai-key' })
|
||||
```
|
||||
|
||||
### 扩展 Provider 注册
|
||||
|
||||
对于非内置的 providers,可以通过注册 API 扩展支持:
|
||||
|
||||
```typescript
|
||||
import { registerProvider, AiCore } from '@cherrystudio/ai-core'
|
||||
|
||||
// 方式一:导入并注册第三方 provider
|
||||
import { createGroq } from '@ai-sdk/groq'
|
||||
|
||||
registerProvider({
|
||||
id: 'groq',
|
||||
name: 'Groq',
|
||||
creator: createGroq,
|
||||
supportsImageGeneration: false
|
||||
})
|
||||
|
||||
// 现在可以使用 Groq
|
||||
const groqExecutor = AiCore.create('groq', { apiKey: 'groq-key' })
|
||||
|
||||
// 方式二:动态导入方式注册
|
||||
registerProvider({
|
||||
id: 'mistral',
|
||||
name: 'Mistral AI',
|
||||
import: () => import('@ai-sdk/mistral'),
|
||||
creatorFunctionName: 'createMistral'
|
||||
})
|
||||
|
||||
const mistralExecutor = AiCore.create('mistral', { apiKey: 'mistral-key' })
|
||||
```
|
||||
|
||||
## 🔌 插件系统
|
||||
|
||||
AI Core 提供了强大的插件系统,支持请求全生命周期的扩展。
|
||||
|
||||
### 内置插件
|
||||
|
||||
#### webSearchPlugin - 网络搜索插件
|
||||
|
||||
为不同 AI Provider 提供统一的网络搜索能力:
|
||||
|
||||
```typescript
|
||||
import { webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
|
||||
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
|
||||
webSearchPlugin({
|
||||
openai: {
|
||||
/* OpenAI 搜索配置 */
|
||||
},
|
||||
anthropic: { maxUses: 5 },
|
||||
google: {
|
||||
/* Google 搜索配置 */
|
||||
},
|
||||
xai: {
|
||||
mode: 'on',
|
||||
returnCitations: true,
|
||||
maxSearchResults: 5,
|
||||
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
|
||||
}
|
||||
})
|
||||
])
|
||||
```
|
||||
|
||||
#### loggingPlugin - 日志插件
|
||||
|
||||
提供详细的请求日志记录:
|
||||
|
||||
```typescript
|
||||
import { createLoggingPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
|
||||
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
|
||||
createLoggingPlugin({
|
||||
logLevel: 'info',
|
||||
includeParams: true,
|
||||
includeResult: false
|
||||
})
|
||||
])
|
||||
```
|
||||
|
||||
#### promptToolUsePlugin - 提示工具使用插件
|
||||
|
||||
为不支持原生 Function Call 的模型提供 prompt 方式的工具调用:
|
||||
|
||||
```typescript
|
||||
import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
|
||||
// 对于不支持 function call 的模型
|
||||
const executor = AiCore.create(
|
||||
'providerId',
|
||||
{
|
||||
apiKey: 'your-key',
|
||||
baseURL: 'https://your-model-endpoint'
|
||||
},
|
||||
[
|
||||
createPromptToolUsePlugin({
|
||||
enabled: true,
|
||||
// 可选:自定义系统提示符构建
|
||||
buildSystemPrompt: (userPrompt, tools) => {
|
||||
return `${userPrompt}\n\nAvailable tools: ${Object.keys(tools).join(', ')}`
|
||||
}
|
||||
})
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### 自定义插件
|
||||
|
||||
创建自定义插件非常简单:
|
||||
|
||||
```typescript
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
|
||||
const customPlugin = definePlugin({
|
||||
name: 'custom-plugin',
|
||||
enforce: 'pre', // 'pre' | 'post' | undefined
|
||||
|
||||
// 在请求开始时记录日志
|
||||
onRequestStart: async (context) => {
|
||||
console.log(`Starting request for model: ${context.modelId}`)
|
||||
},
|
||||
|
||||
// 转换请求参数
|
||||
transformParams: async (params, context) => {
|
||||
// 添加自定义系统消息
|
||||
if (params.messages) {
|
||||
params.messages.unshift({
|
||||
role: 'system',
|
||||
content: 'You are a helpful assistant.'
|
||||
})
|
||||
}
|
||||
return params
|
||||
},
|
||||
|
||||
// 处理响应结果
|
||||
transformResult: async (result, context) => {
|
||||
// 添加元数据
|
||||
if (result.text) {
|
||||
result.metadata = {
|
||||
processedAt: new Date().toISOString(),
|
||||
modelId: context.modelId
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
})
|
||||
|
||||
// 使用自定义插件
|
||||
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [customPlugin])
|
||||
```
|
||||
|
||||
### 使用 AI SDK 原生 Provider 注册表
|
||||
|
||||
> https://ai-sdk.dev/docs/reference/ai-sdk-core/provider-registry
|
||||
|
||||
除了使用内建的 provider 管理,你还可以使用 AI SDK 原生的 `createProviderRegistry` 来构建自己的 provider 注册表。
|
||||
|
||||
#### 基本用法示例
|
||||
|
||||
```typescript
|
||||
import { createClient } from '@cherrystudio/ai-core'
|
||||
import { createProviderRegistry } from 'ai'
|
||||
import { createOpenAI } from '@ai-sdk/openai'
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
|
||||
// 1. 创建 AI SDK 原生注册表
|
||||
export const registry = createProviderRegistry({
|
||||
// register provider with prefix and default setup:
|
||||
anthropic,
|
||||
|
||||
// register provider with prefix and custom setup:
|
||||
openai: createOpenAI({
|
||||
apiKey: process.env.OPENAI_API_KEY
|
||||
})
|
||||
})
|
||||
|
||||
// 2. 创建client,'openai'可以传空或者传providerId(内建的provider)
|
||||
const client = PluginEnabledAiClient.create('openai', {
|
||||
apiKey: process.env.OPENAI_API_KEY
|
||||
})
|
||||
|
||||
// 3. 方式1:使用内建逻辑(传统方式)
|
||||
const result1 = await client.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello with built-in logic!' }]
|
||||
})
|
||||
|
||||
// 4. 方式2:使用自定义注册表(灵活方式)
|
||||
const result2 = await client.streamText({
|
||||
model: registry.languageModel('openai:gpt-4'),
|
||||
messages: [{ role: 'user', content: 'Hello with custom registry!' }]
|
||||
})
|
||||
|
||||
// 5. 支持的重载方法
|
||||
await client.generateObject({
|
||||
model: registry.languageModel('openai:gpt-4'),
|
||||
schema: z.object({ name: z.string() }),
|
||||
messages: [{ role: 'user', content: 'Generate a user' }]
|
||||
})
|
||||
|
||||
await client.streamObject({
|
||||
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
|
||||
schema: z.object({ items: z.array(z.string()) }),
|
||||
messages: [{ role: 'user', content: 'Generate a list' }]
|
||||
})
|
||||
```
|
||||
|
||||
#### 与插件系统配合使用
|
||||
|
||||
更强大的是,你还可以将自定义注册表与 Cherry Studio 的插件系统结合使用:
|
||||
|
||||
```typescript
|
||||
import { PluginEnabledAiClient } from '@cherrystudio/ai-core'
|
||||
import { createProviderRegistry } from 'ai'
|
||||
import { createOpenAI } from '@ai-sdk/openai'
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
|
||||
// 1. 创建带插件的客户端
|
||||
const client = PluginEnabledAiClient.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: process.env.OPENAI_API_KEY
|
||||
},
|
||||
[LoggingPlugin, RetryPlugin]
|
||||
)
|
||||
|
||||
// 2. 创建自定义注册表
|
||||
const registry = createProviderRegistry({
|
||||
openai: createOpenAI({ apiKey: process.env.OPENAI_API_KEY }),
|
||||
anthropic: anthropic({ apiKey: process.env.ANTHROPIC_API_KEY })
|
||||
})
|
||||
|
||||
// 3. 方式1:使用内建逻辑 + 完整插件系统
|
||||
await client.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello with plugins!' }]
|
||||
})
|
||||
|
||||
// 4. 方式2:使用自定义注册表 + 有限插件支持
|
||||
await client.streamText({
|
||||
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
|
||||
messages: [{ role: 'user', content: 'Hello from Claude!' }]
|
||||
})
|
||||
|
||||
// 5. 支持的方法
|
||||
await client.generateObject({
|
||||
model: registry.languageModel('openai:gpt-4'),
|
||||
schema: z.object({ name: z.string() }),
|
||||
messages: [{ role: 'user', content: 'Generate a user' }]
|
||||
})
|
||||
|
||||
await client.streamObject({
|
||||
model: registry.languageModel('openai:gpt-4'),
|
||||
schema: z.object({ items: z.array(z.string()) }),
|
||||
messages: [{ role: 'user', content: 'Generate a list' }]
|
||||
})
|
||||
```
|
||||
|
||||
#### 混合使用的优势
|
||||
|
||||
- **灵活性**:可以根据需要选择使用内建逻辑或自定义注册表
|
||||
- **兼容性**:完全兼容 AI SDK 的 `createProviderRegistry` API
|
||||
- **渐进式**:可以逐步迁移现有代码,无需一次性重构
|
||||
- **插件支持**:自定义注册表仍可享受插件系统的部分功能
|
||||
- **最佳实践**:结合两种方式的优点,既有动态加载的性能优势,又有统一注册表的便利性
|
||||
|
||||
## 📚 相关资源
|
||||
|
||||
- [Vercel AI SDK 文档](https://ai-sdk.dev/)
|
||||
- [Cherry Studio 项目](https://github.com/CherryHQ/cherry-studio)
|
||||
- [AI SDK Providers](https://ai-sdk.dev/providers/ai-sdk-providers)
|
||||
|
||||
## 未来版本
|
||||
|
||||
- 🔮 多 Agent 编排
|
||||
- 🔮 可视化插件配置
|
||||
- 🔮 实时监控和分析
|
||||
- 🔮 云端插件同步
|
||||
|
||||
## 📄 License
|
||||
|
||||
MIT License - 详见 [LICENSE](https://github.com/CherryHQ/cherry-studio/blob/main/LICENSE) 文件
|
||||
|
||||
---
|
||||
|
||||
**Cherry Studio AI Core** - 让 AI 开发更简单、更强大、更灵活 🚀
|
||||
103
packages/aiCore/examples/hub-provider-usage.ts
Normal file
103
packages/aiCore/examples/hub-provider-usage.ts
Normal file
@ -0,0 +1,103 @@
|
||||
/**
|
||||
* Hub Provider 使用示例
|
||||
*
|
||||
* 演示如何使用简化后的Hub Provider功能来路由到多个底层provider
|
||||
*/
|
||||
|
||||
import { createHubProvider, initializeProvider, providerRegistry } from '../src/index'
|
||||
|
||||
async function demonstrateHubProvider() {
|
||||
try {
|
||||
// 1. 初始化底层providers
|
||||
console.log('📦 初始化底层providers...')
|
||||
|
||||
initializeProvider('openai', {
|
||||
apiKey: process.env.OPENAI_API_KEY || 'sk-test-key'
|
||||
})
|
||||
|
||||
initializeProvider('anthropic', {
|
||||
apiKey: process.env.ANTHROPIC_API_KEY || 'sk-ant-test-key'
|
||||
})
|
||||
|
||||
// 2. 创建Hub Provider(自动包含所有已初始化的providers)
|
||||
console.log('🌐 创建Hub Provider...')
|
||||
|
||||
const aihubmixProvider = createHubProvider({
|
||||
hubId: 'aihubmix',
|
||||
debug: true
|
||||
})
|
||||
|
||||
// 3. 注册Hub Provider
|
||||
providerRegistry.registerProvider('aihubmix', aihubmixProvider)
|
||||
|
||||
console.log('✅ Hub Provider "aihubmix" 注册成功')
|
||||
|
||||
// 4. 使用Hub Provider访问不同的模型
|
||||
console.log('\n🚀 使用Hub模型...')
|
||||
|
||||
// 通过Hub路由到OpenAI
|
||||
const openaiModel = providerRegistry.languageModel('aihubmix:openai:gpt-4')
|
||||
console.log('✓ OpenAI模型已获取:', openaiModel.modelId)
|
||||
|
||||
// 通过Hub路由到Anthropic
|
||||
const anthropicModel = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
|
||||
console.log('✓ Anthropic模型已获取:', anthropicModel.modelId)
|
||||
|
||||
// 5. 演示错误处理
|
||||
console.log('\n❌ 演示错误处理...')
|
||||
|
||||
try {
|
||||
// 尝试访问未初始化的provider
|
||||
providerRegistry.languageModel('aihubmix:google:gemini-pro')
|
||||
} catch (error) {
|
||||
console.log('预期错误:', error.message)
|
||||
}
|
||||
|
||||
try {
|
||||
// 尝试使用错误的模型ID格式
|
||||
providerRegistry.languageModel('aihubmix:invalid-format')
|
||||
} catch (error) {
|
||||
console.log('预期错误:', error.message)
|
||||
}
|
||||
|
||||
// 6. 多个Hub Provider示例
|
||||
console.log('\n🔄 创建多个Hub Provider...')
|
||||
|
||||
const localHubProvider = createHubProvider({
|
||||
hubId: 'local-ai'
|
||||
})
|
||||
|
||||
providerRegistry.registerProvider('local-ai', localHubProvider)
|
||||
console.log('✅ Hub Provider "local-ai" 注册成功')
|
||||
|
||||
console.log('\n🎉 Hub Provider演示完成!')
|
||||
} catch (error) {
|
||||
console.error('💥 演示过程中发生错误:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 演示简化的使用方式
|
||||
function simplifiedUsageExample() {
|
||||
console.log('\n📝 简化使用示例:')
|
||||
console.log(`
|
||||
// 1. 初始化providers
|
||||
initializeProvider('openai', { apiKey: 'sk-xxx' })
|
||||
initializeProvider('anthropic', { apiKey: 'sk-ant-xxx' })
|
||||
|
||||
// 2. 创建并注册Hub Provider
|
||||
const hubProvider = createHubProvider({ hubId: 'aihubmix' })
|
||||
providerRegistry.registerProvider('aihubmix', hubProvider)
|
||||
|
||||
// 3. 直接使用
|
||||
const model1 = providerRegistry.languageModel('aihubmix:openai:gpt-4')
|
||||
const model2 = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
|
||||
`)
|
||||
}
|
||||
|
||||
// 运行演示
|
||||
if (require.main === module) {
|
||||
demonstrateHubProvider()
|
||||
simplifiedUsageExample()
|
||||
}
|
||||
|
||||
export { demonstrateHubProvider, simplifiedUsageExample }
|
||||
167
packages/aiCore/examples/image-generation.ts
Normal file
167
packages/aiCore/examples/image-generation.ts
Normal file
@ -0,0 +1,167 @@
|
||||
/**
|
||||
* Image Generation Example
|
||||
* 演示如何使用 aiCore 的文生图功能
|
||||
*/
|
||||
|
||||
import { createExecutor, generateImage } from '../src/index'
|
||||
|
||||
async function main() {
|
||||
// 方式1: 使用执行器实例
|
||||
console.log('📸 创建 OpenAI 图像生成执行器...')
|
||||
const executor = createExecutor('openai', {
|
||||
apiKey: process.env.OPENAI_API_KEY!
|
||||
})
|
||||
|
||||
try {
|
||||
console.log('🎨 使用执行器生成图像...')
|
||||
const result1 = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A futuristic cityscape at sunset with flying cars',
|
||||
size: '1024x1024',
|
||||
n: 1
|
||||
})
|
||||
|
||||
console.log('✅ 图像生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result1.images.length,
|
||||
mediaType: result1.image.mediaType,
|
||||
hasBase64: !!result1.image.base64,
|
||||
providerMetadata: result1.providerMetadata
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 执行器生成失败:', error)
|
||||
}
|
||||
|
||||
// 方式2: 使用直接调用 API
|
||||
try {
|
||||
console.log('🎨 使用直接 API 生成图像...')
|
||||
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
|
||||
prompt: 'A magical forest with glowing mushrooms and fairy lights',
|
||||
aspectRatio: '16:9',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: 'hd',
|
||||
style: 'vivid'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
console.log('✅ 直接 API 生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result2.images.length,
|
||||
mediaType: result2.image.mediaType,
|
||||
hasBase64: !!result2.image.base64
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 直接 API 生成失败:', error)
|
||||
}
|
||||
|
||||
// 方式3: 支持其他提供商 (Google Imagen)
|
||||
if (process.env.GOOGLE_API_KEY) {
|
||||
try {
|
||||
console.log('🎨 使用 Google Imagen 生成图像...')
|
||||
const googleExecutor = createExecutor('google', {
|
||||
apiKey: process.env.GOOGLE_API_KEY!
|
||||
})
|
||||
|
||||
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
|
||||
prompt: 'A serene mountain lake at dawn with mist rising from the water',
|
||||
aspectRatio: '1:1'
|
||||
})
|
||||
|
||||
console.log('✅ Google Imagen 生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result3.images.length,
|
||||
mediaType: result3.image.mediaType,
|
||||
hasBase64: !!result3.image.base64
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ Google Imagen 生成失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 方式4: 支持插件系统
|
||||
const pluginExample = async () => {
|
||||
console.log('🔌 演示插件系统...')
|
||||
|
||||
// 创建一个示例插件,用于修改提示词
|
||||
const promptEnhancerPlugin = {
|
||||
name: 'prompt-enhancer',
|
||||
transformParams: async (params: any) => {
|
||||
console.log('🔧 插件: 增强提示词...')
|
||||
return {
|
||||
...params,
|
||||
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
|
||||
}
|
||||
},
|
||||
transformResult: async (result: any) => {
|
||||
console.log('🔧 插件: 处理结果...')
|
||||
return {
|
||||
...result,
|
||||
enhanced: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const executorWithPlugin = createExecutor(
|
||||
'openai',
|
||||
{
|
||||
apiKey: process.env.OPENAI_API_KEY!
|
||||
},
|
||||
[promptEnhancerPlugin]
|
||||
)
|
||||
|
||||
try {
|
||||
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A cute robot playing in a garden'
|
||||
})
|
||||
|
||||
console.log('✅ 插件系统生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result4.images.length,
|
||||
enhanced: (result4 as any).enhanced,
|
||||
mediaType: result4.image.mediaType
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 插件系统生成失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
await pluginExample()
|
||||
}
|
||||
|
||||
// 错误处理演示
|
||||
async function errorHandlingExample() {
|
||||
console.log('⚠️ 演示错误处理...')
|
||||
|
||||
try {
|
||||
const executor = createExecutor('openai', {
|
||||
apiKey: 'invalid-key'
|
||||
})
|
||||
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'Test image'
|
||||
})
|
||||
} catch (error: any) {
|
||||
console.log('✅ 成功捕获错误:', error.constructor.name)
|
||||
console.log('📋 错误信息:', error.message)
|
||||
console.log('🏷️ 提供商ID:', error.providerId)
|
||||
console.log('🏷️ 模型ID:', error.modelId)
|
||||
}
|
||||
}
|
||||
|
||||
// 运行示例
|
||||
if (require.main === module) {
|
||||
main()
|
||||
.then(() => {
|
||||
console.log('🎉 所有示例完成!')
|
||||
return errorHandlingExample()
|
||||
})
|
||||
.then(() => {
|
||||
console.log('🎯 示例程序结束')
|
||||
process.exit(0)
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('💥 程序执行出错:', error)
|
||||
process.exit(1)
|
||||
})
|
||||
}
|
||||
85
packages/aiCore/package.json
Normal file
85
packages/aiCore/package.json
Normal file
@ -0,0 +1,85 @@
|
||||
{
|
||||
"name": "@cherrystudio/ai-core",
|
||||
"version": "1.0.0-alpha.11",
|
||||
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/index.mjs",
|
||||
"types": "dist/index.d.ts",
|
||||
"react-native": "dist/index.js",
|
||||
"scripts": {
|
||||
"build": "tsdown",
|
||||
"dev": "tsc -w",
|
||||
"clean": "rm -rf dist",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest"
|
||||
},
|
||||
"keywords": [
|
||||
"ai",
|
||||
"sdk",
|
||||
"openai",
|
||||
"anthropic",
|
||||
"google",
|
||||
"cherry-studio",
|
||||
"vercel-ai-sdk"
|
||||
],
|
||||
"author": "Cherry Studio",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/CherryHQ/cherry-studio.git"
|
||||
},
|
||||
"bugs": {
|
||||
"url": "https://github.com/CherryHQ/cherry-studio/issues"
|
||||
},
|
||||
"homepage": "https://github.com/CherryHQ/cherry-studio#readme",
|
||||
"peerDependencies": {
|
||||
"ai": "^5.0.26"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ai-sdk/anthropic": "^2.0.5",
|
||||
"@ai-sdk/azure": "^2.0.16",
|
||||
"@ai-sdk/deepseek": "^1.0.9",
|
||||
"@ai-sdk/google": "^2.0.7",
|
||||
"@ai-sdk/openai": "^2.0.19",
|
||||
"@ai-sdk/openai-compatible": "^1.0.9",
|
||||
"@ai-sdk/provider": "^2.0.0",
|
||||
"@ai-sdk/provider-utils": "^3.0.4",
|
||||
"@ai-sdk/xai": "^2.0.9",
|
||||
"zod": "^3.25.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"tsdown": "^0.12.9",
|
||||
"typescript": "^5.0.0",
|
||||
"vitest": "^3.2.4"
|
||||
},
|
||||
"sideEffects": false,
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
},
|
||||
"files": [
|
||||
"dist"
|
||||
],
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"react-native": "./dist/index.js",
|
||||
"import": "./dist/index.mjs",
|
||||
"require": "./dist/index.js",
|
||||
"default": "./dist/index.js"
|
||||
},
|
||||
"./built-in/plugins": {
|
||||
"types": "./dist/built-in/plugins/index.d.ts",
|
||||
"react-native": "./dist/built-in/plugins/index.js",
|
||||
"import": "./dist/built-in/plugins/index.mjs",
|
||||
"require": "./dist/built-in/plugins/index.js",
|
||||
"default": "./dist/built-in/plugins/index.js"
|
||||
},
|
||||
"./provider": {
|
||||
"types": "./dist/provider/index.d.ts",
|
||||
"react-native": "./dist/provider/index.js",
|
||||
"import": "./dist/provider/index.mjs",
|
||||
"require": "./dist/provider/index.js",
|
||||
"default": "./dist/provider/index.js"
|
||||
}
|
||||
}
|
||||
}
|
||||
2
packages/aiCore/setupVitest.ts
Normal file
2
packages/aiCore/setupVitest.ts
Normal file
@ -0,0 +1,2 @@
|
||||
// 模拟 Vite SSR helper,避免 Node 环境找不到时报错
|
||||
;(globalThis as any).__vite_ssr_exportName__ = (name: string, value: any) => value
|
||||
3
packages/aiCore/src/core/README.MD
Normal file
3
packages/aiCore/src/core/README.MD
Normal file
@ -0,0 +1,3 @@
|
||||
# @cherryStudio-aiCore
|
||||
|
||||
Core
|
||||
17
packages/aiCore/src/core/index.ts
Normal file
17
packages/aiCore/src/core/index.ts
Normal file
@ -0,0 +1,17 @@
|
||||
/**
|
||||
* Core 模块导出
|
||||
* 内部核心功能,供其他模块使用,不直接面向最终调用者
|
||||
*/
|
||||
|
||||
// 中间件系统
|
||||
export type { NamedMiddleware } from './middleware'
|
||||
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
|
||||
|
||||
// 创建管理
|
||||
export { globalModelResolver, ModelResolver } from './models'
|
||||
export type { ModelConfig as ModelConfigType } from './models/types'
|
||||
|
||||
// 执行管理
|
||||
export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type'
|
||||
export { createExecutor, createOpenAICompatibleExecutor } from './runtime'
|
||||
export type { RuntimeConfig } from './runtime/types'
|
||||
8
packages/aiCore/src/core/middleware/index.ts
Normal file
8
packages/aiCore/src/core/middleware/index.ts
Normal file
@ -0,0 +1,8 @@
|
||||
/**
|
||||
* Middleware 模块导出
|
||||
* 提供通用的中间件管理能力
|
||||
*/
|
||||
|
||||
export { createMiddlewares } from './manager'
|
||||
export type { NamedMiddleware } from './types'
|
||||
export { wrapModelWithMiddlewares } from './wrapper'
|
||||
16
packages/aiCore/src/core/middleware/manager.ts
Normal file
16
packages/aiCore/src/core/middleware/manager.ts
Normal file
@ -0,0 +1,16 @@
|
||||
/**
|
||||
* 中间件管理器
|
||||
* 专注于 AI SDK 中间件的管理,与插件系统分离
|
||||
*/
|
||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
/**
|
||||
* 创建中间件列表
|
||||
* 合并用户提供的中间件
|
||||
*/
|
||||
export function createMiddlewares(userMiddlewares: LanguageModelV2Middleware[] = []): LanguageModelV2Middleware[] {
|
||||
// 未来可以在这里添加默认的中间件
|
||||
const defaultMiddlewares: LanguageModelV2Middleware[] = []
|
||||
|
||||
return [...defaultMiddlewares, ...userMiddlewares]
|
||||
}
|
||||
12
packages/aiCore/src/core/middleware/types.ts
Normal file
12
packages/aiCore/src/core/middleware/types.ts
Normal file
@ -0,0 +1,12 @@
|
||||
/**
|
||||
* 中间件系统类型定义
|
||||
*/
|
||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
/**
|
||||
* 具名中间件接口
|
||||
*/
|
||||
export interface NamedMiddleware {
|
||||
name: string
|
||||
middleware: LanguageModelV2Middleware
|
||||
}
|
||||
23
packages/aiCore/src/core/middleware/wrapper.ts
Normal file
23
packages/aiCore/src/core/middleware/wrapper.ts
Normal file
@ -0,0 +1,23 @@
|
||||
/**
|
||||
* 模型包装工具函数
|
||||
* 用于将中间件应用到LanguageModel上
|
||||
*/
|
||||
import { LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
import { wrapLanguageModel } from 'ai'
|
||||
|
||||
/**
|
||||
* 使用中间件包装模型
|
||||
*/
|
||||
export function wrapModelWithMiddlewares(
|
||||
model: LanguageModelV2,
|
||||
middlewares: LanguageModelV2Middleware[]
|
||||
): LanguageModelV2 {
|
||||
if (middlewares.length === 0) {
|
||||
return model
|
||||
}
|
||||
|
||||
return wrapLanguageModel({
|
||||
model,
|
||||
middleware: middlewares
|
||||
})
|
||||
}
|
||||
125
packages/aiCore/src/core/models/ModelResolver.ts
Normal file
125
packages/aiCore/src/core/models/ModelResolver.ts
Normal file
@ -0,0 +1,125 @@
|
||||
/**
|
||||
* 模型解析器 - models模块的核心
|
||||
* 负责将modelId解析为AI SDK的LanguageModel实例
|
||||
* 支持传统格式和命名空间格式
|
||||
* 集成了来自 ModelCreator 的特殊处理逻辑
|
||||
*/
|
||||
|
||||
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import { wrapModelWithMiddlewares } from '../middleware/wrapper'
|
||||
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../providers/RegistryManagement'
|
||||
|
||||
export class ModelResolver {
|
||||
/**
|
||||
* 核心方法:解析任意格式的modelId为语言模型
|
||||
*
|
||||
* @param modelId 模型ID,支持 'gpt-4' 和 'anthropic>claude-3' 两种格式
|
||||
* @param fallbackProviderId 当modelId为传统格式时使用的providerId
|
||||
* @param providerOptions provider配置选项(用于OpenAI模式选择等)
|
||||
* @param middlewares 中间件数组,会应用到最终模型上
|
||||
*/
|
||||
async resolveLanguageModel(
|
||||
modelId: string,
|
||||
fallbackProviderId: string,
|
||||
providerOptions?: any,
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<LanguageModelV2> {
|
||||
let finalProviderId = fallbackProviderId
|
||||
let model: LanguageModelV2
|
||||
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
|
||||
if ((fallbackProviderId === 'openai' || fallbackProviderId === 'azure') && providerOptions?.mode === 'chat') {
|
||||
finalProviderId = `${fallbackProviderId}-chat`
|
||||
}
|
||||
|
||||
// 检查是否是命名空间格式
|
||||
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||
model = this.resolveNamespacedModel(modelId)
|
||||
} else {
|
||||
// 传统格式:使用处理后的 providerId + modelId
|
||||
model = this.resolveTraditionalModel(finalProviderId, modelId)
|
||||
}
|
||||
|
||||
// 🎯 应用中间件(如果有)
|
||||
if (middlewares && middlewares.length > 0) {
|
||||
model = wrapModelWithMiddlewares(model, middlewares)
|
||||
}
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析文本嵌入模型
|
||||
*/
|
||||
async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise<EmbeddingModelV2<string>> {
|
||||
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||
return this.resolveNamespacedEmbeddingModel(modelId)
|
||||
}
|
||||
|
||||
return this.resolveTraditionalEmbeddingModel(fallbackProviderId, modelId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析图像模型
|
||||
*/
|
||||
async resolveImageModel(modelId: string, fallbackProviderId: string): Promise<ImageModelV2> {
|
||||
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||
return this.resolveNamespacedImageModel(modelId)
|
||||
}
|
||||
|
||||
return this.resolveTraditionalImageModel(fallbackProviderId, modelId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析命名空间格式的语言模型
|
||||
* aihubmix:anthropic:claude-3 -> globalRegistryManagement.languageModel('aihubmix:anthropic:claude-3')
|
||||
*/
|
||||
private resolveNamespacedModel(modelId: string): LanguageModelV2 {
|
||||
return globalRegistryManagement.languageModel(modelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析传统格式的语言模型
|
||||
* providerId: 'openai', modelId: 'gpt-4' -> globalRegistryManagement.languageModel('openai:gpt-4')
|
||||
*/
|
||||
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
|
||||
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||
console.log('fullModelId', fullModelId)
|
||||
return globalRegistryManagement.languageModel(fullModelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析命名空间格式的嵌入模型
|
||||
*/
|
||||
private resolveNamespacedEmbeddingModel(modelId: string): EmbeddingModelV2<string> {
|
||||
return globalRegistryManagement.textEmbeddingModel(modelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析传统格式的嵌入模型
|
||||
*/
|
||||
private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV2<string> {
|
||||
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||
return globalRegistryManagement.textEmbeddingModel(fullModelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析命名空间格式的图像模型
|
||||
*/
|
||||
private resolveNamespacedImageModel(modelId: string): ImageModelV2 {
|
||||
return globalRegistryManagement.imageModel(modelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析传统格式的图像模型
|
||||
*/
|
||||
private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV2 {
|
||||
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||
return globalRegistryManagement.imageModel(fullModelId as any)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 全局模型解析器实例
|
||||
*/
|
||||
export const globalModelResolver = new ModelResolver()
|
||||
9
packages/aiCore/src/core/models/index.ts
Normal file
9
packages/aiCore/src/core/models/index.ts
Normal file
@ -0,0 +1,9 @@
|
||||
/**
|
||||
* Models 模块统一导出 - 简化版
|
||||
*/
|
||||
|
||||
// 核心模型解析器
|
||||
export { globalModelResolver, ModelResolver } from './ModelResolver'
|
||||
|
||||
// 保留的类型定义(可能被其他地方使用)
|
||||
export type { ModelConfig as ModelConfigType } from './types'
|
||||
15
packages/aiCore/src/core/models/types.ts
Normal file
15
packages/aiCore/src/core/models/types.ts
Normal file
@ -0,0 +1,15 @@
|
||||
/**
|
||||
* Creation 模块类型定义
|
||||
*/
|
||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import type { ProviderId, ProviderSettingsMap } from '../providers/types'
|
||||
|
||||
export interface ModelConfig<T extends ProviderId = ProviderId> {
|
||||
providerId: T
|
||||
modelId: string
|
||||
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
// 额外模型参数
|
||||
extraModelConfig?: Record<string, any>
|
||||
}
|
||||
87
packages/aiCore/src/core/options/examples.ts
Normal file
87
packages/aiCore/src/core/options/examples.ts
Normal file
@ -0,0 +1,87 @@
|
||||
import { streamText } from 'ai'
|
||||
|
||||
import {
|
||||
createAnthropicOptions,
|
||||
createGenericProviderOptions,
|
||||
createGoogleOptions,
|
||||
createOpenAIOptions,
|
||||
mergeProviderOptions
|
||||
} from './factory'
|
||||
|
||||
// 示例1: 使用已知供应商的严格类型约束
|
||||
export function exampleOpenAIWithOptions() {
|
||||
const openaiOptions = createOpenAIOptions({
|
||||
reasoningEffort: 'medium'
|
||||
})
|
||||
|
||||
// 这里会有类型检查,确保选项符合OpenAI的设置
|
||||
return streamText({
|
||||
model: {} as any, // 实际使用时替换为真实模型
|
||||
prompt: 'Hello',
|
||||
providerOptions: openaiOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例2: 使用Anthropic供应商选项
|
||||
export function exampleAnthropicWithOptions() {
|
||||
const anthropicOptions = createAnthropicOptions({
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budgetTokens: 1000
|
||||
}
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: anthropicOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例3: 使用Google供应商选项
|
||||
export function exampleGoogleWithOptions() {
|
||||
const googleOptions = createGoogleOptions({
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: 1000
|
||||
}
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: googleOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例4: 使用未知供应商(通用类型)
|
||||
export function exampleUnknownProviderWithOptions() {
|
||||
const customProviderOptions = createGenericProviderOptions('custom-provider', {
|
||||
temperature: 0.7,
|
||||
customSetting: 'value',
|
||||
anotherOption: true
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: customProviderOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例5: 合并多个供应商选项
|
||||
export function exampleMergedOptions() {
|
||||
const openaiOptions = createOpenAIOptions({})
|
||||
|
||||
const customOptions = createGenericProviderOptions('custom', {
|
||||
customParam: 'value'
|
||||
})
|
||||
|
||||
const mergedOptions = mergeProviderOptions(openaiOptions, customOptions)
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: mergedOptions
|
||||
})
|
||||
}
|
||||
71
packages/aiCore/src/core/options/factory.ts
Normal file
71
packages/aiCore/src/core/options/factory.ts
Normal file
@ -0,0 +1,71 @@
|
||||
import { ExtractProviderOptions, ProviderOptionsMap, TypedProviderOptions } from './types'
|
||||
|
||||
/**
|
||||
* 创建特定供应商的选项
|
||||
* @param provider 供应商名称
|
||||
* @param options 供应商特定的选项
|
||||
* @returns 格式化的provider options
|
||||
*/
|
||||
export function createProviderOptions<T extends keyof ProviderOptionsMap>(
|
||||
provider: T,
|
||||
options: ExtractProviderOptions<T>
|
||||
): Record<T, ExtractProviderOptions<T>> {
|
||||
return { [provider]: options } as Record<T, ExtractProviderOptions<T>>
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建任意供应商的选项(包括未知供应商)
|
||||
* @param provider 供应商名称
|
||||
* @param options 供应商选项
|
||||
* @returns 格式化的provider options
|
||||
*/
|
||||
export function createGenericProviderOptions<T extends string>(
|
||||
provider: T,
|
||||
options: Record<string, any>
|
||||
): Record<T, Record<string, any>> {
|
||||
return { [provider]: options } as Record<T, Record<string, any>>
|
||||
}
|
||||
|
||||
/**
|
||||
* 合并多个供应商的options
|
||||
* @param optionsMap 包含多个供应商选项的对象
|
||||
* @returns 合并后的TypedProviderOptions
|
||||
*/
|
||||
export function mergeProviderOptions(...optionsMap: Partial<TypedProviderOptions>[]): TypedProviderOptions {
|
||||
return Object.assign({}, ...optionsMap)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenAI供应商选项的便捷函数
|
||||
*/
|
||||
export function createOpenAIOptions(options: ExtractProviderOptions<'openai'>) {
|
||||
return createProviderOptions('openai', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建Anthropic供应商选项的便捷函数
|
||||
*/
|
||||
export function createAnthropicOptions(options: ExtractProviderOptions<'anthropic'>) {
|
||||
return createProviderOptions('anthropic', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建Google供应商选项的便捷函数
|
||||
*/
|
||||
export function createGoogleOptions(options: ExtractProviderOptions<'google'>) {
|
||||
return createProviderOptions('google', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenRouter供应商选项的便捷函数
|
||||
*/
|
||||
export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'>) {
|
||||
return createProviderOptions('openrouter', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建XAI供应商选项的便捷函数
|
||||
*/
|
||||
export function createXaiOptions(options: ExtractProviderOptions<'xai'>) {
|
||||
return createProviderOptions('xai', options)
|
||||
}
|
||||
2
packages/aiCore/src/core/options/index.ts
Normal file
2
packages/aiCore/src/core/options/index.ts
Normal file
@ -0,0 +1,2 @@
|
||||
export * from './factory'
|
||||
export * from './types'
|
||||
38
packages/aiCore/src/core/options/openrouter.ts
Normal file
38
packages/aiCore/src/core/options/openrouter.ts
Normal file
@ -0,0 +1,38 @@
|
||||
export type OpenRouterProviderOptions = {
|
||||
models?: string[]
|
||||
|
||||
/**
|
||||
* https://openrouter.ai/docs/use-cases/reasoning-tokens
|
||||
* One of `max_tokens` or `effort` is required.
|
||||
* If `exclude` is true, reasoning will be removed from the response. Default is false.
|
||||
*/
|
||||
reasoning?: {
|
||||
exclude?: boolean
|
||||
} & (
|
||||
| {
|
||||
max_tokens: number
|
||||
}
|
||||
| {
|
||||
effort: 'high' | 'medium' | 'low'
|
||||
}
|
||||
)
|
||||
|
||||
/**
|
||||
* A unique identifier representing your end-user, which can
|
||||
* help OpenRouter to monitor and detect abuse.
|
||||
*/
|
||||
user?: string
|
||||
|
||||
extraBody?: Record<string, unknown>
|
||||
|
||||
/**
|
||||
* Enable usage accounting to get detailed token usage information.
|
||||
* https://openrouter.ai/docs/use-cases/usage-accounting
|
||||
*/
|
||||
usage?: {
|
||||
/**
|
||||
* When true, includes token usage information in the response.
|
||||
*/
|
||||
include: boolean
|
||||
}
|
||||
}
|
||||
33
packages/aiCore/src/core/options/types.ts
Normal file
33
packages/aiCore/src/core/options/types.ts
Normal file
@ -0,0 +1,33 @@
|
||||
import { type AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||
import { type GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||
import { type OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||
import { type SharedV2ProviderMetadata } from '@ai-sdk/provider'
|
||||
|
||||
import { type OpenRouterProviderOptions } from './openrouter'
|
||||
import { type XaiProviderOptions } from './xai'
|
||||
|
||||
export type ProviderOptions<T extends keyof SharedV2ProviderMetadata> = SharedV2ProviderMetadata[T]
|
||||
|
||||
/**
|
||||
* 供应商选项类型,如果map中没有,说明没有约束
|
||||
*/
|
||||
export type ProviderOptionsMap = {
|
||||
openai: OpenAIResponsesProviderOptions
|
||||
anthropic: AnthropicProviderOptions
|
||||
google: GoogleGenerativeAIProviderOptions
|
||||
openrouter: OpenRouterProviderOptions
|
||||
xai: XaiProviderOptions
|
||||
}
|
||||
|
||||
// 工具类型,用于从ProviderOptionsMap中提取特定供应商的选项类型
|
||||
export type ExtractProviderOptions<T extends keyof ProviderOptionsMap> = ProviderOptionsMap[T]
|
||||
|
||||
/**
|
||||
* 类型安全的ProviderOptions
|
||||
* 对于已知供应商使用严格类型,对于未知供应商允许任意Record<string, JSONValue>
|
||||
*/
|
||||
export type TypedProviderOptions = {
|
||||
[K in keyof ProviderOptionsMap]?: ProviderOptionsMap[K]
|
||||
} & {
|
||||
[K in string]?: Record<string, any>
|
||||
} & SharedV2ProviderMetadata
|
||||
86
packages/aiCore/src/core/options/xai.ts
Normal file
86
packages/aiCore/src/core/options/xai.ts
Normal file
@ -0,0 +1,86 @@
|
||||
// copy from @ai-sdk/xai/xai-chat-options.ts
|
||||
// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件
|
||||
|
||||
import * as z from 'zod/v4'
|
||||
|
||||
const webSourceSchema = z.object({
|
||||
type: z.literal('web'),
|
||||
country: z.string().length(2).optional(),
|
||||
excludedWebsites: z.array(z.string()).max(5).optional(),
|
||||
allowedWebsites: z.array(z.string()).max(5).optional(),
|
||||
safeSearch: z.boolean().optional()
|
||||
})
|
||||
|
||||
const xSourceSchema = z.object({
|
||||
type: z.literal('x'),
|
||||
xHandles: z.array(z.string()).optional()
|
||||
})
|
||||
|
||||
const newsSourceSchema = z.object({
|
||||
type: z.literal('news'),
|
||||
country: z.string().length(2).optional(),
|
||||
excludedWebsites: z.array(z.string()).max(5).optional(),
|
||||
safeSearch: z.boolean().optional()
|
||||
})
|
||||
|
||||
const rssSourceSchema = z.object({
|
||||
type: z.literal('rss'),
|
||||
links: z.array(z.url()).max(1) // currently only supports one RSS link
|
||||
})
|
||||
|
||||
const searchSourceSchema = z.discriminatedUnion('type', [
|
||||
webSourceSchema,
|
||||
xSourceSchema,
|
||||
newsSourceSchema,
|
||||
rssSourceSchema
|
||||
])
|
||||
|
||||
export const xaiProviderOptions = z.object({
|
||||
/**
|
||||
* reasoning effort for reasoning models
|
||||
* only supported by grok-3-mini and grok-3-mini-fast models
|
||||
*/
|
||||
reasoningEffort: z.enum(['low', 'high']).optional(),
|
||||
|
||||
searchParameters: z
|
||||
.object({
|
||||
/**
|
||||
* search mode preference
|
||||
* - "off": disables search completely
|
||||
* - "auto": model decides whether to search (default)
|
||||
* - "on": always enables search
|
||||
*/
|
||||
mode: z.enum(['off', 'auto', 'on']),
|
||||
|
||||
/**
|
||||
* whether to return citations in the response
|
||||
* defaults to true
|
||||
*/
|
||||
returnCitations: z.boolean().optional(),
|
||||
|
||||
/**
|
||||
* start date for search data (ISO8601 format: YYYY-MM-DD)
|
||||
*/
|
||||
fromDate: z.string().optional(),
|
||||
|
||||
/**
|
||||
* end date for search data (ISO8601 format: YYYY-MM-DD)
|
||||
*/
|
||||
toDate: z.string().optional(),
|
||||
|
||||
/**
|
||||
* maximum number of search results to consider
|
||||
* defaults to 20
|
||||
*/
|
||||
maxSearchResults: z.number().min(1).max(50).optional(),
|
||||
|
||||
/**
|
||||
* data sources to search from
|
||||
* defaults to ["web", "x"] if not specified
|
||||
*/
|
||||
sources: z.array(searchSourceSchema).optional()
|
||||
})
|
||||
.optional()
|
||||
})
|
||||
|
||||
export type XaiProviderOptions = z.infer<typeof xaiProviderOptions>
|
||||
257
packages/aiCore/src/core/plugins/README.md
Normal file
257
packages/aiCore/src/core/plugins/README.md
Normal file
@ -0,0 +1,257 @@
|
||||
# AI Core 插件系统
|
||||
|
||||
支持四种钩子类型:**First**、**Sequential**、**Parallel** 和 **Stream**。
|
||||
|
||||
## 🎯 设计理念
|
||||
|
||||
- **语义清晰**:不同钩子有不同的执行语义
|
||||
- **类型安全**:TypeScript 完整支持
|
||||
- **性能优化**:First 短路、Parallel 并发、Sequential 链式
|
||||
- **易于扩展**:`enforce` 排序 + 功能分类
|
||||
|
||||
## 📋 钩子类型
|
||||
|
||||
### 🥇 First 钩子 - 首个有效结果
|
||||
|
||||
```typescript
|
||||
// 只执行第一个返回值的插件,用于解析和查找
|
||||
resolveModel?: (modelId: string, context: AiRequestContext) => string | null
|
||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null
|
||||
```
|
||||
|
||||
### 🔄 Sequential 钩子 - 链式数据转换
|
||||
|
||||
```typescript
|
||||
// 按顺序链式执行,每个插件可以修改数据
|
||||
transformParams?: (params: any, context: AiRequestContext) => any
|
||||
transformResult?: (result: any, context: AiRequestContext) => any
|
||||
```
|
||||
|
||||
### ⚡ Parallel 钩子 - 并行副作用
|
||||
|
||||
```typescript
|
||||
// 并发执行,用于日志、监控等副作用
|
||||
onRequestStart?: (context: AiRequestContext) => void
|
||||
onRequestEnd?: (context: AiRequestContext, result: any) => void
|
||||
onError?: (error: Error, context: AiRequestContext) => void
|
||||
```
|
||||
|
||||
### 🌊 Stream 钩子 - 流处理
|
||||
|
||||
```typescript
|
||||
// 直接使用 AI SDK 的 TransformStream
|
||||
transformStream?: () => (options) => TransformStream<TextStreamPart, TextStreamPart>
|
||||
```
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 基础用法
|
||||
|
||||
```typescript
|
||||
import { PluginManager, createContext, definePlugin } from '@cherrystudio/ai-core/middleware'
|
||||
|
||||
// 创建插件管理器
|
||||
const pluginManager = new PluginManager()
|
||||
|
||||
// 添加插件
|
||||
pluginManager.use({
|
||||
name: 'my-plugin',
|
||||
async transformParams(params, context) {
|
||||
return { ...params, temperature: 0.7 }
|
||||
}
|
||||
})
|
||||
|
||||
// 使用插件
|
||||
const context = createContext('openai', 'gpt-4', { messages: [] })
|
||||
const transformedParams = await pluginManager.executeSequential(
|
||||
'transformParams',
|
||||
{ messages: [{ role: 'user', content: 'Hello' }] },
|
||||
context
|
||||
)
|
||||
```
|
||||
|
||||
### 完整示例
|
||||
|
||||
```typescript
|
||||
import {
|
||||
PluginManager,
|
||||
ModelAliasPlugin,
|
||||
LoggingPlugin,
|
||||
ParamsValidationPlugin,
|
||||
createContext
|
||||
} from '@cherrystudio/ai-core/middleware'
|
||||
|
||||
// 创建插件管理器
|
||||
const manager = new PluginManager([
|
||||
ModelAliasPlugin, // 模型别名解析
|
||||
ParamsValidationPlugin, // 参数验证
|
||||
LoggingPlugin // 日志记录
|
||||
])
|
||||
|
||||
// AI 请求流程
|
||||
async function aiRequest(providerId: string, modelId: string, params: any) {
|
||||
const context = createContext(providerId, modelId, params)
|
||||
|
||||
try {
|
||||
// 1. 【并行】触发请求开始事件
|
||||
await manager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 【首个】解析模型别名
|
||||
const resolvedModel = await manager.executeFirst('resolveModel', modelId, context)
|
||||
context.modelId = resolvedModel || modelId
|
||||
|
||||
// 3. 【串行】转换请求参数
|
||||
const transformedParams = await manager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 【流处理】收集流转换器(AI SDK 原生支持数组)
|
||||
const streamTransforms = manager.collectStreamTransforms()
|
||||
|
||||
// 5. 调用 AI SDK(这里省略具体实现)
|
||||
const result = await callAiSdk(transformedParams, streamTransforms)
|
||||
|
||||
// 6. 【串行】转换响应结果
|
||||
const transformedResult = await manager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 7. 【并行】触发请求完成事件
|
||||
await manager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 8. 【并行】触发错误事件
|
||||
await manager.executeParallel('onError', context, undefined, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 🔧 自定义插件
|
||||
|
||||
### 模型别名插件
|
||||
|
||||
```typescript
|
||||
const ModelAliasPlugin = definePlugin({
|
||||
name: 'model-alias',
|
||||
enforce: 'pre', // 最先执行
|
||||
|
||||
async resolveModel(modelId) {
|
||||
const aliases = {
|
||||
gpt4: 'gpt-4-turbo-preview',
|
||||
claude: 'claude-3-sonnet-20240229'
|
||||
}
|
||||
return aliases[modelId] || null
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 参数验证插件
|
||||
|
||||
```typescript
|
||||
const ValidationPlugin = definePlugin({
|
||||
name: 'validation',
|
||||
|
||||
async transformParams(params) {
|
||||
if (!params.messages) {
|
||||
throw new Error('messages is required')
|
||||
}
|
||||
|
||||
return {
|
||||
...params,
|
||||
temperature: params.temperature ?? 0.7,
|
||||
max_tokens: params.max_tokens ?? 4096
|
||||
}
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 监控插件
|
||||
|
||||
```typescript
|
||||
const MonitoringPlugin = definePlugin({
|
||||
name: 'monitoring',
|
||||
enforce: 'post', // 最后执行
|
||||
|
||||
async onRequestEnd(context, result) {
|
||||
const duration = Date.now() - context.startTime
|
||||
console.log(`请求耗时: ${duration}ms`)
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 内容过滤插件
|
||||
|
||||
```typescript
|
||||
const FilterPlugin = definePlugin({
|
||||
name: 'content-filter',
|
||||
|
||||
transformStream() {
|
||||
return () =>
|
||||
new TransformStream({
|
||||
transform(chunk, controller) {
|
||||
if (chunk.type === 'text-delta') {
|
||||
const filtered = chunk.textDelta.replace(/敏感词/g, '***')
|
||||
controller.enqueue({ ...chunk, textDelta: filtered })
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## 📊 执行顺序
|
||||
|
||||
### 插件排序
|
||||
|
||||
```
|
||||
enforce: 'pre' → normal → enforce: 'post'
|
||||
```
|
||||
|
||||
### 钩子执行流程
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[请求开始] --> B[onRequestStart 并行执行]
|
||||
B --> C[resolveModel 首个有效]
|
||||
C --> D[loadTemplate 首个有效]
|
||||
D --> E[transformParams 串行执行]
|
||||
E --> F[collectStreamTransforms]
|
||||
F --> G[AI SDK 调用]
|
||||
G --> H[transformResult 串行执行]
|
||||
H --> I[onRequestEnd 并行执行]
|
||||
|
||||
G --> J[异常处理]
|
||||
J --> K[onError 并行执行]
|
||||
```
|
||||
|
||||
## 💡 最佳实践
|
||||
|
||||
1. **功能单一**:每个插件专注一个功能
|
||||
2. **幂等性**:插件应该是幂等的,重复执行不会产生副作用
|
||||
3. **错误处理**:插件内部处理异常,不要让异常向上传播
|
||||
4. **性能优化**:使用合适的钩子类型(First vs Sequential vs Parallel)
|
||||
5. **命名规范**:使用语义化的插件名称
|
||||
|
||||
## 🔍 调试工具
|
||||
|
||||
```typescript
|
||||
// 查看插件统计信息
|
||||
const stats = manager.getStats()
|
||||
console.log('插件统计:', stats)
|
||||
|
||||
// 查看所有插件
|
||||
const plugins = manager.getPlugins()
|
||||
console.log(
|
||||
'已注册插件:',
|
||||
plugins.map((p) => p.name)
|
||||
)
|
||||
```
|
||||
|
||||
## ⚡ 性能优势
|
||||
|
||||
- **First 钩子**:一旦找到结果立即停止,避免无效计算
|
||||
- **Parallel 钩子**:真正并发执行,不阻塞主流程
|
||||
- **Sequential 钩子**:保证数据转换的顺序性
|
||||
- **Stream 钩子**:直接集成 AI SDK,零开销
|
||||
|
||||
这个设计兼顾了简洁性和强大功能,为 AI Core 提供了灵活而高效的扩展机制。
|
||||
10
packages/aiCore/src/core/plugins/built-in/index.ts
Normal file
10
packages/aiCore/src/core/plugins/built-in/index.ts
Normal file
@ -0,0 +1,10 @@
|
||||
/**
|
||||
* 内置插件命名空间
|
||||
* 所有内置插件都以 'built-in:' 为前缀
|
||||
*/
|
||||
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
|
||||
|
||||
export { createLoggingPlugin } from './logging'
|
||||
export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin'
|
||||
export type { PromptToolUseConfig, ToolUseRequestContext, ToolUseResult } from './toolUsePlugin/type'
|
||||
export { webSearchPlugin } from './webSearchPlugin'
|
||||
86
packages/aiCore/src/core/plugins/built-in/logging.ts
Normal file
86
packages/aiCore/src/core/plugins/built-in/logging.ts
Normal file
@ -0,0 +1,86 @@
|
||||
/**
|
||||
* 内置插件:日志记录
|
||||
* 记录AI调用的关键信息,支持性能监控和调试
|
||||
*/
|
||||
import { definePlugin } from '../index'
|
||||
import type { AiRequestContext } from '../types'
|
||||
|
||||
export interface LoggingConfig {
|
||||
// 日志级别
|
||||
level?: 'debug' | 'info' | 'warn' | 'error'
|
||||
// 是否记录参数
|
||||
logParams?: boolean
|
||||
// 是否记录结果
|
||||
logResult?: boolean
|
||||
// 是否记录性能数据
|
||||
logPerformance?: boolean
|
||||
// 自定义日志函数
|
||||
logger?: (level: string, message: string, data?: any) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建日志插件
|
||||
*/
|
||||
export function createLoggingPlugin(config: LoggingConfig = {}) {
|
||||
const { level = 'info', logParams = true, logResult = false, logPerformance = true, logger = console.log } = config
|
||||
|
||||
const startTimes = new Map<string, number>()
|
||||
|
||||
return definePlugin({
|
||||
name: 'built-in:logging',
|
||||
|
||||
onRequestStart: (context: AiRequestContext) => {
|
||||
const requestId = context.requestId
|
||||
startTimes.set(requestId, Date.now())
|
||||
|
||||
logger(level, `🚀 AI Request Started`, {
|
||||
requestId,
|
||||
providerId: context.providerId,
|
||||
modelId: context.modelId,
|
||||
originalParams: logParams ? context.originalParams : '[hidden]'
|
||||
})
|
||||
},
|
||||
|
||||
onRequestEnd: (context: AiRequestContext, result: any) => {
|
||||
const requestId = context.requestId
|
||||
const startTime = startTimes.get(requestId)
|
||||
const duration = startTime ? Date.now() - startTime : undefined
|
||||
startTimes.delete(requestId)
|
||||
|
||||
const logData: any = {
|
||||
requestId,
|
||||
providerId: context.providerId,
|
||||
modelId: context.modelId
|
||||
}
|
||||
|
||||
if (logPerformance && duration) {
|
||||
logData.duration = `${duration}ms`
|
||||
}
|
||||
|
||||
if (logResult) {
|
||||
logData.result = result
|
||||
}
|
||||
|
||||
logger(level, `✅ AI Request Completed`, logData)
|
||||
},
|
||||
|
||||
onError: (error: Error, context: AiRequestContext) => {
|
||||
const requestId = context.requestId
|
||||
const startTime = startTimes.get(requestId)
|
||||
const duration = startTime ? Date.now() - startTime : undefined
|
||||
startTimes.delete(requestId)
|
||||
|
||||
logger('error', `❌ AI Request Failed`, {
|
||||
requestId,
|
||||
providerId: context.providerId,
|
||||
modelId: context.modelId,
|
||||
duration: duration ? `${duration}ms` : undefined,
|
||||
error: {
|
||||
name: error.name,
|
||||
message: error.message,
|
||||
stack: error.stack
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -0,0 +1,139 @@
|
||||
/**
|
||||
* 流事件管理器
|
||||
*
|
||||
* 负责处理 AI SDK 流事件的发送和管理
|
||||
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||
*/
|
||||
import type { ModelMessage } from 'ai'
|
||||
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import type { StreamController } from './ToolExecutor'
|
||||
|
||||
/**
|
||||
* 流事件管理器类
|
||||
*/
|
||||
export class StreamEventManager {
|
||||
/**
|
||||
* 发送工具调用步骤开始事件
|
||||
*/
|
||||
sendStepStartEvent(controller: StreamController): void {
|
||||
controller.enqueue({
|
||||
type: 'start-step',
|
||||
request: {},
|
||||
warnings: []
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送步骤完成事件
|
||||
*/
|
||||
sendStepFinishEvent(controller: StreamController, chunk: any): void {
|
||||
controller.enqueue({
|
||||
type: 'finish-step',
|
||||
finishReason: 'stop',
|
||||
response: chunk.response,
|
||||
usage: chunk.usage,
|
||||
providerMetadata: chunk.providerMetadata
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理递归调用并将结果流接入当前流
|
||||
*/
|
||||
async handleRecursiveCall(
|
||||
controller: StreamController,
|
||||
recursiveParams: any,
|
||||
context: AiRequestContext,
|
||||
stepId: string
|
||||
): Promise<void> {
|
||||
try {
|
||||
console.log('[MCP Prompt] Starting recursive call after tool execution...')
|
||||
|
||||
const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||
|
||||
if (recursiveResult && recursiveResult.fullStream) {
|
||||
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
|
||||
} else {
|
||||
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
|
||||
}
|
||||
} catch (error) {
|
||||
this.handleRecursiveCallError(controller, error, stepId)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将递归流的数据传递到当前流
|
||||
*/
|
||||
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
|
||||
const reader = recursiveStream.getReader()
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
break
|
||||
}
|
||||
if (value.type === 'finish') {
|
||||
// 迭代的流不发finish
|
||||
break
|
||||
}
|
||||
// 将递归流的数据传递到当前流
|
||||
controller.enqueue(value)
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理递归调用错误
|
||||
*/
|
||||
private handleRecursiveCallError(controller: StreamController, error: unknown, stepId: string): void {
|
||||
console.error('[MCP Prompt] Recursive call failed:', error)
|
||||
|
||||
// 使用 AI SDK 标准错误格式,但不中断流
|
||||
controller.enqueue({
|
||||
type: 'error',
|
||||
error: {
|
||||
message: error instanceof Error ? error.message : String(error),
|
||||
name: error instanceof Error ? error.name : 'RecursiveCallError'
|
||||
}
|
||||
})
|
||||
|
||||
// 继续发送文本增量,保持流的连续性
|
||||
controller.enqueue({
|
||||
type: 'text-delta',
|
||||
id: stepId,
|
||||
text: '\n\n[工具执行后递归调用失败,继续对话...]'
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建递归调用的参数
|
||||
*/
|
||||
buildRecursiveParams(context: AiRequestContext, textBuffer: string, toolResultsText: string, tools: any): any {
|
||||
// 构建新的对话消息
|
||||
const newMessages: ModelMessage[] = [
|
||||
...(context.originalParams.messages || []),
|
||||
{
|
||||
role: 'assistant',
|
||||
content: textBuffer
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: toolResultsText
|
||||
}
|
||||
]
|
||||
|
||||
// 递归调用,继续对话,重新传递 tools
|
||||
const recursiveParams = {
|
||||
...context.originalParams,
|
||||
messages: newMessages,
|
||||
tools: tools
|
||||
}
|
||||
|
||||
// 更新上下文中的消息
|
||||
context.originalParams.messages = newMessages
|
||||
|
||||
return recursiveParams
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,156 @@
|
||||
/**
|
||||
* 工具执行器
|
||||
*
|
||||
* 负责工具的执行、结果格式化和相关事件发送
|
||||
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||
*/
|
||||
import type { ToolSet } from 'ai'
|
||||
|
||||
import type { ToolUseResult } from './type'
|
||||
|
||||
/**
|
||||
* 工具执行结果
|
||||
*/
|
||||
export interface ExecutedResult {
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
result: any
|
||||
isError?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* 流控制器类型(从 AI SDK 提取)
|
||||
*/
|
||||
export interface StreamController {
|
||||
enqueue(chunk: any): void
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具执行器类
|
||||
*/
|
||||
export class ToolExecutor {
|
||||
/**
|
||||
* 执行多个工具调用
|
||||
*/
|
||||
async executeTools(
|
||||
toolUses: ToolUseResult[],
|
||||
tools: ToolSet,
|
||||
controller: StreamController
|
||||
): Promise<ExecutedResult[]> {
|
||||
const executedResults: ExecutedResult[] = []
|
||||
|
||||
for (const toolUse of toolUses) {
|
||||
try {
|
||||
const tool = tools[toolUse.toolName]
|
||||
if (!tool || typeof tool.execute !== 'function') {
|
||||
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
|
||||
}
|
||||
|
||||
// 发送工具调用开始事件
|
||||
this.sendToolStartEvents(controller, toolUse)
|
||||
|
||||
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
|
||||
|
||||
// 发送 tool-call 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-call',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
input: tool.inputSchema
|
||||
})
|
||||
|
||||
const result = await tool.execute(toolUse.arguments, {
|
||||
toolCallId: toolUse.id,
|
||||
messages: [],
|
||||
abortSignal: new AbortController().signal
|
||||
})
|
||||
|
||||
// 发送 tool-result 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-result',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
input: toolUse.arguments,
|
||||
output: result
|
||||
})
|
||||
|
||||
executedResults.push({
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
result,
|
||||
isError: false
|
||||
})
|
||||
} catch (error) {
|
||||
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
|
||||
|
||||
// 处理错误情况
|
||||
const errorResult = this.handleToolError(toolUse, error, controller)
|
||||
executedResults.push(errorResult)
|
||||
}
|
||||
}
|
||||
|
||||
return executedResults
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化工具结果为 Cherry Studio 标准格式
|
||||
*/
|
||||
formatToolResults(executedResults: ExecutedResult[]): string {
|
||||
return executedResults
|
||||
.map((tr) => {
|
||||
if (!tr.isError) {
|
||||
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <result>${JSON.stringify(tr.result)}</result>\n</tool_use_result>`
|
||||
} else {
|
||||
const error = tr.result || 'Unknown error'
|
||||
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <error>${error}</error>\n</tool_use_result>`
|
||||
}
|
||||
})
|
||||
.join('\n\n')
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送工具调用开始相关事件
|
||||
*/
|
||||
private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
|
||||
// 发送 tool-input-start 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-input-start',
|
||||
id: toolUse.id,
|
||||
toolName: toolUse.toolName
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具执行错误
|
||||
*/
|
||||
private handleToolError(
|
||||
toolUse: ToolUseResult,
|
||||
error: unknown,
|
||||
controller: StreamController
|
||||
// _tools: ToolSet
|
||||
): ExecutedResult {
|
||||
// 使用 AI SDK 标准错误格式
|
||||
// const toolError: TypedToolError<typeof _tools> = {
|
||||
// type: 'tool-error',
|
||||
// toolCallId: toolUse.id,
|
||||
// toolName: toolUse.toolName,
|
||||
// input: toolUse.arguments,
|
||||
// error: error instanceof Error ? error.message : String(error)
|
||||
// }
|
||||
|
||||
// controller.enqueue(toolError)
|
||||
|
||||
// 发送标准错误事件
|
||||
controller.enqueue({
|
||||
type: 'error',
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
})
|
||||
|
||||
return {
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
result: error instanceof Error ? error.message : String(error),
|
||||
isError: true
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,373 @@
|
||||
/**
|
||||
* 内置插件:MCP Prompt 模式
|
||||
* 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用
|
||||
* 内置默认逻辑,支持自定义覆盖
|
||||
*/
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { definePlugin } from '../../index'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import { StreamEventManager } from './StreamEventManager'
|
||||
import { ToolExecutor } from './ToolExecutor'
|
||||
import { PromptToolUseConfig, ToolUseResult } from './type'
|
||||
|
||||
/**
|
||||
* 默认系统提示符模板(提取自 Cherry Studio)
|
||||
*/
|
||||
const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \\
|
||||
You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
|
||||
|
||||
## Tool Use Formatting
|
||||
|
||||
Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure:
|
||||
|
||||
<tool_use>
|
||||
<name>{tool_name}</name>
|
||||
<arguments>{json_arguments}</arguments>
|
||||
</tool_use>
|
||||
|
||||
The tool name should be the exact name of the tool you are using, and the arguments should be a JSON object containing the parameters required by that tool. For example:
|
||||
<tool_use>
|
||||
<name>python_interpreter</name>
|
||||
<arguments>{"code": "5 + 3 + 1294.678"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
The user will respond with the result of the tool use, which should be formatted as follows:
|
||||
|
||||
<tool_use_result>
|
||||
<name>{tool_name}</name>
|
||||
<result>{result}</result>
|
||||
</tool_use_result>
|
||||
|
||||
The result should be a string, which can represent a file or any other output type. You can use this result as input for the next action.
|
||||
For example, if the result of the tool use is an image file, you can use it in the next action like this:
|
||||
|
||||
<tool_use>
|
||||
<name>image_transformer</name>
|
||||
<arguments>{"image": "image_1.jpg"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
Always adhere to this format for the tool use to ensure proper parsing and execution.
|
||||
|
||||
## Tool Use Examples
|
||||
{{ TOOL_USE_EXAMPLES }}
|
||||
|
||||
## Tool Use Available Tools
|
||||
Above example were using notional tools that might not exist for you. You only have access to these tools:
|
||||
{{ AVAILABLE_TOOLS }}
|
||||
|
||||
## Tool Use Rules
|
||||
Here are the rules you should always follow to solve your task:
|
||||
1. Always use the right arguments for the tools. Never use variable names as the action arguments, use the value instead.
|
||||
2. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself.
|
||||
3. If no tool call is needed, just answer the question directly.
|
||||
4. Never re-do a tool call that you previously did with the exact same parameters.
|
||||
5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format.
|
||||
|
||||
# User Instructions
|
||||
{{ USER_SYSTEM_PROMPT }}
|
||||
|
||||
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.`
|
||||
|
||||
/**
|
||||
* 默认工具使用示例(提取自 Cherry Studio)
|
||||
*/
|
||||
const DEFAULT_TOOL_USE_EXAMPLES = `
|
||||
Here are a few examples using notional tools:
|
||||
---
|
||||
User: Generate an image of the oldest person in this document.
|
||||
|
||||
A: I can use the document_qa tool to find out who the oldest person is in the document.
|
||||
<tool_use>
|
||||
<name>document_qa</name>
|
||||
<arguments>{"document": "document.pdf", "question": "Who is the oldest person mentioned?"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>document_qa</name>
|
||||
<result>John Doe, a 55 year old lumberjack living in Newfoundland.</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: I can use the image_generator tool to create a portrait of John Doe.
|
||||
<tool_use>
|
||||
<name>image_generator</name>
|
||||
<arguments>{"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>image_generator</name>
|
||||
<result>image.png</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: the image is generated as image.png
|
||||
|
||||
---
|
||||
User: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
||||
|
||||
A: I can use the python_interpreter tool to calculate the result of the operation.
|
||||
<tool_use>
|
||||
<name>python_interpreter</name>
|
||||
<arguments>{"code": "5 + 3 + 1294.678"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>python_interpreter</name>
|
||||
<result>1302.678</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: The result of the operation is 1302.678.
|
||||
|
||||
---
|
||||
User: "Which city has the highest population , Guangzhou or Shanghai?"
|
||||
|
||||
A: I can use the search tool to find the population of Guangzhou.
|
||||
<tool_use>
|
||||
<name>search</name>
|
||||
<arguments>{"query": "Population Guangzhou"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>search</name>
|
||||
<result>Guangzhou has a population of 15 million inhabitants as of 2021.</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: I can use the search tool to find the population of Shanghai.
|
||||
<tool_use>
|
||||
<name>search</name>
|
||||
<arguments>{"query": "Population Shanghai"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>search</name>
|
||||
<result>26 million (2019)</result>
|
||||
</tool_use_result>
|
||||
Assistant: The population of Shanghai is 26 million, while Guangzhou has a population of 15 million. Therefore, Shanghai has the highest population.`
|
||||
|
||||
/**
|
||||
* 构建可用工具部分(提取自 Cherry Studio)
|
||||
*/
|
||||
function buildAvailableTools(tools: ToolSet): string {
|
||||
const availableTools = Object.keys(tools)
|
||||
.map((toolName: string) => {
|
||||
const tool = tools[toolName]
|
||||
return `
|
||||
<tool>
|
||||
<name>${toolName}</name>
|
||||
<description>${tool.description || ''}</description>
|
||||
<arguments>
|
||||
${tool.inputSchema ? JSON.stringify(tool.inputSchema) : ''}
|
||||
</arguments>
|
||||
</tool>
|
||||
`
|
||||
})
|
||||
.join('\n')
|
||||
return `<tools>
|
||||
${availableTools}
|
||||
</tools>`
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认的系统提示符构建函数(提取自 Cherry Studio)
|
||||
*/
|
||||
function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): string {
|
||||
const availableTools = buildAvailableTools(tools)
|
||||
|
||||
const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES)
|
||||
.replace('{{ AVAILABLE_TOOLS }}', availableTools)
|
||||
.replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt || '')
|
||||
|
||||
return fullPrompt
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认工具解析函数(提取自 Cherry Studio)
|
||||
* 解析 XML 格式的工具调用
|
||||
*/
|
||||
function defaultParseToolUse(content: string, tools: ToolSet): { results: ToolUseResult[]; content: string } {
|
||||
if (!content || !tools || Object.keys(tools).length === 0) {
|
||||
return { results: [], content: content }
|
||||
}
|
||||
|
||||
// 支持两种格式:
|
||||
// 1. 完整的 <tool_use></tool_use> 标签包围的内容
|
||||
// 2. 只有内部内容(从 TagExtractor 提取出来的)
|
||||
|
||||
let contentToProcess = content
|
||||
// 如果内容不包含 <tool_use> 标签,说明是从 TagExtractor 提取的内部内容,需要包装
|
||||
if (!content.includes('<tool_use>')) {
|
||||
contentToProcess = `<tool_use>\n${content}\n</tool_use>`
|
||||
}
|
||||
|
||||
const toolUsePattern =
|
||||
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
|
||||
const results: ToolUseResult[] = []
|
||||
let match
|
||||
let idx = 0
|
||||
|
||||
// Find all tool use blocks
|
||||
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
|
||||
const fullMatch = match[0]
|
||||
const toolName = match[2].trim()
|
||||
const toolArgs = match[4].trim()
|
||||
|
||||
// Try to parse the arguments as JSON
|
||||
let parsedArgs
|
||||
try {
|
||||
parsedArgs = JSON.parse(toolArgs)
|
||||
} catch (error) {
|
||||
// If parsing fails, use the string as is
|
||||
parsedArgs = toolArgs
|
||||
}
|
||||
|
||||
// Find the corresponding tool
|
||||
const tool = tools[toolName]
|
||||
if (!tool) {
|
||||
console.warn(`Tool "${toolName}" not found in available tools`)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add to results array
|
||||
results.push({
|
||||
id: `${toolName}-${idx++}`, // Unique ID for each tool use
|
||||
toolName: toolName,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
})
|
||||
contentToProcess = contentToProcess.replace(fullMatch, '')
|
||||
}
|
||||
return { results, content: contentToProcess }
|
||||
}
|
||||
|
||||
export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config
|
||||
|
||||
return definePlugin({
|
||||
name: 'built-in:prompt-tool-use',
|
||||
transformParams: (params: any, context: AiRequestContext) => {
|
||||
if (!enabled || !params.tools || typeof params.tools !== 'object') {
|
||||
return params
|
||||
}
|
||||
|
||||
context.mcpTools = params.tools
|
||||
console.log('tools stored in context', params.tools)
|
||||
|
||||
// 构建系统提示符
|
||||
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
|
||||
let systemMessage: string | null = systemPrompt
|
||||
console.log('config.context', context)
|
||||
if (config.createSystemMessage) {
|
||||
// 🎯 如果用户提供了自定义处理函数,使用它
|
||||
systemMessage = config.createSystemMessage(systemPrompt, params, context)
|
||||
}
|
||||
|
||||
// 移除 tools,改为 prompt 模式
|
||||
const transformedParams = {
|
||||
...params,
|
||||
...(systemMessage ? { system: systemMessage } : {}),
|
||||
tools: undefined
|
||||
}
|
||||
context.originalParams = transformedParams
|
||||
console.log('transformedParams', transformedParams)
|
||||
return transformedParams
|
||||
},
|
||||
transformStream: (_: any, context: AiRequestContext) => () => {
|
||||
let textBuffer = ''
|
||||
let stepId = ''
|
||||
|
||||
if (!context.mcpTools) {
|
||||
throw new Error('No tools available')
|
||||
}
|
||||
|
||||
// 创建工具执行器和流事件管理器
|
||||
const toolExecutor = new ToolExecutor()
|
||||
const streamEventManager = new StreamEventManager()
|
||||
|
||||
type TOOLS = NonNullable<typeof context.mcpTools>
|
||||
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
|
||||
async transform(
|
||||
chunk: TextStreamPart<TOOLS>,
|
||||
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
|
||||
) {
|
||||
// 收集文本内容
|
||||
if (chunk.type === 'text-delta') {
|
||||
textBuffer += chunk.text || ''
|
||||
stepId = chunk.id || ''
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'text-end' || chunk.type === 'finish-step') {
|
||||
const tools = context.mcpTools
|
||||
if (!tools || Object.keys(tools).length === 0) {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析工具调用
|
||||
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools)
|
||||
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
||||
|
||||
// 如果没有有效的工具调用,直接传递原始事件
|
||||
if (validToolUses.length === 0) {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'text-end') {
|
||||
controller.enqueue({
|
||||
type: 'text-end',
|
||||
id: stepId,
|
||||
providerMetadata: {
|
||||
text: {
|
||||
value: parsedContent
|
||||
}
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
finishReason: 'tool-calls'
|
||||
})
|
||||
|
||||
// 发送步骤开始事件
|
||||
streamEventManager.sendStepStartEvent(controller)
|
||||
|
||||
// 执行工具调用
|
||||
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
|
||||
|
||||
// 发送步骤完成事件
|
||||
streamEventManager.sendStepFinishEvent(controller, chunk)
|
||||
|
||||
// 处理递归调用
|
||||
if (validToolUses.length > 0) {
|
||||
const toolResultsText = toolExecutor.formatToolResults(executedResults)
|
||||
const recursiveParams = streamEventManager.buildRecursiveParams(
|
||||
context,
|
||||
textBuffer,
|
||||
toolResultsText,
|
||||
tools
|
||||
)
|
||||
|
||||
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context, stepId)
|
||||
}
|
||||
|
||||
// 清理状态
|
||||
textBuffer = ''
|
||||
return
|
||||
}
|
||||
|
||||
// 对于其他类型的事件,直接传递
|
||||
controller.enqueue(chunk)
|
||||
},
|
||||
|
||||
flush() {
|
||||
// 流结束时的清理工作
|
||||
console.log('[MCP Prompt] Stream ended, cleaning up...')
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -0,0 +1,196 @@
|
||||
// Copied from https://github.com/vercel/ai/blob/main/packages/ai/core/util/get-potential-start-index.ts
|
||||
|
||||
/**
|
||||
* Returns the index of the start of the searchedText in the text, or null if it
|
||||
* is not found.
|
||||
*/
|
||||
export function getPotentialStartIndex(text: string, searchedText: string): number | null {
|
||||
// Return null immediately if searchedText is empty.
|
||||
if (searchedText.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Check if the searchedText exists as a direct substring of text.
|
||||
const directIndex = text.indexOf(searchedText)
|
||||
if (directIndex !== -1) {
|
||||
return directIndex
|
||||
}
|
||||
|
||||
// Otherwise, look for the largest suffix of "text" that matches
|
||||
// a prefix of "searchedText". We go from the end of text inward.
|
||||
for (let i = text.length - 1; i >= 0; i--) {
|
||||
const suffix = text.substring(i)
|
||||
if (searchedText.startsWith(suffix)) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
export interface TagConfig {
|
||||
openingTag: string
|
||||
closingTag: string
|
||||
separator?: string
|
||||
}
|
||||
|
||||
export interface TagExtractionState {
|
||||
textBuffer: string
|
||||
isInsideTag: boolean
|
||||
isFirstTag: boolean
|
||||
isFirstText: boolean
|
||||
afterSwitch: boolean
|
||||
accumulatedTagContent: string
|
||||
hasTagContent: boolean
|
||||
}
|
||||
|
||||
export interface TagExtractionResult {
|
||||
content: string
|
||||
isTagContent: boolean
|
||||
complete: boolean
|
||||
tagContentExtracted?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用标签提取处理器
|
||||
* 可以处理各种形式的标签对,如 <think>...</think>, <tool_use>...</tool_use> 等
|
||||
*/
|
||||
export class TagExtractor {
|
||||
private config: TagConfig
|
||||
private state: TagExtractionState
|
||||
|
||||
constructor(config: TagConfig) {
|
||||
this.config = config
|
||||
this.state = {
|
||||
textBuffer: '',
|
||||
isInsideTag: false,
|
||||
isFirstTag: true,
|
||||
isFirstText: true,
|
||||
afterSwitch: false,
|
||||
accumulatedTagContent: '',
|
||||
hasTagContent: false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理文本块,返回处理结果
|
||||
*/
|
||||
processText(newText: string): TagExtractionResult[] {
|
||||
this.state.textBuffer += newText
|
||||
const results: TagExtractionResult[] = []
|
||||
|
||||
// 处理标签提取逻辑
|
||||
while (true) {
|
||||
const nextTag = this.state.isInsideTag ? this.config.closingTag : this.config.openingTag
|
||||
const startIndex = getPotentialStartIndex(this.state.textBuffer, nextTag)
|
||||
|
||||
if (startIndex == null) {
|
||||
const content = this.state.textBuffer
|
||||
if (content.length > 0) {
|
||||
results.push({
|
||||
content: this.addPrefix(content),
|
||||
isTagContent: this.state.isInsideTag,
|
||||
complete: false
|
||||
})
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.accumulatedTagContent += this.addPrefix(content)
|
||||
this.state.hasTagContent = true
|
||||
}
|
||||
}
|
||||
this.state.textBuffer = ''
|
||||
break
|
||||
}
|
||||
|
||||
// 处理标签前的内容
|
||||
const contentBeforeTag = this.state.textBuffer.slice(0, startIndex)
|
||||
if (contentBeforeTag.length > 0) {
|
||||
results.push({
|
||||
content: this.addPrefix(contentBeforeTag),
|
||||
isTagContent: this.state.isInsideTag,
|
||||
complete: false
|
||||
})
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.accumulatedTagContent += this.addPrefix(contentBeforeTag)
|
||||
this.state.hasTagContent = true
|
||||
}
|
||||
}
|
||||
|
||||
const foundFullMatch = startIndex + nextTag.length <= this.state.textBuffer.length
|
||||
|
||||
if (foundFullMatch) {
|
||||
// 如果找到完整的标签
|
||||
this.state.textBuffer = this.state.textBuffer.slice(startIndex + nextTag.length)
|
||||
|
||||
// 如果刚刚结束一个标签内容,生成完整的标签内容结果
|
||||
if (this.state.isInsideTag && this.state.hasTagContent) {
|
||||
results.push({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: this.state.accumulatedTagContent
|
||||
})
|
||||
this.state.accumulatedTagContent = ''
|
||||
this.state.hasTagContent = false
|
||||
}
|
||||
|
||||
this.state.isInsideTag = !this.state.isInsideTag
|
||||
this.state.afterSwitch = true
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.isFirstTag = false
|
||||
} else {
|
||||
this.state.isFirstText = false
|
||||
}
|
||||
} else {
|
||||
this.state.textBuffer = this.state.textBuffer.slice(startIndex)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
/**
|
||||
* 完成处理,返回任何剩余的标签内容
|
||||
*/
|
||||
finalize(): TagExtractionResult | null {
|
||||
if (this.state.hasTagContent && this.state.accumulatedTagContent) {
|
||||
const result = {
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: this.state.accumulatedTagContent
|
||||
}
|
||||
this.state.accumulatedTagContent = ''
|
||||
this.state.hasTagContent = false
|
||||
return result
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
private addPrefix(text: string): string {
|
||||
const needsPrefix =
|
||||
this.state.afterSwitch && (this.state.isInsideTag ? !this.state.isFirstTag : !this.state.isFirstText)
|
||||
|
||||
const prefix = needsPrefix && this.config.separator ? this.config.separator : ''
|
||||
this.state.afterSwitch = false
|
||||
return prefix + text
|
||||
}
|
||||
|
||||
/**
|
||||
* 重置状态
|
||||
*/
|
||||
reset(): void {
|
||||
this.state = {
|
||||
textBuffer: '',
|
||||
isInsideTag: false,
|
||||
isFirstTag: true,
|
||||
isFirstText: true,
|
||||
afterSwitch: false,
|
||||
accumulatedTagContent: '',
|
||||
hasTagContent: false
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,33 @@
|
||||
import { ToolSet } from 'ai'
|
||||
|
||||
import { AiRequestContext } from '../..'
|
||||
|
||||
/**
|
||||
* 解析结果类型
|
||||
* 表示从AI响应中解析出的工具使用意图
|
||||
*/
|
||||
export interface ToolUseResult {
|
||||
id: string
|
||||
toolName: string
|
||||
arguments: any
|
||||
status: 'pending' | 'invoking' | 'done' | 'error'
|
||||
}
|
||||
|
||||
export interface BaseToolUsePluginConfig {
|
||||
enabled?: boolean
|
||||
}
|
||||
|
||||
export interface PromptToolUseConfig extends BaseToolUsePluginConfig {
|
||||
// 自定义系统提示符构建函数(可选,有默认实现)
|
||||
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
|
||||
// 自定义工具解析函数(可选,有默认实现)
|
||||
parseToolUse?: (content: string, tools: ToolSet) => { results: ToolUseResult[]; content: string }
|
||||
createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null
|
||||
}
|
||||
|
||||
/**
|
||||
* 扩展的 AI 请求上下文,支持 MCP 工具存储
|
||||
*/
|
||||
export interface ToolUseRequestContext extends AiRequestContext {
|
||||
mcpTools: ToolSet
|
||||
}
|
||||
@ -0,0 +1,67 @@
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
import { google } from '@ai-sdk/google'
|
||||
import { openai } from '@ai-sdk/openai'
|
||||
|
||||
import { ProviderOptionsMap } from '../../../options/types'
|
||||
|
||||
/**
|
||||
* 从 AI SDK 的工具函数中提取参数类型,以确保类型安全。
|
||||
*/
|
||||
type OpenAISearchConfig = Parameters<typeof openai.tools.webSearchPreview>[0]
|
||||
type AnthropicSearchConfig = Parameters<typeof anthropic.tools.webSearch_20250305>[0]
|
||||
type GoogleSearchConfig = Parameters<typeof google.tools.googleSearch>[0]
|
||||
|
||||
/**
|
||||
* 插件初始化时接收的完整配置对象
|
||||
*
|
||||
* 其结构与 ProviderOptions 保持一致,方便上游统一管理配置
|
||||
*/
|
||||
export interface WebSearchPluginConfig {
|
||||
openai?: OpenAISearchConfig
|
||||
anthropic?: AnthropicSearchConfig
|
||||
xai?: ProviderOptionsMap['xai']['searchParameters']
|
||||
google?: GoogleSearchConfig
|
||||
'google-vertex'?: GoogleSearchConfig
|
||||
}
|
||||
|
||||
/**
|
||||
* 插件的默认配置
|
||||
*/
|
||||
export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
|
||||
google: {},
|
||||
'google-vertex': {},
|
||||
openai: {},
|
||||
xai: {
|
||||
mode: 'on',
|
||||
returnCitations: true,
|
||||
maxSearchResults: 5,
|
||||
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
|
||||
},
|
||||
anthropic: {
|
||||
maxUses: 5
|
||||
}
|
||||
}
|
||||
|
||||
export type WebSearchToolOutputSchema = {
|
||||
// Anthropic 工具 - 手动定义
|
||||
anthropicWebSearch: Array<{
|
||||
url: string
|
||||
title: string
|
||||
pageAge: string | null
|
||||
encryptedContent: string
|
||||
type: string
|
||||
}>
|
||||
|
||||
// OpenAI 工具 - 基于实际输出
|
||||
openaiWebSearch: {
|
||||
status: 'completed' | 'failed'
|
||||
}
|
||||
|
||||
// Google 工具
|
||||
googleSearch: {
|
||||
webSearchQueries?: string[]
|
||||
groundingChunks?: Array<{
|
||||
web?: { uri: string; title: string }
|
||||
}>
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,69 @@
|
||||
/**
|
||||
* Web Search Plugin
|
||||
* 提供统一的网络搜索能力,支持多个 AI Provider
|
||||
*/
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
import { google } from '@ai-sdk/google'
|
||||
import { openai } from '@ai-sdk/openai'
|
||||
|
||||
import { createXaiOptions, mergeProviderOptions } from '../../../options'
|
||||
import { definePlugin } from '../../'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import { DEFAULT_WEB_SEARCH_CONFIG, WebSearchPluginConfig } from './helper'
|
||||
|
||||
/**
|
||||
* 网络搜索插件
|
||||
*
|
||||
* @param config - 在插件初始化时传入的静态配置
|
||||
*/
|
||||
export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEARCH_CONFIG) =>
|
||||
definePlugin({
|
||||
name: 'webSearch',
|
||||
enforce: 'pre',
|
||||
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
const { providerId } = context
|
||||
switch (providerId) {
|
||||
case 'openai': {
|
||||
if (config.openai) {
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'anthropic': {
|
||||
if (config.anthropic) {
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'google': {
|
||||
// case 'google-vertex':
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search = google.tools.googleSearch(config.google || {})
|
||||
break
|
||||
}
|
||||
|
||||
case 'xai': {
|
||||
if (config.xai) {
|
||||
const searchOptions = createXaiOptions({
|
||||
searchParameters: { ...config.xai, mode: 'on' }
|
||||
})
|
||||
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
})
|
||||
|
||||
// 导出类型定义供开发者使用
|
||||
export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper'
|
||||
|
||||
// 默认导出
|
||||
export default webSearchPlugin
|
||||
32
packages/aiCore/src/core/plugins/index.ts
Normal file
32
packages/aiCore/src/core/plugins/index.ts
Normal file
@ -0,0 +1,32 @@
|
||||
// 核心类型和接口
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
|
||||
import type { ProviderId } from '../providers'
|
||||
import type { AiPlugin, AiRequestContext } from './types'
|
||||
|
||||
// 插件管理器
|
||||
export { PluginManager } from './manager'
|
||||
|
||||
// 工具函数
|
||||
export function createContext<T extends ProviderId>(
|
||||
providerId: T,
|
||||
modelId: string,
|
||||
originalParams: any
|
||||
): AiRequestContext {
|
||||
return {
|
||||
providerId,
|
||||
modelId,
|
||||
originalParams,
|
||||
metadata: {},
|
||||
startTime: Date.now(),
|
||||
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||
// 占位
|
||||
recursiveCall: () => Promise.resolve(null)
|
||||
}
|
||||
}
|
||||
|
||||
// 插件构建器 - 便于创建插件
|
||||
export function definePlugin(plugin: AiPlugin): AiPlugin
|
||||
export function definePlugin<T extends (...args: any[]) => AiPlugin>(pluginFactory: T): T
|
||||
export function definePlugin(plugin: AiPlugin | ((...args: any[]) => AiPlugin)) {
|
||||
return plugin
|
||||
}
|
||||
184
packages/aiCore/src/core/plugins/manager.ts
Normal file
184
packages/aiCore/src/core/plugins/manager.ts
Normal file
@ -0,0 +1,184 @@
|
||||
import { AiPlugin, AiRequestContext } from './types'
|
||||
|
||||
/**
|
||||
* 插件管理器
|
||||
*/
|
||||
export class PluginManager {
|
||||
private plugins: AiPlugin[] = []
|
||||
|
||||
constructor(plugins: AiPlugin[] = []) {
|
||||
this.plugins = this.sortPlugins(plugins)
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加插件
|
||||
*/
|
||||
use(plugin: AiPlugin): this {
|
||||
this.plugins = this.sortPlugins([...this.plugins, plugin])
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除插件
|
||||
*/
|
||||
remove(pluginName: string): this {
|
||||
this.plugins = this.plugins.filter((p) => p.name !== pluginName)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 插件排序:pre -> normal -> post
|
||||
*/
|
||||
private sortPlugins(plugins: AiPlugin[]): AiPlugin[] {
|
||||
const pre: AiPlugin[] = []
|
||||
const normal: AiPlugin[] = []
|
||||
const post: AiPlugin[] = []
|
||||
|
||||
plugins.forEach((plugin) => {
|
||||
if (plugin.enforce === 'pre') {
|
||||
pre.push(plugin)
|
||||
} else if (plugin.enforce === 'post') {
|
||||
post.push(plugin)
|
||||
} else {
|
||||
normal.push(plugin)
|
||||
}
|
||||
})
|
||||
|
||||
return [...pre, ...normal, ...post]
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 First 钩子 - 返回第一个有效结果
|
||||
*/
|
||||
async executeFirst<T>(
|
||||
hookName: 'resolveModel' | 'loadTemplate',
|
||||
arg: any,
|
||||
context: AiRequestContext
|
||||
): Promise<T | null> {
|
||||
for (const plugin of this.plugins) {
|
||||
const hook = plugin[hookName]
|
||||
if (hook) {
|
||||
const result = await hook(arg, context)
|
||||
if (result !== null && result !== undefined) {
|
||||
return result as T
|
||||
}
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 Sequential 钩子 - 链式数据转换
|
||||
*/
|
||||
async executeSequential<T>(
|
||||
hookName: 'transformParams' | 'transformResult',
|
||||
initialValue: T,
|
||||
context: AiRequestContext
|
||||
): Promise<T> {
|
||||
let result = initialValue
|
||||
|
||||
for (const plugin of this.plugins) {
|
||||
const hook = plugin[hookName]
|
||||
if (hook) {
|
||||
result = await hook<T>(result, context)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 ConfigureContext 钩子 - 串行配置上下文
|
||||
*/
|
||||
async executeConfigureContext(context: AiRequestContext): Promise<void> {
|
||||
for (const plugin of this.plugins) {
|
||||
const hook = plugin.configureContext
|
||||
if (hook) {
|
||||
await hook(context)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 Parallel 钩子 - 并行副作用
|
||||
*/
|
||||
async executeParallel(
|
||||
hookName: 'onRequestStart' | 'onRequestEnd' | 'onError',
|
||||
context: AiRequestContext,
|
||||
result?: any,
|
||||
error?: Error
|
||||
): Promise<void> {
|
||||
const promises = this.plugins
|
||||
.map((plugin) => {
|
||||
const hook = plugin[hookName]
|
||||
if (!hook) return null
|
||||
|
||||
if (hookName === 'onError' && error) {
|
||||
return (hook as any)(error, context)
|
||||
} else if (hookName === 'onRequestEnd' && result !== undefined) {
|
||||
return (hook as any)(context, result)
|
||||
} else if (hookName === 'onRequestStart') {
|
||||
return (hook as any)(context)
|
||||
}
|
||||
return null
|
||||
})
|
||||
.filter(Boolean)
|
||||
|
||||
// 使用 Promise.all 而不是 allSettled,让插件错误能够抛出
|
||||
await Promise.all(promises)
|
||||
}
|
||||
|
||||
/**
|
||||
* 收集所有流转换器(返回数组,AI SDK 原生支持)
|
||||
*/
|
||||
collectStreamTransforms(params: any, context: AiRequestContext) {
|
||||
return this.plugins
|
||||
.filter((plugin) => plugin.transformStream)
|
||||
.map((plugin) => plugin.transformStream?.(params, context))
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有插件信息
|
||||
*/
|
||||
getPlugins(): AiPlugin[] {
|
||||
return [...this.plugins]
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取插件统计信息
|
||||
*/
|
||||
getStats() {
|
||||
const stats = {
|
||||
total: this.plugins.length,
|
||||
pre: 0,
|
||||
normal: 0,
|
||||
post: 0,
|
||||
hooks: {
|
||||
resolveModel: 0,
|
||||
loadTemplate: 0,
|
||||
transformParams: 0,
|
||||
transformResult: 0,
|
||||
onRequestStart: 0,
|
||||
onRequestEnd: 0,
|
||||
onError: 0,
|
||||
transformStream: 0
|
||||
}
|
||||
}
|
||||
|
||||
this.plugins.forEach((plugin) => {
|
||||
// 统计 enforce 类型
|
||||
if (plugin.enforce === 'pre') stats.pre++
|
||||
else if (plugin.enforce === 'post') stats.post++
|
||||
else stats.normal++
|
||||
|
||||
// 统计钩子数量
|
||||
Object.keys(stats.hooks).forEach((hookName) => {
|
||||
if (plugin[hookName as keyof AiPlugin]) {
|
||||
stats.hooks[hookName as keyof typeof stats.hooks]++
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
return stats
|
||||
}
|
||||
}
|
||||
79
packages/aiCore/src/core/plugins/types.ts
Normal file
79
packages/aiCore/src/core/plugins/types.ts
Normal file
@ -0,0 +1,79 @@
|
||||
import type { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { type ProviderId } from '../providers/types'
|
||||
|
||||
/**
|
||||
* 递归调用函数类型
|
||||
* 使用 any 是因为递归调用时参数和返回类型可能完全不同
|
||||
*/
|
||||
export type RecursiveCallFn = (newParams: any) => Promise<any>
|
||||
|
||||
/**
|
||||
* AI 请求上下文
|
||||
*/
|
||||
export interface AiRequestContext {
|
||||
providerId: ProviderId
|
||||
modelId: string
|
||||
originalParams: any
|
||||
metadata: Record<string, any>
|
||||
startTime: number
|
||||
requestId: string
|
||||
recursiveCall: RecursiveCallFn
|
||||
isRecursiveCall?: boolean
|
||||
mcpTools?: ToolSet
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
/**
|
||||
* 钩子分类
|
||||
*/
|
||||
export interface AiPlugin {
|
||||
name: string
|
||||
enforce?: 'pre' | 'post'
|
||||
|
||||
// 【First】首个钩子 - 只执行第一个返回值的插件
|
||||
resolveModel?: (
|
||||
modelId: string,
|
||||
context: AiRequestContext
|
||||
) => Promise<LanguageModel | ImageModelV2 | null> | LanguageModel | ImageModelV2 | null
|
||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
||||
|
||||
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||
configureContext?: (context: AiRequestContext) => void | Promise<void>
|
||||
transformParams?: <T>(params: T, context: AiRequestContext) => T | Promise<T>
|
||||
transformResult?: <T>(result: T, context: AiRequestContext) => T | Promise<T>
|
||||
|
||||
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
|
||||
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
|
||||
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
|
||||
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
|
||||
|
||||
// 【Stream】流处理 - 直接使用 AI SDK
|
||||
transformStream?: (
|
||||
params: any,
|
||||
context: AiRequestContext
|
||||
) => <TOOLS extends ToolSet>(options?: {
|
||||
tools: TOOLS
|
||||
stopStream: () => void
|
||||
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
|
||||
|
||||
// AI SDK 原生中间件
|
||||
// aiSdkMiddlewares?: LanguageModelV1Middleware[]
|
||||
}
|
||||
|
||||
/**
|
||||
* 插件管理器配置
|
||||
*/
|
||||
export interface PluginManagerConfig {
|
||||
plugins: AiPlugin[]
|
||||
context: Partial<AiRequestContext>
|
||||
}
|
||||
|
||||
/**
|
||||
* 钩子执行结果
|
||||
*/
|
||||
export interface HookResult<T = any> {
|
||||
value: T
|
||||
stop?: boolean
|
||||
}
|
||||
101
packages/aiCore/src/core/providers/HubProvider.ts
Normal file
101
packages/aiCore/src/core/providers/HubProvider.ts
Normal file
@ -0,0 +1,101 @@
|
||||
/**
|
||||
* Hub Provider - 支持路由到多个底层provider
|
||||
*
|
||||
* 支持格式: hubId:providerId:modelId
|
||||
* 例如: aihubmix:anthropic:claude-3.5-sonnet
|
||||
*/
|
||||
|
||||
import { ProviderV2 } from '@ai-sdk/provider'
|
||||
import { customProvider } from 'ai'
|
||||
|
||||
import { globalRegistryManagement } from './RegistryManagement'
|
||||
import type { AiSdkMethodName, AiSdkModelReturn, AiSdkModelType } from './types'
|
||||
|
||||
export interface HubProviderConfig {
|
||||
/** Hub的唯一标识符 */
|
||||
hubId: string
|
||||
/** 是否启用调试日志 */
|
||||
debug?: boolean
|
||||
}
|
||||
|
||||
export class HubProviderError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public readonly hubId: string,
|
||||
public readonly providerId?: string,
|
||||
public readonly originalError?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'HubProviderError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析Hub模型ID
|
||||
*/
|
||||
function parseHubModelId(modelId: string): { provider: string; actualModelId: string } {
|
||||
const parts = modelId.split(':')
|
||||
if (parts.length !== 2) {
|
||||
throw new HubProviderError(`Invalid hub model ID format. Expected "provider:modelId", got: ${modelId}`, 'unknown')
|
||||
}
|
||||
return {
|
||||
provider: parts[0],
|
||||
actualModelId: parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建Hub Provider
|
||||
*/
|
||||
export function createHubProvider(config: HubProviderConfig): ProviderV2 {
|
||||
const { hubId } = config
|
||||
|
||||
function getTargetProvider(providerId: string): ProviderV2 {
|
||||
// 从全局注册表获取provider实例
|
||||
try {
|
||||
const provider = globalRegistryManagement.getProvider(providerId)
|
||||
if (!provider) {
|
||||
throw new HubProviderError(
|
||||
`Provider "${providerId}" is not initialized. Please call initializeProvider("${providerId}", options) first.`,
|
||||
hubId,
|
||||
providerId
|
||||
)
|
||||
}
|
||||
return provider
|
||||
} catch (error) {
|
||||
throw new HubProviderError(
|
||||
`Failed to get provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
hubId,
|
||||
providerId,
|
||||
error instanceof Error ? error : undefined
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
function resolveModel<T extends AiSdkModelType>(
|
||||
modelId: string,
|
||||
modelType: T,
|
||||
methodName: AiSdkMethodName<T>
|
||||
): AiSdkModelReturn<T> {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
const fn = targetProvider[methodName] as (id: string) => AiSdkModelReturn<T>
|
||||
|
||||
if (!fn) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider)
|
||||
}
|
||||
|
||||
return fn(actualModelId)
|
||||
}
|
||||
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
languageModel: (modelId: string) => resolveModel(modelId, 'text', 'languageModel'),
|
||||
textEmbeddingModel: (modelId: string) => resolveModel(modelId, 'embedding', 'textEmbeddingModel'),
|
||||
imageModel: (modelId: string) => resolveModel(modelId, 'image', 'imageModel'),
|
||||
transcriptionModel: (modelId: string) => resolveModel(modelId, 'transcription', 'transcriptionModel'),
|
||||
speechModel: (modelId: string) => resolveModel(modelId, 'speech', 'speechModel')
|
||||
}
|
||||
})
|
||||
}
|
||||
221
packages/aiCore/src/core/providers/RegistryManagement.ts
Normal file
221
packages/aiCore/src/core/providers/RegistryManagement.ts
Normal file
@ -0,0 +1,221 @@
|
||||
/**
|
||||
* Provider 注册表管理器
|
||||
* 纯粹的管理功能:存储、检索已配置好的 provider 实例
|
||||
* 基于 AI SDK 原生的 createProviderRegistry
|
||||
*/
|
||||
|
||||
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, ProviderV2 } from '@ai-sdk/provider'
|
||||
import { createProviderRegistry, type ProviderRegistryProvider } from 'ai'
|
||||
|
||||
type PROVIDERS = Record<string, ProviderV2>
|
||||
|
||||
export const DEFAULT_SEPARATOR = '|'
|
||||
|
||||
// export type MODEL_ID = `${string}${typeof DEFAULT_SEPARATOR}${string}`
|
||||
|
||||
export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARATOR> {
|
||||
private providers: PROVIDERS = {}
|
||||
private aliases: Set<string> = new Set() // 记录哪些key是别名
|
||||
private separator: SEPARATOR
|
||||
private registry: ProviderRegistryProvider<PROVIDERS, SEPARATOR> | null = null
|
||||
|
||||
constructor(options: { separator: SEPARATOR } = { separator: DEFAULT_SEPARATOR as SEPARATOR }) {
|
||||
this.separator = options.separator
|
||||
}
|
||||
|
||||
/**
|
||||
* 注册已配置好的 provider 实例
|
||||
*/
|
||||
registerProvider(id: string, provider: ProviderV2, aliases?: string[]): this {
|
||||
// 注册主provider
|
||||
this.providers[id] = provider
|
||||
|
||||
// 注册别名(都指向同一个provider实例)
|
||||
if (aliases) {
|
||||
aliases.forEach((alias) => {
|
||||
this.providers[alias] = provider // 直接存储引用
|
||||
this.aliases.add(alias) // 标记为别名
|
||||
})
|
||||
}
|
||||
|
||||
this.rebuildRegistry()
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取已注册的provider实例
|
||||
*/
|
||||
getProvider(id: string): ProviderV2 | undefined {
|
||||
return this.providers[id]
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量注册 providers
|
||||
*/
|
||||
registerProviders(providers: Record<string, ProviderV2>): this {
|
||||
Object.assign(this.providers, providers)
|
||||
this.rebuildRegistry()
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除 provider(同时清理相关别名)
|
||||
*/
|
||||
unregisterProvider(id: string): this {
|
||||
const provider = this.providers[id]
|
||||
if (!provider) return this
|
||||
|
||||
// 如果移除的是真实ID,需要清理所有指向它的别名
|
||||
if (!this.aliases.has(id)) {
|
||||
// 找到所有指向此provider的别名并删除
|
||||
const aliasesToRemove: string[] = []
|
||||
this.aliases.forEach((alias) => {
|
||||
if (this.providers[alias] === provider) {
|
||||
aliasesToRemove.push(alias)
|
||||
}
|
||||
})
|
||||
|
||||
aliasesToRemove.forEach((alias) => {
|
||||
delete this.providers[alias]
|
||||
this.aliases.delete(alias)
|
||||
})
|
||||
} else {
|
||||
// 如果移除的是别名,只删除别名记录
|
||||
this.aliases.delete(id)
|
||||
}
|
||||
|
||||
delete this.providers[id]
|
||||
this.rebuildRegistry()
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 立即重建 registry - 每次变更都重建
|
||||
*/
|
||||
private rebuildRegistry(): void {
|
||||
if (Object.keys(this.providers).length === 0) {
|
||||
this.registry = null
|
||||
return
|
||||
}
|
||||
|
||||
this.registry = createProviderRegistry<PROVIDERS, SEPARATOR>(this.providers, {
|
||||
separator: this.separator
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取语言模型 - AI SDK 原生方法
|
||||
*/
|
||||
languageModel(id: `${string}${SEPARATOR}${string}`): LanguageModelV2 {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.languageModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取文本嵌入模型 - AI SDK 原生方法
|
||||
*/
|
||||
textEmbeddingModel(id: `${string}${SEPARATOR}${string}`): EmbeddingModelV2<string> {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.textEmbeddingModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取图像模型 - AI SDK 原生方法
|
||||
*/
|
||||
imageModel(id: `${string}${SEPARATOR}${string}`): ImageModelV2 {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.imageModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取转录模型 - AI SDK 原生方法
|
||||
*/
|
||||
transcriptionModel(id: `${string}${SEPARATOR}${string}`): any {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.transcriptionModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取语音模型 - AI SDK 原生方法
|
||||
*/
|
||||
speechModel(id: `${string}${SEPARATOR}${string}`): any {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.speechModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取已注册的 provider 列表
|
||||
*/
|
||||
getRegisteredProviders(): string[] {
|
||||
return Object.keys(this.providers)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有已注册的 providers
|
||||
*/
|
||||
hasProviders(): boolean {
|
||||
return Object.keys(this.providers).length > 0
|
||||
}
|
||||
|
||||
/**
|
||||
* 清除所有 providers
|
||||
*/
|
||||
clear(): this {
|
||||
this.providers = {}
|
||||
this.aliases.clear()
|
||||
this.registry = null
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析真实的Provider ID(供getAiSdkProviderId使用)
|
||||
* 如果传入的是别名,返回真实的Provider ID
|
||||
* 如果传入的是真实ID,直接返回
|
||||
*/
|
||||
resolveProviderId(id: string): string {
|
||||
if (!this.aliases.has(id)) return id // 不是别名,直接返回
|
||||
|
||||
// 是别名,找到真实ID
|
||||
const targetProvider = this.providers[id]
|
||||
for (const [realId, provider] of Object.entries(this.providers)) {
|
||||
if (provider === targetProvider && !this.aliases.has(realId)) {
|
||||
return realId
|
||||
}
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否为别名
|
||||
*/
|
||||
isAlias(id: string): boolean {
|
||||
return this.aliases.has(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有别名映射关系
|
||||
*/
|
||||
getAllAliases(): Record<string, string> {
|
||||
const result: Record<string, string> = {}
|
||||
this.aliases.forEach((alias) => {
|
||||
result[alias] = this.resolveProviderId(alias)
|
||||
})
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 全局注册表管理器实例
|
||||
* 使用 | 作为分隔符,因为 : 会和 :free 等suffix冲突
|
||||
*/
|
||||
export const globalRegistryManagement = new RegistryManagement()
|
||||
@ -0,0 +1,632 @@
|
||||
/**
|
||||
* 测试真正的 AiProviderRegistry 功能
|
||||
*/
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
// 模拟 AI SDK
|
||||
vi.mock('@ai-sdk/openai', () => ({
|
||||
createOpenAI: vi.fn(() => ({ name: 'openai-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/anthropic', () => ({
|
||||
createAnthropic: vi.fn(() => ({ name: 'anthropic-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/azure', () => ({
|
||||
createAzure: vi.fn(() => ({ name: 'azure-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/deepseek', () => ({
|
||||
createDeepSeek: vi.fn(() => ({ name: 'deepseek-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/google', () => ({
|
||||
createGoogleGenerativeAI: vi.fn(() => ({ name: 'google-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/openai-compatible', () => ({
|
||||
createOpenAICompatible: vi.fn(() => ({ name: 'openai-compatible-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/xai', () => ({
|
||||
createXai: vi.fn(() => ({ name: 'xai-mock' }))
|
||||
}))
|
||||
|
||||
import {
|
||||
cleanup,
|
||||
clearAllProviders,
|
||||
createAndRegisterProvider,
|
||||
createProvider,
|
||||
getAllProviderConfigAliases,
|
||||
getAllProviderConfigs,
|
||||
getInitializedProviders,
|
||||
getLanguageModel,
|
||||
getProviderConfig,
|
||||
getProviderConfigByAlias,
|
||||
getSupportedProviders,
|
||||
hasInitializedProviders,
|
||||
hasProviderConfig,
|
||||
hasProviderConfigByAlias,
|
||||
isProviderConfigAlias,
|
||||
ProviderInitializationError,
|
||||
providerRegistry,
|
||||
registerMultipleProviderConfigs,
|
||||
registerProvider,
|
||||
registerProviderConfig,
|
||||
resolveProviderConfigId
|
||||
} from '../registry'
|
||||
import type { ProviderConfig } from '../schemas'
|
||||
|
||||
describe('Provider Registry 功能测试', () => {
|
||||
beforeEach(() => {
|
||||
// 清理状态
|
||||
cleanup()
|
||||
})
|
||||
|
||||
describe('基础功能', () => {
|
||||
it('能够获取支持的 providers 列表', () => {
|
||||
const providers = getSupportedProviders()
|
||||
expect(Array.isArray(providers)).toBe(true)
|
||||
expect(providers.length).toBeGreaterThan(0)
|
||||
|
||||
// 检查返回的数据结构
|
||||
providers.forEach((provider) => {
|
||||
expect(provider).toHaveProperty('id')
|
||||
expect(provider).toHaveProperty('name')
|
||||
expect(typeof provider.id).toBe('string')
|
||||
expect(typeof provider.name).toBe('string')
|
||||
})
|
||||
|
||||
// 包含基础 providers
|
||||
const providerIds = providers.map((p) => p.id)
|
||||
expect(providerIds).toContain('openai')
|
||||
expect(providerIds).toContain('anthropic')
|
||||
expect(providerIds).toContain('google')
|
||||
})
|
||||
|
||||
it('能够获取已初始化的 providers', () => {
|
||||
// 初始状态下没有已初始化的 providers
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasInitializedProviders()).toBe(false)
|
||||
})
|
||||
|
||||
it('能够访问全局注册管理器', () => {
|
||||
expect(providerRegistry).toBeDefined()
|
||||
expect(typeof providerRegistry.clear).toBe('function')
|
||||
expect(typeof providerRegistry.getRegisteredProviders).toBe('function')
|
||||
expect(typeof providerRegistry.hasProviders).toBe('function')
|
||||
})
|
||||
|
||||
it('能够获取语言模型', () => {
|
||||
// 在没有注册 provider 的情况下,这个函数应该会抛出错误
|
||||
expect(() => getLanguageModel('non-existent')).toThrow('No providers registered')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider 配置注册', () => {
|
||||
it('能够注册自定义 provider 配置', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(() => ({ name: 'custom' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
expect(hasProviderConfig('custom-provider')).toBe(true)
|
||||
expect(getProviderConfig('custom-provider')).toEqual(config)
|
||||
})
|
||||
|
||||
it('能够注册带别名的 provider 配置', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'custom-provider-with-aliases',
|
||||
name: 'Custom Provider with Aliases',
|
||||
creator: vi.fn(() => ({ name: 'custom-aliased' })),
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['alias-1', 'alias-2']
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
|
||||
expect(getProviderConfigByAlias('alias-1')).toEqual(config)
|
||||
expect(resolveProviderConfigId('alias-1')).toBe('custom-provider-with-aliases')
|
||||
})
|
||||
|
||||
it('拒绝无效的配置', () => {
|
||||
// 缺少必要字段
|
||||
const invalidConfig = {
|
||||
id: 'invalid-provider'
|
||||
// 缺少 name, creator 等
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(invalidConfig as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('能够批量注册 provider 配置', () => {
|
||||
const configs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'provider-1',
|
||||
name: 'Provider 1',
|
||||
creator: vi.fn(() => ({ name: 'provider-1' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'provider-2',
|
||||
name: 'Provider 2',
|
||||
creator: vi.fn(() => ({ name: 'provider-2' })),
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: '', // 无效配置
|
||||
name: 'Invalid Provider',
|
||||
creator: vi.fn(() => ({ name: 'invalid' })),
|
||||
supportsImageGeneration: false
|
||||
} as any
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(configs)
|
||||
expect(successCount).toBe(2) // 只有前两个成功
|
||||
|
||||
expect(hasProviderConfig('provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('provider-2')).toBe(true)
|
||||
expect(hasProviderConfig('')).toBe(false)
|
||||
})
|
||||
|
||||
it('能够获取所有配置和别名信息', () => {
|
||||
// 注册一些配置
|
||||
registerProviderConfig({
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['test-alias']
|
||||
})
|
||||
|
||||
const allConfigs = getAllProviderConfigs()
|
||||
expect(Array.isArray(allConfigs)).toBe(true)
|
||||
expect(allConfigs.some((config) => config.id === 'test-provider')).toBe(true)
|
||||
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
expect(aliases['test-alias']).toBe('test-provider')
|
||||
expect(isProviderConfigAlias('test-alias')).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider 创建和注册', () => {
|
||||
it('能够创建 provider 实例', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-create-provider',
|
||||
name: 'Test Create Provider',
|
||||
creator: vi.fn(() => ({ name: 'test-created' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 创建 provider 实例
|
||||
const provider = await createProvider('test-create-provider', { apiKey: 'test' })
|
||||
expect(provider).toBeDefined()
|
||||
expect(config.creator).toHaveBeenCalledWith({ apiKey: 'test' })
|
||||
})
|
||||
|
||||
it('能够注册 provider 到全局管理器', () => {
|
||||
const mockProvider = { name: 'mock-provider' }
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-register-provider',
|
||||
name: 'Test Register Provider',
|
||||
creator: vi.fn(() => mockProvider),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 注册 provider 到全局管理器
|
||||
const success = registerProvider('test-register-provider', mockProvider)
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('test-register-provider')
|
||||
expect(hasInitializedProviders()).toBe(true)
|
||||
})
|
||||
|
||||
it('能够一步完成创建和注册', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-create-and-register',
|
||||
name: 'Test Create and Register',
|
||||
creator: vi.fn(() => ({ name: 'test-both' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 一步完成创建和注册
|
||||
const success = await createAndRegisterProvider('test-create-and-register', { apiKey: 'test' })
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('test-create-and-register')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Registry 管理', () => {
|
||||
it('能够清理所有配置和注册的 providers', () => {
|
||||
// 注册一些配置
|
||||
registerProviderConfig({
|
||||
id: 'temp-provider',
|
||||
name: 'Temp Provider',
|
||||
creator: vi.fn(() => ({ name: 'temp' })),
|
||||
supportsImageGeneration: false
|
||||
})
|
||||
|
||||
expect(hasProviderConfig('temp-provider')).toBe(true)
|
||||
|
||||
// 清理
|
||||
cleanup()
|
||||
|
||||
expect(hasProviderConfig('temp-provider')).toBe(false)
|
||||
// 但基础配置应该重新加载
|
||||
expect(hasProviderConfig('openai')).toBe(true) // 基础 providers 会重新初始化
|
||||
})
|
||||
|
||||
it('能够单独清理已注册的 providers', () => {
|
||||
// 清理所有 providers
|
||||
clearAllProviders()
|
||||
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasInitializedProviders()).toBe(false)
|
||||
})
|
||||
|
||||
it('ProviderInitializationError 错误类工作正常', () => {
|
||||
const error = new ProviderInitializationError('Test error', 'test-provider')
|
||||
expect(error.message).toBe('Test error')
|
||||
expect(error.providerId).toBe('test-provider')
|
||||
expect(error.name).toBe('ProviderInitializationError')
|
||||
})
|
||||
})
|
||||
|
||||
describe('错误处理', () => {
|
||||
it('优雅处理空配置', () => {
|
||||
const success = registerProviderConfig(null as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('优雅处理未定义配置', () => {
|
||||
const success = registerProviderConfig(undefined as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('处理空字符串 ID', () => {
|
||||
const config = {
|
||||
id: '',
|
||||
name: 'Empty ID Provider',
|
||||
creator: vi.fn(() => ({ name: 'empty' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('处理创建不存在配置的 provider', async () => {
|
||||
await expect(createProvider('non-existent-provider', {})).rejects.toThrow(
|
||||
'ProviderConfig not found for id: non-existent-provider'
|
||||
)
|
||||
})
|
||||
|
||||
it('处理注册不存在配置的 provider', () => {
|
||||
const mockProvider = { name: 'mock' }
|
||||
const success = registerProvider('non-existent-provider', mockProvider)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('处理获取不存在配置的情况', () => {
|
||||
expect(getProviderConfig('non-existent')).toBeUndefined()
|
||||
expect(getProviderConfigByAlias('non-existent-alias')).toBeUndefined()
|
||||
expect(hasProviderConfig('non-existent')).toBe(false)
|
||||
expect(hasProviderConfigByAlias('non-existent-alias')).toBe(false)
|
||||
})
|
||||
|
||||
it('处理批量注册时的部分失败', () => {
|
||||
const mixedConfigs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'valid-provider-1',
|
||||
name: 'Valid Provider 1',
|
||||
creator: vi.fn(() => ({ name: 'valid-1' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: '', // 无效配置
|
||||
name: 'Invalid Provider',
|
||||
creator: vi.fn(() => ({ name: 'invalid' })),
|
||||
supportsImageGeneration: false
|
||||
} as any,
|
||||
{
|
||||
id: 'valid-provider-2',
|
||||
name: 'Valid Provider 2',
|
||||
creator: vi.fn(() => ({ name: 'valid-2' })),
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(mixedConfigs)
|
||||
expect(successCount).toBe(2) // 只有两个有效配置成功
|
||||
|
||||
expect(hasProviderConfig('valid-provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('valid-provider-2')).toBe(true)
|
||||
expect(hasProviderConfig('')).toBe(false)
|
||||
})
|
||||
|
||||
it('处理动态导入失败的情况', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'import-test-provider',
|
||||
name: 'Import Test Provider',
|
||||
import: vi.fn().mockRejectedValue(new Error('Import failed')),
|
||||
creatorFunctionName: 'createTest',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
registerProviderConfig(config)
|
||||
|
||||
await expect(createProvider('import-test-provider', {})).rejects.toThrow('Import failed')
|
||||
})
|
||||
})
|
||||
|
||||
describe('集成测试', () => {
|
||||
it('正确处理复杂的配置、创建、注册和清理场景', async () => {
|
||||
// 初始状态验证
|
||||
const initialConfigs = getAllProviderConfigs()
|
||||
expect(initialConfigs.length).toBeGreaterThan(0) // 有基础配置
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
|
||||
// 注册多个带别名的 provider 配置
|
||||
const configs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'integration-provider-1',
|
||||
name: 'Integration Provider 1',
|
||||
creator: vi.fn(() => ({ name: 'integration-1' })),
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['alias-1', 'short-name-1']
|
||||
},
|
||||
{
|
||||
id: 'integration-provider-2',
|
||||
name: 'Integration Provider 2',
|
||||
creator: vi.fn(() => ({ name: 'integration-2' })),
|
||||
supportsImageGeneration: true,
|
||||
aliases: ['alias-2', 'short-name-2']
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(configs)
|
||||
expect(successCount).toBe(2)
|
||||
|
||||
// 验证配置注册成功
|
||||
expect(hasProviderConfig('integration-provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('integration-provider-2')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
|
||||
|
||||
// 验证别名映射
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
expect(aliases['alias-1']).toBe('integration-provider-1')
|
||||
expect(aliases['alias-2']).toBe('integration-provider-2')
|
||||
|
||||
// 创建和注册 providers
|
||||
const success1 = await createAndRegisterProvider('integration-provider-1', { apiKey: 'test1' })
|
||||
const success2 = await createAndRegisterProvider('integration-provider-2', { apiKey: 'test2' })
|
||||
expect(success1).toBe(true)
|
||||
expect(success2).toBe(true)
|
||||
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('integration-provider-1')
|
||||
expect(registeredProviders).toContain('integration-provider-2')
|
||||
expect(hasInitializedProviders()).toBe(true)
|
||||
|
||||
// 清理
|
||||
cleanup()
|
||||
|
||||
// 验证清理后的状态
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasProviderConfig('integration-provider-1')).toBe(false)
|
||||
expect(hasProviderConfig('integration-provider-2')).toBe(false)
|
||||
expect(getAllProviderConfigAliases()).toEqual({})
|
||||
|
||||
// 基础配置应该重新加载
|
||||
expect(hasProviderConfig('openai')).toBe(true)
|
||||
})
|
||||
|
||||
it('正确处理动态导入配置的 provider', async () => {
|
||||
const mockModule = { createCustomProvider: vi.fn(() => ({ name: 'custom-dynamic' })) }
|
||||
const dynamicImportConfig: ProviderConfig = {
|
||||
id: 'dynamic-import-provider',
|
||||
name: 'Dynamic Import Provider',
|
||||
import: vi.fn().mockResolvedValue(mockModule),
|
||||
creatorFunctionName: 'createCustomProvider',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 注册配置
|
||||
const configSuccess = registerProviderConfig(dynamicImportConfig)
|
||||
expect(configSuccess).toBe(true)
|
||||
|
||||
// 创建和注册 provider
|
||||
const registerSuccess = await createAndRegisterProvider('dynamic-import-provider', { apiKey: 'test' })
|
||||
expect(registerSuccess).toBe(true)
|
||||
|
||||
// 验证导入函数被调用
|
||||
expect(dynamicImportConfig.import).toHaveBeenCalled()
|
||||
expect(mockModule.createCustomProvider).toHaveBeenCalledWith({ apiKey: 'test' })
|
||||
|
||||
// 验证注册成功
|
||||
expect(getInitializedProviders()).toContain('dynamic-import-provider')
|
||||
})
|
||||
|
||||
it('正确处理大量配置的注册和管理', () => {
|
||||
const largeConfigList: ProviderConfig[] = []
|
||||
|
||||
// 生成50个配置
|
||||
for (let i = 0; i < 50; i++) {
|
||||
largeConfigList.push({
|
||||
id: `bulk-provider-${i}`,
|
||||
name: `Bulk Provider ${i}`,
|
||||
creator: vi.fn(() => ({ name: `bulk-${i}` })),
|
||||
supportsImageGeneration: i % 2 === 0, // 偶数支持图像生成
|
||||
aliases: [`alias-${i}`, `short-${i}`]
|
||||
})
|
||||
}
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(largeConfigList)
|
||||
expect(successCount).toBe(50)
|
||||
|
||||
// 验证所有配置都被正确注册
|
||||
const allConfigs = getAllProviderConfigs()
|
||||
expect(allConfigs.filter((config) => config.id.startsWith('bulk-provider-')).length).toBe(50)
|
||||
|
||||
// 验证别名数量
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
const bulkAliases = Object.keys(aliases).filter(
|
||||
(alias) => alias.startsWith('alias-') || alias.startsWith('short-')
|
||||
)
|
||||
expect(bulkAliases.length).toBe(100) // 每个 provider 有2个别名
|
||||
|
||||
// 随机验证几个配置
|
||||
expect(hasProviderConfig('bulk-provider-0')).toBe(true)
|
||||
expect(hasProviderConfig('bulk-provider-25')).toBe(true)
|
||||
expect(hasProviderConfig('bulk-provider-49')).toBe(true)
|
||||
|
||||
// 验证别名工作正常
|
||||
expect(resolveProviderConfigId('alias-25')).toBe('bulk-provider-25')
|
||||
expect(isProviderConfigAlias('short-30')).toBe(true)
|
||||
|
||||
// 清理能正确处理大量数据
|
||||
cleanup()
|
||||
const cleanupAliases = getAllProviderConfigAliases()
|
||||
expect(
|
||||
Object.keys(cleanupAliases).filter((alias) => alias.startsWith('alias-') || alias.startsWith('short-'))
|
||||
).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('边界测试', () => {
|
||||
it('处理包含特殊字符的 provider IDs', () => {
|
||||
const specialCharsConfigs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'provider-with-dashes',
|
||||
name: 'Provider With Dashes',
|
||||
creator: vi.fn(() => ({ name: 'dashes' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'provider_with_underscores',
|
||||
name: 'Provider With Underscores',
|
||||
creator: vi.fn(() => ({ name: 'underscores' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'provider.with.dots',
|
||||
name: 'Provider With Dots',
|
||||
creator: vi.fn(() => ({ name: 'dots' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(specialCharsConfigs)
|
||||
expect(successCount).toBeGreaterThan(0) // 至少有一些成功
|
||||
|
||||
// 验证支持的特殊字符格式
|
||||
if (hasProviderConfig('provider-with-dashes')) {
|
||||
expect(getProviderConfig('provider-with-dashes')).toBeDefined()
|
||||
}
|
||||
if (hasProviderConfig('provider_with_underscores')) {
|
||||
expect(getProviderConfig('provider_with_underscores')).toBeDefined()
|
||||
}
|
||||
})
|
||||
|
||||
it('处理空的批量注册', () => {
|
||||
const successCount = registerMultipleProviderConfigs([])
|
||||
expect(successCount).toBe(0)
|
||||
|
||||
// 确保没有额外的配置被添加
|
||||
const configsBefore = getAllProviderConfigs().length
|
||||
expect(configsBefore).toBeGreaterThan(0) // 应该有基础配置
|
||||
})
|
||||
|
||||
it('处理重复的配置注册', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'duplicate-test-provider',
|
||||
name: 'Duplicate Test Provider',
|
||||
creator: vi.fn(() => ({ name: 'duplicate' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 第一次注册成功
|
||||
expect(registerProviderConfig(config)).toBe(true)
|
||||
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
|
||||
|
||||
// 重复注册相同的配置(允许覆盖)
|
||||
const updatedConfig: ProviderConfig = {
|
||||
...config,
|
||||
name: 'Updated Duplicate Test Provider'
|
||||
}
|
||||
expect(registerProviderConfig(updatedConfig)).toBe(true)
|
||||
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
|
||||
|
||||
// 验证配置被更新
|
||||
const retrievedConfig = getProviderConfig('duplicate-test-provider')
|
||||
expect(retrievedConfig?.name).toBe('Updated Duplicate Test Provider')
|
||||
})
|
||||
|
||||
it('处理极长的 ID 和名称', () => {
|
||||
const longId = 'very-long-provider-id-' + 'x'.repeat(100)
|
||||
const longName = 'Very Long Provider Name ' + 'Y'.repeat(100)
|
||||
|
||||
const config: ProviderConfig = {
|
||||
id: longId,
|
||||
name: longName,
|
||||
creator: vi.fn(() => ({ name: 'long-test' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
expect(hasProviderConfig(longId)).toBe(true)
|
||||
|
||||
const retrievedConfig = getProviderConfig(longId)
|
||||
expect(retrievedConfig?.name).toBe(longName)
|
||||
})
|
||||
|
||||
it('处理大量别名的配置', () => {
|
||||
const manyAliases = Array.from({ length: 50 }, (_, i) => `alias-${i}`)
|
||||
|
||||
const config: ProviderConfig = {
|
||||
id: 'provider-with-many-aliases',
|
||||
name: 'Provider With Many Aliases',
|
||||
creator: vi.fn(() => ({ name: 'many-aliases' })),
|
||||
supportsImageGeneration: false,
|
||||
aliases: manyAliases
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证所有别名都能正确解析
|
||||
manyAliases.forEach((alias) => {
|
||||
expect(hasProviderConfigByAlias(alias)).toBe(true)
|
||||
expect(resolveProviderConfigId(alias)).toBe('provider-with-many-aliases')
|
||||
expect(isProviderConfigAlias(alias)).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
264
packages/aiCore/src/core/providers/__tests__/schemas.test.ts
Normal file
264
packages/aiCore/src/core/providers/__tests__/schemas.test.ts
Normal file
@ -0,0 +1,264 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
type BaseProviderId,
|
||||
baseProviderIds,
|
||||
baseProviderIdSchema,
|
||||
baseProviders,
|
||||
type CustomProviderId,
|
||||
customProviderIdSchema,
|
||||
providerConfigSchema,
|
||||
type ProviderId,
|
||||
providerIdSchema
|
||||
} from '../schemas'
|
||||
|
||||
describe('Provider Schemas', () => {
|
||||
describe('baseProviders', () => {
|
||||
it('包含所有预期的基础 providers', () => {
|
||||
expect(baseProviders).toBeDefined()
|
||||
expect(Array.isArray(baseProviders)).toBe(true)
|
||||
expect(baseProviders.length).toBeGreaterThan(0)
|
||||
|
||||
const expectedIds = [
|
||||
'openai',
|
||||
'openai-responses',
|
||||
'openai-compatible',
|
||||
'anthropic',
|
||||
'google',
|
||||
'xai',
|
||||
'azure',
|
||||
'deepseek'
|
||||
]
|
||||
const actualIds = baseProviders.map((p) => p.id)
|
||||
expectedIds.forEach((id) => {
|
||||
expect(actualIds).toContain(id)
|
||||
})
|
||||
})
|
||||
|
||||
it('每个基础 provider 有必要的属性', () => {
|
||||
baseProviders.forEach((provider) => {
|
||||
expect(provider).toHaveProperty('id')
|
||||
expect(provider).toHaveProperty('name')
|
||||
expect(provider).toHaveProperty('creator')
|
||||
expect(provider).toHaveProperty('supportsImageGeneration')
|
||||
|
||||
expect(typeof provider.id).toBe('string')
|
||||
expect(typeof provider.name).toBe('string')
|
||||
expect(typeof provider.creator).toBe('function')
|
||||
expect(typeof provider.supportsImageGeneration).toBe('boolean')
|
||||
})
|
||||
})
|
||||
|
||||
it('provider ID 是唯一的', () => {
|
||||
const ids = baseProviders.map((p) => p.id)
|
||||
const uniqueIds = [...new Set(ids)]
|
||||
expect(ids).toEqual(uniqueIds)
|
||||
})
|
||||
})
|
||||
|
||||
describe('baseProviderIds', () => {
|
||||
it('正确提取所有基础 provider IDs', () => {
|
||||
expect(baseProviderIds).toBeDefined()
|
||||
expect(Array.isArray(baseProviderIds)).toBe(true)
|
||||
expect(baseProviderIds.length).toBe(baseProviders.length)
|
||||
|
||||
baseProviders.forEach((provider) => {
|
||||
expect(baseProviderIds).toContain(provider.id)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('baseProviderIdSchema', () => {
|
||||
it('验证有效的基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝无效的基础 provider IDs', () => {
|
||||
const invalidIds = ['invalid', 'not-exists', '']
|
||||
invalidIds.forEach((id) => {
|
||||
expect(baseProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('customProviderIdSchema', () => {
|
||||
it('接受有效的自定义 provider IDs', () => {
|
||||
const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2']
|
||||
validIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝与基础 provider IDs 冲突的 IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝空字符串', () => {
|
||||
expect(customProviderIdSchema.safeParse('').success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('providerIdSchema', () => {
|
||||
it('接受基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('接受有效的自定义 provider IDs', () => {
|
||||
const validCustomIds = ['custom-provider', 'my-ai-service']
|
||||
validCustomIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝无效的 IDs', () => {
|
||||
const invalidIds = ['', undefined, null, 123]
|
||||
invalidIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('providerConfigSchema', () => {
|
||||
it('验证带有 creator 的有效配置', () => {
|
||||
const validConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||
})
|
||||
|
||||
it('验证带有 import 配置的有效配置', () => {
|
||||
const validConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
import: vi.fn(),
|
||||
creatorFunctionName: 'createCustom',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||
})
|
||||
|
||||
it('拒绝既没有 creator 也没有 import 配置的配置', () => {
|
||||
const invalidConfig = {
|
||||
id: 'invalid',
|
||||
name: 'Invalid Provider',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
|
||||
it('为 supportsImageGeneration 设置默认值', () => {
|
||||
const config = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
creator: vi.fn()
|
||||
}
|
||||
const result = providerConfigSchema.safeParse(config)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.data.supportsImageGeneration).toBe(false)
|
||||
}
|
||||
})
|
||||
|
||||
it('拒绝使用基础 provider ID 的配置', () => {
|
||||
const invalidConfig = {
|
||||
id: 'openai', // 基础 provider ID
|
||||
name: 'Should Fail',
|
||||
creator: vi.fn()
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
|
||||
it('拒绝缺少必需字段的配置', () => {
|
||||
const invalidConfigs = [
|
||||
{ name: 'Missing ID', creator: vi.fn() },
|
||||
{ id: 'missing-name', creator: vi.fn() },
|
||||
{ id: '', name: 'Empty ID', creator: vi.fn() },
|
||||
{ id: 'valid-custom', name: '', creator: vi.fn() }
|
||||
]
|
||||
|
||||
invalidConfigs.forEach((config) => {
|
||||
expect(providerConfigSchema.safeParse(config).success).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Schema 验证功能', () => {
|
||||
it('baseProviderIdSchema 正确验证基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
expect(baseProviderIdSchema.safeParse('invalid-id').success).toBe(false)
|
||||
})
|
||||
|
||||
it('customProviderIdSchema 正确验证自定义 provider IDs', () => {
|
||||
const customIds = ['custom-provider', 'my-service', 'company-llm']
|
||||
customIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
// 拒绝基础 provider IDs
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
it('providerIdSchema 接受基础和自定义 provider IDs', () => {
|
||||
// 基础 IDs
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
// 自定义 IDs
|
||||
const customIds = ['custom-provider', 'my-service']
|
||||
customIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('providerConfigSchema 验证完整的 provider 配置', () => {
|
||||
const validConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||
|
||||
const invalidConfig = {
|
||||
id: 'openai', // 不允许基础 provider ID
|
||||
name: 'OpenAI',
|
||||
creator: vi.fn()
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('类型推导', () => {
|
||||
it('BaseProviderId 类型正确', () => {
|
||||
const id: BaseProviderId = 'openai'
|
||||
expect(baseProviderIds).toContain(id)
|
||||
})
|
||||
|
||||
it('CustomProviderId 类型是字符串', () => {
|
||||
const id: CustomProviderId = 'custom-provider'
|
||||
expect(typeof id).toBe('string')
|
||||
})
|
||||
|
||||
it('ProviderId 类型支持基础和自定义 IDs', () => {
|
||||
const baseId: ProviderId = 'openai'
|
||||
const customId: ProviderId = 'custom-provider'
|
||||
expect(typeof baseId).toBe('string')
|
||||
expect(typeof customId).toBe('string')
|
||||
})
|
||||
})
|
||||
})
|
||||
291
packages/aiCore/src/core/providers/factory.ts
Normal file
291
packages/aiCore/src/core/providers/factory.ts
Normal file
@ -0,0 +1,291 @@
|
||||
/**
|
||||
* AI Provider 配置工厂
|
||||
* 提供类型安全的 Provider 配置构建器
|
||||
*/
|
||||
|
||||
import type { ProviderId, ProviderSettingsMap } from './types'
|
||||
|
||||
/**
|
||||
* 通用配置基础类型,包含所有 Provider 共有的属性
|
||||
*/
|
||||
export interface BaseProviderConfig {
|
||||
apiKey?: string
|
||||
baseURL?: string
|
||||
timeout?: number
|
||||
headers?: Record<string, string>
|
||||
fetch?: typeof globalThis.fetch
|
||||
}
|
||||
|
||||
/**
|
||||
* 完整的配置类型,结合基础配置、AI SDK 配置和特定 Provider 配置
|
||||
*/
|
||||
type CompleteProviderConfig<T extends ProviderId> = BaseProviderConfig & Partial<ProviderSettingsMap[T]>
|
||||
|
||||
type ConfigHandler<T extends ProviderId> = (
|
||||
builder: ProviderConfigBuilder<T>,
|
||||
provider: CompleteProviderConfig<T>
|
||||
) => void
|
||||
|
||||
const configHandlers: {
|
||||
[K in ProviderId]?: ConfigHandler<K>
|
||||
} = {
|
||||
azure: (builder, provider) => {
|
||||
const azureBuilder = builder as ProviderConfigBuilder<'azure'>
|
||||
const azureProvider = provider as CompleteProviderConfig<'azure'>
|
||||
azureBuilder.withAzureConfig({
|
||||
apiVersion: azureProvider.apiVersion,
|
||||
resourceName: azureProvider.resourceName
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
|
||||
private config: CompleteProviderConfig<T> = {} as CompleteProviderConfig<T>
|
||||
|
||||
constructor(private providerId: T) {}
|
||||
|
||||
/**
|
||||
* 设置 API Key
|
||||
*/
|
||||
withApiKey(apiKey: string): this
|
||||
withApiKey(apiKey: string, options: T extends 'openai' ? { organization?: string; project?: string } : never): this
|
||||
withApiKey(apiKey: string, options?: any): this {
|
||||
this.config.apiKey = apiKey
|
||||
|
||||
// 类型安全的 OpenAI 特定配置
|
||||
if (this.providerId === 'openai' && options) {
|
||||
const openaiConfig = this.config as CompleteProviderConfig<'openai'>
|
||||
if (options.organization) {
|
||||
openaiConfig.organization = options.organization
|
||||
}
|
||||
if (options.project) {
|
||||
openaiConfig.project = options.project
|
||||
}
|
||||
}
|
||||
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置基础 URL
|
||||
*/
|
||||
withBaseURL(baseURL: string) {
|
||||
this.config.baseURL = baseURL
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置请求配置
|
||||
*/
|
||||
withRequestConfig(options: { headers?: Record<string, string>; fetch?: typeof fetch }): this {
|
||||
if (options.headers) {
|
||||
this.config.headers = { ...this.config.headers, ...options.headers }
|
||||
}
|
||||
if (options.fetch) {
|
||||
this.config.fetch = options.fetch
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Azure OpenAI 特定配置
|
||||
*/
|
||||
withAzureConfig(options: { apiVersion?: string; resourceName?: string }): T extends 'azure' ? this : never
|
||||
withAzureConfig(options: any): any {
|
||||
if (this.providerId === 'azure') {
|
||||
const azureConfig = this.config as CompleteProviderConfig<'azure'>
|
||||
if (options.apiVersion) {
|
||||
azureConfig.apiVersion = options.apiVersion
|
||||
}
|
||||
if (options.resourceName) {
|
||||
azureConfig.resourceName = options.resourceName
|
||||
}
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置自定义参数
|
||||
*/
|
||||
withCustomParams(params: Record<string, any>) {
|
||||
Object.assign(this.config, params)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建最终配置
|
||||
*/
|
||||
build(): ProviderSettingsMap[T] {
|
||||
return this.config as ProviderSettingsMap[T]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider 配置工厂
|
||||
* 提供便捷的配置创建方法
|
||||
*/
|
||||
export class ProviderConfigFactory {
|
||||
/**
|
||||
* 创建配置构建器
|
||||
*/
|
||||
static builder<T extends ProviderId>(providerId: T): ProviderConfigBuilder<T> {
|
||||
return new ProviderConfigBuilder(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 从通用Provider对象创建配置 - 使用更优雅的处理器模式
|
||||
*/
|
||||
static fromProvider<T extends ProviderId>(
|
||||
providerId: T,
|
||||
provider: CompleteProviderConfig<T>,
|
||||
options?: {
|
||||
headers?: Record<string, string>
|
||||
[key: string]: any
|
||||
}
|
||||
): ProviderSettingsMap[T] {
|
||||
const builder = new ProviderConfigBuilder<T>(providerId)
|
||||
|
||||
// 设置基本配置
|
||||
if (provider.apiKey) {
|
||||
builder.withApiKey(provider.apiKey)
|
||||
}
|
||||
|
||||
if (provider.baseURL) {
|
||||
builder.withBaseURL(provider.baseURL)
|
||||
}
|
||||
|
||||
// 设置请求配置
|
||||
if (options?.headers) {
|
||||
builder.withRequestConfig({
|
||||
headers: options.headers
|
||||
})
|
||||
}
|
||||
|
||||
// 使用配置处理器模式 - 更加优雅和可扩展
|
||||
const handler = configHandlers[providerId]
|
||||
if (handler) {
|
||||
handler(builder, provider)
|
||||
}
|
||||
|
||||
// 添加其他自定义参数
|
||||
if (options) {
|
||||
const customOptions = { ...options }
|
||||
delete customOptions.headers // 已经处理过了
|
||||
if (Object.keys(customOptions).length > 0) {
|
||||
builder.withCustomParams(customOptions)
|
||||
}
|
||||
}
|
||||
|
||||
return builder.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 OpenAI 配置
|
||||
*/
|
||||
static createOpenAI(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
organization?: string
|
||||
project?: string
|
||||
}
|
||||
) {
|
||||
const builder = this.builder('openai')
|
||||
|
||||
// 使用类型安全的重载
|
||||
if (options?.organization || options?.project) {
|
||||
builder.withApiKey(apiKey, {
|
||||
organization: options.organization,
|
||||
project: options.project
|
||||
})
|
||||
} else {
|
||||
builder.withApiKey(apiKey)
|
||||
}
|
||||
|
||||
return builder.withBaseURL(options?.baseURL || 'https://api.openai.com').build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Anthropic 配置
|
||||
*/
|
||||
static createAnthropic(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('anthropic')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options?.baseURL || 'https://api.anthropic.com')
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Azure OpenAI 配置
|
||||
*/
|
||||
static createAzureOpenAI(
|
||||
apiKey: string,
|
||||
options: {
|
||||
baseURL: string
|
||||
apiVersion?: string
|
||||
resourceName?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('azure')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options.baseURL)
|
||||
.withAzureConfig({
|
||||
apiVersion: options.apiVersion,
|
||||
resourceName: options.resourceName
|
||||
})
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Google 配置
|
||||
*/
|
||||
static createGoogle(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
projectId?: string
|
||||
location?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('google')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options?.baseURL || 'https://generativelanguage.googleapis.com')
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Vertex AI 配置
|
||||
*/
|
||||
static createVertexAI() {
|
||||
// credentials: {
|
||||
// clientEmail: string
|
||||
// privateKey: string
|
||||
// },
|
||||
// options?: {
|
||||
// project?: string
|
||||
// location?: string
|
||||
// }
|
||||
// return this.builder('google-vertex')
|
||||
// .withGoogleCredentials(credentials)
|
||||
// .withGoogleVertexConfig({
|
||||
// project: options?.project,
|
||||
// location: options?.location
|
||||
// })
|
||||
// .build()
|
||||
}
|
||||
|
||||
static createOpenAICompatible(baseURL: string, apiKey: string) {
|
||||
return this.builder('openai-compatible').withBaseURL(baseURL).withApiKey(apiKey).build()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 便捷的配置创建函数
|
||||
*/
|
||||
export const createProviderConfig = ProviderConfigFactory.fromProvider
|
||||
export const providerConfigBuilder = ProviderConfigFactory.builder
|
||||
83
packages/aiCore/src/core/providers/index.ts
Normal file
83
packages/aiCore/src/core/providers/index.ts
Normal file
@ -0,0 +1,83 @@
|
||||
/**
|
||||
* Providers 模块统一导出 - 独立Provider包
|
||||
*/
|
||||
|
||||
// ==================== 核心管理器 ====================
|
||||
|
||||
// Provider 注册表管理器
|
||||
export { globalRegistryManagement, RegistryManagement } from './RegistryManagement'
|
||||
|
||||
// Provider 核心功能
|
||||
export {
|
||||
// 状态管理
|
||||
cleanup,
|
||||
clearAllProviders,
|
||||
createAndRegisterProvider,
|
||||
createProvider,
|
||||
getAllProviderConfigAliases,
|
||||
getAllProviderConfigs,
|
||||
getImageModel,
|
||||
// 工具函数
|
||||
getInitializedProviders,
|
||||
getLanguageModel,
|
||||
getProviderConfig,
|
||||
getProviderConfigByAlias,
|
||||
getSupportedProviders,
|
||||
getTextEmbeddingModel,
|
||||
hasInitializedProviders,
|
||||
// 工具函数
|
||||
hasProviderConfig,
|
||||
// 别名支持
|
||||
hasProviderConfigByAlias,
|
||||
isProviderConfigAlias,
|
||||
// 错误类型
|
||||
ProviderInitializationError,
|
||||
// 全局访问
|
||||
providerRegistry,
|
||||
registerMultipleProviderConfigs,
|
||||
registerProvider,
|
||||
// 统一Provider系统
|
||||
registerProviderConfig,
|
||||
resolveProviderConfigId
|
||||
} from './registry'
|
||||
|
||||
// ==================== 基础数据和类型 ====================
|
||||
|
||||
// 基础Provider数据源
|
||||
export { baseProviderIds, baseProviders } from './schemas'
|
||||
|
||||
// 类型定义和Schema
|
||||
export type {
|
||||
BaseProviderId,
|
||||
CustomProviderId,
|
||||
DynamicProviderRegistration,
|
||||
ProviderConfig,
|
||||
ProviderId
|
||||
} from './schemas' // 从 schemas 导出的类型
|
||||
export { baseProviderIdSchema, customProviderIdSchema, providerConfigSchema, providerIdSchema } from './schemas' // Schema 导出
|
||||
export type {
|
||||
DynamicProviderRegistry,
|
||||
ExtensibleProviderSettingsMap,
|
||||
ProviderError,
|
||||
ProviderSettingsMap,
|
||||
ProviderTypeRegistrar
|
||||
} from './types'
|
||||
|
||||
// ==================== 工具函数 ====================
|
||||
|
||||
// Provider配置工厂
|
||||
export {
|
||||
type BaseProviderConfig,
|
||||
createProviderConfig,
|
||||
ProviderConfigBuilder,
|
||||
providerConfigBuilder,
|
||||
ProviderConfigFactory
|
||||
} from './factory'
|
||||
|
||||
// 工具函数
|
||||
export { formatPrivateKey } from './utils'
|
||||
|
||||
// ==================== 扩展功能 ====================
|
||||
|
||||
// Hub Provider 功能
|
||||
export { createHubProvider, type HubProviderConfig, HubProviderError } from './HubProvider'
|
||||
320
packages/aiCore/src/core/providers/registry.ts
Normal file
320
packages/aiCore/src/core/providers/registry.ts
Normal file
@ -0,0 +1,320 @@
|
||||
/**
|
||||
* Provider 初始化器
|
||||
* 负责根据配置创建 providers 并注册到全局管理器
|
||||
* 集成了来自 ModelCreator 的特殊处理逻辑
|
||||
*/
|
||||
|
||||
import { customProvider } from 'ai'
|
||||
|
||||
import { globalRegistryManagement } from './RegistryManagement'
|
||||
import { baseProviders, type ProviderConfig } from './schemas'
|
||||
|
||||
/**
|
||||
* Provider 初始化错误类型
|
||||
*/
|
||||
class ProviderInitializationError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId?: string,
|
||||
public cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ProviderInitializationError'
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 全局管理器导出 ====================
|
||||
|
||||
export { globalRegistryManagement as providerRegistry }
|
||||
|
||||
// ==================== 便捷访问方法 ====================
|
||||
|
||||
export const getLanguageModel = (id: string) => globalRegistryManagement.languageModel(id as any)
|
||||
export const getTextEmbeddingModel = (id: string) => globalRegistryManagement.textEmbeddingModel(id as any)
|
||||
export const getImageModel = (id: string) => globalRegistryManagement.imageModel(id as any)
|
||||
|
||||
// ==================== 工具函数 ====================
|
||||
|
||||
/**
|
||||
* 获取支持的 Providers 列表
|
||||
*/
|
||||
export function getSupportedProviders(): Array<{
|
||||
id: string
|
||||
name: string
|
||||
}> {
|
||||
return baseProviders.map((provider) => ({
|
||||
id: provider.id,
|
||||
name: provider.name
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有已初始化的 providers
|
||||
*/
|
||||
export function getInitializedProviders(): string[] {
|
||||
return globalRegistryManagement.getRegisteredProviders()
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有任何已初始化的 providers
|
||||
*/
|
||||
export function hasInitializedProviders(): boolean {
|
||||
return globalRegistryManagement.hasProviders()
|
||||
}
|
||||
|
||||
// ==================== 统一Provider配置系统 ====================
|
||||
|
||||
// 全局Provider配置存储
|
||||
const providerConfigs = new Map<string, ProviderConfig>()
|
||||
// 全局ProviderConfig别名映射 - 借鉴RegistryManagement模式
|
||||
const providerConfigAliases = new Map<string, string>() // alias -> realId
|
||||
|
||||
/**
|
||||
* 初始化内置配置 - 将baseProviders转换为统一格式
|
||||
*/
|
||||
function initializeBuiltInConfigs(): void {
|
||||
baseProviders.forEach((provider) => {
|
||||
const config: ProviderConfig = {
|
||||
id: provider.id,
|
||||
name: provider.name,
|
||||
creator: provider.creator as any, // 类型转换以兼容多种creator签名
|
||||
supportsImageGeneration: provider.supportsImageGeneration || false
|
||||
}
|
||||
providerConfigs.set(provider.id, config)
|
||||
})
|
||||
}
|
||||
|
||||
// 启动时自动注册内置配置
|
||||
initializeBuiltInConfigs()
|
||||
|
||||
/**
|
||||
* 步骤1: 注册Provider配置 - 仅存储配置,不执行创建
|
||||
*/
|
||||
export function registerProviderConfig(config: ProviderConfig): boolean {
|
||||
try {
|
||||
// 验证配置
|
||||
if (!config || !config.id || !config.name) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否与已有配置冲突(包括内置配置)
|
||||
if (providerConfigs.has(config.id)) {
|
||||
console.warn(`ProviderConfig "${config.id}" already exists, will override`)
|
||||
}
|
||||
|
||||
// 存储配置(内置和用户配置统一处理)
|
||||
providerConfigs.set(config.id, config)
|
||||
|
||||
// 处理别名
|
||||
if (config.aliases && config.aliases.length > 0) {
|
||||
config.aliases.forEach((alias) => {
|
||||
if (providerConfigAliases.has(alias)) {
|
||||
console.warn(`ProviderConfig alias "${alias}" already exists, will override`)
|
||||
}
|
||||
providerConfigAliases.set(alias, config.id)
|
||||
})
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (error) {
|
||||
console.error(`Failed to register ProviderConfig:`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 步骤2: 创建Provider - 根据配置执行实际创建
|
||||
*/
|
||||
export async function createProvider(providerId: string, options: any): Promise<any> {
|
||||
// 支持通过别名查找配置
|
||||
const config = getProviderConfigByAlias(providerId)
|
||||
|
||||
if (!config) {
|
||||
throw new Error(`ProviderConfig not found for id: ${providerId}`)
|
||||
}
|
||||
|
||||
try {
|
||||
let creator: (options: any) => any
|
||||
|
||||
if (config.creator) {
|
||||
// 方式1: 直接执行 creator
|
||||
creator = config.creator
|
||||
} else if (config.import && config.creatorFunctionName) {
|
||||
// 方式2: 动态导入并执行
|
||||
const module = await config.import()
|
||||
creator = (module as any)[config.creatorFunctionName]
|
||||
|
||||
if (!creator || typeof creator !== 'function') {
|
||||
throw new Error(`Creator function "${config.creatorFunctionName}" not found in imported module`)
|
||||
}
|
||||
} else {
|
||||
throw new Error('No valid creator method provided in ProviderConfig')
|
||||
}
|
||||
|
||||
// 使用真实配置创建provider实例
|
||||
return creator(options)
|
||||
} catch (error) {
|
||||
console.error(`Failed to create provider "${providerId}":`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 步骤3: 注册Provider到全局管理器
|
||||
*/
|
||||
export function registerProvider(providerId: string, provider: any): boolean {
|
||||
try {
|
||||
const config = providerConfigs.get(providerId)
|
||||
if (!config) {
|
||||
console.error(`ProviderConfig not found for id: ${providerId}`)
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取aliases配置
|
||||
const aliases = config.aliases
|
||||
|
||||
// 处理特殊provider逻辑
|
||||
if (providerId === 'openai') {
|
||||
// 注册默认 openai
|
||||
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||
|
||||
// 创建并注册 openai-chat 变体
|
||||
const openaiChatProvider = customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
globalRegistryManagement.registerProvider(`${providerId}-chat`, openaiChatProvider)
|
||||
} else if (providerId === 'azure') {
|
||||
globalRegistryManagement.registerProvider(`${providerId}-chat`, provider, aliases)
|
||||
// 跟上面相反,creator产出的默认会调用chat
|
||||
const azureResponsesProvider = customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.responses(modelId)
|
||||
}
|
||||
})
|
||||
globalRegistryManagement.registerProvider(providerId, azureResponsesProvider)
|
||||
} else {
|
||||
// 其他provider直接注册
|
||||
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (error) {
|
||||
console.error(`Failed to register provider "${providerId}" to global registry:`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 便捷函数: 一次性完成创建+注册
|
||||
*/
|
||||
export async function createAndRegisterProvider(providerId: string, options: any): Promise<boolean> {
|
||||
try {
|
||||
// 步骤2: 创建provider
|
||||
const provider = await createProvider(providerId, options)
|
||||
|
||||
// 步骤3: 注册到全局管理器
|
||||
return registerProvider(providerId, provider)
|
||||
} catch (error) {
|
||||
console.error(`Failed to create and register provider "${providerId}":`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量注册Provider配置
|
||||
*/
|
||||
export function registerMultipleProviderConfigs(configs: ProviderConfig[]): number {
|
||||
let successCount = 0
|
||||
configs.forEach((config) => {
|
||||
if (registerProviderConfig(config)) {
|
||||
successCount++
|
||||
}
|
||||
})
|
||||
return successCount
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有对应的Provider配置
|
||||
*/
|
||||
export function hasProviderConfig(providerId: string): boolean {
|
||||
return providerConfigs.has(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过别名或ID检查是否有对应的Provider配置
|
||||
*/
|
||||
export function hasProviderConfigByAlias(aliasOrId: string): boolean {
|
||||
const realId = resolveProviderConfigId(aliasOrId)
|
||||
return providerConfigs.has(realId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有Provider配置
|
||||
*/
|
||||
export function getAllProviderConfigs(): ProviderConfig[] {
|
||||
return Array.from(providerConfigs.values())
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据ID获取Provider配置
|
||||
*/
|
||||
export function getProviderConfig(providerId: string): ProviderConfig | undefined {
|
||||
return providerConfigs.get(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过别名或ID获取Provider配置
|
||||
*/
|
||||
export function getProviderConfigByAlias(aliasOrId: string): ProviderConfig | undefined {
|
||||
// 先检查是否为别名,如果是则解析为真实ID
|
||||
const realId = providerConfigAliases.get(aliasOrId) || aliasOrId
|
||||
return providerConfigs.get(realId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析真实的ProviderConfig ID(去别名化)
|
||||
*/
|
||||
export function resolveProviderConfigId(aliasOrId: string): string {
|
||||
return providerConfigAliases.get(aliasOrId) || aliasOrId
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否为ProviderConfig别名
|
||||
*/
|
||||
export function isProviderConfigAlias(id: string): boolean {
|
||||
return providerConfigAliases.has(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有ProviderConfig别名映射关系
|
||||
*/
|
||||
export function getAllProviderConfigAliases(): Record<string, string> {
|
||||
const result: Record<string, string> = {}
|
||||
providerConfigAliases.forEach((realId, alias) => {
|
||||
result[alias] = realId
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理所有Provider配置和已注册的providers
|
||||
*/
|
||||
export function cleanup(): void {
|
||||
providerConfigs.clear()
|
||||
providerConfigAliases.clear() // 清理别名映射
|
||||
globalRegistryManagement.clear()
|
||||
// 重新初始化内置配置
|
||||
initializeBuiltInConfigs()
|
||||
}
|
||||
|
||||
export function clearAllProviders(): void {
|
||||
globalRegistryManagement.clear()
|
||||
}
|
||||
|
||||
// ==================== 导出错误类型 ====================
|
||||
|
||||
export { ProviderInitializationError }
|
||||
178
packages/aiCore/src/core/providers/schemas.ts
Normal file
178
packages/aiCore/src/core/providers/schemas.ts
Normal file
@ -0,0 +1,178 @@
|
||||
/**
|
||||
* Provider Config 定义
|
||||
*/
|
||||
|
||||
import { createAnthropic } from '@ai-sdk/anthropic'
|
||||
import { createAzure } from '@ai-sdk/azure'
|
||||
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
||||
import { createDeepSeek } from '@ai-sdk/deepseek'
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google'
|
||||
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
|
||||
import { createXai } from '@ai-sdk/xai'
|
||||
import { customProvider, type Provider } from 'ai'
|
||||
import * as z from 'zod'
|
||||
|
||||
/**
|
||||
* 基础 Provider IDs
|
||||
*/
|
||||
export const baseProviderIds = [
|
||||
'openai',
|
||||
'openai-chat',
|
||||
'openai-compatible',
|
||||
'anthropic',
|
||||
'google',
|
||||
'xai',
|
||||
'azure',
|
||||
'azure-responses',
|
||||
'deepseek'
|
||||
] as const
|
||||
|
||||
/**
|
||||
* 基础 Provider ID Schema
|
||||
*/
|
||||
export const baseProviderIdSchema = z.enum(baseProviderIds)
|
||||
|
||||
/**
|
||||
* 基础 Provider ID
|
||||
*/
|
||||
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
|
||||
|
||||
export const baseProviderSchema = z.object({
|
||||
id: baseProviderIdSchema,
|
||||
name: z.string(),
|
||||
creator: z.function().args(z.any()).returns(z.any()) as z.ZodType<(options: any) => Provider>,
|
||||
supportsImageGeneration: z.boolean()
|
||||
})
|
||||
|
||||
export type BaseProvider = z.infer<typeof baseProviderSchema>
|
||||
|
||||
/**
|
||||
* 基础 Providers 定义
|
||||
* 作为唯一数据源,避免重复维护
|
||||
*/
|
||||
export const baseProviders = [
|
||||
{
|
||||
id: 'openai',
|
||||
name: 'OpenAI',
|
||||
creator: createOpenAI,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'openai-chat',
|
||||
name: 'OpenAI Chat',
|
||||
creator: (options: OpenAIProviderSettings) => {
|
||||
const provider = createOpenAI(options)
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
},
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'openai-compatible',
|
||||
name: 'OpenAI Compatible',
|
||||
creator: createOpenAICompatible,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'anthropic',
|
||||
name: 'Anthropic',
|
||||
creator: createAnthropic,
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'google',
|
||||
name: 'Google Generative AI',
|
||||
creator: createGoogleGenerativeAI,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'xai',
|
||||
name: 'xAI (Grok)',
|
||||
creator: createXai,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'azure',
|
||||
name: 'Azure OpenAI',
|
||||
creator: createAzure,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'azure-responses',
|
||||
name: 'Azure OpenAI Responses',
|
||||
creator: (options: AzureOpenAIProviderSettings) => {
|
||||
const provider = createAzure(options)
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.responses(modelId)
|
||||
}
|
||||
})
|
||||
},
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'deepseek',
|
||||
name: 'DeepSeek',
|
||||
creator: createDeepSeek,
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
] as const satisfies BaseProvider[]
|
||||
|
||||
/**
|
||||
* 用户自定义 Provider ID Schema
|
||||
* 允许任意字符串,但排除基础 provider IDs 以避免冲突
|
||||
*/
|
||||
export const customProviderIdSchema = z
|
||||
.string()
|
||||
.min(1)
|
||||
.refine((id) => !baseProviderIds.includes(id as any), {
|
||||
message: 'Custom provider ID cannot conflict with base provider IDs'
|
||||
})
|
||||
|
||||
/**
|
||||
* Provider ID Schema - 支持基础和自定义
|
||||
*/
|
||||
export const providerIdSchema = z.union([baseProviderIdSchema, customProviderIdSchema])
|
||||
|
||||
/**
|
||||
* Provider 配置 Schema
|
||||
* 用于Provider的配置验证
|
||||
*/
|
||||
export const providerConfigSchema = z
|
||||
.object({
|
||||
id: customProviderIdSchema, // 只允许自定义ID
|
||||
name: z.string().min(1),
|
||||
creator: z.function().optional(),
|
||||
import: z.function().optional(),
|
||||
creatorFunctionName: z.string().optional(),
|
||||
supportsImageGeneration: z.boolean().default(false),
|
||||
imageCreator: z.function().optional(),
|
||||
validateOptions: z.function().optional(),
|
||||
aliases: z.array(z.string()).optional()
|
||||
})
|
||||
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
|
||||
message: 'Must provide either creator function or import configuration'
|
||||
})
|
||||
|
||||
/**
|
||||
* Provider ID 类型 - 基于 zod schema 推导
|
||||
*/
|
||||
export type ProviderId = z.infer<typeof providerIdSchema>
|
||||
export type CustomProviderId = z.infer<typeof customProviderIdSchema>
|
||||
|
||||
/**
|
||||
* Provider 配置类型
|
||||
*/
|
||||
export type ProviderConfig = z.infer<typeof providerConfigSchema>
|
||||
|
||||
/**
|
||||
* 兼容性类型别名
|
||||
* @deprecated 使用 ProviderConfig 替代
|
||||
*/
|
||||
export type DynamicProviderRegistration = ProviderConfig
|
||||
96
packages/aiCore/src/core/providers/types.ts
Normal file
96
packages/aiCore/src/core/providers/types.ts
Normal file
@ -0,0 +1,96 @@
|
||||
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
|
||||
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
||||
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
|
||||
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
|
||||
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
|
||||
import {
|
||||
EmbeddingModelV2 as EmbeddingModel,
|
||||
ImageModelV2 as ImageModel,
|
||||
LanguageModelV2 as LanguageModel,
|
||||
ProviderV2,
|
||||
SpeechModelV2 as SpeechModel,
|
||||
TranscriptionModelV2 as TranscriptionModel
|
||||
} from '@ai-sdk/provider'
|
||||
import { type XaiProviderSettings } from '@ai-sdk/xai'
|
||||
|
||||
// 导入基于 Zod 的 ProviderId 类型
|
||||
import { type ProviderId as ZodProviderId } from './schemas'
|
||||
|
||||
export interface ExtensibleProviderSettingsMap {
|
||||
// 基础的静态providers
|
||||
openai: OpenAIProviderSettings
|
||||
'openai-responses': OpenAIProviderSettings
|
||||
'openai-compatible': OpenAICompatibleProviderSettings
|
||||
anthropic: AnthropicProviderSettings
|
||||
google: GoogleGenerativeAIProviderSettings
|
||||
xai: XaiProviderSettings
|
||||
azure: AzureOpenAIProviderSettings
|
||||
deepseek: DeepSeekProviderSettings
|
||||
}
|
||||
|
||||
// 动态扩展的provider类型注册表
|
||||
export interface DynamicProviderRegistry {
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
// 合并基础和动态provider类型
|
||||
export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry
|
||||
|
||||
// 错误类型
|
||||
export class ProviderError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId: string,
|
||||
public code?: string,
|
||||
public cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ProviderError'
|
||||
}
|
||||
}
|
||||
|
||||
// 动态ProviderId类型 - 基于 Zod Schema,支持运行时扩展和验证
|
||||
export type ProviderId = ZodProviderId
|
||||
|
||||
export interface ProviderTypeRegistrar {
|
||||
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
|
||||
getProviderSettings<T extends string>(providerId: T): any
|
||||
}
|
||||
|
||||
// 重新导出所有类型供外部使用
|
||||
export type {
|
||||
AnthropicProviderSettings,
|
||||
AzureOpenAIProviderSettings,
|
||||
DeepSeekProviderSettings,
|
||||
GoogleGenerativeAIProviderSettings,
|
||||
OpenAICompatibleProviderSettings,
|
||||
OpenAIProviderSettings,
|
||||
XaiProviderSettings
|
||||
}
|
||||
|
||||
export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel<string> | TranscriptionModel | SpeechModel
|
||||
|
||||
export type AiSdkModelType = 'text' | 'image' | 'embedding' | 'transcription' | 'speech'
|
||||
|
||||
export const METHOD_MAP = {
|
||||
text: 'languageModel',
|
||||
image: 'imageModel',
|
||||
embedding: 'textEmbeddingModel',
|
||||
transcription: 'transcriptionModel',
|
||||
speech: 'speechModel'
|
||||
} as const satisfies Record<AiSdkModelType, keyof ProviderV2>
|
||||
|
||||
export type AiSdkModelMethodMap = Record<AiSdkModelType, keyof ProviderV2>
|
||||
|
||||
export type AiSdkModelReturnMap = {
|
||||
text: LanguageModel
|
||||
image: ImageModel
|
||||
embedding: EmbeddingModel<string>
|
||||
transcription: TranscriptionModel
|
||||
speech: SpeechModel
|
||||
}
|
||||
|
||||
export type AiSdkMethodName<T extends AiSdkModelType> = (typeof METHOD_MAP)[T]
|
||||
|
||||
export type AiSdkModelReturn<T extends AiSdkModelType> = AiSdkModelReturnMap[T]
|
||||
86
packages/aiCore/src/core/providers/utils.ts
Normal file
86
packages/aiCore/src/core/providers/utils.ts
Normal file
@ -0,0 +1,86 @@
|
||||
/**
|
||||
* 格式化私钥,确保它包含正确的PEM头部和尾部
|
||||
*/
|
||||
export function formatPrivateKey(privateKey: string): string {
|
||||
if (!privateKey || typeof privateKey !== 'string') {
|
||||
throw new Error('Private key must be a non-empty string')
|
||||
}
|
||||
|
||||
// 先处理 JSON 字符串中的转义换行符
|
||||
const key = privateKey.replace(/\\n/g, '\n')
|
||||
|
||||
// 检查是否已经是正确格式的 PEM 私钥
|
||||
const hasBeginMarker = key.includes('-----BEGIN PRIVATE KEY-----')
|
||||
const hasEndMarker = key.includes('-----END PRIVATE KEY-----')
|
||||
|
||||
if (hasBeginMarker && hasEndMarker) {
|
||||
// 已经是 PEM 格式,但可能格式不规范,重新格式化
|
||||
return normalizePemFormat(key)
|
||||
}
|
||||
|
||||
// 如果没有完整的 PEM 头尾,尝试重新构建
|
||||
return reconstructPemKey(key)
|
||||
}
|
||||
|
||||
/**
|
||||
* 标准化 PEM 格式
|
||||
*/
|
||||
function normalizePemFormat(pemKey: string): string {
|
||||
// 分离头部、内容和尾部
|
||||
const lines = pemKey
|
||||
.split('\n')
|
||||
.map((line) => line.trim())
|
||||
.filter((line) => line.length > 0)
|
||||
|
||||
let keyContent = ''
|
||||
let foundBegin = false
|
||||
let foundEnd = false
|
||||
|
||||
for (const line of lines) {
|
||||
if (line === '-----BEGIN PRIVATE KEY-----') {
|
||||
foundBegin = true
|
||||
continue
|
||||
}
|
||||
if (line === '-----END PRIVATE KEY-----') {
|
||||
foundEnd = true
|
||||
break
|
||||
}
|
||||
if (foundBegin && !foundEnd) {
|
||||
keyContent += line
|
||||
}
|
||||
}
|
||||
|
||||
if (!foundBegin || !foundEnd || !keyContent) {
|
||||
throw new Error('Invalid PEM format: missing BEGIN/END markers or key content')
|
||||
}
|
||||
|
||||
// 重新格式化为 64 字符一行
|
||||
const formattedContent = keyContent.match(/.{1,64}/g)?.join('\n') || keyContent
|
||||
|
||||
return `-----BEGIN PRIVATE KEY-----\n${formattedContent}\n-----END PRIVATE KEY-----`
|
||||
}
|
||||
|
||||
/**
|
||||
* 重新构建 PEM 私钥
|
||||
*/
|
||||
function reconstructPemKey(key: string): string {
|
||||
// 移除所有空白字符和可能存在的不完整头尾
|
||||
let cleanKey = key.replace(/\s+/g, '')
|
||||
cleanKey = cleanKey.replace(/-----BEGIN[^-]*-----/g, '')
|
||||
cleanKey = cleanKey.replace(/-----END[^-]*-----/g, '')
|
||||
|
||||
// 确保私钥内容不为空
|
||||
if (!cleanKey) {
|
||||
throw new Error('Private key content is empty after cleaning')
|
||||
}
|
||||
|
||||
// 验证是否是有效的 Base64 字符
|
||||
if (!/^[A-Za-z0-9+/=]+$/.test(cleanKey)) {
|
||||
throw new Error('Private key contains invalid characters (not valid Base64)')
|
||||
}
|
||||
|
||||
// 格式化为 64 字符一行
|
||||
const formattedKey = cleanKey.match(/.{1,64}/g)?.join('\n') || cleanKey
|
||||
|
||||
return `-----BEGIN PRIVATE KEY-----\n${formattedKey}\n-----END PRIVATE KEY-----`
|
||||
}
|
||||
523
packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
Normal file
523
packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
Normal file
@ -0,0 +1,523 @@
|
||||
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import { experimental_generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { type AiPlugin } from '../../plugins'
|
||||
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||
import { ImageGenerationError, ImageModelResolutionError } from '../errors'
|
||||
import { RuntimeExecutor } from '../executor'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('ai', () => ({
|
||||
experimental_generateImage: vi.fn(),
|
||||
NoImageGeneratedError: class NoImageGeneratedError extends Error {
|
||||
static isInstance = vi.fn()
|
||||
constructor() {
|
||||
super('No image generated')
|
||||
this.name = 'NoImageGeneratedError'
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../providers/RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
imageModel: vi.fn()
|
||||
},
|
||||
DEFAULT_SEPARATOR: '|'
|
||||
}))
|
||||
|
||||
describe('RuntimeExecutor.generateImage', () => {
|
||||
let executor: RuntimeExecutor<'openai'>
|
||||
let mockImageModel: ImageModelV2
|
||||
let mockGenerateImageResult: any
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset all mocks
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Create executor instance
|
||||
executor = RuntimeExecutor.create('openai', {
|
||||
apiKey: 'test-key'
|
||||
})
|
||||
|
||||
// Mock image model
|
||||
mockImageModel = {
|
||||
modelId: 'dall-e-3',
|
||||
provider: 'openai'
|
||||
} as ImageModelV2
|
||||
|
||||
// Mock generateImage result
|
||||
mockGenerateImageResult = {
|
||||
image: {
|
||||
base64: 'base64-encoded-image-data',
|
||||
uint8Array: new Uint8Array([1, 2, 3]),
|
||||
mediaType: 'image/png'
|
||||
},
|
||||
images: [
|
||||
{
|
||||
base64: 'base64-encoded-image-data',
|
||||
uint8Array: new Uint8Array([1, 2, 3]),
|
||||
mediaType: 'image/png'
|
||||
}
|
||||
],
|
||||
warnings: [],
|
||||
providerMetadata: {
|
||||
openai: {
|
||||
images: [{ revisedPrompt: 'A detailed prompt' }]
|
||||
}
|
||||
},
|
||||
responses: []
|
||||
}
|
||||
|
||||
// Setup mocks to avoid "No providers registered" error
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel)
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult)
|
||||
})
|
||||
|
||||
describe('Basic functionality', () => {
|
||||
it('should generate a single image with minimal parameters', async () => {
|
||||
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape at sunset' })
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3')
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A futuristic cityscape at sunset'
|
||||
})
|
||||
|
||||
expect(result).toEqual(mockGenerateImageResult)
|
||||
})
|
||||
|
||||
it('should generate image with pre-created model', async () => {
|
||||
const result = await executor.generateImage({
|
||||
model: mockImageModel,
|
||||
prompt: 'A beautiful landscape'
|
||||
})
|
||||
|
||||
// Note: globalRegistryManagement.imageModel may still be called due to resolveImageModel logic
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A beautiful landscape'
|
||||
})
|
||||
|
||||
expect(result).toEqual(mockGenerateImageResult)
|
||||
})
|
||||
|
||||
it('should support multiple images generation', async () => {
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape', n: 3 })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A futuristic cityscape',
|
||||
n: 3
|
||||
})
|
||||
})
|
||||
|
||||
it('should support size specification', async () => {
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A beautiful sunset', size: '1024x1024' })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A beautiful sunset',
|
||||
size: '1024x1024'
|
||||
})
|
||||
})
|
||||
|
||||
it('should support aspect ratio specification', async () => {
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A mountain landscape', aspectRatio: '16:9' })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A mountain landscape',
|
||||
aspectRatio: '16:9'
|
||||
})
|
||||
})
|
||||
|
||||
it('should support seed for consistent output', async () => {
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A cat in space', seed: 1234567890 })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A cat in space',
|
||||
seed: 1234567890
|
||||
})
|
||||
})
|
||||
|
||||
it('should support abort signal', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A cityscape', abortSignal: abortController.signal })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A cityscape',
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
})
|
||||
|
||||
it('should support provider-specific options', async () => {
|
||||
await executor.generateImage({
|
||||
model: 'dall-e-3',
|
||||
prompt: 'A space station',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: 'hd',
|
||||
style: 'vivid'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A space station',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: 'hd',
|
||||
style: 'vivid'
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should support custom headers', async () => {
|
||||
await executor.generateImage({
|
||||
model: 'dall-e-3',
|
||||
prompt: 'A robot',
|
||||
headers: {
|
||||
'X-Custom-Header': 'test-value'
|
||||
}
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A robot',
|
||||
headers: {
|
||||
'X-Custom-Header': 'test-value'
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Plugin integration', () => {
|
||||
it('should execute plugins in correct order', async () => {
|
||||
const pluginCallOrder: string[] = []
|
||||
|
||||
const testPlugin: AiPlugin = {
|
||||
name: 'test-plugin',
|
||||
onRequestStart: vi.fn(async () => {
|
||||
pluginCallOrder.push('onRequestStart')
|
||||
}),
|
||||
transformParams: vi.fn(async (params) => {
|
||||
pluginCallOrder.push('transformParams')
|
||||
return { ...params, size: '512x512' }
|
||||
}),
|
||||
transformResult: vi.fn(async (result) => {
|
||||
pluginCallOrder.push('transformResult')
|
||||
return { ...result, processed: true }
|
||||
}),
|
||||
onRequestEnd: vi.fn(async () => {
|
||||
pluginCallOrder.push('onRequestEnd')
|
||||
})
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[testPlugin]
|
||||
)
|
||||
|
||||
const result = await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||
|
||||
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
|
||||
|
||||
expect(testPlugin.transformParams).toHaveBeenCalledWith(
|
||||
{ prompt: 'A test image' },
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
)
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A test image',
|
||||
size: '512x512' // Should be transformed by plugin
|
||||
})
|
||||
|
||||
expect(result).toEqual({
|
||||
...mockGenerateImageResult,
|
||||
processed: true // Should be transformed by plugin
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle model resolution through plugins', async () => {
|
||||
const customImageModel = {
|
||||
modelId: 'custom-model',
|
||||
provider: 'openai'
|
||||
} as ImageModelV2
|
||||
|
||||
const modelResolutionPlugin: AiPlugin = {
|
||||
name: 'model-resolver',
|
||||
resolveModel: vi.fn(async () => customImageModel)
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[modelResolutionPlugin]
|
||||
)
|
||||
|
||||
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||
|
||||
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
|
||||
'dall-e-3',
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
)
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: customImageModel,
|
||||
prompt: 'A test image'
|
||||
})
|
||||
})
|
||||
|
||||
it('should support recursive calls from plugins', async () => {
|
||||
const recursivePlugin: AiPlugin = {
|
||||
name: 'recursive-plugin',
|
||||
transformParams: vi.fn(async (params, context) => {
|
||||
if (!context.isRecursiveCall && params.prompt === 'original') {
|
||||
// Make a recursive call with modified prompt
|
||||
await context.recursiveCall({
|
||||
model: 'dall-e-3',
|
||||
prompt: 'modified'
|
||||
})
|
||||
}
|
||||
return params
|
||||
})
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[recursivePlugin]
|
||||
)
|
||||
|
||||
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'original' })
|
||||
|
||||
expect(recursivePlugin.transformParams).toHaveBeenCalledTimes(2)
|
||||
expect(aiGenerateImage).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error handling', () => {
|
||||
it('should handle model creation errors', async () => {
|
||||
const modelError = new Error('Failed to get image model')
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||
throw modelError
|
||||
})
|
||||
|
||||
await expect(executor.generateImage({ model: 'invalid-model', prompt: 'A test image' })).rejects.toThrow(
|
||||
ImageGenerationError
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle ImageModelResolutionError correctly', async () => {
|
||||
const resolutionError = new ImageModelResolutionError('invalid-model', 'openai', new Error('Model not found'))
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||
throw resolutionError
|
||||
})
|
||||
|
||||
const thrownError = await executor
|
||||
.generateImage({ model: 'invalid-model', prompt: 'A test image' })
|
||||
.catch((error) => error)
|
||||
|
||||
expect(thrownError).toBeInstanceOf(ImageGenerationError)
|
||||
expect(thrownError.message).toContain('Failed to generate image:')
|
||||
expect(thrownError.providerId).toBe('openai')
|
||||
expect(thrownError.modelId).toBe('invalid-model')
|
||||
expect(thrownError.cause).toBeInstanceOf(ImageModelResolutionError)
|
||||
expect(thrownError.cause.message).toContain('Failed to resolve image model: invalid-model')
|
||||
})
|
||||
|
||||
it('should handle ImageModelResolutionError without provider', async () => {
|
||||
const resolutionError = new ImageModelResolutionError('unknown-model')
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||
throw resolutionError
|
||||
})
|
||||
|
||||
await expect(executor.generateImage({ model: 'unknown-model', prompt: 'A test image' })).rejects.toThrow(
|
||||
ImageGenerationError
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle image generation API errors', async () => {
|
||||
const apiError = new Error('API request failed')
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
|
||||
|
||||
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||
'Failed to generate image:'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle NoImageGeneratedError', async () => {
|
||||
const noImageError = new NoImageGeneratedError({
|
||||
cause: new Error('No image generated'),
|
||||
responses: []
|
||||
})
|
||||
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
|
||||
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
|
||||
|
||||
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||
'Failed to generate image:'
|
||||
)
|
||||
})
|
||||
|
||||
it('should execute onError plugin hook on failure', async () => {
|
||||
const error = new Error('Generation failed')
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(error)
|
||||
|
||||
const errorPlugin: AiPlugin = {
|
||||
name: 'error-handler',
|
||||
onError: vi.fn()
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[errorPlugin]
|
||||
)
|
||||
|
||||
await expect(executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||
'Failed to generate image:'
|
||||
)
|
||||
|
||||
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||
error,
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle abort signal timeout', async () => {
|
||||
const abortError = new Error('Operation was aborted')
|
||||
abortError.name = 'AbortError'
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(abortError)
|
||||
|
||||
const abortController = new AbortController()
|
||||
setTimeout(() => abortController.abort(), 10)
|
||||
|
||||
await expect(
|
||||
executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', abortSignal: abortController.signal })
|
||||
).rejects.toThrow('Failed to generate image:')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Multiple providers support', () => {
|
||||
it('should work with different providers', async () => {
|
||||
const googleExecutor = RuntimeExecutor.create('google', {
|
||||
apiKey: 'google-key'
|
||||
})
|
||||
|
||||
await googleExecutor.generateImage({ model: 'imagen-3.0-generate-002', prompt: 'A landscape' })
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google|imagen-3.0-generate-002')
|
||||
})
|
||||
|
||||
it('should support xAI Grok image models', async () => {
|
||||
const xaiExecutor = RuntimeExecutor.create('xai', {
|
||||
apiKey: 'xai-key'
|
||||
})
|
||||
|
||||
await xaiExecutor.generateImage({ model: 'grok-2-image', prompt: 'A futuristic robot' })
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai|grok-2-image')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Advanced features', () => {
|
||||
it('should support batch image generation with maxImagesPerCall', async () => {
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', n: 10, maxImagesPerCall: 5 })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A test image',
|
||||
n: 10,
|
||||
maxImagesPerCall: 5
|
||||
})
|
||||
})
|
||||
|
||||
it('should support retries with maxRetries', async () => {
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', maxRetries: 3 })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A test image',
|
||||
maxRetries: 3
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle warnings from the model', async () => {
|
||||
const resultWithWarnings = {
|
||||
...mockGenerateImageResult,
|
||||
warnings: [
|
||||
{
|
||||
type: 'unsupported-setting',
|
||||
message: 'Size parameter not supported for this model'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithWarnings)
|
||||
|
||||
const result = await executor.generateImage({
|
||||
model: 'dall-e-3',
|
||||
prompt: 'A test image',
|
||||
size: '2048x2048' // Unsupported size
|
||||
})
|
||||
|
||||
expect(result.warnings).toHaveLength(1)
|
||||
expect(result.warnings[0].type).toBe('unsupported-setting')
|
||||
})
|
||||
|
||||
it('should provide access to provider metadata', async () => {
|
||||
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||
|
||||
expect(result.providerMetadata).toBeDefined()
|
||||
expect(result.providerMetadata.openai).toBeDefined()
|
||||
})
|
||||
|
||||
it('should provide response metadata', async () => {
|
||||
const resultWithMetadata = {
|
||||
...mockGenerateImageResult,
|
||||
responses: [
|
||||
{
|
||||
timestamp: new Date(),
|
||||
modelId: 'dall-e-3',
|
||||
headers: { 'x-request-id': 'test-123' }
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithMetadata)
|
||||
|
||||
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||
|
||||
expect(result.responses).toHaveLength(1)
|
||||
expect(result.responses[0].modelId).toBe('dall-e-3')
|
||||
expect(result.responses[0].headers).toEqual({ 'x-request-id': 'test-123' })
|
||||
})
|
||||
})
|
||||
})
|
||||
38
packages/aiCore/src/core/runtime/errors.ts
Normal file
38
packages/aiCore/src/core/runtime/errors.ts
Normal file
@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Error classes for runtime operations
|
||||
*/
|
||||
|
||||
/**
|
||||
* Error thrown when image generation fails
|
||||
*/
|
||||
export class ImageGenerationError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId?: string,
|
||||
public modelId?: string,
|
||||
public cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ImageGenerationError'
|
||||
|
||||
// Maintain proper stack trace (for V8 engines)
|
||||
if (Error.captureStackTrace) {
|
||||
Error.captureStackTrace(this, ImageGenerationError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when model resolution fails during image generation
|
||||
*/
|
||||
export class ImageModelResolutionError extends ImageGenerationError {
|
||||
constructor(modelId: string, providerId?: string, cause?: Error) {
|
||||
super(
|
||||
`Failed to resolve image model: ${modelId}${providerId ? ` for provider: ${providerId}` : ''}`,
|
||||
providerId,
|
||||
modelId,
|
||||
cause
|
||||
)
|
||||
this.name = 'ImageModelResolutionError'
|
||||
}
|
||||
}
|
||||
321
packages/aiCore/src/core/runtime/executor.ts
Normal file
321
packages/aiCore/src/core/runtime/executor.ts
Normal file
@ -0,0 +1,321 @@
|
||||
/**
|
||||
* 运行时执行器
|
||||
* 专注于插件化的AI调用处理
|
||||
*/
|
||||
import { ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
import {
|
||||
experimental_generateImage as generateImage,
|
||||
generateObject,
|
||||
generateText,
|
||||
LanguageModel,
|
||||
streamObject,
|
||||
streamText
|
||||
} from 'ai'
|
||||
|
||||
import { globalModelResolver } from '../models'
|
||||
import { type ModelConfig } from '../models/types'
|
||||
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||
import { type ProviderId } from '../providers'
|
||||
import { ImageGenerationError, ImageModelResolutionError } from './errors'
|
||||
import { PluginEngine } from './pluginEngine'
|
||||
import { type RuntimeConfig } from './types'
|
||||
|
||||
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
public pluginEngine: PluginEngine<T>
|
||||
// private options: ProviderSettingsMap[T]
|
||||
private config: RuntimeConfig<T>
|
||||
|
||||
constructor(config: RuntimeConfig<T>) {
|
||||
// if (!isProviderSupported(config.providerId)) {
|
||||
// throw new Error(`Unsupported provider: ${config.providerId}`)
|
||||
// }
|
||||
|
||||
// 存储options供后续使用
|
||||
// this.options = config.options
|
||||
this.config = config
|
||||
// 创建插件客户端
|
||||
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
|
||||
}
|
||||
|
||||
private createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
|
||||
return definePlugin({
|
||||
name: '_internal_resolveModel',
|
||||
enforce: 'post',
|
||||
|
||||
resolveModel: async (modelId: string) => {
|
||||
// 注意:extraModelConfig 暂时不支持,已在新架构中移除
|
||||
return await this.resolveModel(modelId, middlewares)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private createResolveImageModelPlugin() {
|
||||
return definePlugin({
|
||||
name: '_internal_resolveImageModel',
|
||||
enforce: 'post',
|
||||
|
||||
resolveModel: async (modelId: string) => {
|
||||
return await this.resolveImageModel(modelId)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private createConfigureContextPlugin() {
|
||||
return definePlugin({
|
||||
name: '_internal_configureContext',
|
||||
configureContext: async (context: AiRequestContext) => {
|
||||
context.executor = this
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// === 高阶重载:直接使用模型 ===
|
||||
|
||||
/**
|
||||
* 流式文本生成
|
||||
*/
|
||||
async streamText(
|
||||
params: Parameters<typeof streamText>[0],
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamText>> {
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeStreamWithPlugins(
|
||||
'streamText',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams, streamTransforms) => {
|
||||
const experimental_transform =
|
||||
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
|
||||
|
||||
const finalParams = {
|
||||
model: resolvedModel,
|
||||
...transformedParams,
|
||||
experimental_transform
|
||||
} as Parameters<typeof streamText>[0]
|
||||
|
||||
return await streamText(finalParams)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// === 其他方法的重载 ===
|
||||
|
||||
/**
|
||||
* 生成文本
|
||||
*/
|
||||
async generateText(
|
||||
params: Parameters<typeof generateText>[0],
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateText>> {
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'generateText',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
generateText({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateText>[0])
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成结构化对象
|
||||
*/
|
||||
async generateObject(
|
||||
params: Parameters<typeof generateObject>[0],
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateObject>> {
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'generateObject',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
generateObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateObject>[0])
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式生成结构化对象
|
||||
*/
|
||||
async streamObject(
|
||||
params: Parameters<typeof streamObject>[0],
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamObject>> {
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'streamObject',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
streamObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof streamObject>[0])
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成图像
|
||||
*/
|
||||
async generateImage(
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'> & { model: string | ImageModelV2 }
|
||||
): Promise<ReturnType<typeof generateImage>> {
|
||||
try {
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return await this.pluginEngine.executeImageWithPlugins(
|
||||
'generateImage',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) => {
|
||||
return await generateImage({ model: resolvedModel, ...transformedParams })
|
||||
}
|
||||
)
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
const modelId = typeof params.model === 'string' ? params.model : params.model.modelId
|
||||
throw new ImageGenerationError(
|
||||
`Failed to generate image: ${error.message}`,
|
||||
this.config.providerId,
|
||||
modelId,
|
||||
error
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// === 辅助方法 ===
|
||||
|
||||
/**
|
||||
* 解析模型:如果是字符串则创建模型,如果是模型则直接返回
|
||||
*/
|
||||
private async resolveModel(
|
||||
modelOrId: LanguageModel,
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<LanguageModelV2> {
|
||||
if (typeof modelOrId === 'string') {
|
||||
// 🎯 字符串modelId,使用新的ModelResolver解析,传递完整参数
|
||||
return await globalModelResolver.resolveLanguageModel(
|
||||
modelOrId, // 支持 'gpt-4' 和 'aihubmix:anthropic:claude-3.5-sonnet'
|
||||
this.config.providerId, // fallback provider
|
||||
this.config.providerSettings, // provider options
|
||||
middlewares // 中间件数组
|
||||
)
|
||||
} else {
|
||||
// 已经是模型,直接返回
|
||||
return modelOrId
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析图像模型:如果是字符串则创建图像模型,如果是模型则直接返回
|
||||
*/
|
||||
private async resolveImageModel(modelOrId: ImageModelV2 | string): Promise<ImageModelV2> {
|
||||
try {
|
||||
if (typeof modelOrId === 'string') {
|
||||
// 字符串modelId,使用新的ModelResolver解析
|
||||
return await globalModelResolver.resolveImageModel(
|
||||
modelOrId, // 支持 'dall-e-3' 和 'aihubmix:openai:dall-e-3'
|
||||
this.config.providerId // fallback provider
|
||||
)
|
||||
} else {
|
||||
// 已经是模型,直接返回
|
||||
return modelOrId
|
||||
}
|
||||
} catch (error) {
|
||||
throw new ImageModelResolutionError(
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
this.config.providerId,
|
||||
error instanceof Error ? error : undefined
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// === 静态工厂方法 ===
|
||||
|
||||
/**
|
||||
* 创建执行器 - 支持已知provider的类型安全
|
||||
*/
|
||||
static create<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ModelConfig<T>['providerSettings'],
|
||||
plugins?: AiPlugin[]
|
||||
): RuntimeExecutor<T> {
|
||||
return new RuntimeExecutor({
|
||||
providerId,
|
||||
providerSettings: options,
|
||||
plugins
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenAI Compatible执行器
|
||||
*/
|
||||
static createOpenAICompatible(
|
||||
options: ModelConfig<'openai-compatible'>['providerSettings'],
|
||||
plugins: AiPlugin[] = []
|
||||
): RuntimeExecutor<'openai-compatible'> {
|
||||
return new RuntimeExecutor({
|
||||
providerId: 'openai-compatible',
|
||||
providerSettings: options,
|
||||
plugins
|
||||
})
|
||||
}
|
||||
}
|
||||
117
packages/aiCore/src/core/runtime/index.ts
Normal file
117
packages/aiCore/src/core/runtime/index.ts
Normal file
@ -0,0 +1,117 @@
|
||||
/**
|
||||
* Runtime 模块导出
|
||||
* 专注于运行时插件化AI调用处理
|
||||
*/
|
||||
|
||||
// 主要的运行时执行器
|
||||
export { RuntimeExecutor } from './executor'
|
||||
|
||||
// 导出类型
|
||||
export type { RuntimeConfig } from './types'
|
||||
|
||||
// === 便捷工厂函数 ===
|
||||
|
||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import { type AiPlugin } from '../plugins'
|
||||
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
|
||||
import { RuntimeExecutor } from './executor'
|
||||
|
||||
/**
|
||||
* 创建运行时执行器 - 支持类型安全的已知provider
|
||||
*/
|
||||
export function createExecutor<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
plugins?: AiPlugin[]
|
||||
): RuntimeExecutor<T> {
|
||||
return RuntimeExecutor.create(providerId, options, plugins)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenAI Compatible执行器
|
||||
*/
|
||||
export function createOpenAICompatibleExecutor(
|
||||
options: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' },
|
||||
plugins: AiPlugin[] = []
|
||||
): RuntimeExecutor<'openai-compatible'> {
|
||||
return RuntimeExecutor.createOpenAICompatible(options, plugins)
|
||||
}
|
||||
|
||||
// === 直接调用API(无需创建executor实例)===
|
||||
|
||||
/**
|
||||
* 直接流式文本生成 - 支持middlewares
|
||||
*/
|
||||
export async function streamText<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
params: Parameters<RuntimeExecutor<T>['streamText']>[0],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.streamText(params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接生成文本 - 支持middlewares
|
||||
*/
|
||||
export async function generateText<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
params: Parameters<RuntimeExecutor<T>['generateText']>[0],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateText(params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接生成结构化对象 - 支持middlewares
|
||||
*/
|
||||
export async function generateObject<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
params: Parameters<RuntimeExecutor<T>['generateObject']>[0],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateObject(params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接流式生成结构化对象 - 支持middlewares
|
||||
*/
|
||||
export async function streamObject<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
params: Parameters<RuntimeExecutor<T>['streamObject']>[0],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.streamObject(params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接生成图像 - 支持middlewares
|
||||
*/
|
||||
export async function generateImage<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
params: Parameters<RuntimeExecutor<T>['generateImage']>[0],
|
||||
plugins?: AiPlugin[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateImage']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateImage(params)
|
||||
}
|
||||
|
||||
// === Agent 功能预留 ===
|
||||
// 未来将在 ../agents/ 文件夹中添加:
|
||||
// - AgentExecutor.ts
|
||||
// - WorkflowManager.ts
|
||||
// - ConversationManager.ts
|
||||
// 并在此处导出相关API
|
||||
290
packages/aiCore/src/core/runtime/pluginEngine.ts
Normal file
290
packages/aiCore/src/core/runtime/pluginEngine.ts
Normal file
@ -0,0 +1,290 @@
|
||||
/* eslint-disable @eslint-react/naming-convention/context-name */
|
||||
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import { LanguageModel } from 'ai'
|
||||
|
||||
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
||||
import { type ProviderId } from '../providers/types'
|
||||
|
||||
/**
|
||||
* 插件增强的 AI 客户端
|
||||
* 专注于插件处理,不暴露用户API
|
||||
*/
|
||||
export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
private pluginManager: PluginManager
|
||||
|
||||
constructor(
|
||||
private readonly providerId: T,
|
||||
// private readonly options: ProviderSettingsMap[T],
|
||||
plugins: AiPlugin[] = []
|
||||
) {
|
||||
this.pluginManager = new PluginManager(plugins)
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加插件
|
||||
*/
|
||||
use(plugin: AiPlugin): this {
|
||||
this.pluginManager.use(plugin)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量添加插件
|
||||
*/
|
||||
usePlugins(plugins: AiPlugin[]): this {
|
||||
plugins.forEach((plugin) => this.use(plugin))
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除插件
|
||||
*/
|
||||
removePlugin(pluginName: string): this {
|
||||
this.pluginManager.remove(pluginName)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取插件统计
|
||||
*/
|
||||
getPluginStats() {
|
||||
return this.pluginManager.getStats()
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有插件
|
||||
*/
|
||||
getPlugins() {
|
||||
return this.pluginManager.getPlugins()
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行带插件的操作(非流式)
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeWithPlugins<TParams, TResult>(
|
||||
methodName: string,
|
||||
model: LanguageModel,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: LanguageModel | undefined
|
||||
let modelId: string
|
||||
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
} else {
|
||||
// 模型对象:直接使用
|
||||
resolvedModel = model
|
||||
modelId = model.modelId
|
||||
}
|
||||
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeWithPlugins(methodName, model, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型(如果是字符串)
|
||||
if (typeof model === 'string') {
|
||||
const resolved = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
if (!resolved) {
|
||||
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||
}
|
||||
resolvedModel = resolved
|
||||
}
|
||||
|
||||
if (!resolvedModel) {
|
||||
throw new Error(`Model resolution failed: no model available`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 执行具体的 API 调用
|
||||
const result = await executor(resolvedModel, transformedParams)
|
||||
|
||||
// 5. 转换结果(对于非流式调用)
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 6. 触发完成事件
|
||||
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 7. 触发错误事件
|
||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行带插件的图像生成操作
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeImageWithPlugins<TParams, TResult>(
|
||||
methodName: string,
|
||||
model: ImageModelV2 | string,
|
||||
params: TParams,
|
||||
executor: (model: ImageModelV2, transformedParams: TParams) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: ImageModelV2 | undefined
|
||||
let modelId: string
|
||||
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
} else {
|
||||
// 模型对象:直接使用
|
||||
resolvedModel = model
|
||||
modelId = model.modelId
|
||||
}
|
||||
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeImageWithPlugins(methodName, model, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型(如果是字符串)
|
||||
if (typeof model === 'string') {
|
||||
const resolved = await this.pluginManager.executeFirst<ImageModelV2>('resolveModel', modelId, context)
|
||||
if (!resolved) {
|
||||
throw new Error(`Failed to resolve image model: ${modelId}`)
|
||||
}
|
||||
resolvedModel = resolved
|
||||
}
|
||||
|
||||
if (!resolvedModel) {
|
||||
throw new Error(`Image model resolution failed: no model available`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 执行具体的 API 调用
|
||||
const result = await executor(resolvedModel, transformedParams)
|
||||
|
||||
// 5. 转换结果
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 6. 触发完成事件
|
||||
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 7. 触发错误事件
|
||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行流式调用的通用逻辑(支持流转换器)
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeStreamWithPlugins<TParams, TResult>(
|
||||
methodName: string,
|
||||
model: LanguageModel,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: LanguageModel | undefined
|
||||
let modelId: string
|
||||
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
} else {
|
||||
// 模型对象:直接使用
|
||||
resolvedModel = model
|
||||
modelId = model.modelId
|
||||
}
|
||||
|
||||
// 创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeStreamWithPlugins(methodName, model, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型(如果是字符串)
|
||||
if (typeof model === 'string') {
|
||||
const resolved = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
if (!resolved) {
|
||||
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||
}
|
||||
resolvedModel = resolved
|
||||
}
|
||||
|
||||
if (!resolvedModel) {
|
||||
throw new Error(`Model resolution failed: no model available`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 收集流转换器
|
||||
const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context)
|
||||
|
||||
// 5. 执行流式 API 调用
|
||||
const result = await executor(resolvedModel, transformedParams, streamTransforms)
|
||||
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件)
|
||||
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 7. 触发错误事件
|
||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
15
packages/aiCore/src/core/runtime/types.ts
Normal file
15
packages/aiCore/src/core/runtime/types.ts
Normal file
@ -0,0 +1,15 @@
|
||||
/**
|
||||
* Runtime 层类型定义
|
||||
*/
|
||||
import { type ModelConfig } from '../models/types'
|
||||
import { type AiPlugin } from '../plugins'
|
||||
import { type ProviderId } from '../providers/types'
|
||||
|
||||
/**
|
||||
* 运行时执行器配置
|
||||
*/
|
||||
export interface RuntimeConfig<T extends ProviderId = ProviderId> {
|
||||
providerId: T
|
||||
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
46
packages/aiCore/src/index.ts
Normal file
46
packages/aiCore/src/index.ts
Normal file
@ -0,0 +1,46 @@
|
||||
/**
|
||||
* Cherry Studio AI Core Package
|
||||
* 基于 Vercel AI SDK 的统一 AI Provider 接口
|
||||
*/
|
||||
|
||||
// 导入内部使用的类和函数
|
||||
|
||||
// ==================== 主要用户接口 ====================
|
||||
export {
|
||||
createExecutor,
|
||||
createOpenAICompatibleExecutor,
|
||||
generateImage,
|
||||
generateObject,
|
||||
generateText,
|
||||
streamText
|
||||
} from './core/runtime'
|
||||
|
||||
// ==================== 高级API ====================
|
||||
export { globalModelResolver as modelResolver } from './core/models'
|
||||
|
||||
// ==================== 插件系统 ====================
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins'
|
||||
export { createContext, definePlugin, PluginManager } from './core/plugins'
|
||||
// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
|
||||
export { PluginEngine } from './core/runtime/pluginEngine'
|
||||
|
||||
// ==================== AI SDK 常用类型导出 ====================
|
||||
// 直接导出 AI SDK 的常用类型,方便使用
|
||||
export type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
|
||||
export type { ToolCall } from '@ai-sdk/provider-utils'
|
||||
export type { ReasoningPart } from '@ai-sdk/provider-utils'
|
||||
|
||||
// ==================== 选项 ====================
|
||||
export {
|
||||
createAnthropicOptions,
|
||||
createGoogleOptions,
|
||||
createOpenAIOptions,
|
||||
type ExtractProviderOptions,
|
||||
mergeProviderOptions,
|
||||
type ProviderOptionsMap,
|
||||
type TypedProviderOptions
|
||||
} from './core/options'
|
||||
|
||||
// ==================== 包信息 ====================
|
||||
export const AI_CORE_VERSION = '1.0.0'
|
||||
export const AI_CORE_NAME = '@cherrystudio/ai-core'
|
||||
2
packages/aiCore/src/types.ts
Normal file
2
packages/aiCore/src/types.ts
Normal file
@ -0,0 +1,2 @@
|
||||
// 重新导出插件类型
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'
|
||||
26
packages/aiCore/tsconfig.json
Normal file
26
packages/aiCore/tsconfig.json
Normal file
@ -0,0 +1,26 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"declaration": true,
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"resolveJsonModule": true,
|
||||
"allowSyntheticDefaultImports": true,
|
||||
"noEmitOnError": false,
|
||||
"experimentalDecorators": true,
|
||||
"emitDecoratorMetadata": true
|
||||
},
|
||||
"include": [
|
||||
"src/**/*"
|
||||
],
|
||||
"exclude": [
|
||||
"node_modules",
|
||||
"dist"
|
||||
]
|
||||
}
|
||||
14
packages/aiCore/tsdown.config.ts
Normal file
14
packages/aiCore/tsdown.config.ts
Normal file
@ -0,0 +1,14 @@
|
||||
import { defineConfig } from 'tsdown'
|
||||
|
||||
export default defineConfig({
|
||||
entry: {
|
||||
index: 'src/index.ts',
|
||||
'built-in/plugins/index': 'src/core/plugins/built-in/index.ts',
|
||||
'provider/index': 'src/core/providers/index.ts'
|
||||
},
|
||||
outDir: 'dist',
|
||||
format: ['esm', 'cjs'],
|
||||
clean: true,
|
||||
dts: true,
|
||||
tsconfig: 'tsconfig.json'
|
||||
})
|
||||
15
packages/aiCore/vitest.config.ts
Normal file
15
packages/aiCore/vitest.config.ts
Normal file
@ -0,0 +1,15 @@
|
||||
import { defineConfig } from 'vitest/config'
|
||||
|
||||
export default defineConfig({
|
||||
test: {
|
||||
globals: true
|
||||
},
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': './src'
|
||||
}
|
||||
},
|
||||
esbuild: {
|
||||
target: 'node18'
|
||||
}
|
||||
})
|
||||
@ -83,7 +83,7 @@ export enum IpcChannel {
|
||||
Mcp_UploadDxt = 'mcp:upload-dxt',
|
||||
Mcp_AbortTool = 'mcp:abort-tool',
|
||||
Mcp_GetServerVersion = 'mcp:get-server-version',
|
||||
|
||||
Mcp_Progress = 'mcp:progress',
|
||||
// Python
|
||||
Python_Execute = 'python:execute',
|
||||
|
||||
@ -123,6 +123,12 @@ export enum IpcChannel {
|
||||
Windows_SetMinimumSize = 'window:set-minimum-size',
|
||||
Windows_Resize = 'window:resize',
|
||||
Windows_GetSize = 'window:get-size',
|
||||
Windows_Minimize = 'window:minimize',
|
||||
Windows_Maximize = 'window:maximize',
|
||||
Windows_Unmaximize = 'window:unmaximize',
|
||||
Windows_Close = 'window:close',
|
||||
Windows_IsMaximized = 'window:is-maximized',
|
||||
Windows_MaximizedChanged = 'window:maximized-changed',
|
||||
|
||||
KnowledgeBase_Create = 'knowledge-base:create',
|
||||
KnowledgeBase_Reset = 'knowledge-base:reset',
|
||||
@ -322,6 +328,14 @@ export enum IpcChannel {
|
||||
TRACE_CLEAN_LOCAL_DATA = 'trace:cleanLocalData',
|
||||
TRACE_ADD_STREAM_MESSAGE = 'trace:addStreamMessage',
|
||||
|
||||
// Anthropic OAuth
|
||||
Anthropic_StartOAuthFlow = 'anthropic:start-oauth-flow',
|
||||
Anthropic_CompleteOAuthWithCode = 'anthropic:complete-oauth-with-code',
|
||||
Anthropic_CancelOAuthFlow = 'anthropic:cancel-oauth-flow',
|
||||
Anthropic_GetAccessToken = 'anthropic:get-access-token',
|
||||
Anthropic_HasCredentials = 'anthropic:has-credentials',
|
||||
Anthropic_ClearCredentials = 'anthropic:clear-credentials',
|
||||
|
||||
// CodeTools
|
||||
CodeTools_Run = 'code-tools:run',
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ export type LoaderReturn = {
|
||||
loaderType: string
|
||||
status?: ProcessingStatus
|
||||
message?: string
|
||||
messageSource?: 'preprocess' | 'embedding'
|
||||
messageSource?: 'preprocess' | 'embedding' | 'validation'
|
||||
}
|
||||
|
||||
export type FileChangeEventType = 'add' | 'change' | 'unlink' | 'addDir' | 'unlinkDir'
|
||||
@ -17,3 +17,8 @@ export type FileChangeEvent = {
|
||||
filePath: string
|
||||
watchPath: string
|
||||
}
|
||||
|
||||
export type MCPProgressEvent = {
|
||||
callId: string
|
||||
progress: number // 0-1 range
|
||||
}
|
||||
|
||||
@ -1,199 +1,274 @@
|
||||
<!DOCTYPE html>
|
||||
<!doctype html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>许可协议 | License Agreement</title>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
</head>
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>许可协议 | License Agreement</title>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
</head>
|
||||
<body class="bg-gray-50">
|
||||
<div class="max-w-4xl mx-auto px-4 py-8">
|
||||
<!-- 中文版本 -->
|
||||
<div class="mb-12">
|
||||
<h1 class="text-3xl font-bold mb-8 text-gray-900">许可协议</h1>
|
||||
|
||||
<body class="bg-gray-50">
|
||||
<div class="max-w-4xl mx-auto px-4 py-8">
|
||||
<!-- 中文版本 -->
|
||||
<div class="mb-12">
|
||||
<h1 class="text-3xl font-bold mb-8 text-gray-900">许可协议</h1>
|
||||
|
||||
<p class="mb-6 text-gray-700">本项目采用<strong>区分用户的双重许可 (User-Segmented Dual Licensing)</strong> 模式。</p>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">核心原则</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li><strong>个人用户 和 10人及以下企业/组织:</strong> 默认适用 <strong>GNU Affero 通用公共许可证 v3.0 (AGPLv3)</strong>。</li>
|
||||
<li><strong>超过10人的企业/组织:</strong> <strong>必须</strong> 获取 <strong>商业许可证 (Commercial License)</strong>。</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">定义:"10人及以下"</h2>
|
||||
<p class="text-gray-700">
|
||||
指在您的组织(包括公司、非营利组织、政府机构、教育机构等任何实体)中,能够访问、使用或以任何方式直接或间接受益于本软件(Cherry
|
||||
Studio)功能的个人总数不超过10人。这包括但不限于开发者、测试人员、运营人员、最终用户、通过集成系统间接使用者等。
|
||||
<p class="mb-6 text-gray-700">
|
||||
本项目采用<strong>区分用户的双重许可 (User-Segmented Dual Licensing)</strong> 模式。
|
||||
</p>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">1. 开源许可证 (Open Source License): AGPLv3 - 适用于个人及10人及以下组织
|
||||
</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>如果您是个人用户,或者您的组织满足上述"10人及以下"的定义,您可以在 <strong>AGPLv3</strong> 的条款下自由使用、修改和分发 Cherry Studio。AGPLv3 的完整文本可以访问
|
||||
<a href="https://www.gnu.org/licenses/agpl-3.0.html"
|
||||
class="text-blue-600 hover:underline">https://www.gnu.org/licenses/agpl-3.0.html</a> 获取。
|
||||
</li>
|
||||
<li><strong>核心义务:</strong> AGPLv3 的一个关键要求是,如果您修改了 Cherry Studio 并通过网络提供服务,或者分发了修改后的版本,您必须以 AGPLv3
|
||||
许可证向接收者提供相应的<strong>完整源代码</strong>。即使您符合"10人及以下"的标准,如果您希望避免此源代码公开义务,您也需要考虑获取商业许可证(见下文)。</li>
|
||||
<li>使用前请务必仔细阅读并理解 AGPLv3 的所有条款。</li>
|
||||
</ul>
|
||||
</section>
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">核心原则</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>
|
||||
<strong>个人用户 和 10人及以下企业/组织:</strong> 默认适用
|
||||
<strong>GNU Affero 通用公共许可证 v3.0 (AGPLv3)</strong>。
|
||||
</li>
|
||||
<li>
|
||||
<strong>超过10人的企业/组织:</strong> <strong>必须</strong> 获取
|
||||
<strong>商业许可证 (Commercial License)</strong>。
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">2. 商业许可证 (Commercial License) - 适用于超过10人的组织,或希望规避 AGPLv3
|
||||
义务的用户</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li><strong>强制要求:</strong>
|
||||
如果您的组织<strong>不</strong>满足上述"10人及以下"的定义(即有11人或更多人可以访问、使用或受益于本软件),您<strong>必须</strong>联系我们获取并签署一份商业许可证才能使用
|
||||
Cherry Studio。</li>
|
||||
<li><strong>自愿选择:</strong> 即使您的组织满足"10人及以下"的条件,但如果您的使用场景<strong>无法满足 AGPLv3
|
||||
的条款要求</strong>(特别是关于<strong>源代码公开</strong>的义务),或者您需要 AGPLv3 <strong>未提供</strong>的特定商业条款(如保证、赔偿、无 Copyleft
|
||||
限制等),您也<strong>必须</strong>联系我们获取并签署一份商业许可证。</li>
|
||||
<li><strong>需要商业许可证的常见情况包括(但不限于):</strong>
|
||||
<ul class="list-disc pl-6 mt-2 space-y-1">
|
||||
<li>您的组织规模超过10人。</li>
|
||||
<li>(无论组织规模)您希望分发修改过的 Cherry Studio 版本,但<strong>不希望</strong>根据 AGPLv3 公开您修改部分的源代码。</li>
|
||||
<li>(无论组织规模)您希望基于修改过的 Cherry Studio 提供网络服务(SaaS),但<strong>不希望</strong>根据 AGPLv3 向服务使用者提供修改后的源代码。</li>
|
||||
<li>(无论组织规模)您的公司政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件,或要求闭源分发及保密。</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li><strong>获取商业许可:</strong> 请通过邮箱 <a href="mailto:bd@cherry-ai.com"
|
||||
class="text-blue-600 hover:underline">bd@cherry-ai.com</a> 联系 Cherry Studio 开发团队洽谈商业授权事宜。</li>
|
||||
</ul>
|
||||
</section>
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">定义:"10人及以下"</h2>
|
||||
<p class="text-gray-700">
|
||||
指在您的组织(包括公司、非营利组织、政府机构、教育机构等任何实体)中,能够访问、使用或以任何方式直接或间接受益于本软件(Cherry
|
||||
Studio)功能的个人总数不超过10人。这包括但不限于开发者、测试人员、运营人员、最终用户、通过集成系统间接使用者等。
|
||||
</p>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">3. 贡献 (Contributions)</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>我们欢迎社区对 Cherry Studio 的贡献。所有向本项目提交的贡献都将被视为在 <strong>AGPLv3</strong> 许可证下提供。</li>
|
||||
<li>通过向本项目提交贡献(例如通过 Pull Request),即表示您同意您的代码以 AGPLv3 许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。</li>
|
||||
<li>您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 Cherry Studio 版本中。</li>
|
||||
</ul>
|
||||
</section>
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">
|
||||
1. 开源许可证 (Open Source License): AGPLv3 - 适用于个人及10人及以下组织
|
||||
</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>
|
||||
如果您是个人用户,或者您的组织满足上述"10人及以下"的定义,您可以在
|
||||
<strong>AGPLv3</strong> 的条款下自由使用、修改和分发 Cherry Studio。AGPLv3 的完整文本可以访问
|
||||
<a href="https://www.gnu.org/licenses/agpl-3.0.html" class="text-blue-600 hover:underline"
|
||||
>https://www.gnu.org/licenses/agpl-3.0.html</a
|
||||
>
|
||||
获取。
|
||||
</li>
|
||||
<li>
|
||||
<strong>核心义务:</strong> AGPLv3 的一个关键要求是,如果您修改了 Cherry Studio
|
||||
并通过网络提供服务,或者分发了修改后的版本,您必须以 AGPLv3
|
||||
许可证向接收者提供相应的<strong>完整源代码</strong>。即使您符合"10人及以下"的标准,如果您希望避免此源代码公开义务,您也需要考虑获取商业许可证(见下文)。
|
||||
</li>
|
||||
<li>使用前请务必仔细阅读并理解 AGPLv3 的所有条款。</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">4. 其他条款 (Other Terms)</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>关于商业许可证的具体条款和条件,以双方签署的正式商业许可协议为准。</li>
|
||||
<li>项目维护者保留根据需要更新本许可政策(包括用户规模定义和阈值)的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。</li>
|
||||
</ul>
|
||||
</section>
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">
|
||||
2. 商业许可证 (Commercial License) - 适用于超过10人的组织,或希望规避 AGPLv3 义务的用户
|
||||
</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>
|
||||
<strong>强制要求:</strong>
|
||||
如果您的组织<strong>不</strong>满足上述"10人及以下"的定义(即有11人或更多人可以访问、使用或受益于本软件),您<strong>必须</strong>联系我们获取并签署一份商业许可证才能使用
|
||||
Cherry Studio。
|
||||
</li>
|
||||
<li>
|
||||
<strong>自愿选择:</strong> 即使您的组织满足"10人及以下"的条件,但如果您的使用场景<strong
|
||||
>无法满足 AGPLv3 的条款要求</strong
|
||||
>(特别是关于<strong>源代码公开</strong>的义务),或者您需要 AGPLv3
|
||||
<strong>未提供</strong>的特定商业条款(如保证、赔偿、无 Copyleft
|
||||
限制等),您也<strong>必须</strong>联系我们获取并签署一份商业许可证。
|
||||
</li>
|
||||
<li>
|
||||
<strong>需要商业许可证的常见情况包括(但不限于):</strong>
|
||||
<ul class="list-disc pl-6 mt-2 space-y-1">
|
||||
<li>您的组织规模超过10人。</li>
|
||||
<li>
|
||||
(无论组织规模)您希望分发修改过的 Cherry Studio 版本,但<strong>不希望</strong>根据 AGPLv3
|
||||
公开您修改部分的源代码。
|
||||
</li>
|
||||
<li>
|
||||
(无论组织规模)您希望基于修改过的 Cherry Studio 提供网络服务(SaaS),但<strong>不希望</strong>根据
|
||||
AGPLv3 向服务使用者提供修改后的源代码。
|
||||
</li>
|
||||
<li>
|
||||
(无论组织规模)您的公司政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件,或要求闭源分发及保密。
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li>
|
||||
<strong>获取商业许可:</strong> 请通过邮箱
|
||||
<a href="mailto:bd@cherry-ai.com" class="text-blue-600 hover:underline">bd@cherry-ai.com</a> 联系 Cherry
|
||||
Studio 开发团队洽谈商业授权事宜。
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">3. 贡献 (Contributions)</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>
|
||||
我们欢迎社区对 Cherry Studio 的贡献。所有向本项目提交的贡献都将被视为在
|
||||
<strong>AGPLv3</strong> 许可证下提供。
|
||||
</li>
|
||||
<li>
|
||||
通过向本项目提交贡献(例如通过 Pull Request),即表示您同意您的代码以 AGPLv3
|
||||
许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。
|
||||
</li>
|
||||
<li>您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 Cherry Studio 版本中。</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">4. 其他条款 (Other Terms)</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>关于商业许可证的具体条款和条件,以双方签署的正式商业许可协议为准。</li>
|
||||
<li>
|
||||
项目维护者保留根据需要更新本许可政策(包括用户规模定义和阈值)的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
</div>
|
||||
|
||||
<hr class="my-12 border-gray-300" />
|
||||
|
||||
<!-- English Version -->
|
||||
<div>
|
||||
<h1 class="text-3xl font-bold mb-8 text-gray-900">Licensing</h1>
|
||||
|
||||
<p class="mb-6 text-gray-700">This project employs a <strong>User-Segmented Dual Licensing</strong> model.</p>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">Core Principle</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>
|
||||
<strong>Individual Users and Organizations with 10 or Fewer Individuals:</strong> Governed by default
|
||||
under the <strong>GNU Affero General Public License v3.0 (AGPLv3)</strong>.
|
||||
</li>
|
||||
<li>
|
||||
<strong>Organizations with More Than 10 Individuals:</strong> <strong>Must</strong> obtain a
|
||||
<strong>Commercial License</strong>.
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">Definition: "10 or Fewer Individuals"</h2>
|
||||
<p class="text-gray-700">
|
||||
Refers to any organization (including companies, non-profits, government agencies, educational institutions,
|
||||
etc.) where the total number of individuals who can access, use, or in any way directly or indirectly
|
||||
benefit from the functionality of this software (Cherry Studio) does not exceed 10. This includes, but is
|
||||
not limited to, developers, testers, operations staff, end-users, and indirect users via integrated systems.
|
||||
</p>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">
|
||||
1. Open Source License: AGPLv3 - For Individuals and Organizations of 10 or Fewer
|
||||
</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>
|
||||
If you are an individual user, or if your organization meets the "10 or Fewer Individuals" definition
|
||||
above, you are free to use, modify, and distribute Cherry Studio under the terms of the
|
||||
<strong>AGPLv3</strong>. The full text of the AGPLv3 can be found at
|
||||
<a href="https://www.gnu.org/licenses/agpl-3.0.html" class="text-blue-600 hover:underline"
|
||||
>https://www.gnu.org/licenses/agpl-3.0.html</a
|
||||
>.
|
||||
</li>
|
||||
<li>
|
||||
<strong>Core Obligation:</strong> A key requirement of the AGPLv3 is that if you modify Cherry Studio and
|
||||
make it available over a network, or distribute the modified version, you must provide the
|
||||
<strong>complete corresponding source code</strong> under the AGPLv3 license to the recipients. Even if
|
||||
you qualify under the "10 or Fewer Individuals" rule, if you wish to avoid this source code disclosure
|
||||
obligation, you will need to obtain a Commercial License (see below).
|
||||
</li>
|
||||
<li>Please read and understand the full terms of the AGPLv3 carefully before use.</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">
|
||||
2. Commercial License - For Organizations with More Than 10 Individuals, or Users Needing to Avoid AGPLv3
|
||||
Obligations
|
||||
</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>
|
||||
<strong>Mandatory Requirement:</strong> If your organization does <strong>not</strong> meet the "10 or
|
||||
Fewer Individuals" definition above (i.e., 11 or more individuals can access, use, or benefit from the
|
||||
software), you <strong>must</strong> contact us to obtain and execute a Commercial License to use Cherry
|
||||
Studio.
|
||||
</li>
|
||||
<li>
|
||||
<strong>Voluntary Option:</strong> Even if your organization meets the "10 or Fewer Individuals"
|
||||
condition, if your intended use case
|
||||
<strong>cannot comply with the terms of the AGPLv3</strong> (particularly the obligations regarding
|
||||
<strong>source code disclosure</strong>), or if you require specific commercial terms
|
||||
<strong>not offered</strong> by the AGPLv3 (such as warranties, indemnities, or freedom from copyleft
|
||||
restrictions), you also <strong>must</strong> contact us to obtain and execute a Commercial License.
|
||||
</li>
|
||||
<li>
|
||||
<strong>Common scenarios requiring a Commercial License include (but are not limited to):</strong>
|
||||
<ul class="list-disc pl-6 mt-2 space-y-1">
|
||||
<li>
|
||||
Your organization has more than 10 individuals who can access, use, or benefit from the software.
|
||||
</li>
|
||||
<li>
|
||||
(Regardless of organization size) You wish to distribute a modified version of Cherry Studio but
|
||||
<strong>do not want</strong> to disclose the source code of your modifications under AGPLv3.
|
||||
</li>
|
||||
<li>
|
||||
(Regardless of organization size) You wish to provide a network service (SaaS) based on a modified
|
||||
version of Cherry Studio but <strong>do not want</strong> to provide the modified source code to users
|
||||
of the service under AGPLv3.
|
||||
</li>
|
||||
<li>
|
||||
(Regardless of organization size) Your corporate policies, client contracts, or project requirements
|
||||
prohibit the use of AGPLv3-licensed software or mandate closed-source distribution and
|
||||
confidentiality.
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li>
|
||||
<strong>Obtaining a Commercial License:</strong> Please contact the Cherry Studio development team via
|
||||
email at <a href="mailto:bd@cherry-ai.com" class="text-blue-600 hover:underline">bd@cherry-ai.com</a> to
|
||||
discuss commercial licensing options.
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">3. Contributions</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>
|
||||
We welcome community contributions to Cherry Studio. All contributions submitted to this project are
|
||||
considered to be offered under the <strong>AGPLv3</strong> license.
|
||||
</li>
|
||||
<li>
|
||||
By submitting a contribution to this project (e.g., via a Pull Request), you agree to license your code
|
||||
under the AGPLv3 to the project and all its downstream users (regardless of whether those users ultimately
|
||||
operate under AGPLv3 or a Commercial License).
|
||||
</li>
|
||||
<li>
|
||||
You also understand and agree that your contribution may be included in distributions of Cherry Studio
|
||||
offered under our commercial license.
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">4. Other Terms</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>
|
||||
The specific terms and conditions of the Commercial License are governed by the formal commercial license
|
||||
agreement signed by both parties.
|
||||
</li>
|
||||
<li>
|
||||
The project maintainers reserve the right to update this licensing policy (including the definition and
|
||||
threshold for user count) as needed. Updates will be communicated through official project channels (e.g.,
|
||||
code repository, official website).
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<hr class="my-12 border-gray-300">
|
||||
|
||||
<!-- English Version -->
|
||||
<div>
|
||||
<h1 class="text-3xl font-bold mb-8 text-gray-900">Licensing</h1>
|
||||
|
||||
<p class="mb-6 text-gray-700">This project employs a <strong>User-Segmented Dual Licensing</strong> model.</p>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">Core Principle</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li><strong>Individual Users and Organizations with 10 or Fewer Individuals:</strong> Governed by default
|
||||
under the <strong>GNU Affero General Public License v3.0 (AGPLv3)</strong>.</li>
|
||||
<li><strong>Organizations with More Than 10 Individuals:</strong> <strong>Must</strong> obtain a
|
||||
<strong>Commercial License</strong>.
|
||||
</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">Definition: "10 or Fewer Individuals"</h2>
|
||||
<p class="text-gray-700">
|
||||
Refers to any organization (including companies, non-profits, government agencies, educational institutions,
|
||||
etc.) where the total number of individuals who can access, use, or in any way directly or indirectly benefit
|
||||
from the functionality of this software (Cherry Studio) does not exceed 10. This includes, but is not limited
|
||||
to, developers, testers, operations staff, end-users, and indirect users via integrated systems.
|
||||
</p>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">1. Open Source License: AGPLv3 - For Individuals and
|
||||
Organizations of 10 or Fewer</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>If you are an individual user, or if your organization meets the "10 or Fewer Individuals" definition
|
||||
above, you are free to use, modify, and distribute Cherry Studio under the terms of the
|
||||
<strong>AGPLv3</strong>. The full text of the AGPLv3 can be found at <a
|
||||
href="https://www.gnu.org/licenses/agpl-3.0.html"
|
||||
class="text-blue-600 hover:underline">https://www.gnu.org/licenses/agpl-3.0.html</a>.
|
||||
</li>
|
||||
<li><strong>Core Obligation:</strong> A key requirement of the AGPLv3 is that if you modify Cherry Studio and
|
||||
make it available over a network, or distribute the modified version, you must provide the <strong>complete
|
||||
corresponding source code</strong> under the AGPLv3 license to the recipients. Even if you qualify under
|
||||
the "10 or Fewer Individuals" rule, if you wish to avoid this source code disclosure obligation, you will
|
||||
need to obtain a Commercial License (see below).</li>
|
||||
<li>Please read and understand the full terms of the AGPLv3 carefully before use.</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">2. Commercial License - For Organizations with More Than 10
|
||||
Individuals, or Users Needing to Avoid AGPLv3 Obligations</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li><strong>Mandatory Requirement:</strong> If your organization does <strong>not</strong> meet the "10 or
|
||||
Fewer Individuals" definition above (i.e., 11 or more individuals can access, use, or benefit from the
|
||||
software), you <strong>must</strong> contact us to obtain and execute a Commercial License to use Cherry
|
||||
Studio.</li>
|
||||
<li><strong>Voluntary Option:</strong> Even if your organization meets the "10 or Fewer Individuals"
|
||||
condition, if your intended use case <strong>cannot comply with the terms of the AGPLv3</strong>
|
||||
(particularly the obligations regarding <strong>source code disclosure</strong>), or if you require specific
|
||||
commercial terms <strong>not offered</strong> by the AGPLv3 (such as warranties, indemnities, or freedom
|
||||
from copyleft restrictions), you also <strong>must</strong> contact us to obtain and execute a Commercial
|
||||
License.</li>
|
||||
<li><strong>Common scenarios requiring a Commercial License include (but are not limited to):</strong>
|
||||
<ul class="list-disc pl-6 mt-2 space-y-1">
|
||||
<li>Your organization has more than 10 individuals who can access, use, or benefit from the software.</li>
|
||||
<li>(Regardless of organization size) You wish to distribute a modified version of Cherry Studio but
|
||||
<strong>do not want</strong> to disclose the source code of your modifications under AGPLv3.
|
||||
</li>
|
||||
<li>(Regardless of organization size) You wish to provide a network service (SaaS) based on a modified
|
||||
version of Cherry Studio but <strong>do not want</strong> to provide the modified source code to users
|
||||
of the service under AGPLv3.</li>
|
||||
<li>(Regardless of organization size) Your corporate policies, client contracts, or project requirements
|
||||
prohibit the use of AGPLv3-licensed software or mandate closed-source distribution and confidentiality.
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li><strong>Obtaining a Commercial License:</strong> Please contact the Cherry Studio development team via
|
||||
email at <a href="mailto:bd@cherry-ai.com" class="text-blue-600 hover:underline">bd@cherry-ai.com</a> to
|
||||
discuss commercial licensing options.</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">3. Contributions</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>We welcome community contributions to Cherry Studio. All contributions submitted to this project are
|
||||
considered to be offered under the <strong>AGPLv3</strong> license.</li>
|
||||
<li>By submitting a contribution to this project (e.g., via a Pull Request), you agree to license your code
|
||||
under the AGPLv3 to the project and all its downstream users (regardless of whether those users ultimately
|
||||
operate under AGPLv3 or a Commercial License).</li>
|
||||
<li>You also understand and agree that your contribution may be included in distributions of Cherry Studio
|
||||
offered under our commercial license.</li>
|
||||
</ul>
|
||||
</section>
|
||||
|
||||
<section class="mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4 text-gray-900">4. Other Terms</h2>
|
||||
<ul class="list-disc pl-6 space-y-2 text-gray-700">
|
||||
<li>The specific terms and conditions of the Commercial License are governed by the formal commercial license
|
||||
agreement signed by both parties.</li>
|
||||
<li>The project maintainers reserve the right to update this licensing policy (including the definition and
|
||||
threshold for user count) as needed. Updates will be communicated through official project channels (e.g.,
|
||||
code repository, official website).</li>
|
||||
</ul>
|
||||
</section>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -7,6 +7,7 @@ import { preferenceService } from '@data/PreferenceService'
|
||||
import { loggerService } from '@logger'
|
||||
import { isLinux, isMac, isPortable, isWin } from '@main/constant'
|
||||
import { generateSignature } from '@main/integration/cherryin'
|
||||
import anthropicService from '@main/services/AnthropicService'
|
||||
import { getBinaryPath, isBinaryExists, runInstallScript } from '@main/utils/process'
|
||||
import { handleZoomFactor } from '@main/utils/zoom'
|
||||
import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
||||
@ -27,7 +28,7 @@ import DxtService from './services/DxtService'
|
||||
import { ExportService } from './services/ExportService'
|
||||
import { fileStorage as fileManager } from './services/FileStorage'
|
||||
import FileService from './services/FileSystemService'
|
||||
import KnowledgeService from './services/KnowledgeService'
|
||||
import KnowledgeService from './services/knowledge/KnowledgeService'
|
||||
import mcpService from './services/MCPService'
|
||||
import MemoryService from './services/memory/MemoryService'
|
||||
import { openTraceWindow, setTraceWindowTitle } from './services/NodeTraceService'
|
||||
@ -524,7 +525,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
}
|
||||
})
|
||||
|
||||
// knowledge base
|
||||
ipcMain.handle(IpcChannel.KnowledgeBase_Create, KnowledgeService.create.bind(KnowledgeService))
|
||||
ipcMain.handle(IpcChannel.KnowledgeBase_Reset, KnowledgeService.reset.bind(KnowledgeService))
|
||||
ipcMain.handle(IpcChannel.KnowledgeBase_Delete, KnowledgeService.delete.bind(KnowledgeService))
|
||||
@ -588,6 +588,41 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
return [width, height]
|
||||
})
|
||||
|
||||
// Window Controls
|
||||
ipcMain.handle(IpcChannel.Windows_Minimize, () => {
|
||||
checkMainWindow()
|
||||
mainWindow.minimize()
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.Windows_Maximize, () => {
|
||||
checkMainWindow()
|
||||
mainWindow.maximize()
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.Windows_Unmaximize, () => {
|
||||
checkMainWindow()
|
||||
mainWindow.unmaximize()
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.Windows_Close, () => {
|
||||
checkMainWindow()
|
||||
mainWindow.close()
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.Windows_IsMaximized, () => {
|
||||
checkMainWindow()
|
||||
return mainWindow.isMaximized()
|
||||
})
|
||||
|
||||
// Send maximized state changes to renderer
|
||||
mainWindow.on('maximize', () => {
|
||||
mainWindow.webContents.send(IpcChannel.Windows_MaximizedChanged, true)
|
||||
})
|
||||
|
||||
mainWindow.on('unmaximize', () => {
|
||||
mainWindow.webContents.send(IpcChannel.Windows_MaximizedChanged, false)
|
||||
})
|
||||
|
||||
// VertexAI
|
||||
ipcMain.handle(IpcChannel.VertexAI_GetAuthHeaders, async (_, params) => {
|
||||
return vertexAIService.getAuthHeaders(params)
|
||||
@ -747,6 +782,16 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
addStreamMessage(spanId, modelName, context, msg)
|
||||
)
|
||||
|
||||
// Anthropic OAuth
|
||||
ipcMain.handle(IpcChannel.Anthropic_StartOAuthFlow, () => anthropicService.startOAuthFlow())
|
||||
ipcMain.handle(IpcChannel.Anthropic_CompleteOAuthWithCode, (_, code: string) =>
|
||||
anthropicService.completeOAuthWithCode(code)
|
||||
)
|
||||
ipcMain.handle(IpcChannel.Anthropic_CancelOAuthFlow, () => anthropicService.cancelOAuthFlow())
|
||||
ipcMain.handle(IpcChannel.Anthropic_GetAccessToken, () => anthropicService.getValidAccessToken())
|
||||
ipcMain.handle(IpcChannel.Anthropic_HasCredentials, () => anthropicService.hasCredentials())
|
||||
ipcMain.handle(IpcChannel.Anthropic_ClearCredentials, () => anthropicService.clearCredentials())
|
||||
|
||||
// CodeTools
|
||||
ipcMain.handle(IpcChannel.CodeTools_Run, codeToolsService.run)
|
||||
|
||||
|
||||
63
src/main/knowledge/langchain/embeddings/EmbeddingsFactory.ts
Normal file
63
src/main/knowledge/langchain/embeddings/EmbeddingsFactory.ts
Normal file
@ -0,0 +1,63 @@
|
||||
import { VoyageEmbeddings } from '@langchain/community/embeddings/voyage'
|
||||
import type { Embeddings } from '@langchain/core/embeddings'
|
||||
import { OllamaEmbeddings } from '@langchain/ollama'
|
||||
import { AzureOpenAIEmbeddings, OpenAIEmbeddings } from '@langchain/openai'
|
||||
import { ApiClient, SystemProviderIds } from '@types'
|
||||
|
||||
import { isJinaEmbeddingsModel, JinaEmbeddings } from './JinaEmbeddings'
|
||||
|
||||
export default class EmbeddingsFactory {
|
||||
static create({ embedApiClient, dimensions }: { embedApiClient: ApiClient; dimensions?: number }): Embeddings {
|
||||
const batchSize = 10
|
||||
const { model, provider, apiKey, apiVersion, baseURL } = embedApiClient
|
||||
if (provider === SystemProviderIds.ollama) {
|
||||
let baseUrl = baseURL
|
||||
if (baseURL.includes('v1/')) {
|
||||
baseUrl = baseURL.replace('v1/', '')
|
||||
}
|
||||
const headers = apiKey
|
||||
? {
|
||||
Authorization: `Bearer ${apiKey}`
|
||||
}
|
||||
: undefined
|
||||
return new OllamaEmbeddings({
|
||||
model: model,
|
||||
baseUrl,
|
||||
...headers
|
||||
})
|
||||
} else if (provider === SystemProviderIds.voyageai) {
|
||||
return new VoyageEmbeddings({
|
||||
modelName: model,
|
||||
apiKey,
|
||||
outputDimension: dimensions,
|
||||
batchSize
|
||||
})
|
||||
}
|
||||
if (isJinaEmbeddingsModel(model)) {
|
||||
return new JinaEmbeddings({
|
||||
model,
|
||||
apiKey,
|
||||
batchSize,
|
||||
dimensions,
|
||||
baseUrl: baseURL
|
||||
})
|
||||
}
|
||||
if (apiVersion !== undefined) {
|
||||
return new AzureOpenAIEmbeddings({
|
||||
azureOpenAIApiKey: apiKey,
|
||||
azureOpenAIApiVersion: apiVersion,
|
||||
azureOpenAIApiDeploymentName: model,
|
||||
azureOpenAIEndpoint: baseURL,
|
||||
dimensions,
|
||||
batchSize
|
||||
})
|
||||
}
|
||||
return new OpenAIEmbeddings({
|
||||
model,
|
||||
apiKey,
|
||||
dimensions,
|
||||
batchSize,
|
||||
configuration: { baseURL }
|
||||
})
|
||||
}
|
||||
}
|
||||
199
src/main/knowledge/langchain/embeddings/JinaEmbeddings.ts
Normal file
199
src/main/knowledge/langchain/embeddings/JinaEmbeddings.ts
Normal file
@ -0,0 +1,199 @@
|
||||
import { Embeddings, type EmbeddingsParams } from '@langchain/core/embeddings'
|
||||
import { chunkArray } from '@langchain/core/utils/chunk_array'
|
||||
import { getEnvironmentVariable } from '@langchain/core/utils/env'
|
||||
import z from 'zod/v4'
|
||||
|
||||
const jinaModelSchema = z.union([
|
||||
z.literal('jina-clip-v2'),
|
||||
z.literal('jina-embeddings-v3'),
|
||||
z.literal('jina-colbert-v2'),
|
||||
z.literal('jina-clip-v1'),
|
||||
z.literal('jina-colbert-v1-en'),
|
||||
z.literal('jina-embeddings-v2-base-es'),
|
||||
z.literal('jina-embeddings-v2-base-code'),
|
||||
z.literal('jina-embeddings-v2-base-de'),
|
||||
z.literal('jina-embeddings-v2-base-zh'),
|
||||
z.literal('jina-embeddings-v2-base-en')
|
||||
])
|
||||
|
||||
type JinaModel = z.infer<typeof jinaModelSchema>
|
||||
|
||||
export const isJinaEmbeddingsModel = (model: string): model is JinaModel => {
|
||||
return jinaModelSchema.safeParse(model).success
|
||||
}
|
||||
|
||||
interface JinaEmbeddingsParams extends EmbeddingsParams {
|
||||
/** Model name to use */
|
||||
model: JinaModel
|
||||
|
||||
baseUrl?: string
|
||||
|
||||
/**
|
||||
* Timeout to use when making requests to Jina.
|
||||
*/
|
||||
timeout?: number
|
||||
|
||||
/**
|
||||
* The maximum number of documents to embed in a single request.
|
||||
*/
|
||||
batchSize?: number
|
||||
|
||||
/**
|
||||
* Whether to strip new lines from the input text.
|
||||
*/
|
||||
stripNewLines?: boolean
|
||||
|
||||
/**
|
||||
* The dimensions of the embedding.
|
||||
*/
|
||||
dimensions?: number
|
||||
|
||||
/**
|
||||
* Scales the embedding so its Euclidean (L2) norm becomes 1, preserving direction. Useful when downstream involves dot-product, classification, visualization..
|
||||
*/
|
||||
normalized?: boolean
|
||||
}
|
||||
|
||||
type JinaMultiModelInput =
|
||||
| {
|
||||
text: string
|
||||
image?: never
|
||||
}
|
||||
| {
|
||||
image: string
|
||||
text?: never
|
||||
}
|
||||
|
||||
type JinaEmbeddingsInput = string | JinaMultiModelInput
|
||||
|
||||
interface EmbeddingCreateParams {
|
||||
model: JinaEmbeddingsParams['model']
|
||||
|
||||
/**
|
||||
* input can be strings or JinaMultiModelInputs,if you want embed image,you should use JinaMultiModelInputs
|
||||
*/
|
||||
input: JinaEmbeddingsInput[]
|
||||
dimensions: number
|
||||
task?: 'retrieval.query' | 'retrieval.passage'
|
||||
}
|
||||
|
||||
interface EmbeddingResponse {
|
||||
model: string
|
||||
object: string
|
||||
usage: {
|
||||
total_tokens: number
|
||||
prompt_tokens: number
|
||||
}
|
||||
data: {
|
||||
object: string
|
||||
index: number
|
||||
embedding: number[]
|
||||
}[]
|
||||
}
|
||||
|
||||
interface EmbeddingErrorResponse {
|
||||
detail: string
|
||||
}
|
||||
|
||||
export class JinaEmbeddings extends Embeddings implements JinaEmbeddingsParams {
|
||||
model: JinaEmbeddingsParams['model'] = 'jina-clip-v2'
|
||||
|
||||
batchSize = 24
|
||||
|
||||
baseUrl = 'https://api.jina.ai/v1/embeddings'
|
||||
|
||||
stripNewLines = true
|
||||
|
||||
dimensions = 1024
|
||||
|
||||
apiKey: string
|
||||
|
||||
constructor(
|
||||
fields?: Partial<JinaEmbeddingsParams> & {
|
||||
apiKey?: string
|
||||
}
|
||||
) {
|
||||
const fieldsWithDefaults = { maxConcurrency: 2, ...fields }
|
||||
super(fieldsWithDefaults)
|
||||
|
||||
const apiKey =
|
||||
fieldsWithDefaults?.apiKey || getEnvironmentVariable('JINA_API_KEY') || getEnvironmentVariable('JINA_AUTH_TOKEN')
|
||||
|
||||
if (!apiKey) throw new Error('Jina API key not found')
|
||||
|
||||
this.apiKey = apiKey
|
||||
this.baseUrl = fieldsWithDefaults?.baseUrl ? `${fieldsWithDefaults?.baseUrl}embeddings` : this.baseUrl
|
||||
this.model = fieldsWithDefaults?.model ?? this.model
|
||||
this.dimensions = fieldsWithDefaults?.dimensions ?? this.dimensions
|
||||
this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize
|
||||
this.stripNewLines = fieldsWithDefaults?.stripNewLines ?? this.stripNewLines
|
||||
}
|
||||
|
||||
private doStripNewLines(input: JinaEmbeddingsInput[]) {
|
||||
if (this.stripNewLines) {
|
||||
return input.map((i) => {
|
||||
if (typeof i === 'string') {
|
||||
return i.replace(/\n/g, ' ')
|
||||
}
|
||||
if (i.text) {
|
||||
return { text: i.text.replace(/\n/g, ' ') }
|
||||
}
|
||||
return i
|
||||
})
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
async embedDocuments(input: JinaEmbeddingsInput[]): Promise<number[][]> {
|
||||
const batches = chunkArray(this.doStripNewLines(input), this.batchSize)
|
||||
const batchRequests = batches.map((batch) => {
|
||||
const params = this.getParams(batch)
|
||||
return this.embeddingWithRetry(params)
|
||||
})
|
||||
|
||||
const batchResponses = await Promise.all(batchRequests)
|
||||
const embeddings: number[][] = []
|
||||
|
||||
for (let i = 0; i < batchResponses.length; i += 1) {
|
||||
const batch = batches[i]
|
||||
const batchResponse = batchResponses[i] || []
|
||||
for (let j = 0; j < batch.length; j += 1) {
|
||||
embeddings.push(batchResponse[j])
|
||||
}
|
||||
}
|
||||
|
||||
return embeddings
|
||||
}
|
||||
|
||||
async embedQuery(input: JinaEmbeddingsInput): Promise<number[]> {
|
||||
const params = this.getParams(this.doStripNewLines([input]), true)
|
||||
|
||||
const embeddings = (await this.embeddingWithRetry(params)) || [[]]
|
||||
return embeddings[0]
|
||||
}
|
||||
|
||||
private getParams(input: JinaEmbeddingsInput[], query?: boolean): EmbeddingCreateParams {
|
||||
return {
|
||||
model: this.model,
|
||||
input,
|
||||
dimensions: this.dimensions,
|
||||
task: query ? 'retrieval.query' : this.model === 'jina-clip-v2' ? undefined : 'retrieval.passage'
|
||||
}
|
||||
}
|
||||
|
||||
private async embeddingWithRetry(body: EmbeddingCreateParams) {
|
||||
const response = await fetch(this.baseUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${this.apiKey}`
|
||||
},
|
||||
body: JSON.stringify(body)
|
||||
})
|
||||
const embeddingData: EmbeddingResponse | EmbeddingErrorResponse = await response.json()
|
||||
if ('detail' in embeddingData && embeddingData.detail) {
|
||||
throw new Error(`${embeddingData.detail}`)
|
||||
}
|
||||
return (embeddingData as EmbeddingResponse).data.map(({ embedding }) => embedding)
|
||||
}
|
||||
}
|
||||
25
src/main/knowledge/langchain/embeddings/TextEmbeddings.ts
Normal file
25
src/main/knowledge/langchain/embeddings/TextEmbeddings.ts
Normal file
@ -0,0 +1,25 @@
|
||||
import type { Embeddings as BaseEmbeddings } from '@langchain/core/embeddings'
|
||||
import { TraceMethod } from '@mcp-trace/trace-core'
|
||||
import { ApiClient } from '@types'
|
||||
|
||||
import EmbeddingsFactory from './EmbeddingsFactory'
|
||||
|
||||
export default class TextEmbeddings {
|
||||
private sdk: BaseEmbeddings
|
||||
constructor({ embedApiClient, dimensions }: { embedApiClient: ApiClient; dimensions?: number }) {
|
||||
this.sdk = EmbeddingsFactory.create({
|
||||
embedApiClient,
|
||||
dimensions
|
||||
})
|
||||
}
|
||||
|
||||
@TraceMethod({ spanName: 'embedDocuments', tag: 'Embeddings' })
|
||||
public async embedDocuments(texts: string[]): Promise<number[][]> {
|
||||
return this.sdk.embedDocuments(texts)
|
||||
}
|
||||
|
||||
@TraceMethod({ spanName: 'embedQuery', tag: 'Embeddings' })
|
||||
public async embedQuery(text: string): Promise<number[]> {
|
||||
return this.sdk.embedQuery(text)
|
||||
}
|
||||
}
|
||||
97
src/main/knowledge/langchain/loader/MarkdownLoader.ts
Normal file
97
src/main/knowledge/langchain/loader/MarkdownLoader.ts
Normal file
@ -0,0 +1,97 @@
|
||||
import { BaseDocumentLoader } from '@langchain/core/document_loaders/base'
|
||||
import { Document } from '@langchain/core/documents'
|
||||
import { readTextFileWithAutoEncoding } from '@main/utils/file'
|
||||
import MarkdownIt from 'markdown-it'
|
||||
|
||||
export class MarkdownLoader extends BaseDocumentLoader {
|
||||
private path: string
|
||||
private md: MarkdownIt
|
||||
|
||||
constructor(path: string) {
|
||||
super()
|
||||
this.path = path
|
||||
this.md = new MarkdownIt()
|
||||
}
|
||||
public async load(): Promise<Document[]> {
|
||||
const content = await readTextFileWithAutoEncoding(this.path)
|
||||
return this.parseMarkdown(content)
|
||||
}
|
||||
|
||||
private parseMarkdown(content: string): Document[] {
|
||||
const tokens = this.md.parse(content, {})
|
||||
const documents: Document[] = []
|
||||
|
||||
let currentSection: {
|
||||
heading?: string
|
||||
level?: number
|
||||
content: string
|
||||
startLine?: number
|
||||
} = { content: '' }
|
||||
|
||||
let i = 0
|
||||
while (i < tokens.length) {
|
||||
const token = tokens[i]
|
||||
|
||||
if (token.type === 'heading_open') {
|
||||
// Save previous section if it has content
|
||||
if (currentSection.content.trim()) {
|
||||
documents.push(
|
||||
new Document({
|
||||
pageContent: currentSection.content.trim(),
|
||||
metadata: {
|
||||
source: this.path,
|
||||
heading: currentSection.heading || 'Introduction',
|
||||
level: currentSection.level || 0,
|
||||
startLine: currentSection.startLine || 0
|
||||
}
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
// Start new section
|
||||
const level = parseInt(token.tag.slice(1)) // Extract number from h1, h2, etc.
|
||||
const headingContent = tokens[i + 1]?.content || ''
|
||||
|
||||
currentSection = {
|
||||
heading: headingContent,
|
||||
level: level,
|
||||
content: '',
|
||||
startLine: token.map?.[0] || 0
|
||||
}
|
||||
|
||||
// Skip heading_open, inline, heading_close tokens
|
||||
i += 3
|
||||
continue
|
||||
}
|
||||
|
||||
// Add token content to current section
|
||||
if (token.content) {
|
||||
currentSection.content += token.content
|
||||
}
|
||||
|
||||
// Add newlines for block tokens
|
||||
if (token.block && token.type !== 'heading_close') {
|
||||
currentSection.content += '\n'
|
||||
}
|
||||
|
||||
i++
|
||||
}
|
||||
|
||||
// Add the last section
|
||||
if (currentSection.content.trim()) {
|
||||
documents.push(
|
||||
new Document({
|
||||
pageContent: currentSection.content.trim(),
|
||||
metadata: {
|
||||
source: this.path,
|
||||
heading: currentSection.heading || 'Introduction',
|
||||
level: currentSection.level || 0,
|
||||
startLine: currentSection.startLine || 0
|
||||
}
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
return documents
|
||||
}
|
||||
}
|
||||
50
src/main/knowledge/langchain/loader/NoteLoader.ts
Normal file
50
src/main/knowledge/langchain/loader/NoteLoader.ts
Normal file
@ -0,0 +1,50 @@
|
||||
import { BaseDocumentLoader } from '@langchain/core/document_loaders/base'
|
||||
import { Document } from '@langchain/core/documents'
|
||||
|
||||
export class NoteLoader extends BaseDocumentLoader {
|
||||
private text: string
|
||||
private sourceUrl?: string
|
||||
constructor(
|
||||
public _text: string,
|
||||
public _sourceUrl?: string
|
||||
) {
|
||||
super()
|
||||
this.text = _text
|
||||
this.sourceUrl = _sourceUrl
|
||||
}
|
||||
|
||||
/**
|
||||
* A protected method that takes a `raw` string as a parameter and returns
|
||||
* a promise that resolves to an array containing the raw text as a single
|
||||
* element.
|
||||
* @param raw The raw text to be parsed.
|
||||
* @returns A promise that resolves to an array containing the raw text as a single element.
|
||||
*/
|
||||
protected async parse(raw: string): Promise<string[]> {
|
||||
return [raw]
|
||||
}
|
||||
|
||||
public async load(): Promise<Document[]> {
|
||||
const metadata = { source: this.sourceUrl || 'note' }
|
||||
const parsed = await this.parse(this.text)
|
||||
parsed.forEach((pageContent, i) => {
|
||||
if (typeof pageContent !== 'string') {
|
||||
throw new Error(`Expected string, at position ${i} got ${typeof pageContent}`)
|
||||
}
|
||||
})
|
||||
|
||||
return parsed.map(
|
||||
(pageContent, i) =>
|
||||
new Document({
|
||||
pageContent,
|
||||
metadata:
|
||||
parsed.length === 1
|
||||
? metadata
|
||||
: {
|
||||
...metadata,
|
||||
line: i + 1
|
||||
}
|
||||
})
|
||||
)
|
||||
}
|
||||
}
|
||||
170
src/main/knowledge/langchain/loader/YoutubeLoader.ts
Normal file
170
src/main/knowledge/langchain/loader/YoutubeLoader.ts
Normal file
@ -0,0 +1,170 @@
|
||||
import { BaseDocumentLoader } from '@langchain/core/document_loaders/base'
|
||||
import { Document } from '@langchain/core/documents'
|
||||
import { Innertube } from 'youtubei.js'
|
||||
|
||||
// ... (接口定义 YoutubeConfig 和 VideoMetadata 保持不变)
|
||||
|
||||
/**
|
||||
* Configuration options for the YoutubeLoader class. Includes properties
|
||||
* such as the videoId, language, and addVideoInfo.
|
||||
*/
|
||||
interface YoutubeConfig {
|
||||
videoId: string
|
||||
language?: string
|
||||
addVideoInfo?: boolean
|
||||
// 新增一个选项,用于控制输出格式
|
||||
transcriptFormat?: 'text' | 'srt'
|
||||
}
|
||||
|
||||
/**
|
||||
* Metadata of a YouTube video. Includes properties such as the source
|
||||
* (videoId), description, title, view_count, author, and category.
|
||||
*/
|
||||
interface VideoMetadata {
|
||||
source: string
|
||||
description?: string
|
||||
title?: string
|
||||
view_count?: number
|
||||
author?: string
|
||||
category?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* A document loader for loading data from YouTube videos. It uses the
|
||||
* youtubei.js library to fetch the transcript and video metadata.
|
||||
* @example
|
||||
* ```typescript
|
||||
* const loader = new YoutubeLoader({
|
||||
* videoId: "VIDEO_ID",
|
||||
* language: "en",
|
||||
* addVideoInfo: true,
|
||||
* transcriptFormat: "srt" // 获取 SRT 格式
|
||||
* });
|
||||
* const docs = await loader.load();
|
||||
* console.log(docs[0].pageContent);
|
||||
* ```
|
||||
*/
|
||||
export class YoutubeLoader extends BaseDocumentLoader {
|
||||
private videoId: string
|
||||
private language?: string
|
||||
private addVideoInfo: boolean
|
||||
// 新增格式化选项的私有属性
|
||||
private transcriptFormat: 'text' | 'srt'
|
||||
|
||||
constructor(config: YoutubeConfig) {
|
||||
super()
|
||||
this.videoId = config.videoId
|
||||
this.language = config?.language
|
||||
this.addVideoInfo = config?.addVideoInfo ?? false
|
||||
// 初始化格式化选项,默认为 'text' 以保持向后兼容
|
||||
this.transcriptFormat = config?.transcriptFormat ?? 'text'
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the videoId from a YouTube video URL.
|
||||
* @param url The URL of the YouTube video.
|
||||
* @returns The videoId of the YouTube video.
|
||||
*/
|
||||
private static getVideoID(url: string): string {
|
||||
const match = url.match(/.*(?:youtu.be\/|v\/|u\/\w\/|embed\/|watch\?v=)([^#&?]*).*/)
|
||||
if (match !== null && match[1].length === 11) {
|
||||
return match[1]
|
||||
} else {
|
||||
throw new Error('Failed to get youtube video id from the url')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new instance of the YoutubeLoader class from a YouTube video
|
||||
* URL.
|
||||
* @param url The URL of the YouTube video.
|
||||
* @param config Optional configuration options for the YoutubeLoader instance, excluding the videoId.
|
||||
* @returns A new instance of the YoutubeLoader class.
|
||||
*/
|
||||
static createFromUrl(url: string, config?: Omit<YoutubeConfig, 'videoId'>): YoutubeLoader {
|
||||
const videoId = YoutubeLoader.getVideoID(url)
|
||||
return new YoutubeLoader({ ...config, videoId })
|
||||
}
|
||||
|
||||
/**
|
||||
* [新增] 辅助函数:将毫秒转换为 SRT 时间戳格式 (HH:MM:SS,ms)
|
||||
* @param ms 毫秒数
|
||||
* @returns 格式化后的时间字符串
|
||||
*/
|
||||
private static formatTimestamp(ms: number): string {
|
||||
const totalSeconds = Math.floor(ms / 1000)
|
||||
const hours = Math.floor(totalSeconds / 3600)
|
||||
.toString()
|
||||
.padStart(2, '0')
|
||||
const minutes = Math.floor((totalSeconds % 3600) / 60)
|
||||
.toString()
|
||||
.padStart(2, '0')
|
||||
const seconds = (totalSeconds % 60).toString().padStart(2, '0')
|
||||
const milliseconds = (ms % 1000).toString().padStart(3, '0')
|
||||
return `${hours}:${minutes}:${seconds},${milliseconds}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the transcript and video metadata from the specified YouTube
|
||||
* video. It can return the transcript as plain text or in SRT format.
|
||||
* @returns An array of Documents representing the retrieved data.
|
||||
*/
|
||||
async load(): Promise<Document[]> {
|
||||
const metadata: VideoMetadata = {
|
||||
source: this.videoId
|
||||
}
|
||||
|
||||
try {
|
||||
const youtube = await Innertube.create({
|
||||
lang: this.language,
|
||||
retrieve_player: false
|
||||
})
|
||||
|
||||
const info = await youtube.getInfo(this.videoId)
|
||||
const transcriptData = await info.getTranscript()
|
||||
|
||||
if (!transcriptData.transcript.content?.body?.initial_segments) {
|
||||
throw new Error('Transcript segments not found in the response.')
|
||||
}
|
||||
|
||||
const segments = transcriptData.transcript.content.body.initial_segments
|
||||
|
||||
let pageContent: string
|
||||
|
||||
// 根据 transcriptFormat 选项决定如何格式化字幕
|
||||
if (this.transcriptFormat === 'srt') {
|
||||
// [修改] 将字幕片段格式化为 SRT 格式
|
||||
pageContent = segments
|
||||
.map((segment, index) => {
|
||||
const srtIndex = index + 1
|
||||
const startTime = YoutubeLoader.formatTimestamp(Number(segment.start_ms))
|
||||
const endTime = YoutubeLoader.formatTimestamp(Number(segment.end_ms))
|
||||
const text = segment.snippet?.text || '' // 使用 segment.snippet.text
|
||||
|
||||
return `${srtIndex}\n${startTime} --> ${endTime}\n${text}`
|
||||
})
|
||||
.join('\n\n') // 每个 SRT 块之间用两个换行符分隔
|
||||
} else {
|
||||
// [原始逻辑] 拼接为纯文本
|
||||
pageContent = segments.map((segment) => segment.snippet?.text || '').join(' ')
|
||||
}
|
||||
|
||||
if (this.addVideoInfo) {
|
||||
const basicInfo = info.basic_info
|
||||
metadata.description = basicInfo.short_description
|
||||
metadata.title = basicInfo.title
|
||||
metadata.view_count = basicInfo.view_count
|
||||
metadata.author = basicInfo.author
|
||||
}
|
||||
|
||||
const document = new Document({
|
||||
pageContent,
|
||||
metadata
|
||||
})
|
||||
|
||||
return [document]
|
||||
} catch (e: unknown) {
|
||||
throw new Error(`Failed to get YouTube video transcription: ${(e as Error).message}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
235
src/main/knowledge/langchain/loader/index.ts
Normal file
235
src/main/knowledge/langchain/loader/index.ts
Normal file
@ -0,0 +1,235 @@
|
||||
import { DocxLoader } from '@langchain/community/document_loaders/fs/docx'
|
||||
import { EPubLoader } from '@langchain/community/document_loaders/fs/epub'
|
||||
import { PDFLoader } from '@langchain/community/document_loaders/fs/pdf'
|
||||
import { PPTXLoader } from '@langchain/community/document_loaders/fs/pptx'
|
||||
import { CheerioWebBaseLoader } from '@langchain/community/document_loaders/web/cheerio'
|
||||
import { SitemapLoader } from '@langchain/community/document_loaders/web/sitemap'
|
||||
import { FaissStore } from '@langchain/community/vectorstores/faiss'
|
||||
import { Document } from '@langchain/core/documents'
|
||||
import { loggerService } from '@logger'
|
||||
import { UrlSource } from '@main/utils/knowledge'
|
||||
import { LoaderReturn } from '@shared/config/types'
|
||||
import { FileMetadata, FileTypes, KnowledgeBaseParams } from '@types'
|
||||
import { randomUUID } from 'crypto'
|
||||
import { JSONLoader } from 'langchain/document_loaders/fs/json'
|
||||
import { TextLoader } from 'langchain/document_loaders/fs/text'
|
||||
|
||||
import { SplitterFactory } from '../splitter'
|
||||
import { MarkdownLoader } from './MarkdownLoader'
|
||||
import { NoteLoader } from './NoteLoader'
|
||||
import { YoutubeLoader } from './YoutubeLoader'
|
||||
|
||||
const logger = loggerService.withContext('KnowledgeService File Loader')
|
||||
|
||||
type LoaderInstance =
|
||||
| TextLoader
|
||||
| PDFLoader
|
||||
| PPTXLoader
|
||||
| DocxLoader
|
||||
| JSONLoader
|
||||
| EPubLoader
|
||||
| CheerioWebBaseLoader
|
||||
| YoutubeLoader
|
||||
| SitemapLoader
|
||||
| NoteLoader
|
||||
| MarkdownLoader
|
||||
|
||||
/**
|
||||
* 为文档数组中的每个文档的 metadata 添加类型信息。
|
||||
*/
|
||||
function formatDocument(docs: Document[], type: string): Document[] {
|
||||
return docs.map((doc) => ({
|
||||
...doc,
|
||||
metadata: {
|
||||
...doc.metadata,
|
||||
type: type
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用文档处理管道
|
||||
*/
|
||||
async function processDocuments(
|
||||
base: KnowledgeBaseParams,
|
||||
vectorStore: FaissStore,
|
||||
docs: Document[],
|
||||
loaderType: string,
|
||||
splitterType?: string
|
||||
): Promise<LoaderReturn> {
|
||||
const formattedDocs = formatDocument(docs, loaderType)
|
||||
const splitter = SplitterFactory.create({
|
||||
chunkSize: base.chunkSize,
|
||||
chunkOverlap: base.chunkOverlap,
|
||||
...(splitterType && { type: splitterType })
|
||||
})
|
||||
|
||||
const splitterResults = await splitter.splitDocuments(formattedDocs)
|
||||
const ids = splitterResults.map(() => randomUUID())
|
||||
|
||||
await vectorStore.addDocuments(splitterResults, { ids })
|
||||
|
||||
return {
|
||||
entriesAdded: splitterResults.length,
|
||||
uniqueId: ids[0] || '',
|
||||
uniqueIds: ids,
|
||||
loaderType
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用加载器执行函数
|
||||
*/
|
||||
async function executeLoader(
|
||||
base: KnowledgeBaseParams,
|
||||
vectorStore: FaissStore,
|
||||
loaderInstance: LoaderInstance,
|
||||
loaderType: string,
|
||||
identifier: string,
|
||||
splitterType?: string
|
||||
): Promise<LoaderReturn> {
|
||||
const emptyResult: LoaderReturn = {
|
||||
entriesAdded: 0,
|
||||
uniqueId: '',
|
||||
uniqueIds: [],
|
||||
loaderType
|
||||
}
|
||||
|
||||
try {
|
||||
const docs = await loaderInstance.load()
|
||||
return await processDocuments(base, vectorStore, docs, loaderType, splitterType)
|
||||
} catch (error) {
|
||||
logger.error(`Error loading or processing ${identifier} with loader ${loaderType}: ${error}`)
|
||||
return emptyResult
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 文件扩展名到加载器的映射
|
||||
*/
|
||||
const FILE_LOADER_MAP: Record<string, { loader: new (path: string) => LoaderInstance; type: string }> = {
|
||||
'.pdf': { loader: PDFLoader, type: 'pdf' },
|
||||
'.txt': { loader: TextLoader, type: 'text' },
|
||||
'.pptx': { loader: PPTXLoader, type: 'pptx' },
|
||||
'.docx': { loader: DocxLoader, type: 'docx' },
|
||||
'.doc': { loader: DocxLoader, type: 'doc' },
|
||||
'.json': { loader: JSONLoader, type: 'json' },
|
||||
'.epub': { loader: EPubLoader, type: 'epub' },
|
||||
'.md': { loader: MarkdownLoader, type: 'markdown' }
|
||||
}
|
||||
|
||||
export async function addFileLoader(
|
||||
base: KnowledgeBaseParams,
|
||||
vectorStore: FaissStore,
|
||||
file: FileMetadata
|
||||
): Promise<LoaderReturn> {
|
||||
const fileExt = file.ext.toLowerCase()
|
||||
const loaderConfig = FILE_LOADER_MAP[fileExt]
|
||||
|
||||
if (!loaderConfig) {
|
||||
// 默认使用文本加载器
|
||||
const loaderInstance = new TextLoader(file.path)
|
||||
const type = fileExt.replace('.', '') || 'unknown'
|
||||
return executeLoader(base, vectorStore, loaderInstance, type, file.path)
|
||||
}
|
||||
|
||||
const loaderInstance = new loaderConfig.loader(file.path)
|
||||
return executeLoader(base, vectorStore, loaderInstance, loaderConfig.type, file.path)
|
||||
}
|
||||
|
||||
export async function addWebLoader(
|
||||
base: KnowledgeBaseParams,
|
||||
vectorStore: FaissStore,
|
||||
url: string,
|
||||
source: UrlSource
|
||||
): Promise<LoaderReturn> {
|
||||
let loaderInstance: CheerioWebBaseLoader | YoutubeLoader | undefined
|
||||
let splitterType: string | undefined
|
||||
|
||||
switch (source) {
|
||||
case 'normal':
|
||||
loaderInstance = new CheerioWebBaseLoader(url)
|
||||
break
|
||||
case 'youtube':
|
||||
loaderInstance = YoutubeLoader.createFromUrl(url, {
|
||||
addVideoInfo: true,
|
||||
transcriptFormat: 'srt'
|
||||
})
|
||||
splitterType = 'srt'
|
||||
break
|
||||
}
|
||||
|
||||
if (!loaderInstance) {
|
||||
return {
|
||||
entriesAdded: 0,
|
||||
uniqueId: '',
|
||||
uniqueIds: [],
|
||||
loaderType: source
|
||||
}
|
||||
}
|
||||
|
||||
return executeLoader(base, vectorStore, loaderInstance, source, url, splitterType)
|
||||
}
|
||||
|
||||
export async function addSitemapLoader(
|
||||
base: KnowledgeBaseParams,
|
||||
vectorStore: FaissStore,
|
||||
url: string
|
||||
): Promise<LoaderReturn> {
|
||||
const loaderInstance = new SitemapLoader(url)
|
||||
return executeLoader(base, vectorStore, loaderInstance, 'sitemap', url)
|
||||
}
|
||||
|
||||
export async function addNoteLoader(
|
||||
base: KnowledgeBaseParams,
|
||||
vectorStore: FaissStore,
|
||||
content: string,
|
||||
sourceUrl: string
|
||||
): Promise<LoaderReturn> {
|
||||
const loaderInstance = new NoteLoader(content, sourceUrl)
|
||||
return executeLoader(base, vectorStore, loaderInstance, 'note', sourceUrl)
|
||||
}
|
||||
|
||||
export async function addVideoLoader(
|
||||
base: KnowledgeBaseParams,
|
||||
vectorStore: FaissStore,
|
||||
files: FileMetadata[]
|
||||
): Promise<LoaderReturn> {
|
||||
const srtFile = files.find((f) => f.type === FileTypes.TEXT)
|
||||
const videoFile = files.find((f) => f.type === FileTypes.VIDEO)
|
||||
|
||||
const emptyResult: LoaderReturn = {
|
||||
entriesAdded: 0,
|
||||
uniqueId: '',
|
||||
uniqueIds: [],
|
||||
loaderType: 'video'
|
||||
}
|
||||
|
||||
if (!srtFile || !videoFile) {
|
||||
return emptyResult
|
||||
}
|
||||
|
||||
try {
|
||||
const loaderInstance = new TextLoader(srtFile.path)
|
||||
const originalDocs = await loaderInstance.load()
|
||||
|
||||
const docsWithVideoMeta = originalDocs.map(
|
||||
(doc) =>
|
||||
new Document({
|
||||
...doc,
|
||||
metadata: {
|
||||
...doc.metadata,
|
||||
video: {
|
||||
path: videoFile.path,
|
||||
name: videoFile.origin_name
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return await processDocuments(base, vectorStore, docsWithVideoMeta, 'video', 'srt')
|
||||
} catch (error) {
|
||||
logger.error(`Error loading or processing file ${srtFile.path} with loader video: ${error}`)
|
||||
return emptyResult
|
||||
}
|
||||
}
|
||||
55
src/main/knowledge/langchain/retriever/index.ts
Normal file
55
src/main/knowledge/langchain/retriever/index.ts
Normal file
@ -0,0 +1,55 @@
|
||||
import { BM25Retriever } from '@langchain/community/retrievers/bm25'
|
||||
import { FaissStore } from '@langchain/community/vectorstores/faiss'
|
||||
import { BaseRetriever } from '@langchain/core/retrievers'
|
||||
import { loggerService } from '@main/services/LoggerService'
|
||||
import { type KnowledgeBaseParams } from '@types'
|
||||
import { type Document } from 'langchain/document'
|
||||
import { EnsembleRetriever } from 'langchain/retrievers/ensemble'
|
||||
|
||||
const logger = loggerService.withContext('RetrieverFactory')
|
||||
export class RetrieverFactory {
|
||||
/**
|
||||
* 根据提供的参数创建一个 LangChain 检索器 (Retriever)。
|
||||
* @param base 知识库配置参数。
|
||||
* @param vectorStore 一个已初始化的向量存储实例。
|
||||
* @param documents 文档列表,用于初始化 BM25Retriever。
|
||||
* @returns 返回一个 BaseRetriever 实例。
|
||||
*/
|
||||
public createRetriever(base: KnowledgeBaseParams, vectorStore: FaissStore, documents: Document[]): BaseRetriever {
|
||||
const retrieverType = base.retriever?.mode ?? 'hybrid'
|
||||
const retrieverWeight = base.retriever?.weight ?? 0.5
|
||||
const searchK = base.documentCount ?? 5
|
||||
|
||||
logger.info(`Creating retriever of type: ${retrieverType} with k=${searchK}`)
|
||||
|
||||
switch (retrieverType) {
|
||||
case 'bm25':
|
||||
if (documents.length === 0) {
|
||||
throw new Error('BM25Retriever requires documents, but none were provided or found.')
|
||||
}
|
||||
logger.info('Create BM25 Retriever')
|
||||
return BM25Retriever.fromDocuments(documents, { k: searchK })
|
||||
|
||||
case 'hybrid': {
|
||||
if (documents.length === 0) {
|
||||
logger.warn('No documents provided for BM25 part of hybrid search. Falling back to vector search only.')
|
||||
return vectorStore.asRetriever(searchK)
|
||||
}
|
||||
|
||||
const vectorstoreRetriever = vectorStore.asRetriever(searchK)
|
||||
const bm25Retriever = BM25Retriever.fromDocuments(documents, { k: searchK })
|
||||
|
||||
logger.info('Create Hybrid Retriever')
|
||||
return new EnsembleRetriever({
|
||||
retrievers: [bm25Retriever, vectorstoreRetriever],
|
||||
weights: [retrieverWeight, 1 - retrieverWeight]
|
||||
})
|
||||
}
|
||||
|
||||
case 'vector':
|
||||
default:
|
||||
logger.info('Create Vector Retriever')
|
||||
return vectorStore.asRetriever(searchK)
|
||||
}
|
||||
}
|
||||
}
|
||||
133
src/main/knowledge/langchain/splitter/SrtSplitter.ts
Normal file
133
src/main/knowledge/langchain/splitter/SrtSplitter.ts
Normal file
@ -0,0 +1,133 @@
|
||||
import { Document } from '@langchain/core/documents'
|
||||
import { TextSplitter, TextSplitterParams } from 'langchain/text_splitter'
|
||||
|
||||
// 定义一个接口来表示解析后的单个字幕片段
|
||||
interface SrtSegment {
|
||||
text: string
|
||||
startTime: number // in seconds
|
||||
endTime: number // in seconds
|
||||
}
|
||||
|
||||
// 辅助函数:将 SRT 时间戳字符串 (HH:MM:SS,ms) 转换为秒
|
||||
function srtTimeToSeconds(time: string): number {
|
||||
const parts = time.split(':')
|
||||
const secondsAndMs = parts[2].split(',')
|
||||
const hours = parseInt(parts[0], 10)
|
||||
const minutes = parseInt(parts[1], 10)
|
||||
const seconds = parseInt(secondsAndMs[0], 10)
|
||||
const milliseconds = parseInt(secondsAndMs[1], 10)
|
||||
|
||||
return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000
|
||||
}
|
||||
|
||||
export class SrtSplitter extends TextSplitter {
|
||||
constructor(fields?: Partial<TextSplitterParams>) {
|
||||
// 传入 chunkSize 和 chunkOverlap
|
||||
super(fields)
|
||||
}
|
||||
splitText(): Promise<string[]> {
|
||||
throw new Error('Method not implemented.')
|
||||
}
|
||||
|
||||
// 核心方法:重写 splitDocuments 来实现自定义逻辑
|
||||
async splitDocuments(documents: Document[]): Promise<Document[]> {
|
||||
const allChunks: Document[] = []
|
||||
|
||||
for (const doc of documents) {
|
||||
// 1. 解析 SRT 内容
|
||||
const segments = this.parseSrt(doc.pageContent)
|
||||
if (segments.length === 0) continue
|
||||
|
||||
// 2. 将字幕片段组合成块
|
||||
const chunks = this.mergeSegmentsIntoChunks(segments, doc.metadata)
|
||||
allChunks.push(...chunks)
|
||||
}
|
||||
|
||||
return allChunks
|
||||
}
|
||||
|
||||
// 辅助方法:解析整个 SRT 字符串
|
||||
private parseSrt(srt: string): SrtSegment[] {
|
||||
const segments: SrtSegment[] = []
|
||||
const blocks = srt.trim().split(/\n\n/)
|
||||
|
||||
for (const block of blocks) {
|
||||
const lines = block.split('\n')
|
||||
if (lines.length < 3) continue
|
||||
|
||||
const timeMatch = lines[1].match(/(\d{2}:\d{2}:\d{2},\d{3}) --> (\d{2}:\d{2}:\d{2},\d{3})/)
|
||||
if (!timeMatch) continue
|
||||
|
||||
const startTime = srtTimeToSeconds(timeMatch[1])
|
||||
const endTime = srtTimeToSeconds(timeMatch[2])
|
||||
const text = lines.slice(2).join(' ').trim()
|
||||
|
||||
segments.push({ text, startTime, endTime })
|
||||
}
|
||||
|
||||
return segments
|
||||
}
|
||||
|
||||
// 辅助方法:将解析后的片段合并成每 5 段一个块
|
||||
private mergeSegmentsIntoChunks(segments: SrtSegment[], baseMetadata: Record<string, any>): Document[] {
|
||||
const chunks: Document[] = []
|
||||
let currentChunkText = ''
|
||||
let currentChunkStartTime = 0
|
||||
let currentChunkEndTime = 0
|
||||
let segmentCount = 0
|
||||
|
||||
for (const segment of segments) {
|
||||
if (segmentCount === 0) {
|
||||
currentChunkStartTime = segment.startTime
|
||||
}
|
||||
|
||||
currentChunkText += (currentChunkText ? ' ' : '') + segment.text
|
||||
currentChunkEndTime = segment.endTime
|
||||
segmentCount++
|
||||
|
||||
// 当累积到 5 段时,创建一个新的 Document
|
||||
if (segmentCount === 5) {
|
||||
const metadata: Record<string, any> = {
|
||||
...baseMetadata,
|
||||
startTime: currentChunkStartTime,
|
||||
endTime: currentChunkEndTime
|
||||
}
|
||||
if (baseMetadata.source_url) {
|
||||
metadata.source_url_with_timestamp = `${baseMetadata.source_url}?t=${Math.floor(currentChunkStartTime)}s`
|
||||
}
|
||||
chunks.push(
|
||||
new Document({
|
||||
pageContent: currentChunkText,
|
||||
metadata
|
||||
})
|
||||
)
|
||||
|
||||
// 重置计数器和临时变量
|
||||
currentChunkText = ''
|
||||
currentChunkStartTime = 0
|
||||
currentChunkEndTime = 0
|
||||
segmentCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
// 如果还有剩余的片段,创建最后一个 Document
|
||||
if (segmentCount > 0) {
|
||||
const metadata: Record<string, any> = {
|
||||
...baseMetadata,
|
||||
startTime: currentChunkStartTime,
|
||||
endTime: currentChunkEndTime
|
||||
}
|
||||
if (baseMetadata.source_url) {
|
||||
metadata.source_url_with_timestamp = `${baseMetadata.source_url}?t=${Math.floor(currentChunkStartTime)}s`
|
||||
}
|
||||
chunks.push(
|
||||
new Document({
|
||||
pageContent: currentChunkText,
|
||||
metadata
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
}
|
||||
31
src/main/knowledge/langchain/splitter/index.ts
Normal file
31
src/main/knowledge/langchain/splitter/index.ts
Normal file
@ -0,0 +1,31 @@
|
||||
import { RecursiveCharacterTextSplitter, TextSplitter } from '@langchain/textsplitters'
|
||||
|
||||
import { SrtSplitter } from './SrtSplitter'
|
||||
|
||||
export type SplitterConfig = {
|
||||
chunkSize?: number
|
||||
chunkOverlap?: number
|
||||
type?: 'recursive' | 'srt' | string
|
||||
}
|
||||
export class SplitterFactory {
|
||||
/**
|
||||
* Creates a TextSplitter instance based on the provided configuration.
|
||||
* @param config - The configuration object specifying the splitter type and its parameters.
|
||||
* @returns An instance of a TextSplitter, or null if no splitting is required.
|
||||
*/
|
||||
public static create(config: SplitterConfig): TextSplitter {
|
||||
switch (config.type) {
|
||||
case 'srt':
|
||||
return new SrtSplitter({
|
||||
chunkSize: config.chunkSize,
|
||||
chunkOverlap: config.chunkOverlap
|
||||
})
|
||||
case 'recursive':
|
||||
default:
|
||||
return new RecursiveCharacterTextSplitter({
|
||||
chunkSize: config.chunkSize,
|
||||
chunkOverlap: config.chunkOverlap
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
63
src/main/knowledge/preprocess/PreprocessingService.ts
Normal file
63
src/main/knowledge/preprocess/PreprocessingService.ts
Normal file
@ -0,0 +1,63 @@
|
||||
import PreprocessProvider from '@main/knowledge/preprocess/PreprocessProvider'
|
||||
import { loggerService } from '@main/services/LoggerService'
|
||||
import { windowService } from '@main/services/WindowService'
|
||||
import type { FileMetadata, KnowledgeBaseParams, KnowledgeItem } from '@types'
|
||||
|
||||
const logger = loggerService.withContext('PreprocessingService')
|
||||
|
||||
class PreprocessingService {
|
||||
public async preprocessFile(
|
||||
file: FileMetadata,
|
||||
base: KnowledgeBaseParams,
|
||||
item: KnowledgeItem,
|
||||
userId: string
|
||||
): Promise<FileMetadata> {
|
||||
let fileToProcess: FileMetadata = file
|
||||
// Check if preprocessing is configured and applicable (e.g., for PDFs)
|
||||
if (base.preprocessProvider && file.ext.toLowerCase() === '.pdf') {
|
||||
try {
|
||||
const provider = new PreprocessProvider(base.preprocessProvider.provider, userId)
|
||||
|
||||
// Check if file has already been preprocessed
|
||||
const alreadyProcessed = await provider.checkIfAlreadyProcessed(file)
|
||||
if (alreadyProcessed) {
|
||||
logger.debug(`File already preprocessed, using cached result: ${file.path}`)
|
||||
return alreadyProcessed
|
||||
}
|
||||
|
||||
// Execute preprocessing
|
||||
logger.debug(`Starting preprocess for scanned PDF: ${file.path}`)
|
||||
const { processedFile, quota } = await provider.parseFile(item.id, file)
|
||||
fileToProcess = processedFile
|
||||
|
||||
// Notify the UI
|
||||
const mainWindow = windowService.getMainWindow()
|
||||
mainWindow?.webContents.send('file-preprocess-finished', {
|
||||
itemId: item.id,
|
||||
quota: quota
|
||||
})
|
||||
} catch (err) {
|
||||
logger.error(`Preprocessing failed: ${err}`)
|
||||
// If preprocessing fails, re-throw the error to be handled by the caller
|
||||
throw new Error(`Preprocessing failed: ${err}`)
|
||||
}
|
||||
}
|
||||
|
||||
return fileToProcess
|
||||
}
|
||||
|
||||
public async checkQuota(base: KnowledgeBaseParams, userId: string): Promise<number> {
|
||||
try {
|
||||
if (base.preprocessProvider && base.preprocessProvider.type === 'preprocess') {
|
||||
const provider = new PreprocessProvider(base.preprocessProvider.provider, userId)
|
||||
return await provider.checkQuota()
|
||||
}
|
||||
throw new Error('No preprocess provider configured')
|
||||
} catch (err) {
|
||||
logger.error(`Failed to check quota: ${err}`)
|
||||
throw new Error(`Failed to check quota: ${err}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const preprocessingService = new PreprocessingService()
|
||||
@ -1,101 +1,46 @@
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
import { DEFAULT_DOCUMENT_COUNT, DEFAULT_RELEVANT_SCORE } from '@main/utils/knowledge'
|
||||
import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types'
|
||||
|
||||
import { MultiModalDocument, RerankStrategy } from './strategies/RerankStrategy'
|
||||
import { StrategyFactory } from './strategies/StrategyFactory'
|
||||
|
||||
export default abstract class BaseReranker {
|
||||
protected base: KnowledgeBaseParams
|
||||
protected strategy: RerankStrategy
|
||||
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
if (!base.rerankApiClient) {
|
||||
throw new Error('Rerank model is required')
|
||||
}
|
||||
this.base = base
|
||||
this.strategy = StrategyFactory.createStrategy(base.rerankApiClient.provider)
|
||||
}
|
||||
|
||||
abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]>
|
||||
|
||||
/**
|
||||
* Get Rerank Request Url
|
||||
*/
|
||||
protected getRerankUrl() {
|
||||
if (this.base.rerankApiClient?.provider === 'bailian') {
|
||||
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
|
||||
}
|
||||
|
||||
let baseURL = this.base.rerankApiClient?.baseURL
|
||||
|
||||
if (baseURL && baseURL.endsWith('/')) {
|
||||
// `/` 结尾强制使用rerankBaseURL
|
||||
return `${baseURL}rerank`
|
||||
}
|
||||
|
||||
if (baseURL && !baseURL.endsWith('/v1')) {
|
||||
baseURL = `${baseURL}/v1`
|
||||
}
|
||||
|
||||
return `${baseURL}/rerank`
|
||||
abstract rerank(query: string, searchResults: KnowledgeSearchResult[]): Promise<KnowledgeSearchResult[]>
|
||||
protected getRerankUrl(): string {
|
||||
return this.strategy.buildUrl(this.base.rerankApiClient?.baseURL)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Rerank Request Body
|
||||
*/
|
||||
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
|
||||
const provider = this.base.rerankApiClient?.provider
|
||||
const documents = searchResults.map((doc) => doc.pageContent)
|
||||
const topN = this.base.documentCount
|
||||
|
||||
if (provider === 'voyageai') {
|
||||
return {
|
||||
model: this.base.rerankApiClient?.model,
|
||||
query,
|
||||
documents,
|
||||
top_k: topN
|
||||
}
|
||||
} else if (provider === 'bailian') {
|
||||
return {
|
||||
model: this.base.rerankApiClient?.model,
|
||||
input: {
|
||||
query,
|
||||
documents
|
||||
},
|
||||
parameters: {
|
||||
top_n: topN
|
||||
}
|
||||
}
|
||||
} else if (provider?.includes('tei')) {
|
||||
return {
|
||||
query,
|
||||
texts: documents,
|
||||
return_text: true
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
model: this.base.rerankApiClient?.model,
|
||||
query,
|
||||
documents,
|
||||
top_n: topN
|
||||
}
|
||||
}
|
||||
protected getRerankRequestBody(query: string, searchResults: KnowledgeSearchResult[]) {
|
||||
const documents = this.buildDocuments(searchResults)
|
||||
const topN = this.base.documentCount ?? DEFAULT_DOCUMENT_COUNT
|
||||
const model = this.base.rerankApiClient?.model
|
||||
return this.strategy.buildRequestBody(query, documents, topN, model)
|
||||
}
|
||||
private buildDocuments(searchResults: KnowledgeSearchResult[]): MultiModalDocument[] {
|
||||
return searchResults.map((doc) => {
|
||||
const document: MultiModalDocument = {}
|
||||
|
||||
/**
|
||||
* Extract Rerank Result
|
||||
*/
|
||||
// 检查是否是图片类型,添加图片内容
|
||||
if (doc.metadata?.type === 'image') {
|
||||
document.image = doc.pageContent
|
||||
} else {
|
||||
document.text = doc.pageContent
|
||||
}
|
||||
|
||||
return document
|
||||
})
|
||||
}
|
||||
protected extractRerankResult(data: any) {
|
||||
const provider = this.base.rerankApiClient?.provider
|
||||
if (provider === 'bailian') {
|
||||
return data.output.results
|
||||
} else if (provider === 'voyageai') {
|
||||
return data.data
|
||||
} else if (provider?.includes('tei')) {
|
||||
return data.map((item: any) => {
|
||||
return {
|
||||
index: item.index,
|
||||
relevance_score: item.score
|
||||
}
|
||||
})
|
||||
} else {
|
||||
return data.results
|
||||
}
|
||||
return this.strategy.extractResults(data)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -105,35 +50,30 @@ export default abstract class BaseReranker {
|
||||
* @protected
|
||||
*/
|
||||
protected getRerankResult(
|
||||
searchResults: ExtractChunkData[],
|
||||
rerankResults: Array<{
|
||||
index: number
|
||||
relevance_score: number
|
||||
}>
|
||||
searchResults: KnowledgeSearchResult[],
|
||||
rerankResults: Array<{ index: number; relevance_score: number }>
|
||||
) {
|
||||
const resultMap = new Map(rerankResults.map((result) => [result.index, result.relevance_score || 0]))
|
||||
const resultMap = new Map(
|
||||
rerankResults.map((result) => [result.index, result.relevance_score || DEFAULT_RELEVANT_SCORE])
|
||||
)
|
||||
|
||||
return searchResults
|
||||
.map((doc: ExtractChunkData, index: number) => {
|
||||
const returenResults = searchResults
|
||||
.map((doc: KnowledgeSearchResult, index: number) => {
|
||||
const score = resultMap.get(index)
|
||||
if (score === undefined) return undefined
|
||||
|
||||
return {
|
||||
...doc,
|
||||
score
|
||||
}
|
||||
return { ...doc, score }
|
||||
})
|
||||
.filter((doc): doc is ExtractChunkData => doc !== undefined)
|
||||
.filter((doc): doc is KnowledgeSearchResult => doc !== undefined)
|
||||
.sort((a, b) => b.score - a.score)
|
||||
}
|
||||
|
||||
return returenResults
|
||||
}
|
||||
public defaultHeaders() {
|
||||
return {
|
||||
Authorization: `Bearer ${this.base.rerankApiClient?.apiKey}`,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
}
|
||||
|
||||
protected formatErrorMessage(url: string, error: any, requestBody: any) {
|
||||
const errorDetails = {
|
||||
url: url,
|
||||
|
||||
@ -1,19 +1,14 @@
|
||||
import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types'
|
||||
import { net } from 'electron'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
export default class GeneralReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
|
||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||
public rerank = async (query: string, searchResults: KnowledgeSearchResult[]): Promise<KnowledgeSearchResult[]> => {
|
||||
const url = this.getRerankUrl()
|
||||
|
||||
const requestBody = this.getRerankRequestBody(query, searchResults)
|
||||
|
||||
try {
|
||||
const response = await net.fetch(url, {
|
||||
method: 'POST',
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types'
|
||||
|
||||
import GeneralReranker from './GeneralReranker'
|
||||
|
||||
@ -8,7 +7,7 @@ export default class Reranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
this.sdk = new GeneralReranker(base)
|
||||
}
|
||||
public async rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> {
|
||||
public async rerank(query: string, searchResults: KnowledgeSearchResult[]): Promise<KnowledgeSearchResult[]> {
|
||||
return this.sdk.rerank(query, searchResults)
|
||||
}
|
||||
}
|
||||
|
||||
18
src/main/knowledge/reranker/strategies/BailianStrategy.ts
Normal file
18
src/main/knowledge/reranker/strategies/BailianStrategy.ts
Normal file
@ -0,0 +1,18 @@
|
||||
import { MultiModalDocument, RerankStrategy } from './RerankStrategy'
|
||||
export class BailianStrategy implements RerankStrategy {
|
||||
buildUrl(): string {
|
||||
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
|
||||
}
|
||||
buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string) {
|
||||
const textDocuments = documents.filter((d) => d.text).map((d) => d.text!)
|
||||
|
||||
return {
|
||||
model,
|
||||
input: { query, documents: textDocuments },
|
||||
parameters: { top_n: topN }
|
||||
}
|
||||
}
|
||||
extractResults(data: any) {
|
||||
return data.output.results
|
||||
}
|
||||
}
|
||||
25
src/main/knowledge/reranker/strategies/DefaultStrategy.ts
Normal file
25
src/main/knowledge/reranker/strategies/DefaultStrategy.ts
Normal file
@ -0,0 +1,25 @@
|
||||
import { MultiModalDocument, RerankStrategy } from './RerankStrategy'
|
||||
export class DefaultStrategy implements RerankStrategy {
|
||||
buildUrl(baseURL?: string): string {
|
||||
if (baseURL && baseURL.endsWith('/')) {
|
||||
return `${baseURL}rerank`
|
||||
}
|
||||
if (baseURL && !baseURL.endsWith('/v1')) {
|
||||
baseURL = `${baseURL}/v1`
|
||||
}
|
||||
return `${baseURL}/rerank`
|
||||
}
|
||||
buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string) {
|
||||
const textDocuments = documents.filter((d) => d.text).map((d) => d.text!)
|
||||
|
||||
return {
|
||||
model,
|
||||
query,
|
||||
documents: textDocuments,
|
||||
top_n: topN
|
||||
}
|
||||
}
|
||||
extractResults(data: any) {
|
||||
return data.results
|
||||
}
|
||||
}
|
||||
33
src/main/knowledge/reranker/strategies/JinaStrategy.ts
Normal file
33
src/main/knowledge/reranker/strategies/JinaStrategy.ts
Normal file
@ -0,0 +1,33 @@
|
||||
import { MultiModalDocument, RerankStrategy } from './RerankStrategy'
|
||||
export class JinaStrategy implements RerankStrategy {
|
||||
buildUrl(baseURL?: string): string {
|
||||
if (baseURL && baseURL.endsWith('/')) {
|
||||
return `${baseURL}rerank`
|
||||
}
|
||||
if (baseURL && !baseURL.endsWith('/v1')) {
|
||||
baseURL = `${baseURL}/v1`
|
||||
}
|
||||
return `${baseURL}/rerank`
|
||||
}
|
||||
buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string) {
|
||||
if (model === 'jina-reranker-m0') {
|
||||
return {
|
||||
model,
|
||||
query,
|
||||
documents,
|
||||
top_n: topN
|
||||
}
|
||||
}
|
||||
const textDocuments = documents.filter((d) => d.text).map((d) => d.text!)
|
||||
|
||||
return {
|
||||
model,
|
||||
query,
|
||||
documents: textDocuments,
|
||||
top_n: topN
|
||||
}
|
||||
}
|
||||
extractResults(data: any) {
|
||||
return data.results
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user