diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 4596fc41d6..14eca3205f 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -3,3 +3,4 @@
/src/main/services/ConfigManager.ts @0xfullex
/packages/shared/IpcChannel.ts @0xfullex
/src/main/ipc.ts @0xfullex
+/app-upgrade-config.json @kangfenmao
diff --git a/.github/workflows/auto-i18n.yml b/.github/workflows/auto-i18n.yml
index 1584ab48db..ea9f05ae03 100644
--- a/.github/workflows/auto-i18n.yml
+++ b/.github/workflows/auto-i18n.yml
@@ -77,7 +77,7 @@ jobs:
with:
token: ${{ secrets.GITHUB_TOKEN }} # Use the built-in GITHUB_TOKEN for bot actions
commit-message: "feat(bot): Weekly automated script run"
- title: "🤖 Weekly Automated Update: ${{ env.CURRENT_DATE }}"
+ title: "🤖 Weekly Auto I18N Sync: ${{ env.CURRENT_DATE }}"
body: |
This PR includes changes generated by the weekly auto i18n.
Review the changes before merging.
diff --git a/.github/workflows/dispatch-docs-update.yml b/.github/workflows/dispatch-docs-update.yml
index b9457faec6..bb33c60b33 100644
--- a/.github/workflows/dispatch-docs-update.yml
+++ b/.github/workflows/dispatch-docs-update.yml
@@ -19,7 +19,7 @@ jobs:
echo "tag=${{ github.event.release.tag_name }}" >> $GITHUB_OUTPUT
- name: Dispatch update-download-version workflow to cherry-studio-docs
- uses: peter-evans/repository-dispatch@v3
+ uses: peter-evans/repository-dispatch@v4
with:
token: ${{ secrets.REPO_DISPATCH_TOKEN }}
repository: CherryHQ/cherry-studio-docs
diff --git a/.github/workflows/update-app-upgrade-config.yml b/.github/workflows/update-app-upgrade-config.yml
new file mode 100644
index 0000000000..7470bb0b6c
--- /dev/null
+++ b/.github/workflows/update-app-upgrade-config.yml
@@ -0,0 +1,212 @@
+name: Update App Upgrade Config
+
+on:
+ release:
+ types:
+ - released
+ - prereleased
+ workflow_dispatch:
+ inputs:
+ tag:
+ description: "Release tag (e.g., v1.2.3)"
+ required: true
+ type: string
+ is_prerelease:
+ description: "Mark the tag as a prerelease when running manually"
+ required: false
+ default: false
+ type: boolean
+
+permissions:
+ contents: write
+ pull-requests: write
+
+jobs:
+ propose-update:
+ runs-on: ubuntu-latest
+ if: github.event_name == 'workflow_dispatch' || (github.event_name == 'release' && github.event.release.draft == false)
+
+ steps:
+ - name: Check if should proceed
+ id: check
+ run: |
+ EVENT="${{ github.event_name }}"
+
+ if [ "$EVENT" = "workflow_dispatch" ]; then
+ TAG="${{ github.event.inputs.tag }}"
+ else
+ TAG="${{ github.event.release.tag_name }}"
+ fi
+
+ latest_tag=$(
+ curl -L \
+ -H "Accept: application/vnd.github+json" \
+ -H "Authorization: Bearer ${{ github.token }}" \
+ -H "X-GitHub-Api-Version: 2022-11-28" \
+ https://api.github.com/repos/${{ github.repository }}/releases/latest \
+ | jq -r '.tag_name'
+ )
+
+ if [ "$EVENT" = "workflow_dispatch" ]; then
+ MANUAL_IS_PRERELEASE="${{ github.event.inputs.is_prerelease }}"
+ if [ -z "$MANUAL_IS_PRERELEASE" ]; then
+ MANUAL_IS_PRERELEASE="false"
+ fi
+ if [ "$MANUAL_IS_PRERELEASE" = "true" ]; then
+ if ! echo "$TAG" | grep -E '(-beta([.-][0-9]+)?|-rc([.-][0-9]+)?)' >/dev/null; then
+ echo "Manual prerelease flag set but tag $TAG lacks beta/rc suffix. Skipping." >&2
+ echo "should_run=false" >> "$GITHUB_OUTPUT"
+ echo "is_prerelease=false" >> "$GITHUB_OUTPUT"
+ echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
+ exit 0
+ fi
+ fi
+ echo "should_run=true" >> "$GITHUB_OUTPUT"
+ echo "is_prerelease=$MANUAL_IS_PRERELEASE" >> "$GITHUB_OUTPUT"
+ echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
+ exit 0
+ fi
+
+ IS_PRERELEASE="${{ github.event.release.prerelease }}"
+
+ if [ "$IS_PRERELEASE" = "true" ]; then
+ if ! echo "$TAG" | grep -E '(-beta([.-][0-9]+)?|-rc([.-][0-9]+)?)' >/dev/null; then
+ echo "Release marked as prerelease but tag $TAG lacks beta/rc suffix. Skipping." >&2
+ echo "should_run=false" >> "$GITHUB_OUTPUT"
+ echo "is_prerelease=false" >> "$GITHUB_OUTPUT"
+ echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
+ exit 0
+ fi
+ echo "should_run=true" >> "$GITHUB_OUTPUT"
+ echo "is_prerelease=true" >> "$GITHUB_OUTPUT"
+ echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
+ echo "Release is prerelease, proceeding"
+ exit 0
+ fi
+
+ if [[ "${latest_tag}" == "$TAG" ]]; then
+ echo "should_run=true" >> "$GITHUB_OUTPUT"
+ echo "is_prerelease=false" >> "$GITHUB_OUTPUT"
+ echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
+ echo "Release is latest, proceeding"
+ else
+ echo "should_run=false" >> "$GITHUB_OUTPUT"
+ echo "is_prerelease=false" >> "$GITHUB_OUTPUT"
+ echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
+ echo "Release is neither prerelease nor latest, skipping"
+ fi
+
+ - name: Prepare metadata
+ id: meta
+ if: steps.check.outputs.should_run == 'true'
+ run: |
+ EVENT="${{ github.event_name }}"
+ LATEST_TAG="${{ steps.check.outputs.latest_tag }}"
+ if [ "$EVENT" = "release" ]; then
+ TAG="${{ github.event.release.tag_name }}"
+ PRE="${{ github.event.release.prerelease }}"
+
+ if [ -n "$LATEST_TAG" ] && [ "$LATEST_TAG" = "$TAG" ]; then
+ LATEST="true"
+ else
+ LATEST="false"
+ fi
+ TRIGGER="release"
+ else
+ TAG="${{ github.event.inputs.tag }}"
+ PRE="${{ github.event.inputs.is_prerelease }}"
+ if [ -z "$PRE" ]; then
+ PRE="false"
+ fi
+ if [ -n "$LATEST_TAG" ] && [ "$LATEST_TAG" = "$TAG" ] && [ "$PRE" != "true" ]; then
+ LATEST="true"
+ else
+ LATEST="false"
+ fi
+ TRIGGER="manual"
+ fi
+
+ SAFE_TAG=$(echo "$TAG" | sed 's/[^A-Za-z0-9._-]/-/g')
+ echo "tag=$TAG" >> "$GITHUB_OUTPUT"
+ echo "safe_tag=$SAFE_TAG" >> "$GITHUB_OUTPUT"
+ echo "prerelease=$PRE" >> "$GITHUB_OUTPUT"
+ echo "latest=$LATEST" >> "$GITHUB_OUTPUT"
+ echo "trigger=$TRIGGER" >> "$GITHUB_OUTPUT"
+
+ - name: Checkout default branch
+ if: steps.check.outputs.should_run == 'true'
+ uses: actions/checkout@v5
+ with:
+ ref: ${{ github.event.repository.default_branch }}
+ path: main
+ fetch-depth: 0
+
+ - name: Checkout x-files/app-upgrade-config branch
+ if: steps.check.outputs.should_run == 'true'
+ uses: actions/checkout@v5
+ with:
+ ref: x-files/app-upgrade-config
+ path: cs
+ fetch-depth: 0
+
+ - name: Setup Node.js
+ if: steps.check.outputs.should_run == 'true'
+ uses: actions/setup-node@v4
+ with:
+ node-version: 22
+
+ - name: Enable Corepack
+ if: steps.check.outputs.should_run == 'true'
+ run: corepack enable && corepack prepare yarn@4.9.1 --activate
+
+ - name: Install dependencies
+ if: steps.check.outputs.should_run == 'true'
+ working-directory: main
+ run: yarn install --immutable
+
+ - name: Update upgrade config
+ if: steps.check.outputs.should_run == 'true'
+ working-directory: main
+ env:
+ RELEASE_TAG: ${{ steps.meta.outputs.tag }}
+ IS_PRERELEASE: ${{ steps.check.outputs.is_prerelease }}
+ run: |
+ yarn tsx scripts/update-app-upgrade-config.ts \
+ --tag "$RELEASE_TAG" \
+ --config ../cs/app-upgrade-config.json \
+ --is-prerelease "$IS_PRERELEASE"
+
+ - name: Detect changes
+ if: steps.check.outputs.should_run == 'true'
+ id: diff
+ working-directory: cs
+ run: |
+ if git diff --quiet -- app-upgrade-config.json; then
+ echo "changed=false" >> "$GITHUB_OUTPUT"
+ else
+ echo "changed=true" >> "$GITHUB_OUTPUT"
+ fi
+
+ - name: Create pull request
+ if: steps.check.outputs.should_run == 'true' && steps.diff.outputs.changed == 'true'
+ uses: peter-evans/create-pull-request@v7
+ with:
+ path: cs
+ base: x-files/app-upgrade-config
+ branch: chore/update-app-upgrade-config/${{ steps.meta.outputs.safe_tag }}
+ commit-message: "🤖 chore: sync app-upgrade-config for ${{ steps.meta.outputs.tag }}"
+ title: "chore: update app-upgrade-config for ${{ steps.meta.outputs.tag }}"
+ body: |
+ Automated update triggered by `${{ steps.meta.outputs.trigger }}`.
+
+ - Source tag: `${{ steps.meta.outputs.tag }}`
+ - Pre-release: `${{ steps.meta.outputs.prerelease }}`
+ - Latest: `${{ steps.meta.outputs.latest }}`
+ - Workflow run: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
+ labels: |
+ automation
+ app-upgrade
+
+ - name: No changes detected
+ if: steps.check.outputs.should_run == 'true' && steps.diff.outputs.changed != 'true'
+ run: echo "No updates required for x-files/app-upgrade-config/app-upgrade-config.json"
diff --git a/.oxlintrc.json b/.oxlintrc.json
index 5d63538e2c..093ae25f18 100644
--- a/.oxlintrc.json
+++ b/.oxlintrc.json
@@ -11,6 +11,7 @@
"dist/**",
"out/**",
"local/**",
+ "tests/**",
".yarn/**",
".gitignore",
"scripts/cloudflare-worker.js",
@@ -22,7 +23,6 @@
"eslint.config.mjs"
],
"overrides": [
- // set different env
{
"env": {
"node": true
@@ -36,8 +36,7 @@
"files": [
"src/renderer/**/*.{ts,tsx}",
"packages/aiCore/**",
- "packages/extension-table-plus/**",
- "resources/js/**"
+ "packages/extension-table-plus/**"
]
},
{
@@ -53,76 +52,24 @@
"node": true
},
"files": ["src/preload/**"]
+ },
+ {
+ "files": ["packages/ai-sdk-provider/**"],
+ "globals": {
+ "fetch": "readonly"
+ }
}
],
- // We don't use the React plugin here because its behavior differs slightly from that of ESLint's React plugin.
"plugins": ["unicorn", "typescript", "oxc", "import"],
"rules": {
- "constructor-super": "error",
- "for-direction": "error",
- "getter-return": "error",
"no-array-constructor": "off",
- // "import/no-cycle": "error", // tons of error, bro
- "no-async-promise-executor": "error",
"no-caller": "warn",
- "no-case-declarations": "error",
- "no-class-assign": "error",
- "no-compare-neg-zero": "error",
- "no-cond-assign": "error",
- "no-const-assign": "error",
- "no-constant-binary-expression": "error",
- "no-constant-condition": "error",
- "no-control-regex": "error",
- "no-debugger": "error",
- "no-delete-var": "error",
- "no-dupe-args": "error",
- "no-dupe-class-members": "error",
- "no-dupe-else-if": "error",
- "no-dupe-keys": "error",
- "no-duplicate-case": "error",
- "no-empty": "error",
- "no-empty-character-class": "error",
- "no-empty-pattern": "error",
- "no-empty-static-block": "error",
"no-eval": "warn",
- "no-ex-assign": "error",
- "no-extra-boolean-cast": "error",
"no-fallthrough": "warn",
- "no-func-assign": "error",
- "no-global-assign": "error",
- "no-import-assign": "error",
- "no-invalid-regexp": "error",
- "no-irregular-whitespace": "error",
- "no-loss-of-precision": "error",
- "no-misleading-character-class": "error",
- "no-new-native-nonconstructor": "error",
- "no-nonoctal-decimal-escape": "error",
- "no-obj-calls": "error",
- "no-octal": "error",
- "no-prototype-builtins": "error",
- "no-redeclare": "error",
- "no-regex-spaces": "error",
- "no-self-assign": "error",
- "no-setter-return": "error",
- "no-shadow-restricted-names": "error",
- "no-sparse-arrays": "error",
- "no-this-before-super": "error",
"no-unassigned-vars": "warn",
- "no-undef": "error",
- "no-unexpected-multiline": "error",
- "no-unreachable": "error",
- "no-unsafe-finally": "error",
- "no-unsafe-negation": "error",
- "no-unsafe-optional-chaining": "error",
- "no-unused-expressions": "off", // this rule disallow us to use expression to call function, like `condition && fn()`
- "no-unused-labels": "error",
- "no-unused-private-class-members": "error",
+ "no-unused-expressions": "off",
"no-unused-vars": ["warn", { "caughtErrors": "none" }],
- "no-useless-backreference": "error",
- "no-useless-catch": "error",
- "no-useless-escape": "error",
"no-useless-rename": "warn",
- "no-with": "error",
"oxc/bad-array-method-on-arguments": "warn",
"oxc/bad-char-at-comparison": "warn",
"oxc/bad-comparison-sequence": "warn",
@@ -134,19 +81,17 @@
"oxc/erasing-op": "warn",
"oxc/missing-throw": "warn",
"oxc/number-arg-out-of-range": "warn",
- "oxc/only-used-in-recursion": "off", // manually off bacause of existing warning. may turn it on in the future
+ "oxc/only-used-in-recursion": "off",
"oxc/uninvoked-array-callback": "warn",
- "require-yield": "error",
"typescript/await-thenable": "warn",
- // "typescript/ban-ts-comment": "error",
- "typescript/no-array-constructor": "error",
"typescript/consistent-type-imports": "error",
+ "typescript/no-array-constructor": "error",
"typescript/no-array-delete": "warn",
"typescript/no-base-to-string": "warn",
"typescript/no-duplicate-enum-values": "error",
"typescript/no-duplicate-type-constituents": "warn",
"typescript/no-empty-object-type": "off",
- "typescript/no-explicit-any": "off", // not safe but too many errors
+ "typescript/no-explicit-any": "off",
"typescript/no-extra-non-null-assertion": "error",
"typescript/no-floating-promises": "warn",
"typescript/no-for-in-array": "warn",
@@ -155,7 +100,7 @@
"typescript/no-misused-new": "error",
"typescript/no-misused-spread": "warn",
"typescript/no-namespace": "error",
- "typescript/no-non-null-asserted-optional-chain": "off", // it's off now. but may turn it on.
+ "typescript/no-non-null-asserted-optional-chain": "off",
"typescript/no-redundant-type-constituents": "warn",
"typescript/no-require-imports": "off",
"typescript/no-this-alias": "error",
@@ -173,20 +118,18 @@
"typescript/triple-slash-reference": "error",
"typescript/unbound-method": "warn",
"unicorn/no-await-in-promise-methods": "warn",
- "unicorn/no-empty-file": "off", // manually off bacause of existing warning. may turn it on in the future
+ "unicorn/no-empty-file": "off",
"unicorn/no-invalid-fetch-options": "warn",
"unicorn/no-invalid-remove-event-listener": "warn",
- "unicorn/no-new-array": "off", // manually off bacause of existing warning. may turn it on in the future
+ "unicorn/no-new-array": "off",
"unicorn/no-single-promise-in-promise-methods": "warn",
- "unicorn/no-thenable": "off", // manually off bacause of existing warning. may turn it on in the future
+ "unicorn/no-thenable": "off",
"unicorn/no-unnecessary-await": "warn",
"unicorn/no-useless-fallback-in-spread": "warn",
"unicorn/no-useless-length-check": "warn",
- "unicorn/no-useless-spread": "off", // manually off bacause of existing warning. may turn it on in the future
+ "unicorn/no-useless-spread": "off",
"unicorn/prefer-set-size": "warn",
- "unicorn/prefer-string-starts-ends-with": "warn",
- "use-isnan": "error",
- "valid-typeof": "error"
+ "unicorn/prefer-string-starts-ends-with": "warn"
},
"settings": {
"jsdoc": {
diff --git a/.yarn/patches/@ai-sdk-google-npm-2.0.23-81682e07b0.patch b/.yarn/patches/@ai-sdk-google-npm-2.0.43-689ed559b3.patch
similarity index 68%
rename from .yarn/patches/@ai-sdk-google-npm-2.0.23-81682e07b0.patch
rename to .yarn/patches/@ai-sdk-google-npm-2.0.43-689ed559b3.patch
index ba4cd59d4c..3015e702ed 100644
--- a/.yarn/patches/@ai-sdk-google-npm-2.0.23-81682e07b0.patch
+++ b/.yarn/patches/@ai-sdk-google-npm-2.0.43-689ed559b3.patch
@@ -1,8 +1,8 @@
diff --git a/dist/index.js b/dist/index.js
-index 4cc66d83af1cef39f6447dc62e680251e05ddf9f..eb9819cb674c1808845ceb29936196c4bb355172 100644
+index 51ce7e423934fb717cb90245cdfcdb3dae6780e6..0f7f7009e2f41a79a8669d38c8a44867bbff5e1f 100644
--- a/dist/index.js
+++ b/dist/index.js
-@@ -471,7 +471,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
+@@ -474,7 +474,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
@@ -12,10 +12,10 @@ index 4cc66d83af1cef39f6447dc62e680251e05ddf9f..eb9819cb674c1808845ceb29936196c4
// src/google-generative-ai-options.ts
diff --git a/dist/index.mjs b/dist/index.mjs
-index a032505ec54e132dc386dde001dc51f710f84c83..5efada51b9a8b56e3f01b35e734908ebe3c37043 100644
+index f4b77e35c0cbfece85a3ef0d4f4e67aa6dde6271..8d2fecf8155a226006a0bde72b00b6036d4014b6 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
-@@ -477,7 +477,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
+@@ -480,7 +480,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
diff --git a/.yarn/patches/@ai-sdk-huggingface-npm-0.0.4-8080836bc1.patch b/.yarn/patches/@ai-sdk-huggingface-npm-0.0.4-8080836bc1.patch
deleted file mode 100644
index 7aeb4ea9cf..0000000000
--- a/.yarn/patches/@ai-sdk-huggingface-npm-0.0.4-8080836bc1.patch
+++ /dev/null
@@ -1,131 +0,0 @@
-diff --git a/dist/index.mjs b/dist/index.mjs
-index b3f018730a93639aad7c203f15fb1aeb766c73f4..ade2a43d66e9184799d072153df61ef7be4ea110 100644
---- a/dist/index.mjs
-+++ b/dist/index.mjs
-@@ -296,7 +296,14 @@ var HuggingFaceResponsesLanguageModel = class {
- metadata: huggingfaceOptions == null ? void 0 : huggingfaceOptions.metadata,
- instructions: huggingfaceOptions == null ? void 0 : huggingfaceOptions.instructions,
- ...preparedTools && { tools: preparedTools },
-- ...preparedToolChoice && { tool_choice: preparedToolChoice }
-+ ...preparedToolChoice && { tool_choice: preparedToolChoice },
-+ ...(huggingfaceOptions?.reasoningEffort != null && {
-+ reasoning: {
-+ ...(huggingfaceOptions?.reasoningEffort != null && {
-+ effort: huggingfaceOptions.reasoningEffort,
-+ }),
-+ },
-+ }),
- };
- return { args: baseArgs, warnings };
- }
-@@ -365,6 +372,20 @@ var HuggingFaceResponsesLanguageModel = class {
- }
- break;
- }
-+ case 'reasoning': {
-+ for (const contentPart of part.content) {
-+ content.push({
-+ type: 'reasoning',
-+ text: contentPart.text,
-+ providerMetadata: {
-+ huggingface: {
-+ itemId: part.id,
-+ },
-+ },
-+ });
-+ }
-+ break;
-+ }
- case "mcp_call": {
- content.push({
- type: "tool-call",
-@@ -519,6 +540,11 @@ var HuggingFaceResponsesLanguageModel = class {
- id: value.item.call_id,
- toolName: value.item.name
- });
-+ } else if (value.item.type === 'reasoning') {
-+ controller.enqueue({
-+ type: 'reasoning-start',
-+ id: value.item.id,
-+ });
- }
- return;
- }
-@@ -570,6 +596,22 @@ var HuggingFaceResponsesLanguageModel = class {
- });
- return;
- }
-+ if (isReasoningDeltaChunk(value)) {
-+ controller.enqueue({
-+ type: 'reasoning-delta',
-+ id: value.item_id,
-+ delta: value.delta,
-+ });
-+ return;
-+ }
-+
-+ if (isReasoningEndChunk(value)) {
-+ controller.enqueue({
-+ type: 'reasoning-end',
-+ id: value.item_id,
-+ });
-+ return;
-+ }
- },
- flush(controller) {
- controller.enqueue({
-@@ -593,7 +635,8 @@ var HuggingFaceResponsesLanguageModel = class {
- var huggingfaceResponsesProviderOptionsSchema = z2.object({
- metadata: z2.record(z2.string(), z2.string()).optional(),
- instructions: z2.string().optional(),
-- strictJsonSchema: z2.boolean().optional()
-+ strictJsonSchema: z2.boolean().optional(),
-+ reasoningEffort: z2.string().optional(),
- });
- var huggingfaceResponsesResponseSchema = z2.object({
- id: z2.string(),
-@@ -727,12 +770,31 @@ var responseCreatedChunkSchema = z2.object({
- model: z2.string()
- })
- });
-+var reasoningTextDeltaChunkSchema = z2.object({
-+ type: z2.literal('response.reasoning_text.delta'),
-+ item_id: z2.string(),
-+ output_index: z2.number(),
-+ content_index: z2.number(),
-+ delta: z2.string(),
-+ sequence_number: z2.number(),
-+});
-+
-+var reasoningTextEndChunkSchema = z2.object({
-+ type: z2.literal('response.reasoning_text.done'),
-+ item_id: z2.string(),
-+ output_index: z2.number(),
-+ content_index: z2.number(),
-+ text: z2.string(),
-+ sequence_number: z2.number(),
-+});
- var huggingfaceResponsesChunkSchema = z2.union([
- responseOutputItemAddedSchema,
- responseOutputItemDoneSchema,
- textDeltaChunkSchema,
- responseCompletedChunkSchema,
- responseCreatedChunkSchema,
-+ reasoningTextDeltaChunkSchema,
-+ reasoningTextEndChunkSchema,
- z2.object({ type: z2.string() }).loose()
- // fallback for unknown chunks
- ]);
-@@ -751,6 +813,12 @@ function isResponseCompletedChunk(chunk) {
- function isResponseCreatedChunk(chunk) {
- return chunk.type === "response.created";
- }
-+function isReasoningDeltaChunk(chunk) {
-+ return chunk.type === 'response.reasoning_text.delta';
-+}
-+function isReasoningEndChunk(chunk) {
-+ return chunk.type === 'response.reasoning_text.done';
-+}
-
- // src/huggingface-provider.ts
- function createHuggingFace(options = {}) {
diff --git a/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch b/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch
new file mode 100644
index 0000000000..2a13c33a78
--- /dev/null
+++ b/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch
@@ -0,0 +1,140 @@
+diff --git a/dist/index.js b/dist/index.js
+index 73045a7d38faafdc7f7d2cd79d7ff0e2b031056b..8d948c9ac4ea4b474db9ef3c5491961e7fcf9a07 100644
+--- a/dist/index.js
++++ b/dist/index.js
+@@ -421,6 +421,17 @@ var OpenAICompatibleChatLanguageModel = class {
+ text: reasoning
+ });
+ }
++ if (choice.message.images) {
++ for (const image of choice.message.images) {
++ const match1 = image.image_url.url.match(/^data:([^;]+)/)
++ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
++ content.push({
++ type: 'file',
++ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
++ data: match2 ? match2[1] : image.image_url.url,
++ });
++ }
++ }
+ if (choice.message.tool_calls != null) {
+ for (const toolCall of choice.message.tool_calls) {
+ content.push({
+@@ -598,6 +609,17 @@ var OpenAICompatibleChatLanguageModel = class {
+ delta: delta.content
+ });
+ }
++ if (delta.images) {
++ for (const image of delta.images) {
++ const match1 = image.image_url.url.match(/^data:([^;]+)/)
++ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
++ controller.enqueue({
++ type: 'file',
++ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
++ data: match2 ? match2[1] : image.image_url.url,
++ });
++ }
++ }
+ if (delta.tool_calls != null) {
+ for (const toolCallDelta of delta.tool_calls) {
+ const index = toolCallDelta.index;
+@@ -765,6 +787,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({
+ arguments: import_v43.z.string()
+ })
+ })
++ ).nullish(),
++ images: import_v43.z.array(
++ import_v43.z.object({
++ type: import_v43.z.literal('image_url'),
++ image_url: import_v43.z.object({
++ url: import_v43.z.string(),
++ })
++ })
+ ).nullish()
+ }),
+ finish_reason: import_v43.z.string().nullish()
+@@ -795,6 +825,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union(
+ arguments: import_v43.z.string().nullish()
+ })
+ })
++ ).nullish(),
++ images: import_v43.z.array(
++ import_v43.z.object({
++ type: import_v43.z.literal('image_url'),
++ image_url: import_v43.z.object({
++ url: import_v43.z.string(),
++ })
++ })
+ ).nullish()
+ }).nullish(),
+ finish_reason: import_v43.z.string().nullish()
+diff --git a/dist/index.mjs b/dist/index.mjs
+index 1c2b9560bbfbfe10cb01af080aeeed4ff59db29c..2c8ddc4fc9bfc5e7e06cfca105d197a08864c427 100644
+--- a/dist/index.mjs
++++ b/dist/index.mjs
+@@ -405,6 +405,17 @@ var OpenAICompatibleChatLanguageModel = class {
+ text: reasoning
+ });
+ }
++ if (choice.message.images) {
++ for (const image of choice.message.images) {
++ const match1 = image.image_url.url.match(/^data:([^;]+)/)
++ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
++ content.push({
++ type: 'file',
++ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
++ data: match2 ? match2[1] : image.image_url.url,
++ });
++ }
++ }
+ if (choice.message.tool_calls != null) {
+ for (const toolCall of choice.message.tool_calls) {
+ content.push({
+@@ -582,6 +593,17 @@ var OpenAICompatibleChatLanguageModel = class {
+ delta: delta.content
+ });
+ }
++ if (delta.images) {
++ for (const image of delta.images) {
++ const match1 = image.image_url.url.match(/^data:([^;]+)/)
++ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
++ controller.enqueue({
++ type: 'file',
++ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
++ data: match2 ? match2[1] : image.image_url.url,
++ });
++ }
++ }
+ if (delta.tool_calls != null) {
+ for (const toolCallDelta of delta.tool_calls) {
+ const index = toolCallDelta.index;
+@@ -749,6 +771,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({
+ arguments: z3.string()
+ })
+ })
++ ).nullish(),
++ images: z3.array(
++ z3.object({
++ type: z3.literal('image_url'),
++ image_url: z3.object({
++ url: z3.string(),
++ })
++ })
+ ).nullish()
+ }),
+ finish_reason: z3.string().nullish()
+@@ -779,6 +809,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([
+ arguments: z3.string().nullish()
+ })
+ })
++ ).nullish(),
++ images: z3.array(
++ z3.object({
++ type: z3.literal('image_url'),
++ image_url: z3.object({
++ url: z3.string(),
++ })
++ })
+ ).nullish()
+ }).nullish(),
+ finish_reason: z3.string().nullish()
diff --git a/.yarn/patches/@ai-sdk-openai-npm-2.0.52-b36d949c76.patch b/.yarn/patches/@ai-sdk-openai-npm-2.0.72-234e68da87.patch
similarity index 85%
rename from .yarn/patches/@ai-sdk-openai-npm-2.0.52-b36d949c76.patch
rename to .yarn/patches/@ai-sdk-openai-npm-2.0.72-234e68da87.patch
index a7985ddfcd..973ddc62ac 100644
--- a/.yarn/patches/@ai-sdk-openai-npm-2.0.52-b36d949c76.patch
+++ b/.yarn/patches/@ai-sdk-openai-npm-2.0.72-234e68da87.patch
@@ -1,5 +1,5 @@
diff --git a/dist/index.js b/dist/index.js
-index cc6652c4e7f32878a64a2614115bf7eeb3b7c890..76e989017549c89b45d633525efb1f318026d9b2 100644
+index bf900591bf2847a3253fe441aad24c06da19c6c1..c1d9bb6fefa2df1383339324073db0a70ea2b5a2 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -274,6 +274,7 @@ var openaiChatResponseSchema = (0, import_provider_utils3.lazyValidator)(
@@ -18,30 +18,29 @@ index cc6652c4e7f32878a64a2614115bf7eeb3b7c890..76e989017549c89b45d633525efb1f31
tool_calls: import_v42.z.array(
import_v42.z.object({
index: import_v42.z.number(),
-@@ -785,6 +787,14 @@ var OpenAIChatLanguageModel = class {
+@@ -795,6 +797,13 @@ var OpenAIChatLanguageModel = class {
if (text != null && text.length > 0) {
content.push({ type: "text", text });
}
-+ const reasoning =
-+ choice.message.reasoning_content;
++ const reasoning = choice.message.reasoning_content;
+ if (reasoning != null && reasoning.length > 0) {
+ content.push({
+ type: 'reasoning',
-+ text: reasoning,
++ text: reasoning
+ });
+ }
for (const toolCall of (_a = choice.message.tool_calls) != null ? _a : []) {
content.push({
type: "tool-call",
-@@ -866,6 +876,7 @@ var OpenAIChatLanguageModel = class {
+@@ -876,6 +885,7 @@ var OpenAIChatLanguageModel = class {
};
- let isFirstChunk = true;
+ let metadataExtracted = false;
let isActiveText = false;
+ let isActiveReasoning = false;
const providerMetadata = { openai: {} };
return {
stream: response.pipeThrough(
-@@ -920,6 +931,22 @@ var OpenAIChatLanguageModel = class {
+@@ -933,6 +943,21 @@ var OpenAIChatLanguageModel = class {
return;
}
const delta = choice.delta;
@@ -54,7 +53,6 @@ index cc6652c4e7f32878a64a2614115bf7eeb3b7c890..76e989017549c89b45d633525efb1f31
+ });
+ isActiveReasoning = true;
+ }
-+
+ controller.enqueue({
+ type: 'reasoning-delta',
+ id: 'reasoning-0',
@@ -64,7 +62,7 @@ index cc6652c4e7f32878a64a2614115bf7eeb3b7c890..76e989017549c89b45d633525efb1f31
if (delta.content != null) {
if (!isActiveText) {
controller.enqueue({ type: "text-start", id: "0" });
-@@ -1032,6 +1059,9 @@ var OpenAIChatLanguageModel = class {
+@@ -1045,6 +1070,9 @@ var OpenAIChatLanguageModel = class {
}
},
flush(controller) {
diff --git a/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.25-08bbabb5d3.patch b/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.53-4b77f4cf29.patch
similarity index 69%
rename from .yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.25-08bbabb5d3.patch
rename to .yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.53-4b77f4cf29.patch
index 057443aa43..4481b58f32 100644
--- a/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.25-08bbabb5d3.patch
+++ b/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.53-4b77f4cf29.patch
@@ -1,8 +1,8 @@
diff --git a/sdk.mjs b/sdk.mjs
-index 10162e5b1624f8ce667768943347a6e41089ad2f..32568ae08946590e382270c88d85fba81187568e 100755
+index bf429a344b7d59f70aead16b639f949b07688a81..f77d50cc5d3fb04292cb3ac7fa7085d02dcc628f 100755
--- a/sdk.mjs
+++ b/sdk.mjs
-@@ -6213,7 +6213,7 @@ function createAbortController(maxListeners = DEFAULT_MAX_LISTENERS) {
+@@ -6250,7 +6250,7 @@ function createAbortController(maxListeners = DEFAULT_MAX_LISTENERS) {
}
// ../src/transport/ProcessTransport.ts
@@ -11,16 +11,20 @@ index 10162e5b1624f8ce667768943347a6e41089ad2f..32568ae08946590e382270c88d85fba8
import { createInterface } from "readline";
// ../src/utils/fsOperations.ts
-@@ -6487,14 +6487,11 @@ class ProcessTransport {
+@@ -6619,18 +6619,11 @@ class ProcessTransport {
const errorMessage = isNativeBinary(pathToClaudeCodeExecutable) ? `Claude Code native binary not found at ${pathToClaudeCodeExecutable}. Please ensure Claude Code is installed via native installer or specify a valid path with options.pathToClaudeCodeExecutable.` : `Claude Code executable not found at ${pathToClaudeCodeExecutable}. Is options.pathToClaudeCodeExecutable set?`;
throw new ReferenceError(errorMessage);
}
- const isNative = isNativeBinary(pathToClaudeCodeExecutable);
- const spawnCommand = isNative ? pathToClaudeCodeExecutable : executable;
- const spawnArgs = isNative ? [...executableArgs, ...args] : [...executableArgs, pathToClaudeCodeExecutable, ...args];
-- this.logForDebugging(isNative ? `Spawning Claude Code native binary: ${spawnCommand} ${spawnArgs.join(" ")}` : `Spawning Claude Code process: ${spawnCommand} ${spawnArgs.join(" ")}`);
-+ this.logForDebugging(`Forking Claude Code Node.js process: ${pathToClaudeCodeExecutable} ${args.join(" ")}`);
- const stderrMode = env.DEBUG || stderr ? "pipe" : "ignore";
+- const spawnMessage = isNative ? `Spawning Claude Code native binary: ${spawnCommand} ${spawnArgs.join(" ")}` : `Spawning Claude Code process: ${spawnCommand} ${spawnArgs.join(" ")}`;
+- logForSdkDebugging(spawnMessage);
+- if (stderr) {
+- stderr(spawnMessage);
+- }
++ logForSdkDebugging(`Forking Claude Code Node.js process: ${pathToClaudeCodeExecutable} ${args.join(" ")}`);
+ const stderrMode = env.DEBUG_CLAUDE_AGENT_SDK || stderr ? "pipe" : "ignore";
- this.child = spawn(spawnCommand, spawnArgs, {
+ this.child = fork(pathToClaudeCodeExecutable, args, {
cwd,
diff --git a/.yarn/patches/app-builder-lib-npm-26.0.15-360e5b0476.patch b/.yarn/patches/app-builder-lib-npm-26.0.15-360e5b0476.patch
deleted file mode 100644
index e9ca84e6cd..0000000000
--- a/.yarn/patches/app-builder-lib-npm-26.0.15-360e5b0476.patch
+++ /dev/null
@@ -1,276 +0,0 @@
-diff --git a/out/macPackager.js b/out/macPackager.js
-index 852f6c4d16f86a7bb8a78bf1ed5a14647a279aa1..60e7f5f16a844541eb1909b215fcda1811e924b8 100644
---- a/out/macPackager.js
-+++ b/out/macPackager.js
-@@ -423,7 +423,7 @@ class MacPackager extends platformPackager_1.PlatformPackager {
- }
- appPlist.CFBundleName = appInfo.productName;
- appPlist.CFBundleDisplayName = appInfo.productName;
-- const minimumSystemVersion = this.platformSpecificBuildOptions.minimumSystemVersion;
-+ const minimumSystemVersion = this.platformSpecificBuildOptions.LSMinimumSystemVersion;
- if (minimumSystemVersion != null) {
- appPlist.LSMinimumSystemVersion = minimumSystemVersion;
- }
-diff --git a/out/publish/updateInfoBuilder.js b/out/publish/updateInfoBuilder.js
-index 7924c5b47d01f8dfccccb8f46658015fa66da1f7..1a1588923c3939ae1297b87931ba83f0ebc052d8 100644
---- a/out/publish/updateInfoBuilder.js
-+++ b/out/publish/updateInfoBuilder.js
-@@ -133,6 +133,7 @@ async function createUpdateInfo(version, event, releaseInfo) {
- const customUpdateInfo = event.updateInfo;
- const url = path.basename(event.file);
- const sha512 = (customUpdateInfo == null ? null : customUpdateInfo.sha512) || (await (0, hash_1.hashFile)(event.file));
-+ const minimumSystemVersion = customUpdateInfo == null ? null : customUpdateInfo.minimumSystemVersion;
- const files = [{ url, sha512 }];
- const result = {
- // @ts-ignore
-@@ -143,9 +144,13 @@ async function createUpdateInfo(version, event, releaseInfo) {
- path: url /* backward compatibility, electron-updater 1.x - electron-updater 2.15.0 */,
- // @ts-ignore
- sha512 /* backward compatibility, electron-updater 1.x - electron-updater 2.15.0 */,
-+ minimumSystemVersion,
- ...releaseInfo,
- };
- if (customUpdateInfo != null) {
-+ if (customUpdateInfo.minimumSystemVersion) {
-+ delete customUpdateInfo.minimumSystemVersion;
-+ }
- // file info or nsis web installer packages info
- Object.assign("sha512" in customUpdateInfo ? files[0] : result, customUpdateInfo);
- }
-diff --git a/out/targets/ArchiveTarget.js b/out/targets/ArchiveTarget.js
-index e1f52a5fa86fff6643b2e57eaf2af318d541f865..47cc347f154a24b365e70ae5e1f6d309f3582ed0 100644
---- a/out/targets/ArchiveTarget.js
-+++ b/out/targets/ArchiveTarget.js
-@@ -69,6 +69,9 @@ class ArchiveTarget extends core_1.Target {
- }
- }
- }
-+ if (updateInfo != null && this.packager.platformSpecificBuildOptions.minimumSystemVersion) {
-+ updateInfo.minimumSystemVersion = this.packager.platformSpecificBuildOptions.minimumSystemVersion;
-+ }
- await packager.info.emitArtifactBuildCompleted({
- updateInfo,
- file: artifactPath,
-diff --git a/out/targets/nsis/NsisTarget.js b/out/targets/nsis/NsisTarget.js
-index e8bd7bb46c8a54b3f55cf3a853ef924195271e01..f956e9f3fe9eb903c78aef3502553b01de4b89b1 100644
---- a/out/targets/nsis/NsisTarget.js
-+++ b/out/targets/nsis/NsisTarget.js
-@@ -305,6 +305,9 @@ class NsisTarget extends core_1.Target {
- if (updateInfo != null && isPerMachine && (oneClick || options.packElevateHelper)) {
- updateInfo.isAdminRightsRequired = true;
- }
-+ if (updateInfo != null && this.packager.platformSpecificBuildOptions.minimumSystemVersion) {
-+ updateInfo.minimumSystemVersion = this.packager.platformSpecificBuildOptions.minimumSystemVersion;
-+ }
- await packager.info.emitArtifactBuildCompleted({
- file: installerPath,
- updateInfo,
-diff --git a/out/util/yarn.js b/out/util/yarn.js
-index 1ee20f8b252a8f28d0c7b103789cf0a9a427aec1..c2878ec54d57da50bf14225e0c70c9c88664eb8a 100644
---- a/out/util/yarn.js
-+++ b/out/util/yarn.js
-@@ -140,6 +140,7 @@ async function rebuild(config, { appDir, projectDir }, options) {
- arch,
- platform,
- buildFromSource,
-+ ignoreModules: config.excludeReBuildModules || undefined,
- projectRootPath: projectDir,
- mode: config.nativeRebuilder || "sequential",
- disablePreGypCopy: true,
-diff --git a/scheme.json b/scheme.json
-index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..0167441bf928a92f59b5dbe70b2317a74dda74c9 100644
---- a/scheme.json
-+++ b/scheme.json
-@@ -1825,6 +1825,20 @@
- "string"
- ]
- },
-+ "excludeReBuildModules": {
-+ "anyOf": [
-+ {
-+ "items": {
-+ "type": "string"
-+ },
-+ "type": "array"
-+ },
-+ {
-+ "type": "null"
-+ }
-+ ],
-+ "description": "The modules to exclude from the rebuild."
-+ },
- "executableArgs": {
- "anyOf": [
- {
-@@ -1975,6 +1989,13 @@
- ],
- "description": "The mime types in addition to specified in the file associations. Use it if you don't want to register a new mime type, but reuse existing."
- },
-+ "minimumSystemVersion": {
-+ "description": "The minimum os kernel version required to install the application.",
-+ "type": [
-+ "null",
-+ "string"
-+ ]
-+ },
- "packageCategory": {
- "description": "backward compatibility + to allow specify fpm-only category for all possible fpm targets in one place",
- "type": [
-@@ -2327,6 +2348,13 @@
- "MacConfiguration": {
- "additionalProperties": false,
- "properties": {
-+ "LSMinimumSystemVersion": {
-+ "description": "The minimum version of macOS required for the app to run. Corresponds to `LSMinimumSystemVersion`.",
-+ "type": [
-+ "null",
-+ "string"
-+ ]
-+ },
- "additionalArguments": {
- "anyOf": [
- {
-@@ -2527,6 +2555,20 @@
- "string"
- ]
- },
-+ "excludeReBuildModules": {
-+ "anyOf": [
-+ {
-+ "items": {
-+ "type": "string"
-+ },
-+ "type": "array"
-+ },
-+ {
-+ "type": "null"
-+ }
-+ ],
-+ "description": "The modules to exclude from the rebuild."
-+ },
- "executableName": {
- "description": "The executable name. Defaults to `productName`.",
- "type": [
-@@ -2737,7 +2779,7 @@
- "type": "boolean"
- },
- "minimumSystemVersion": {
-- "description": "The minimum version of macOS required for the app to run. Corresponds to `LSMinimumSystemVersion`.",
-+ "description": "The minimum os kernel version required to install the application.",
- "type": [
- "null",
- "string"
-@@ -2959,6 +3001,13 @@
- "MasConfiguration": {
- "additionalProperties": false,
- "properties": {
-+ "LSMinimumSystemVersion": {
-+ "description": "The minimum version of macOS required for the app to run. Corresponds to `LSMinimumSystemVersion`.",
-+ "type": [
-+ "null",
-+ "string"
-+ ]
-+ },
- "additionalArguments": {
- "anyOf": [
- {
-@@ -3159,6 +3208,20 @@
- "string"
- ]
- },
-+ "excludeReBuildModules": {
-+ "anyOf": [
-+ {
-+ "items": {
-+ "type": "string"
-+ },
-+ "type": "array"
-+ },
-+ {
-+ "type": "null"
-+ }
-+ ],
-+ "description": "The modules to exclude from the rebuild."
-+ },
- "executableName": {
- "description": "The executable name. Defaults to `productName`.",
- "type": [
-@@ -3369,7 +3432,7 @@
- "type": "boolean"
- },
- "minimumSystemVersion": {
-- "description": "The minimum version of macOS required for the app to run. Corresponds to `LSMinimumSystemVersion`.",
-+ "description": "The minimum os kernel version required to install the application.",
- "type": [
- "null",
- "string"
-@@ -6381,6 +6444,20 @@
- "string"
- ]
- },
-+ "excludeReBuildModules": {
-+ "anyOf": [
-+ {
-+ "items": {
-+ "type": "string"
-+ },
-+ "type": "array"
-+ },
-+ {
-+ "type": "null"
-+ }
-+ ],
-+ "description": "The modules to exclude from the rebuild."
-+ },
- "executableName": {
- "description": "The executable name. Defaults to `productName`.",
- "type": [
-@@ -6507,6 +6584,13 @@
- "string"
- ]
- },
-+ "minimumSystemVersion": {
-+ "description": "The minimum os kernel version required to install the application.",
-+ "type": [
-+ "null",
-+ "string"
-+ ]
-+ },
- "protocols": {
- "anyOf": [
- {
-@@ -7153,6 +7237,20 @@
- "string"
- ]
- },
-+ "excludeReBuildModules": {
-+ "anyOf": [
-+ {
-+ "items": {
-+ "type": "string"
-+ },
-+ "type": "array"
-+ },
-+ {
-+ "type": "null"
-+ }
-+ ],
-+ "description": "The modules to exclude from the rebuild."
-+ },
- "executableName": {
- "description": "The executable name. Defaults to `productName`.",
- "type": [
-@@ -7376,6 +7474,13 @@
- ],
- "description": "MAS (Mac Application Store) development options (`mas-dev` target)."
- },
-+ "minimumSystemVersion": {
-+ "description": "The minimum os kernel version required to install the application.",
-+ "type": [
-+ "null",
-+ "string"
-+ ]
-+ },
- "msi": {
- "anyOf": [
- {
diff --git a/.yarn/patches/electron-updater-npm-6.7.0-47b11bb0d4.patch b/.yarn/patches/electron-updater-npm-6.7.0-47b11bb0d4.patch
new file mode 100644
index 0000000000..f9e54ac947
--- /dev/null
+++ b/.yarn/patches/electron-updater-npm-6.7.0-47b11bb0d4.patch
@@ -0,0 +1,14 @@
+diff --git a/out/util.js b/out/util.js
+index 9294ffd6ca8f02c2e0f90c663e7e9cdc02c1ac37..f52107493e2995320ee4efd0eb2a8c9bf03291a2 100644
+--- a/out/util.js
++++ b/out/util.js
+@@ -23,7 +23,8 @@ function newUrlFromBase(pathname, baseUrl, addRandomQueryToAvoidCaching = false)
+ result.search = search;
+ }
+ else if (addRandomQueryToAvoidCaching) {
+- result.search = `noCache=${Date.now().toString(32)}`;
++ // use no cache header instead
++ // result.search = `noCache=${Date.now().toString(32)}`;
+ }
+ return result;
+ }
diff --git a/CLAUDE.md b/CLAUDE.md
index 0728605824..c96fc0e403 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -10,8 +10,17 @@ This file provides guidance to AI coding assistants when working with code in th
- **Log centrally**: Route all logging through `loggerService` with the right context—no `console.log`.
- **Research via subagent**: Lean on `subagent` for external docs, APIs, news, and references.
- **Always propose before executing**: Before making any changes, clearly explain your planned approach and wait for explicit user approval to ensure alignment and prevent unwanted modifications.
-- **Write conventional commits with emoji**: Commit small, focused changes using emoji-prefixed Conventional Commit messages (e.g., `✨ feat:`, `🐛 fix:`, `♻️ refactor:`, `
-📝 docs:`).
+- **Lint, test, and format before completion**: Coding tasks are only complete after running `yarn lint`, `yarn test`, and `yarn format` successfully.
+- **Write conventional commits**: Commit small, focused changes using Conventional Commit messages (e.g., `feat:`, `fix:`, `refactor:`, `docs:`).
+
+## Pull Request Workflow (CRITICAL)
+
+When creating a Pull Request, you MUST:
+
+1. **Read the PR template first**: Always read `.github/pull_request_template.md` before creating the PR
+2. **Follow ALL template sections**: Structure the `--body` parameter to include every section from the template
+3. **Never skip sections**: Include all sections even if marking them as N/A or "None"
+4. **Use proper formatting**: Match the template's markdown structure exactly (headings, checkboxes, code blocks)
## Development Commands
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 545c34dc12..3ddc05ce85 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,4 +1,4 @@
-[中文](docs/CONTRIBUTING.zh.md) | [English](CONTRIBUTING.md)
+[中文](docs/zh/guides/contributing.md) | [English](CONTRIBUTING.md)
# Cherry Studio Contributor Guide
@@ -32,7 +32,7 @@ To help you get familiar with the codebase, we recommend tackling issues tagged
### Testing
-Features without tests are considered non-existent. To ensure code is truly effective, relevant processes should be covered by unit tests and functional tests. Therefore, when considering contributions, please also consider testability. All tests can be run locally without dependency on CI. Please refer to the "Testing" section in the [Developer Guide](docs/dev.md).
+Features without tests are considered non-existent. To ensure code is truly effective, relevant processes should be covered by unit tests and functional tests. Therefore, when considering contributions, please also consider testability. All tests can be run locally without dependency on CI. Please refer to the "Testing" section in the [Developer Guide](docs/zh/guides/development.md).
### Automated Testing for Pull Requests
@@ -60,7 +60,7 @@ Maintainers are here to help you implement your use case within a reasonable tim
### Participating in the Test Plan
-The Test Plan aims to provide users with a more stable application experience and faster iteration speed. For details, please refer to the [Test Plan](docs/testplan-en.md).
+The Test Plan aims to provide users with a more stable application experience and faster iteration speed. For details, please refer to the [Test Plan](docs/en/guides/test-plan.md).
### Other Suggestions
diff --git a/README.md b/README.md
index c3d3f915a1..f790c10cbd 100644
--- a/README.md
+++ b/README.md
@@ -34,7 +34,7 @@
-
English | 中文 | Official Site | Documents | Development | Feedback
+English | 中文 | Official Site | Documents | Development | Feedback
@@ -67,7 +67,7 @@ Cherry Studio is a desktop client that supports multiple LLM providers, availabl
👏 Join [Telegram Group](https://t.me/CherryStudioAI)|[Discord](https://discord.gg/wez8HtpxqQ) | [QQ Group(575014769)](https://qm.qq.com/q/lo0D4qVZKi)
-❤️ Like Cherry Studio? Give it a star 🌟 or [Sponsor](docs/sponsor.md) to support the development!
+❤️ Like Cherry Studio? Give it a star 🌟 or [Sponsor](docs/zh/guides/sponsor.md) to support the development!
# 🌠 Screenshot
@@ -82,7 +82,7 @@ Cherry Studio is a desktop client that supports multiple LLM providers, availabl
1. **Diverse LLM Provider Support**:
- ☁️ Major LLM Cloud Services: OpenAI, Gemini, Anthropic, and more
-- 🔗 AI Web Service Integration: Claude, Perplexity, Poe, and others
+- 🔗 AI Web Service Integration: Claude, Perplexity, [Poe](https://poe.com/), and others
- 💻 Local Model Support with Ollama, LM Studio
2. **AI Assistants & Conversations**:
@@ -175,7 +175,7 @@ We welcome contributions to Cherry Studio! Here are some ways you can contribute
6. **Community Engagement**: Join discussions and help users.
7. **Promote Usage**: Spread the word about Cherry Studio.
-Refer to the [Branching Strategy](docs/branching-strategy-en.md) for contribution guidelines
+Refer to the [Branching Strategy](docs/en/guides/branching-strategy.md) for contribution guidelines
## Getting Started
@@ -238,10 +238,6 @@ The Enterprise Edition addresses core challenges in team collaboration by centra
## ✨ Online Demo
-> 🚧 **Public Beta Notice**
->
-> The Enterprise Edition is currently in its early public beta stage, and we are actively iterating and optimizing its features. We are aware that it may not be perfectly stable yet. If you encounter any issues or have valuable suggestions during your trial, we would be very grateful if you could contact us via email to provide feedback.
-
**🔗 [Cherry Studio Enterprise](https://www.cherry-ai.com/enterprise)**
## Version Comparison
@@ -249,7 +245,7 @@ The Enterprise Edition addresses core challenges in team collaboration by centra
| Feature | Community Edition | Enterprise Edition |
| :---------------- | :----------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------- |
| **Open Source** | ✅ Yes | ⭕️ Partially released to customers |
-| **Cost** | Free for Personal Use / Commercial License | Buyout / Subscription Fee |
+| **Cost** | [AGPL-3.0 License](https://github.com/CherryHQ/cherry-studio?tab=AGPL-3.0-1-ov-file) | Buyout / Subscription Fee |
| **Admin Backend** | — | ● Centralized **Model** Access
● **Employee** Management
● Shared **Knowledge Base**
● **Access** Control
● **Data** Backup |
| **Server** | — | ✅ Dedicated Private Deployment |
@@ -262,8 +258,12 @@ We believe the Enterprise Edition will become your team's AI productivity engine
# 🔗 Related Projects
+- [new-api](https://github.com/QuantumNous/new-api): The next-generation LLM gateway and AI asset management system supports multiple languages.
+
- [one-api](https://github.com/songquanpeng/one-api): LLM API management and distribution system supporting mainstream models like OpenAI, Azure, and Anthropic. Features a unified API interface, suitable for key management and secondary distribution.
+- [Poe](https://poe.com/): Poe gives you access to the best AI, all in one place. Explore GPT-5, Claude Opus 4.1, DeepSeek-R1, Veo 3, ElevenLabs, and millions of others.
+
- [ublacklist](https://github.com/iorate/ublacklist): Blocks specific sites from appearing in Google search results
# 🚀 Contributors
diff --git a/app-upgrade-config.json b/app-upgrade-config.json
new file mode 100644
index 0000000000..84e381c86a
--- /dev/null
+++ b/app-upgrade-config.json
@@ -0,0 +1,49 @@
+{
+ "lastUpdated": "2025-11-10T08:14:28Z",
+ "versions": {
+ "1.6.7": {
+ "metadata": {
+ "segmentId": "legacy-v1",
+ "segmentType": "legacy"
+ },
+ "minCompatibleVersion": "1.0.0",
+ "description": "Last stable v1.7.x release - required intermediate version for users below v1.7",
+ "channels": {
+ "latest": {
+ "version": "1.6.7",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.7",
+ "gitcode": "https://releases.cherry-ai.com"
+ }
+ },
+ "rc": {
+ "version": "1.6.0-rc.5",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5",
+ "gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5"
+ }
+ },
+ "beta": {
+ "version": "1.7.0-beta.3",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3",
+ "gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3"
+ }
+ }
+ }
+ },
+ "2.0.0": {
+ "metadata": {
+ "segmentId": "gateway-v2",
+ "segmentType": "breaking"
+ },
+ "minCompatibleVersion": "1.7.0",
+ "description": "Major release v2.0 - required intermediate version for v2.x upgrades",
+ "channels": {
+ "latest": null,
+ "rc": null,
+ "beta": null
+ }
+ }
+ }
+}
diff --git a/biome.jsonc b/biome.jsonc
index 9509135fc4..705b1e01f3 100644
--- a/biome.jsonc
+++ b/biome.jsonc
@@ -14,7 +14,7 @@
}
},
"enabled": true,
- "includes": ["**/*.json", "!*.json", "!**/package.json"]
+ "includes": ["**/*.json", "!*.json", "!**/package.json", "!coverage/**"]
},
"css": {
"formatter": {
@@ -23,7 +23,7 @@
},
"files": {
"ignoreUnknown": false,
- "includes": ["**", "!**/.claude/**"],
+ "includes": ["**", "!**/.claude/**", "!**/.vscode/**"],
"maxSize": 2097152
},
"formatter": {
diff --git a/components.json b/components.json
deleted file mode 100644
index c5aceeb3ce..0000000000
--- a/components.json
+++ /dev/null
@@ -1,21 +0,0 @@
-{
- "$schema": "https://ui.shadcn.com/schema.json",
- "aliases": {
- "components": "@renderer/ui/third-party",
- "hooks": "@renderer/hooks",
- "lib": "@renderer/lib",
- "ui": "@renderer/ui",
- "utils": "@renderer/utils"
- },
- "iconLibrary": "lucide",
- "rsc": false,
- "style": "new-york",
- "tailwind": {
- "baseColor": "zinc",
- "config": "",
- "css": "src/renderer/src/assets/styles/tailwind.css",
- "cssVariables": true,
- "prefix": ""
- },
- "tsx": true
-}
diff --git a/config/app-upgrade-segments.json b/config/app-upgrade-segments.json
new file mode 100644
index 0000000000..70c8ac25f0
--- /dev/null
+++ b/config/app-upgrade-segments.json
@@ -0,0 +1,81 @@
+{
+ "segments": [
+ {
+ "id": "legacy-v1",
+ "type": "legacy",
+ "match": {
+ "range": ">=1.0.0 <2.0.0"
+ },
+ "minCompatibleVersion": "1.0.0",
+ "description": "Last stable v1.7.x release - required intermediate version for users below v1.7",
+ "channelTemplates": {
+ "latest": {
+ "feedTemplates": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
+ "gitcode": "https://releases.cherry-ai.com"
+ }
+ },
+ "rc": {
+ "feedTemplates": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
+ "gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
+ }
+ },
+ "beta": {
+ "feedTemplates": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
+ "gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
+ }
+ }
+ }
+ },
+ {
+ "id": "gateway-v2",
+ "type": "breaking",
+ "match": {
+ "exact": ["2.0.0"]
+ },
+ "lockedVersion": "2.0.0",
+ "minCompatibleVersion": "1.7.0",
+ "description": "Major release v2.0 - required intermediate version for v2.x upgrades",
+ "channelTemplates": {
+ "latest": {
+ "feedTemplates": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
+ "gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
+ }
+ }
+ }
+ },
+ {
+ "id": "current-v2",
+ "type": "latest",
+ "match": {
+ "range": ">=2.0.0 <3.0.0",
+ "excludeExact": ["2.0.0"]
+ },
+ "minCompatibleVersion": "2.0.0",
+ "description": "Current latest v2.x release",
+ "channelTemplates": {
+ "latest": {
+ "feedTemplates": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
+ "gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
+ }
+ },
+ "rc": {
+ "feedTemplates": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
+ "gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
+ }
+ },
+ "beta": {
+ "feedTemplates": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
+ "gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
+ }
+ }
+ }
+ }
+ ]
+}
diff --git a/docs/README.md b/docs/README.md
new file mode 100644
index 0000000000..bd5f055766
--- /dev/null
+++ b/docs/README.md
@@ -0,0 +1,81 @@
+# Cherry Studio Documentation / 文档
+
+This directory contains the project documentation in multiple languages.
+
+本目录包含多语言项目文档。
+
+---
+
+## Languages / 语言
+
+- **[中文文档](./zh/README.md)** - Chinese Documentation
+- **English Documentation** - See sections below
+
+---
+
+## English Documentation
+
+### Guides
+
+| Document | Description |
+|----------|-------------|
+| [Development Setup](./en/guides/development.md) | Development environment setup |
+| [Branching Strategy](./en/guides/branching-strategy.md) | Git branching workflow |
+| [i18n Guide](./en/guides/i18n.md) | Internationalization guide |
+| [Logging Guide](./en/guides/logging.md) | How to use the logger service |
+| [Test Plan](./en/guides/test-plan.md) | Test plan and release channels |
+
+### References
+
+| Document | Description |
+|----------|-------------|
+| [App Upgrade Config](./en/references/app-upgrade.md) | Application upgrade configuration |
+| [CodeBlockView Component](./en/references/components/code-block-view.md) | Code block view component |
+| [Image Preview Components](./en/references/components/image-preview.md) | Image preview components |
+
+---
+
+## 中文文档
+
+### 指南 (Guides)
+
+| 文档 | 说明 |
+|------|------|
+| [开发环境设置](./zh/guides/development.md) | 开发环境配置 |
+| [贡献指南](./zh/guides/contributing.md) | 如何贡献代码 |
+| [分支策略](./zh/guides/branching-strategy.md) | Git 分支工作流 |
+| [测试计划](./zh/guides/test-plan.md) | 测试计划和发布通道 |
+| [国际化指南](./zh/guides/i18n.md) | 国际化开发指南 |
+| [日志使用指南](./zh/guides/logging.md) | 如何使用日志服务 |
+| [中间件开发](./zh/guides/middleware.md) | 如何编写中间件 |
+| [记忆功能](./zh/guides/memory.md) | 记忆功能使用指南 |
+| [赞助信息](./zh/guides/sponsor.md) | 赞助相关信息 |
+
+### 参考 (References)
+
+| 文档 | 说明 |
+|------|------|
+| [消息系统](./zh/references/message-system.md) | 消息系统架构和 API |
+| [数据库结构](./zh/references/database.md) | 数据库表结构 |
+| [服务](./zh/references/services.md) | 服务层文档 (KnowledgeService) |
+| [代码执行](./zh/references/code-execution.md) | 代码执行功能 |
+| [应用升级配置](./zh/references/app-upgrade.md) | 应用升级配置 |
+| [CodeBlockView 组件](./zh/references/components/code-block-view.md) | 代码块视图组件 |
+| [图像预览组件](./zh/references/components/image-preview.md) | 图像预览组件 |
+
+---
+
+## Missing Translations / 缺少翻译
+
+The following documents are only available in Chinese and need English translations:
+
+以下文档仅有中文版本,需要英文翻译:
+
+- `guides/contributing.md`
+- `guides/memory.md`
+- `guides/middleware.md`
+- `guides/sponsor.md`
+- `references/message-system.md`
+- `references/database.md`
+- `references/services.md`
+- `references/code-execution.md`
diff --git a/docs/technical/.assets.how-to-i18n/demo-1.png b/docs/assets/images/i18n/demo-1.png
similarity index 100%
rename from docs/technical/.assets.how-to-i18n/demo-1.png
rename to docs/assets/images/i18n/demo-1.png
diff --git a/docs/technical/.assets.how-to-i18n/demo-2.png b/docs/assets/images/i18n/demo-2.png
similarity index 100%
rename from docs/technical/.assets.how-to-i18n/demo-2.png
rename to docs/assets/images/i18n/demo-2.png
diff --git a/docs/technical/.assets.how-to-i18n/demo-3.png b/docs/assets/images/i18n/demo-3.png
similarity index 100%
rename from docs/technical/.assets.how-to-i18n/demo-3.png
rename to docs/assets/images/i18n/demo-3.png
diff --git a/docs/technical/message-lifecycle.png b/docs/assets/images/message-lifecycle.png
similarity index 100%
rename from docs/technical/message-lifecycle.png
rename to docs/assets/images/message-lifecycle.png
diff --git a/docs/branching-strategy-en.md b/docs/en/guides/branching-strategy.md
similarity index 98%
rename from docs/branching-strategy-en.md
rename to docs/en/guides/branching-strategy.md
index 8e646249ad..11eabeec73 100644
--- a/docs/branching-strategy-en.md
+++ b/docs/en/guides/branching-strategy.md
@@ -16,7 +16,7 @@ Cherry Studio implements a structured branching strategy to maintain code qualit
- Only accepts documentation updates and bug fixes
- Thoroughly tested before production deployment
-For details about the `testplan` branch used in the Test Plan, please refer to the [Test Plan](testplan-en.md).
+For details about the `testplan` branch used in the Test Plan, please refer to the [Test Plan](./test-plan.md).
## Contributing Branches
diff --git a/docs/dev.md b/docs/en/guides/development.md
similarity index 100%
rename from docs/dev.md
rename to docs/en/guides/development.md
diff --git a/docs/technical/how-to-i18n-en.md b/docs/en/guides/i18n.md
similarity index 97%
rename from docs/technical/how-to-i18n-en.md
rename to docs/en/guides/i18n.md
index 1bbf7edca8..a3284e3ab9 100644
--- a/docs/technical/how-to-i18n-en.md
+++ b/docs/en/guides/i18n.md
@@ -18,11 +18,11 @@ The plugin has already been configured in the project — simply install it to g
### Demo
-
+
-
+
-
+
## i18n Conventions
diff --git a/docs/technical/how-to-use-logger-en.md b/docs/en/guides/logging.md
similarity index 100%
rename from docs/technical/how-to-use-logger-en.md
rename to docs/en/guides/logging.md
diff --git a/docs/testplan-en.md b/docs/en/guides/test-plan.md
similarity index 92%
rename from docs/testplan-en.md
rename to docs/en/guides/test-plan.md
index 0f7cd41473..c7d0c4c660 100644
--- a/docs/testplan-en.md
+++ b/docs/en/guides/test-plan.md
@@ -11,13 +11,15 @@ The Test Plan is divided into the RC channel and the Beta channel, with the foll
Users can enable the "Test Plan" and select the version channel in the software's `Settings` > `About`. Please note that the versions in the "Test Plan" cannot guarantee data consistency, so be sure to back up your data before using them.
+After enabling the RC channel or Beta channel, if a stable version is released, users will still be upgraded to the stable version.
+
Users are welcome to submit issues or provide feedback through other channels for any bugs encountered during testing. Your feedback is very important to us.
## Developer Guide
### Participating in the Test Plan
-Developers should submit `PRs` according to the [Contributor Guide](../CONTRIBUTING.md) (and ensure the target branch is `main`). The repository maintainers will evaluate whether the `PR` should be included in the Test Plan based on factors such as the impact of the feature on the application, its importance, and whether broader testing is needed.
+Developers should submit `PRs` according to the [Contributor Guide](../../CONTRIBUTING.md) (and ensure the target branch is `main`). The repository maintainers will evaluate whether the `PR` should be included in the Test Plan based on factors such as the impact of the feature on the application, its importance, and whether broader testing is needed.
If the `PR` is added to the Test Plan, the repository maintainers will:
diff --git a/docs/en/references/app-upgrade.md b/docs/en/references/app-upgrade.md
new file mode 100644
index 0000000000..0662abf236
--- /dev/null
+++ b/docs/en/references/app-upgrade.md
@@ -0,0 +1,430 @@
+# Update Configuration System Design Document
+
+## Background
+
+Currently, AppUpdater directly queries the GitHub API to retrieve beta and rc update information. To support users in China, we need to fetch a static JSON configuration file from GitHub/GitCode based on IP geolocation, which contains update URLs for all channels.
+
+## Design Goals
+
+1. Support different configuration sources based on IP geolocation (GitHub/GitCode)
+2. Support version compatibility control (e.g., users below v1.x must upgrade to v1.7.0 before upgrading to v2.0)
+3. Easy to extend, supporting future multi-major-version upgrade paths (v1.6 → v1.7 → v2.0 → v2.8 → v3.0)
+4. Maintain compatibility with existing electron-updater mechanism
+
+## Current Version Strategy
+
+- **v1.7.x** is the last version of the 1.x series
+- Users **below v1.7.0** must first upgrade to v1.7.0 (or higher 1.7.x version)
+- Users **v1.7.0 and above** can directly upgrade to v2.x.x
+
+## Automation Workflow
+
+The `x-files/app-upgrade-config/app-upgrade-config.json` file is synchronized by the [`Update App Upgrade Config`](../../.github/workflows/update-app-upgrade-config.yml) workflow. The workflow runs the [`scripts/update-app-upgrade-config.ts`](../../scripts/update-app-upgrade-config.ts) helper so that every release tag automatically updates the JSON in `x-files/app-upgrade-config`.
+
+### Trigger Conditions
+
+- **Release events (`release: released/prereleased`)**
+ - Draft releases are ignored.
+ - When GitHub marks the release as _prerelease_, the tag must include `-beta`/`-rc` (with optional numeric suffix). Otherwise the workflow exits early.
+ - When GitHub marks the release as stable, the tag must match the latest release returned by the GitHub API. This prevents out-of-order updates when publishing historical tags.
+ - If the guard clauses pass, the version is tagged as `latest` or `beta/rc` based on its semantic suffix and propagated to the script through the `IS_PRERELEASE` flag.
+- **Manual dispatch (`workflow_dispatch`)**
+ - Required input: `tag` (e.g., `v2.0.1`). Optional input: `is_prerelease` (defaults to `false`).
+ - When `is_prerelease=true`, the tag must carry a beta/rc suffix, mirroring the automatic validation.
+ - Manual runs still download the latest release metadata so that the workflow knows whether the tag represents the newest stable version (for documentation inside the PR body).
+
+### Workflow Steps
+
+1. **Guard + metadata preparation** – the `Check if should proceed` and `Prepare metadata` steps compute the target tag, prerelease flag, whether the tag is the newest release, and a `safe_tag` slug used for branch names. When any rule fails, the workflow stops without touching the config.
+2. **Checkout source branches** – the default branch is checked out into `main/`, while the long-lived `x-files/app-upgrade-config` branch lives in `cs/`. All modifications happen in the latter directory.
+3. **Install toolchain** – Node.js 22, Corepack, and frozen Yarn dependencies are installed inside `main/`.
+4. **Run the update script** – `yarn tsx scripts/update-app-upgrade-config.ts --tag
--config ../cs/app-upgrade-config.json --is-prerelease ` updates the JSON in-place.
+ - The script normalizes the tag (e.g., strips `v` prefix), detects the release channel (`latest`, `rc`, `beta`), and loads segment rules from `config/app-upgrade-segments.json`.
+ - It validates that prerelease flags and semantic suffixes agree, enforces locked segments, builds mirror feed URLs, and performs release-availability checks (GitHub HEAD request for every channel; GitCode GET for latest channels, falling back to `https://releases.cherry-ai.com` when gitcode is delayed).
+ - After updating the relevant channel entry, the script rewrites the config with semver-sort order and a new `lastUpdated` timestamp.
+5. **Detect changes + create PR** – if `cs/app-upgrade-config.json` changed, the workflow opens a PR `chore/update-app-upgrade-config/` against `x-files/app-upgrade-config` with a commit message `🤖 chore: sync app-upgrade-config for `. Otherwise it logs that no update is required.
+
+### Manual Trigger Guide
+
+1. Open the Cherry Studio repository on GitHub → **Actions** tab → select **Update App Upgrade Config**.
+2. Click **Run workflow**, choose the default branch (usually `main`), and fill in the `tag` input (e.g., `v2.1.0`).
+3. Toggle `is_prerelease` only when the tag carries a prerelease suffix (`-beta`, `-rc`). Leave it unchecked for stable releases.
+4. Start the run and wait for it to finish. Check the generated PR in the `x-files/app-upgrade-config` branch, verify the diff in `app-upgrade-config.json`, and merge once validated.
+
+## JSON Configuration File Format
+
+### File Location
+
+- **GitHub**: `https://raw.githubusercontent.com/CherryHQ/cherry-studio/refs/heads/x-files/app-upgrade-config/app-upgrade-config.json`
+- **GitCode**: `https://gitcode.com/CherryHQ/cherry-studio/raw/x-files/app-upgrade-config/app-upgrade-config.json`
+
+**Note**: Both mirrors provide the same configuration file hosted on the `x-files/app-upgrade-config` branch. The client automatically selects the optimal mirror based on IP geolocation.
+
+### Configuration Structure (Current Implementation)
+
+```json
+{
+ "lastUpdated": "2025-01-05T00:00:00Z",
+ "versions": {
+ "1.6.7": {
+ "minCompatibleVersion": "1.0.0",
+ "description": "Last stable v1.7.x release - required intermediate version for users below v1.7",
+ "channels": {
+ "latest": {
+ "version": "1.6.7",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.7",
+ "gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v1.6.7"
+ }
+ },
+ "rc": {
+ "version": "1.6.0-rc.5",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5",
+ "gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5"
+ }
+ },
+ "beta": {
+ "version": "1.6.7-beta.3",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3",
+ "gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3"
+ }
+ }
+ }
+ },
+ "2.0.0": {
+ "minCompatibleVersion": "1.7.0",
+ "description": "Major release v2.0 - required intermediate version for v2.x upgrades",
+ "channels": {
+ "latest": null,
+ "rc": null,
+ "beta": null
+ }
+ }
+ }
+}
+```
+
+### Future Extension Example
+
+When releasing v3.0, if users need to first upgrade to v2.8, you can add:
+
+```json
+{
+ "2.8.0": {
+ "minCompatibleVersion": "2.0.0",
+ "description": "Stable v2.8 - required for v3 upgrade",
+ "channels": {
+ "latest": {
+ "version": "2.8.0",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/v2.8.0",
+ "gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v2.8.0"
+ }
+ },
+ "rc": null,
+ "beta": null
+ }
+ },
+ "3.0.0": {
+ "minCompatibleVersion": "2.8.0",
+ "description": "Major release v3.0",
+ "channels": {
+ "latest": {
+ "version": "3.0.0",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/latest",
+ "gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/latest"
+ }
+ },
+ "rc": {
+ "version": "3.0.0-rc.1",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/v3.0.0-rc.1",
+ "gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v3.0.0-rc.1"
+ }
+ },
+ "beta": null
+ }
+ }
+}
+```
+
+### Field Descriptions
+
+- `lastUpdated`: Last update time of the configuration file (ISO 8601 format)
+- `versions`: Version configuration object, key is the version number, sorted by semantic versioning
+ - `minCompatibleVersion`: Minimum compatible version that can upgrade to this version
+ - `description`: Version description
+ - `channels`: Update channel configuration
+ - `latest`: Stable release channel
+ - `rc`: Release Candidate channel
+ - `beta`: Beta testing channel
+ - Each channel contains:
+ - `version`: Version number for this channel
+ - `feedUrls`: Multi-mirror URL configuration
+ - `github`: electron-updater feed URL for GitHub mirror
+ - `gitcode`: electron-updater feed URL for GitCode mirror
+ - `metadata`: Stable mapping info for automation
+ - `segmentId`: ID from `config/app-upgrade-segments.json`
+ - `segmentType`: Optional flag (`legacy` | `breaking` | `latest`) for documentation/debugging
+
+## TypeScript Type Definitions
+
+```typescript
+// Mirror enum
+enum UpdateMirror {
+ GITHUB = 'github',
+ GITCODE = 'gitcode'
+}
+
+interface UpdateConfig {
+ lastUpdated: string
+ versions: {
+ [versionKey: string]: VersionConfig
+ }
+}
+
+interface VersionConfig {
+ minCompatibleVersion: string
+ description: string
+ channels: {
+ latest: ChannelConfig | null
+ rc: ChannelConfig | null
+ beta: ChannelConfig | null
+ }
+ metadata?: {
+ segmentId: string
+ segmentType?: 'legacy' | 'breaking' | 'latest'
+ }
+}
+
+interface ChannelConfig {
+ version: string
+ feedUrls: Record
+ // Equivalent to:
+ // feedUrls: {
+ // github: string
+ // gitcode: string
+ // }
+}
+```
+
+## Segment Metadata & Breaking Markers
+
+- **Segment definitions** now live in `config/app-upgrade-segments.json`. Each segment describes a semantic-version range (or exact matches) plus metadata such as `segmentId`, `segmentType`, `minCompatibleVersion`, and per-channel feed URL templates.
+- Each entry under `versions` carries a `metadata.segmentId`. This acts as the stable key that scripts use to decide which slot to update, even if the actual semantic version string changes.
+- Mark major upgrade gateways (e.g., `2.0.0`) by giving the related segment a `segmentType: "breaking"` and (optionally) `lockedVersion`. This prevents automation from accidentally moving that entry when other 2.x builds ship.
+- Adding another breaking hop (e.g., `3.0.0`) only requires defining a new segment in the JSON file; the automation will pick it up on the next run.
+
+## Automation Workflow
+
+Starting from this change, `.github/workflows/update-app-upgrade-config.yml` listens to GitHub release events (published + prerelease). The workflow:
+
+1. Checks out the default branch (for scripts) and the `x-files/app-upgrade-config` branch (where the config is hosted).
+2. Runs `yarn tsx scripts/update-app-upgrade-config.ts --tag --config ../cs/app-upgrade-config.json` to regenerate the config directly inside the `x-files/app-upgrade-config` working tree.
+3. If the file changed, it opens a PR against `x-files/app-upgrade-config` via `peter-evans/create-pull-request`, with the generated diff limited to `app-upgrade-config.json`.
+
+You can run the same script locally via `yarn update:upgrade-config --tag v2.1.6 --config ../cs/app-upgrade-config.json` (add `--dry-run` to preview) to reproduce or debug whatever the workflow does. Passing `--skip-release-checks` along with `--dry-run` lets you bypass the release-page existence check (useful when the GitHub/GitCode pages aren’t published yet). Running without `--config` continues to update the copy in your current working directory (main branch) for documentation purposes.
+
+## Version Matching Logic
+
+### Algorithm Flow
+
+1. Get user's current version (`currentVersion`) and requested channel (`requestedChannel`)
+2. Get all version numbers from configuration file, sort in descending order by semantic versioning
+3. Iterate through the sorted version list:
+ - Check if `currentVersion >= minCompatibleVersion`
+ - Check if the requested `channel` exists and is not `null`
+ - If conditions are met, return the channel configuration
+4. If no matching version is found, return `null`
+
+### Pseudocode Implementation
+
+```typescript
+function findCompatibleVersion(
+ currentVersion: string,
+ requestedChannel: UpgradeChannel,
+ config: UpdateConfig
+): ChannelConfig | null {
+ // Get all version numbers and sort in descending order
+ const versions = Object.keys(config.versions).sort(semver.rcompare)
+
+ for (const versionKey of versions) {
+ const versionConfig = config.versions[versionKey]
+ const channelConfig = versionConfig.channels[requestedChannel]
+
+ // Check version compatibility and channel availability
+ if (
+ semver.gte(currentVersion, versionConfig.minCompatibleVersion) &&
+ channelConfig !== null
+ ) {
+ return channelConfig
+ }
+ }
+
+ return null // No compatible version found
+}
+```
+
+## Upgrade Path Examples
+
+### Scenario 1: v1.6.5 User Upgrade (Below 1.7)
+
+- **Current Version**: 1.6.5
+- **Requested Channel**: latest
+- **Match Result**: 1.7.0
+- **Reason**: 1.6.5 >= 0.0.0 (satisfies 1.7.0's minCompatibleVersion), but doesn't satisfy 2.0.0's minCompatibleVersion (1.7.0)
+- **Action**: Prompt user to upgrade to 1.7.0, which is the required intermediate version for v2.x upgrade
+
+### Scenario 2: v1.6.5 User Requests rc/beta
+
+- **Current Version**: 1.6.5
+- **Requested Channel**: rc or beta
+- **Match Result**: 1.7.0 (latest)
+- **Reason**: 1.7.0 version doesn't provide rc/beta channels (values are null)
+- **Action**: Upgrade to 1.7.0 stable version
+
+### Scenario 3: v1.7.0 User Upgrades to Latest
+
+- **Current Version**: 1.7.0
+- **Requested Channel**: latest
+- **Match Result**: 2.0.0
+- **Reason**: 1.7.0 >= 1.7.0 (satisfies 2.0.0's minCompatibleVersion)
+- **Action**: Directly upgrade to 2.0.0 (current latest stable version)
+
+### Scenario 4: v1.7.2 User Upgrades to RC Version
+
+- **Current Version**: 1.7.2
+- **Requested Channel**: rc
+- **Match Result**: 2.0.0-rc.1
+- **Reason**: 1.7.2 >= 1.7.0 (satisfies 2.0.0's minCompatibleVersion), and rc channel exists
+- **Action**: Upgrade to 2.0.0-rc.1
+
+### Scenario 5: v1.7.0 User Upgrades to Beta Version
+
+- **Current Version**: 1.7.0
+- **Requested Channel**: beta
+- **Match Result**: 2.0.0-beta.1
+- **Reason**: 1.7.0 >= 1.7.0, and beta channel exists
+- **Action**: Upgrade to 2.0.0-beta.1
+
+### Scenario 6: v2.5.0 User Upgrade (Future)
+
+Assuming v2.8.0 and v3.0.0 configurations have been added:
+- **Current Version**: 2.5.0
+- **Requested Channel**: latest
+- **Match Result**: 2.8.0
+- **Reason**: 2.5.0 >= 2.0.0 (satisfies 2.8.0's minCompatibleVersion), but doesn't satisfy 3.0.0's requirement
+- **Action**: Prompt user to upgrade to 2.8.0, which is the required intermediate version for v3.x upgrade
+
+## Code Changes
+
+### Main Modifications
+
+1. **New Methods**
+ - `_fetchUpdateConfig(ipCountry: string): Promise` - Fetch configuration file based on IP
+ - `_findCompatibleChannel(currentVersion: string, channel: UpgradeChannel, config: UpdateConfig): ChannelConfig | null` - Find compatible channel configuration
+
+2. **Modified Methods**
+ - `_getReleaseVersionFromGithub()` → Remove or refactor to `_getChannelFeedUrl()`
+ - `_setFeedUrl()` - Use new configuration system to replace existing logic
+
+3. **New Type Definitions**
+ - `UpdateConfig`
+ - `VersionConfig`
+ - `ChannelConfig`
+
+### Mirror Selection Logic
+
+The client automatically selects the optimal mirror based on IP geolocation:
+
+```typescript
+private async _setFeedUrl() {
+ const currentVersion = app.getVersion()
+ const testPlan = configManager.getTestPlan()
+ const requestedChannel = testPlan ? this._getTestChannel() : UpgradeChannel.LATEST
+
+ // Determine mirror based on IP country
+ const ipCountry = await getIpCountry()
+ const mirror = ipCountry.toLowerCase() === 'cn' ? 'gitcode' : 'github'
+
+ // Fetch update config
+ const config = await this._fetchUpdateConfig(mirror)
+
+ if (config) {
+ const channelConfig = this._findCompatibleChannel(currentVersion, requestedChannel, config)
+ if (channelConfig) {
+ // Select feed URL from the corresponding mirror
+ const feedUrl = channelConfig.feedUrls[mirror]
+ this._setChannel(requestedChannel, feedUrl)
+ return
+ }
+ }
+
+ // Fallback logic
+ const defaultFeedUrl = mirror === 'gitcode'
+ ? FeedUrl.PRODUCTION
+ : FeedUrl.GITHUB_LATEST
+ this._setChannel(UpgradeChannel.LATEST, defaultFeedUrl)
+}
+
+private async _fetchUpdateConfig(mirror: 'github' | 'gitcode'): Promise {
+ const configUrl = mirror === 'gitcode'
+ ? UpdateConfigUrl.GITCODE
+ : UpdateConfigUrl.GITHUB
+
+ try {
+ const response = await net.fetch(configUrl, {
+ headers: {
+ 'User-Agent': generateUserAgent(),
+ 'Accept': 'application/json',
+ 'X-Client-Id': configManager.getClientId()
+ }
+ })
+ return await response.json() as UpdateConfig
+ } catch (error) {
+ logger.error('Failed to fetch update config:', error)
+ return null
+ }
+}
+```
+
+## Fallback and Error Handling Strategy
+
+1. **Configuration file fetch failure**: Log error, return current version, don't offer updates
+2. **No matching version**: Notify user that current version doesn't support automatic upgrade
+3. **Network exception**: Cache last successfully fetched configuration (optional)
+
+## GitHub Release Requirements
+
+To support intermediate version upgrades, the following files need to be retained:
+
+- **v1.7.0 release** and its latest*.yml files (as upgrade target for users below v1.7)
+- Future intermediate versions (e.g., v2.8.0) need to retain corresponding release and latest*.yml files
+- Complete installation packages for each version
+
+### Currently Required Releases
+
+| Version | Purpose | Must Retain |
+|---------|---------|-------------|
+| v1.7.0 | Upgrade target for users below 1.7 | ✅ Yes |
+| v2.0.0-rc.1 | RC testing channel | ❌ Optional |
+| v2.0.0-beta.1 | Beta testing channel | ❌ Optional |
+| latest | Latest stable version (automatic) | ✅ Yes |
+
+## Advantages
+
+1. **Flexibility**: Supports arbitrarily complex upgrade paths
+2. **Extensibility**: Adding new versions only requires adding new entries to the configuration file
+3. **Maintainability**: Configuration is separated from code, allowing upgrade strategy adjustments without releasing new versions
+4. **Multi-source support**: Automatically selects optimal configuration source based on geolocation
+5. **Version control**: Enforces intermediate version upgrades, ensuring data migration and compatibility
+
+## Future Extensions
+
+- Support more granular version range control (e.g., `>=1.5.0 <1.8.0`)
+- Support multi-step upgrade path hints (e.g., notify user needs 1.5 → 1.8 → 2.0)
+- Support A/B testing and gradual rollout
+- Support local caching and expiration strategy for configuration files
diff --git a/docs/technical/CodeBlockView-en.md b/docs/en/references/components/code-block-view.md
similarity index 98%
rename from docs/technical/CodeBlockView-en.md
rename to docs/en/references/components/code-block-view.md
index 786d7aa029..22a3ea5a1f 100644
--- a/docs/technical/CodeBlockView-en.md
+++ b/docs/en/references/components/code-block-view.md
@@ -85,7 +85,7 @@ Main responsibilities:
- **SvgPreview**: SVG image preview
- **GraphvizPreview**: Graphviz diagram preview
-All special view components share a common architecture for consistent user experience and functionality. For detailed information about these components and their implementation, see [Image Preview Components Documentation](./ImagePreview-en.md).
+All special view components share a common architecture for consistent user experience and functionality. For detailed information about these components and their implementation, see [Image Preview Components Documentation](./image-preview.md).
#### StatusBar
diff --git a/docs/technical/ImagePreview-en.md b/docs/en/references/components/image-preview.md
similarity index 98%
rename from docs/technical/ImagePreview-en.md
rename to docs/en/references/components/image-preview.md
index 383bf5c664..8244f8fe9b 100644
--- a/docs/technical/ImagePreview-en.md
+++ b/docs/en/references/components/image-preview.md
@@ -192,4 +192,4 @@ Image Preview Components integrate seamlessly with CodeBlockView:
- Shared state management
- Responsive layout adaptation
-For more information about the overall CodeBlockView architecture, see [CodeBlockView Documentation](./CodeBlockView-en.md).
+For more information about the overall CodeBlockView architecture, see [CodeBlockView Documentation](./code-block-view.md).
diff --git a/docs/technical/Message.md b/docs/technical/Message.md
deleted file mode 100644
index 673b1cce7b..0000000000
--- a/docs/technical/Message.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# 消息的生命周期
-
-
diff --git a/docs/technical/db.settings.md b/docs/technical/db.settings.md
deleted file mode 100644
index 1d63098851..0000000000
--- a/docs/technical/db.settings.md
+++ /dev/null
@@ -1,11 +0,0 @@
-# 数据库设置字段
-
-此文档包含部分字段的数据类型说明。
-
-## 字段
-
-| 字段名 | 类型 | 说明 |
-| ------------------------------ | ------------------------------ | ------------ |
-| `translate:target:language` | `LanguageCode` | 翻译目标语言 |
-| `translate:source:language` | `LanguageCode` | 翻译源语言 |
-| `translate:bidirectional:pair` | `[LanguageCode, LanguageCode]` | 双向翻译对 |
diff --git a/docs/technical/how-to-use-messageBlock.md b/docs/technical/how-to-use-messageBlock.md
deleted file mode 100644
index f60c2851ce..0000000000
--- a/docs/technical/how-to-use-messageBlock.md
+++ /dev/null
@@ -1,127 +0,0 @@
-# messageBlock.ts 使用指南
-
-该文件定义了用于管理应用程序中所有 `MessageBlock` 实体的 Redux Slice。它使用 Redux Toolkit 的 `createSlice` 和 `createEntityAdapter` 来高效地处理规范化的状态,并提供了一系列 actions 和 selectors 用于与消息块数据交互。
-
-## 核心目标
-
-- **状态管理**: 集中管理所有 `MessageBlock` 的状态。`MessageBlock` 代表消息中的不同内容单元(如文本、代码、图片、引用等)。
-- **规范化**: 使用 `createEntityAdapter` 将 `MessageBlock` 数据存储在规范化的结构中(`{ ids: [], entities: {} }`),这有助于提高性能和简化更新逻辑。
-- **可预测性**: 提供明确的 actions 来修改状态,并通过 selectors 安全地访问状态。
-
-## 关键概念
-
-- **Slice (`createSlice`)**: Redux Toolkit 的核心 API,用于创建包含 reducer 逻辑、action creators 和初始状态的 Redux 模块。
-- **Entity Adapter (`createEntityAdapter`)**: Redux Toolkit 提供的工具,用于简化对规范化数据的 CRUD(创建、读取、更新、删除)操作。它会自动生成 reducer 函数和 selectors。
-- **Selectors**: 用于从 Redux store 中派生和计算数据的函数。Selectors 可以被记忆化(memoized),以提高性能。
-
-## State 结构
-
-`messageBlocks` slice 的状态结构由 `createEntityAdapter` 定义,大致如下:
-
-```typescript
-{
- ids: string[]; // 存储所有 MessageBlock ID 的有序列表
- entities: { [id: string]: MessageBlock }; // 按 ID 存储 MessageBlock 对象的字典
- loadingState: 'idle' | 'loading' | 'succeeded' | 'failed'; // (可选) 其他状态,如加载状态
- error: string | null; // (可选) 错误信息
-}
-```
-
-## Actions
-
-该 slice 导出以下 actions (由 `createSlice` 和 `createEntityAdapter` 自动生成或自定义):
-
-- **`upsertOneBlock(payload: MessageBlock)`**:
-
- - 添加一个新的 `MessageBlock` 或更新一个已存在的 `MessageBlock`。如果 payload 中的 `id` 已存在,则执行更新;否则执行插入。
-
-- **`upsertManyBlocks(payload: MessageBlock[])`**:
-
- - 添加或更新多个 `MessageBlock`。常用于批量加载数据(例如,加载一个 Topic 的所有消息块)。
-
-- **`removeOneBlock(payload: string)`**:
-
- - 根据提供的 `id` (payload) 移除单个 `MessageBlock`。
-
-- **`removeManyBlocks(payload: string[])`**:
-
- - 根据提供的 `id` 数组 (payload) 移除多个 `MessageBlock`。常用于删除消息或清空 Topic 时清理相关的块。
-
-- **`removeAllBlocks()`**:
-
- - 移除 state 中的所有 `MessageBlock` 实体。
-
-- **`updateOneBlock(payload: { id: string; changes: Partial })`**:
-
- - 更新一个已存在的 `MessageBlock`。`payload` 需要包含块的 `id` 和一个包含要更改的字段的 `changes` 对象。
-
-- **`setMessageBlocksLoading(payload: 'idle' | 'loading')`**:
-
- - (自定义) 设置 `loadingState` 属性。
-
-- **`setMessageBlocksError(payload: string)`**:
- - (自定义) 设置 `loadingState` 为 `'failed'` 并记录错误信息。
-
-**使用示例 (在 Thunk 或其他 Dispatch 的地方):**
-
-```typescript
-import { upsertOneBlock, removeManyBlocks, updateOneBlock } from './messageBlock'
-import store from './store' // 假设这是你的 Redux store 实例
-
-// 添加或更新一个块
-const newBlock: MessageBlock = {
- /* ... block data ... */
-}
-store.dispatch(upsertOneBlock(newBlock))
-
-// 更新一个块的内容
-store.dispatch(updateOneBlock({ id: blockId, changes: { content: 'New content' } }))
-
-// 删除多个块
-const blockIdsToRemove = ['id1', 'id2']
-store.dispatch(removeManyBlocks(blockIdsToRemove))
-```
-
-## Selectors
-
-该 slice 导出由 `createEntityAdapter` 生成的基础 selectors,并通过 `messageBlocksSelectors` 对象访问:
-
-- **`messageBlocksSelectors.selectIds(state: RootState): string[]`**: 返回包含所有块 ID 的数组。
-- **`messageBlocksSelectors.selectEntities(state: RootState): { [id: string]: MessageBlock }`**: 返回块 ID 到块对象的映射字典。
-- **`messageBlocksSelectors.selectAll(state: RootState): MessageBlock[]`**: 返回包含所有块对象的数组。
-- **`messageBlocksSelectors.selectTotal(state: RootState): number`**: 返回块的总数。
-- **`messageBlocksSelectors.selectById(state: RootState, id: string): MessageBlock | undefined`**: 根据 ID 返回单个块对象,如果找不到则返回 `undefined`。
-
-**此外,还提供了一个自定义的、记忆化的 selector:**
-
-- **`selectFormattedCitationsByBlockId(state: RootState, blockId: string | undefined): Citation[]`**:
- - 接收一个 `blockId`。
- - 如果该 ID 对应的块是 `CITATION` 类型,则提取并格式化其包含的引用信息(来自网页搜索、知识库等),进行去重和重新编号,最后返回一个 `Citation[]` 数组,用于在 UI 中显示。
- - 如果块不存在或类型不匹配,返回空数组 `[]`。
- - 这个 selector 封装了处理不同引用来源(Gemini, OpenAI, OpenRouter, Zhipu 等)的复杂逻辑。
-
-**使用示例 (在 React 组件或 `useSelector` 中):**
-
-```typescript
-import { useSelector } from 'react-redux'
-import { messageBlocksSelectors, selectFormattedCitationsByBlockId } from './messageBlock'
-import type { RootState } from './store'
-
-// 获取所有块
-const allBlocks = useSelector(messageBlocksSelectors.selectAll)
-
-// 获取特定 ID 的块
-const specificBlock = useSelector((state: RootState) => messageBlocksSelectors.selectById(state, someBlockId))
-
-// 获取特定引用块格式化后的引用列表
-const formattedCitations = useSelector((state: RootState) => selectFormattedCitationsByBlockId(state, citationBlockId))
-
-// 在组件中使用引用数据
-// {formattedCitations.map(citation => ...)}
-```
-
-## 集成
-
-`messageBlock.ts` slice 通常与 `messageThunk.ts` 中的 Thunks 紧密协作。Thunks 负责处理异步逻辑(如 API 调用、数据库操作),并在需要时 dispatch `messageBlock` slice 的 actions 来更新状态。例如,当 `messageThunk` 接收到流式响应时,它会 dispatch `upsertOneBlock` 或 `updateOneBlock` 来实时更新对应的 `MessageBlock`。同样,删除消息的 Thunk 会 dispatch `removeManyBlocks`。
-
-理解 `messageBlock.ts` 的职责是管理**状态本身**,而 `messageThunk.ts` 负责**触发状态变更**的异步流程,这对于维护清晰的应用架构至关重要。
diff --git a/docs/technical/how-to-use-messageThunk.md b/docs/technical/how-to-use-messageThunk.md
deleted file mode 100644
index 86952f99ad..0000000000
--- a/docs/technical/how-to-use-messageThunk.md
+++ /dev/null
@@ -1,105 +0,0 @@
-# messageThunk.ts 使用指南
-
-该文件包含用于管理应用程序中消息流、处理助手交互以及同步 Redux 状态与 IndexedDB 数据库的核心 Thunk Action Creators。主要围绕 `Message` 和 `MessageBlock` 对象进行操作。
-
-## 核心功能
-
-1. **发送/接收消息**: 处理用户消息的发送,触发助手响应,并流式处理返回的数据,将其解析为不同的 `MessageBlock`。
-2. **状态管理**: 确保 Redux store 中的消息和消息块状态与 IndexedDB 中的持久化数据保持一致。
-3. **消息操作**: 提供删除、重发、重新生成、编辑后重发、追加响应、克隆等消息生命周期管理功能。
-4. **Block 处理**: 动态创建、更新和保存各种类型的 `MessageBlock`(文本、思考过程、工具调用、引用、图片、错误、翻译等)。
-
-## 主要 Thunks
-
-以下是一些关键的 Thunk 函数及其用途:
-
-1. **`sendMessage(userMessage, userMessageBlocks, assistant, topicId)`**
-
- - **用途**: 发送一条新的用户消息。
- - **流程**:
- - 保存用户消息 (`userMessage`) 及其块 (`userMessageBlocks`) 到 Redux 和 DB。
- - 检查 `@mentions` 以确定是单模型响应还是多模型响应。
- - 创建助手消息(们)的存根 (Stub)。
- - 将存根添加到 Redux 和 DB。
- - 将核心处理逻辑 `fetchAndProcessAssistantResponseImpl` 添加到该 `topicId` 的队列中以获取实际响应。
- - **Block 相关**: 主要处理用户消息的初始 `MessageBlock` 保存。
-
-2. **`fetchAndProcessAssistantResponseImpl(dispatch, getState, topicId, assistant, assistantMessage)`**
-
- - **用途**: (内部函数) 获取并处理单个助手响应的核心逻辑,被 `sendMessage`, `resend...`, `regenerate...`, `append...` 等调用。
- - **流程**:
- - 设置 Topic 加载状态。
- - 准备上下文消息。
- - 调用 `fetchChatCompletion` API 服务。
- - 使用 `createStreamProcessor` 处理流式响应。
- - 通过各种回调 (`onTextChunk`, `onThinkingChunk`, `onToolCallComplete`, `onImageGenerated`, `onError`, `onComplete` 等) 处理不同类型的事件。
- - **Block 相关**:
- - 根据流事件创建初始 `UNKNOWN` 块。
- - 实时创建和更新 `MAIN_TEXT` 和 `THINKING` 块,使用 `throttledBlockUpdate` 和 `throttledBlockDbUpdate` 进行节流更新。
- - 创建 `TOOL`, `CITATION`, `IMAGE`, `ERROR` 等类型的块。
- - 在事件完成时(如 `onTextComplete`, `onToolCallComplete`)将块状态标记为 `SUCCESS` 或 `ERROR`,并使用 `saveUpdatedBlockToDB` 保存最终状态。
- - 使用 `handleBlockTransition` 管理非流式块(如 `TOOL`, `CITATION`)的添加和状态更新。
-
-3. **`loadTopicMessagesThunk(topicId, forceReload)`**
-
- - **用途**: 从数据库加载指定主题的所有消息及其关联的 `MessageBlock`。
- - **流程**:
- - 从 DB 获取 `Topic` 及其 `messages` 列表。
- - 根据消息 ID 列表从 DB 获取所有相关的 `MessageBlock`。
- - 使用 `upsertManyBlocks` 将块更新到 Redux。
- - 将消息更新到 Redux。
- - **Block 相关**: 负责将持久化的 `MessageBlock` 加载到 Redux 状态。
-
-4. **删除 Thunks**
-
- - `deleteSingleMessageThunk(topicId, messageId)`: 删除单个消息及其所有 `MessageBlock`。
- - `deleteMessageGroupThunk(topicId, askId)`: 删除一个用户消息及其所有相关的助手响应消息和它们的所有 `MessageBlock`。
- - `clearTopicMessagesThunk(topicId)`: 清空主题下的所有消息及其所有 `MessageBlock`。
- - **Block 相关**: 从 Redux 和 DB 中移除指定的 `MessageBlock`。
-
-5. **重发/重新生成 Thunks**
-
- - `resendMessageThunk(topicId, userMessageToResend, assistant)`: 重发用户消息。会重置(清空 Block 并标记为 PENDING)所有与该用户消息关联的助手响应,然后重新请求生成。
- - `resendUserMessageWithEditThunk(topicId, originalMessage, mainTextBlockId, editedContent, assistant)`: 用户编辑消息内容后重发。先更新用户消息的 `MAIN_TEXT` 块内容,然后调用 `resendMessageThunk`。
- - `regenerateAssistantResponseThunk(topicId, assistantMessageToRegenerate, assistant)`: 重新生成单个助手响应。重置该助手消息(清空 Block 并标记为 PENDING),然后重新请求生成。
- - **Block 相关**: 删除旧的 `MessageBlock`,并在重新生成过程中创建新的 `MessageBlock`。
-
-6. **`appendAssistantResponseThunk(topicId, existingAssistantMessageId, newModel, assistant)`**
-
- - **用途**: 在已有的对话上下文中,针对同一个用户问题,使用新选择的模型追加一个新的助手响应。
- - **流程**:
- - 找到现有助手消息以获取原始 `askId`。
- - 创建使用 `newModel` 的新助手消息存根(使用相同的 `askId`)。
- - 添加新存根到 Redux 和 DB。
- - 将 `fetchAndProcessAssistantResponseImpl` 添加到队列以生成新响应。
- - **Block 相关**: 为新的助手响应创建全新的 `MessageBlock`。
-
-7. **`cloneMessagesToNewTopicThunk(sourceTopicId, branchPointIndex, newTopic)`**
-
- - **用途**: 将源主题的部分消息(及其 Block)克隆到一个**已存在**的新主题中。
- - **流程**:
- - 复制指定索引前的消息。
- - 为所有克隆的消息和 Block 生成新的 UUID。
- - 正确映射克隆消息之间的 `askId` 关系。
- - 复制 `MessageBlock` 内容,更新其 `messageId` 指向新的消息 ID。
- - 更新文件引用计数(如果 Block 是文件或图片)。
- - 将克隆的消息和 Block 保存到新主题的 Redux 状态和 DB 中。
- - **Block 相关**: 创建 `MessageBlock` 的副本,并更新其 ID 和 `messageId`。
-
-8. **`initiateTranslationThunk(messageId, topicId, targetLanguage, sourceBlockId?, sourceLanguage?)`**
- - **用途**: 为指定消息启动翻译流程,创建一个初始的 `TRANSLATION` 类型的 `MessageBlock`。
- - **流程**:
- - 创建一个状态为 `STREAMING` 的 `TranslationMessageBlock`。
- - 将其添加到 Redux 和 DB。
- - 更新原消息的 `blocks` 列表以包含新的翻译块 ID。
- - **Block 相关**: 创建并保存一个占位的 `TranslationMessageBlock`。实际翻译内容的获取和填充需要后续步骤。
-
-## 内部机制和注意事项
-
-- **数据库交互**: 通过 `saveMessageAndBlocksToDB`, `updateExistingMessageAndBlocksInDB`, `saveUpdatesToDB`, `saveUpdatedBlockToDB`, `throttledBlockDbUpdate` 等辅助函数与 IndexedDB (`db`) 交互,确保数据持久化。
-- **状态同步**: Thunks 负责协调 Redux Store 和 IndexedDB 之间的数据一致性。
-- **队列 (`getTopicQueue`)**: 使用 `AsyncQueue` 确保对同一主题的操作(尤其是 API 请求)按顺序执行,避免竞态条件。
-- **节流 (`throttle`)**: 对流式响应中频繁的 Block 更新(文本、思考)使用 `lodash.throttle` 优化性能,减少 Redux dispatch 和 DB 写入次数。
-- **错误处理**: `fetchAndProcessAssistantResponseImpl` 内的回调函数(特别是 `onError`)处理流处理和 API 调用中可能出现的错误,并创建 `ERROR` 类型的 `MessageBlock`。
-
-开发者在使用这些 Thunks 时,通常需要提供 `dispatch`, `getState` (由 Redux Thunk 中间件注入),以及如 `topicId`, `assistant` 配置对象, 相关的 `Message` 或 `MessageBlock` 对象/ID 等参数。理解每个 Thunk 的职责和它如何影响消息及块的状态至关重要。
diff --git a/docs/technical/how-to-use-useMessageOperations.md b/docs/technical/how-to-use-useMessageOperations.md
deleted file mode 100644
index df56ad5e5f..0000000000
--- a/docs/technical/how-to-use-useMessageOperations.md
+++ /dev/null
@@ -1,156 +0,0 @@
-# useMessageOperations.ts 使用指南
-
-该文件定义了一个名为 `useMessageOperations` 的自定义 React Hook。这个 Hook 的主要目的是为 React 组件提供一个便捷的接口,用于执行与特定主题(Topic)相关的各种消息操作。它封装了调用 Redux Thunks (`messageThunk.ts`) 和 Actions (`newMessage.ts`, `messageBlock.ts`) 的逻辑,简化了组件与消息数据交互的代码。
-
-## 核心目标
-
-- **封装**: 将复杂的消息操作逻辑(如删除、重发、重新生成、编辑、翻译等)封装在易于使用的函数中。
-- **简化**: 让组件可以直接调用这些操作函数,而无需直接与 Redux `dispatch` 或 Thunks 交互。
-- **上下文关联**: 所有操作都与传入的 `topic` 对象相关联,确保操作作用于正确的主题。
-
-## 如何使用
-
-在你的 React 函数组件中,导入并调用 `useMessageOperations` Hook,并传入当前活动的 `Topic` 对象。
-
-```typescript
-import React from 'react';
-import { useMessageOperations } from '@renderer/hooks/useMessageOperations';
-import type { Topic, Message, Assistant, Model } from '@renderer/types';
-
-interface MyComponentProps {
- currentTopic: Topic;
- currentAssistant: Assistant;
-}
-
-function MyComponent({ currentTopic, currentAssistant }: MyComponentProps) {
- const {
- deleteMessage,
- resendMessage,
- regenerateAssistantMessage,
- appendAssistantResponse,
- getTranslationUpdater,
- createTopicBranch,
- // ... 其他操作函数
- } = useMessageOperations(currentTopic);
-
- const handleDelete = (messageId: string) => {
- deleteMessage(messageId);
- };
-
- const handleResend = (message: Message) => {
- resendMessage(message, currentAssistant);
- };
-
- const handleAppend = (existingMsg: Message, newModel: Model) => {
- appendAssistantResponse(existingMsg, newModel, currentAssistant);
- }
-
- // ... 在组件中使用其他操作函数
-
- return (
-
- {/* Component UI */}
-
- {/* ... */}
-
- );
-}
-```
-
-## 返回值
-
-`useMessageOperations(topic)` Hook 返回一个包含以下函数和值的对象:
-
-- **`deleteMessage(id: string)`**:
-
- - 删除指定 `id` 的单个消息。
- - 内部调用 `deleteSingleMessageThunk`。
-
-- **`deleteGroupMessages(askId: string)`**:
-
- - 删除与指定 `askId` 相关联的一组消息(通常是用户提问及其所有助手回答)。
- - 内部调用 `deleteMessageGroupThunk`。
-
-- **`editMessage(messageId: string, updates: Partial)`**:
-
- - 更新指定 `messageId` 的消息的部分属性。
- - **注意**: 目前主要用于更新 Redux 状态
- - 内部调用 `newMessagesActions.updateMessage`。
-
-- **`resendMessage(message: Message, assistant: Assistant)`**:
-
- - 重新发送指定的用户消息 (`message`),这将触发其所有关联助手响应的重新生成。
- - 内部调用 `resendMessageThunk`。
-
-- **`resendUserMessageWithEdit(message: Message, editedContent: string, assistant: Assistant)`**:
-
- - 在用户消息的主要文本块被编辑后,重新发送该消息。
- - 会先查找消息的 `MAIN_TEXT` 块 ID,然后调用 `resendUserMessageWithEditThunk`。
-
-- **`clearTopicMessages(_topicId?: string)`**:
-
- - 清除当前主题(或可选的指定 `_topicId`)下的所有消息。
- - 内部调用 `clearTopicMessagesThunk`。
-
-- **`createNewContext()`**:
-
- - 发出一个全局事件 (`EVENT_NAMES.NEW_CONTEXT`),通常用于通知 UI 清空显示,准备新的上下文。不直接修改 Redux 状态。
-
-- **`displayCount`**:
-
- - (非操作函数) 从 Redux store 中获取当前的 `displayCount` 值。
-
-- **`pauseMessages()`**:
-
- - 尝试中止当前主题中正在进行的消息生成(状态为 `processing` 或 `pending`)。
- - 通过查找相关的 `askId` 并调用 `abortCompletion` 来实现。
- - 同时会 dispatch `setTopicLoading` action 将加载状态设为 `false`。
-
-- **`resumeMessage(message: Message, assistant: Assistant)`**:
-
- - 恢复/重新发送一个用户消息。目前实现为直接调用 `resendMessage`。
-
-- **`regenerateAssistantMessage(message: Message, assistant: Assistant)`**:
-
- - 重新生成指定的**助手**消息 (`message`) 的响应。
- - 内部调用 `regenerateAssistantResponseThunk`。
-
-- **`appendAssistantResponse(existingAssistantMessage: Message, newModel: Model, assistant: Assistant)`**:
-
- - 针对 `existingAssistantMessage` 所回复的**同一用户提问**,使用 `newModel` 追加一个新的助手响应。
- - 内部调用 `appendAssistantResponseThunk`。
-
-- **`getTranslationUpdater(messageId: string, targetLanguage: string, sourceBlockId?: string, sourceLanguage?: string)`**:
-
- - **用途**: 获取一个用于逐步更新翻译块内容的函数。
- - **流程**:
- 1. 内部调用 `initiateTranslationThunk` 来创建或获取一个 `TRANSLATION` 类型的 `MessageBlock`,并获取其 `blockId`。
- 2. 返回一个**异步更新函数**。
- - **返回的更新函数 `(accumulatedText: string, isComplete?: boolean) => void`**:
- - 接收累积的翻译文本和完成状态。
- - 调用 `updateOneBlock` 更新 Redux 中的翻译块内容和状态 (`STREAMING` 或 `SUCCESS`)。
- - 调用 `throttledBlockDbUpdate` 将更新(节流地)保存到数据库。
- - 如果初始化失败(Thunk 返回 `undefined`),则此函数返回 `null`。
-
-- **`createTopicBranch(sourceTopicId: string, branchPointIndex: number, newTopic: Topic)`**:
- - 创建一个主题分支,将 `sourceTopicId` 主题中 `branchPointIndex` 索引之前的消息克隆到 `newTopic` 中。
- - **注意**: `newTopic` 对象必须是调用此函数**之前**已经创建并添加到 Redux 和数据库中的。
- - 内部调用 `cloneMessagesToNewTopicThunk`。
-
-## 依赖
-
-- **`topic: Topic`**: 必须传入当前操作上下文的主题对象。Hook 返回的操作函数将始终作用于这个主题的 `topic.id`。
-- **Redux `dispatch`**: Hook 内部使用 `useAppDispatch` 获取 `dispatch` 函数来调用 actions 和 thunks。
-
-## 相关 Hooks
-
-在同一文件中还定义了两个辅助 Hook:
-
-- **`useTopicMessages(topic: Topic)`**:
-
- - 使用 `selectMessagesForTopic` selector 来获取并返回指定主题的消息列表。
-
-- **`useTopicLoading(topic: Topic)`**:
- - 使用 `selectNewTopicLoading` selector 来获取并返回指定主题的加载状态。
-
-这些 Hook 可以与 `useMessageOperations` 结合使用,方便地在组件中获取消息数据、加载状态,并执行相关操作。
diff --git a/docs/README.zh.md b/docs/zh/README.md
similarity index 97%
rename from docs/README.zh.md
rename to docs/zh/README.md
index 84546c57ee..f8a1f1ab8c 100644
--- a/docs/README.zh.md
+++ b/docs/zh/README.md
@@ -34,7 +34,7 @@
- English | 中文 | 官方网站 | 文档 | 开发 | 反馈
+ English | 中文 | 官方网站 | 文档 | 开发 | 反馈
@@ -70,7 +70,7 @@ Cherry Studio 是一款支持多个大语言模型(LLM)服务商的桌面客
👏 欢迎加入 [Telegram 群组](https://t.me/CherryStudioAI)|[Discord](https://discord.gg/wez8HtpxqQ) | [QQ群(575014769)](https://qm.qq.com/q/lo0D4qVZKi)
-❤️ 喜欢 Cherry Studio? 点亮小星星 🌟 或 [赞助开发者](sponsor.md)! ❤️
+❤️ 喜欢 Cherry Studio? 点亮小星星 🌟 或 [赞助开发者](./guides/sponsor.md)! ❤️
# 📖 使用教程
@@ -181,7 +181,7 @@ https://docs.cherry-ai.com
6. **社区参与**:加入讨论并帮助用户
7. **推广使用**:宣传 Cherry Studio
-参考[分支策略](branching-strategy-zh.md)了解贡献指南
+参考[分支策略](./guides/branching-strategy.md)了解贡献指南
## 入门
@@ -190,7 +190,7 @@ https://docs.cherry-ai.com
3. **提交更改**:提交并推送您的更改
4. **打开 Pull Request**:描述您的更改和原因
-有关更详细的指南,请参阅我们的 [贡献指南](CONTRIBUTING.zh.md)
+有关更详细的指南,请参阅我们的 [贡献指南](./guides/contributing.md)
感谢您的支持和贡献!
diff --git a/docs/branching-strategy-zh.md b/docs/zh/guides/branching-strategy.md
similarity index 98%
rename from docs/branching-strategy-zh.md
rename to docs/zh/guides/branching-strategy.md
index 36b7ca263d..c6ab0eb0b5 100644
--- a/docs/branching-strategy-zh.md
+++ b/docs/zh/guides/branching-strategy.md
@@ -16,7 +16,7 @@ Cherry Studio 采用结构化的分支策略来维护代码质量并简化开发
- 只接受文档更新和 bug 修复
- 经过完整测试后可以发布到生产环境
-关于测试计划所使用的`testplan`分支,请查阅[测试计划](testplan-zh.md)。
+关于测试计划所使用的`testplan`分支,请查阅[测试计划](./test-plan.md)。
## 贡献分支
diff --git a/docs/CONTRIBUTING.zh.md b/docs/zh/guides/contributing.md
similarity index 94%
rename from docs/CONTRIBUTING.zh.md
rename to docs/zh/guides/contributing.md
index 98efcc286e..dcea60cfbc 100644
--- a/docs/CONTRIBUTING.zh.md
+++ b/docs/zh/guides/contributing.md
@@ -1,6 +1,6 @@
# Cherry Studio 贡献者指南
-[**English**](../CONTRIBUTING.md) | [**中文**](CONTRIBUTING.zh.md)
+[**English**](../../../CONTRIBUTING.md) | **中文**
欢迎来到 Cherry Studio 的贡献者社区!我们致力于将 Cherry Studio 打造成一个长期提供价值的项目,并希望邀请更多的开发者加入我们的行列。无论您是经验丰富的开发者还是刚刚起步的初学者,您的贡献都将帮助我们更好地服务用户,提升软件质量。
@@ -24,7 +24,7 @@
## 开始之前
-请确保阅读了[行为准则](../CODE_OF_CONDUCT.md)和[LICENSE](../LICENSE)。
+请确保阅读了[行为准则](../../../CODE_OF_CONDUCT.md)和[LICENSE](../../../LICENSE)。
## 开始贡献
@@ -32,7 +32,7 @@
### 测试
-未经测试的功能等同于不存在。为确保代码真正有效,应通过单元测试和功能测试覆盖相关流程。因此,在考虑贡献时,也请考虑可测试性。所有测试均可本地运行,无需依赖 CI。请参阅[开发者指南](dev.md#test)中的“Test”部分。
+未经测试的功能等同于不存在。为确保代码真正有效,应通过单元测试和功能测试覆盖相关流程。因此,在考虑贡献时,也请考虑可测试性。所有测试均可本地运行,无需依赖 CI。请参阅[开发者指南](./development.md#test)中的"Test"部分。
### 拉取请求的自动化测试
@@ -60,11 +60,11 @@ git commit --signoff -m "Your commit message"
### 获取代码审查/合并
-维护者在此帮助您在合理时间内实现您的用例。他们会尽力在合理时间内审查您的代码并提供建设性反馈。但如果您在审查过程中受阻,或认为您的 Pull Request 未得到应有的关注,请通过 Issue 中的评论或者[社群](README.zh.md#-community)联系我们
+维护者在此帮助您在合理时间内实现您的用例。他们会尽力在合理时间内审查您的代码并提供建设性反馈。但如果您在审查过程中受阻,或认为您的 Pull Request 未得到应有的关注,请通过 Issue 中的评论或者[社群](../README.md#-community)联系我们
### 参与测试计划
-测试计划旨在为用户提供更稳定的应用体验和更快的迭代速度,详细情况请参阅[测试计划](testplan-zh.md)。
+测试计划旨在为用户提供更稳定的应用体验和更快的迭代速度,详细情况请参阅[测试计划](./test-plan.md)。
### 其他建议
diff --git a/docs/zh/guides/development.md b/docs/zh/guides/development.md
new file mode 100644
index 0000000000..fe67742768
--- /dev/null
+++ b/docs/zh/guides/development.md
@@ -0,0 +1,73 @@
+# 🖥️ Develop
+
+## IDE Setup
+
+- Editor: [Cursor](https://www.cursor.com/), etc. Any VS Code compatible editor.
+- Linter: [ESLint](https://marketplace.visualstudio.com/items?itemName=dbaeumer.vscode-eslint)
+- Formatter: [Biome](https://marketplace.visualstudio.com/items?itemName=biomejs.biome)
+
+## Project Setup
+
+### Install
+
+```bash
+yarn
+```
+
+### Development
+
+### Setup Node.js
+
+Download and install [Node.js v22.x.x](https://nodejs.org/en/download)
+
+### Setup Yarn
+
+```bash
+corepack enable
+corepack prepare yarn@4.9.1 --activate
+```
+
+### Install Dependencies
+
+```bash
+yarn install
+```
+
+### ENV
+
+```bash
+copy .env.example .env
+```
+
+### Start
+
+```bash
+yarn dev
+```
+
+### Debug
+
+```bash
+yarn debug
+```
+
+Then input chrome://inspect in browser
+
+### Test
+
+```bash
+yarn test
+```
+
+### Build
+
+```bash
+# For windows
+$ yarn build:win
+
+# For macOS
+$ yarn build:mac
+
+# For Linux
+$ yarn build:linux
+```
diff --git a/docs/technical/how-to-i18n-zh.md b/docs/zh/guides/i18n.md
similarity index 97%
rename from docs/technical/how-to-i18n-zh.md
rename to docs/zh/guides/i18n.md
index 5d0a93c369..82624d35c8 100644
--- a/docs/technical/how-to-i18n-zh.md
+++ b/docs/zh/guides/i18n.md
@@ -15,11 +15,11 @@ i18n ally是一个强大的VSCode插件,它能在开发阶段提供实时反
### 效果展示
-
+
-
+
-
+
## i18n 约定
diff --git a/docs/technical/how-to-use-logger-zh.md b/docs/zh/guides/logging.md
similarity index 100%
rename from docs/technical/how-to-use-logger-zh.md
rename to docs/zh/guides/logging.md
diff --git a/docs/features/memory-guide-zh.md b/docs/zh/guides/memory.md
similarity index 100%
rename from docs/features/memory-guide-zh.md
rename to docs/zh/guides/memory.md
diff --git a/docs/technical/how-to-write-middlewares.md b/docs/zh/guides/middleware.md
similarity index 100%
rename from docs/technical/how-to-write-middlewares.md
rename to docs/zh/guides/middleware.md
diff --git a/docs/sponsor.md b/docs/zh/guides/sponsor.md
similarity index 100%
rename from docs/sponsor.md
rename to docs/zh/guides/sponsor.md
diff --git a/docs/testplan-zh.md b/docs/zh/guides/test-plan.md
similarity index 91%
rename from docs/testplan-zh.md
rename to docs/zh/guides/test-plan.md
index ed4913d4a4..42147e8990 100644
--- a/docs/testplan-zh.md
+++ b/docs/zh/guides/test-plan.md
@@ -11,13 +11,15 @@
用户可以在软件的`设置`-`关于`中,开启“测试计划”并选择版本通道。请注意“测试计划”的版本无法保证数据的一致性,请使用前一定要备份数据。
+用户选择RC版通道或Beta版通道后,若发布了正式版,仍旧会升级到正式版。
+
用户在测试过程中发现的BUG,欢迎提交issue或通过其他渠道反馈。用户的反馈对我们非常重要。
## 开发者指南
### 参与测试计划
-开发者按照[贡献者指南](CONTRIBUTING.zh.md)要求正常提交`PR`(并注意提交target为`main`)。仓库维护者会综合考虑(例如该功能对应用的影响程度,功能的重要性,是否需要更广泛的测试等),决定该`PR`是否应加入测试计划。
+开发者按照[贡献者指南](./contributing.md)要求正常提交`PR`(并注意提交target为`main`)。仓库维护者会综合考虑(例如该功能对应用的影响程度,功能的重要性,是否需要更广泛的测试等),决定该`PR`是否应加入测试计划。
若该`PR`加入测试计划,仓库维护者会做如下操作:
diff --git a/docs/zh/references/app-upgrade.md b/docs/zh/references/app-upgrade.md
new file mode 100644
index 0000000000..29f9f75d79
--- /dev/null
+++ b/docs/zh/references/app-upgrade.md
@@ -0,0 +1,430 @@
+# 更新配置系统设计文档
+
+## 背景
+
+当前 AppUpdater 直接请求 GitHub API 获取 beta 和 rc 的更新信息。为了支持国内用户,需要根据 IP 地理位置,分别从 GitHub/GitCode 获取一个固定的 JSON 配置文件,该文件包含所有渠道的更新地址。
+
+## 设计目标
+
+1. 支持根据 IP 地理位置选择不同的配置源(GitHub/GitCode)
+2. 支持版本兼容性控制(如 v1.x 以下必须先升级到 v1.7.0 才能升级到 v2.0)
+3. 易于扩展,支持未来多个主版本的升级路径(v1.6 → v1.7 → v2.0 → v2.8 → v3.0)
+4. 保持与现有 electron-updater 机制的兼容性
+
+## 当前版本策略
+
+- **v1.7.x** 是 1.x 系列的最后版本
+- **v1.7.0 以下**的用户必须先升级到 v1.7.0(或更高的 1.7.x 版本)
+- **v1.7.0 及以上**的用户可以直接升级到 v2.x.x
+
+## 自动化工作流
+
+`x-files/app-upgrade-config/app-upgrade-config.json` 由 [`Update App Upgrade Config`](../../.github/workflows/update-app-upgrade-config.yml) workflow 自动同步。工作流会调用 [`scripts/update-app-upgrade-config.ts`](../../scripts/update-app-upgrade-config.ts) 脚本,根据指定 tag 更新 `x-files/app-upgrade-config` 分支上的配置文件。
+
+### 触发条件
+
+- **Release 事件(`release: released/prereleased`)**
+ - Draft release 会被忽略。
+ - 当 GitHub 将 release 标记为 *prerelease* 时,tag 必须包含 `-beta`/`-rc`(可带序号),否则直接跳过。
+ - 当 release 标记为稳定版时,tag 必须与 GitHub API 返回的最新稳定版本一致,防止发布历史 tag 时意外挂起工作流。
+ - 满足上述条件后,工作流会根据语义化版本判断渠道(`latest`/`beta`/`rc`),并通过 `IS_PRERELEASE` 传递给脚本。
+- **手动触发(`workflow_dispatch`)**
+ - 必填:`tag`(例:`v2.0.1`);选填:`is_prerelease`(默认 `false`)。
+ - 当 `is_prerelease=true` 时,同样要求 tag 带有 beta/rc 后缀。
+ - 手动运行仍会请求 GitHub 最新 release 信息,用于在 PR 说明中标注该 tag 是否是最新稳定版。
+
+### 工作流步骤
+
+1. **检查与元数据准备**:`Check if should proceed` 和 `Prepare metadata` 步骤会计算 tag、prerelease 标志、是否最新版本以及用于分支名的 `safe_tag`。若任意校验失败,工作流立即退出。
+2. **检出分支**:默认分支被检出到 `main/`,长期维护的 `x-files/app-upgrade-config` 分支则在 `cs/` 中,所有改动都发生在 `cs/`。
+3. **安装工具链**:安装 Node.js 22、启用 Corepack,并在 `main/` 目录执行 `yarn install --immutable`。
+4. **运行更新脚本**:执行 `yarn tsx scripts/update-app-upgrade-config.ts --tag --config ../cs/app-upgrade-config.json --is-prerelease `。
+ - 脚本会标准化 tag(去掉 `v` 前缀等)、识别渠道、加载 `config/app-upgrade-segments.json` 中的分段规则。
+ - 校验 prerelease 标志与语义后缀是否匹配、强制锁定的 segment 是否满足、生成镜像的下载地址,并检查 release 是否已经在 GitHub/GitCode 可用(latest 渠道在 GitCode 不可用时会回退到 `https://releases.cherry-ai.com`)。
+ - 更新对应的渠道配置后,脚本会按 semver 排序写回 JSON,并刷新 `lastUpdated`。
+5. **检测变更并创建 PR**:若 `cs/app-upgrade-config.json` 有变更,则创建 `chore/update-app-upgrade-config/` 分支,提交信息为 `🤖 chore: sync app-upgrade-config for `,并向 `x-files/app-upgrade-config` 提 PR;无变更则输出提示。
+
+### 手动触发指南
+
+1. 进入 Cherry Studio 仓库的 GitHub **Actions** 页面,选择 **Update App Upgrade Config** 工作流。
+2. 点击 **Run workflow**,保持默认分支(通常为 `main`),填写 `tag`(如 `v2.1.0`)。
+3. 只有在 tag 带 `-beta`/`-rc` 后缀时才勾选 `is_prerelease`,稳定版保持默认。
+4. 启动运行并等待完成,随后到 `x-files/app-upgrade-config` 分支的 PR 查看 `app-upgrade-config.json` 的变更并在验证后合并。
+
+## JSON 配置文件格式
+
+### 文件位置
+
+- **GitHub**: `https://raw.githubusercontent.com/CherryHQ/cherry-studio/refs/heads/x-files/app-upgrade-config/app-upgrade-config.json`
+- **GitCode**: `https://gitcode.com/CherryHQ/cherry-studio/raw/x-files/app-upgrade-config/app-upgrade-config.json`
+
+**说明**:两个镜像源提供相同的配置文件,统一托管在 `x-files/app-upgrade-config` 分支上。客户端根据 IP 地理位置自动选择最优镜像源。
+
+### 配置结构(当前实际配置)
+
+```json
+{
+ "lastUpdated": "2025-01-05T00:00:00Z",
+ "versions": {
+ "1.6.7": {
+ "minCompatibleVersion": "1.0.0",
+ "description": "Last stable v1.7.x release - required intermediate version for users below v1.7",
+ "channels": {
+ "latest": {
+ "version": "1.6.7",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.7",
+ "gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v1.6.7"
+ }
+ },
+ "rc": {
+ "version": "1.6.0-rc.5",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5",
+ "gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5"
+ }
+ },
+ "beta": {
+ "version": "1.6.7-beta.3",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3",
+ "gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3"
+ }
+ }
+ }
+ },
+ "2.0.0": {
+ "minCompatibleVersion": "1.7.0",
+ "description": "Major release v2.0 - required intermediate version for v2.x upgrades",
+ "channels": {
+ "latest": null,
+ "rc": null,
+ "beta": null
+ }
+ }
+ }
+}
+```
+
+### 未来扩展示例
+
+当需要发布 v3.0 时,如果需要强制用户先升级到 v2.8,可以添加:
+
+```json
+{
+ "2.8.0": {
+ "minCompatibleVersion": "2.0.0",
+ "description": "Stable v2.8 - required for v3 upgrade",
+ "channels": {
+ "latest": {
+ "version": "2.8.0",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/v2.8.0",
+ "gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v2.8.0"
+ }
+ },
+ "rc": null,
+ "beta": null
+ }
+ },
+ "3.0.0": {
+ "minCompatibleVersion": "2.8.0",
+ "description": "Major release v3.0",
+ "channels": {
+ "latest": {
+ "version": "3.0.0",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/latest",
+ "gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/latest"
+ }
+ },
+ "rc": {
+ "version": "3.0.0-rc.1",
+ "feedUrls": {
+ "github": "https://github.com/CherryHQ/cherry-studio/releases/download/v3.0.0-rc.1",
+ "gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v3.0.0-rc.1"
+ }
+ },
+ "beta": null
+ }
+ }
+}
+```
+
+### 字段说明
+
+- `lastUpdated`: 配置文件最后更新时间(ISO 8601 格式)
+- `versions`: 版本配置对象,key 为版本号,按语义化版本排序
+ - `minCompatibleVersion`: 可以升级到此版本的最低兼容版本
+ - `description`: 版本描述
+ - `channels`: 更新渠道配置
+ - `latest`: 稳定版渠道
+ - `rc`: Release Candidate 渠道
+ - `beta`: Beta 测试渠道
+ - 每个渠道包含:
+ - `version`: 该渠道的版本号
+ - `feedUrls`: 多镜像源 URL 配置
+ - `github`: GitHub 镜像源的 electron-updater feed URL
+ - `gitcode`: GitCode 镜像源的 electron-updater feed URL
+ - `metadata`: 自动化匹配所需的稳定标识
+ - `segmentId`: 来自 `config/app-upgrade-segments.json` 的段位 ID
+ - `segmentType`: 可选字段(`legacy` | `breaking` | `latest`),便于文档/调试
+
+## TypeScript 类型定义
+
+```typescript
+// 镜像源枚举
+enum UpdateMirror {
+ GITHUB = 'github',
+ GITCODE = 'gitcode'
+}
+
+interface UpdateConfig {
+ lastUpdated: string
+ versions: {
+ [versionKey: string]: VersionConfig
+ }
+}
+
+interface VersionConfig {
+ minCompatibleVersion: string
+ description: string
+ channels: {
+ latest: ChannelConfig | null
+ rc: ChannelConfig | null
+ beta: ChannelConfig | null
+ }
+ metadata?: {
+ segmentId: string
+ segmentType?: 'legacy' | 'breaking' | 'latest'
+ }
+}
+
+interface ChannelConfig {
+ version: string
+ feedUrls: Record
+ // 等同于:
+ // feedUrls: {
+ // github: string
+ // gitcode: string
+ // }
+}
+```
+
+## 段位元数据(Break Change 标记)
+
+- 所有段位定义(如 `legacy-v1`、`gateway-v2` 等)集中在 `config/app-upgrade-segments.json`,用于描述匹配范围、`segmentId`、`segmentType`、默认 `minCompatibleVersion/description` 以及各渠道的 URL 模板。
+- `versions` 下的每个节点都会带上 `metadata.segmentId`。自动脚本始终依据该 ID 来定位并更新条目,即便 key 从 `2.1.5` 切换到 `2.1.6` 也不会错位。
+- 如果某段需要锁死在特定版本(例如 `2.0.0` 的 break change),可在段定义中设置 `segmentType: "breaking"` 并提供 `lockedVersion`,脚本在遇到不匹配的 tag 时会短路报错,保证升级路径安全。
+- 面对未来新的断层(例如 `3.0.0`),只需要在段定义里新增一段,自动化即可识别并更新。
+
+## 自动化工作流
+
+`.github/workflows/update-app-upgrade-config.yml` 会在 GitHub Release(包含正常发布与 Pre Release)触发:
+
+1. 同时 Checkout 仓库默认分支(用于脚本)和 `x-files/app-upgrade-config` 分支(真实托管配置的分支)。
+2. 在默认分支目录执行 `yarn tsx scripts/update-app-upgrade-config.ts --tag --config ../cs/app-upgrade-config.json`,直接重写 `x-files/app-upgrade-config` 分支里的配置文件。
+3. 如果 `app-upgrade-config.json` 有变化,则通过 `peter-evans/create-pull-request` 自动创建一个指向 `x-files/app-upgrade-config` 的 PR,Diff 仅包含该文件。
+
+如需本地调试,可执行 `yarn update:upgrade-config --tag v2.1.6 --config ../cs/app-upgrade-config.json`(加 `--dry-run` 仅打印结果)来复现 CI 行为。若需要暂时跳过 GitHub/GitCode Release 页面是否就绪的校验,可在 `--dry-run` 的同时附加 `--skip-release-checks`。不加 `--config` 时默认更新当前工作目录(通常是 main 分支)下的副本,方便文档/审查。
+
+## 版本匹配逻辑
+
+### 算法流程
+
+1. 获取用户当前版本(`currentVersion`)和请求的渠道(`requestedChannel`)
+2. 获取配置文件中所有版本号,按语义化版本从大到小排序
+3. 遍历排序后的版本列表:
+ - 检查 `currentVersion >= minCompatibleVersion`
+ - 检查请求的 `channel` 是否存在且不为 `null`
+ - 如果满足条件,返回该渠道配置
+4. 如果没有找到匹配版本,返回 `null`
+
+### 伪代码实现
+
+```typescript
+function findCompatibleVersion(
+ currentVersion: string,
+ requestedChannel: UpgradeChannel,
+ config: UpdateConfig
+): ChannelConfig | null {
+ // 获取所有版本号并从大到小排序
+ const versions = Object.keys(config.versions).sort(semver.rcompare)
+
+ for (const versionKey of versions) {
+ const versionConfig = config.versions[versionKey]
+ const channelConfig = versionConfig.channels[requestedChannel]
+
+ // 检查版本兼容性和渠道可用性
+ if (
+ semver.gte(currentVersion, versionConfig.minCompatibleVersion) &&
+ channelConfig !== null
+ ) {
+ return channelConfig
+ }
+ }
+
+ return null // 没有找到兼容版本
+}
+```
+
+## 升级路径示例
+
+### 场景 1: v1.6.5 用户升级(低于 1.7)
+
+- **当前版本**: 1.6.5
+- **请求渠道**: latest
+- **匹配结果**: 1.7.0
+- **原因**: 1.6.5 >= 0.0.0(满足 1.7.0 的 minCompatibleVersion),但不满足 2.0.0 的 minCompatibleVersion (1.7.0)
+- **操作**: 提示用户升级到 1.7.0,这是升级到 v2.x 的必要中间版本
+
+### 场景 2: v1.6.5 用户请求 rc/beta
+
+- **当前版本**: 1.6.5
+- **请求渠道**: rc 或 beta
+- **匹配结果**: 1.7.0 (latest)
+- **原因**: 1.7.0 版本不提供 rc/beta 渠道(值为 null)
+- **操作**: 升级到 1.7.0 稳定版
+
+### 场景 3: v1.7.0 用户升级到最新版
+
+- **当前版本**: 1.7.0
+- **请求渠道**: latest
+- **匹配结果**: 2.0.0
+- **原因**: 1.7.0 >= 1.7.0(满足 2.0.0 的 minCompatibleVersion)
+- **操作**: 直接升级到 2.0.0(当前最新稳定版)
+
+### 场景 4: v1.7.2 用户升级到 RC 版本
+
+- **当前版本**: 1.7.2
+- **请求渠道**: rc
+- **匹配结果**: 2.0.0-rc.1
+- **原因**: 1.7.2 >= 1.7.0(满足 2.0.0 的 minCompatibleVersion),且 rc 渠道存在
+- **操作**: 升级到 2.0.0-rc.1
+
+### 场景 5: v1.7.0 用户升级到 Beta 版本
+
+- **当前版本**: 1.7.0
+- **请求渠道**: beta
+- **匹配结果**: 2.0.0-beta.1
+- **原因**: 1.7.0 >= 1.7.0,且 beta 渠道存在
+- **操作**: 升级到 2.0.0-beta.1
+
+### 场景 6: v2.5.0 用户升级(未来)
+
+假设已添加 v2.8.0 和 v3.0.0 配置:
+- **当前版本**: 2.5.0
+- **请求渠道**: latest
+- **匹配结果**: 2.8.0
+- **原因**: 2.5.0 >= 2.0.0(满足 2.8.0 的 minCompatibleVersion),但不满足 3.0.0 的要求
+- **操作**: 提示用户升级到 2.8.0,这是升级到 v3.x 的必要中间版本
+
+## 代码改动计划
+
+### 主要修改
+
+1. **新增方法**
+ - `_fetchUpdateConfig(ipCountry: string): Promise` - 根据 IP 获取配置文件
+ - `_findCompatibleChannel(currentVersion: string, channel: UpgradeChannel, config: UpdateConfig): ChannelConfig | null` - 查找兼容的渠道配置
+
+2. **修改方法**
+ - `_getReleaseVersionFromGithub()` → 移除或重构为 `_getChannelFeedUrl()`
+ - `_setFeedUrl()` - 使用新的配置系统替代现有逻辑
+
+3. **新增类型定义**
+ - `UpdateConfig`
+ - `VersionConfig`
+ - `ChannelConfig`
+
+### 镜像源选择逻辑
+
+客户端根据 IP 地理位置自动选择最优镜像源:
+
+```typescript
+private async _setFeedUrl() {
+ const currentVersion = app.getVersion()
+ const testPlan = configManager.getTestPlan()
+ const requestedChannel = testPlan ? this._getTestChannel() : UpgradeChannel.LATEST
+
+ // 根据 IP 国家确定镜像源
+ const ipCountry = await getIpCountry()
+ const mirror = ipCountry.toLowerCase() === 'cn' ? 'gitcode' : 'github'
+
+ // 获取更新配置
+ const config = await this._fetchUpdateConfig(mirror)
+
+ if (config) {
+ const channelConfig = this._findCompatibleChannel(currentVersion, requestedChannel, config)
+ if (channelConfig) {
+ // 从配置中选择对应镜像源的 URL
+ const feedUrl = channelConfig.feedUrls[mirror]
+ this._setChannel(requestedChannel, feedUrl)
+ return
+ }
+ }
+
+ // Fallback 逻辑
+ const defaultFeedUrl = mirror === 'gitcode'
+ ? FeedUrl.PRODUCTION
+ : FeedUrl.GITHUB_LATEST
+ this._setChannel(UpgradeChannel.LATEST, defaultFeedUrl)
+}
+
+private async _fetchUpdateConfig(mirror: 'github' | 'gitcode'): Promise {
+ const configUrl = mirror === 'gitcode'
+ ? UpdateConfigUrl.GITCODE
+ : UpdateConfigUrl.GITHUB
+
+ try {
+ const response = await net.fetch(configUrl, {
+ headers: {
+ 'User-Agent': generateUserAgent(),
+ 'Accept': 'application/json',
+ 'X-Client-Id': configManager.getClientId()
+ }
+ })
+ return await response.json() as UpdateConfig
+ } catch (error) {
+ logger.error('Failed to fetch update config:', error)
+ return null
+ }
+}
+```
+
+## 降级和容错策略
+
+1. **配置文件获取失败**: 记录错误日志,返回当前版本,不提供更新
+2. **没有匹配的版本**: 提示用户当前版本不支持自动升级
+3. **网络异常**: 缓存上次成功获取的配置(可选)
+
+## GitHub Release 要求
+
+为支持中间版本升级,需要保留以下文件:
+
+- **v1.7.0 release** 及其 latest*.yml 文件(作为 v1.7 以下用户的升级目标)
+- 未来如需强制中间版本(如 v2.8.0),需要保留对应的 release 和 latest*.yml 文件
+- 各版本的完整安装包
+
+### 当前需要的 Release
+
+| 版本 | 用途 | 必须保留 |
+|------|------|---------|
+| v1.7.0 | 1.7 以下用户的升级目标 | ✅ 是 |
+| v2.0.0-rc.1 | RC 测试渠道 | ❌ 可选 |
+| v2.0.0-beta.1 | Beta 测试渠道 | ❌ 可选 |
+| latest | 最新稳定版(自动) | ✅ 是 |
+
+## 优势
+
+1. **灵活性**: 支持任意复杂的升级路径
+2. **可扩展性**: 新增版本只需在配置文件中添加新条目
+3. **可维护性**: 配置与代码分离,无需发版即可调整升级策略
+4. **多源支持**: 自动根据地理位置选择最优配置源
+5. **版本控制**: 强制中间版本升级,确保数据迁移和兼容性
+
+## 未来扩展
+
+- 支持更细粒度的版本范围控制(如 `>=1.5.0 <1.8.0`)
+- 支持多步升级路径提示(如提示用户需要 1.5 → 1.8 → 2.0)
+- 支持 A/B 测试和灰度发布
+- 支持配置文件的本地缓存和过期策略
diff --git a/docs/technical/code-execution.md b/docs/zh/references/code-execution.md
similarity index 100%
rename from docs/technical/code-execution.md
rename to docs/zh/references/code-execution.md
diff --git a/docs/technical/CodeBlockView-zh.md b/docs/zh/references/components/code-block-view.md
similarity index 98%
rename from docs/technical/CodeBlockView-zh.md
rename to docs/zh/references/components/code-block-view.md
index a817e99361..6805aac7a9 100644
--- a/docs/technical/CodeBlockView-zh.md
+++ b/docs/zh/references/components/code-block-view.md
@@ -85,7 +85,7 @@ graph TD
- **SvgPreview**: SVG 图像预览
- **GraphvizPreview**: Graphviz 图表预览
-所有特殊视图组件共享通用架构,以确保一致的用户体验和功能。有关这些组件及其实现的详细信息,请参阅 [图像预览组件文档](./ImagePreview-zh.md)。
+所有特殊视图组件共享通用架构,以确保一致的用户体验和功能。有关这些组件及其实现的详细信息,请参阅[图像预览组件文档](./image-preview.md)。
#### StatusBar 状态栏
diff --git a/docs/technical/ImagePreview-zh.md b/docs/zh/references/components/image-preview.md
similarity index 99%
rename from docs/technical/ImagePreview-zh.md
rename to docs/zh/references/components/image-preview.md
index 8a68b84312..99c51cb995 100644
--- a/docs/technical/ImagePreview-zh.md
+++ b/docs/zh/references/components/image-preview.md
@@ -192,4 +192,4 @@ const { containerRef, error, isLoading, triggerRender, cancelRender, clearError,
- 共享状态管理
- 响应式布局适应
-有关整体 CodeBlockView 架构的更多信息,请参阅 [CodeBlockView 文档](./CodeBlockView-zh.md)。
+有关整体 CodeBlockView 架构的更多信息,请参阅 [CodeBlockView 文档](./code-block-view.md)。
diff --git a/docs/technical/db.translate_languages.md b/docs/zh/references/database.md
similarity index 57%
rename from docs/technical/db.translate_languages.md
rename to docs/zh/references/database.md
index 37231c89cd..9fd72d0286 100644
--- a/docs/technical/db.translate_languages.md
+++ b/docs/zh/references/database.md
@@ -1,6 +1,24 @@
-# `translate_languages` 表技术文档
+# 数据库参考文档
-## 📄 概述
+本文档介绍 Cherry Studio 的数据库结构,包括设置字段和翻译语言表。
+
+---
+
+## 设置字段 (settings)
+
+此部分包含设置相关字段的数据类型说明。
+
+### 翻译相关字段
+
+| 字段名 | 类型 | 说明 |
+| ------------------------------ | ------------------------------ | ------------ |
+| `translate:target:language` | `LanguageCode` | 翻译目标语言 |
+| `translate:source:language` | `LanguageCode` | 翻译源语言 |
+| `translate:bidirectional:pair` | `[LanguageCode, LanguageCode]` | 双向翻译对 |
+
+---
+
+## 翻译语言表 (translate_languages)
`translate_languages` 记录用户自定义的的语言类型(`Language`)。
diff --git a/docs/zh/references/message-system.md b/docs/zh/references/message-system.md
new file mode 100644
index 0000000000..91eb2fd82f
--- /dev/null
+++ b/docs/zh/references/message-system.md
@@ -0,0 +1,404 @@
+# 消息系统
+
+本文档介绍 Cherry Studio 的消息系统架构,包括消息生命周期、状态管理和操作接口。
+
+## 消息的生命周期
+
+
+
+---
+
+# messageBlock.ts 使用指南
+
+该文件定义了用于管理应用程序中所有 `MessageBlock` 实体的 Redux Slice。它使用 Redux Toolkit 的 `createSlice` 和 `createEntityAdapter` 来高效地处理规范化的状态,并提供了一系列 actions 和 selectors 用于与消息块数据交互。
+
+## 核心目标
+
+- **状态管理**: 集中管理所有 `MessageBlock` 的状态。`MessageBlock` 代表消息中的不同内容单元(如文本、代码、图片、引用等)。
+- **规范化**: 使用 `createEntityAdapter` 将 `MessageBlock` 数据存储在规范化的结构中(`{ ids: [], entities: {} }`),这有助于提高性能和简化更新逻辑。
+- **可预测性**: 提供明确的 actions 来修改状态,并通过 selectors 安全地访问状态。
+
+## 关键概念
+
+- **Slice (`createSlice`)**: Redux Toolkit 的核心 API,用于创建包含 reducer 逻辑、action creators 和初始状态的 Redux 模块。
+- **Entity Adapter (`createEntityAdapter`)**: Redux Toolkit 提供的工具,用于简化对规范化数据的 CRUD(创建、读取、更新、删除)操作。它会自动生成 reducer 函数和 selectors。
+- **Selectors**: 用于从 Redux store 中派生和计算数据的函数。Selectors 可以被记忆化(memoized),以提高性能。
+
+## State 结构
+
+`messageBlocks` slice 的状态结构由 `createEntityAdapter` 定义,大致如下:
+
+```typescript
+{
+ ids: string[]; // 存储所有 MessageBlock ID 的有序列表
+ entities: { [id: string]: MessageBlock }; // 按 ID 存储 MessageBlock 对象的字典
+ loadingState: 'idle' | 'loading' | 'succeeded' | 'failed'; // (可选) 其他状态,如加载状态
+ error: string | null; // (可选) 错误信息
+}
+```
+
+## Actions
+
+该 slice 导出以下 actions (由 `createSlice` 和 `createEntityAdapter` 自动生成或自定义):
+
+- **`upsertOneBlock(payload: MessageBlock)`**:
+
+ - 添加一个新的 `MessageBlock` 或更新一个已存在的 `MessageBlock`。如果 payload 中的 `id` 已存在,则执行更新;否则执行插入。
+
+- **`upsertManyBlocks(payload: MessageBlock[])`**:
+
+ - 添加或更新多个 `MessageBlock`。常用于批量加载数据(例如,加载一个 Topic 的所有消息块)。
+
+- **`removeOneBlock(payload: string)`**:
+
+ - 根据提供的 `id` (payload) 移除单个 `MessageBlock`。
+
+- **`removeManyBlocks(payload: string[])`**:
+
+ - 根据提供的 `id` 数组 (payload) 移除多个 `MessageBlock`。常用于删除消息或清空 Topic 时清理相关的块。
+
+- **`removeAllBlocks()`**:
+
+ - 移除 state 中的所有 `MessageBlock` 实体。
+
+- **`updateOneBlock(payload: { id: string; changes: Partial })`**:
+
+ - 更新一个已存在的 `MessageBlock`。`payload` 需要包含块的 `id` 和一个包含要更改的字段的 `changes` 对象。
+
+- **`setMessageBlocksLoading(payload: 'idle' | 'loading')`**:
+
+ - (自定义) 设置 `loadingState` 属性。
+
+- **`setMessageBlocksError(payload: string)`**:
+ - (自定义) 设置 `loadingState` 为 `'failed'` 并记录错误信息。
+
+**使用示例 (在 Thunk 或其他 Dispatch 的地方):**
+
+```typescript
+import { upsertOneBlock, removeManyBlocks, updateOneBlock } from './messageBlock'
+import store from './store' // 假设这是你的 Redux store 实例
+
+// 添加或更新一个块
+const newBlock: MessageBlock = {
+ /* ... block data ... */
+}
+store.dispatch(upsertOneBlock(newBlock))
+
+// 更新一个块的内容
+store.dispatch(updateOneBlock({ id: blockId, changes: { content: 'New content' } }))
+
+// 删除多个块
+const blockIdsToRemove = ['id1', 'id2']
+store.dispatch(removeManyBlocks(blockIdsToRemove))
+```
+
+## Selectors
+
+该 slice 导出由 `createEntityAdapter` 生成的基础 selectors,并通过 `messageBlocksSelectors` 对象访问:
+
+- **`messageBlocksSelectors.selectIds(state: RootState): string[]`**: 返回包含所有块 ID 的数组。
+- **`messageBlocksSelectors.selectEntities(state: RootState): { [id: string]: MessageBlock }`**: 返回块 ID 到块对象的映射字典。
+- **`messageBlocksSelectors.selectAll(state: RootState): MessageBlock[]`**: 返回包含所有块对象的数组。
+- **`messageBlocksSelectors.selectTotal(state: RootState): number`**: 返回块的总数。
+- **`messageBlocksSelectors.selectById(state: RootState, id: string): MessageBlock | undefined`**: 根据 ID 返回单个块对象,如果找不到则返回 `undefined`。
+
+**此外,还提供了一个自定义的、记忆化的 selector:**
+
+- **`selectFormattedCitationsByBlockId(state: RootState, blockId: string | undefined): Citation[]`**:
+ - 接收一个 `blockId`。
+ - 如果该 ID 对应的块是 `CITATION` 类型,则提取并格式化其包含的引用信息(来自网页搜索、知识库等),进行去重和重新编号,最后返回一个 `Citation[]` 数组,用于在 UI 中显示。
+ - 如果块不存在或类型不匹配,返回空数组 `[]`。
+ - 这个 selector 封装了处理不同引用来源(Gemini, OpenAI, OpenRouter, Zhipu 等)的复杂逻辑。
+
+**使用示例 (在 React 组件或 `useSelector` 中):**
+
+```typescript
+import { useSelector } from 'react-redux'
+import { messageBlocksSelectors, selectFormattedCitationsByBlockId } from './messageBlock'
+import type { RootState } from './store'
+
+// 获取所有块
+const allBlocks = useSelector(messageBlocksSelectors.selectAll)
+
+// 获取特定 ID 的块
+const specificBlock = useSelector((state: RootState) => messageBlocksSelectors.selectById(state, someBlockId))
+
+// 获取特定引用块格式化后的引用列表
+const formattedCitations = useSelector((state: RootState) => selectFormattedCitationsByBlockId(state, citationBlockId))
+
+// 在组件中使用引用数据
+// {formattedCitations.map(citation => ...)}
+```
+
+## 集成
+
+`messageBlock.ts` slice 通常与 `messageThunk.ts` 中的 Thunks 紧密协作。Thunks 负责处理异步逻辑(如 API 调用、数据库操作),并在需要时 dispatch `messageBlock` slice 的 actions 来更新状态。例如,当 `messageThunk` 接收到流式响应时,它会 dispatch `upsertOneBlock` 或 `updateOneBlock` 来实时更新对应的 `MessageBlock`。同样,删除消息的 Thunk 会 dispatch `removeManyBlocks`。
+
+理解 `messageBlock.ts` 的职责是管理**状态本身**,而 `messageThunk.ts` 负责**触发状态变更**的异步流程,这对于维护清晰的应用架构至关重要。
+
+---
+
+# messageThunk.ts 使用指南
+
+该文件包含用于管理应用程序中消息流、处理助手交互以及同步 Redux 状态与 IndexedDB 数据库的核心 Thunk Action Creators。主要围绕 `Message` 和 `MessageBlock` 对象进行操作。
+
+## 核心功能
+
+1. **发送/接收消息**: 处理用户消息的发送,触发助手响应,并流式处理返回的数据,将其解析为不同的 `MessageBlock`。
+2. **状态管理**: 确保 Redux store 中的消息和消息块状态与 IndexedDB 中的持久化数据保持一致。
+3. **消息操作**: 提供删除、重发、重新生成、编辑后重发、追加响应、克隆等消息生命周期管理功能。
+4. **Block 处理**: 动态创建、更新和保存各种类型的 `MessageBlock`(文本、思考过程、工具调用、引用、图片、错误、翻译等)。
+
+## 主要 Thunks
+
+以下是一些关键的 Thunk 函数及其用途:
+
+1. **`sendMessage(userMessage, userMessageBlocks, assistant, topicId)`**
+
+ - **用途**: 发送一条新的用户消息。
+ - **流程**:
+ - 保存用户消息 (`userMessage`) 及其块 (`userMessageBlocks`) 到 Redux 和 DB。
+ - 检查 `@mentions` 以确定是单模型响应还是多模型响应。
+ - 创建助手消息(们)的存根 (Stub)。
+ - 将存根添加到 Redux 和 DB。
+ - 将核心处理逻辑 `fetchAndProcessAssistantResponseImpl` 添加到该 `topicId` 的队列中以获取实际响应。
+ - **Block 相关**: 主要处理用户消息的初始 `MessageBlock` 保存。
+
+2. **`fetchAndProcessAssistantResponseImpl(dispatch, getState, topicId, assistant, assistantMessage)`**
+
+ - **用途**: (内部函数) 获取并处理单个助手响应的核心逻辑,被 `sendMessage`, `resend...`, `regenerate...`, `append...` 等调用。
+ - **流程**:
+ - 设置 Topic 加载状态。
+ - 准备上下文消息。
+ - 调用 `fetchChatCompletion` API 服务。
+ - 使用 `createStreamProcessor` 处理流式响应。
+ - 通过各种回调 (`onTextChunk`, `onThinkingChunk`, `onToolCallComplete`, `onImageGenerated`, `onError`, `onComplete` 等) 处理不同类型的事件。
+ - **Block 相关**:
+ - 根据流事件创建初始 `UNKNOWN` 块。
+ - 实时创建和更新 `MAIN_TEXT` 和 `THINKING` 块,使用 `throttledBlockUpdate` 和 `throttledBlockDbUpdate` 进行节流更新。
+ - 创建 `TOOL`, `CITATION`, `IMAGE`, `ERROR` 等类型的块。
+ - 在事件完成时(如 `onTextComplete`, `onToolCallComplete`)将块状态标记为 `SUCCESS` 或 `ERROR`,并使用 `saveUpdatedBlockToDB` 保存最终状态。
+ - 使用 `handleBlockTransition` 管理非流式块(如 `TOOL`, `CITATION`)的添加和状态更新。
+
+3. **`loadTopicMessagesThunk(topicId, forceReload)`**
+
+ - **用途**: 从数据库加载指定主题的所有消息及其关联的 `MessageBlock`。
+ - **流程**:
+ - 从 DB 获取 `Topic` 及其 `messages` 列表。
+ - 根据消息 ID 列表从 DB 获取所有相关的 `MessageBlock`。
+ - 使用 `upsertManyBlocks` 将块更新到 Redux。
+ - 将消息更新到 Redux。
+ - **Block 相关**: 负责将持久化的 `MessageBlock` 加载到 Redux 状态。
+
+4. **删除 Thunks**
+
+ - `deleteSingleMessageThunk(topicId, messageId)`: 删除单个消息及其所有 `MessageBlock`。
+ - `deleteMessageGroupThunk(topicId, askId)`: 删除一个用户消息及其所有相关的助手响应消息和它们的所有 `MessageBlock`。
+ - `clearTopicMessagesThunk(topicId)`: 清空主题下的所有消息及其所有 `MessageBlock`。
+ - **Block 相关**: 从 Redux 和 DB 中移除指定的 `MessageBlock`。
+
+5. **重发/重新生成 Thunks**
+
+ - `resendMessageThunk(topicId, userMessageToResend, assistant)`: 重发用户消息。会重置(清空 Block 并标记为 PENDING)所有与该用户消息关联的助手响应,然后重新请求生成。
+ - `resendUserMessageWithEditThunk(topicId, originalMessage, mainTextBlockId, editedContent, assistant)`: 用户编辑消息内容后重发。先更新用户消息的 `MAIN_TEXT` 块内容,然后调用 `resendMessageThunk`。
+ - `regenerateAssistantResponseThunk(topicId, assistantMessageToRegenerate, assistant)`: 重新生成单个助手响应。重置该助手消息(清空 Block 并标记为 PENDING),然后重新请求生成。
+ - **Block 相关**: 删除旧的 `MessageBlock`,并在重新生成过程中创建新的 `MessageBlock`。
+
+6. **`appendAssistantResponseThunk(topicId, existingAssistantMessageId, newModel, assistant)`**
+
+ - **用途**: 在已有的对话上下文中,针对同一个用户问题,使用新选择的模型追加一个新的助手响应。
+ - **流程**:
+ - 找到现有助手消息以获取原始 `askId`。
+ - 创建使用 `newModel` 的新助手消息存根(使用相同的 `askId`)。
+ - 添加新存根到 Redux 和 DB。
+ - 将 `fetchAndProcessAssistantResponseImpl` 添加到队列以生成新响应。
+ - **Block 相关**: 为新的助手响应创建全新的 `MessageBlock`。
+
+7. **`cloneMessagesToNewTopicThunk(sourceTopicId, branchPointIndex, newTopic)`**
+
+ - **用途**: 将源主题的部分消息(及其 Block)克隆到一个**已存在**的新主题中。
+ - **流程**:
+ - 复制指定索引前的消息。
+ - 为所有克隆的消息和 Block 生成新的 UUID。
+ - 正确映射克隆消息之间的 `askId` 关系。
+ - 复制 `MessageBlock` 内容,更新其 `messageId` 指向新的消息 ID。
+ - 更新文件引用计数(如果 Block 是文件或图片)。
+ - 将克隆的消息和 Block 保存到新主题的 Redux 状态和 DB 中。
+ - **Block 相关**: 创建 `MessageBlock` 的副本,并更新其 ID 和 `messageId`。
+
+8. **`initiateTranslationThunk(messageId, topicId, targetLanguage, sourceBlockId?, sourceLanguage?)`**
+ - **用途**: 为指定消息启动翻译流程,创建一个初始的 `TRANSLATION` 类型的 `MessageBlock`。
+ - **流程**:
+ - 创建一个状态为 `STREAMING` 的 `TranslationMessageBlock`。
+ - 将其添加到 Redux 和 DB。
+ - 更新原消息的 `blocks` 列表以包含新的翻译块 ID。
+ - **Block 相关**: 创建并保存一个占位的 `TranslationMessageBlock`。实际翻译内容的获取和填充需要后续步骤。
+
+## 内部机制和注意事项
+
+- **数据库交互**: 通过 `saveMessageAndBlocksToDB`, `updateExistingMessageAndBlocksInDB`, `saveUpdatesToDB`, `saveUpdatedBlockToDB`, `throttledBlockDbUpdate` 等辅助函数与 IndexedDB (`db`) 交互,确保数据持久化。
+- **状态同步**: Thunks 负责协调 Redux Store 和 IndexedDB 之间的数据一致性。
+- **队列 (`getTopicQueue`)**: 使用 `AsyncQueue` 确保对同一主题的操作(尤其是 API 请求)按顺序执行,避免竞态条件。
+- **节流 (`throttle`)**: 对流式响应中频繁的 Block 更新(文本、思考)使用 `lodash.throttle` 优化性能,减少 Redux dispatch 和 DB 写入次数。
+- **错误处理**: `fetchAndProcessAssistantResponseImpl` 内的回调函数(特别是 `onError`)处理流处理和 API 调用中可能出现的错误,并创建 `ERROR` 类型的 `MessageBlock`。
+
+开发者在使用这些 Thunks 时,通常需要提供 `dispatch`, `getState` (由 Redux Thunk 中间件注入),以及如 `topicId`, `assistant` 配置对象, 相关的 `Message` 或 `MessageBlock` 对象/ID 等参数。理解每个 Thunk 的职责和它如何影响消息及块的状态至关重要。
+
+---
+
+# useMessageOperations.ts 使用指南
+
+该文件定义了一个名为 `useMessageOperations` 的自定义 React Hook。这个 Hook 的主要目的是为 React 组件提供一个便捷的接口,用于执行与特定主题(Topic)相关的各种消息操作。它封装了调用 Redux Thunks (`messageThunk.ts`) 和 Actions (`newMessage.ts`, `messageBlock.ts`) 的逻辑,简化了组件与消息数据交互的代码。
+
+## 核心目标
+
+- **封装**: 将复杂的消息操作逻辑(如删除、重发、重新生成、编辑、翻译等)封装在易于使用的函数中。
+- **简化**: 让组件可以直接调用这些操作函数,而无需直接与 Redux `dispatch` 或 Thunks 交互。
+- **上下文关联**: 所有操作都与传入的 `topic` 对象相关联,确保操作作用于正确的主题。
+
+## 如何使用
+
+在你的 React 函数组件中,导入并调用 `useMessageOperations` Hook,并传入当前活动的 `Topic` 对象。
+
+```typescript
+import React from 'react';
+import { useMessageOperations } from '@renderer/hooks/useMessageOperations';
+import type { Topic, Message, Assistant, Model } from '@renderer/types';
+
+interface MyComponentProps {
+ currentTopic: Topic;
+ currentAssistant: Assistant;
+}
+
+function MyComponent({ currentTopic, currentAssistant }: MyComponentProps) {
+ const {
+ deleteMessage,
+ resendMessage,
+ regenerateAssistantMessage,
+ appendAssistantResponse,
+ getTranslationUpdater,
+ createTopicBranch,
+ // ... 其他操作函数
+ } = useMessageOperations(currentTopic);
+
+ const handleDelete = (messageId: string) => {
+ deleteMessage(messageId);
+ };
+
+ const handleResend = (message: Message) => {
+ resendMessage(message, currentAssistant);
+ };
+
+ const handleAppend = (existingMsg: Message, newModel: Model) => {
+ appendAssistantResponse(existingMsg, newModel, currentAssistant);
+ }
+
+ // ... 在组件中使用其他操作函数
+
+ return (
+
+ {/* Component UI */}
+
+ {/* ... */}
+
+ );
+}
+```
+
+## 返回值
+
+`useMessageOperations(topic)` Hook 返回一个包含以下函数和值的对象:
+
+- **`deleteMessage(id: string)`**:
+
+ - 删除指定 `id` 的单个消息。
+ - 内部调用 `deleteSingleMessageThunk`。
+
+- **`deleteGroupMessages(askId: string)`**:
+
+ - 删除与指定 `askId` 相关联的一组消息(通常是用户提问及其所有助手回答)。
+ - 内部调用 `deleteMessageGroupThunk`。
+
+- **`editMessage(messageId: string, updates: Partial)`**:
+
+ - 更新指定 `messageId` 的消息的部分属性。
+ - **注意**: 目前主要用于更新 Redux 状态
+ - 内部调用 `newMessagesActions.updateMessage`。
+
+- **`resendMessage(message: Message, assistant: Assistant)`**:
+
+ - 重新发送指定的用户消息 (`message`),这将触发其所有关联助手响应的重新生成。
+ - 内部调用 `resendMessageThunk`。
+
+- **`resendUserMessageWithEdit(message: Message, editedContent: string, assistant: Assistant)`**:
+
+ - 在用户消息的主要文本块被编辑后,重新发送该消息。
+ - 会先查找消息的 `MAIN_TEXT` 块 ID,然后调用 `resendUserMessageWithEditThunk`。
+
+- **`clearTopicMessages(_topicId?: string)`**:
+
+ - 清除当前主题(或可选的指定 `_topicId`)下的所有消息。
+ - 内部调用 `clearTopicMessagesThunk`。
+
+- **`createNewContext()`**:
+
+ - 发出一个全局事件 (`EVENT_NAMES.NEW_CONTEXT`),通常用于通知 UI 清空显示,准备新的上下文。不直接修改 Redux 状态。
+
+- **`displayCount`**:
+
+ - (非操作函数) 从 Redux store 中获取当前的 `displayCount` 值。
+
+- **`pauseMessages()`**:
+
+ - 尝试中止当前主题中正在进行的消息生成(状态为 `processing` 或 `pending`)。
+ - 通过查找相关的 `askId` 并调用 `abortCompletion` 来实现。
+ - 同时会 dispatch `setTopicLoading` action 将加载状态设为 `false`。
+
+- **`resumeMessage(message: Message, assistant: Assistant)`**:
+
+ - 恢复/重新发送一个用户消息。目前实现为直接调用 `resendMessage`。
+
+- **`regenerateAssistantMessage(message: Message, assistant: Assistant)`**:
+
+ - 重新生成指定的**助手**消息 (`message`) 的响应。
+ - 内部调用 `regenerateAssistantResponseThunk`。
+
+- **`appendAssistantResponse(existingAssistantMessage: Message, newModel: Model, assistant: Assistant)`**:
+
+ - 针对 `existingAssistantMessage` 所回复的**同一用户提问**,使用 `newModel` 追加一个新的助手响应。
+ - 内部调用 `appendAssistantResponseThunk`。
+
+- **`getTranslationUpdater(messageId: string, targetLanguage: string, sourceBlockId?: string, sourceLanguage?: string)`**:
+
+ - **用途**: 获取一个用于逐步更新翻译块内容的函数。
+ - **流程**:
+ 1. 内部调用 `initiateTranslationThunk` 来创建或获取一个 `TRANSLATION` 类型的 `MessageBlock`,并获取其 `blockId`。
+ 2. 返回一个**异步更新函数**。
+ - **返回的更新函数 `(accumulatedText: string, isComplete?: boolean) => void`**:
+ - 接收累积的翻译文本和完成状态。
+ - 调用 `updateOneBlock` 更新 Redux 中的翻译块内容和状态 (`STREAMING` 或 `SUCCESS`)。
+ - 调用 `throttledBlockDbUpdate` 将更新(节流地)保存到数据库。
+ - 如果初始化失败(Thunk 返回 `undefined`),则此函数返回 `null`。
+
+- **`createTopicBranch(sourceTopicId: string, branchPointIndex: number, newTopic: Topic)`**:
+ - 创建一个主题分支,将 `sourceTopicId` 主题中 `branchPointIndex` 索引之前的消息克隆到 `newTopic` 中。
+ - **注意**: `newTopic` 对象必须是调用此函数**之前**已经创建并添加到 Redux 和数据库中的。
+ - 内部调用 `cloneMessagesToNewTopicThunk`。
+
+## 依赖
+
+- **`topic: Topic`**: 必须传入当前操作上下文的主题对象。Hook 返回的操作函数将始终作用于这个主题的 `topic.id`。
+- **Redux `dispatch`**: Hook 内部使用 `useAppDispatch` 获取 `dispatch` 函数来调用 actions 和 thunks。
+
+## 相关 Hooks
+
+在同一文件中还定义了两个辅助 Hook:
+
+- **`useTopicMessages(topic: Topic)`**:
+
+ - 使用 `selectMessagesForTopic` selector 来获取并返回指定主题的消息列表。
+
+- **`useTopicLoading(topic: Topic)`**:
+ - 使用 `selectNewTopicLoading` selector 来获取并返回指定主题的加载状态。
+
+这些 Hook 可以与 `useMessageOperations` 结合使用,方便地在组件中获取消息数据、加载状态,并执行相关操作。
diff --git a/docs/technical/KnowledgeService.md b/docs/zh/references/services.md
similarity index 100%
rename from docs/technical/KnowledgeService.md
rename to docs/zh/references/services.md
diff --git a/electron-builder.yml b/electron-builder.yml
index 6b14548b75..5e63e7231d 100644
--- a/electron-builder.yml
+++ b/electron-builder.yml
@@ -97,7 +97,6 @@ mac:
entitlementsInherit: build/entitlements.mac.plist
notarize: false
artifactName: ${productName}-${version}-${arch}.${ext}
- minimumSystemVersion: "20.1.0" # 最低支持 macOS 11.0
extendInfo:
- NSCameraUsageDescription: Application requests access to the device's camera.
- NSMicrophoneUsageDescription: Application requests access to the device's microphone.
@@ -135,59 +134,108 @@ artifactBuildCompleted: scripts/artifact-build-completed.js
releaseInfo:
releaseNotes: |
- What's New in v1.7.0-beta.4
+ A New Era of Intelligence with Cherry Studio 1.7.1
- Major Changes:
- - UI Framework Upgrade: Improved performance and user experience with new design system
- - App Menu i18n: Menu now supports multiple languages and syncs with app language settings
+ Today we're releasing Cherry Studio 1.7.1 — our most ambitious update yet, introducing Agent: autonomous AI that thinks, plans, and acts.
- New Features:
- - AWS Bedrock API Key: Support Bedrock API key authentication with Extended Thinking (reasoning) capability
- - SophNet Provider: Added support for SophNet LLM provider
- - Auto Session Rename: Agent sessions automatically rename based on conversation topics
- - TopP Parameter: Added TopP parameter support for more precise model control
- - Reasoning Effort Control: Quick access to reasoning effort settings in input bar
+ For years, AI assistants have been reactive — waiting for your commands, responding to your questions. With Agent, we're changing that. Now, AI can truly work alongside you: understanding complex goals, breaking them into steps, and executing them independently.
- Improvements:
- - Topics & Sessions: Enhanced UI with better styling and smoother interactions
- - Quick Panel: Improved option visibility and control
- - Painting Models: Smarter model initialization with better defaults
- - System Shutdown: Better handling of shutdown events to prevent data loss
- - Smaller Package Size: Optimized build configuration for faster downloads
+ This is what we've been building toward. And it's just the beginning.
- Bug Fixes:
- - Fixed Perplexity provider support and API host formatting
- - Fixed CherryAI provider support and API host formatting
- - Fixed i18n translations for painting image size options
- - Fixed agent session message token usage tracking
- - Fixed prompt stream handling on completion or error
- - Fixed message API initialization issues
+ 🤖 Meet Agent
+ Imagine having a brilliant colleague who never sleeps. Give Agent a goal — write a report, analyze data, refactor code — and watch it work. It reasons through problems, breaks them into steps, calls the right tools, and adapts when things change.
+
+ - **Think → Plan → Act**: From goal to execution, fully autonomous
+ - **Deep Reasoning**: Multi-turn thinking that solves real problems
+ - **Tool Mastery**: File operations, web search, code execution, and more
+ - **Skill Plugins**: Extend with custom commands and capabilities
+ - **You Stay in Control**: Real-time approval for sensitive actions
+ - **Full Visibility**: Every thought, every decision, fully transparent
+
+ 🌐 Expanding Ecosystem
+ - **New Providers**: HuggingFace, Mistral, CherryIN, AI Gateway, Intel OVMS, Didi MCP
+ - **New Models**: Claude 4.5 Haiku, DeepSeek v3.2, GLM-4.6, Doubao, Ling series
+ - **MCP Integration**: Alibaba Cloud, ModelScope, Higress, MCP.so, TokenFlux and more
+
+ 📚 Smarter Knowledge Base
+ - **OpenMinerU**: Self-hosted document processing
+ - **Full-Text Search**: Find anything instantly across your notes
+ - **Enhanced Tool Selection**: Smarter configuration for better AI assistance
+
+ 📝 Notes, Reimagined
+ - Full-text search with highlighted results
+ - AI-powered smart rename
+ - Export as image
+ - Auto-wrap for tables
+
+ 🖼️ Image & OCR
+ - Intel OVMS painting capabilities
+ - Intel OpenVINO NPU-accelerated OCR
+
+ 🌍 Now in 10+ Languages
+ - Added German support
+ - Enhanced internationalization
+
+ ⚡ Faster & More Polished
+ - Electron 38 upgrade
+ - New MCP management interface
+ - Dozens of UI refinements
+
+ ❤️ Fully Open Source
+ Commercial restrictions removed. Cherry Studio now follows standard AGPL v3 — free for teams of any size.
+
+ The Agent Era is here. We can't wait to see what you'll create.
- v1.7.0-beta.4 新特性
+ Cherry Studio 1.7.1:开启智能新纪元
- 重大变更:
- - UI 框架升级:采用新设计系统,提升性能和用户体验
- - 应用菜单国际化:菜单支持多语言,并自动同步应用语言设置
+ 今天,我们正式发布 Cherry Studio 1.7.1 —— 迄今最具雄心的版本,带来全新的 Agent:能够自主思考、规划和行动的 AI。
- 新功能:
- - AWS Bedrock API 密钥:支持 Bedrock API 密钥身份验证,并支持扩展思考(推理)能力
- - SophNet 提供商:添加 SophNet LLM 提供商支持
- - 自动会话重命名:Agent 会话根据对话主题自动重命名
- - TopP 参数:添加 TopP 参数支持,更精确控制模型输出
+ 多年来,AI 助手一直是被动的——等待你的指令,回应你的问题。Agent 改变了这一切。现在,AI 能够真正与你并肩工作:理解复杂目标,将其拆解为步骤,并独立执行。
- 改进:
- - 主题和会话:增强 UI,改进样式和交互体验
- - 快速面板:改进选项可见性和控制
- - 绘图模型:更智能的模型初始化和更好的默认值
- - 系统关机:更好地处理关机事件,防止数据丢失
- - 更小的安装包:优化构建配置,下载更快
+ 这是我们一直在构建的未来。而这,仅仅是开始。
- 问题修复:
- - 修复 Perplexity 提供商支持和 API 主机格式化
- - 修复 CherryAI 提供商支持和 API 主机格式化
- - 修复绘图图像大小选项的 i18n 翻译
- - 修复 Agent 会话消息的 token 使用量跟踪
- - 修复完成或错误时的提示流处理
- - 修复消息 API 初始化问题
+ 🤖 认识 Agent
+ 想象一位永不疲倦的得力伙伴。给 Agent 一个目标——撰写报告、分析数据、重构代码——然后看它工作。它会推理问题、拆解步骤、调用工具,并在情况变化时灵活应对。
+
+ - **思考 → 规划 → 行动**:从目标到执行,全程自主
+ - **深度推理**:多轮思考,解决真实问题
+ - **工具大师**:文件操作、网络搜索、代码执行,样样精通
+ - **技能插件**:自定义命令,无限扩展
+ - **你掌控全局**:敏感操作,实时审批
+ - **完全透明**:每一步思考,每一个决策,清晰可见
+
+ 🌐 生态持续壮大
+ - **新增服务商**:Hugging Face、Mistral、Perplexity、SophNet、AI Gateway、Cerebras AI
+ - **新增模型**:Gemini 3、Gemini 3 Pro(支持图像预览)、GPT-5.1、Claude Opus 4.5
+ - **MCP 集成**:百炼、魔搭、Higress、MCP.so、TokenFlux 等平台
+
+ 📚 更智能的知识库
+ - **OpenMinerU**:本地自部署文档处理
+ - **全文搜索**:笔记内容一搜即达
+ - **增强工具选择**:更智能的配置,更好的 AI 协助
+
+ 📝 笔记,焕然一新
+ - 全文搜索,结果高亮
+ - AI 智能重命名
+ - 导出为图片
+ - 表格自动换行
+
+ 🖼️ 图像与 OCR
+ - Intel OVMS 绘图能力
+ - Intel OpenVINO NPU 加速 OCR
+
+ 🌍 支持 10+ 种语言
+ - 新增德语支持
+ - 全面增强国际化
+
+ ⚡ 更快、更精致
+ - 升级 Electron 38
+ - 新的 MCP 管理界面
+ - 数十处 UI 细节打磨
+
+ ❤️ 完全开源
+ 商用限制已移除。Cherry Studio 现遵循标准 AGPL v3 协议——任意规模团队均可自由使用。
+
+ Agent 纪元已至。期待你的创造。
diff --git a/electron.vite.config.ts b/electron.vite.config.ts
index b4914539c7..172d48ca9a 100644
--- a/electron.vite.config.ts
+++ b/electron.vite.config.ts
@@ -95,7 +95,8 @@ export default defineConfig({
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
'@cherrystudio/ai-core/built-in/plugins': resolve('packages/aiCore/src/core/plugins/built-in'),
'@cherrystudio/ai-core': resolve('packages/aiCore/src'),
- '@cherrystudio/extension-table-plus': resolve('packages/extension-table-plus/src')
+ '@cherrystudio/extension-table-plus': resolve('packages/extension-table-plus/src'),
+ '@cherrystudio/ai-sdk-provider': resolve('packages/ai-sdk-provider/src')
}
},
optimizeDeps: {
diff --git a/eslint.config.mjs b/eslint.config.mjs
index fcc952ed65..64fdefa1dc 100644
--- a/eslint.config.mjs
+++ b/eslint.config.mjs
@@ -58,6 +58,7 @@ export default defineConfig([
'dist/**',
'out/**',
'local/**',
+ 'tests/**',
'.yarn/**',
'.gitignore',
'scripts/cloudflare-worker.js',
diff --git a/package.json b/package.json
index 5c41f0df65..fd5eb0151d 100644
--- a/package.json
+++ b/package.json
@@ -1,6 +1,6 @@
{
"name": "CherryStudio",
- "version": "1.7.0-beta.3",
+ "version": "1.7.1",
"private": true,
"description": "A powerful AI assistant for producer.",
"main": "./out/main/index.js",
@@ -58,9 +58,11 @@
"update:i18n": "dotenv -e .env -- tsx scripts/update-i18n.ts",
"auto:i18n": "dotenv -e .env -- tsx scripts/auto-translate-i18n.ts",
"update:languages": "tsx scripts/update-languages.ts",
+ "update:upgrade-config": "tsx scripts/update-app-upgrade-config.ts",
"test": "vitest run --silent",
"test:main": "vitest run --project main",
"test:renderer": "vitest run --project renderer",
+ "test:aicore": "vitest run --project aiCore",
"test:update": "yarn test:renderer --update",
"test:coverage": "vitest run --coverage --silent",
"test:ui": "vitest --ui",
@@ -73,17 +75,19 @@
"format:check": "biome format && biome lint",
"prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky",
"claude": "dotenv -e .env -- claude",
- "release:aicore:alpha": "yarn workspace @cherrystudio/ai-core version prerelease --immediate && yarn workspace @cherrystudio/ai-core npm publish --tag alpha --access public",
- "release:aicore:beta": "yarn workspace @cherrystudio/ai-core version prerelease --immediate && yarn workspace @cherrystudio/ai-core npm publish --tag beta --access public",
- "release:aicore": "yarn workspace @cherrystudio/ai-core version patch --immediate && yarn workspace @cherrystudio/ai-core npm publish --access public"
+ "release:aicore:alpha": "yarn workspace @cherrystudio/ai-core version prerelease --preid alpha --immediate && yarn workspace @cherrystudio/ai-core build && yarn workspace @cherrystudio/ai-core npm publish --tag alpha --access public",
+ "release:aicore:beta": "yarn workspace @cherrystudio/ai-core version prerelease --preid beta --immediate && yarn workspace @cherrystudio/ai-core build && yarn workspace @cherrystudio/ai-core npm publish --tag beta --access public",
+ "release:aicore": "yarn workspace @cherrystudio/ai-core version patch --immediate && yarn workspace @cherrystudio/ai-core build && yarn workspace @cherrystudio/ai-core npm publish --access public",
+ "release:ai-sdk-provider": "yarn workspace @cherrystudio/ai-sdk-provider version patch --immediate && yarn workspace @cherrystudio/ai-sdk-provider build && yarn workspace @cherrystudio/ai-sdk-provider npm publish --access public"
},
"dependencies": {
- "@anthropic-ai/claude-agent-sdk": "patch:@anthropic-ai/claude-agent-sdk@npm%3A0.1.25#~/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.25-08bbabb5d3.patch",
+ "@anthropic-ai/claude-agent-sdk": "patch:@anthropic-ai/claude-agent-sdk@npm%3A0.1.53#~/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.53-4b77f4cf29.patch",
"@libsql/client": "0.14.0",
"@libsql/win32-x64-msvc": "^0.4.7",
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
"@paymoapp/electron-shutdown-handler": "^1.1.2",
"@strongtz/win32-arm64-msvc": "^0.4.7",
+ "emoji-picker-element-data": "^1",
"express": "^5.1.0",
"font-list": "^2.0.0",
"graceful-fs": "^4.2.11",
@@ -106,11 +110,17 @@
"@agentic/exa": "^7.3.3",
"@agentic/searxng": "^7.3.3",
"@agentic/tavily": "^7.3.3",
- "@ai-sdk/amazon-bedrock": "^3.0.42",
- "@ai-sdk/google-vertex": "^3.0.48",
- "@ai-sdk/huggingface": "patch:@ai-sdk/huggingface@npm%3A0.0.4#~/.yarn/patches/@ai-sdk-huggingface-npm-0.0.4-8080836bc1.patch",
- "@ai-sdk/mistral": "^2.0.19",
- "@ai-sdk/perplexity": "^2.0.13",
+ "@ai-sdk/amazon-bedrock": "^3.0.61",
+ "@ai-sdk/anthropic": "^2.0.49",
+ "@ai-sdk/cerebras": "^1.0.31",
+ "@ai-sdk/gateway": "^2.0.15",
+ "@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.43#~/.yarn/patches/@ai-sdk-google-npm-2.0.43-689ed559b3.patch",
+ "@ai-sdk/google-vertex": "^3.0.79",
+ "@ai-sdk/huggingface": "^0.0.10",
+ "@ai-sdk/mistral": "^2.0.24",
+ "@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.72#~/.yarn/patches/@ai-sdk-openai-npm-2.0.72-234e68da87.patch",
+ "@ai-sdk/perplexity": "^2.0.20",
+ "@ai-sdk/test-server": "^0.0.1",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
@@ -118,7 +128,7 @@
"@aws-sdk/client-bedrock-runtime": "^3.910.0",
"@aws-sdk/client-s3": "^3.910.0",
"@biomejs/biome": "2.2.4",
- "@cherrystudio/ai-core": "workspace:^1.0.0-alpha.18",
+ "@cherrystudio/ai-core": "workspace:^1.0.9",
"@cherrystudio/embedjs": "^0.1.31",
"@cherrystudio/embedjs-libsql": "^0.1.31",
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
@@ -132,7 +142,7 @@
"@cherrystudio/embedjs-ollama": "^0.1.31",
"@cherrystudio/embedjs-openai": "^0.1.31",
"@cherrystudio/extension-table-plus": "workspace:^",
- "@cherrystudio/openai": "^6.5.0",
+ "@cherrystudio/openai": "^6.9.0",
"@dnd-kit/core": "^6.3.1",
"@dnd-kit/modifiers": "^9.0.0",
"@dnd-kit/sortable": "^10.0.0",
@@ -152,18 +162,18 @@
"@langchain/core": "patch:@langchain/core@npm%3A1.0.2#~/.yarn/patches/@langchain-core-npm-1.0.2-183ef83fe4.patch",
"@langchain/openai": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
"@mistralai/mistralai": "^1.7.5",
- "@modelcontextprotocol/sdk": "^1.17.5",
+ "@modelcontextprotocol/sdk": "^1.23.0",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
- "@openrouter/ai-sdk-provider": "^1.2.0",
+ "@openrouter/ai-sdk-provider": "^1.2.8",
"@opentelemetry/api": "^1.9.0",
"@opentelemetry/core": "2.0.0",
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
"@opentelemetry/sdk-trace-base": "^2.0.0",
"@opentelemetry/sdk-trace-node": "^2.0.0",
"@opentelemetry/sdk-trace-web": "^2.0.0",
- "@opeoginni/github-copilot-openai-compatible": "0.1.19",
- "@playwright/test": "^1.52.0",
+ "@opeoginni/github-copilot-openai-compatible": "^0.1.21",
+ "@playwright/test": "^1.55.1",
"@radix-ui/react-context-menu": "^2.2.16",
"@reduxjs/toolkit": "^2.2.5",
"@shikijs/markdown-it": "^3.12.0",
@@ -197,6 +207,7 @@
"@types/content-type": "^1.1.9",
"@types/cors": "^2.8.19",
"@types/diff": "^7",
+ "@types/dotenv": "^8.2.3",
"@types/express": "^5",
"@types/fs-extra": "^11",
"@types/he": "^1",
@@ -208,8 +219,8 @@
"@types/mime-types": "^3",
"@types/node": "^22.17.1",
"@types/pako": "^1.0.2",
- "@types/react": "^19.0.12",
- "@types/react-dom": "^19.0.4",
+ "@types/react": "^19.2.7",
+ "@types/react-dom": "^19.2.3",
"@types/react-infinite-scroll-component": "^5.0.0",
"@types/react-transition-group": "^4.4.12",
"@types/react-window": "^1",
@@ -231,7 +242,7 @@
"@viz-js/lang-dot": "^1.0.5",
"@viz-js/viz": "^3.14.0",
"@xyflow/react": "^12.4.4",
- "ai": "^5.0.76",
+ "ai": "^5.0.98",
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
"archiver": "^7.0.1",
"async-mutex": "^0.5.0",
@@ -241,7 +252,7 @@
"check-disk-space": "3.4.0",
"cheerio": "^1.1.2",
"chokidar": "^4.0.3",
- "claude-code-plugins": "1.0.1",
+ "claude-code-plugins": "1.0.3",
"cli-progress": "^3.12.0",
"clsx": "^2.1.1",
"code-inspector-plugin": "^0.20.14",
@@ -257,12 +268,12 @@
"dotenv-cli": "^7.4.2",
"drizzle-kit": "^0.31.4",
"drizzle-orm": "^0.44.5",
- "electron": "38.4.0",
- "electron-builder": "26.0.15",
+ "electron": "38.7.0",
+ "electron-builder": "26.1.0",
"electron-devtools-installer": "^3.2.0",
"electron-reload": "^2.0.0-alpha.1",
"electron-store": "^8.2.0",
- "electron-updater": "6.6.4",
+ "electron-updater": "patch:electron-updater@npm%3A6.7.0#~/.yarn/patches/electron-updater-npm-6.7.0-47b11bb0d4.patch",
"electron-vite": "4.0.1",
"electron-window-state": "^5.0.3",
"emittery": "^1.0.3",
@@ -307,12 +318,12 @@
"motion": "^12.10.5",
"notion-helper": "^1.3.22",
"npx-scope-finder": "^1.2.0",
+ "ollama-ai-provider-v2": "^1.5.5",
"oxlint": "^1.22.0",
"oxlint-tsgolint": "^0.2.0",
"p-queue": "^8.1.0",
"pdf-lib": "^1.17.1",
"pdf-parse": "^1.1.1",
- "playwright": "^1.55.1",
"proxy-agent": "^6.5.0",
"react": "^19.2.0",
"react-dom": "^19.2.0",
@@ -379,13 +390,11 @@
"@codemirror/lint": "6.8.5",
"@codemirror/view": "6.38.1",
"@langchain/core@npm:^0.3.26": "patch:@langchain/core@npm%3A1.0.2#~/.yarn/patches/@langchain-core-npm-1.0.2-183ef83fe4.patch",
- "app-builder-lib@npm:26.0.13": "patch:app-builder-lib@npm%3A26.0.13#~/.yarn/patches/app-builder-lib-npm-26.0.13-a064c9e1d0.patch",
- "app-builder-lib@npm:26.0.15": "patch:app-builder-lib@npm%3A26.0.15#~/.yarn/patches/app-builder-lib-npm-26.0.15-360e5b0476.patch",
"atomically@npm:^1.7.0": "patch:atomically@npm%3A1.7.0#~/.yarn/patches/atomically-npm-1.7.0-e742e5293b.patch",
"esbuild": "^0.25.0",
"file-stream-rotator@npm:^0.6.1": "patch:file-stream-rotator@npm%3A0.6.1#~/.yarn/patches/file-stream-rotator-npm-0.6.1-eab45fb13d.patch",
"libsql@npm:^0.4.4": "patch:libsql@npm%3A0.4.7#~/.yarn/patches/libsql-npm-0.4.7-444e260fb1.patch",
- "node-abi": "4.12.0",
+ "node-abi": "4.24.0",
"openai@npm:^4.77.0": "npm:@cherrystudio/openai@6.5.0",
"openai@npm:^4.87.3": "npm:@cherrystudio/openai@6.5.0",
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
@@ -394,7 +403,6 @@
"undici": "6.21.2",
"vite": "npm:rolldown-vite@7.1.5",
"tesseract.js@npm:*": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch",
- "@ai-sdk/google@npm:2.0.23": "patch:@ai-sdk/google@npm%3A2.0.23#~/.yarn/patches/@ai-sdk-google-npm-2.0.23-81682e07b0.patch",
"@ai-sdk/openai@npm:^2.0.52": "patch:@ai-sdk/openai@npm%3A2.0.52#~/.yarn/patches/@ai-sdk-openai-npm-2.0.52-b36d949c76.patch",
"@img/sharp-darwin-arm64": "0.34.3",
"@img/sharp-darwin-x64": "0.34.3",
@@ -405,7 +413,10 @@
"openai@npm:5.12.2": "npm:@cherrystudio/openai@6.5.0",
"@langchain/openai@npm:>=0.1.0 <0.6.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
- "@langchain/openai@npm:>=0.2.0 <0.7.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch"
+ "@langchain/openai@npm:>=0.2.0 <0.7.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
+ "@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.72#~/.yarn/patches/@ai-sdk-openai-npm-2.0.72-234e68da87.patch",
+ "@ai-sdk/google@npm:^2.0.40": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
+ "@ai-sdk/openai-compatible@npm:^1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch"
},
"packageManager": "yarn@4.9.1",
"lint-staged": {
diff --git a/packages/ai-sdk-provider/README.md b/packages/ai-sdk-provider/README.md
new file mode 100644
index 0000000000..ecd9df2923
--- /dev/null
+++ b/packages/ai-sdk-provider/README.md
@@ -0,0 +1,39 @@
+# @cherrystudio/ai-sdk-provider
+
+CherryIN provider bundle for the [Vercel AI SDK](https://ai-sdk.dev/).
+It exposes the CherryIN OpenAI-compatible entrypoints and dynamically routes Anthropic and Gemini model ids to their CherryIN upstream equivalents.
+
+## Installation
+
+```bash
+npm install ai @cherrystudio/ai-sdk-provider @ai-sdk/anthropic @ai-sdk/google @ai-sdk/openai
+# or
+yarn add ai @cherrystudio/ai-sdk-provider @ai-sdk/anthropic @ai-sdk/google @ai-sdk/openai
+```
+
+> **Note**: This package requires peer dependencies `ai`, `@ai-sdk/anthropic`, `@ai-sdk/google`, and `@ai-sdk/openai` to be installed.
+
+## Usage
+
+```ts
+import { createCherryIn, cherryIn } from '@cherrystudio/ai-sdk-provider'
+
+const cherryInProvider = createCherryIn({
+ apiKey: process.env.CHERRYIN_API_KEY,
+ // optional overrides:
+ // baseURL: 'https://open.cherryin.net/v1',
+ // anthropicBaseURL: 'https://open.cherryin.net/anthropic',
+ // geminiBaseURL: 'https://open.cherryin.net/gemini/v1beta',
+})
+
+// Chat models will auto-route based on the model id prefix:
+const openaiModel = cherryInProvider.chat('gpt-4o-mini')
+const anthropicModel = cherryInProvider.chat('claude-3-5-sonnet-latest')
+const geminiModel = cherryInProvider.chat('gemini-2.0-pro-exp')
+
+const { text } = await openaiModel.invoke('Hello CherryIN!')
+```
+
+The provider also exposes `completion`, `responses`, `embedding`, `image`, `transcription`, and `speech` helpers aligned with the upstream APIs.
+
+See [AI SDK docs](https://ai-sdk.dev/providers/community-providers/custom-providers) for configuring custom providers.
diff --git a/packages/ai-sdk-provider/package.json b/packages/ai-sdk-provider/package.json
new file mode 100644
index 0000000000..25864f3b1f
--- /dev/null
+++ b/packages/ai-sdk-provider/package.json
@@ -0,0 +1,65 @@
+{
+ "name": "@cherrystudio/ai-sdk-provider",
+ "version": "0.1.3",
+ "description": "Cherry Studio AI SDK provider bundle with CherryIN routing.",
+ "keywords": [
+ "ai-sdk",
+ "provider",
+ "cherryin",
+ "vercel-ai-sdk",
+ "cherry-studio"
+ ],
+ "author": "Cherry Studio",
+ "license": "MIT",
+ "homepage": "https://github.com/CherryHQ/cherry-studio",
+ "repository": {
+ "type": "git",
+ "url": "git+https://github.com/CherryHQ/cherry-studio.git",
+ "directory": "packages/ai-sdk-provider"
+ },
+ "bugs": {
+ "url": "https://github.com/CherryHQ/cherry-studio/issues"
+ },
+ "type": "module",
+ "main": "dist/index.cjs",
+ "module": "dist/index.js",
+ "types": "dist/index.d.ts",
+ "files": [
+ "dist"
+ ],
+ "scripts": {
+ "build": "tsdown",
+ "dev": "tsc -w",
+ "clean": "rm -rf dist",
+ "test": "vitest run",
+ "test:watch": "vitest"
+ },
+ "peerDependencies": {
+ "@ai-sdk/anthropic": "^2.0.29",
+ "@ai-sdk/google": "^2.0.23",
+ "@ai-sdk/openai": "^2.0.64",
+ "ai": "^5.0.26"
+ },
+ "dependencies": {
+ "@ai-sdk/openai-compatible": "^1.0.28",
+ "@ai-sdk/provider": "^2.0.0",
+ "@ai-sdk/provider-utils": "^3.0.17"
+ },
+ "devDependencies": {
+ "tsdown": "^0.13.3",
+ "typescript": "^5.8.2",
+ "vitest": "^3.2.4"
+ },
+ "sideEffects": false,
+ "engines": {
+ "node": ">=18.0.0"
+ },
+ "exports": {
+ ".": {
+ "types": "./dist/index.d.ts",
+ "import": "./dist/index.js",
+ "require": "./dist/index.cjs",
+ "default": "./dist/index.js"
+ }
+ }
+}
diff --git a/packages/ai-sdk-provider/src/cherryin-provider.ts b/packages/ai-sdk-provider/src/cherryin-provider.ts
new file mode 100644
index 0000000000..33ec1a2a3a
--- /dev/null
+++ b/packages/ai-sdk-provider/src/cherryin-provider.ts
@@ -0,0 +1,348 @@
+import { AnthropicMessagesLanguageModel } from '@ai-sdk/anthropic/internal'
+import { GoogleGenerativeAILanguageModel } from '@ai-sdk/google/internal'
+import type { OpenAIProviderSettings } from '@ai-sdk/openai'
+import {
+ OpenAICompletionLanguageModel,
+ OpenAIEmbeddingModel,
+ OpenAIImageModel,
+ OpenAIResponsesLanguageModel,
+ OpenAISpeechModel,
+ OpenAITranscriptionModel
+} from '@ai-sdk/openai/internal'
+import { OpenAICompatibleChatLanguageModel } from '@ai-sdk/openai-compatible'
+import {
+ type EmbeddingModelV2,
+ type ImageModelV2,
+ type LanguageModelV2,
+ type ProviderV2,
+ type SpeechModelV2,
+ type TranscriptionModelV2
+} from '@ai-sdk/provider'
+import type { FetchFunction } from '@ai-sdk/provider-utils'
+import { loadApiKey, withoutTrailingSlash } from '@ai-sdk/provider-utils'
+
+export const CHERRYIN_PROVIDER_NAME = 'cherryin' as const
+export const DEFAULT_CHERRYIN_BASE_URL = 'https://open.cherryin.net/v1'
+export const DEFAULT_CHERRYIN_ANTHROPIC_BASE_URL = 'https://open.cherryin.net/v1'
+export const DEFAULT_CHERRYIN_GEMINI_BASE_URL = 'https://open.cherryin.net/v1beta/models'
+
+const ANTHROPIC_PREFIX = /^anthropic\//i
+const GEMINI_PREFIX = /^google\//i
+// const GEMINI_EXCLUDED_SUFFIXES = ['-nothink', '-search']
+
+type HeaderValue = string | undefined
+
+type HeadersInput = Record | (() => Record)
+
+export interface CherryInProviderSettings {
+ /**
+ * CherryIN API key.
+ *
+ * If omitted, the provider will read the `CHERRYIN_API_KEY` environment variable.
+ */
+ apiKey?: string
+ /**
+ * Optional custom fetch implementation.
+ */
+ fetch?: FetchFunction
+ /**
+ * Base URL for OpenAI-compatible CherryIN endpoints.
+ *
+ * Defaults to `https://open.cherryin.net/v1`.
+ */
+ baseURL?: string
+ /**
+ * Base URL for Anthropic-compatible endpoints.
+ *
+ * Defaults to `https://open.cherryin.net/anthropic`.
+ */
+ anthropicBaseURL?: string
+ /**
+ * Base URL for Gemini-compatible endpoints.
+ *
+ * Defaults to `https://open.cherryin.net/gemini/v1beta`.
+ */
+ geminiBaseURL?: string
+ /**
+ * Optional static headers applied to every request.
+ */
+ headers?: HeadersInput
+ /**
+ * Optional endpoint type to distinguish different endpoint behaviors.
+ * "image-generation" is also openai endpoint, but specifically for image generation.
+ */
+ endpointType?: 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank'
+}
+
+export interface CherryInProvider extends ProviderV2 {
+ (modelId: string, settings?: OpenAIProviderSettings): LanguageModelV2
+ languageModel(modelId: string, settings?: OpenAIProviderSettings): LanguageModelV2
+ chat(modelId: string, settings?: OpenAIProviderSettings): LanguageModelV2
+ responses(modelId: string): LanguageModelV2
+ completion(modelId: string, settings?: OpenAIProviderSettings): LanguageModelV2
+ embedding(modelId: string, settings?: OpenAIProviderSettings): EmbeddingModelV2
+ textEmbedding(modelId: string, settings?: OpenAIProviderSettings): EmbeddingModelV2
+ textEmbeddingModel(modelId: string, settings?: OpenAIProviderSettings): EmbeddingModelV2
+ image(modelId: string, settings?: OpenAIProviderSettings): ImageModelV2
+ imageModel(modelId: string, settings?: OpenAIProviderSettings): ImageModelV2
+ transcription(modelId: string): TranscriptionModelV2
+ transcriptionModel(modelId: string): TranscriptionModelV2
+ speech(modelId: string): SpeechModelV2
+ speechModel(modelId: string): SpeechModelV2
+}
+
+const resolveApiKey = (options: CherryInProviderSettings): string =>
+ loadApiKey({
+ apiKey: options.apiKey,
+ environmentVariableName: 'CHERRYIN_API_KEY',
+ description: 'CherryIN'
+ })
+
+const isAnthropicModel = (modelId: string) => ANTHROPIC_PREFIX.test(modelId)
+const isGeminiModel = (modelId: string) => GEMINI_PREFIX.test(modelId)
+
+const createCustomFetch = (originalFetch?: any) => {
+ return async (url: string, options: any) => {
+ if (options?.body) {
+ try {
+ const body = JSON.parse(options.body)
+ if (body.tools && Array.isArray(body.tools) && body.tools.length === 0 && body.tool_choice) {
+ delete body.tool_choice
+ options.body = JSON.stringify(body)
+ }
+ } catch (error) {
+ // ignore error
+ }
+ }
+
+ return originalFetch ? originalFetch(url, options) : fetch(url, options)
+ }
+}
+class CherryInOpenAIChatLanguageModel extends OpenAICompatibleChatLanguageModel {
+ constructor(modelId: string, settings: any) {
+ super(modelId, {
+ ...settings,
+ fetch: createCustomFetch(settings.fetch)
+ })
+ }
+}
+
+const resolveConfiguredHeaders = (headers?: HeadersInput): Record => {
+ if (typeof headers === 'function') {
+ return { ...headers() }
+ }
+ return headers ? { ...headers } : {}
+}
+
+const toBearerToken = (authorization?: string) => (authorization ? authorization.replace(/^Bearer\s+/i, '') : undefined)
+
+const createJsonHeadersGetter = (options: CherryInProviderSettings): (() => Record) => {
+ return () => ({
+ Authorization: `Bearer ${resolveApiKey(options)}`,
+ 'Content-Type': 'application/json',
+ ...resolveConfiguredHeaders(options.headers)
+ })
+}
+
+const createAuthHeadersGetter = (options: CherryInProviderSettings): (() => Record) => {
+ return () => ({
+ Authorization: `Bearer ${resolveApiKey(options)}`,
+ ...resolveConfiguredHeaders(options.headers)
+ })
+}
+
+export const createCherryIn = (options: CherryInProviderSettings = {}): CherryInProvider => {
+ const {
+ baseURL = DEFAULT_CHERRYIN_BASE_URL,
+ anthropicBaseURL = DEFAULT_CHERRYIN_ANTHROPIC_BASE_URL,
+ geminiBaseURL = DEFAULT_CHERRYIN_GEMINI_BASE_URL,
+ fetch,
+ endpointType
+ } = options
+
+ const getJsonHeaders = createJsonHeadersGetter(options)
+ const getAuthHeaders = createAuthHeadersGetter(options)
+
+ const url = ({ path }: { path: string; modelId: string }) => `${withoutTrailingSlash(baseURL)}${path}`
+
+ const createAnthropicModel = (modelId: string) =>
+ new AnthropicMessagesLanguageModel(modelId, {
+ provider: `${CHERRYIN_PROVIDER_NAME}.anthropic`,
+ baseURL: anthropicBaseURL,
+ headers: () => {
+ const headers = getJsonHeaders()
+ const apiKey = toBearerToken(headers.Authorization)
+ return {
+ ...headers,
+ 'x-api-key': apiKey
+ }
+ },
+ fetch,
+ supportedUrls: () => ({
+ 'image/*': [/^https?:\/\/.*$/]
+ })
+ })
+
+ const createGeminiModel = (modelId: string) =>
+ new GoogleGenerativeAILanguageModel(modelId, {
+ provider: `${CHERRYIN_PROVIDER_NAME}.google`,
+ baseURL: geminiBaseURL,
+ headers: () => {
+ const headers = getJsonHeaders()
+ const apiKey = toBearerToken(headers.Authorization)
+ return {
+ ...headers,
+ 'x-goog-api-key': apiKey
+ }
+ },
+ fetch,
+ generateId: () => `${CHERRYIN_PROVIDER_NAME}-${Date.now()}`,
+ supportedUrls: () => ({})
+ })
+
+ const createOpenAIChatModel = (modelId: string, settings: OpenAIProviderSettings = {}) =>
+ new CherryInOpenAIChatLanguageModel(modelId, {
+ provider: `${CHERRYIN_PROVIDER_NAME}.openai-chat`,
+ url,
+ headers: () => ({
+ ...getJsonHeaders(),
+ ...settings.headers
+ }),
+ fetch
+ })
+
+ const createChatModelByModelId = (modelId: string, settings: OpenAIProviderSettings = {}) => {
+ if (isAnthropicModel(modelId)) {
+ return createAnthropicModel(modelId)
+ }
+ if (isGeminiModel(modelId)) {
+ return createGeminiModel(modelId)
+ }
+ return new OpenAIResponsesLanguageModel(modelId, {
+ provider: `${CHERRYIN_PROVIDER_NAME}.openai`,
+ url,
+ headers: () => ({
+ ...getJsonHeaders(),
+ ...settings.headers
+ }),
+ fetch
+ })
+ }
+
+ const createChatModel = (modelId: string, settings: OpenAIProviderSettings = {}) => {
+ if (!endpointType) return createChatModelByModelId(modelId, settings)
+ switch (endpointType) {
+ case 'anthropic':
+ return createAnthropicModel(modelId)
+ case 'gemini':
+ return createGeminiModel(modelId)
+ case 'openai':
+ return createOpenAIChatModel(modelId)
+ case 'openai-response':
+ default:
+ return new OpenAIResponsesLanguageModel(modelId, {
+ provider: `${CHERRYIN_PROVIDER_NAME}.openai`,
+ url,
+ headers: () => ({
+ ...getJsonHeaders(),
+ ...settings.headers
+ }),
+ fetch
+ })
+ }
+ }
+
+ const createCompletionModel = (modelId: string, settings: OpenAIProviderSettings = {}) =>
+ new OpenAICompletionLanguageModel(modelId, {
+ provider: `${CHERRYIN_PROVIDER_NAME}.completion`,
+ url,
+ headers: () => ({
+ ...getJsonHeaders(),
+ ...settings.headers
+ }),
+ fetch
+ })
+
+ const createEmbeddingModel = (modelId: string, settings: OpenAIProviderSettings = {}) =>
+ new OpenAIEmbeddingModel(modelId, {
+ provider: `${CHERRYIN_PROVIDER_NAME}.embeddings`,
+ url,
+ headers: () => ({
+ ...getJsonHeaders(),
+ ...settings.headers
+ }),
+ fetch
+ })
+
+ const createResponsesModel = (modelId: string) =>
+ new OpenAIResponsesLanguageModel(modelId, {
+ provider: `${CHERRYIN_PROVIDER_NAME}.responses`,
+ url,
+ headers: () => ({
+ ...getJsonHeaders()
+ }),
+ fetch
+ })
+
+ const createImageModel = (modelId: string, settings: OpenAIProviderSettings = {}) =>
+ new OpenAIImageModel(modelId, {
+ provider: `${CHERRYIN_PROVIDER_NAME}.image`,
+ url,
+ headers: () => ({
+ ...getJsonHeaders(),
+ ...settings.headers
+ }),
+ fetch
+ })
+
+ const createTranscriptionModel = (modelId: string) =>
+ new OpenAITranscriptionModel(modelId, {
+ provider: `${CHERRYIN_PROVIDER_NAME}.transcription`,
+ url,
+ headers: () => ({
+ ...getAuthHeaders()
+ }),
+ fetch
+ })
+
+ const createSpeechModel = (modelId: string) =>
+ new OpenAISpeechModel(modelId, {
+ provider: `${CHERRYIN_PROVIDER_NAME}.speech`,
+ url,
+ headers: () => ({
+ ...getJsonHeaders()
+ }),
+ fetch
+ })
+
+ const provider: CherryInProvider = function (modelId: string, settings?: OpenAIProviderSettings) {
+ if (new.target) {
+ throw new Error('CherryIN provider function cannot be called with the new keyword.')
+ }
+
+ return createChatModel(modelId, settings)
+ }
+
+ provider.languageModel = createChatModel
+ provider.chat = createOpenAIChatModel
+
+ provider.responses = createResponsesModel
+ provider.completion = createCompletionModel
+
+ provider.embedding = createEmbeddingModel
+ provider.textEmbedding = createEmbeddingModel
+ provider.textEmbeddingModel = createEmbeddingModel
+
+ provider.image = createImageModel
+ provider.imageModel = createImageModel
+
+ provider.transcription = createTranscriptionModel
+ provider.transcriptionModel = createTranscriptionModel
+
+ provider.speech = createSpeechModel
+ provider.speechModel = createSpeechModel
+
+ return provider
+}
+
+export const cherryIn = createCherryIn()
diff --git a/packages/ai-sdk-provider/src/index.ts b/packages/ai-sdk-provider/src/index.ts
new file mode 100644
index 0000000000..d397dd5af5
--- /dev/null
+++ b/packages/ai-sdk-provider/src/index.ts
@@ -0,0 +1 @@
+export * from './cherryin-provider'
diff --git a/packages/ai-sdk-provider/tsconfig.json b/packages/ai-sdk-provider/tsconfig.json
new file mode 100644
index 0000000000..26ee731bb7
--- /dev/null
+++ b/packages/ai-sdk-provider/tsconfig.json
@@ -0,0 +1,19 @@
+{
+ "compilerOptions": {
+ "allowSyntheticDefaultImports": true,
+ "declaration": true,
+ "esModuleInterop": true,
+ "forceConsistentCasingInFileNames": true,
+ "module": "ESNext",
+ "moduleResolution": "bundler",
+ "noEmitOnError": false,
+ "outDir": "./dist",
+ "resolveJsonModule": true,
+ "rootDir": "./src",
+ "skipLibCheck": true,
+ "strict": true,
+ "target": "ES2020"
+ },
+ "exclude": ["node_modules", "dist"],
+ "include": ["src/**/*"]
+}
diff --git a/packages/ai-sdk-provider/tsdown.config.ts b/packages/ai-sdk-provider/tsdown.config.ts
new file mode 100644
index 0000000000..0e07d34cac
--- /dev/null
+++ b/packages/ai-sdk-provider/tsdown.config.ts
@@ -0,0 +1,12 @@
+import { defineConfig } from 'tsdown'
+
+export default defineConfig({
+ entry: {
+ index: 'src/index.ts'
+ },
+ outDir: 'dist',
+ format: ['esm', 'cjs'],
+ clean: true,
+ dts: true,
+ tsconfig: 'tsconfig.json'
+})
diff --git a/packages/aiCore/README.md b/packages/aiCore/README.md
index 4ca5ea6640..1380019094 100644
--- a/packages/aiCore/README.md
+++ b/packages/aiCore/README.md
@@ -71,7 +71,7 @@ Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口
## 安装
```bash
-npm install @cherrystudio/ai-core ai
+npm install @cherrystudio/ai-core ai @ai-sdk/google @ai-sdk/openai
```
### React Native
diff --git a/packages/aiCore/package.json b/packages/aiCore/package.json
index 8310b4164c..a648dcf3c7 100644
--- a/packages/aiCore/package.json
+++ b/packages/aiCore/package.json
@@ -1,6 +1,6 @@
{
"name": "@cherrystudio/ai-core",
- "version": "1.0.1",
+ "version": "1.0.9",
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
"main": "dist/index.js",
"module": "dist/index.mjs",
@@ -33,17 +33,19 @@
},
"homepage": "https://github.com/CherryHQ/cherry-studio#readme",
"peerDependencies": {
+ "@ai-sdk/google": "^2.0.36",
+ "@ai-sdk/openai": "^2.0.64",
+ "@cherrystudio/ai-sdk-provider": "^0.1.3",
"ai": "^5.0.26"
},
"dependencies": {
- "@ai-sdk/anthropic": "^2.0.32",
- "@ai-sdk/azure": "^2.0.53",
- "@ai-sdk/deepseek": "^1.0.23",
- "@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.52#~/.yarn/patches/@ai-sdk-openai-npm-2.0.52-b36d949c76.patch",
- "@ai-sdk/openai-compatible": "^1.0.22",
+ "@ai-sdk/anthropic": "^2.0.49",
+ "@ai-sdk/azure": "^2.0.74",
+ "@ai-sdk/deepseek": "^1.0.31",
+ "@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
"@ai-sdk/provider": "^2.0.0",
- "@ai-sdk/provider-utils": "^3.0.12",
- "@ai-sdk/xai": "^2.0.26",
+ "@ai-sdk/provider-utils": "^3.0.17",
+ "@ai-sdk/xai": "^2.0.36",
"zod": "^4.1.5"
},
"devDependencies": {
diff --git a/packages/aiCore/src/__tests__/fixtures/mock-providers.ts b/packages/aiCore/src/__tests__/fixtures/mock-providers.ts
new file mode 100644
index 0000000000..e8ec2a4a05
--- /dev/null
+++ b/packages/aiCore/src/__tests__/fixtures/mock-providers.ts
@@ -0,0 +1,180 @@
+/**
+ * Mock Provider Instances
+ * Provides mock implementations for all supported AI providers
+ */
+
+import type { ImageModelV2, LanguageModelV2 } from '@ai-sdk/provider'
+import { vi } from 'vitest'
+
+/**
+ * Creates a mock language model with customizable behavior
+ */
+export function createMockLanguageModel(overrides?: Partial): LanguageModelV2 {
+ return {
+ specificationVersion: 'v1',
+ provider: 'mock-provider',
+ modelId: 'mock-model',
+ defaultObjectGenerationMode: 'tool',
+
+ doGenerate: vi.fn().mockResolvedValue({
+ text: 'Mock response text',
+ finishReason: 'stop',
+ usage: {
+ promptTokens: 10,
+ completionTokens: 20,
+ totalTokens: 30
+ },
+ rawCall: { rawPrompt: null, rawSettings: {} },
+ rawResponse: { headers: {} },
+ warnings: []
+ }),
+
+ doStream: vi.fn().mockReturnValue({
+ stream: (async function* () {
+ yield {
+ type: 'text-delta',
+ textDelta: 'Mock '
+ }
+ yield {
+ type: 'text-delta',
+ textDelta: 'streaming '
+ }
+ yield {
+ type: 'text-delta',
+ textDelta: 'response'
+ }
+ yield {
+ type: 'finish',
+ finishReason: 'stop',
+ usage: {
+ promptTokens: 10,
+ completionTokens: 15,
+ totalTokens: 25
+ }
+ }
+ })(),
+ rawCall: { rawPrompt: null, rawSettings: {} },
+ rawResponse: { headers: {} },
+ warnings: []
+ }),
+
+ ...overrides
+ } as LanguageModelV2
+}
+
+/**
+ * Creates a mock image model with customizable behavior
+ */
+export function createMockImageModel(overrides?: Partial): ImageModelV2 {
+ return {
+ specificationVersion: 'v2',
+ provider: 'mock-provider',
+ modelId: 'mock-image-model',
+
+ doGenerate: vi.fn().mockResolvedValue({
+ images: [
+ {
+ base64: 'mock-base64-image-data',
+ uint8Array: new Uint8Array([1, 2, 3, 4, 5]),
+ mimeType: 'image/png'
+ }
+ ],
+ warnings: []
+ }),
+
+ ...overrides
+ } as ImageModelV2
+}
+
+/**
+ * Mock provider configurations for testing
+ */
+export const mockProviderConfigs = {
+ openai: {
+ apiKey: 'sk-test-openai-key-123456789',
+ baseURL: 'https://api.openai.com/v1',
+ organization: 'test-org'
+ },
+
+ anthropic: {
+ apiKey: 'sk-ant-test-key-123456789',
+ baseURL: 'https://api.anthropic.com'
+ },
+
+ google: {
+ apiKey: 'test-google-api-key-123456789',
+ baseURL: 'https://generativelanguage.googleapis.com/v1'
+ },
+
+ xai: {
+ apiKey: 'xai-test-key-123456789',
+ baseURL: 'https://api.x.ai/v1'
+ },
+
+ azure: {
+ apiKey: 'test-azure-key-123456789',
+ resourceName: 'test-resource',
+ deployment: 'test-deployment'
+ },
+
+ deepseek: {
+ apiKey: 'sk-test-deepseek-key-123456789',
+ baseURL: 'https://api.deepseek.com/v1'
+ },
+
+ openrouter: {
+ apiKey: 'sk-or-test-key-123456789',
+ baseURL: 'https://openrouter.ai/api/v1'
+ },
+
+ huggingface: {
+ apiKey: 'hf_test_key_123456789',
+ baseURL: 'https://api-inference.huggingface.co'
+ },
+
+ 'openai-compatible': {
+ apiKey: 'test-compatible-key-123456789',
+ baseURL: 'https://api.example.com/v1',
+ name: 'test-provider'
+ },
+
+ 'openai-chat': {
+ apiKey: 'sk-test-chat-key-123456789',
+ baseURL: 'https://api.openai.com/v1'
+ }
+} as const
+
+/**
+ * Mock provider instances for testing
+ */
+export const mockProviderInstances = {
+ openai: {
+ name: 'openai-mock',
+ languageModel: createMockLanguageModel({ provider: 'openai', modelId: 'gpt-4' }),
+ imageModel: createMockImageModel({ provider: 'openai', modelId: 'dall-e-3' })
+ },
+
+ anthropic: {
+ name: 'anthropic-mock',
+ languageModel: createMockLanguageModel({ provider: 'anthropic', modelId: 'claude-3-5-sonnet-20241022' })
+ },
+
+ google: {
+ name: 'google-mock',
+ languageModel: createMockLanguageModel({ provider: 'google', modelId: 'gemini-2.0-flash-exp' }),
+ imageModel: createMockImageModel({ provider: 'google', modelId: 'imagen-3.0-generate-001' })
+ },
+
+ xai: {
+ name: 'xai-mock',
+ languageModel: createMockLanguageModel({ provider: 'xai', modelId: 'grok-2-latest' }),
+ imageModel: createMockImageModel({ provider: 'xai', modelId: 'grok-2-image-latest' })
+ },
+
+ deepseek: {
+ name: 'deepseek-mock',
+ languageModel: createMockLanguageModel({ provider: 'deepseek', modelId: 'deepseek-chat' })
+ }
+}
+
+export type ProviderId = keyof typeof mockProviderConfigs
diff --git a/packages/aiCore/src/__tests__/fixtures/mock-responses.ts b/packages/aiCore/src/__tests__/fixtures/mock-responses.ts
new file mode 100644
index 0000000000..388a4f7fd5
--- /dev/null
+++ b/packages/aiCore/src/__tests__/fixtures/mock-responses.ts
@@ -0,0 +1,238 @@
+/**
+ * Mock Responses
+ * Provides realistic mock responses for all provider types
+ */
+
+import type { ModelMessage, Tool } from 'ai'
+import { jsonSchema } from 'ai'
+
+/**
+ * Standard test messages for all scenarios
+ */
+export const testMessages: Record = {
+ simple: [{ role: 'user' as const, content: 'Hello, how are you?' }],
+
+ conversation: [
+ { role: 'user' as const, content: 'What is the capital of France?' },
+ { role: 'assistant' as const, content: 'The capital of France is Paris.' },
+ { role: 'user' as const, content: 'What is its population?' }
+ ],
+
+ withSystem: [
+ { role: 'system' as const, content: 'You are a helpful assistant that provides concise answers.' },
+ { role: 'user' as const, content: 'Explain quantum computing in one sentence.' }
+ ],
+
+ withImages: [
+ {
+ role: 'user' as const,
+ content: [
+ { type: 'text' as const, text: 'What is in this image?' },
+ {
+ type: 'image' as const,
+ image:
+ ''
+ }
+ ]
+ }
+ ],
+
+ toolUse: [{ role: 'user' as const, content: 'What is the weather in San Francisco?' }],
+
+ multiTurn: [
+ { role: 'user' as const, content: 'Can you help me with a math problem?' },
+ { role: 'assistant' as const, content: 'Of course! What math problem would you like help with?' },
+ { role: 'user' as const, content: 'What is 15 * 23?' },
+ { role: 'assistant' as const, content: '15 * 23 = 345' },
+ { role: 'user' as const, content: 'Now divide that by 5' }
+ ]
+}
+
+/**
+ * Standard test tools for tool calling scenarios
+ */
+export const testTools: Record = {
+ getWeather: {
+ description: 'Get the current weather in a given location',
+ inputSchema: jsonSchema({
+ type: 'object',
+ properties: {
+ location: {
+ type: 'string',
+ description: 'The city and state, e.g. San Francisco, CA'
+ },
+ unit: {
+ type: 'string',
+ enum: ['celsius', 'fahrenheit'],
+ description: 'The temperature unit to use'
+ }
+ },
+ required: ['location']
+ }),
+ execute: async ({ location, unit = 'fahrenheit' }) => {
+ return {
+ location,
+ temperature: unit === 'celsius' ? 22 : 72,
+ unit,
+ condition: 'sunny'
+ }
+ }
+ },
+
+ calculate: {
+ description: 'Perform a mathematical calculation',
+ inputSchema: jsonSchema({
+ type: 'object',
+ properties: {
+ operation: {
+ type: 'string',
+ enum: ['add', 'subtract', 'multiply', 'divide'],
+ description: 'The operation to perform'
+ },
+ a: {
+ type: 'number',
+ description: 'The first number'
+ },
+ b: {
+ type: 'number',
+ description: 'The second number'
+ }
+ },
+ required: ['operation', 'a', 'b']
+ }),
+ execute: async ({ operation, a, b }) => {
+ const operations = {
+ add: (x: number, y: number) => x + y,
+ subtract: (x: number, y: number) => x - y,
+ multiply: (x: number, y: number) => x * y,
+ divide: (x: number, y: number) => x / y
+ }
+ return { result: operations[operation as keyof typeof operations](a, b) }
+ }
+ },
+
+ searchDatabase: {
+ description: 'Search for information in a database',
+ inputSchema: jsonSchema({
+ type: 'object',
+ properties: {
+ query: {
+ type: 'string',
+ description: 'The search query'
+ },
+ limit: {
+ type: 'number',
+ description: 'Maximum number of results to return',
+ default: 10
+ }
+ },
+ required: ['query']
+ }),
+ execute: async ({ query, limit = 10 }) => {
+ return {
+ results: [
+ { id: 1, title: `Result 1 for ${query}`, relevance: 0.95 },
+ { id: 2, title: `Result 2 for ${query}`, relevance: 0.87 }
+ ].slice(0, limit)
+ }
+ }
+ }
+}
+
+/**
+ * Mock complete responses for non-streaming scenarios
+ * Note: AI SDK v5 uses inputTokens/outputTokens instead of promptTokens/completionTokens
+ */
+export const mockCompleteResponses = {
+ simple: {
+ text: 'This is a simple response.',
+ finishReason: 'stop' as const,
+ usage: {
+ inputTokens: 15,
+ outputTokens: 8,
+ totalTokens: 23
+ }
+ },
+
+ withToolCalls: {
+ text: 'I will check the weather for you.',
+ toolCalls: [
+ {
+ toolCallId: 'call_456',
+ toolName: 'getWeather',
+ args: { location: 'New York, NY', unit: 'celsius' }
+ }
+ ],
+ finishReason: 'tool-calls' as const,
+ usage: {
+ inputTokens: 25,
+ outputTokens: 12,
+ totalTokens: 37
+ }
+ },
+
+ withWarnings: {
+ text: 'Response with warnings.',
+ finishReason: 'stop' as const,
+ usage: {
+ inputTokens: 10,
+ outputTokens: 5,
+ totalTokens: 15
+ },
+ warnings: [
+ {
+ type: 'unsupported-setting' as const,
+ setting: 'temperature',
+ details: 'Temperature parameter not supported for this model'
+ }
+ ]
+ }
+}
+
+/**
+ * Mock image generation responses
+ */
+export const mockImageResponses = {
+ single: {
+ image: {
+ base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
+ uint8Array: new Uint8Array([137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82]),
+ mimeType: 'image/png' as const
+ },
+ warnings: []
+ },
+
+ multiple: {
+ images: [
+ {
+ base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
+ uint8Array: new Uint8Array([137, 80, 78, 71]),
+ mimeType: 'image/png' as const
+ },
+ {
+ base64: 'iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAYAAABytg0kAAAAEklEQVR42mNk+M9QzwAEjDAGACCKAgdZ9zImAAAAAElFTkSuQmCC',
+ uint8Array: new Uint8Array([137, 80, 78, 71]),
+ mimeType: 'image/png' as const
+ }
+ ],
+ warnings: []
+ },
+
+ withProviderMetadata: {
+ image: {
+ base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
+ uint8Array: new Uint8Array([137, 80, 78, 71]),
+ mimeType: 'image/png' as const
+ },
+ providerMetadata: {
+ openai: {
+ images: [
+ {
+ revisedPrompt: 'A detailed and enhanced version of the original prompt'
+ }
+ ]
+ }
+ },
+ warnings: []
+ }
+}
diff --git a/packages/aiCore/src/__tests__/helpers/provider-test-utils.ts b/packages/aiCore/src/__tests__/helpers/provider-test-utils.ts
new file mode 100644
index 0000000000..f8a2051b4b
--- /dev/null
+++ b/packages/aiCore/src/__tests__/helpers/provider-test-utils.ts
@@ -0,0 +1,329 @@
+/**
+ * Provider-Specific Test Utilities
+ * Helper functions for testing individual providers with all their parameters
+ */
+
+import type { Tool } from 'ai'
+import { expect } from 'vitest'
+
+/**
+ * Provider parameter configurations for comprehensive testing
+ */
+export const providerParameterMatrix = {
+ openai: {
+ models: ['gpt-4', 'gpt-4-turbo', 'gpt-3.5-turbo', 'gpt-4o'],
+ parameters: {
+ temperature: [0, 0.5, 0.7, 1.0, 1.5, 2.0],
+ maxTokens: [100, 500, 1000, 2000, 4000],
+ topP: [0.1, 0.5, 0.9, 1.0],
+ frequencyPenalty: [-2.0, -1.0, 0, 1.0, 2.0],
+ presencePenalty: [-2.0, -1.0, 0, 1.0, 2.0],
+ stop: [undefined, ['stop'], ['STOP', 'END']],
+ seed: [undefined, 12345, 67890],
+ responseFormat: [undefined, { type: 'json_object' as const }],
+ user: [undefined, 'test-user-123']
+ },
+ toolChoice: ['auto', 'required', 'none', { type: 'function' as const, name: 'getWeather' }],
+ parallelToolCalls: [true, false]
+ },
+
+ anthropic: {
+ models: ['claude-3-5-sonnet-20241022', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
+ parameters: {
+ temperature: [0, 0.5, 1.0],
+ maxTokens: [100, 1000, 4000, 8000],
+ topP: [0.1, 0.5, 0.9, 1.0],
+ topK: [undefined, 1, 5, 10, 40],
+ stop: [undefined, ['Human:', 'Assistant:']],
+ metadata: [undefined, { userId: 'test-123' }]
+ },
+ toolChoice: ['auto', 'any', { type: 'tool' as const, name: 'getWeather' }]
+ },
+
+ google: {
+ models: ['gemini-2.0-flash-exp', 'gemini-1.5-pro', 'gemini-1.5-flash'],
+ parameters: {
+ temperature: [0, 0.5, 0.9, 1.0],
+ maxTokens: [100, 1000, 2000, 8000],
+ topP: [0.1, 0.5, 0.95, 1.0],
+ topK: [undefined, 1, 16, 40],
+ stopSequences: [undefined, ['END'], ['STOP', 'TERMINATE']]
+ },
+ safetySettings: [
+ undefined,
+ [
+ { category: 'HARM_CATEGORY_HARASSMENT', threshold: 'BLOCK_MEDIUM_AND_ABOVE' },
+ { category: 'HARM_CATEGORY_HATE_SPEECH', threshold: 'BLOCK_ONLY_HIGH' }
+ ]
+ ]
+ },
+
+ xai: {
+ models: ['grok-2-latest', 'grok-2-1212'],
+ parameters: {
+ temperature: [0, 0.5, 1.0, 1.5],
+ maxTokens: [100, 500, 2000, 4000],
+ topP: [0.1, 0.5, 0.9, 1.0],
+ stop: [undefined, ['STOP'], ['END', 'TERMINATE']],
+ seed: [undefined, 12345]
+ }
+ },
+
+ deepseek: {
+ models: ['deepseek-chat', 'deepseek-coder'],
+ parameters: {
+ temperature: [0, 0.5, 1.0],
+ maxTokens: [100, 1000, 4000],
+ topP: [0.1, 0.5, 0.95],
+ frequencyPenalty: [0, 0.5, 1.0],
+ presencePenalty: [0, 0.5, 1.0],
+ stop: [undefined, ['```'], ['END']]
+ }
+ },
+
+ azure: {
+ deployments: ['gpt-4-deployment', 'gpt-35-turbo-deployment'],
+ parameters: {
+ temperature: [0, 0.7, 1.0],
+ maxTokens: [100, 1000, 2000],
+ topP: [0.1, 0.5, 0.95],
+ frequencyPenalty: [0, 1.0],
+ presencePenalty: [0, 1.0],
+ stop: [undefined, ['STOP']]
+ }
+ }
+} as const
+
+/**
+ * Creates test cases for all parameter combinations
+ */
+export function generateParameterTestCases>(
+ params: T,
+ maxCombinations = 50
+): Array> {
+ const keys = Object.keys(params) as Array
+ const testCases: Array> = []
+
+ // Generate combinations using sampling strategy for large parameter spaces
+ const totalCombinations = keys.reduce((acc, key) => acc * params[key].length, 1)
+
+ if (totalCombinations <= maxCombinations) {
+ // Generate all combinations if total is small
+ generateAllCombinations(params, keys, 0, {}, testCases)
+ } else {
+ // Sample diverse combinations if total is large
+ generateSampledCombinations(params, keys, maxCombinations, testCases)
+ }
+
+ return testCases
+}
+
+function generateAllCombinations>(
+ params: T,
+ keys: Array,
+ index: number,
+ current: Partial<{ [K in keyof T]: T[K][number] }>,
+ results: Array>
+) {
+ if (index === keys.length) {
+ results.push({ ...current })
+ return
+ }
+
+ const key = keys[index]
+ for (const value of params[key]) {
+ generateAllCombinations(params, keys, index + 1, { ...current, [key]: value }, results)
+ }
+}
+
+function generateSampledCombinations>(
+ params: T,
+ keys: Array,
+ count: number,
+ results: Array>
+) {
+ // Generate edge cases first (min/max values)
+ const edgeCase1: any = {}
+ const edgeCase2: any = {}
+
+ for (const key of keys) {
+ edgeCase1[key] = params[key][0]
+ edgeCase2[key] = params[key][params[key].length - 1]
+ }
+
+ results.push(edgeCase1, edgeCase2)
+
+ // Generate random combinations for the rest
+ for (let i = results.length; i < count; i++) {
+ const combination: any = {}
+ for (const key of keys) {
+ const values = params[key]
+ combination[key] = values[Math.floor(Math.random() * values.length)]
+ }
+ results.push(combination)
+ }
+}
+
+/**
+ * Validates that all provider-specific parameters are correctly passed through
+ */
+export function validateProviderParams(providerId: string, actualParams: any, expectedParams: any): void {
+ const requiredFields: Record = {
+ openai: ['model', 'messages'],
+ anthropic: ['model', 'messages'],
+ google: ['model', 'contents'],
+ xai: ['model', 'messages'],
+ deepseek: ['model', 'messages'],
+ azure: ['messages']
+ }
+
+ const fields = requiredFields[providerId] || ['model', 'messages']
+
+ for (const field of fields) {
+ expect(actualParams).toHaveProperty(field)
+ }
+
+ // Validate optional parameters if they were provided
+ const optionalParams = ['temperature', 'max_tokens', 'top_p', 'stop', 'tools']
+
+ for (const param of optionalParams) {
+ if (expectedParams[param] !== undefined) {
+ expect(actualParams[param]).toEqual(expectedParams[param])
+ }
+ }
+}
+
+/**
+ * Creates a comprehensive test suite for a provider
+ */
+// oxlint-disable-next-line no-unused-vars
+export function createProviderTestSuite(_providerId: string) {
+ return {
+ testBasicCompletion: async (executor: any, model: string) => {
+ const result = await executor.generateText({
+ model,
+ messages: [{ role: 'user' as const, content: 'Hello' }]
+ })
+
+ expect(result).toBeDefined()
+ expect(result.text).toBeDefined()
+ expect(typeof result.text).toBe('string')
+ },
+
+ testStreaming: async (executor: any, model: string) => {
+ const chunks: any[] = []
+ const result = await executor.streamText({
+ model,
+ messages: [{ role: 'user' as const, content: 'Hello' }]
+ })
+
+ for await (const chunk of result.textStream) {
+ chunks.push(chunk)
+ }
+
+ expect(chunks.length).toBeGreaterThan(0)
+ },
+
+ testTemperature: async (executor: any, model: string, temperatures: number[]) => {
+ for (const temperature of temperatures) {
+ const result = await executor.generateText({
+ model,
+ messages: [{ role: 'user' as const, content: 'Hello' }],
+ temperature
+ })
+
+ expect(result).toBeDefined()
+ }
+ },
+
+ testMaxTokens: async (executor: any, model: string, maxTokensValues: number[]) => {
+ for (const maxTokens of maxTokensValues) {
+ const result = await executor.generateText({
+ model,
+ messages: [{ role: 'user' as const, content: 'Hello' }],
+ maxTokens
+ })
+
+ expect(result).toBeDefined()
+ if (result.usage?.completionTokens) {
+ expect(result.usage.completionTokens).toBeLessThanOrEqual(maxTokens)
+ }
+ }
+ },
+
+ testToolCalling: async (executor: any, model: string, tools: Record) => {
+ const result = await executor.generateText({
+ model,
+ messages: [{ role: 'user' as const, content: 'What is the weather in SF?' }],
+ tools
+ })
+
+ expect(result).toBeDefined()
+ },
+
+ testStopSequences: async (executor: any, model: string, stopSequences: string[][]) => {
+ for (const stop of stopSequences) {
+ const result = await executor.generateText({
+ model,
+ messages: [{ role: 'user' as const, content: 'Count to 10' }],
+ stop
+ })
+
+ expect(result).toBeDefined()
+ }
+ }
+ }
+}
+
+/**
+ * Generates test data for vision/multimodal testing
+ */
+export function createVisionTestData() {
+ return {
+ imageUrl: 'https://example.com/test-image.jpg',
+ base64Image:
+ '',
+ messages: [
+ {
+ role: 'user' as const,
+ content: [
+ { type: 'text' as const, text: 'What is in this image?' },
+ {
+ type: 'image' as const,
+ image:
+ ''
+ }
+ ]
+ }
+ ]
+ }
+}
+
+/**
+ * Creates mock responses for different finish reasons
+ */
+export function createFinishReasonMocks() {
+ return {
+ stop: {
+ text: 'Complete response.',
+ finishReason: 'stop' as const,
+ usage: { promptTokens: 10, completionTokens: 5, totalTokens: 15 }
+ },
+ length: {
+ text: 'Incomplete response due to',
+ finishReason: 'length' as const,
+ usage: { promptTokens: 10, completionTokens: 100, totalTokens: 110 }
+ },
+ 'tool-calls': {
+ text: 'Calling tools',
+ finishReason: 'tool-calls' as const,
+ toolCalls: [{ toolCallId: 'call_1', toolName: 'getWeather', args: { location: 'SF' } }],
+ usage: { promptTokens: 10, completionTokens: 8, totalTokens: 18 }
+ },
+ 'content-filter': {
+ text: '',
+ finishReason: 'content-filter' as const,
+ usage: { promptTokens: 10, completionTokens: 0, totalTokens: 10 }
+ }
+ }
+}
diff --git a/packages/aiCore/src/__tests__/helpers/test-utils.ts b/packages/aiCore/src/__tests__/helpers/test-utils.ts
new file mode 100644
index 0000000000..8231075785
--- /dev/null
+++ b/packages/aiCore/src/__tests__/helpers/test-utils.ts
@@ -0,0 +1,291 @@
+/**
+ * Test Utilities
+ * Helper functions for testing AI Core functionality
+ */
+
+import { expect, vi } from 'vitest'
+
+import type { ProviderId } from '../fixtures/mock-providers'
+import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../fixtures/mock-providers'
+
+/**
+ * Creates a test provider with streaming support
+ */
+export function createTestStreamingProvider(chunks: any[]) {
+ return createMockLanguageModel({
+ doStream: vi.fn().mockReturnValue({
+ stream: (async function* () {
+ for (const chunk of chunks) {
+ yield chunk
+ }
+ })(),
+ rawCall: { rawPrompt: null, rawSettings: {} },
+ rawResponse: { headers: {} },
+ warnings: []
+ })
+ })
+}
+
+/**
+ * Creates a test provider that throws errors
+ */
+export function createErrorProvider(error: Error) {
+ return createMockLanguageModel({
+ doGenerate: vi.fn().mockRejectedValue(error),
+ doStream: vi.fn().mockImplementation(() => {
+ throw error
+ })
+ })
+}
+
+/**
+ * Collects all chunks from a stream
+ */
+export async function collectStreamChunks(stream: AsyncIterable): Promise {
+ const chunks: T[] = []
+ for await (const chunk of stream) {
+ chunks.push(chunk)
+ }
+ return chunks
+}
+
+/**
+ * Waits for a specific number of milliseconds
+ */
+export function wait(ms: number): Promise {
+ return new Promise((resolve) => setTimeout(resolve, ms))
+}
+
+/**
+ * Creates a mock abort controller that aborts after a delay
+ */
+export function createDelayedAbortController(delayMs: number): AbortController {
+ const controller = new AbortController()
+ setTimeout(() => controller.abort(), delayMs)
+ return controller
+}
+
+/**
+ * Asserts that a function throws an error with a specific message
+ */
+export async function expectError(fn: () => Promise, expectedMessage?: string | RegExp): Promise {
+ try {
+ await fn()
+ throw new Error('Expected function to throw an error, but it did not')
+ } catch (error) {
+ if (expectedMessage) {
+ const message = (error as Error).message
+ if (typeof expectedMessage === 'string') {
+ if (!message.includes(expectedMessage)) {
+ throw new Error(`Expected error message to include "${expectedMessage}", but got "${message}"`)
+ }
+ } else {
+ if (!expectedMessage.test(message)) {
+ throw new Error(`Expected error message to match ${expectedMessage}, but got "${message}"`)
+ }
+ }
+ }
+ return error as Error
+ }
+}
+
+/**
+ * Creates a spy function that tracks calls and arguments
+ */
+export function createSpy any>() {
+ const calls: Array<{ args: Parameters; result?: ReturnType; error?: Error }> = []
+
+ const spy = vi.fn((...args: Parameters) => {
+ try {
+ const result = undefined as ReturnType
+ calls.push({ args, result })
+ return result
+ } catch (error) {
+ calls.push({ args, error: error as Error })
+ throw error
+ }
+ })
+
+ return {
+ fn: spy,
+ calls,
+ getCalls: () => calls,
+ getCallCount: () => calls.length,
+ getLastCall: () => calls[calls.length - 1],
+ reset: () => {
+ calls.length = 0
+ spy.mockClear()
+ }
+ }
+}
+
+/**
+ * Validates provider configuration
+ */
+export function validateProviderConfig(providerId: ProviderId) {
+ const config = mockProviderConfigs[providerId]
+ if (!config) {
+ throw new Error(`No mock configuration found for provider: ${providerId}`)
+ }
+
+ if (!config.apiKey) {
+ throw new Error(`Provider ${providerId} is missing apiKey in mock config`)
+ }
+
+ return config
+}
+
+/**
+ * Creates a test context with common setup
+ */
+export function createTestContext() {
+ const mocks = {
+ languageModel: createMockLanguageModel(),
+ imageModel: createMockImageModel(),
+ providers: new Map()
+ }
+
+ const cleanup = () => {
+ mocks.providers.clear()
+ vi.clearAllMocks()
+ }
+
+ return {
+ mocks,
+ cleanup
+ }
+}
+
+/**
+ * Measures execution time of an async function
+ */
+export async function measureTime(fn: () => Promise): Promise<{ result: T; duration: number }> {
+ const start = Date.now()
+ const result = await fn()
+ const duration = Date.now() - start
+ return { result, duration }
+}
+
+/**
+ * Retries a function until it succeeds or max attempts reached
+ */
+export async function retryUntilSuccess(fn: () => Promise, maxAttempts = 3, delayMs = 100): Promise {
+ let lastError: Error | undefined
+
+ for (let attempt = 1; attempt <= maxAttempts; attempt++) {
+ try {
+ return await fn()
+ } catch (error) {
+ lastError = error as Error
+ if (attempt < maxAttempts) {
+ await wait(delayMs)
+ }
+ }
+ }
+
+ throw lastError || new Error('All retry attempts failed')
+}
+
+/**
+ * Creates a mock streaming response that emits chunks at intervals
+ */
+export function createTimedStream(chunks: T[], intervalMs = 10) {
+ return {
+ async *[Symbol.asyncIterator]() {
+ for (const chunk of chunks) {
+ await wait(intervalMs)
+ yield chunk
+ }
+ }
+ }
+}
+
+/**
+ * Asserts that two objects are deeply equal, ignoring specified keys
+ */
+export function assertDeepEqualIgnoring>(
+ actual: T,
+ expected: T,
+ ignoreKeys: string[] = []
+): void {
+ const filterKeys = (obj: T): Partial => {
+ const filtered = { ...obj }
+ for (const key of ignoreKeys) {
+ delete filtered[key]
+ }
+ return filtered
+ }
+
+ const filteredActual = filterKeys(actual)
+ const filteredExpected = filterKeys(expected)
+
+ expect(filteredActual).toEqual(filteredExpected)
+}
+
+/**
+ * Creates a provider mock that simulates rate limiting
+ */
+export function createRateLimitedProvider(limitPerSecond: number) {
+ const calls: number[] = []
+
+ return createMockLanguageModel({
+ doGenerate: vi.fn().mockImplementation(async () => {
+ const now = Date.now()
+ calls.push(now)
+
+ // Remove calls older than 1 second
+ const recentCalls = calls.filter((time) => now - time < 1000)
+
+ if (recentCalls.length > limitPerSecond) {
+ throw new Error('Rate limit exceeded')
+ }
+
+ return {
+ text: 'Rate limited response',
+ finishReason: 'stop' as const,
+ usage: { promptTokens: 10, completionTokens: 5, totalTokens: 15 },
+ rawCall: { rawPrompt: null, rawSettings: {} },
+ rawResponse: { headers: {} },
+ warnings: []
+ }
+ })
+ })
+}
+
+/**
+ * Validates streaming response structure
+ */
+export function validateStreamChunk(chunk: any): void {
+ expect(chunk).toBeDefined()
+ expect(chunk).toHaveProperty('type')
+
+ if (chunk.type === 'text-delta') {
+ expect(chunk).toHaveProperty('textDelta')
+ expect(typeof chunk.textDelta).toBe('string')
+ } else if (chunk.type === 'finish') {
+ expect(chunk).toHaveProperty('finishReason')
+ expect(chunk).toHaveProperty('usage')
+ } else if (chunk.type === 'tool-call') {
+ expect(chunk).toHaveProperty('toolCallId')
+ expect(chunk).toHaveProperty('toolName')
+ expect(chunk).toHaveProperty('args')
+ }
+}
+
+/**
+ * Creates a test logger that captures log messages
+ */
+export function createTestLogger() {
+ const logs: Array<{ level: string; message: string; meta?: any }> = []
+
+ return {
+ info: (message: string, meta?: any) => logs.push({ level: 'info', message, meta }),
+ warn: (message: string, meta?: any) => logs.push({ level: 'warn', message, meta }),
+ error: (message: string, meta?: any) => logs.push({ level: 'error', message, meta }),
+ debug: (message: string, meta?: any) => logs.push({ level: 'debug', message, meta }),
+ getLogs: () => logs,
+ clear: () => {
+ logs.length = 0
+ }
+ }
+}
diff --git a/packages/aiCore/src/__tests__/index.ts b/packages/aiCore/src/__tests__/index.ts
new file mode 100644
index 0000000000..23ecd167a4
--- /dev/null
+++ b/packages/aiCore/src/__tests__/index.ts
@@ -0,0 +1,12 @@
+/**
+ * Test Infrastructure Exports
+ * Central export point for all test utilities, fixtures, and helpers
+ */
+
+// Fixtures
+export * from './fixtures/mock-providers'
+export * from './fixtures/mock-responses'
+
+// Helpers
+export * from './helpers/provider-test-utils'
+export * from './helpers/test-utils'
diff --git a/packages/aiCore/src/__tests__/mocks/ai-sdk-provider.ts b/packages/aiCore/src/__tests__/mocks/ai-sdk-provider.ts
new file mode 100644
index 0000000000..57dcdd0fd1
--- /dev/null
+++ b/packages/aiCore/src/__tests__/mocks/ai-sdk-provider.ts
@@ -0,0 +1,35 @@
+/**
+ * Mock for @cherrystudio/ai-sdk-provider
+ * This mock is used in tests to avoid importing the actual package
+ */
+
+export type CherryInProviderSettings = {
+ apiKey?: string
+ baseURL?: string
+}
+
+// oxlint-disable-next-line no-unused-vars
+export const createCherryIn = (_options?: CherryInProviderSettings) => ({
+ // oxlint-disable-next-line no-unused-vars
+ languageModel: (_modelId: string) => ({
+ specificationVersion: 'v1',
+ provider: 'cherryin',
+ modelId: 'mock-model',
+ doGenerate: async () => ({ text: 'mock response' }),
+ doStream: async () => ({ stream: (async function* () {})() })
+ }),
+ // oxlint-disable-next-line no-unused-vars
+ chat: (_modelId: string) => ({
+ specificationVersion: 'v1',
+ provider: 'cherryin-chat',
+ modelId: 'mock-model',
+ doGenerate: async () => ({ text: 'mock response' }),
+ doStream: async () => ({ stream: (async function* () {})() })
+ }),
+ // oxlint-disable-next-line no-unused-vars
+ textEmbeddingModel: (_modelId: string) => ({
+ specificationVersion: 'v1',
+ provider: 'cherryin',
+ modelId: 'mock-embedding-model'
+ })
+})
diff --git a/packages/aiCore/src/__tests__/setup.ts b/packages/aiCore/src/__tests__/setup.ts
new file mode 100644
index 0000000000..1e35458ad6
--- /dev/null
+++ b/packages/aiCore/src/__tests__/setup.ts
@@ -0,0 +1,9 @@
+/**
+ * Vitest Setup File
+ * Global test configuration and mocks for @cherrystudio/ai-core package
+ */
+
+// Mock Vite SSR helper to avoid Node environment errors
+;(globalThis as any).__vite_ssr_exportName__ = (_name: string, value: any) => value
+
+// Note: @cherrystudio/ai-sdk-provider is mocked via alias in vitest.config.ts
diff --git a/packages/aiCore/src/core/options/__tests__/factory.test.ts b/packages/aiCore/src/core/options/__tests__/factory.test.ts
new file mode 100644
index 0000000000..86f8017818
--- /dev/null
+++ b/packages/aiCore/src/core/options/__tests__/factory.test.ts
@@ -0,0 +1,109 @@
+import { describe, expect, it } from 'vitest'
+
+import { createOpenAIOptions, createOpenRouterOptions, mergeProviderOptions } from '../factory'
+
+describe('mergeProviderOptions', () => {
+ it('deep merges provider options for the same provider', () => {
+ const reasoningOptions = createOpenRouterOptions({
+ reasoning: {
+ enabled: true,
+ effort: 'medium'
+ }
+ })
+ const webSearchOptions = createOpenRouterOptions({
+ plugins: [{ id: 'web', max_results: 5 }]
+ })
+
+ const merged = mergeProviderOptions(reasoningOptions, webSearchOptions)
+
+ expect(merged.openrouter).toEqual({
+ reasoning: {
+ enabled: true,
+ effort: 'medium'
+ },
+ plugins: [{ id: 'web', max_results: 5 }]
+ })
+ })
+
+ it('preserves options from other providers while merging', () => {
+ const openRouter = createOpenRouterOptions({
+ reasoning: { enabled: true }
+ })
+ const openAI = createOpenAIOptions({
+ reasoningEffort: 'low'
+ })
+ const merged = mergeProviderOptions(openRouter, openAI)
+
+ expect(merged.openrouter).toEqual({ reasoning: { enabled: true } })
+ expect(merged.openai).toEqual({ reasoningEffort: 'low' })
+ })
+
+ it('overwrites primitive values with later values', () => {
+ const first = createOpenAIOptions({
+ reasoningEffort: 'low',
+ user: 'user-123'
+ })
+ const second = createOpenAIOptions({
+ reasoningEffort: 'high',
+ maxToolCalls: 5
+ })
+
+ const merged = mergeProviderOptions(first, second)
+
+ expect(merged.openai).toEqual({
+ reasoningEffort: 'high', // overwritten by second
+ user: 'user-123', // preserved from first
+ maxToolCalls: 5 // added from second
+ })
+ })
+
+ it('overwrites arrays with later values instead of merging', () => {
+ const first = createOpenRouterOptions({
+ models: ['gpt-4', 'gpt-3.5-turbo']
+ })
+ const second = createOpenRouterOptions({
+ models: ['claude-3-opus', 'claude-3-sonnet']
+ })
+
+ const merged = mergeProviderOptions(first, second)
+
+ // Array is completely replaced, not merged
+ expect(merged.openrouter?.models).toEqual(['claude-3-opus', 'claude-3-sonnet'])
+ })
+
+ it('deeply merges nested objects while overwriting primitives', () => {
+ const first = createOpenRouterOptions({
+ reasoning: {
+ enabled: true,
+ effort: 'low'
+ },
+ user: 'user-123'
+ })
+ const second = createOpenRouterOptions({
+ reasoning: {
+ effort: 'high',
+ max_tokens: 500
+ },
+ user: 'user-456'
+ })
+
+ const merged = mergeProviderOptions(first, second)
+
+ expect(merged.openrouter).toEqual({
+ reasoning: {
+ enabled: true, // preserved from first
+ effort: 'high', // overwritten by second
+ max_tokens: 500 // added from second
+ },
+ user: 'user-456' // overwritten by second
+ })
+ })
+
+ it('replaces arrays instead of merging them', () => {
+ const first = createOpenRouterOptions({ plugins: [{ id: 'old' }] })
+ const second = createOpenRouterOptions({ plugins: [{ id: 'new' }] })
+ const merged = mergeProviderOptions(first, second)
+ // @ts-expect-error type-check for openrouter options is skipped. see function signature of createOpenRouterOptions
+ expect(merged.openrouter?.plugins).toEqual([{ id: 'new' }])
+ })
+})
diff --git a/packages/aiCore/src/core/options/factory.ts b/packages/aiCore/src/core/options/factory.ts
index ecd53e6330..1e493b2337 100644
--- a/packages/aiCore/src/core/options/factory.ts
+++ b/packages/aiCore/src/core/options/factory.ts
@@ -26,13 +26,65 @@ export function createGenericProviderOptions(
return { [provider]: options } as Record>
}
+type PlainObject = Record
+
+const isPlainObject = (value: unknown): value is PlainObject => {
+ return typeof value === 'object' && value !== null && !Array.isArray(value)
+}
+
+function deepMergeObjects(target: T, source: PlainObject): T {
+ const result: PlainObject = { ...target }
+ Object.entries(source).forEach(([key, value]) => {
+ if (isPlainObject(value) && isPlainObject(result[key])) {
+ result[key] = deepMergeObjects(result[key], value)
+ } else {
+ result[key] = value
+ }
+ })
+ return result as T
+}
+
/**
- * 合并多个供应商的options
- * @param optionsMap 包含多个供应商选项的对象
- * @returns 合并后的TypedProviderOptions
+ * Deep-merge multiple provider-specific options.
+ * Nested objects are recursively merged; primitive values are overwritten.
+ *
+ * When the same key appears in multiple options:
+ * - If both values are plain objects: they are deeply merged (recursive merge)
+ * - If values are primitives/arrays: the later value overwrites the earlier one
+ *
+ * @example
+ * mergeProviderOptions(
+ * { openrouter: { reasoning: { enabled: true, effort: 'low' }, user: 'user-123' } },
+ * { openrouter: { reasoning: { effort: 'high', max_tokens: 500 }, models: ['gpt-4'] } }
+ * )
+ * // Result: {
+ * // openrouter: {
+ * // reasoning: { enabled: true, effort: 'high', max_tokens: 500 },
+ * // user: 'user-123',
+ * // models: ['gpt-4']
+ * // }
+ * // }
+ *
+ * @param optionsMap Objects containing options for multiple providers
+ * @returns Fully merged TypedProviderOptions
*/
export function mergeProviderOptions(...optionsMap: Partial[]): TypedProviderOptions {
- return Object.assign({}, ...optionsMap)
+ return optionsMap.reduce((acc, options) => {
+ if (!options) {
+ return acc
+ }
+ Object.entries(options).forEach(([providerId, providerOptions]) => {
+ if (!providerOptions) {
+ return
+ }
+ if (acc[providerId]) {
+ acc[providerId] = deepMergeObjects(acc[providerId] as PlainObject, providerOptions as PlainObject)
+ } else {
+ acc[providerId] = providerOptions as any
+ }
+ })
+ return acc
+ }, {} as TypedProviderOptions)
}
/**
diff --git a/packages/aiCore/src/core/plugins/built-in/index.ts b/packages/aiCore/src/core/plugins/built-in/index.ts
index 1f8916b09a..d7f35d0cd1 100644
--- a/packages/aiCore/src/core/plugins/built-in/index.ts
+++ b/packages/aiCore/src/core/plugins/built-in/index.ts
@@ -4,12 +4,7 @@
*/
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
-export { googleToolsPlugin } from './googleToolsPlugin'
-export { createLoggingPlugin } from './logging'
-export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin'
-export type {
- PromptToolUseConfig,
- ToolUseRequestContext,
- ToolUseResult
-} from './toolUsePlugin/type'
-export { webSearchPlugin, type WebSearchPluginConfig } from './webSearchPlugin'
+export * from './googleToolsPlugin'
+export * from './toolUsePlugin/promptToolUsePlugin'
+export * from './toolUsePlugin/type'
+export * from './webSearchPlugin'
diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts
index 42bd17e09c..61e6f49b81 100644
--- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts
+++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts
@@ -1,9 +1,10 @@
-import type { anthropic } from '@ai-sdk/anthropic'
-import type { google } from '@ai-sdk/google'
-import type { openai } from '@ai-sdk/openai'
+import { anthropic } from '@ai-sdk/anthropic'
+import { google } from '@ai-sdk/google'
+import { openai } from '@ai-sdk/openai'
import type { InferToolInput, InferToolOutput } from 'ai'
import { type Tool } from 'ai'
+import { createOpenRouterOptions, createXaiOptions, mergeProviderOptions } from '../../../options'
import type { ProviderOptionsMap } from '../../../options/types'
import type { OpenRouterSearchConfig } from './openrouter'
@@ -34,7 +35,6 @@ export interface WebSearchPluginConfig {
anthropic?: AnthropicSearchConfig
xai?: ProviderOptionsMap['xai']['searchParameters']
google?: GoogleSearchConfig
- 'google-vertex'?: GoogleSearchConfig
openrouter?: OpenRouterSearchConfig
}
@@ -43,7 +43,6 @@ export interface WebSearchPluginConfig {
*/
export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
google: {},
- 'google-vertex': {},
openai: {},
'openai-chat': {},
xai: {
@@ -95,3 +94,29 @@ export type WebSearchToolInputSchema = {
google: InferToolInput
'openai-chat': InferToolInput
}
+
+export const switchWebSearchTool = (config: WebSearchPluginConfig, params: any) => {
+ if (config.openai) {
+ if (!params.tools) params.tools = {}
+ params.tools.web_search = openai.tools.webSearch(config.openai)
+ } else if (config['openai-chat']) {
+ if (!params.tools) params.tools = {}
+ params.tools.web_search_preview = openai.tools.webSearchPreview(config['openai-chat'])
+ } else if (config.anthropic) {
+ if (!params.tools) params.tools = {}
+ params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic)
+ } else if (config.google) {
+ // case 'google-vertex':
+ if (!params.tools) params.tools = {}
+ params.tools.web_search = google.tools.googleSearch(config.google || {})
+ } else if (config.xai) {
+ const searchOptions = createXaiOptions({
+ searchParameters: { ...config.xai, mode: 'on' }
+ })
+ params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
+ } else if (config.openrouter) {
+ const searchOptions = createOpenRouterOptions(config.openrouter)
+ params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
+ }
+ return params
+}
diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts
index 34eba79637..a46df7dd4c 100644
--- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts
+++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts
@@ -2,15 +2,10 @@
* Web Search Plugin
* 提供统一的网络搜索能力,支持多个 AI Provider
*/
-import { anthropic } from '@ai-sdk/anthropic'
-import { google } from '@ai-sdk/google'
-import { openai } from '@ai-sdk/openai'
-import { createOpenRouterOptions, createXaiOptions, mergeProviderOptions } from '../../../options'
import { definePlugin } from '../../'
-import type { AiRequestContext } from '../../types'
import type { WebSearchPluginConfig } from './helper'
-import { DEFAULT_WEB_SEARCH_CONFIG } from './helper'
+import { DEFAULT_WEB_SEARCH_CONFIG, switchWebSearchTool } from './helper'
/**
* 网络搜索插件
@@ -22,64 +17,14 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
name: 'webSearch',
enforce: 'pre',
- transformParams: async (params: any, context: AiRequestContext) => {
- const { providerId } = context
- switch (providerId) {
- case 'openai': {
- if (config.openai) {
- if (!params.tools) params.tools = {}
- params.tools.web_search = openai.tools.webSearch(config.openai)
- }
- break
- }
- case 'openai-chat': {
- if (config['openai-chat']) {
- if (!params.tools) params.tools = {}
- params.tools.web_search_preview = openai.tools.webSearchPreview(config['openai-chat'])
- }
- break
- }
-
- case 'anthropic': {
- if (config.anthropic) {
- if (!params.tools) params.tools = {}
- params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic)
- }
- break
- }
-
- case 'google': {
- // case 'google-vertex':
- if (!params.tools) params.tools = {}
- params.tools.web_search = google.tools.googleSearch(config.google || {})
- break
- }
-
- case 'xai': {
- if (config.xai) {
- const searchOptions = createXaiOptions({
- searchParameters: { ...config.xai, mode: 'on' }
- })
- params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
- }
- break
- }
-
- case 'openrouter': {
- if (config.openrouter) {
- const searchOptions = createOpenRouterOptions(config.openrouter)
- params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
- }
- break
- }
- }
-
+ transformParams: async (params: any) => {
+ switchWebSearchTool(config, params)
return params
}
})
// 导出类型定义供开发者使用
-export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper'
+export * from './helper'
// 默认导出
export default webSearchPlugin
diff --git a/packages/aiCore/src/core/providers/__tests__/schemas.test.ts b/packages/aiCore/src/core/providers/__tests__/schemas.test.ts
index 82b390ba05..02fe21889a 100644
--- a/packages/aiCore/src/core/providers/__tests__/schemas.test.ts
+++ b/packages/aiCore/src/core/providers/__tests__/schemas.test.ts
@@ -19,15 +19,20 @@ describe('Provider Schemas', () => {
expect(Array.isArray(baseProviders)).toBe(true)
expect(baseProviders.length).toBeGreaterThan(0)
+ // These are the actual base providers defined in schemas.ts
const expectedIds = [
'openai',
- 'openai-responses',
+ 'openai-chat',
'openai-compatible',
'anthropic',
'google',
'xai',
'azure',
- 'deepseek'
+ 'azure-responses',
+ 'deepseek',
+ 'openrouter',
+ 'cherryin',
+ 'cherryin-chat'
]
const actualIds = baseProviders.map((p) => p.id)
expectedIds.forEach((id) => {
diff --git a/packages/aiCore/src/core/providers/index.ts b/packages/aiCore/src/core/providers/index.ts
index 3ac445cb22..b9ebd6f682 100644
--- a/packages/aiCore/src/core/providers/index.ts
+++ b/packages/aiCore/src/core/providers/index.ts
@@ -44,7 +44,7 @@ export {
// ==================== 基础数据和类型 ====================
// 基础Provider数据源
-export { baseProviderIds, baseProviders } from './schemas'
+export { baseProviderIds, baseProviders, isBaseProvider } from './schemas'
// 类型定义和Schema
export type {
diff --git a/packages/aiCore/src/core/providers/schemas.ts b/packages/aiCore/src/core/providers/schemas.ts
index 7ca4f6b0c8..43a370af9b 100644
--- a/packages/aiCore/src/core/providers/schemas.ts
+++ b/packages/aiCore/src/core/providers/schemas.ts
@@ -7,11 +7,11 @@ import { createAzure } from '@ai-sdk/azure'
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { createDeepSeek } from '@ai-sdk/deepseek'
import { createGoogleGenerativeAI } from '@ai-sdk/google'
-import { createHuggingFace } from '@ai-sdk/huggingface'
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
import type { LanguageModelV2 } from '@ai-sdk/provider'
import { createXai } from '@ai-sdk/xai'
+import { type CherryInProviderSettings, createCherryIn } from '@cherrystudio/ai-sdk-provider'
import { createOpenRouter } from '@openrouter/ai-sdk-provider'
import type { Provider } from 'ai'
import { customProvider } from 'ai'
@@ -31,7 +31,8 @@ export const baseProviderIds = [
'azure-responses',
'deepseek',
'openrouter',
- 'huggingface'
+ 'cherryin',
+ 'cherryin-chat'
] as const
/**
@@ -137,9 +138,23 @@ export const baseProviders = [
supportsImageGeneration: true
},
{
- id: 'huggingface',
- name: 'HuggingFace',
- creator: createHuggingFace,
+ id: 'cherryin',
+ name: 'CherryIN',
+ creator: createCherryIn,
+ supportsImageGeneration: true
+ },
+ {
+ id: 'cherryin-chat',
+ name: 'CherryIN Chat',
+ creator: (options: CherryInProviderSettings) => {
+ const provider = createCherryIn(options)
+ return customProvider({
+ fallbackProvider: {
+ ...provider,
+ languageModel: (modelId: string) => provider.chat(modelId)
+ }
+ })
+ },
supportsImageGeneration: true
}
] as const satisfies BaseProvider[]
diff --git a/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts b/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
index 217319aacc..56ab87dbcc 100644
--- a/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
+++ b/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
@@ -232,11 +232,13 @@ describe('RuntimeExecutor.generateImage', () => {
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
+ // transformParams receives params without model (model is handled separately)
+ // and context with core fields + dynamic fields (requestId, startTime, etc.)
expect(testPlugin.transformParams).toHaveBeenCalledWith(
- { prompt: 'A test image' },
+ expect.objectContaining({ prompt: 'A test image' }),
expect.objectContaining({
providerId: 'openai',
- modelId: 'dall-e-3'
+ model: 'dall-e-3'
})
)
@@ -273,11 +275,12 @@ describe('RuntimeExecutor.generateImage', () => {
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
+ // resolveModel receives model id and context with core fields
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
'dall-e-3',
expect.objectContaining({
providerId: 'openai',
- modelId: 'dall-e-3'
+ model: 'dall-e-3'
})
)
@@ -339,12 +342,11 @@ describe('RuntimeExecutor.generateImage', () => {
.generateImage({ model: 'invalid-model', prompt: 'A test image' })
.catch((error) => error)
- expect(thrownError).toBeInstanceOf(ImageGenerationError)
- expect(thrownError.message).toContain('Failed to generate image:')
+ // Error is thrown from pluginEngine directly as ImageModelResolutionError
+ expect(thrownError).toBeInstanceOf(ImageModelResolutionError)
+ expect(thrownError.message).toContain('Failed to resolve image model: invalid-model')
expect(thrownError.providerId).toBe('openai')
expect(thrownError.modelId).toBe('invalid-model')
- expect(thrownError.cause).toBeInstanceOf(ImageModelResolutionError)
- expect(thrownError.cause.message).toContain('Failed to resolve image model: invalid-model')
})
it('should handle ImageModelResolutionError without provider', async () => {
@@ -362,8 +364,9 @@ describe('RuntimeExecutor.generateImage', () => {
const apiError = new Error('API request failed')
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
+ // Error propagates directly from pluginEngine without wrapping
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
- 'Failed to generate image:'
+ 'API request failed'
)
})
@@ -376,8 +379,9 @@ describe('RuntimeExecutor.generateImage', () => {
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
+ // Error propagates directly from pluginEngine
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
- 'Failed to generate image:'
+ 'No image generated'
)
})
@@ -398,15 +402,17 @@ describe('RuntimeExecutor.generateImage', () => {
[errorPlugin]
)
+ // Error propagates directly from pluginEngine
await expect(executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
- 'Failed to generate image:'
+ 'Generation failed'
)
+ // onError receives the original error and context with core fields
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
- modelId: 'dall-e-3'
+ model: 'dall-e-3'
})
)
})
@@ -419,9 +425,10 @@ describe('RuntimeExecutor.generateImage', () => {
const abortController = new AbortController()
setTimeout(() => abortController.abort(), 10)
+ // Error propagates directly from pluginEngine
await expect(
executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', abortSignal: abortController.signal })
- ).rejects.toThrow('Failed to generate image:')
+ ).rejects.toThrow('Operation was aborted')
})
})
diff --git a/packages/aiCore/src/core/runtime/__tests__/generateText.test.ts b/packages/aiCore/src/core/runtime/__tests__/generateText.test.ts
new file mode 100644
index 0000000000..cb1d1d671a
--- /dev/null
+++ b/packages/aiCore/src/core/runtime/__tests__/generateText.test.ts
@@ -0,0 +1,504 @@
+/**
+ * RuntimeExecutor.generateText Comprehensive Tests
+ * Tests non-streaming text generation across all providers with various parameters
+ */
+
+import { generateText } from 'ai'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+import {
+ createMockLanguageModel,
+ mockCompleteResponses,
+ mockProviderConfigs,
+ testMessages,
+ testTools
+} from '../../../__tests__'
+import type { AiPlugin } from '../../plugins'
+import { globalRegistryManagement } from '../../providers/RegistryManagement'
+import { RuntimeExecutor } from '../executor'
+
+// Mock AI SDK - use importOriginal to keep jsonSchema and other non-mocked exports
+vi.mock('ai', async (importOriginal) => {
+ const actual = (await importOriginal()) as Record
+ return {
+ ...actual,
+ generateText: vi.fn()
+ }
+})
+
+vi.mock('../../providers/RegistryManagement', () => ({
+ globalRegistryManagement: {
+ languageModel: vi.fn()
+ },
+ DEFAULT_SEPARATOR: '|'
+}))
+
+describe('RuntimeExecutor.generateText', () => {
+ let executor: RuntimeExecutor<'openai'>
+ let mockLanguageModel: any
+
+ beforeEach(() => {
+ vi.clearAllMocks()
+
+ executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
+
+ mockLanguageModel = createMockLanguageModel({
+ provider: 'openai',
+ modelId: 'gpt-4'
+ })
+
+ vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
+ vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
+ })
+
+ describe('Basic Functionality', () => {
+ it('should generate text with minimal parameters', async () => {
+ const result = await executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.simple
+ })
+
+ expect(generateText).toHaveBeenCalledWith({
+ model: mockLanguageModel,
+ messages: testMessages.simple
+ })
+
+ expect(result.text).toBe('This is a simple response.')
+ expect(result.finishReason).toBe('stop')
+ expect(result.usage).toBeDefined()
+ })
+
+ it('should generate with system messages', async () => {
+ await executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.withSystem
+ })
+
+ expect(generateText).toHaveBeenCalledWith({
+ model: mockLanguageModel,
+ messages: testMessages.withSystem
+ })
+ })
+
+ it('should generate with conversation history', async () => {
+ await executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.conversation
+ })
+
+ expect(generateText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ messages: testMessages.conversation
+ })
+ )
+ })
+ })
+
+ describe('All Parameter Combinations', () => {
+ it('should support all parameters together', async () => {
+ await executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ temperature: 0.7,
+ maxOutputTokens: 500,
+ topP: 0.9,
+ frequencyPenalty: 0.5,
+ presencePenalty: 0.3,
+ stopSequences: ['STOP'],
+ seed: 12345
+ })
+
+ expect(generateText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ temperature: 0.7,
+ maxOutputTokens: 500,
+ topP: 0.9,
+ frequencyPenalty: 0.5,
+ presencePenalty: 0.3,
+ stopSequences: ['STOP'],
+ seed: 12345
+ })
+ )
+ })
+
+ it('should support partial parameters', async () => {
+ await executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ temperature: 0.5,
+ maxOutputTokens: 100
+ })
+
+ expect(generateText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ temperature: 0.5,
+ maxOutputTokens: 100
+ })
+ )
+ })
+ })
+
+ describe('Tool Calling', () => {
+ beforeEach(() => {
+ vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.withToolCalls as any)
+ })
+
+ it('should support tool calling', async () => {
+ const result = await executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.toolUse,
+ tools: testTools
+ })
+
+ expect(generateText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ tools: testTools
+ })
+ )
+
+ expect(result.toolCalls).toBeDefined()
+ expect(result.toolCalls).toHaveLength(1)
+ })
+
+ it('should support toolChoice auto', async () => {
+ await executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.toolUse,
+ tools: testTools,
+ toolChoice: 'auto'
+ })
+
+ expect(generateText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ toolChoice: 'auto'
+ })
+ )
+ })
+
+ it('should support toolChoice required', async () => {
+ await executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.toolUse,
+ tools: testTools,
+ toolChoice: 'required'
+ })
+
+ expect(generateText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ toolChoice: 'required'
+ })
+ )
+ })
+
+ it('should support toolChoice none', async () => {
+ vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
+
+ await executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ tools: testTools,
+ toolChoice: 'none'
+ })
+
+ expect(generateText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ toolChoice: 'none'
+ })
+ )
+ })
+
+ it('should support specific tool selection', async () => {
+ await executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.toolUse,
+ tools: testTools,
+ toolChoice: {
+ type: 'tool',
+ toolName: 'getWeather'
+ }
+ })
+
+ expect(generateText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ toolChoice: {
+ type: 'tool',
+ toolName: 'getWeather'
+ }
+ })
+ )
+ })
+ })
+
+ describe('Multiple Providers', () => {
+ it('should work with Anthropic provider', async () => {
+ const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
+
+ const anthropicModel = createMockLanguageModel({
+ provider: 'anthropic',
+ modelId: 'claude-3-5-sonnet-20241022'
+ })
+
+ vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(anthropicModel)
+
+ await anthropicExecutor.generateText({
+ model: 'claude-3-5-sonnet-20241022',
+ messages: testMessages.simple
+ })
+
+ expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('anthropic|claude-3-5-sonnet-20241022')
+ })
+
+ it('should work with Google provider', async () => {
+ const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
+
+ const googleModel = createMockLanguageModel({
+ provider: 'google',
+ modelId: 'gemini-2.0-flash-exp'
+ })
+
+ vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(googleModel)
+
+ await googleExecutor.generateText({
+ model: 'gemini-2.0-flash-exp',
+ messages: testMessages.simple
+ })
+
+ expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('google|gemini-2.0-flash-exp')
+ })
+
+ it('should work with xAI provider', async () => {
+ const xaiExecutor = RuntimeExecutor.create('xai', mockProviderConfigs.xai)
+
+ const xaiModel = createMockLanguageModel({
+ provider: 'xai',
+ modelId: 'grok-2-latest'
+ })
+
+ vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(xaiModel)
+
+ await xaiExecutor.generateText({
+ model: 'grok-2-latest',
+ messages: testMessages.simple
+ })
+
+ expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('xai|grok-2-latest')
+ })
+
+ it('should work with DeepSeek provider', async () => {
+ const deepseekExecutor = RuntimeExecutor.create('deepseek', mockProviderConfigs.deepseek)
+
+ const deepseekModel = createMockLanguageModel({
+ provider: 'deepseek',
+ modelId: 'deepseek-chat'
+ })
+
+ vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(deepseekModel)
+
+ await deepseekExecutor.generateText({
+ model: 'deepseek-chat',
+ messages: testMessages.simple
+ })
+
+ expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('deepseek|deepseek-chat')
+ })
+ })
+
+ describe('Plugin Integration', () => {
+ it('should execute all plugin hooks', async () => {
+ const pluginCalls: string[] = []
+
+ const testPlugin: AiPlugin = {
+ name: 'test-plugin',
+ onRequestStart: vi.fn(async () => {
+ pluginCalls.push('onRequestStart')
+ }),
+ transformParams: vi.fn(async (params) => {
+ pluginCalls.push('transformParams')
+ return { ...params, temperature: 0.8 }
+ }),
+ transformResult: vi.fn(async (result) => {
+ pluginCalls.push('transformResult')
+ return { ...result, text: result.text + ' [modified]' }
+ }),
+ onRequestEnd: vi.fn(async () => {
+ pluginCalls.push('onRequestEnd')
+ })
+ }
+
+ const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
+
+ const result = await executorWithPlugin.generateText({
+ model: 'gpt-4',
+ messages: testMessages.simple
+ })
+
+ expect(pluginCalls).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
+
+ // Verify transformed parameters
+ expect(generateText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ temperature: 0.8
+ })
+ )
+
+ // Verify transformed result
+ expect(result.text).toContain('[modified]')
+ })
+
+ it('should handle multiple plugins in order', async () => {
+ const pluginOrder: string[] = []
+
+ const plugin1: AiPlugin = {
+ name: 'plugin-1',
+ transformParams: vi.fn(async (params) => {
+ pluginOrder.push('plugin-1')
+ return { ...params, temperature: 0.5 }
+ })
+ }
+
+ const plugin2: AiPlugin = {
+ name: 'plugin-2',
+ transformParams: vi.fn(async (params) => {
+ pluginOrder.push('plugin-2')
+ return { ...params, maxTokens: 200 }
+ })
+ }
+
+ const executorWithPlugins = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [plugin1, plugin2])
+
+ await executorWithPlugins.generateText({
+ model: 'gpt-4',
+ messages: testMessages.simple
+ })
+
+ expect(pluginOrder).toEqual(['plugin-1', 'plugin-2'])
+
+ expect(generateText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ temperature: 0.5,
+ maxTokens: 200
+ })
+ )
+ })
+ })
+
+ describe('Error Handling', () => {
+ it('should handle API errors', async () => {
+ const error = new Error('API request failed')
+ vi.mocked(generateText).mockRejectedValue(error)
+
+ await expect(
+ executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.simple
+ })
+ ).rejects.toThrow('API request failed')
+ })
+
+ it('should execute onError plugin hook', async () => {
+ const error = new Error('Generation failed')
+ vi.mocked(generateText).mockRejectedValue(error)
+
+ const errorPlugin: AiPlugin = {
+ name: 'error-handler',
+ onError: vi.fn()
+ }
+
+ const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
+
+ await expect(
+ executorWithPlugin.generateText({
+ model: 'gpt-4',
+ messages: testMessages.simple
+ })
+ ).rejects.toThrow('Generation failed')
+
+ // onError receives the original error and context with core fields
+ expect(errorPlugin.onError).toHaveBeenCalledWith(
+ error,
+ expect.objectContaining({
+ providerId: 'openai',
+ model: 'gpt-4'
+ })
+ )
+ })
+
+ it('should handle model not found error', async () => {
+ const error = new Error('Model not found: invalid-model')
+ vi.mocked(globalRegistryManagement.languageModel).mockImplementation(() => {
+ throw error
+ })
+
+ await expect(
+ executor.generateText({
+ model: 'invalid-model',
+ messages: testMessages.simple
+ })
+ ).rejects.toThrow('Model not found')
+ })
+ })
+
+ describe('Usage and Metadata', () => {
+ it('should return usage information', async () => {
+ const result = await executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.simple
+ })
+
+ expect(result.usage).toBeDefined()
+ expect(result.usage.inputTokens).toBe(15)
+ expect(result.usage.outputTokens).toBe(8)
+ expect(result.usage.totalTokens).toBe(23)
+ })
+
+ it('should handle warnings', async () => {
+ vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.withWarnings as any)
+
+ const result = await executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ temperature: 2.5 // Unsupported value
+ })
+
+ expect(result.warnings).toBeDefined()
+ expect(result.warnings).toHaveLength(1)
+ expect(result.warnings![0].type).toBe('unsupported-setting')
+ })
+ })
+
+ describe('Abort Signal', () => {
+ it('should support abort signal', async () => {
+ const abortController = new AbortController()
+
+ await executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ abortSignal: abortController.signal
+ })
+
+ expect(generateText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ abortSignal: abortController.signal
+ })
+ )
+ })
+
+ it('should handle aborted request', async () => {
+ const abortError = new Error('Request aborted')
+ abortError.name = 'AbortError'
+
+ vi.mocked(generateText).mockRejectedValue(abortError)
+
+ const abortController = new AbortController()
+ abortController.abort()
+
+ await expect(
+ executor.generateText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ abortSignal: abortController.signal
+ })
+ ).rejects.toThrow('Request aborted')
+ })
+ })
+})
diff --git a/packages/aiCore/src/core/runtime/__tests__/streamText.test.ts b/packages/aiCore/src/core/runtime/__tests__/streamText.test.ts
new file mode 100644
index 0000000000..49253594cc
--- /dev/null
+++ b/packages/aiCore/src/core/runtime/__tests__/streamText.test.ts
@@ -0,0 +1,531 @@
+/**
+ * RuntimeExecutor.streamText Comprehensive Tests
+ * Tests streaming text generation across all providers with various parameters
+ */
+
+import { streamText } from 'ai'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+import { collectStreamChunks, createMockLanguageModel, mockProviderConfigs, testMessages } from '../../../__tests__'
+import type { AiPlugin } from '../../plugins'
+import { globalRegistryManagement } from '../../providers/RegistryManagement'
+import { RuntimeExecutor } from '../executor'
+
+// Mock AI SDK - use importOriginal to keep jsonSchema and other non-mocked exports
+vi.mock('ai', async (importOriginal) => {
+ const actual = (await importOriginal()) as Record
+ return {
+ ...actual,
+ streamText: vi.fn()
+ }
+})
+
+vi.mock('../../providers/RegistryManagement', () => ({
+ globalRegistryManagement: {
+ languageModel: vi.fn()
+ },
+ DEFAULT_SEPARATOR: '|'
+}))
+
+describe('RuntimeExecutor.streamText', () => {
+ let executor: RuntimeExecutor<'openai'>
+ let mockLanguageModel: any
+
+ beforeEach(() => {
+ vi.clearAllMocks()
+
+ executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
+
+ mockLanguageModel = createMockLanguageModel({
+ provider: 'openai',
+ modelId: 'gpt-4'
+ })
+
+ vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
+ })
+
+ describe('Basic Functionality', () => {
+ it('should stream text with minimal parameters', async () => {
+ const mockStream = {
+ textStream: (async function* () {
+ yield 'Hello'
+ yield ' '
+ yield 'World'
+ })(),
+ fullStream: (async function* () {
+ yield { type: 'text-delta', textDelta: 'Hello' }
+ yield { type: 'text-delta', textDelta: ' ' }
+ yield { type: 'text-delta', textDelta: 'World' }
+ })(),
+ usage: Promise.resolve({ promptTokens: 5, completionTokens: 3, totalTokens: 8 })
+ }
+
+ vi.mocked(streamText).mockResolvedValue(mockStream as any)
+
+ const result = await executor.streamText({
+ model: 'gpt-4',
+ messages: testMessages.simple
+ })
+
+ expect(streamText).toHaveBeenCalledWith({
+ model: mockLanguageModel,
+ messages: testMessages.simple
+ })
+
+ const chunks = await collectStreamChunks(result.textStream)
+ expect(chunks).toEqual(['Hello', ' ', 'World'])
+ })
+
+ it('should stream with system messages', async () => {
+ const mockStream = {
+ textStream: (async function* () {
+ yield 'Response'
+ })(),
+ fullStream: (async function* () {
+ yield { type: 'text-delta', textDelta: 'Response' }
+ })()
+ }
+
+ vi.mocked(streamText).mockResolvedValue(mockStream as any)
+
+ await executor.streamText({
+ model: 'gpt-4',
+ messages: testMessages.withSystem
+ })
+
+ expect(streamText).toHaveBeenCalledWith({
+ model: mockLanguageModel,
+ messages: testMessages.withSystem
+ })
+ })
+
+ it('should stream multi-turn conversations', async () => {
+ const mockStream = {
+ textStream: (async function* () {
+ yield 'Multi-turn response'
+ })(),
+ fullStream: (async function* () {
+ yield { type: 'text-delta', textDelta: 'Multi-turn response' }
+ })()
+ }
+
+ vi.mocked(streamText).mockResolvedValue(mockStream as any)
+
+ await executor.streamText({
+ model: 'gpt-4',
+ messages: testMessages.multiTurn
+ })
+
+ expect(streamText).toHaveBeenCalled()
+ expect(streamText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ messages: testMessages.multiTurn
+ })
+ )
+ })
+ })
+
+ describe('Temperature Parameter', () => {
+ const temperatures = [0, 0.3, 0.5, 0.7, 0.9, 1.0, 1.5, 2.0]
+
+ it.each(temperatures)('should support temperature=%s', async (temperature) => {
+ const mockStream = {
+ textStream: (async function* () {
+ yield 'Response'
+ })(),
+ fullStream: (async function* () {
+ yield { type: 'text-delta', textDelta: 'Response' }
+ })()
+ }
+
+ vi.mocked(streamText).mockResolvedValue(mockStream as any)
+
+ await executor.streamText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ temperature
+ })
+
+ expect(streamText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ temperature
+ })
+ )
+ })
+ })
+
+ describe('Max Tokens Parameter', () => {
+ const maxTokensValues = [10, 50, 100, 500, 1000, 2000, 4000]
+
+ it.each(maxTokensValues)('should support maxOutputTokens=%s', async (maxOutputTokens) => {
+ const mockStream = {
+ textStream: (async function* () {
+ yield 'Response'
+ })(),
+ fullStream: (async function* () {
+ yield { type: 'text-delta', textDelta: 'Response' }
+ })()
+ }
+
+ vi.mocked(streamText).mockResolvedValue(mockStream as any)
+
+ await executor.streamText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ maxOutputTokens
+ })
+
+ // Parameters are passed through without transformation
+ expect(streamText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ maxOutputTokens
+ })
+ )
+ })
+ })
+
+ describe('Top P Parameter', () => {
+ const topPValues = [0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 1.0]
+
+ it.each(topPValues)('should support topP=%s', async (topP) => {
+ const mockStream = {
+ textStream: (async function* () {
+ yield 'Response'
+ })(),
+ fullStream: (async function* () {
+ yield { type: 'text-delta', textDelta: 'Response' }
+ })()
+ }
+
+ vi.mocked(streamText).mockResolvedValue(mockStream as any)
+
+ await executor.streamText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ topP
+ })
+
+ expect(streamText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ topP
+ })
+ )
+ })
+ })
+
+ describe('Frequency and Presence Penalty', () => {
+ it('should support frequency penalty', async () => {
+ const penalties = [-2.0, -1.0, 0, 0.5, 1.0, 1.5, 2.0]
+
+ for (const frequencyPenalty of penalties) {
+ vi.clearAllMocks()
+
+ const mockStream = {
+ textStream: (async function* () {
+ yield 'Response'
+ })(),
+ fullStream: (async function* () {
+ yield { type: 'text-delta', textDelta: 'Response' }
+ })()
+ }
+
+ vi.mocked(streamText).mockResolvedValue(mockStream as any)
+
+ await executor.streamText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ frequencyPenalty
+ })
+
+ expect(streamText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ frequencyPenalty
+ })
+ )
+ }
+ })
+
+ it('should support presence penalty', async () => {
+ const penalties = [-2.0, -1.0, 0, 0.5, 1.0, 1.5, 2.0]
+
+ for (const presencePenalty of penalties) {
+ vi.clearAllMocks()
+
+ const mockStream = {
+ textStream: (async function* () {
+ yield 'Response'
+ })(),
+ fullStream: (async function* () {
+ yield { type: 'text-delta', textDelta: 'Response' }
+ })()
+ }
+
+ vi.mocked(streamText).mockResolvedValue(mockStream as any)
+
+ await executor.streamText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ presencePenalty
+ })
+
+ expect(streamText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ presencePenalty
+ })
+ )
+ }
+ })
+
+ it('should support both penalties together', async () => {
+ const mockStream = {
+ textStream: (async function* () {
+ yield 'Response'
+ })(),
+ fullStream: (async function* () {
+ yield { type: 'text-delta', textDelta: 'Response' }
+ })()
+ }
+
+ vi.mocked(streamText).mockResolvedValue(mockStream as any)
+
+ await executor.streamText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ frequencyPenalty: 0.5,
+ presencePenalty: 0.5
+ })
+
+ expect(streamText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ frequencyPenalty: 0.5,
+ presencePenalty: 0.5
+ })
+ )
+ })
+ })
+
+ describe('Seed Parameter', () => {
+ it('should support seed for deterministic output', async () => {
+ const seeds = [0, 12345, 67890, 999999]
+
+ for (const seed of seeds) {
+ vi.clearAllMocks()
+
+ const mockStream = {
+ textStream: (async function* () {
+ yield 'Response'
+ })(),
+ fullStream: (async function* () {
+ yield { type: 'text-delta', textDelta: 'Response' }
+ })()
+ }
+
+ vi.mocked(streamText).mockResolvedValue(mockStream as any)
+
+ await executor.streamText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ seed
+ })
+
+ expect(streamText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ seed
+ })
+ )
+ }
+ })
+ })
+
+ describe('Abort Signal', () => {
+ it('should support abort signal', async () => {
+ const abortController = new AbortController()
+
+ const mockStream = {
+ textStream: (async function* () {
+ yield 'Response'
+ })(),
+ fullStream: (async function* () {
+ yield { type: 'text-delta', textDelta: 'Response' }
+ })()
+ }
+
+ vi.mocked(streamText).mockResolvedValue(mockStream as any)
+
+ await executor.streamText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ abortSignal: abortController.signal
+ })
+
+ expect(streamText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ abortSignal: abortController.signal
+ })
+ )
+ })
+
+ it('should handle abort during streaming', async () => {
+ const abortController = new AbortController()
+
+ const mockStream = {
+ textStream: (async function* () {
+ yield 'Start'
+ // Simulate abort
+ abortController.abort()
+ throw new Error('Aborted')
+ })(),
+ fullStream: (async function* () {
+ yield { type: 'text-delta', textDelta: 'Start' }
+ throw new Error('Aborted')
+ })()
+ }
+
+ vi.mocked(streamText).mockResolvedValue(mockStream as any)
+
+ const result = await executor.streamText({
+ model: 'gpt-4',
+ messages: testMessages.simple,
+ abortSignal: abortController.signal
+ })
+
+ await expect(async () => {
+ // oxlint-disable-next-line no-unused-vars
+ for await (const _chunk of result.textStream) {
+ // Stream should be interrupted
+ }
+ }).rejects.toThrow('Aborted')
+ })
+ })
+
+ describe('Plugin Integration', () => {
+ it('should execute plugins during streaming', async () => {
+ const pluginCalls: string[] = []
+
+ const testPlugin: AiPlugin = {
+ name: 'test-plugin',
+ onRequestStart: vi.fn(async () => {
+ pluginCalls.push('onRequestStart')
+ }),
+ transformParams: vi.fn(async (params) => {
+ pluginCalls.push('transformParams')
+ return { ...params, temperature: 0.5 }
+ }),
+ onRequestEnd: vi.fn(async () => {
+ pluginCalls.push('onRequestEnd')
+ })
+ }
+
+ const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
+
+ const mockStream = {
+ textStream: (async function* () {
+ yield 'Response'
+ })(),
+ fullStream: (async function* () {
+ yield { type: 'text-delta', textDelta: 'Response' }
+ })()
+ }
+
+ vi.mocked(streamText).mockResolvedValue(mockStream as any)
+
+ const result = await executorWithPlugin.streamText({
+ model: 'gpt-4',
+ messages: testMessages.simple
+ })
+
+ // Consume stream
+ // oxlint-disable-next-line no-unused-vars
+ for await (const _chunk of result.textStream) {
+ // Stream chunks
+ }
+
+ expect(pluginCalls).toContain('onRequestStart')
+ expect(pluginCalls).toContain('transformParams')
+
+ // Verify transformed parameters were used
+ expect(streamText).toHaveBeenCalledWith(
+ expect.objectContaining({
+ temperature: 0.5
+ })
+ )
+ })
+ })
+
+ describe('Full Stream with Finish Reason', () => {
+ it('should provide finish reason in full stream', async () => {
+ const mockStream = {
+ textStream: (async function* () {
+ yield 'Response'
+ })(),
+ fullStream: (async function* () {
+ yield { type: 'text-delta', textDelta: 'Response' }
+ yield {
+ type: 'finish',
+ finishReason: 'stop',
+ usage: { promptTokens: 5, completionTokens: 3, totalTokens: 8 }
+ }
+ })()
+ }
+
+ vi.mocked(streamText).mockResolvedValue(mockStream as any)
+
+ const result = await executor.streamText({
+ model: 'gpt-4',
+ messages: testMessages.simple
+ })
+
+ const fullChunks = await collectStreamChunks(result.fullStream)
+
+ expect(fullChunks).toHaveLength(2)
+ expect(fullChunks[0]).toEqual({ type: 'text-delta', textDelta: 'Response' })
+ expect(fullChunks[1]).toEqual({
+ type: 'finish',
+ finishReason: 'stop',
+ usage: { promptTokens: 5, completionTokens: 3, totalTokens: 8 }
+ })
+ })
+ })
+
+ describe('Error Handling', () => {
+ it('should handle streaming errors', async () => {
+ const error = new Error('Streaming failed')
+ vi.mocked(streamText).mockRejectedValue(error)
+
+ await expect(
+ executor.streamText({
+ model: 'gpt-4',
+ messages: testMessages.simple
+ })
+ ).rejects.toThrow('Streaming failed')
+ })
+
+ it('should execute onError plugin hook on failure', async () => {
+ const error = new Error('Stream error')
+ vi.mocked(streamText).mockRejectedValue(error)
+
+ const errorPlugin: AiPlugin = {
+ name: 'error-handler',
+ onError: vi.fn()
+ }
+
+ const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
+
+ await expect(
+ executorWithPlugin.streamText({
+ model: 'gpt-4',
+ messages: testMessages.simple
+ })
+ ).rejects.toThrow('Stream error')
+
+ // onError receives the original error and context with core fields
+ expect(errorPlugin.onError).toHaveBeenCalledWith(
+ error,
+ expect.objectContaining({
+ providerId: 'openai',
+ model: 'gpt-4'
+ })
+ )
+ })
+ })
+})
diff --git a/packages/aiCore/vitest.config.ts b/packages/aiCore/vitest.config.ts
index 0cc6b51df4..2f520ea967 100644
--- a/packages/aiCore/vitest.config.ts
+++ b/packages/aiCore/vitest.config.ts
@@ -1,12 +1,20 @@
+import path from 'node:path'
+import { fileURLToPath } from 'node:url'
+
import { defineConfig } from 'vitest/config'
+const __dirname = path.dirname(fileURLToPath(import.meta.url))
+
export default defineConfig({
test: {
- globals: true
+ globals: true,
+ setupFiles: [path.resolve(__dirname, './src/__tests__/setup.ts')]
},
resolve: {
alias: {
- '@': './src'
+ '@': path.resolve(__dirname, './src'),
+ // Mock external packages that may not be available in test environment
+ '@cherrystudio/ai-sdk-provider': path.resolve(__dirname, './src/__tests__/mocks/ai-sdk-provider.ts')
}
},
esbuild: {
diff --git a/packages/shared/IpcChannel.ts b/packages/shared/IpcChannel.ts
index 111f9304e4..31db6bece9 100644
--- a/packages/shared/IpcChannel.ts
+++ b/packages/shared/IpcChannel.ts
@@ -41,6 +41,7 @@ export enum IpcChannel {
App_SetFullScreen = 'app:set-full-screen',
App_IsFullScreen = 'app:is-full-screen',
App_GetSystemFonts = 'app:get-system-fonts',
+ APP_CrashRenderProcess = 'app:crash-render-process',
App_MacIsProcessTrusted = 'app:mac-is-process-trusted',
App_MacRequestProcessTrust = 'app:mac-request-process-trust',
@@ -189,6 +190,7 @@ export enum IpcChannel {
Fs_ReadText = 'fs:readText',
File_OpenWithRelativePath = 'file:openWithRelativePath',
File_IsTextFile = 'file:isTextFile',
+ File_ListDirectory = 'file:listDirectory',
File_GetDirectoryStructure = 'file:getDirectoryStructure',
File_CheckFileName = 'file:checkFileName',
File_ValidateNotesDirectory = 'file:validateNotesDirectory',
@@ -233,6 +235,7 @@ export enum IpcChannel {
System_GetDeviceType = 'system:getDeviceType',
System_GetHostname = 'system:getHostname',
System_GetCpuName = 'system:getCpuName',
+ System_CheckGitBash = 'system:checkGitBash',
// DevTools
System_ToggleDevTools = 'system:toggleDevTools',
diff --git a/packages/shared/anthropic/index.ts b/packages/shared/anthropic/index.ts
index b9e9cb8846..bff143d118 100644
--- a/packages/shared/anthropic/index.ts
+++ b/packages/shared/anthropic/index.ts
@@ -88,11 +88,16 @@ export function getSdkClient(
}
})
}
- const baseURL =
+ let baseURL =
provider.type === 'anthropic'
? provider.apiHost
: (provider.anthropicApiHost && provider.anthropicApiHost.trim()) || provider.apiHost
+ // Anthropic SDK automatically appends /v1 to all endpoints (like /v1/messages, /v1/models)
+ // We need to strip api version from baseURL to avoid duplication (e.g., /v3/v1/models)
+ // formatProviderApiHost adds /v1 for AI SDK compatibility, but Anthropic SDK needs it removed
+ baseURL = baseURL.replace(/\/v\d+(?:alpha|beta)?(?=\/|$)/i, '')
+
logger.debug('Anthropic API baseURL', { baseURL, providerId: provider.id })
if (provider.id === 'aihubmix') {
diff --git a/packages/shared/config/constant.ts b/packages/shared/config/constant.ts
index 3b38592005..1e02ce7706 100644
--- a/packages/shared/config/constant.ts
+++ b/packages/shared/config/constant.ts
@@ -7,6 +7,11 @@ export const documentExts = ['.pdf', '.doc', '.docx', '.pptx', '.xlsx', '.odt',
export const thirdPartyApplicationExts = ['.draftsExport']
export const bookExts = ['.epub']
+export const API_SERVER_DEFAULTS = {
+ HOST: '127.0.0.1',
+ PORT: 23333
+}
+
/**
* A flat array of all file extensions known by the linguist database.
* This is the primary source for identifying code files.
@@ -197,12 +202,22 @@ export enum FeedUrl {
GITHUB_LATEST = 'https://github.com/CherryHQ/cherry-studio/releases/latest/download'
}
+export enum UpdateConfigUrl {
+ GITHUB = 'https://raw.githubusercontent.com/CherryHQ/cherry-studio/refs/heads/x-files/app-upgrade-config/app-upgrade-config.json',
+ GITCODE = 'https://raw.gitcode.com/CherryHQ/cherry-studio/raw/x-files%2Fapp-upgrade-config/app-upgrade-config.json'
+}
+
export enum UpgradeChannel {
LATEST = 'latest', // 最新稳定版本
RC = 'rc', // 公测版本
BETA = 'beta' // 预览版本
}
+export enum UpdateMirror {
+ GITHUB = 'github',
+ GITCODE = 'gitcode'
+}
+
export const defaultTimeout = 10 * 1000 * 60
export const occupiedDirs = ['logs', 'Network', 'Partitions/webview/Network']
@@ -470,3 +485,6 @@ export const MACOS_TERMINALS_WITH_COMMANDS: TerminalConfigWithCommand[] = [
})
}
]
+
+// resources/scripts should be maintained manually
+export const HOME_CHERRY_DIR = '.cherrystudio'
diff --git a/packages/shared/config/providers.ts b/packages/shared/config/providers.ts
new file mode 100644
index 0000000000..f7744150e2
--- /dev/null
+++ b/packages/shared/config/providers.ts
@@ -0,0 +1,48 @@
+/**
+ * @fileoverview Shared provider configuration for Claude Code and Anthropic API compatibility
+ *
+ * This module defines which models from specific providers support the Anthropic API endpoint.
+ * Used by both the Code Tools page and the Anthropic SDK client.
+ */
+
+/**
+ * Silicon provider models that support Anthropic API endpoint.
+ * These models can be used with Claude Code via the Anthropic-compatible API.
+ *
+ * @see https://docs.siliconflow.cn/cn/api-reference/chat-completions/messages
+ */
+export const SILICON_ANTHROPIC_COMPATIBLE_MODELS: readonly string[] = [
+ // DeepSeek V3.1 series
+ 'Pro/deepseek-ai/DeepSeek-V3.1-Terminus',
+ 'deepseek-ai/DeepSeek-V3.1',
+ 'Pro/deepseek-ai/DeepSeek-V3.1',
+ // DeepSeek V3 series
+ 'deepseek-ai/DeepSeek-V3',
+ 'Pro/deepseek-ai/DeepSeek-V3',
+ // Moonshot/Kimi series
+ 'moonshotai/Kimi-K2-Instruct-0905',
+ 'Pro/moonshotai/Kimi-K2-Instruct-0905',
+ 'moonshotai/Kimi-Dev-72B',
+ // Baidu ERNIE
+ 'baidu/ERNIE-4.5-300B-A47B'
+]
+
+/**
+ * Creates a Set for efficient lookup of silicon Anthropic-compatible model IDs.
+ */
+const SILICON_ANTHROPIC_COMPATIBLE_MODEL_SET = new Set(SILICON_ANTHROPIC_COMPATIBLE_MODELS)
+
+/**
+ * Checks if a model ID is compatible with Anthropic API on Silicon provider.
+ *
+ * @param modelId - The model ID to check
+ * @returns true if the model supports Anthropic API endpoint
+ */
+export function isSiliconAnthropicCompatibleModel(modelId: string): boolean {
+ return SILICON_ANTHROPIC_COMPATIBLE_MODEL_SET.has(modelId)
+}
+
+/**
+ * Silicon provider's Anthropic API host URL.
+ */
+export const SILICON_ANTHROPIC_API_HOST = 'https://api.siliconflow.cn'
diff --git a/packages/shared/utils.ts b/packages/shared/utils.ts
index e87e2f2bef..a14f78958d 100644
--- a/packages/shared/utils.ts
+++ b/packages/shared/utils.ts
@@ -4,3 +4,34 @@ export const defaultAppHeaders = () => {
'X-Title': 'Cherry Studio'
}
}
+
+// Following two function are not being used for now.
+// I may use them in the future, so just keep them commented. - by eurfelux
+
+/**
+ * Converts an `undefined` value to `null`, otherwise returns the value as-is.
+ * @param value - The value to check
+ * @returns `null` if the input is `undefined`; otherwise the input value
+ */
+
+// export function toNullIfUndefined(value: T | undefined): T | null {
+// if (value === undefined) {
+// return null
+// } else {
+// return value
+// }
+// }
+
+/**
+ * Converts a `null` value to `undefined`, otherwise returns the value as-is.
+ * @param value - The value to check
+ * @returns `undefined` if the input is `null`; otherwise the input value
+ */
+
+// export function toUndefinedIfNull(value: T | null): T | undefined {
+// if (value === null) {
+// return undefined
+// } else {
+// return value
+// }
+// }
diff --git a/playwright.config.ts b/playwright.config.ts
index e12ce7ab6d..0b67f0e76f 100644
--- a/playwright.config.ts
+++ b/playwright.config.ts
@@ -1,42 +1,64 @@
-import { defineConfig, devices } from '@playwright/test'
+import { defineConfig } from '@playwright/test'
/**
- * See https://playwright.dev/docs/test-configuration.
+ * Playwright configuration for Electron e2e testing.
+ * See https://playwright.dev/docs/test-configuration
*/
export default defineConfig({
- // Look for test files, relative to this configuration file.
- testDir: './tests/e2e',
- /* Run tests in files in parallel */
- fullyParallel: true,
- /* Fail the build on CI if you accidentally left test.only in the source code. */
- forbidOnly: !!process.env.CI,
- /* Retry on CI only */
- retries: process.env.CI ? 2 : 0,
- /* Opt out of parallel tests on CI. */
- workers: process.env.CI ? 1 : undefined,
- /* Reporter to use. See https://playwright.dev/docs/test-reporters */
- reporter: 'html',
- /* Shared settings for all the projects below. See https://playwright.dev/docs/api/class-testoptions. */
- use: {
- /* Base URL to use in actions like `await page.goto('/')`. */
- // baseURL: 'http://localhost:3000',
+ // Look for test files in the specs directory
+ testDir: './tests/e2e/specs',
- /* Collect trace when retrying the failed test. See https://playwright.dev/docs/trace-viewer */
- trace: 'on-first-retry'
+ // Global timeout for each test
+ timeout: 60000,
+
+ // Assertion timeout
+ expect: {
+ timeout: 10000
},
- /* Configure projects for major browsers */
+ // Electron apps should run tests sequentially to avoid conflicts
+ fullyParallel: false,
+ workers: 1,
+
+ // Fail the build on CI if you accidentally left test.only in the source code
+ forbidOnly: !!process.env.CI,
+
+ // Retry on CI only
+ retries: process.env.CI ? 2 : 0,
+
+ // Reporter configuration
+ reporter: [['html', { outputFolder: 'playwright-report' }], ['list']],
+
+ // Global setup and teardown
+ globalSetup: './tests/e2e/global-setup.ts',
+ globalTeardown: './tests/e2e/global-teardown.ts',
+
+ // Output directory for test artifacts
+ outputDir: './test-results',
+
+ // Shared settings for all tests
+ use: {
+ // Collect trace when retrying the failed test
+ trace: 'retain-on-failure',
+
+ // Take screenshot only on failure
+ screenshot: 'only-on-failure',
+
+ // Record video only on failure
+ video: 'retain-on-failure',
+
+ // Action timeout
+ actionTimeout: 15000,
+
+ // Navigation timeout
+ navigationTimeout: 30000
+ },
+
+ // Single project for Electron testing
projects: [
{
- name: 'chromium',
- use: { ...devices['Desktop Chrome'] }
+ name: 'electron',
+ testMatch: '**/*.spec.ts'
}
]
-
- /* Run your local dev server before starting the tests */
- // webServer: {
- // command: 'npm run start',
- // url: 'http://localhost:3000',
- // reuseExistingServer: !process.env.CI,
- // },
})
diff --git a/resources/database/drizzle/0002_wealthy_naoko.sql b/resources/database/drizzle/0002_wealthy_naoko.sql
new file mode 100644
index 0000000000..c369ccf61f
--- /dev/null
+++ b/resources/database/drizzle/0002_wealthy_naoko.sql
@@ -0,0 +1 @@
+ALTER TABLE `sessions` ADD `slash_commands` text;
\ No newline at end of file
diff --git a/resources/database/drizzle/meta/0002_snapshot.json b/resources/database/drizzle/meta/0002_snapshot.json
new file mode 100644
index 0000000000..ef5eefcb65
--- /dev/null
+++ b/resources/database/drizzle/meta/0002_snapshot.json
@@ -0,0 +1,346 @@
+{
+ "version": "6",
+ "dialect": "sqlite",
+ "id": "0cf3d79e-69bf-4dba-8df4-996b9b67d2e8",
+ "prevId": "dabab6db-a2cd-4e96-b06e-6cb87d445a87",
+ "tables": {
+ "agents": {
+ "name": "agents",
+ "columns": {
+ "id": {
+ "name": "id",
+ "type": "text",
+ "primaryKey": true,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "type": {
+ "name": "type",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "name": {
+ "name": "name",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "description": {
+ "name": "description",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "accessible_paths": {
+ "name": "accessible_paths",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "instructions": {
+ "name": "instructions",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "model": {
+ "name": "model",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "plan_model": {
+ "name": "plan_model",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "small_model": {
+ "name": "small_model",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "mcps": {
+ "name": "mcps",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "allowed_tools": {
+ "name": "allowed_tools",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "configuration": {
+ "name": "configuration",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "created_at": {
+ "name": "created_at",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "updated_at": {
+ "name": "updated_at",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ }
+ },
+ "indexes": {},
+ "foreignKeys": {},
+ "compositePrimaryKeys": {},
+ "uniqueConstraints": {},
+ "checkConstraints": {}
+ },
+ "session_messages": {
+ "name": "session_messages",
+ "columns": {
+ "id": {
+ "name": "id",
+ "type": "integer",
+ "primaryKey": true,
+ "notNull": true,
+ "autoincrement": true
+ },
+ "session_id": {
+ "name": "session_id",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "role": {
+ "name": "role",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "content": {
+ "name": "content",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "agent_session_id": {
+ "name": "agent_session_id",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false,
+ "default": "''"
+ },
+ "metadata": {
+ "name": "metadata",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "created_at": {
+ "name": "created_at",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "updated_at": {
+ "name": "updated_at",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ }
+ },
+ "indexes": {},
+ "foreignKeys": {},
+ "compositePrimaryKeys": {},
+ "uniqueConstraints": {},
+ "checkConstraints": {}
+ },
+ "migrations": {
+ "name": "migrations",
+ "columns": {
+ "version": {
+ "name": "version",
+ "type": "integer",
+ "primaryKey": true,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "tag": {
+ "name": "tag",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "executed_at": {
+ "name": "executed_at",
+ "type": "integer",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ }
+ },
+ "indexes": {},
+ "foreignKeys": {},
+ "compositePrimaryKeys": {},
+ "uniqueConstraints": {},
+ "checkConstraints": {}
+ },
+ "sessions": {
+ "name": "sessions",
+ "columns": {
+ "id": {
+ "name": "id",
+ "type": "text",
+ "primaryKey": true,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "agent_type": {
+ "name": "agent_type",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "agent_id": {
+ "name": "agent_id",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "name": {
+ "name": "name",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "description": {
+ "name": "description",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "accessible_paths": {
+ "name": "accessible_paths",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "instructions": {
+ "name": "instructions",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "model": {
+ "name": "model",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "plan_model": {
+ "name": "plan_model",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "small_model": {
+ "name": "small_model",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "mcps": {
+ "name": "mcps",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "allowed_tools": {
+ "name": "allowed_tools",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "slash_commands": {
+ "name": "slash_commands",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "configuration": {
+ "name": "configuration",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false,
+ "autoincrement": false
+ },
+ "created_at": {
+ "name": "created_at",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ },
+ "updated_at": {
+ "name": "updated_at",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "autoincrement": false
+ }
+ },
+ "indexes": {},
+ "foreignKeys": {},
+ "compositePrimaryKeys": {},
+ "uniqueConstraints": {},
+ "checkConstraints": {}
+ }
+ },
+ "views": {},
+ "enums": {},
+ "_meta": {
+ "schemas": {},
+ "tables": {},
+ "columns": {}
+ },
+ "internal": {
+ "indexes": {}
+ }
+}
diff --git a/resources/database/drizzle/meta/_journal.json b/resources/database/drizzle/meta/_journal.json
index 8648e01703..ac026637aa 100644
--- a/resources/database/drizzle/meta/_journal.json
+++ b/resources/database/drizzle/meta/_journal.json
@@ -15,6 +15,13 @@
"when": 1758187378775,
"tag": "0001_woozy_captain_flint",
"breakpoints": true
+ },
+ {
+ "idx": 2,
+ "version": "6",
+ "when": 1762526423527,
+ "tag": "0002_wealthy_naoko",
+ "breakpoints": true
}
]
}
diff --git a/resources/js/bridge.js b/resources/js/bridge.js
deleted file mode 100644
index f6c0021a63..0000000000
--- a/resources/js/bridge.js
+++ /dev/null
@@ -1,36 +0,0 @@
-;(() => {
- let messageId = 0
- const pendingCalls = new Map()
-
- function api(method, ...args) {
- const id = messageId++
- return new Promise((resolve, reject) => {
- pendingCalls.set(id, { resolve, reject })
- window.parent.postMessage({ id, type: 'api-call', method, args }, '*')
- })
- }
-
- window.addEventListener('message', (event) => {
- if (event.data.type === 'api-response') {
- const { id, result, error } = event.data
- const pendingCall = pendingCalls.get(id)
- if (pendingCall) {
- if (error) {
- pendingCall.reject(new Error(error))
- } else {
- pendingCall.resolve(result)
- }
- pendingCalls.delete(id)
- }
- }
- })
-
- window.api = new Proxy(
- {},
- {
- get: (target, prop) => {
- return (...args) => api(prop, ...args)
- }
- }
- )
-})()
diff --git a/resources/js/utils.js b/resources/js/utils.js
deleted file mode 100644
index 36981ac44f..0000000000
--- a/resources/js/utils.js
+++ /dev/null
@@ -1,5 +0,0 @@
-export function getQueryParam(paramName) {
- const url = new URL(window.location.href)
- const params = new URLSearchParams(url.search)
- return params.get(paramName)
-}
diff --git a/resources/scripts/install-bun.js b/resources/scripts/install-bun.js
index 1467a4cde4..33ee18d732 100644
--- a/resources/scripts/install-bun.js
+++ b/resources/scripts/install-bun.js
@@ -7,7 +7,7 @@ const { downloadWithRedirects } = require('./download')
// Base URL for downloading bun binaries
const BUN_RELEASE_BASE_URL = 'https://gitcode.com/CherryHQ/bun/releases/download'
-const DEFAULT_BUN_VERSION = '1.2.17' // Default fallback version
+const DEFAULT_BUN_VERSION = '1.3.1' // Default fallback version
// Mapping of platform+arch to binary package name
const BUN_PACKAGES = {
diff --git a/resources/scripts/install-uv.js b/resources/scripts/install-uv.js
index 3dc8b3e477..c3d34efc33 100644
--- a/resources/scripts/install-uv.js
+++ b/resources/scripts/install-uv.js
@@ -7,28 +7,29 @@ const { downloadWithRedirects } = require('./download')
// Base URL for downloading uv binaries
const UV_RELEASE_BASE_URL = 'https://gitcode.com/CherryHQ/uv/releases/download'
-const DEFAULT_UV_VERSION = '0.7.13'
+const DEFAULT_UV_VERSION = '0.9.5'
// Mapping of platform+arch to binary package name
const UV_PACKAGES = {
- 'darwin-arm64': 'uv-aarch64-apple-darwin.zip',
- 'darwin-x64': 'uv-x86_64-apple-darwin.zip',
+ 'darwin-arm64': 'uv-aarch64-apple-darwin.tar.gz',
+ 'darwin-x64': 'uv-x86_64-apple-darwin.tar.gz',
'win32-arm64': 'uv-aarch64-pc-windows-msvc.zip',
'win32-ia32': 'uv-i686-pc-windows-msvc.zip',
'win32-x64': 'uv-x86_64-pc-windows-msvc.zip',
- 'linux-arm64': 'uv-aarch64-unknown-linux-gnu.zip',
- 'linux-ia32': 'uv-i686-unknown-linux-gnu.zip',
- 'linux-ppc64': 'uv-powerpc64-unknown-linux-gnu.zip',
- 'linux-ppc64le': 'uv-powerpc64le-unknown-linux-gnu.zip',
- 'linux-s390x': 'uv-s390x-unknown-linux-gnu.zip',
- 'linux-x64': 'uv-x86_64-unknown-linux-gnu.zip',
- 'linux-armv7l': 'uv-armv7-unknown-linux-gnueabihf.zip',
+ 'linux-arm64': 'uv-aarch64-unknown-linux-gnu.tar.gz',
+ 'linux-ia32': 'uv-i686-unknown-linux-gnu.tar.gz',
+ 'linux-ppc64': 'uv-powerpc64-unknown-linux-gnu.tar.gz',
+ 'linux-ppc64le': 'uv-powerpc64le-unknown-linux-gnu.tar.gz',
+ 'linux-riscv64': 'uv-riscv64gc-unknown-linux-gnu.tar.gz',
+ 'linux-s390x': 'uv-s390x-unknown-linux-gnu.tar.gz',
+ 'linux-x64': 'uv-x86_64-unknown-linux-gnu.tar.gz',
+ 'linux-armv7l': 'uv-armv7-unknown-linux-gnueabihf.tar.gz',
// MUSL variants
- 'linux-musl-arm64': 'uv-aarch64-unknown-linux-musl.zip',
- 'linux-musl-ia32': 'uv-i686-unknown-linux-musl.zip',
- 'linux-musl-x64': 'uv-x86_64-unknown-linux-musl.zip',
- 'linux-musl-armv6l': 'uv-arm-unknown-linux-musleabihf.zip',
- 'linux-musl-armv7l': 'uv-armv7-unknown-linux-musleabihf.zip'
+ 'linux-musl-arm64': 'uv-aarch64-unknown-linux-musl.tar.gz',
+ 'linux-musl-ia32': 'uv-i686-unknown-linux-musl.tar.gz',
+ 'linux-musl-x64': 'uv-x86_64-unknown-linux-musl.tar.gz',
+ 'linux-musl-armv6l': 'uv-arm-unknown-linux-musleabihf.tar.gz',
+ 'linux-musl-armv7l': 'uv-armv7-unknown-linux-musleabihf.tar.gz'
}
/**
@@ -56,6 +57,7 @@ async function downloadUvBinary(platform, arch, version = DEFAULT_UV_VERSION, is
const downloadUrl = `${UV_RELEASE_BASE_URL}/${version}/${packageName}`
const tempdir = os.tmpdir()
const tempFilename = path.join(tempdir, packageName)
+ const isTarGz = packageName.endsWith('.tar.gz')
try {
console.log(`Downloading uv ${version} for ${platformKey}...`)
@@ -65,34 +67,58 @@ async function downloadUvBinary(platform, arch, version = DEFAULT_UV_VERSION, is
console.log(`Extracting ${packageName} to ${binDir}...`)
- const zip = new StreamZip.async({ file: tempFilename })
+ if (isTarGz) {
+ // Use tar command to extract tar.gz files (macOS and Linux)
+ const tempExtractDir = path.join(tempdir, `uv-extract-${Date.now()}`)
+ fs.mkdirSync(tempExtractDir, { recursive: true })
- // Get all entries in the zip file
- const entries = await zip.entries()
+ execSync(`tar -xzf "${tempFilename}" -C "${tempExtractDir}"`, { stdio: 'inherit' })
- // Extract files directly to binDir, flattening the directory structure
- for (const entry of Object.values(entries)) {
- if (!entry.isDirectory) {
- // Get just the filename without path
- const filename = path.basename(entry.name)
- const outputPath = path.join(binDir, filename)
-
- console.log(`Extracting ${entry.name} -> ${filename}`)
- await zip.extract(entry.name, outputPath)
- // Make executable files executable on Unix-like systems
- if (platform !== 'win32') {
- try {
+ // Find all files in the extracted directory and move them to binDir
+ const findAndMoveFiles = (dir) => {
+ const entries = fs.readdirSync(dir, { withFileTypes: true })
+ for (const entry of entries) {
+ const fullPath = path.join(dir, entry.name)
+ if (entry.isDirectory()) {
+ findAndMoveFiles(fullPath)
+ } else {
+ const filename = path.basename(entry.name)
+ const outputPath = path.join(binDir, filename)
+ fs.copyFileSync(fullPath, outputPath)
+ console.log(`Extracted ${entry.name} -> ${outputPath}`)
+ // Make executable on Unix-like systems
fs.chmodSync(outputPath, 0o755)
- } catch (chmodError) {
- console.error(`Warning: Failed to set executable permissions on ${filename}`)
- return 102
}
}
- console.log(`Extracted ${entry.name} -> ${outputPath}`)
}
+
+ findAndMoveFiles(tempExtractDir)
+
+ // Clean up temporary extraction directory
+ fs.rmSync(tempExtractDir, { recursive: true })
+ } else {
+ // Use StreamZip for zip files (Windows)
+ const zip = new StreamZip.async({ file: tempFilename })
+
+ // Get all entries in the zip file
+ const entries = await zip.entries()
+
+ // Extract files directly to binDir, flattening the directory structure
+ for (const entry of Object.values(entries)) {
+ if (!entry.isDirectory) {
+ // Get just the filename without path
+ const filename = path.basename(entry.name)
+ const outputPath = path.join(binDir, filename)
+
+ console.log(`Extracting ${entry.name} -> ${filename}`)
+ await zip.extract(entry.name, outputPath)
+ console.log(`Extracted ${entry.name} -> ${outputPath}`)
+ }
+ }
+
+ await zip.close()
}
- await zip.close()
fs.unlinkSync(tempFilename)
console.log(`Successfully installed uv ${version} for ${platform}-${arch}`)
return 0
diff --git a/resources/scripts/ipService.js b/resources/scripts/ipService.js
deleted file mode 100644
index 8e997659a7..0000000000
--- a/resources/scripts/ipService.js
+++ /dev/null
@@ -1,88 +0,0 @@
-const https = require('https')
-const { loggerService } = require('@logger')
-
-const logger = loggerService.withContext('IpService')
-
-/**
- * 获取用户的IP地址所在国家
- * @returns {Promise} 返回国家代码,默认为'CN'
- */
-async function getIpCountry() {
- return new Promise((resolve) => {
- // 添加超时控制
- const timeout = setTimeout(() => {
- logger.info('IP Address Check Timeout, default to China Mirror')
- resolve('CN')
- }, 5000)
-
- const options = {
- hostname: 'ipinfo.io',
- path: '/json',
- method: 'GET',
- headers: {
- 'User-Agent':
- 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36',
- 'Accept-Language': 'en-US,en;q=0.9'
- }
- }
-
- const req = https.request(options, (res) => {
- clearTimeout(timeout)
- let data = ''
-
- res.on('data', (chunk) => {
- data += chunk
- })
-
- res.on('end', () => {
- try {
- const parsed = JSON.parse(data)
- const country = parsed.country || 'CN'
- logger.info(`Detected user IP address country: ${country}`)
- resolve(country)
- } catch (error) {
- logger.error('Failed to parse IP address information:', error.message)
- resolve('CN')
- }
- })
- })
-
- req.on('error', (error) => {
- clearTimeout(timeout)
- logger.error('Failed to get IP address information:', error.message)
- resolve('CN')
- })
-
- req.end()
- })
-}
-
-/**
- * 检查用户是否在中国
- * @returns {Promise} 如果用户在中国返回true,否则返回false
- */
-async function isUserInChina() {
- const country = await getIpCountry()
- return country.toLowerCase() === 'cn'
-}
-
-/**
- * 根据用户位置获取适合的npm镜像URL
- * @returns {Promise} 返回npm镜像URL
- */
-async function getNpmRegistryUrl() {
- const inChina = await isUserInChina()
- if (inChina) {
- logger.info('User in China, using Taobao npm mirror')
- return 'https://registry.npmmirror.com'
- } else {
- logger.info('User not in China, using default npm mirror')
- return 'https://registry.npmjs.org'
- }
-}
-
-module.exports = {
- getIpCountry,
- isUserInChina,
- getNpmRegistryUrl
-}
diff --git a/scripts/auto-translate-i18n.ts b/scripts/auto-translate-i18n.ts
index 71650f6618..7a1bea6f35 100644
--- a/scripts/auto-translate-i18n.ts
+++ b/scripts/auto-translate-i18n.ts
@@ -18,8 +18,10 @@ import { sortedObjectByKeys } from './sort'
// ========== SCRIPT CONFIGURATION AREA - MODIFY SETTINGS HERE ==========
const SCRIPT_CONFIG = {
// 🔧 Concurrency Control Configuration
- MAX_CONCURRENT_TRANSLATIONS: 5, // Max concurrent requests (Make sure the concurrency level does not exceed your provider's limits.)
- TRANSLATION_DELAY_MS: 100, // Delay between requests to avoid rate limiting (Recommended: 100-500ms, Range: 0-5000ms)
+ MAX_CONCURRENT_TRANSLATIONS: process.env.TRANSLATION_MAX_CONCURRENT_REQUESTS
+ ? parseInt(process.env.TRANSLATION_MAX_CONCURRENT_REQUESTS)
+ : 5, // Max concurrent requests (Make sure the concurrency level does not exceed your provider's limits.)
+ TRANSLATION_DELAY_MS: process.env.TRANSLATION_DELAY_MS ? parseInt(process.env.TRANSLATION_DELAY_MS) : 500, // Delay between requests to avoid rate limiting (Recommended: 100-500ms, Range: 0-5000ms)
// 🔑 API Configuration
API_KEY: process.env.TRANSLATION_API_KEY || '', // API key from environment variable
diff --git a/scripts/feishu-notify.js b/scripts/feishu-notify.js
index aae9004a48..d238dedb90 100644
--- a/scripts/feishu-notify.js
+++ b/scripts/feishu-notify.js
@@ -91,23 +91,6 @@ function createIssueCard(issueData) {
return {
elements: [
- {
- tag: 'div',
- text: {
- tag: 'lark_md',
- content: `**🐛 New GitHub Issue #${issueNumber}**`
- }
- },
- {
- tag: 'hr'
- },
- {
- tag: 'div',
- text: {
- tag: 'lark_md',
- content: `**📝 Title:** ${issueTitle}`
- }
- },
{
tag: 'div',
text: {
@@ -158,7 +141,7 @@ function createIssueCard(issueData) {
template: 'blue',
title: {
tag: 'plain_text',
- content: '🆕 Cherry Studio - New Issue'
+ content: `#${issueNumber} - ${issueTitle}`
}
}
}
diff --git a/scripts/update-app-upgrade-config.ts b/scripts/update-app-upgrade-config.ts
new file mode 100644
index 0000000000..4fcfa647f7
--- /dev/null
+++ b/scripts/update-app-upgrade-config.ts
@@ -0,0 +1,532 @@
+import fs from 'fs/promises'
+import path from 'path'
+import semver from 'semver'
+
+type UpgradeChannel = 'latest' | 'rc' | 'beta'
+type UpdateMirror = 'github' | 'gitcode'
+
+const CHANNELS: UpgradeChannel[] = ['latest', 'rc', 'beta']
+const MIRRORS: UpdateMirror[] = ['github', 'gitcode']
+const GITHUB_REPO = 'CherryHQ/cherry-studio'
+const GITCODE_REPO = 'CherryHQ/cherry-studio'
+const DEFAULT_FEED_TEMPLATES: Record = {
+ github: `https://github.com/${GITHUB_REPO}/releases/download/{{tag}}`,
+ gitcode: `https://gitcode.com/${GITCODE_REPO}/releases/download/{{tag}}`
+}
+const GITCODE_LATEST_FALLBACK = 'https://releases.cherry-ai.com'
+
+interface CliOptions {
+ tag?: string
+ configPath?: string
+ segmentsPath?: string
+ dryRun?: boolean
+ skipReleaseChecks?: boolean
+ isPrerelease?: boolean
+}
+
+interface ChannelTemplateConfig {
+ feedTemplates?: Partial>
+}
+
+interface SegmentMatchRule {
+ range?: string
+ exact?: string[]
+ excludeExact?: string[]
+}
+
+interface SegmentDefinition {
+ id: string
+ type: 'legacy' | 'breaking' | 'latest'
+ match: SegmentMatchRule
+ lockedVersion?: string
+ minCompatibleVersion: string
+ description: string
+ channelTemplates?: Partial>
+}
+
+interface SegmentMetadataFile {
+ segments: SegmentDefinition[]
+}
+
+interface ChannelConfig {
+ version: string
+ feedUrls: Record
+}
+
+interface VersionMetadata {
+ segmentId: string
+ segmentType?: string
+}
+
+interface VersionEntry {
+ metadata?: VersionMetadata
+ minCompatibleVersion: string
+ description: string
+ channels: Record
+}
+
+interface UpgradeConfigFile {
+ lastUpdated: string
+ versions: Record
+}
+
+interface ReleaseInfo {
+ tag: string
+ version: string
+ channel: UpgradeChannel
+}
+
+interface UpdateVersionsResult {
+ versions: Record
+ updated: boolean
+}
+
+const ROOT_DIR = path.resolve(__dirname, '..')
+const DEFAULT_CONFIG_PATH = path.join(ROOT_DIR, 'app-upgrade-config.json')
+const DEFAULT_SEGMENTS_PATH = path.join(ROOT_DIR, 'config/app-upgrade-segments.json')
+
+async function main() {
+ const options = parseArgs()
+ const releaseTag = resolveTag(options)
+ const normalizedVersion = normalizeVersion(releaseTag)
+ const releaseChannel = detectChannel(normalizedVersion)
+ if (!releaseChannel) {
+ console.warn(`[update-app-upgrade-config] Tag ${normalizedVersion} does not map to beta/rc/latest. Skipping.`)
+ return
+ }
+
+ // Validate version format matches prerelease status
+ if (options.isPrerelease !== undefined) {
+ const hasPrereleaseSuffix = releaseChannel === 'beta' || releaseChannel === 'rc'
+
+ if (options.isPrerelease && !hasPrereleaseSuffix) {
+ console.warn(
+ `[update-app-upgrade-config] ⚠️ Release marked as prerelease but version ${normalizedVersion} has no beta/rc suffix. Skipping.`
+ )
+ return
+ }
+
+ if (!options.isPrerelease && hasPrereleaseSuffix) {
+ console.warn(
+ `[update-app-upgrade-config] ⚠️ Release marked as latest but version ${normalizedVersion} has prerelease suffix (${releaseChannel}). Skipping.`
+ )
+ return
+ }
+ }
+
+ const [config, segmentFile] = await Promise.all([
+ readJson(options.configPath ?? DEFAULT_CONFIG_PATH),
+ readJson(options.segmentsPath ?? DEFAULT_SEGMENTS_PATH)
+ ])
+
+ const segment = pickSegment(segmentFile.segments, normalizedVersion)
+ if (!segment) {
+ throw new Error(`Unable to find upgrade segment for version ${normalizedVersion}`)
+ }
+
+ if (segment.lockedVersion && segment.lockedVersion !== normalizedVersion) {
+ throw new Error(`Segment ${segment.id} is locked to ${segment.lockedVersion}, but received ${normalizedVersion}`)
+ }
+
+ const releaseInfo: ReleaseInfo = {
+ tag: formatTag(releaseTag),
+ version: normalizedVersion,
+ channel: releaseChannel
+ }
+
+ const { versions: updatedVersions, updated } = await updateVersions(
+ config.versions,
+ segment,
+ releaseInfo,
+ Boolean(options.skipReleaseChecks)
+ )
+
+ if (!updated) {
+ throw new Error(
+ `[update-app-upgrade-config] Feed URLs are not ready for ${releaseInfo.version} (${releaseInfo.channel}). Try again after the release mirrors finish syncing.`
+ )
+ }
+
+ const updatedConfig: UpgradeConfigFile = {
+ ...config,
+ lastUpdated: new Date().toISOString(),
+ versions: updatedVersions
+ }
+
+ const output = JSON.stringify(updatedConfig, null, 2) + '\n'
+
+ if (options.dryRun) {
+ console.log('Dry run enabled. Generated configuration:\n')
+ console.log(output)
+ return
+ }
+
+ await fs.writeFile(options.configPath ?? DEFAULT_CONFIG_PATH, output, 'utf-8')
+ console.log(
+ `✅ Updated ${path.relative(process.cwd(), options.configPath ?? DEFAULT_CONFIG_PATH)} for ${segment.id} (${releaseInfo.channel}) -> ${releaseInfo.version}`
+ )
+}
+
+function parseArgs(): CliOptions {
+ const args = process.argv.slice(2)
+ const options: CliOptions = {}
+
+ for (let i = 0; i < args.length; i += 1) {
+ const arg = args[i]
+ if (arg === '--tag') {
+ options.tag = args[i + 1]
+ i += 1
+ } else if (arg === '--config') {
+ options.configPath = args[i + 1]
+ i += 1
+ } else if (arg === '--segments') {
+ options.segmentsPath = args[i + 1]
+ i += 1
+ } else if (arg === '--dry-run') {
+ options.dryRun = true
+ } else if (arg === '--skip-release-checks') {
+ options.skipReleaseChecks = true
+ } else if (arg === '--is-prerelease') {
+ options.isPrerelease = args[i + 1] === 'true'
+ i += 1
+ } else if (arg === '--help') {
+ printHelp()
+ process.exit(0)
+ } else {
+ console.warn(`Ignoring unknown argument "${arg}"`)
+ }
+ }
+
+ if (options.skipReleaseChecks && !options.dryRun) {
+ throw new Error('--skip-release-checks can only be used together with --dry-run')
+ }
+
+ return options
+}
+
+function printHelp() {
+ console.log(`Usage: tsx scripts/update-app-upgrade-config.ts [options]
+
+Options:
+ --tag Release tag (e.g. v2.1.6). Falls back to GITHUB_REF_NAME/RELEASE_TAG.
+ --config Path to app-upgrade-config.json.
+ --segments Path to app-upgrade-segments.json.
+ --is-prerelease Whether this is a prerelease (validates version format).
+ --dry-run Print the result without writing to disk.
+ --skip-release-checks Skip release page availability checks (only valid with --dry-run).
+ --help Show this help message.`)
+}
+
+function resolveTag(options: CliOptions): string {
+ const envTag = process.env.RELEASE_TAG ?? process.env.GITHUB_REF_NAME ?? process.env.TAG_NAME
+ const tag = options.tag ?? envTag
+
+ if (!tag) {
+ throw new Error('A release tag is required. Pass --tag or set RELEASE_TAG/GITHUB_REF_NAME.')
+ }
+
+ return tag
+}
+
+function normalizeVersion(tag: string): string {
+ const cleaned = semver.clean(tag, { loose: true })
+ if (!cleaned) {
+ throw new Error(`Tag "${tag}" is not a valid semantic version`)
+ }
+
+ const valid = semver.valid(cleaned, { loose: true })
+ if (!valid) {
+ throw new Error(`Unable to normalize tag "${tag}" to a valid semantic version`)
+ }
+
+ return valid
+}
+
+function detectChannel(version: string): UpgradeChannel | null {
+ const parsed = semver.parse(version, { loose: true, includePrerelease: true })
+ if (!parsed) {
+ return null
+ }
+
+ if (parsed.prerelease.length === 0) {
+ return 'latest'
+ }
+
+ const label = String(parsed.prerelease[0]).toLowerCase()
+ if (label === 'beta') {
+ return 'beta'
+ }
+ if (label === 'rc') {
+ return 'rc'
+ }
+
+ return null
+}
+
+async function readJson(filePath: string): Promise {
+ const absolute = path.isAbsolute(filePath) ? filePath : path.resolve(filePath)
+ const data = await fs.readFile(absolute, 'utf-8')
+ return JSON.parse(data) as T
+}
+
+function pickSegment(segments: SegmentDefinition[], version: string): SegmentDefinition | null {
+ for (const segment of segments) {
+ if (matchesSegment(segment.match, version)) {
+ return segment
+ }
+ }
+ return null
+}
+
+function matchesSegment(matchRule: SegmentMatchRule, version: string): boolean {
+ if (matchRule.exact && matchRule.exact.includes(version)) {
+ return true
+ }
+
+ if (matchRule.excludeExact && matchRule.excludeExact.includes(version)) {
+ return false
+ }
+
+ if (matchRule.range && !semver.satisfies(version, matchRule.range, { includePrerelease: true })) {
+ return false
+ }
+
+ if (matchRule.exact) {
+ return matchRule.exact.includes(version)
+ }
+
+ return Boolean(matchRule.range)
+}
+
+function formatTag(tag: string): string {
+ if (tag.startsWith('refs/tags/')) {
+ return tag.replace('refs/tags/', '')
+ }
+ return tag
+}
+
+async function updateVersions(
+ versions: Record,
+ segment: SegmentDefinition,
+ releaseInfo: ReleaseInfo,
+ skipReleaseValidation: boolean
+): Promise {
+ const versionsCopy: Record = { ...versions }
+ const existingKey = findVersionKeyBySegment(versionsCopy, segment.id)
+ const targetKey = resolveVersionKey(existingKey, segment, releaseInfo)
+ const shouldRename = existingKey && existingKey !== targetKey
+
+ let entry: VersionEntry
+ if (existingKey) {
+ entry = { ...versionsCopy[existingKey], channels: { ...versionsCopy[existingKey].channels } }
+ } else {
+ entry = createEmptyVersionEntry()
+ }
+
+ entry.channels = ensureChannelSlots(entry.channels)
+
+ const channelUpdated = await applyChannelUpdate(entry, segment, releaseInfo, skipReleaseValidation)
+ if (!channelUpdated) {
+ return { versions, updated: false }
+ }
+
+ if (shouldRename && existingKey) {
+ delete versionsCopy[existingKey]
+ }
+
+ entry.metadata = {
+ segmentId: segment.id,
+ segmentType: segment.type
+ }
+ entry.minCompatibleVersion = segment.minCompatibleVersion
+ entry.description = segment.description
+
+ versionsCopy[targetKey] = entry
+ return {
+ versions: sortVersionMap(versionsCopy),
+ updated: true
+ }
+}
+
+function findVersionKeyBySegment(versions: Record, segmentId: string): string | null {
+ for (const [key, value] of Object.entries(versions)) {
+ if (value.metadata?.segmentId === segmentId) {
+ return key
+ }
+ }
+ return null
+}
+
+function resolveVersionKey(existingKey: string | null, segment: SegmentDefinition, releaseInfo: ReleaseInfo): string {
+ if (segment.lockedVersion) {
+ return segment.lockedVersion
+ }
+
+ if (releaseInfo.channel === 'latest') {
+ return releaseInfo.version
+ }
+
+ if (existingKey) {
+ return existingKey
+ }
+
+ const baseVersion = getBaseVersion(releaseInfo.version)
+ return baseVersion ?? releaseInfo.version
+}
+
+function getBaseVersion(version: string): string | null {
+ const parsed = semver.parse(version, { loose: true, includePrerelease: true })
+ if (!parsed) {
+ return null
+ }
+ return `${parsed.major}.${parsed.minor}.${parsed.patch}`
+}
+
+function createEmptyVersionEntry(): VersionEntry {
+ return {
+ minCompatibleVersion: '',
+ description: '',
+ channels: {
+ latest: null,
+ rc: null,
+ beta: null
+ }
+ }
+}
+
+function ensureChannelSlots(
+ channels: Record
+): Record {
+ return CHANNELS.reduce(
+ (acc, channel) => {
+ acc[channel] = channels[channel] ?? null
+ return acc
+ },
+ {} as Record
+ )
+}
+
+async function applyChannelUpdate(
+ entry: VersionEntry,
+ segment: SegmentDefinition,
+ releaseInfo: ReleaseInfo,
+ skipReleaseValidation: boolean
+): Promise {
+ if (!CHANNELS.includes(releaseInfo.channel)) {
+ throw new Error(`Unsupported channel "${releaseInfo.channel}"`)
+ }
+
+ const feedUrls = buildFeedUrls(segment, releaseInfo)
+
+ if (skipReleaseValidation) {
+ console.warn(
+ `[update-app-upgrade-config] Skipping release availability validation for ${releaseInfo.version} (${releaseInfo.channel}).`
+ )
+ } else {
+ const availability = await ensureReleaseAvailability(releaseInfo)
+ if (!availability.github) {
+ return false
+ }
+ if (releaseInfo.channel === 'latest' && !availability.gitcode) {
+ console.warn(
+ `[update-app-upgrade-config] gitcode release page not ready for ${releaseInfo.tag}. Falling back to ${GITCODE_LATEST_FALLBACK}.`
+ )
+ feedUrls.gitcode = GITCODE_LATEST_FALLBACK
+ }
+ }
+
+ entry.channels[releaseInfo.channel] = {
+ version: releaseInfo.version,
+ feedUrls
+ }
+
+ return true
+}
+
+function buildFeedUrls(segment: SegmentDefinition, releaseInfo: ReleaseInfo): Record {
+ return MIRRORS.reduce(
+ (acc, mirror) => {
+ const template = resolveFeedTemplate(segment, releaseInfo, mirror)
+ acc[mirror] = applyTemplate(template, releaseInfo)
+ return acc
+ },
+ {} as Record
+ )
+}
+
+function resolveFeedTemplate(segment: SegmentDefinition, releaseInfo: ReleaseInfo, mirror: UpdateMirror): string {
+ if (mirror === 'gitcode' && releaseInfo.channel !== 'latest') {
+ return segment.channelTemplates?.[releaseInfo.channel]?.feedTemplates?.github ?? DEFAULT_FEED_TEMPLATES.github
+ }
+
+ return segment.channelTemplates?.[releaseInfo.channel]?.feedTemplates?.[mirror] ?? DEFAULT_FEED_TEMPLATES[mirror]
+}
+
+function applyTemplate(template: string, releaseInfo: ReleaseInfo): string {
+ return template.replace(/{{\s*tag\s*}}/gi, releaseInfo.tag).replace(/{{\s*version\s*}}/gi, releaseInfo.version)
+}
+
+function sortVersionMap(versions: Record): Record {
+ const sorted = Object.entries(versions).sort(([a], [b]) => semver.rcompare(a, b))
+ return sorted.reduce(
+ (acc, [version, entry]) => {
+ acc[version] = entry
+ return acc
+ },
+ {} as Record
+ )
+}
+
+interface ReleaseAvailability {
+ github: boolean
+ gitcode: boolean
+}
+
+async function ensureReleaseAvailability(releaseInfo: ReleaseInfo): Promise {
+ const mirrorsToCheck: UpdateMirror[] = releaseInfo.channel === 'latest' ? MIRRORS : ['github']
+ const availability: ReleaseAvailability = {
+ github: false,
+ gitcode: releaseInfo.channel === 'latest' ? false : true
+ }
+
+ for (const mirror of mirrorsToCheck) {
+ const url = getReleasePageUrl(mirror, releaseInfo.tag)
+ try {
+ const response = await fetch(url, {
+ method: mirror === 'github' ? 'HEAD' : 'GET',
+ redirect: 'follow'
+ })
+
+ if (response.ok) {
+ availability[mirror] = true
+ } else {
+ console.warn(
+ `[update-app-upgrade-config] ${mirror} release not available for ${releaseInfo.tag} (status ${response.status}, ${url}).`
+ )
+ availability[mirror] = false
+ }
+ } catch (error) {
+ console.warn(
+ `[update-app-upgrade-config] Failed to verify ${mirror} release page for ${releaseInfo.tag} (${url}). Continuing.`,
+ error
+ )
+ availability[mirror] = false
+ }
+ }
+
+ return availability
+}
+
+function getReleasePageUrl(mirror: UpdateMirror, tag: string): string {
+ if (mirror === 'github') {
+ return `https://github.com/${GITHUB_REPO}/releases/tag/${encodeURIComponent(tag)}`
+ }
+ // Use latest.yml download URL for GitCode to check if release exists
+ // Note: GitCode returns 401 for HEAD requests, so we use GET in ensureReleaseAvailability
+ return `https://gitcode.com/${GITCODE_REPO}/releases/download/${encodeURIComponent(tag)}/latest.yml`
+}
+
+main().catch((error) => {
+ console.error('❌ Failed to update app-upgrade-config:', error)
+ process.exit(1)
+})
diff --git a/src/main/apiServer/config.ts b/src/main/apiServer/config.ts
index 60b1986be9..0966827a7b 100644
--- a/src/main/apiServer/config.ts
+++ b/src/main/apiServer/config.ts
@@ -1,3 +1,4 @@
+import { API_SERVER_DEFAULTS } from '@shared/config/constant'
import type { ApiServerConfig } from '@types'
import { v4 as uuidv4 } from 'uuid'
@@ -6,9 +7,6 @@ import { reduxService } from '../services/ReduxService'
const logger = loggerService.withContext('ApiServerConfig')
-const defaultHost = 'localhost'
-const defaultPort = 23333
-
class ConfigManager {
private _config: ApiServerConfig | null = null
@@ -30,8 +28,8 @@ class ConfigManager {
}
this._config = {
enabled: serverSettings?.enabled ?? false,
- port: serverSettings?.port ?? defaultPort,
- host: defaultHost,
+ port: serverSettings?.port ?? API_SERVER_DEFAULTS.PORT,
+ host: serverSettings?.host ?? API_SERVER_DEFAULTS.HOST,
apiKey: apiKey
}
return this._config
@@ -39,8 +37,8 @@ class ConfigManager {
logger.warn('Failed to load config from Redux, using defaults', { error })
this._config = {
enabled: false,
- port: defaultPort,
- host: defaultHost,
+ port: API_SERVER_DEFAULTS.PORT,
+ host: API_SERVER_DEFAULTS.HOST,
apiKey: this.generateApiKey()
}
return this._config
diff --git a/src/main/apiServer/middleware/openapi.ts b/src/main/apiServer/middleware/openapi.ts
index ff01005bd9..6b374901ca 100644
--- a/src/main/apiServer/middleware/openapi.ts
+++ b/src/main/apiServer/middleware/openapi.ts
@@ -20,8 +20,8 @@ const swaggerOptions: swaggerJSDoc.Options = {
},
servers: [
{
- url: 'http://localhost:23333',
- description: 'Local development server'
+ url: '/',
+ description: 'Current server'
}
],
components: {
diff --git a/src/main/apiServer/routes/models.ts b/src/main/apiServer/routes/models.ts
index 8481e1ea59..d776d5ea91 100644
--- a/src/main/apiServer/routes/models.ts
+++ b/src/main/apiServer/routes/models.ts
@@ -104,12 +104,6 @@ const router = express
logger.warn('No models available from providers', { filter })
}
- logger.info('Models response ready', {
- filter,
- total: response.total,
- modelIds: response.data.map((m) => m.id)
- })
-
return res.json(response satisfies ApiModelsResponse)
} catch (error: any) {
logger.error('Error fetching models', { error })
diff --git a/src/main/apiServer/server.ts b/src/main/apiServer/server.ts
index 9b15e56da0..e59e6bd504 100644
--- a/src/main/apiServer/server.ts
+++ b/src/main/apiServer/server.ts
@@ -3,7 +3,6 @@ import { createServer } from 'node:http'
import { loggerService } from '@logger'
import { IpcChannel } from '@shared/IpcChannel'
-import { agentService } from '../services/agents'
import { windowService } from '../services/WindowService'
import { app } from './app'
import { config } from './config'
@@ -32,11 +31,6 @@ export class ApiServer {
// Load config
const { port, host } = await config.load()
- // Initialize AgentService
- logger.info('Initializing AgentService')
- await agentService.initialize()
- logger.info('AgentService initialized')
-
// Create server with Express app
this.server = createServer(app)
this.applyServerTimeouts(this.server)
diff --git a/src/main/apiServer/services/models.ts b/src/main/apiServer/services/models.ts
index a32d6d37dc..52f0db857f 100644
--- a/src/main/apiServer/services/models.ts
+++ b/src/main/apiServer/services/models.ts
@@ -32,7 +32,7 @@ export class ModelsService {
for (const model of models) {
const provider = providers.find((p) => p.id === model.provider)
- logger.debug(`Processing model ${model.id}`)
+ // logger.debug(`Processing model ${model.id}`)
if (!provider) {
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
continue
diff --git a/src/main/apiServer/utils/index.ts b/src/main/apiServer/utils/index.ts
index f9f751c559..e25b49e750 100644
--- a/src/main/apiServer/utils/index.ts
+++ b/src/main/apiServer/utils/index.ts
@@ -1,6 +1,7 @@
import { CacheService } from '@main/services/CacheService'
import { loggerService } from '@main/services/LoggerService'
import { reduxService } from '@main/services/ReduxService'
+import { isSiliconAnthropicCompatibleModel } from '@shared/config/providers'
import type { ApiModel, Model, Provider } from '@types'
const logger = loggerService.withContext('ApiServerUtils')
@@ -287,6 +288,8 @@ export const getProviderAnthropicModelChecker = (providerId: string): ((m: Model
return (m: Model) => m.endpoint_type === 'anthropic'
case 'aihubmix':
return (m: Model) => m.id.includes('claude')
+ case 'silicon':
+ return (m: Model) => isSiliconAnthropicCompatibleModel(m.id)
default:
// allow all models when checker not configured
return () => true
diff --git a/src/main/index.ts b/src/main/index.ts
index 025268b651..56750e6b61 100644
--- a/src/main/index.ts
+++ b/src/main/index.ts
@@ -8,7 +8,7 @@ import '@main/config'
import { loggerService } from '@logger'
import { electronApp, optimizer } from '@electron-toolkit/utils'
import { replaceDevtoolsFont } from '@main/utils/windowUtil'
-import { app } from 'electron'
+import { app, crashReporter } from 'electron'
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
import { isDev, isLinux, isWin } from './constant'
@@ -34,9 +34,18 @@ import { TrayService } from './services/TrayService'
import { versionService } from './services/VersionService'
import { windowService } from './services/WindowService'
import { initWebviewHotkeys } from './services/WebviewService'
+import { runAsyncFunction } from './utils'
const logger = loggerService.withContext('MainEntry')
+// enable local crash reports
+crashReporter.start({
+ companyName: 'CherryHQ',
+ productName: 'CherryStudio',
+ submitURL: '',
+ uploadToServer: false
+})
+
/**
* Disable hardware acceleration if setting is enabled
*/
@@ -162,39 +171,33 @@ if (!app.requestSingleInstanceLock()) {
//start selection assistant service
initSelectionService()
- // Initialize Agent Service
- try {
- await agentService.initialize()
- logger.info('Agent service initialized successfully')
- } catch (error: any) {
- logger.error('Failed to initialize Agent service:', error)
- }
+ runAsyncFunction(async () => {
+ // Start API server if enabled or if agents exist
+ try {
+ const config = await apiServerService.getCurrentConfig()
+ logger.info('API server config:', config)
- // Start API server if enabled or if agents exist
- try {
- const config = await apiServerService.getCurrentConfig()
- logger.info('API server config:', config)
-
- // Check if there are any agents
- let shouldStart = config.enabled
- if (!shouldStart) {
- try {
- const { total } = await agentService.listAgents({ limit: 1 })
- if (total > 0) {
- shouldStart = true
- logger.info(`Detected ${total} agent(s), auto-starting API server`)
+ // Check if there are any agents
+ let shouldStart = config.enabled
+ if (!shouldStart) {
+ try {
+ const { total } = await agentService.listAgents({ limit: 1 })
+ if (total > 0) {
+ shouldStart = true
+ logger.info(`Detected ${total} agent(s), auto-starting API server`)
+ }
+ } catch (error: any) {
+ logger.warn('Failed to check agent count:', error)
}
- } catch (error: any) {
- logger.warn('Failed to check agent count:', error)
}
- }
- if (shouldStart) {
- await apiServerService.start()
+ if (shouldStart) {
+ await apiServerService.start()
+ }
+ } catch (error: any) {
+ logger.error('Failed to check/start API server:', error)
}
- } catch (error: any) {
- logger.error('Failed to check/start API server:', error)
- }
+ })
})
registerProtocolClient(app)
diff --git a/src/main/ipc.ts b/src/main/ipc.ts
index 91b6bdf20f..c34050b692 100644
--- a/src/main/ipc.ts
+++ b/src/main/ipc.ts
@@ -494,6 +494,44 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
ipcMain.handle(IpcChannel.System_GetCpuName, () => require('os').cpus()[0].model)
+ ipcMain.handle(IpcChannel.System_CheckGitBash, () => {
+ if (!isWin) {
+ return true // Non-Windows systems don't need Git Bash
+ }
+
+ try {
+ // Check common Git Bash installation paths
+ const commonPaths = [
+ path.join(process.env.ProgramFiles || 'C:\\Program Files', 'Git', 'bin', 'bash.exe'),
+ path.join(process.env['ProgramFiles(x86)'] || 'C:\\Program Files (x86)', 'Git', 'bin', 'bash.exe'),
+ path.join(process.env.LOCALAPPDATA || '', 'Programs', 'Git', 'bin', 'bash.exe')
+ ]
+
+ // Check if any of the common paths exist
+ for (const bashPath of commonPaths) {
+ if (fs.existsSync(bashPath)) {
+ logger.debug('Git Bash found', { path: bashPath })
+ return true
+ }
+ }
+
+ // Check if git is in PATH
+ const { execSync } = require('child_process')
+ try {
+ execSync('git --version', { stdio: 'ignore' })
+ logger.debug('Git found in PATH')
+ return true
+ } catch {
+ // Git not in PATH
+ }
+
+ logger.debug('Git Bash not found on Windows system')
+ return false
+ } catch (error) {
+ logger.error('Error checking Git Bash', error as Error)
+ return false
+ }
+ })
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
const win = BrowserWindow.fromWebContents(e.sender)
win && win.webContents.toggleDevTools()
@@ -552,6 +590,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.File_BinaryImage, fileManager.binaryImage.bind(fileManager))
ipcMain.handle(IpcChannel.File_OpenWithRelativePath, fileManager.openFileWithRelativePath.bind(fileManager))
ipcMain.handle(IpcChannel.File_IsTextFile, fileManager.isTextFile.bind(fileManager))
+ ipcMain.handle(IpcChannel.File_ListDirectory, fileManager.listDirectory.bind(fileManager))
ipcMain.handle(IpcChannel.File_GetDirectoryStructure, fileManager.getDirectoryStructure.bind(fileManager))
ipcMain.handle(IpcChannel.File_CheckFileName, fileManager.fileNameGuard.bind(fileManager))
ipcMain.handle(IpcChannel.File_ValidateNotesDirectory, fileManager.validateNotesDirectory.bind(fileManager))
@@ -1052,4 +1091,8 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.WebSocket_Status, WebSocketService.getStatus)
ipcMain.handle(IpcChannel.WebSocket_SendFile, WebSocketService.sendFile)
ipcMain.handle(IpcChannel.WebSocket_GetAllCandidates, WebSocketService.getAllCandidates)
+
+ ipcMain.handle(IpcChannel.APP_CrashRenderProcess, () => {
+ mainWindow.webContents.forcefullyCrashRenderer()
+ })
}
diff --git a/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts b/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts
index 8a780d5618..e9f459fd6c 100644
--- a/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts
+++ b/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts
@@ -19,19 +19,9 @@ export default class EmbeddingsFactory {
})
}
if (provider === 'ollama') {
- if (baseURL.includes('v1/')) {
- return new OllamaEmbeddings({
- model: model,
- baseUrl: baseURL.replace('v1/', ''),
- requestOptions: {
- // @ts-ignore expected
- 'encoding-format': 'float'
- }
- })
- }
return new OllamaEmbeddings({
model: model,
- baseUrl: baseURL,
+ baseUrl: baseURL.replace(/\/api$/, ''),
requestOptions: {
// @ts-ignore expected
'encoding-format': 'float'
diff --git a/src/main/knowledge/preprocess/MineruPreprocessProvider.ts b/src/main/knowledge/preprocess/MineruPreprocessProvider.ts
index 0e93af674a..80aec40622 100644
--- a/src/main/knowledge/preprocess/MineruPreprocessProvider.ts
+++ b/src/main/knowledge/preprocess/MineruPreprocessProvider.ts
@@ -21,6 +21,7 @@ type ApiResponse = {
type BatchUploadResponse = {
batch_id: string
file_urls: string[]
+ headers?: Record[]
}
type ExtractProgress = {
@@ -55,7 +56,7 @@ type QuotaResponse = {
export default class MineruPreprocessProvider extends BasePreprocessProvider {
constructor(provider: PreprocessProvider, userId?: string) {
super(provider, userId)
- // todo:免费期结束后删除
+ // TODO: remove after free period ends
this.provider.apiKey = this.provider.apiKey || import.meta.env.MAIN_VITE_MINERU_API_KEY
}
@@ -68,21 +69,21 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
logger.info(`MinerU preprocess processing started: ${filePath}`)
await this.validateFile(filePath)
- // 1. 获取上传URL并上传文件
+ // 1. Get upload URL and upload file
const batchId = await this.uploadFile(file)
logger.info(`MinerU file upload completed: batch_id=${batchId}`)
- // 2. 等待处理完成并获取结果
+ // 2. Wait for completion and fetch results
const extractResult = await this.waitForCompletion(sourceId, batchId, file.origin_name)
logger.info(`MinerU processing completed for batch: ${batchId}`)
- // 3. 下载并解压文件
+ // 3. Download and extract output
const { path: outputPath } = await this.downloadAndExtractFile(extractResult.full_zip_url!, file)
// 4. check quota
const quota = await this.checkQuota()
- // 5. 创建处理后的文件信息
+ // 5. Create processed file metadata
return {
processedFile: this.createProcessedFileInfo(file, outputPath),
quota
@@ -115,23 +116,48 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
}
private async validateFile(filePath: string): Promise {
+ // Phase 1: check file size (without loading into memory)
+ logger.info(`Validating PDF file: ${filePath}`)
+ const stats = await fs.promises.stat(filePath)
+ const fileSizeBytes = stats.size
+
+ // Ensure file size is under 200MB
+ if (fileSizeBytes >= 200 * 1024 * 1024) {
+ const fileSizeMB = Math.round(fileSizeBytes / (1024 * 1024))
+ throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 200MB`)
+ }
+
+ // Phase 2: check page count (requires reading file with error handling)
const pdfBuffer = await fs.promises.readFile(filePath)
- const doc = await this.readPdf(pdfBuffer)
+ try {
+ const doc = await this.readPdf(pdfBuffer)
- // 文件页数小于600页
- if (doc.numPages >= 600) {
- throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 600 pages`)
- }
- // 文件大小小于200MB
- if (pdfBuffer.length >= 200 * 1024 * 1024) {
- const fileSizeMB = Math.round(pdfBuffer.length / (1024 * 1024))
- throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 200MB`)
+ // Ensure page count is under 600 pages
+ if (doc.numPages >= 600) {
+ throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 600 pages`)
+ }
+
+ logger.info(`PDF validation passed: ${doc.numPages} pages, ${Math.round(fileSizeBytes / (1024 * 1024))}MB`)
+ } catch (error: any) {
+ // If the page limit is exceeded, rethrow immediately
+ if (error.message.includes('exceeds the limit')) {
+ throw error
+ }
+
+ // If PDF parsing fails, log a detailed warning but continue processing
+ logger.warn(
+ `Failed to parse PDF structure (file may be corrupted or use non-standard format). ` +
+ `Skipping page count validation. Will attempt to process with MinerU API. ` +
+ `Error details: ${error.message}. ` +
+ `Suggestion: If processing fails, try repairing the PDF using tools like Adobe Acrobat or online PDF repair services.`
+ )
+ // Do not throw; continue processing
}
}
private createProcessedFileInfo(file: FileMetadata, outputPath: string): FileMetadata {
- // 查找解压后的主要文件
+ // Locate the main extracted file
let finalPath = ''
let finalName = file.origin_name.replace('.pdf', '.md')
@@ -143,14 +169,14 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
const originalMdPath = path.join(outputPath, mdFile)
const newMdPath = path.join(outputPath, finalName)
- // 重命名文件为原始文件名
+ // Rename the file to match the original name
try {
fs.renameSync(originalMdPath, newMdPath)
finalPath = newMdPath
logger.info(`Renamed markdown file from ${mdFile} to ${finalName}`)
} catch (renameError) {
logger.warn(`Failed to rename file ${mdFile} to ${finalName}: ${renameError}`)
- // 如果重命名失败,使用原文件
+ // If renaming fails, fall back to the original file
finalPath = originalMdPath
finalName = mdFile
}
@@ -178,7 +204,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
logger.info(`Downloading MinerU result to: ${zipPath}`)
try {
- // 下载ZIP文件
+ // Download the ZIP file
const response = await net.fetch(zipUrl, { method: 'GET' })
if (!response.ok) {
throw new Error(`HTTP ${response.status}: ${response.statusText}`)
@@ -187,17 +213,17 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
fs.writeFileSync(zipPath, Buffer.from(arrayBuffer))
logger.info(`Downloaded ZIP file: ${zipPath}`)
- // 确保提取目录存在
+ // Ensure the extraction directory exists
if (!fs.existsSync(extractPath)) {
fs.mkdirSync(extractPath, { recursive: true })
}
- // 解压文件
+ // Extract the ZIP contents
const zip = new AdmZip(zipPath)
zip.extractAllTo(extractPath, true)
logger.info(`Extracted files to: ${extractPath}`)
- // 删除临时ZIP文件
+ // Remove the temporary ZIP file
fs.unlinkSync(zipPath)
return { path: extractPath }
@@ -209,11 +235,11 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
private async uploadFile(file: FileMetadata): Promise {
try {
- // 步骤1: 获取上传URL
- const { batchId, fileUrls } = await this.getBatchUploadUrls(file)
- // 步骤2: 上传文件到获取的URL
+ // Step 1: obtain the upload URL
+ const { batchId, fileUrls, uploadHeaders } = await this.getBatchUploadUrls(file)
+ // Step 2: upload the file to the obtained URL
const filePath = fileStorage.getFilePathById(file)
- await this.putFileToUrl(filePath, fileUrls[0])
+ await this.putFileToUrl(filePath, fileUrls[0], file.origin_name, uploadHeaders?.[0])
logger.info(`File uploaded successfully: ${filePath}`, { batchId, fileUrls })
return batchId
@@ -223,7 +249,9 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
}
}
- private async getBatchUploadUrls(file: FileMetadata): Promise<{ batchId: string; fileUrls: string[] }> {
+ private async getBatchUploadUrls(
+ file: FileMetadata
+ ): Promise<{ batchId: string; fileUrls: string[]; uploadHeaders?: Record[] }> {
const endpoint = `${this.provider.apiHost}/api/v4/file-urls/batch`
const payload = {
@@ -254,10 +282,11 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
if (response.ok) {
const data: ApiResponse = await response.json()
if (data.code === 0 && data.data) {
- const { batch_id, file_urls } = data.data
+ const { batch_id, file_urls, headers: uploadHeaders } = data.data
return {
batchId: batch_id,
- fileUrls: file_urls
+ fileUrls: file_urls,
+ uploadHeaders
}
} else {
throw new Error(`API returned error: ${data.msg || JSON.stringify(data)}`)
@@ -271,23 +300,28 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
}
}
- private async putFileToUrl(filePath: string, uploadUrl: string): Promise {
+ private async putFileToUrl(
+ filePath: string,
+ uploadUrl: string,
+ fileName?: string,
+ headers?: Record
+ ): Promise {
try {
const fileBuffer = await fs.promises.readFile(filePath)
+ const fileSize = fileBuffer.byteLength
+ const displayName = fileName ?? path.basename(filePath)
+ logger.info(`Uploading file to MinerU OSS: ${displayName} (${fileSize} bytes)`)
+
+ // https://mineru.net/apiManage/docs
const response = await net.fetch(uploadUrl, {
method: 'PUT',
- body: fileBuffer,
- headers: {
- 'Content-Type': 'application/pdf'
- }
- // headers: {
- // 'Content-Length': fileBuffer.length.toString()
- // }
+ headers,
+ body: new Uint8Array(fileBuffer)
})
if (!response.ok) {
- // 克隆 response 以避免消费 body stream
+ // Clone the response to avoid consuming the body stream
const responseClone = response.clone()
try {
@@ -358,20 +392,20 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
try {
const result = await this.getExtractResults(batchId)
- // 查找对应文件的处理结果
+ // Find the corresponding file result
const fileResult = result.extract_result.find((item) => item.file_name === fileName)
if (!fileResult) {
throw new Error(`File ${fileName} not found in batch results`)
}
- // 检查处理状态
+ // Check the processing state
if (fileResult.state === 'done' && fileResult.full_zip_url) {
logger.info(`Processing completed for file: ${fileName}`)
return fileResult
} else if (fileResult.state === 'failed') {
throw new Error(`Processing failed for file: ${fileName}, error: ${fileResult.err_msg}`)
} else if (fileResult.state === 'running') {
- // 发送进度更新
+ // Send progress updates
if (fileResult.extract_progress) {
const progress = Math.round(
(fileResult.extract_progress.extracted_pages / fileResult.extract_progress.total_pages) * 100
@@ -379,7 +413,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
await this.sendPreprocessProgress(sourceId, progress)
logger.info(`File ${fileName} processing progress: ${progress}%`)
} else {
- // 如果没有具体进度信息,发送一个通用进度
+ // If no detailed progress information is available, send a generic update
await this.sendPreprocessProgress(sourceId, 50)
logger.info(`File ${fileName} is still processing...`)
}
diff --git a/src/main/knowledge/preprocess/OpenMineruPreprocessProvider.ts b/src/main/knowledge/preprocess/OpenMineruPreprocessProvider.ts
index 9a3bca65a1..f322fbac35 100644
--- a/src/main/knowledge/preprocess/OpenMineruPreprocessProvider.ts
+++ b/src/main/knowledge/preprocess/OpenMineruPreprocessProvider.ts
@@ -53,18 +53,43 @@ export default class OpenMineruPreprocessProvider extends BasePreprocessProvider
}
private async validateFile(filePath: string): Promise {
+ // 第一阶段:检查文件大小(无需读取文件到内存)
+ logger.info(`Validating PDF file: ${filePath}`)
+ const stats = await fs.promises.stat(filePath)
+ const fileSizeBytes = stats.size
+
+ // File size must be less than 200MB
+ if (fileSizeBytes >= 200 * 1024 * 1024) {
+ const fileSizeMB = Math.round(fileSizeBytes / (1024 * 1024))
+ throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 200MB`)
+ }
+
+ // 第二阶段:检查页数(需要读取文件,带错误处理)
const pdfBuffer = await fs.promises.readFile(filePath)
- const doc = await this.readPdf(pdfBuffer)
+ try {
+ const doc = await this.readPdf(pdfBuffer)
- // File page count must be less than 600 pages
- if (doc.numPages >= 600) {
- throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 600 pages`)
- }
- // File size must be less than 200MB
- if (pdfBuffer.length >= 200 * 1024 * 1024) {
- const fileSizeMB = Math.round(pdfBuffer.length / (1024 * 1024))
- throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 200MB`)
+ // File page count must be less than 600 pages
+ if (doc.numPages >= 600) {
+ throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 600 pages`)
+ }
+
+ logger.info(`PDF validation passed: ${doc.numPages} pages, ${Math.round(fileSizeBytes / (1024 * 1024))}MB`)
+ } catch (error: any) {
+ // 如果是页数超限错误,直接抛出
+ if (error.message.includes('exceeds the limit')) {
+ throw error
+ }
+
+ // PDF 解析失败,记录详细警告但允许继续处理
+ logger.warn(
+ `Failed to parse PDF structure (file may be corrupted or use non-standard format). ` +
+ `Skipping page count validation. Will attempt to process with MinerU API. ` +
+ `Error details: ${error.message}. ` +
+ `Suggestion: If processing fails, try repairing the PDF using tools like Adobe Acrobat or online PDF repair services.`
+ )
+ // 不抛出错误,允许继续处理
}
}
@@ -72,8 +97,8 @@ export default class OpenMineruPreprocessProvider extends BasePreprocessProvider
// Find the main file after extraction
let finalPath = ''
let finalName = file.origin_name.replace('.pdf', '.md')
- // Find the corresponding folder by file name
- outputPath = path.join(outputPath, `${file.origin_name.replace('.pdf', '')}`)
+ // Find the corresponding folder by file id
+ outputPath = path.join(outputPath, file.id)
try {
const files = fs.readdirSync(outputPath)
@@ -125,7 +150,7 @@ export default class OpenMineruPreprocessProvider extends BasePreprocessProvider
formData.append('return_md', 'true')
formData.append('response_format_zip', 'true')
formData.append('files', fileBuffer, {
- filename: file.origin_name
+ filename: file.name
})
while (retries < maxRetries) {
@@ -139,7 +164,7 @@ export default class OpenMineruPreprocessProvider extends BasePreprocessProvider
...(this.provider.apiKey ? { Authorization: `Bearer ${this.provider.apiKey}` } : {}),
...formData.getHeaders()
},
- body: formData.getBuffer()
+ body: new Uint8Array(formData.getBuffer())
})
if (!response.ok) {
diff --git a/src/main/services/AppUpdater.ts b/src/main/services/AppUpdater.ts
index 168084bd32..57dc3fb2a8 100644
--- a/src/main/services/AppUpdater.ts
+++ b/src/main/services/AppUpdater.ts
@@ -2,7 +2,7 @@ import { loggerService } from '@logger'
import { isWin } from '@main/constant'
import { getIpCountry } from '@main/utils/ipService'
import { generateUserAgent } from '@main/utils/systemInfo'
-import { FeedUrl, UpgradeChannel } from '@shared/config/constant'
+import { FeedUrl, UpdateConfigUrl, UpdateMirror, UpgradeChannel } from '@shared/config/constant'
import { IpcChannel } from '@shared/IpcChannel'
import type { UpdateInfo } from 'builder-util-runtime'
import { CancellationToken } from 'builder-util-runtime'
@@ -22,7 +22,29 @@ const LANG_MARKERS = {
EN_START: '',
ZH_CN_START: '',
END: ''
-} as const
+}
+
+interface UpdateConfig {
+ lastUpdated: string
+ versions: {
+ [versionKey: string]: VersionConfig
+ }
+}
+
+interface VersionConfig {
+ minCompatibleVersion: string
+ description: string
+ channels: {
+ latest: ChannelConfig | null
+ rc: ChannelConfig | null
+ beta: ChannelConfig | null
+ }
+}
+
+interface ChannelConfig {
+ version: string
+ feedUrls: Record
+}
export default class AppUpdater {
autoUpdater: _AppUpdater = autoUpdater
@@ -37,7 +59,9 @@ export default class AppUpdater {
autoUpdater.requestHeaders = {
...autoUpdater.requestHeaders,
'User-Agent': generateUserAgent(),
- 'X-Client-Id': configManager.getClientId()
+ 'X-Client-Id': configManager.getClientId(),
+ // no-cache
+ 'Cache-Control': 'no-cache'
}
autoUpdater.on('error', (error) => {
@@ -75,61 +99,6 @@ export default class AppUpdater {
this.autoUpdater = autoUpdater
}
- private async _getReleaseVersionFromGithub(channel: UpgradeChannel) {
- const headers = {
- Accept: 'application/vnd.github+json',
- 'X-GitHub-Api-Version': '2022-11-28',
- 'Accept-Language': 'en-US,en;q=0.9'
- }
- try {
- logger.info(`get release version from github: ${channel}`)
- const responses = await net.fetch('https://api.github.com/repos/CherryHQ/cherry-studio/releases?per_page=8', {
- headers
- })
- const data = (await responses.json()) as GithubReleaseInfo[]
- let mightHaveLatest = false
- const release: GithubReleaseInfo | undefined = data.find((item: GithubReleaseInfo) => {
- if (!item.draft && !item.prerelease) {
- mightHaveLatest = true
- }
-
- return item.prerelease && item.tag_name.includes(`-${channel}.`)
- })
-
- if (!release) {
- return null
- }
-
- // if the release version is the same as the current version, return null
- if (release.tag_name === app.getVersion()) {
- return null
- }
-
- if (mightHaveLatest) {
- logger.info(`might have latest release, get latest release`)
- const latestReleaseResponse = await net.fetch(
- 'https://api.github.com/repos/CherryHQ/cherry-studio/releases/latest',
- {
- headers
- }
- )
- const latestRelease = (await latestReleaseResponse.json()) as GithubReleaseInfo
- if (semver.gt(latestRelease.tag_name, release.tag_name)) {
- logger.info(
- `latest release version is ${latestRelease.tag_name}, prerelease version is ${release.tag_name}, return null`
- )
- return null
- }
- }
-
- logger.info(`release url is ${release.tag_name}, set channel to ${channel}`)
- return `https://github.com/CherryHQ/cherry-studio/releases/download/${release.tag_name}`
- } catch (error) {
- logger.error('Failed to get latest not draft version from github:', error as Error)
- return null
- }
- }
-
public setAutoUpdate(isActive: boolean) {
autoUpdater.autoDownload = isActive
autoUpdater.autoInstallOnAppQuit = isActive
@@ -161,6 +130,88 @@ export default class AppUpdater {
return UpgradeChannel.LATEST
}
+ /**
+ * Fetch update configuration from GitHub or GitCode based on mirror
+ * @param mirror - Mirror to fetch config from
+ * @returns UpdateConfig object or null if fetch fails
+ */
+ private async _fetchUpdateConfig(mirror: UpdateMirror): Promise {
+ const configUrl = mirror === UpdateMirror.GITCODE ? UpdateConfigUrl.GITCODE : UpdateConfigUrl.GITHUB
+
+ try {
+ logger.info(`Fetching update config from ${configUrl} (mirror: ${mirror})`)
+ const response = await net.fetch(configUrl, {
+ headers: {
+ 'User-Agent': generateUserAgent(),
+ Accept: 'application/json',
+ 'X-Client-Id': configManager.getClientId(),
+ // no-cache
+ 'Cache-Control': 'no-cache'
+ }
+ })
+
+ if (!response.ok) {
+ throw new Error(`HTTP error! status: ${response.status}`)
+ }
+
+ const config = (await response.json()) as UpdateConfig
+ logger.info(`Update config fetched successfully, last updated: ${config.lastUpdated}`)
+ return config
+ } catch (error) {
+ logger.error('Failed to fetch update config:', error as Error)
+ return null
+ }
+ }
+
+ /**
+ * Find compatible channel configuration based on current version
+ * @param currentVersion - Current app version
+ * @param requestedChannel - Requested upgrade channel (latest/rc/beta)
+ * @param config - Update configuration object
+ * @returns Object containing ChannelConfig and actual channel if found, null otherwise
+ */
+ private _findCompatibleChannel(
+ currentVersion: string,
+ requestedChannel: UpgradeChannel,
+ config: UpdateConfig
+ ): { config: ChannelConfig; channel: UpgradeChannel } | null {
+ // Get all version keys and sort descending (newest first)
+ const versionKeys = Object.keys(config.versions).sort(semver.rcompare)
+
+ logger.info(
+ `Finding compatible channel for version ${currentVersion}, requested channel: ${requestedChannel}, available versions: ${versionKeys.join(', ')}`
+ )
+
+ for (const versionKey of versionKeys) {
+ const versionConfig = config.versions[versionKey]
+ const channelConfig = versionConfig.channels[requestedChannel]
+ const latestChannelConfig = versionConfig.channels[UpgradeChannel.LATEST]
+
+ // Check version compatibility and channel availability
+ if (semver.gte(currentVersion, versionConfig.minCompatibleVersion) && channelConfig !== null) {
+ logger.info(
+ `Found compatible version: ${versionKey} (minCompatibleVersion: ${versionConfig.minCompatibleVersion}), version: ${channelConfig.version}`
+ )
+
+ if (
+ requestedChannel !== UpgradeChannel.LATEST &&
+ latestChannelConfig &&
+ semver.gte(latestChannelConfig.version, channelConfig.version)
+ ) {
+ logger.info(
+ `latest channel version is greater than the requested channel version: ${latestChannelConfig.version} > ${channelConfig.version}, using latest instead`
+ )
+ return { config: latestChannelConfig, channel: UpgradeChannel.LATEST }
+ }
+
+ return { config: channelConfig, channel: requestedChannel }
+ }
+ }
+
+ logger.warn(`No compatible channel found for version ${currentVersion} and channel ${requestedChannel}`)
+ return null
+ }
+
private _setChannel(channel: UpgradeChannel, feedUrl: string) {
this.autoUpdater.channel = channel
this.autoUpdater.setFeedURL(feedUrl)
@@ -172,33 +223,42 @@ export default class AppUpdater {
}
private async _setFeedUrl() {
+ const currentVersion = app.getVersion()
const testPlan = configManager.getTestPlan()
- if (testPlan) {
- const channel = this._getTestChannel()
+ const requestedChannel = testPlan ? this._getTestChannel() : UpgradeChannel.LATEST
- if (channel === UpgradeChannel.LATEST) {
- this._setChannel(UpgradeChannel.LATEST, FeedUrl.GITHUB_LATEST)
- return
- }
-
- const releaseUrl = await this._getReleaseVersionFromGithub(channel)
- if (releaseUrl) {
- logger.info(`release url is ${releaseUrl}, set channel to ${channel}`)
- this._setChannel(channel, releaseUrl)
- return
- }
-
- // if no prerelease url, use github latest to get release
- this._setChannel(UpgradeChannel.LATEST, FeedUrl.GITHUB_LATEST)
- return
- }
-
- this._setChannel(UpgradeChannel.LATEST, FeedUrl.PRODUCTION)
+ // Determine mirror based on IP country
const ipCountry = await getIpCountry()
- logger.info(`ipCountry is ${ipCountry}, set channel to ${UpgradeChannel.LATEST}`)
- if (ipCountry.toLowerCase() !== 'cn') {
- this._setChannel(UpgradeChannel.LATEST, FeedUrl.GITHUB_LATEST)
+ const mirror = ipCountry.toLowerCase() === 'cn' ? UpdateMirror.GITCODE : UpdateMirror.GITHUB
+
+ logger.info(
+ `Setting feed URL for version ${currentVersion}, testPlan: ${testPlan}, requested channel: ${requestedChannel}, mirror: ${mirror} (IP country: ${ipCountry})`
+ )
+
+ // Try to fetch update config from remote
+ const config = await this._fetchUpdateConfig(mirror)
+
+ if (config) {
+ // Use new config-based system
+ const result = this._findCompatibleChannel(currentVersion, requestedChannel, config)
+
+ if (result) {
+ const { config: channelConfig, channel: actualChannel } = result
+ const feedUrl = channelConfig.feedUrls[mirror]
+ logger.info(
+ `Using config-based feed URL: ${feedUrl} for channel ${actualChannel} (requested: ${requestedChannel}, mirror: ${mirror})`
+ )
+ this._setChannel(actualChannel, feedUrl)
+ return
+ }
}
+
+ logger.info('Failed to fetch update config, falling back to default feed URL')
+ // Fallback: use default feed URL based on mirror
+ const defaultFeedUrl = mirror === UpdateMirror.GITCODE ? FeedUrl.PRODUCTION : FeedUrl.GITHUB_LATEST
+
+ logger.info(`Using fallback feed URL: ${defaultFeedUrl}`)
+ this._setChannel(UpgradeChannel.LATEST, defaultFeedUrl)
}
public cancelDownload() {
@@ -320,8 +380,3 @@ export default class AppUpdater {
return processedInfo
}
}
-interface GithubReleaseInfo {
- draft: boolean
- prerelease: boolean
- tag_name: string
-}
diff --git a/src/main/services/CodeToolsService.ts b/src/main/services/CodeToolsService.ts
index 3a93a40d79..35655a88e7 100644
--- a/src/main/services/CodeToolsService.ts
+++ b/src/main/services/CodeToolsService.ts
@@ -10,6 +10,7 @@ import { getBinaryName } from '@main/utils/process'
import type { TerminalConfig, TerminalConfigWithCommand } from '@shared/config/constant'
import {
codeTools,
+ HOME_CHERRY_DIR,
MACOS_TERMINALS,
MACOS_TERMINALS_WITH_COMMANDS,
terminalApps,
@@ -66,7 +67,7 @@ class CodeToolsService {
}
public async getBunPath() {
- const dir = path.join(os.homedir(), '.cherrystudio', 'bin')
+ const dir = path.join(os.homedir(), HOME_CHERRY_DIR, 'bin')
const bunName = await getBinaryName('bun')
const bunPath = path.join(dir, bunName)
return bunPath
@@ -362,7 +363,7 @@ class CodeToolsService {
private async isPackageInstalled(cliTool: string): Promise {
const executableName = await this.getCliExecutableName(cliTool)
- const binDir = path.join(os.homedir(), '.cherrystudio', 'bin')
+ const binDir = path.join(os.homedir(), HOME_CHERRY_DIR, 'bin')
const executablePath = path.join(binDir, executableName + (isWin ? '.exe' : ''))
// Ensure bin directory exists
@@ -389,7 +390,7 @@ class CodeToolsService {
logger.info(`${cliTool} is installed, getting current version`)
try {
const executableName = await this.getCliExecutableName(cliTool)
- const binDir = path.join(os.homedir(), '.cherrystudio', 'bin')
+ const binDir = path.join(os.homedir(), HOME_CHERRY_DIR, 'bin')
const executablePath = path.join(binDir, executableName + (isWin ? '.exe' : ''))
const { stdout } = await execAsync(`"${executablePath}" --version`, {
@@ -500,7 +501,7 @@ class CodeToolsService {
try {
const packageName = await this.getPackageName(cliTool)
const bunPath = await this.getBunPath()
- const bunInstallPath = path.join(os.homedir(), '.cherrystudio')
+ const bunInstallPath = path.join(os.homedir(), HOME_CHERRY_DIR)
const registryUrl = await this.getNpmRegistryUrl()
const installEnvPrefix = isWin
@@ -547,10 +548,21 @@ class CodeToolsService {
logger.debug(`Environment variables:`, Object.keys(env))
logger.debug(`Options:`, options)
+ // Validate directory exists before proceeding
+ if (!directory || !fs.existsSync(directory)) {
+ const errorMessage = `Directory does not exist: ${directory}`
+ logger.error(errorMessage)
+ return {
+ success: false,
+ message: errorMessage,
+ command: ''
+ }
+ }
+
const packageName = await this.getPackageName(cliTool)
const bunPath = await this.getBunPath()
const executableName = await this.getCliExecutableName(cliTool)
- const binDir = path.join(os.homedir(), '.cherrystudio', 'bin')
+ const binDir = path.join(os.homedir(), HOME_CHERRY_DIR, 'bin')
const executablePath = path.join(binDir, executableName + (isWin ? '.exe' : ''))
logger.debug(`Package name: ${packageName}`)
@@ -652,7 +664,7 @@ class CodeToolsService {
baseCommand = `${baseCommand} ${configParams}`
}
- const bunInstallPath = path.join(os.homedir(), '.cherrystudio')
+ const bunInstallPath = path.join(os.homedir(), HOME_CHERRY_DIR)
if (isInstalled) {
// If already installed, run executable directly (with optional update message)
@@ -708,6 +720,7 @@ class CodeToolsService {
// Build bat file content, including debug information
const batContent = [
'@echo off',
+ 'chcp 65001 >nul 2>&1', // Switch to UTF-8 code page for international path support
`title ${cliTool} - Cherry Studio`, // Set window title in bat file
'echo ================================================',
'echo Cherry Studio CLI Tool Launcher',
diff --git a/src/main/services/FileStorage.ts b/src/main/services/FileStorage.ts
index 00dda778be..c8eb6abb03 100644
--- a/src/main/services/FileStorage.ts
+++ b/src/main/services/FileStorage.ts
@@ -16,6 +16,7 @@ import type { FSWatcher } from 'chokidar'
import chokidar from 'chokidar'
import * as crypto from 'crypto'
import type { OpenDialogOptions, OpenDialogReturnValue, SaveDialogOptions, SaveDialogReturnValue } from 'electron'
+import { app } from 'electron'
import { dialog, net, shell } from 'electron'
import * as fs from 'fs'
import { writeFileSync } from 'fs'
@@ -30,6 +31,73 @@ import WordExtractor from 'word-extractor'
const logger = loggerService.withContext('FileStorage')
+// Get ripgrep binary path
+const getRipgrepBinaryPath = (): string | null => {
+ try {
+ const arch = process.arch === 'arm64' ? 'arm64' : 'x64'
+ const platform = process.platform === 'darwin' ? 'darwin' : process.platform === 'win32' ? 'win32' : 'linux'
+ let ripgrepBinaryPath = path.join(
+ __dirname,
+ '../../node_modules/@anthropic-ai/claude-agent-sdk/vendor/ripgrep',
+ `${arch}-${platform}`,
+ process.platform === 'win32' ? 'rg.exe' : 'rg'
+ )
+
+ if (app.isPackaged) {
+ ripgrepBinaryPath = ripgrepBinaryPath.replace(/\.asar([\\/])/, '.asar.unpacked$1')
+ }
+
+ if (fs.existsSync(ripgrepBinaryPath)) {
+ return ripgrepBinaryPath
+ }
+ return null
+ } catch (error) {
+ logger.error('Failed to locate ripgrep binary:', error as Error)
+ return null
+ }
+}
+
+/**
+ * Execute ripgrep with captured output
+ */
+function executeRipgrep(args: string[]): Promise<{ exitCode: number; output: string }> {
+ return new Promise((resolve, reject) => {
+ const ripgrepBinaryPath = getRipgrepBinaryPath()
+
+ if (!ripgrepBinaryPath) {
+ reject(new Error('Ripgrep binary not available'))
+ return
+ }
+
+ const { spawn } = require('child_process')
+ const child = spawn(ripgrepBinaryPath, ['--no-config', '--ignore-case', ...args], {
+ stdio: ['pipe', 'pipe', 'pipe']
+ })
+
+ let output = ''
+ let errorOutput = ''
+
+ child.stdout.on('data', (data: Buffer) => {
+ output += data.toString()
+ })
+
+ child.stderr.on('data', (data: Buffer) => {
+ errorOutput += data.toString()
+ })
+
+ child.on('close', (code: number) => {
+ resolve({
+ exitCode: code || 0,
+ output: output || errorOutput
+ })
+ })
+
+ child.on('error', (error: Error) => {
+ reject(error)
+ })
+ })
+}
+
interface FileWatcherConfig {
watchExtensions?: string[]
ignoredPatterns?: (string | RegExp)[]
@@ -54,6 +122,26 @@ const DEFAULT_WATCHER_CONFIG: Required = {
eventChannel: 'file-change'
}
+interface DirectoryListOptions {
+ recursive?: boolean
+ maxDepth?: number
+ includeHidden?: boolean
+ includeFiles?: boolean
+ includeDirectories?: boolean
+ maxEntries?: number
+ searchPattern?: string
+}
+
+const DEFAULT_DIRECTORY_LIST_OPTIONS: Required = {
+ recursive: true,
+ maxDepth: 3,
+ includeHidden: false,
+ includeFiles: true,
+ includeDirectories: true,
+ maxEntries: 10,
+ searchPattern: '.'
+}
+
class FileStorage {
private storageDir = getFilesDir()
private notesDir = getNotesDir()
@@ -390,13 +478,16 @@ class FileStorage {
}
}
- public readFile = async (
- _: Electron.IpcMainInvokeEvent,
- id: string,
- detectEncoding: boolean = false
- ): Promise => {
- const filePath = path.join(this.storageDir, id)
-
+ /**
+ * Core file reading logic that handles both documents and text files.
+ *
+ * @private
+ * @param filePath - Full path to the file
+ * @param detectEncoding - Whether to auto-detect text file encoding
+ * @returns Promise resolving to the extracted text content
+ * @throws Error if file reading fails
+ */
+ private async readFileCore(filePath: string, detectEncoding: boolean = false): Promise {
const fileExtension = path.extname(filePath)
if (documentExts.includes(fileExtension)) {
@@ -416,7 +507,7 @@ class FileStorage {
return data
} catch (error) {
chdir(originalCwd)
- logger.error('Failed to read file:', error as Error)
+ logger.error('Failed to read document file:', error as Error)
throw error
}
}
@@ -428,11 +519,72 @@ class FileStorage {
return fs.readFileSync(filePath, 'utf-8')
}
} catch (error) {
- logger.error('Failed to read file:', error as Error)
+ logger.error('Failed to read text file:', error as Error)
throw new Error(`Failed to read file: ${filePath}.`)
}
}
+ /**
+ * Reads and extracts content from a stored file.
+ *
+ * Supports multiple file formats including:
+ * - Complex documents: .pdf, .doc, .docx, .pptx, .xlsx, .odt, .odp, .ods
+ * - Text files: .txt, .md, .json, .csv, etc.
+ * - Code files: .js, .ts, .py, .java, etc.
+ *
+ * For document formats, extracts text content using specialized parsers:
+ * - .doc files: Uses word-extractor library
+ * - Other Office formats: Uses officeparser library
+ *
+ * For text files, can optionally detect encoding automatically.
+ *
+ * @param _ - Electron IPC invoke event (unused)
+ * @param id - File identifier with extension (e.g., "uuid.docx")
+ * @param detectEncoding - Whether to auto-detect text file encoding (default: false)
+ * @returns Promise resolving to the extracted text content of the file
+ * @throws Error if file reading fails or file is not found
+ *
+ * @example
+ * // Read a DOCX file
+ * const content = await readFile(event, "document.docx");
+ *
+ * @example
+ * // Read a text file with encoding detection
+ * const content = await readFile(event, "text.txt", true);
+ *
+ * @example
+ * // Read a PDF file
+ * const content = await readFile(event, "manual.pdf");
+ */
+ public readFile = async (
+ _: Electron.IpcMainInvokeEvent,
+ id: string,
+ detectEncoding: boolean = false
+ ): Promise => {
+ const filePath = path.join(this.storageDir, id)
+ return this.readFileCore(filePath, detectEncoding)
+ }
+
+ /**
+ * Reads and extracts content from an external file path.
+ *
+ * Similar to readFile, but operates on external file paths instead of stored files.
+ * Supports the same file formats including complex documents and text files.
+ *
+ * @param _ - Electron IPC invoke event (unused)
+ * @param filePath - Absolute path to the external file
+ * @param detectEncoding - Whether to auto-detect text file encoding (default: false)
+ * @returns Promise resolving to the extracted text content of the file
+ * @throws Error if file does not exist or reading fails
+ *
+ * @example
+ * // Read an external DOCX file
+ * const content = await readExternalFile(event, "/path/to/document.docx");
+ *
+ * @example
+ * // Read an external text file with encoding detection
+ * const content = await readExternalFile(event, "/path/to/text.txt", true);
+ */
public readExternalFile = async (
_: Electron.IpcMainInvokeEvent,
filePath: string,
@@ -442,40 +594,7 @@ class FileStorage {
throw new Error(`File does not exist: ${filePath}`)
}
- const fileExtension = path.extname(filePath)
-
- if (documentExts.includes(fileExtension)) {
- const originalCwd = process.cwd()
- try {
- chdir(this.tempDir)
-
- if (fileExtension === '.doc') {
- const extractor = new WordExtractor()
- const extracted = await extractor.extract(filePath)
- chdir(originalCwd)
- return extracted.getBody()
- }
-
- const data = await officeParser.parseOfficeAsync(filePath)
- chdir(originalCwd)
- return data
- } catch (error) {
- chdir(originalCwd)
- logger.error('Failed to read file:', error as Error)
- throw error
- }
- }
-
- try {
- if (detectEncoding) {
- return readTextFileWithAutoEncoding(filePath)
- } else {
- return fs.readFileSync(filePath, 'utf-8')
- }
- } catch (error) {
- logger.error('Failed to read file:', error as Error)
- throw new Error(`Failed to read file: ${filePath}.`)
- }
+ return this.readFileCore(filePath, detectEncoding)
}
public createTempFile = async (_: Electron.IpcMainInvokeEvent, fileName: string): Promise => {
@@ -748,6 +867,284 @@ class FileStorage {
}
}
+ public listDirectory = async (
+ _: Electron.IpcMainInvokeEvent,
+ dirPath: string,
+ options?: DirectoryListOptions
+ ): Promise => {
+ const mergedOptions: Required = {
+ ...DEFAULT_DIRECTORY_LIST_OPTIONS,
+ ...options
+ }
+
+ const resolvedPath = path.resolve(dirPath)
+
+ const stat = await fs.promises.stat(resolvedPath).catch((error) => {
+ logger.error(`[IPC - Error] Failed to access directory: ${resolvedPath}`, error as Error)
+ throw error
+ })
+
+ if (!stat.isDirectory()) {
+ throw new Error(`Path is not a directory: ${resolvedPath}`)
+ }
+
+ // Use ripgrep for file listing with relevance-based sorting
+ if (!getRipgrepBinaryPath()) {
+ throw new Error('Ripgrep binary not available')
+ }
+
+ return await this.listDirectoryWithRipgrep(resolvedPath, mergedOptions)
+ }
+
+ /**
+ * Search directories by name pattern
+ */
+ private async searchDirectories(
+ resolvedPath: string,
+ options: Required,
+ currentDepth: number = 0
+ ): Promise {
+ if (!options.includeDirectories) return []
+ if (!options.recursive && currentDepth > 0) return []
+ if (options.maxDepth > 0 && currentDepth >= options.maxDepth) return []
+
+ const directories: string[] = []
+ const excludedDirs = new Set([
+ 'node_modules',
+ '.git',
+ '.idea',
+ '.vscode',
+ 'dist',
+ 'build',
+ '.next',
+ '.nuxt',
+ 'coverage',
+ '.cache'
+ ])
+
+ try {
+ const entries = await fs.promises.readdir(resolvedPath, { withFileTypes: true })
+ const searchPatternLower = options.searchPattern.toLowerCase()
+
+ for (const entry of entries) {
+ if (!entry.isDirectory()) continue
+
+ // Skip hidden directories unless explicitly included
+ if (!options.includeHidden && entry.name.startsWith('.')) continue
+
+ // Skip excluded directories
+ if (excludedDirs.has(entry.name)) continue
+
+ const fullPath = path.join(resolvedPath, entry.name).replace(/\\/g, '/')
+
+ // Check if directory name matches search pattern
+ if (options.searchPattern === '.' || entry.name.toLowerCase().includes(searchPatternLower)) {
+ directories.push(fullPath)
+ }
+
+ // Recursively search subdirectories
+ if (options.recursive && currentDepth < options.maxDepth) {
+ const subDirs = await this.searchDirectories(fullPath, options, currentDepth + 1)
+ directories.push(...subDirs)
+ }
+ }
+ } catch (error) {
+ logger.warn(`Failed to search directories in: ${resolvedPath}`, error as Error)
+ }
+
+ return directories
+ }
+
+ /**
+ * Search files by filename pattern
+ */
+ private async searchByFilename(resolvedPath: string, options: Required): Promise {
+ const files: string[] = []
+ const directories: string[] = []
+
+ // Search for files using ripgrep
+ if (options.includeFiles) {
+ const args: string[] = ['--files']
+
+ // Handle hidden files
+ if (!options.includeHidden) {
+ args.push('--glob', '!.*')
+ }
+
+ // Use --iglob to let ripgrep filter filenames (case-insensitive)
+ if (options.searchPattern && options.searchPattern !== '.') {
+ args.push('--iglob', `*${options.searchPattern}*`)
+ }
+
+ // Exclude common hidden directories and large directories
+ args.push('-g', '!**/node_modules/**')
+ args.push('-g', '!**/.git/**')
+ args.push('-g', '!**/.idea/**')
+ args.push('-g', '!**/.vscode/**')
+ args.push('-g', '!**/.DS_Store')
+ args.push('-g', '!**/dist/**')
+ args.push('-g', '!**/build/**')
+ args.push('-g', '!**/.next/**')
+ args.push('-g', '!**/.nuxt/**')
+ args.push('-g', '!**/coverage/**')
+ args.push('-g', '!**/.cache/**')
+
+ // Handle max depth
+ if (!options.recursive) {
+ args.push('--max-depth', '1')
+ } else if (options.maxDepth > 0) {
+ args.push('--max-depth', options.maxDepth.toString())
+ }
+
+ // Add the directory path
+ args.push(resolvedPath)
+
+ const { exitCode, output } = await executeRipgrep(args)
+
+ // Exit code 0 means files found, 1 means no files found (still success), 2+ means error
+ if (exitCode >= 2) {
+ throw new Error(`Ripgrep failed with exit code ${exitCode}: ${output}`)
+ }
+
+ // Parse ripgrep output (no need to filter by filename - ripgrep already did it)
+ files.push(
+ ...output
+ .split('\n')
+ .filter((line) => line.trim())
+ .map((line) => line.replace(/\\/g, '/'))
+ )
+ }
+
+ // Search for directories
+ if (options.includeDirectories) {
+ directories.push(...(await this.searchDirectories(resolvedPath, options)))
+ }
+
+ // Combine and sort: directories first (alphabetically), then files (alphabetically)
+ const sortedDirectories = directories.sort((a, b) => {
+ const aName = path.basename(a)
+ const bName = path.basename(b)
+ return aName.localeCompare(bName)
+ })
+
+ const sortedFiles = files.sort((a, b) => {
+ const aName = path.basename(a)
+ const bName = path.basename(b)
+ return aName.localeCompare(bName)
+ })
+
+ return [...sortedDirectories, ...sortedFiles].slice(0, options.maxEntries)
+ }
+
+ /**
+ * Search files by content pattern
+ */
+ private async searchByContent(resolvedPath: string, options: Required): Promise {
+ const args: string[] = ['-l']
+
+ // Handle hidden files
+ if (!options.includeHidden) {
+ args.push('--glob', '!.*')
+ }
+
+ // Exclude common hidden directories and large directories
+ args.push('-g', '!**/node_modules/**')
+ args.push('-g', '!**/.git/**')
+ args.push('-g', '!**/.idea/**')
+ args.push('-g', '!**/.vscode/**')
+ args.push('-g', '!**/.DS_Store')
+ args.push('-g', '!**/dist/**')
+ args.push('-g', '!**/build/**')
+ args.push('-g', '!**/.next/**')
+ args.push('-g', '!**/.nuxt/**')
+ args.push('-g', '!**/coverage/**')
+ args.push('-g', '!**/.cache/**')
+
+ // Handle max depth
+ if (!options.recursive) {
+ args.push('--max-depth', '1')
+ } else if (options.maxDepth > 0) {
+ args.push('--max-depth', options.maxDepth.toString())
+ }
+
+ // Handle max count
+ if (options.maxEntries > 0) {
+ args.push('--max-count', options.maxEntries.toString())
+ }
+
+ // Add search pattern (search in content)
+ args.push(options.searchPattern)
+
+ // Add the directory path
+ args.push(resolvedPath)
+
+ const { exitCode, output } = await executeRipgrep(args)
+
+ // Exit code 0 means files found, 1 means no files found (still success), 2+ means error
+ if (exitCode >= 2) {
+ throw new Error(`Ripgrep failed with exit code ${exitCode}: ${output}`)
+ }
+
+ // Parse ripgrep output (already sorted by relevance)
+ const results = output
+ .split('\n')
+ .filter((line) => line.trim())
+ .map((line) => line.replace(/\\/g, '/'))
+ .slice(0, options.maxEntries)
+
+ return results
+ }
+
+ private async listDirectoryWithRipgrep(
+ resolvedPath: string,
+ options: Required
+ ): Promise {
+ const maxEntries = options.maxEntries
+
+ // Step 1: Search by filename first
+ logger.debug('Searching by filename pattern', { pattern: options.searchPattern, path: resolvedPath })
+ const filenameResults = await this.searchByFilename(resolvedPath, options)
+
+ logger.debug('Found matches by filename', { count: filenameResults.length })
+
+ // If we have enough filename matches, return them
+ if (filenameResults.length >= maxEntries) {
+ return filenameResults.slice(0, maxEntries)
+ }
+
+ // Step 2: If filename matches are less than maxEntries, search by content to fill up
+ logger.debug('Filename matches insufficient, searching by content to fill up', {
+ filenameCount: filenameResults.length,
+ needed: maxEntries - filenameResults.length
+ })
+
+ // Adjust maxEntries for content search to get enough results
+ const contentOptions = {
+ ...options,
+ maxEntries: maxEntries - filenameResults.length + 20 // Request extra to account for duplicates
+ }
+
+ const contentResults = await this.searchByContent(resolvedPath, contentOptions)
+
+ logger.debug('Found matches by content', { count: contentResults.length })
+
+ // Combine results: filename matches first, then content matches (deduplicated)
+ const combined = [...filenameResults]
+ const filenameSet = new Set(filenameResults)
+
+ for (const filePath of contentResults) {
+ if (!filenameSet.has(filePath)) {
+ combined.push(filePath)
+ if (combined.length >= maxEntries) {
+ break
+ }
+ }
+ }
+
+ logger.debug('Combined results', { total: combined.length, filenameCount: filenameResults.length })
+ return combined.slice(0, maxEntries)
+ }
+
public validateNotesDirectory = async (_: Electron.IpcMainInvokeEvent, dirPath: string): Promise => {
try {
if (!dirPath || typeof dirPath !== 'string') {
diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts
index ba3340780b..3925376226 100644
--- a/src/main/services/MCPService.ts
+++ b/src/main/services/MCPService.ts
@@ -12,6 +12,7 @@ import { TraceMethod, withSpanFunc } from '@mcp-trace/trace-core'
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
import type { SSEClientTransportOptions } from '@modelcontextprotocol/sdk/client/sse.js'
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
+import type { StdioServerParameters } from '@modelcontextprotocol/sdk/client/stdio.js'
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'
import {
StreamableHTTPClientTransport,
@@ -30,6 +31,7 @@ import {
ToolListChangedNotificationSchema
} from '@modelcontextprotocol/sdk/types.js'
import { nanoid } from '@reduxjs/toolkit'
+import { HOME_CHERRY_DIR } from '@shared/config/constant'
import type { MCPProgressEvent } from '@shared/config/types'
import { IpcChannel } from '@shared/IpcChannel'
import { defaultAppHeaders } from '@shared/utils'
@@ -41,11 +43,14 @@ import {
type MCPPrompt,
type MCPResource,
type MCPServer,
- type MCPTool
+ type MCPTool,
+ MCPToolInputSchema,
+ MCPToolOutputSchema
} from '@types'
import { app, net } from 'electron'
import { EventEmitter } from 'events'
import { v4 as uuidv4 } from 'uuid'
+import * as z from 'zod'
import { CacheService } from './CacheService'
import DxtService from './DxtService'
@@ -342,7 +347,7 @@ class McpService {
removeEnvProxy(loginShellEnv)
}
- const transportOptions: any = {
+ const transportOptions: StdioServerParameters = {
command: cmd,
args,
env: {
@@ -619,7 +624,9 @@ class McpService {
tools.map((tool: SDKTool) => {
const serverTool: MCPTool = {
...tool,
- id: buildFunctionCallToolName(server.name, tool.name),
+ inputSchema: z.parse(MCPToolInputSchema, tool.inputSchema),
+ outputSchema: tool.outputSchema ? z.parse(MCPToolOutputSchema, tool.outputSchema) : undefined,
+ id: buildFunctionCallToolName(server.name, tool.name, server.id),
serverId: server.id,
serverName: server.name,
type: 'mcp'
@@ -715,7 +722,7 @@ class McpService {
}
public async getInstallInfo() {
- const dir = path.join(os.homedir(), '.cherrystudio', 'bin')
+ const dir = path.join(os.homedir(), HOME_CHERRY_DIR, 'bin')
const uvName = await getBinaryName('uv')
const bunName = await getBinaryName('bun')
const uvPath = path.join(dir, uvName)
diff --git a/src/main/services/OvmsManager.ts b/src/main/services/OvmsManager.ts
index f319200ac3..3a32d74ecf 100644
--- a/src/main/services/OvmsManager.ts
+++ b/src/main/services/OvmsManager.ts
@@ -3,6 +3,7 @@ import { homedir } from 'node:os'
import { promisify } from 'node:util'
import { loggerService } from '@logger'
+import { HOME_CHERRY_DIR } from '@shared/config/constant'
import * as fs from 'fs-extra'
import * as path from 'path'
@@ -145,7 +146,7 @@ class OvmsManager {
*/
public async runOvms(): Promise<{ success: boolean; message?: string }> {
const homeDir = homedir()
- const ovmsDir = path.join(homeDir, '.cherrystudio', 'ovms', 'ovms')
+ const ovmsDir = path.join(homeDir, HOME_CHERRY_DIR, 'ovms', 'ovms')
const configPath = path.join(ovmsDir, 'models', 'config.json')
const runBatPath = path.join(ovmsDir, 'run.bat')
@@ -195,7 +196,7 @@ class OvmsManager {
*/
public async getOvmsStatus(): Promise<'not-installed' | 'not-running' | 'running'> {
const homeDir = homedir()
- const ovmsPath = path.join(homeDir, '.cherrystudio', 'ovms', 'ovms', 'ovms.exe')
+ const ovmsPath = path.join(homeDir, HOME_CHERRY_DIR, 'ovms', 'ovms', 'ovms.exe')
try {
// Check if OVMS executable exists
@@ -273,7 +274,7 @@ class OvmsManager {
}
const homeDir = homedir()
- const configPath = path.join(homeDir, '.cherrystudio', 'ovms', 'ovms', 'models', 'config.json')
+ const configPath = path.join(homeDir, HOME_CHERRY_DIR, 'ovms', 'ovms', 'models', 'config.json')
try {
if (!(await fs.pathExists(configPath))) {
logger.warn(`Config file does not exist: ${configPath}`)
@@ -304,7 +305,7 @@ class OvmsManager {
private async applyModelPath(modelDirPath: string): Promise {
const homeDir = homedir()
- const patchDir = path.join(homeDir, '.cherrystudio', 'ovms', 'patch')
+ const patchDir = path.join(homeDir, HOME_CHERRY_DIR, 'ovms', 'patch')
if (!(await fs.pathExists(patchDir))) {
return true
}
@@ -355,7 +356,7 @@ class OvmsManager {
logger.info(`Adding model: ${modelName} with ID: ${modelId}, Source: ${modelSource}, Task: ${task}`)
const homeDir = homedir()
- const ovdndDir = path.join(homeDir, '.cherrystudio', 'ovms', 'ovms')
+ const ovdndDir = path.join(homeDir, HOME_CHERRY_DIR, 'ovms', 'ovms')
const pathModel = path.join(ovdndDir, 'models', modelId)
try {
@@ -468,7 +469,7 @@ class OvmsManager {
*/
public async checkModelExists(modelId: string): Promise {
const homeDir = homedir()
- const ovmsDir = path.join(homeDir, '.cherrystudio', 'ovms', 'ovms')
+ const ovmsDir = path.join(homeDir, HOME_CHERRY_DIR, 'ovms', 'ovms')
const configPath = path.join(ovmsDir, 'models', 'config.json')
try {
@@ -495,7 +496,7 @@ class OvmsManager {
*/
public async updateModelConfig(modelName: string, modelId: string): Promise {
const homeDir = homedir()
- const ovmsDir = path.join(homeDir, '.cherrystudio', 'ovms', 'ovms')
+ const ovmsDir = path.join(homeDir, HOME_CHERRY_DIR, 'ovms', 'ovms')
const configPath = path.join(ovmsDir, 'models', 'config.json')
try {
@@ -548,7 +549,7 @@ class OvmsManager {
*/
public async getModels(): Promise {
const homeDir = homedir()
- const ovmsDir = path.join(homeDir, '.cherrystudio', 'ovms', 'ovms')
+ const ovmsDir = path.join(homeDir, HOME_CHERRY_DIR, 'ovms', 'ovms')
const configPath = path.join(ovmsDir, 'models', 'config.json')
try {
diff --git a/src/main/services/SpanCacheService.ts b/src/main/services/SpanCacheService.ts
index 62707388a4..47a89d4327 100644
--- a/src/main/services/SpanCacheService.ts
+++ b/src/main/services/SpanCacheService.ts
@@ -3,6 +3,7 @@ import type { Attributes, SpanEntity, TokenUsage, TraceCache } from '@mcp-trace/
import { convertSpanToSpanEntity } from '@mcp-trace/trace-core'
import { SpanStatusCode } from '@opentelemetry/api'
import type { ReadableSpan } from '@opentelemetry/sdk-trace-base'
+import { HOME_CHERRY_DIR } from '@shared/config/constant'
import fs from 'fs/promises'
import * as os from 'os'
import * as path from 'path'
@@ -18,7 +19,7 @@ class SpanCacheService implements TraceCache {
pri
constructor() {
- this.fileDir = path.join(os.homedir(), '.cherrystudio', 'trace')
+ this.fileDir = path.join(os.homedir(), HOME_CHERRY_DIR, 'trace')
}
createSpan: (span: ReadableSpan) => void = (span: ReadableSpan) => {
diff --git a/src/main/services/WindowService.ts b/src/main/services/WindowService.ts
index 66aed098e7..63eaaba995 100644
--- a/src/main/services/WindowService.ts
+++ b/src/main/services/WindowService.ts
@@ -375,13 +375,16 @@ export class WindowService {
mainWindow.hide()
- // TODO: don't hide dock icon when close to tray
- // will cause the cmd+h behavior not working
- // after the electron fix the bug, we can restore this code
- // //for mac users, should hide dock icon if close to tray
- // if (isMac && isTrayOnClose) {
- // app.dock?.hide()
- // }
+ //for mac users, should hide dock icon if close to tray
+ if (isMac && isTrayOnClose) {
+ app.dock?.hide()
+
+ mainWindow.once('show', () => {
+ //restore the window can hide by cmd+h when the window is shown again
+ // https://github.com/electron/electron/pull/47970
+ app.dock?.show()
+ })
+ }
})
mainWindow.on('closed', () => {
diff --git a/src/main/services/__tests__/AppUpdater.test.ts b/src/main/services/__tests__/AppUpdater.test.ts
index 1be0e2f486..f7de00475a 100644
--- a/src/main/services/__tests__/AppUpdater.test.ts
+++ b/src/main/services/__tests__/AppUpdater.test.ts
@@ -85,6 +85,9 @@ vi.mock('electron-updater', () => ({
}))
// Import after mocks
+import { UpdateMirror } from '@shared/config/constant'
+import { app, net } from 'electron'
+
import AppUpdater from '../AppUpdater'
import { configManager } from '../ConfigManager'
@@ -274,4 +277,711 @@ describe('AppUpdater', () => {
expect(result.releaseNotes).toBeNull()
})
})
+
+ describe('_fetchUpdateConfig', () => {
+ const mockConfig = {
+ lastUpdated: '2025-01-05T00:00:00Z',
+ versions: {
+ '1.6.7': {
+ minCompatibleVersion: '1.0.0',
+ description: 'Test version',
+ channels: {
+ latest: {
+ version: '1.6.7',
+ feedUrls: {
+ github: 'https://github.com/test/v1.6.7',
+ gitcode: 'https://gitcode.com/test/v1.6.7'
+ }
+ },
+ rc: null,
+ beta: null
+ }
+ }
+ }
+ }
+
+ it('should fetch config from GitHub mirror', async () => {
+ vi.mocked(net.fetch).mockResolvedValue({
+ ok: true,
+ json: async () => mockConfig
+ } as any)
+
+ const result = await (appUpdater as any)._fetchUpdateConfig(UpdateMirror.GITHUB)
+
+ expect(result).toEqual(mockConfig)
+ expect(net.fetch).toHaveBeenCalledWith(expect.stringContaining('github'), expect.any(Object))
+ })
+
+ it('should fetch config from GitCode mirror', async () => {
+ vi.mocked(net.fetch).mockResolvedValue({
+ ok: true,
+ json: async () => mockConfig
+ } as any)
+
+ const result = await (appUpdater as any)._fetchUpdateConfig(UpdateMirror.GITCODE)
+
+ expect(result).toEqual(mockConfig)
+ // GitCode URL may vary, just check that fetch was called
+ expect(net.fetch).toHaveBeenCalledWith(expect.any(String), expect.any(Object))
+ })
+
+ it('should return null on HTTP error', async () => {
+ vi.mocked(net.fetch).mockResolvedValue({
+ ok: false,
+ status: 404
+ } as any)
+
+ const result = await (appUpdater as any)._fetchUpdateConfig(UpdateMirror.GITHUB)
+
+ expect(result).toBeNull()
+ })
+
+ it('should return null on network error', async () => {
+ vi.mocked(net.fetch).mockRejectedValue(new Error('Network error'))
+
+ const result = await (appUpdater as any)._fetchUpdateConfig(UpdateMirror.GITHUB)
+
+ expect(result).toBeNull()
+ })
+ })
+
+ describe('_findCompatibleChannel', () => {
+ const mockConfig = {
+ lastUpdated: '2025-01-05T00:00:00Z',
+ versions: {
+ '1.6.7': {
+ minCompatibleVersion: '1.0.0',
+ description: 'v1.6.7',
+ channels: {
+ latest: {
+ version: '1.6.7',
+ feedUrls: {
+ github: 'https://github.com/test/v1.6.7',
+ gitcode: 'https://gitcode.com/test/v1.6.7'
+ }
+ },
+ rc: {
+ version: '1.7.0-rc.1',
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0-rc.1',
+ gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
+ }
+ },
+ beta: {
+ version: '1.7.0-beta.3',
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0-beta.3',
+ gitcode: 'https://gitcode.com/test/v1.7.0-beta.3'
+ }
+ }
+ }
+ },
+ '2.0.0': {
+ minCompatibleVersion: '1.7.0',
+ description: 'v2.0.0',
+ channels: {
+ latest: null,
+ rc: null,
+ beta: null
+ }
+ }
+ }
+ }
+
+ it('should find compatible latest channel', () => {
+ vi.mocked(app.getVersion).mockReturnValue('1.5.0')
+
+ const result = (appUpdater as any)._findCompatibleChannel('1.5.0', 'latest', mockConfig)
+
+ expect(result?.config).toEqual({
+ version: '1.6.7',
+ feedUrls: {
+ github: 'https://github.com/test/v1.6.7',
+ gitcode: 'https://gitcode.com/test/v1.6.7'
+ }
+ })
+ expect(result?.channel).toBe('latest')
+ })
+
+ it('should find compatible rc channel', () => {
+ vi.mocked(app.getVersion).mockReturnValue('1.5.0')
+
+ const result = (appUpdater as any)._findCompatibleChannel('1.5.0', 'rc', mockConfig)
+
+ expect(result?.config).toEqual({
+ version: '1.7.0-rc.1',
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0-rc.1',
+ gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
+ }
+ })
+ expect(result?.channel).toBe('rc')
+ })
+
+ it('should find compatible beta channel', () => {
+ vi.mocked(app.getVersion).mockReturnValue('1.5.0')
+
+ const result = (appUpdater as any)._findCompatibleChannel('1.5.0', 'beta', mockConfig)
+
+ expect(result?.config).toEqual({
+ version: '1.7.0-beta.3',
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0-beta.3',
+ gitcode: 'https://gitcode.com/test/v1.7.0-beta.3'
+ }
+ })
+ expect(result?.channel).toBe('beta')
+ })
+
+ it('should return latest when latest version >= rc version', () => {
+ const configWithNewerLatest = {
+ lastUpdated: '2025-01-05T00:00:00Z',
+ versions: {
+ '1.7.0': {
+ minCompatibleVersion: '1.0.0',
+ description: 'v1.7.0',
+ channels: {
+ latest: {
+ version: '1.7.0',
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0',
+ gitcode: 'https://gitcode.com/test/v1.7.0'
+ }
+ },
+ rc: {
+ version: '1.7.0-rc.1',
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0-rc.1',
+ gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
+ }
+ },
+ beta: null
+ }
+ }
+ }
+ }
+
+ const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'rc', configWithNewerLatest)
+
+ // Should return latest instead of rc because 1.7.0 >= 1.7.0-rc.1
+ expect(result?.config).toEqual({
+ version: '1.7.0',
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0',
+ gitcode: 'https://gitcode.com/test/v1.7.0'
+ }
+ })
+ expect(result?.channel).toBe('latest') // ✅ 返回 latest 频道
+ })
+
+ it('should return latest when latest version >= beta version', () => {
+ const configWithNewerLatest = {
+ lastUpdated: '2025-01-05T00:00:00Z',
+ versions: {
+ '1.7.0': {
+ minCompatibleVersion: '1.0.0',
+ description: 'v1.7.0',
+ channels: {
+ latest: {
+ version: '1.7.0',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0',
+
+ gitcode: 'https://gitcode.com/test/v1.7.0'
+ }
+ },
+ rc: null,
+ beta: {
+ version: '1.6.8-beta.1',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.6.8-beta.1',
+
+ gitcode: 'https://gitcode.com/test/v1.6.8-beta.1'
+ }
+ }
+ }
+ }
+ }
+ }
+
+ const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'beta', configWithNewerLatest)
+
+ // Should return latest instead of beta because 1.7.0 >= 1.6.8-beta.1
+ expect(result?.config).toEqual({
+ version: '1.7.0',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0',
+
+ gitcode: 'https://gitcode.com/test/v1.7.0'
+ }
+ })
+ })
+
+ it('should not compare latest with itself when requesting latest channel', () => {
+ const config = {
+ lastUpdated: '2025-01-05T00:00:00Z',
+ versions: {
+ '1.7.0': {
+ minCompatibleVersion: '1.0.0',
+ description: 'v1.7.0',
+ channels: {
+ latest: {
+ version: '1.7.0',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0',
+
+ gitcode: 'https://gitcode.com/test/v1.7.0'
+ }
+ },
+ rc: {
+ version: '1.7.0-rc.1',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0-rc.1',
+
+ gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
+ }
+ },
+ beta: null
+ }
+ }
+ }
+ }
+
+ const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'latest', config)
+
+ // Should return latest directly without comparing with itself
+ expect(result?.config).toEqual({
+ version: '1.7.0',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0',
+
+ gitcode: 'https://gitcode.com/test/v1.7.0'
+ }
+ })
+ })
+
+ it('should return rc when rc version > latest version', () => {
+ const configWithNewerRc = {
+ lastUpdated: '2025-01-05T00:00:00Z',
+ versions: {
+ '1.7.0': {
+ minCompatibleVersion: '1.0.0',
+ description: 'v1.7.0',
+ channels: {
+ latest: {
+ version: '1.6.7',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.6.7',
+
+ gitcode: 'https://gitcode.com/test/v1.6.7'
+ }
+ },
+ rc: {
+ version: '1.7.0-rc.1',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0-rc.1',
+
+ gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
+ }
+ },
+ beta: null
+ }
+ }
+ }
+ }
+
+ const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'rc', configWithNewerRc)
+
+ // Should return rc because 1.7.0-rc.1 > 1.6.7
+ expect(result?.config).toEqual({
+ version: '1.7.0-rc.1',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0-rc.1',
+
+ gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
+ }
+ })
+ })
+
+ it('should return beta when beta version > latest version', () => {
+ const configWithNewerBeta = {
+ lastUpdated: '2025-01-05T00:00:00Z',
+ versions: {
+ '1.7.0': {
+ minCompatibleVersion: '1.0.0',
+ description: 'v1.7.0',
+ channels: {
+ latest: {
+ version: '1.6.7',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.6.7',
+
+ gitcode: 'https://gitcode.com/test/v1.6.7'
+ }
+ },
+ rc: null,
+ beta: {
+ version: '1.7.0-beta.5',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0-beta.5',
+
+ gitcode: 'https://gitcode.com/test/v1.7.0-beta.5'
+ }
+ }
+ }
+ }
+ }
+ }
+
+ const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'beta', configWithNewerBeta)
+
+ // Should return beta because 1.7.0-beta.5 > 1.6.7
+ expect(result?.config).toEqual({
+ version: '1.7.0-beta.5',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0-beta.5',
+
+ gitcode: 'https://gitcode.com/test/v1.7.0-beta.5'
+ }
+ })
+ })
+
+ it('should return lower version when higher version has no compatible channel', () => {
+ vi.mocked(app.getVersion).mockReturnValue('1.8.0')
+
+ const result = (appUpdater as any)._findCompatibleChannel('1.8.0', 'latest', mockConfig)
+
+ // 1.8.0 >= 1.7.0 but 2.0.0 has no latest channel, so return 1.6.7
+ expect(result?.config).toEqual({
+ version: '1.6.7',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.6.7',
+
+ gitcode: 'https://gitcode.com/test/v1.6.7'
+ }
+ })
+ })
+
+ it('should return null when current version does not meet minCompatibleVersion', () => {
+ vi.mocked(app.getVersion).mockReturnValue('0.9.0')
+
+ const result = (appUpdater as any)._findCompatibleChannel('0.9.0', 'latest', mockConfig)
+
+ // 0.9.0 < 1.0.0 (minCompatibleVersion)
+ expect(result).toBeNull()
+ })
+
+ it('should return lower version rc when higher version has no rc channel', () => {
+ const result = (appUpdater as any)._findCompatibleChannel('1.8.0', 'rc', mockConfig)
+
+ // 1.8.0 >= 1.7.0 but 2.0.0 has no rc channel, so return 1.6.7 rc
+ expect(result?.config).toEqual({
+ version: '1.7.0-rc.1',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0-rc.1',
+
+ gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
+ }
+ })
+ })
+
+ it('should return null when no version has the requested channel', () => {
+ const configWithoutRc = {
+ lastUpdated: '2025-01-05T00:00:00Z',
+ versions: {
+ '1.6.7': {
+ minCompatibleVersion: '1.0.0',
+ description: 'v1.6.7',
+ channels: {
+ latest: {
+ version: '1.6.7',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.6.7',
+
+ gitcode: 'https://gitcode.com/test/v1.6.7'
+ }
+ },
+ rc: null,
+ beta: null
+ }
+ }
+ }
+ }
+
+ const result = (appUpdater as any)._findCompatibleChannel('1.5.0', 'rc', configWithoutRc)
+
+ expect(result).toBeNull()
+ })
+ })
+
+ describe('Upgrade Path', () => {
+ const fullConfig = {
+ lastUpdated: '2025-01-05T00:00:00Z',
+ versions: {
+ '1.6.7': {
+ minCompatibleVersion: '1.0.0',
+ description: 'Last v1.x',
+ channels: {
+ latest: {
+ version: '1.6.7',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.6.7',
+
+ gitcode: 'https://gitcode.com/test/v1.6.7'
+ }
+ },
+ rc: {
+ version: '1.7.0-rc.1',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0-rc.1',
+
+ gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
+ }
+ },
+ beta: {
+ version: '1.7.0-beta.3',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.0-beta.3',
+
+ gitcode: 'https://gitcode.com/test/v1.7.0-beta.3'
+ }
+ }
+ }
+ },
+ '2.0.0': {
+ minCompatibleVersion: '1.7.0',
+ description: 'First v2.x',
+ channels: {
+ latest: null,
+ rc: null,
+ beta: null
+ }
+ }
+ }
+ }
+
+ it('should upgrade from 1.6.3 to 1.6.7', () => {
+ const result = (appUpdater as any)._findCompatibleChannel('1.6.3', 'latest', fullConfig)
+
+ expect(result?.config).toEqual({
+ version: '1.6.7',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.6.7',
+
+ gitcode: 'https://gitcode.com/test/v1.6.7'
+ }
+ })
+ })
+
+ it('should block upgrade from 1.6.7 to 2.0.0 (minCompatibleVersion not met)', () => {
+ const result = (appUpdater as any)._findCompatibleChannel('1.6.7', 'latest', fullConfig)
+
+ // Should return 1.6.7, not 2.0.0, because 1.6.7 < 1.7.0 (minCompatibleVersion of 2.0.0)
+ expect(result?.config).toEqual({
+ version: '1.6.7',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.6.7',
+
+ gitcode: 'https://gitcode.com/test/v1.6.7'
+ }
+ })
+ })
+
+ it('should allow upgrade from 1.7.0 to 2.0.0', () => {
+ const configWith2x = {
+ ...fullConfig,
+ versions: {
+ ...fullConfig.versions,
+ '2.0.0': {
+ minCompatibleVersion: '1.7.0',
+ description: 'First v2.x',
+ channels: {
+ latest: {
+ version: '2.0.0',
+
+ feedUrls: {
+ github: 'https://github.com/test/v2.0.0',
+
+ gitcode: 'https://gitcode.com/test/v2.0.0'
+ }
+ },
+ rc: null,
+ beta: null
+ }
+ }
+ }
+ }
+
+ const result = (appUpdater as any)._findCompatibleChannel('1.7.0', 'latest', configWith2x)
+
+ expect(result?.config).toEqual({
+ version: '2.0.0',
+
+ feedUrls: {
+ github: 'https://github.com/test/v2.0.0',
+
+ gitcode: 'https://gitcode.com/test/v2.0.0'
+ }
+ })
+ })
+ })
+
+ describe('Complete Multi-Step Upgrade Path', () => {
+ const fullUpgradeConfig = {
+ lastUpdated: '2025-01-05T00:00:00Z',
+ versions: {
+ '1.7.5': {
+ minCompatibleVersion: '1.0.0',
+ description: 'Last v1.x stable',
+ channels: {
+ latest: {
+ version: '1.7.5',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.5',
+
+ gitcode: 'https://gitcode.com/test/v1.7.5'
+ }
+ },
+ rc: null,
+ beta: null
+ }
+ },
+ '2.0.0': {
+ minCompatibleVersion: '1.7.0',
+ description: 'First v2.x - intermediate version',
+ channels: {
+ latest: {
+ version: '2.0.0',
+
+ feedUrls: {
+ github: 'https://github.com/test/v2.0.0',
+
+ gitcode: 'https://gitcode.com/test/v2.0.0'
+ }
+ },
+ rc: null,
+ beta: null
+ }
+ },
+ '2.1.6': {
+ minCompatibleVersion: '2.0.0',
+ description: 'Current v2.x stable',
+ channels: {
+ latest: {
+ version: '2.1.6',
+
+ feedUrls: {
+ github: 'https://github.com/test/latest',
+
+ gitcode: 'https://gitcode.com/test/latest'
+ }
+ },
+ rc: null,
+ beta: null
+ }
+ }
+ }
+ }
+
+ it('should upgrade from 1.6.3 to 1.7.5 (step 1)', () => {
+ const result = (appUpdater as any)._findCompatibleChannel('1.6.3', 'latest', fullUpgradeConfig)
+
+ expect(result?.config).toEqual({
+ version: '1.7.5',
+
+ feedUrls: {
+ github: 'https://github.com/test/v1.7.5',
+
+ gitcode: 'https://gitcode.com/test/v1.7.5'
+ }
+ })
+ })
+
+ it('should upgrade from 1.7.5 to 2.0.0 (step 2)', () => {
+ const result = (appUpdater as any)._findCompatibleChannel('1.7.5', 'latest', fullUpgradeConfig)
+
+ expect(result?.config).toEqual({
+ version: '2.0.0',
+
+ feedUrls: {
+ github: 'https://github.com/test/v2.0.0',
+
+ gitcode: 'https://gitcode.com/test/v2.0.0'
+ }
+ })
+ })
+
+ it('should upgrade from 2.0.0 to 2.1.6 (step 3)', () => {
+ const result = (appUpdater as any)._findCompatibleChannel('2.0.0', 'latest', fullUpgradeConfig)
+
+ expect(result?.config).toEqual({
+ version: '2.1.6',
+
+ feedUrls: {
+ github: 'https://github.com/test/latest',
+
+ gitcode: 'https://gitcode.com/test/latest'
+ }
+ })
+ })
+
+ it('should complete full upgrade path: 1.6.3 -> 1.7.5 -> 2.0.0 -> 2.1.6', () => {
+ // Step 1: 1.6.3 -> 1.7.5
+ let currentVersion = '1.6.3'
+ let result = (appUpdater as any)._findCompatibleChannel(currentVersion, 'latest', fullUpgradeConfig)
+ expect(result?.config.version).toBe('1.7.5')
+
+ // Step 2: 1.7.5 -> 2.0.0
+ currentVersion = result?.config.version!
+ result = (appUpdater as any)._findCompatibleChannel(currentVersion, 'latest', fullUpgradeConfig)
+ expect(result?.config.version).toBe('2.0.0')
+
+ // Step 3: 2.0.0 -> 2.1.6
+ currentVersion = result?.config.version!
+ result = (appUpdater as any)._findCompatibleChannel(currentVersion, 'latest', fullUpgradeConfig)
+ expect(result?.config.version).toBe('2.1.6')
+
+ // Final: 2.1.6 is the latest, no more upgrades
+ currentVersion = result?.config.version!
+ result = (appUpdater as any)._findCompatibleChannel(currentVersion, 'latest', fullUpgradeConfig)
+ expect(result?.config.version).toBe('2.1.6')
+ })
+
+ it('should block direct upgrade from 1.6.3 to 2.0.0 (skip intermediate)', () => {
+ const result = (appUpdater as any)._findCompatibleChannel('1.6.3', 'latest', fullUpgradeConfig)
+
+ // Should return 1.7.5, not 2.0.0, because 1.6.3 < 1.7.0 (minCompatibleVersion of 2.0.0)
+ expect(result?.config.version).toBe('1.7.5')
+ expect(result?.config.version).not.toBe('2.0.0')
+ })
+
+ it('should block direct upgrade from 1.7.5 to 2.1.6 (skip intermediate)', () => {
+ const result = (appUpdater as any)._findCompatibleChannel('1.7.5', 'latest', fullUpgradeConfig)
+
+ // Should return 2.0.0, not 2.1.6, because 1.7.5 < 2.0.0 (minCompatibleVersion of 2.1.6)
+ expect(result?.config.version).toBe('2.0.0')
+ expect(result?.config.version).not.toBe('2.1.6')
+ })
+ })
})
diff --git a/src/main/services/agents/BaseService.ts b/src/main/services/agents/BaseService.ts
index d96ce1b8e4..78bf72a952 100644
--- a/src/main/services/agents/BaseService.ts
+++ b/src/main/services/agents/BaseService.ts
@@ -1,17 +1,13 @@
-import { type Client, createClient } from '@libsql/client'
import { loggerService } from '@logger'
import { mcpApiService } from '@main/apiServer/services/mcp'
import type { ModelValidationError } from '@main/apiServer/utils'
import { validateModelId } from '@main/apiServer/utils'
import type { AgentType, MCPTool, SlashCommand, Tool } from '@types'
import { objectKeys } from '@types'
-import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql'
import fs from 'fs'
import path from 'path'
-import { MigrationService } from './database/MigrationService'
-import * as schema from './database/schema'
-import { dbPath } from './drizzle.config'
+import { DatabaseManager } from './database/DatabaseManager'
import type { AgentModelField } from './errors'
import { AgentModelValidationError } from './errors'
import { builtinSlashCommands } from './services/claudecode/commands'
@@ -20,40 +16,24 @@ import { builtinTools } from './services/claudecode/tools'
const logger = loggerService.withContext('BaseService')
/**
- * Base service class providing shared database connection and utilities
- * for all agent-related services.
+ * Base service class providing shared utilities for all agent-related services.
*
* Features:
- * - Programmatic schema management (no CLI dependencies)
- * - Automatic table creation and migration
- * - Schema version tracking and compatibility checks
- * - Transaction-based operations for safety
- * - Development vs production mode handling
- * - Connection retry logic with exponential backoff
+ * - Database access through DatabaseManager singleton
+ * - JSON field serialization/deserialization
+ * - Path validation and creation
+ * - Model validation
+ * - MCP tools and slash commands listing
*/
export abstract class BaseService {
- protected static client: Client | null = null
- protected static db: LibSQLDatabase | null = null
- protected static isInitialized = false
- protected static initializationPromise: Promise | null = null
- protected jsonFields: string[] = ['tools', 'mcps', 'configuration', 'accessible_paths', 'allowed_tools']
-
- /**
- * Initialize database with retry logic and proper error handling
- */
- protected static async initialize(): Promise {
- // Return existing initialization if in progress
- if (BaseService.initializationPromise) {
- return BaseService.initializationPromise
- }
-
- if (BaseService.isInitialized) {
- return
- }
-
- BaseService.initializationPromise = BaseService.performInitialization()
- return BaseService.initializationPromise
- }
+ protected jsonFields: string[] = [
+ 'tools',
+ 'mcps',
+ 'configuration',
+ 'accessible_paths',
+ 'allowed_tools',
+ 'slash_commands'
+ ]
public async listMcpTools(agentType: AgentType, ids?: string[]): Promise {
const tools: Tool[] = []
@@ -94,78 +74,13 @@ export abstract class BaseService {
return []
}
- private static async performInitialization(): Promise {
- const maxRetries = 3
- let lastError: Error
-
- for (let attempt = 1; attempt <= maxRetries; attempt++) {
- try {
- logger.info(`Initializing Agent database at: ${dbPath} (attempt ${attempt}/${maxRetries})`)
-
- // Ensure the database directory exists
- const dbDir = path.dirname(dbPath)
- if (!fs.existsSync(dbDir)) {
- logger.info(`Creating database directory: ${dbDir}`)
- fs.mkdirSync(dbDir, { recursive: true })
- }
-
- BaseService.client = createClient({
- url: `file:${dbPath}`
- })
-
- BaseService.db = drizzle(BaseService.client, { schema })
-
- // Run database migrations
- const migrationService = new MigrationService(BaseService.db, BaseService.client)
- await migrationService.runMigrations()
-
- BaseService.isInitialized = true
- logger.info('Agent database initialized successfully')
- return
- } catch (error) {
- lastError = error as Error
- logger.warn(`Database initialization attempt ${attempt} failed:`, lastError)
-
- // Clean up on failure
- if (BaseService.client) {
- try {
- BaseService.client.close()
- } catch (closeError) {
- logger.warn('Failed to close client during cleanup:', closeError as Error)
- }
- }
- BaseService.client = null
- BaseService.db = null
-
- // Wait before retrying (exponential backoff)
- if (attempt < maxRetries) {
- const delay = Math.pow(2, attempt) * 1000 // 2s, 4s, 8s
- logger.info(`Retrying in ${delay}ms...`)
- await new Promise((resolve) => setTimeout(resolve, delay))
- }
- }
- }
-
- // All retries failed
- BaseService.initializationPromise = null
- logger.error('Failed to initialize Agent database after all retries:', lastError!)
- throw lastError!
- }
-
- protected ensureInitialized(): void {
- if (!BaseService.isInitialized || !BaseService.db || !BaseService.client) {
- throw new Error('Database not initialized. Call initialize() first.')
- }
- }
-
- protected get database(): LibSQLDatabase {
- this.ensureInitialized()
- return BaseService.db!
- }
-
- protected get rawClient(): Client {
- this.ensureInitialized()
- return BaseService.client!
+ /**
+ * Get database instance
+ * Automatically waits for initialization to complete
+ */
+ protected async getDatabase() {
+ const dbManager = await DatabaseManager.getInstance()
+ return dbManager.getDatabase()
}
protected serializeJsonFields(data: any): any {
@@ -277,7 +192,7 @@ export abstract class BaseService {
}
/**
- * Force re-initialization (for development/testing)
+ * Validate agent model configuration
*/
protected async validateAgentModels(
agentType: AgentType,
@@ -318,22 +233,4 @@ export abstract class BaseService {
}
}
}
-
- static async reinitialize(): Promise {
- BaseService.isInitialized = false
- BaseService.initializationPromise = null
-
- if (BaseService.client) {
- try {
- BaseService.client.close()
- } catch (error) {
- logger.warn('Failed to close client during reinitialize:', error as Error)
- }
- }
-
- BaseService.client = null
- BaseService.db = null
-
- await BaseService.initialize()
- }
}
diff --git a/src/main/services/agents/database/DatabaseManager.ts b/src/main/services/agents/database/DatabaseManager.ts
new file mode 100644
index 0000000000..f4b13971c7
--- /dev/null
+++ b/src/main/services/agents/database/DatabaseManager.ts
@@ -0,0 +1,156 @@
+import { type Client, createClient } from '@libsql/client'
+import { loggerService } from '@logger'
+import type { LibSQLDatabase } from 'drizzle-orm/libsql'
+import { drizzle } from 'drizzle-orm/libsql'
+import fs from 'fs'
+import path from 'path'
+
+import { dbPath } from '../drizzle.config'
+import { MigrationService } from './MigrationService'
+import * as schema from './schema'
+
+const logger = loggerService.withContext('DatabaseManager')
+
+/**
+ * Database initialization state
+ */
+enum InitState {
+ INITIALIZING = 'initializing',
+ INITIALIZED = 'initialized',
+ FAILED = 'failed'
+}
+
+/**
+ * DatabaseManager - Singleton class for managing libsql database connections
+ *
+ * Responsibilities:
+ * - Single source of truth for database connection
+ * - Thread-safe initialization with state management
+ * - Automatic migration handling
+ * - Safe connection cleanup
+ * - Error recovery and retry logic
+ * - Windows platform compatibility fixes
+ */
+export class DatabaseManager {
+ private static instance: DatabaseManager | null = null
+
+ private client: Client | null = null
+ private db: LibSQLDatabase | null = null
+ private state: InitState = InitState.INITIALIZING
+
+ /**
+ * Get the singleton instance (database initialization starts automatically)
+ */
+ public static async getInstance(): Promise {
+ if (DatabaseManager.instance) {
+ return DatabaseManager.instance
+ }
+
+ const instance = new DatabaseManager()
+ await instance.initialize()
+ DatabaseManager.instance = instance
+
+ return instance
+ }
+
+ /**
+ * Perform the actual initialization
+ */
+ public async initialize(): Promise {
+ if (this.state === InitState.INITIALIZED) {
+ return
+ }
+
+ try {
+ logger.info(`Initializing database at: ${dbPath}`)
+
+ // Ensure database directory exists
+ const dbDir = path.dirname(dbPath)
+ if (!fs.existsSync(dbDir)) {
+ logger.info(`Creating database directory: ${dbDir}`)
+ fs.mkdirSync(dbDir, { recursive: true })
+ }
+
+ // Check if database file is corrupted (Windows specific check)
+ if (fs.existsSync(dbPath)) {
+ const stats = fs.statSync(dbPath)
+ if (stats.size === 0) {
+ logger.warn('Database file is empty, removing corrupted file')
+ fs.unlinkSync(dbPath)
+ }
+ }
+
+ // Create client with platform-specific options
+ this.client = createClient({
+ url: `file:${dbPath}`,
+ // intMode: 'number' helps avoid some Windows compatibility issues
+ intMode: 'number'
+ })
+
+ // Create drizzle instance
+ this.db = drizzle(this.client, { schema })
+
+ // Run migrations
+ const migrationService = new MigrationService(this.db, this.client)
+ await migrationService.runMigrations()
+
+ this.state = InitState.INITIALIZED
+ logger.info('Database initialized successfully')
+ } catch (error) {
+ const err = error as Error
+ logger.error('Database initialization failed:', {
+ error: err.message,
+ stack: err.stack
+ })
+
+ // Clean up failed initialization
+ this.cleanupFailedInit()
+
+ // Set failed state
+ this.state = InitState.FAILED
+ throw new Error(`Database initialization failed: ${err.message || 'Unknown error'}`)
+ }
+ }
+
+ /**
+ * Clean up after failed initialization
+ */
+ private cleanupFailedInit(): void {
+ if (this.client) {
+ try {
+ // On Windows, closing a partially initialized client can crash
+ // Wrap in try-catch and ignore errors during cleanup
+ this.client.close()
+ } catch (error) {
+ logger.warn('Failed to close client during cleanup:', error as Error)
+ }
+ }
+ this.client = null
+ this.db = null
+ }
+
+ /**
+ * Get the database instance
+ * Automatically waits for initialization to complete
+ * @throws Error if database initialization failed
+ */
+ public getDatabase(): LibSQLDatabase {
+ return this.db!
+ }
+
+ /**
+ * Get the raw client (for advanced operations)
+ * Automatically waits for initialization to complete
+ * @throws Error if database initialization failed
+ */
+ public async getClient(): Promise {
+ return this.client!
+ }
+
+ /**
+ * Check if database is initialized
+ */
+ public isInitialized(): boolean {
+ return this.state === InitState.INITIALIZED
+ }
+}
diff --git a/src/main/services/agents/database/index.ts b/src/main/services/agents/database/index.ts
index 61b3a9ffcc..43302a6b25 100644
--- a/src/main/services/agents/database/index.ts
+++ b/src/main/services/agents/database/index.ts
@@ -7,8 +7,14 @@
* Schema evolution is handled by Drizzle Kit migrations.
*/
+// Database Manager (Singleton)
+export * from './DatabaseManager'
+
// Drizzle ORM schemas
export * from './schema'
// Repository helpers
export * from './sessionMessageRepository'
+
+// Migration Service
+export * from './MigrationService'
diff --git a/src/main/services/agents/database/schema/sessions.schema.ts b/src/main/services/agents/database/schema/sessions.schema.ts
index 21ac2fe2c6..4b16a9ec41 100644
--- a/src/main/services/agents/database/schema/sessions.schema.ts
+++ b/src/main/services/agents/database/schema/sessions.schema.ts
@@ -22,6 +22,7 @@ export const sessionsTable = sqliteTable('sessions', {
mcps: text('mcps'), // JSON array of MCP tool IDs
allowed_tools: text('allowed_tools'), // JSON array of allowed tool IDs (whitelist)
+ slash_commands: text('slash_commands'), // JSON array of slash command objects from SDK init
configuration: text('configuration'), // JSON, extensible settings
diff --git a/src/main/services/agents/database/sessionMessageRepository.ts b/src/main/services/agents/database/sessionMessageRepository.ts
index 4567c61ec0..a9b1d2e572 100644
--- a/src/main/services/agents/database/sessionMessageRepository.ts
+++ b/src/main/services/agents/database/sessionMessageRepository.ts
@@ -15,26 +15,16 @@ import { sessionMessagesTable } from './schema'
const logger = loggerService.withContext('AgentMessageRepository')
-type TxClient = any
-
export type PersistUserMessageParams = AgentMessageUserPersistPayload & {
sessionId: string
agentSessionId?: string
- tx?: TxClient
}
export type PersistAssistantMessageParams = AgentMessageAssistantPersistPayload & {
sessionId: string
agentSessionId: string
- tx?: TxClient
}
-type PersistExchangeParams = AgentMessagePersistExchangePayload & {
- tx?: TxClient
-}
-
-type PersistExchangeResult = AgentMessagePersistExchangeResult
-
class AgentMessageRepository extends BaseService {
private static instance: AgentMessageRepository | null = null
@@ -87,17 +77,13 @@ class AgentMessageRepository extends BaseService {
return deserialized
}
- private getWriter(tx?: TxClient): TxClient {
- return tx ?? this.database
- }
-
private async findExistingMessageRow(
- writer: TxClient,
sessionId: string,
role: string,
messageId: string
): Promise {
- const candidateRows: SessionMessageRow[] = await writer
+ const database = await this.getDatabase()
+ const candidateRows: SessionMessageRow[] = await database
.select()
.from(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.session_id, sessionId), eq(sessionMessagesTable.role, role)))
@@ -122,10 +108,7 @@ class AgentMessageRepository extends BaseService {
private async upsertMessage(
params: PersistUserMessageParams | PersistAssistantMessageParams
): Promise {
- await AgentMessageRepository.initialize()
- this.ensureInitialized()
-
- const { sessionId, agentSessionId = '', payload, metadata, createdAt, tx } = params
+ const { sessionId, agentSessionId = '', payload, metadata, createdAt } = params
if (!payload?.message?.role) {
throw new Error('Message payload missing role')
@@ -135,18 +118,18 @@ class AgentMessageRepository extends BaseService {
throw new Error('Message payload missing id')
}
- const writer = this.getWriter(tx)
+ const database = await this.getDatabase()
const now = createdAt ?? payload.message.createdAt ?? new Date().toISOString()
const serializedPayload = this.serializeMessage(payload)
const serializedMetadata = this.serializeMetadata(metadata)
- const existingRow = await this.findExistingMessageRow(writer, sessionId, payload.message.role, payload.message.id)
+ const existingRow = await this.findExistingMessageRow(sessionId, payload.message.role, payload.message.id)
if (existingRow) {
const metadataToPersist = serializedMetadata ?? existingRow.metadata ?? undefined
const agentSessionToPersist = agentSessionId || existingRow.agent_session_id || ''
- await writer
+ await database
.update(sessionMessagesTable)
.set({
content: serializedPayload,
@@ -175,7 +158,7 @@ class AgentMessageRepository extends BaseService {
updated_at: now
}
- const [saved] = await writer.insert(sessionMessagesTable).values(insertData).returning()
+ const [saved] = await database.insert(sessionMessagesTable).values(insertData).returning()
return this.deserialize(saved)
}
@@ -188,49 +171,38 @@ class AgentMessageRepository extends BaseService {
return this.upsertMessage(params)
}
- async persistExchange(params: PersistExchangeParams): Promise {
- await AgentMessageRepository.initialize()
- this.ensureInitialized()
-
+ async persistExchange(params: AgentMessagePersistExchangePayload): Promise {
const { sessionId, agentSessionId, user, assistant } = params
- const result = await this.database.transaction(async (tx) => {
- const exchangeResult: PersistExchangeResult = {}
+ const exchangeResult: AgentMessagePersistExchangeResult = {}
- if (user?.payload) {
- exchangeResult.userMessage = await this.persistUserMessage({
- sessionId,
- agentSessionId,
- payload: user.payload,
- metadata: user.metadata,
- createdAt: user.createdAt,
- tx
- })
- }
+ if (user?.payload) {
+ exchangeResult.userMessage = await this.persistUserMessage({
+ sessionId,
+ agentSessionId,
+ payload: user.payload,
+ metadata: user.metadata,
+ createdAt: user.createdAt
+ })
+ }
- if (assistant?.payload) {
- exchangeResult.assistantMessage = await this.persistAssistantMessage({
- sessionId,
- agentSessionId,
- payload: assistant.payload,
- metadata: assistant.metadata,
- createdAt: assistant.createdAt,
- tx
- })
- }
+ if (assistant?.payload) {
+ exchangeResult.assistantMessage = await this.persistAssistantMessage({
+ sessionId,
+ agentSessionId,
+ payload: assistant.payload,
+ metadata: assistant.metadata,
+ createdAt: assistant.createdAt
+ })
+ }
- return exchangeResult
- })
-
- return result
+ return exchangeResult
}
async getSessionHistory(sessionId: string): Promise {
- await AgentMessageRepository.initialize()
- this.ensureInitialized()
-
try {
- const rows = await this.database
+ const database = await this.getDatabase()
+ const rows = await database
.select()
.from(sessionMessagesTable)
.where(eq(sessionMessagesTable.session_id, sessionId))
diff --git a/src/main/services/agents/services/AgentService.ts b/src/main/services/agents/services/AgentService.ts
index 07ed89a0f3..2faa87bb45 100644
--- a/src/main/services/agents/services/AgentService.ts
+++ b/src/main/services/agents/services/AgentService.ts
@@ -32,14 +32,8 @@ export class AgentService extends BaseService {
return AgentService.instance
}
- async initialize(): Promise {
- await BaseService.initialize()
- }
-
// Agent Methods
async createAgent(req: CreateAgentRequest): Promise {
- this.ensureInitialized()
-
const id = `agent_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`
const now = new Date().toISOString()
@@ -75,8 +69,9 @@ export class AgentService extends BaseService {
updated_at: now
}
- await this.database.insert(agentsTable).values(insertData)
- const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
+ const database = await this.getDatabase()
+ await database.insert(agentsTable).values(insertData)
+ const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
if (!result[0]) {
throw new Error('Failed to create agent')
}
@@ -86,9 +81,8 @@ export class AgentService extends BaseService {
}
async getAgent(id: string): Promise {
- this.ensureInitialized()
-
- const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
+ const database = await this.getDatabase()
+ const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
if (!result[0]) {
return null
@@ -118,9 +112,9 @@ export class AgentService extends BaseService {
}
async listAgents(options: ListOptions = {}): Promise<{ agents: AgentEntity[]; total: number }> {
- this.ensureInitialized() // Build query with pagination
-
- const totalResult = await this.database.select({ count: count() }).from(agentsTable)
+ // Build query with pagination
+ const database = await this.getDatabase()
+ const totalResult = await database.select({ count: count() }).from(agentsTable)
const sortBy = options.sortBy || 'created_at'
const orderBy = options.orderBy || 'desc'
@@ -128,7 +122,7 @@ export class AgentService extends BaseService {
const sortField = agentsTable[sortBy]
const orderFn = orderBy === 'asc' ? asc : desc
- const baseQuery = this.database.select().from(agentsTable).orderBy(orderFn(sortField))
+ const baseQuery = database.select().from(agentsTable).orderBy(orderFn(sortField))
const result =
options.limit !== undefined
@@ -151,8 +145,6 @@ export class AgentService extends BaseService {
updates: UpdateAgentRequest,
options: { replace?: boolean } = {}
): Promise {
- this.ensureInitialized()
-
// Check if agent exists
const existing = await this.getAgent(id)
if (!existing) {
@@ -195,22 +187,21 @@ export class AgentService extends BaseService {
}
}
- await this.database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
+ const database = await this.getDatabase()
+ await database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
return await this.getAgent(id)
}
async deleteAgent(id: string): Promise {
- this.ensureInitialized()
-
- const result = await this.database.delete(agentsTable).where(eq(agentsTable.id, id))
+ const database = await this.getDatabase()
+ const result = await database.delete(agentsTable).where(eq(agentsTable.id, id))
return result.rowsAffected > 0
}
async agentExists(id: string): Promise {
- this.ensureInitialized()
-
- const result = await this.database
+ const database = await this.getDatabase()
+ const result = await database
.select({ id: agentsTable.id })
.from(agentsTable)
.where(eq(agentsTable.id, id))
diff --git a/src/main/services/agents/services/SessionMessageService.ts b/src/main/services/agents/services/SessionMessageService.ts
index 46435fa371..48ef8621ef 100644
--- a/src/main/services/agents/services/SessionMessageService.ts
+++ b/src/main/services/agents/services/SessionMessageService.ts
@@ -104,14 +104,9 @@ export class SessionMessageService extends BaseService {
return SessionMessageService.instance
}
- async initialize(): Promise {
- await BaseService.initialize()
- }
-
async sessionMessageExists(id: number): Promise {
- this.ensureInitialized()
-
- const result = await this.database
+ const database = await this.getDatabase()
+ const result = await database
.select({ id: sessionMessagesTable.id })
.from(sessionMessagesTable)
.where(eq(sessionMessagesTable.id, id))
@@ -124,10 +119,9 @@ export class SessionMessageService extends BaseService {
sessionId: string,
options: ListOptions = {}
): Promise<{ messages: AgentSessionMessageEntity[] }> {
- this.ensureInitialized()
-
// Get messages with pagination
- const baseQuery = this.database
+ const database = await this.getDatabase()
+ const baseQuery = database
.select()
.from(sessionMessagesTable)
.where(eq(sessionMessagesTable.session_id, sessionId))
@@ -146,9 +140,8 @@ export class SessionMessageService extends BaseService {
}
async deleteSessionMessage(sessionId: string, messageId: number): Promise {
- this.ensureInitialized()
-
- const result = await this.database
+ const database = await this.getDatabase()
+ const result = await database
.delete(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.id, messageId), eq(sessionMessagesTable.session_id, sessionId)))
@@ -160,8 +153,6 @@ export class SessionMessageService extends BaseService {
messageData: CreateSessionMessageRequest,
abortController: AbortController
): Promise {
- this.ensureInitialized()
-
return await this.startSessionMessageStream(session, messageData, abortController)
}
@@ -270,10 +261,9 @@ export class SessionMessageService extends BaseService {
}
private async getLastAgentSessionId(sessionId: string): Promise {
- this.ensureInitialized()
-
try {
- const result = await this.database
+ const database = await this.getDatabase()
+ const result = await database
.select({ agent_session_id: sessionMessagesTable.agent_session_id })
.from(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.session_id, sessionId), not(eq(sessionMessagesTable.agent_session_id, ''))))
diff --git a/src/main/services/agents/services/SessionService.ts b/src/main/services/agents/services/SessionService.ts
index 62dad3ed51..d933ef8dd9 100644
--- a/src/main/services/agents/services/SessionService.ts
+++ b/src/main/services/agents/services/SessionService.ts
@@ -1,4 +1,5 @@
-import type { UpdateSessionResponse } from '@types'
+import { loggerService } from '@logger'
+import type { SlashCommand, UpdateSessionResponse } from '@types'
import {
AgentBaseSchema,
type AgentEntity,
@@ -13,6 +14,10 @@ import { and, count, desc, eq, type SQL } from 'drizzle-orm'
import { BaseService } from '../BaseService'
import { agentsTable, type InsertSessionRow, type SessionRow, sessionsTable } from '../database/schema'
import type { AgentModelField } from '../errors'
+import { pluginService } from '../plugins/PluginService'
+import { builtinSlashCommands } from './claudecode/commands'
+
+const logger = loggerService.withContext('SessionService')
export class SessionService extends BaseService {
private static instance: SessionService | null = null
@@ -25,21 +30,62 @@ export class SessionService extends BaseService {
return SessionService.instance
}
- async initialize(): Promise {
- await BaseService.initialize()
+ /**
+ * Override BaseService.listSlashCommands to merge builtin and plugin commands
+ */
+ async listSlashCommands(agentType: string, agentId?: string): Promise {
+ const commands: SlashCommand[] = []
+
+ // Add builtin slash commands
+ if (agentType === 'claude-code') {
+ commands.push(...builtinSlashCommands)
+ }
+
+ // Add local command plugins from .claude/commands/
+ if (agentId) {
+ try {
+ const installedPlugins = await pluginService.listInstalled(agentId)
+
+ // Filter for command type plugins
+ const commandPlugins = installedPlugins.filter((p) => p.type === 'command')
+
+ // Convert plugin metadata to SlashCommand format
+ for (const plugin of commandPlugins) {
+ const commandName = plugin.metadata.filename.replace(/\.md$/i, '')
+ commands.push({
+ command: `/${commandName}`,
+ description: plugin.metadata.description
+ })
+ }
+
+ logger.info('Listed slash commands', {
+ agentType,
+ agentId,
+ builtinCount: builtinSlashCommands.length,
+ localCount: commandPlugins.length,
+ totalCount: commands.length
+ })
+ } catch (error) {
+ logger.warn('Failed to list local command plugins', {
+ agentId,
+ error: error instanceof Error ? error.message : String(error)
+ })
+ }
+ }
+
+ return commands
}
async createSession(
agentId: string,
req: Partial = {}
): Promise {
- this.ensureInitialized()
-
// Validate agent exists - we'll need to import AgentService for this check
// For now, we'll skip this validation to avoid circular dependencies
// The database foreign key constraint will handle this
- const agents = await this.database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1)
+ const database = await this.getDatabase()
+ const agents = await database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1)
if (!agents[0]) {
throw new Error('Agent not found')
}
@@ -78,14 +124,16 @@ export class SessionService extends BaseService {
plan_model: serializedData.plan_model || null,
small_model: serializedData.small_model || null,
mcps: serializedData.mcps || null,
+ allowed_tools: serializedData.allowed_tools || null,
configuration: serializedData.configuration || null,
created_at: now,
updated_at: now
}
- await this.database.insert(sessionsTable).values(insertData)
+ const db = await this.getDatabase()
+ await db.insert(sessionsTable).values(insertData)
- const result = await this.database.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
+ const result = await db.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
if (!result[0]) {
throw new Error('Failed to create session')
@@ -96,9 +144,8 @@ export class SessionService extends BaseService {
}
async getSession(agentId: string, id: string): Promise {
- this.ensureInitialized()
-
- const result = await this.database
+ const database = await this.getDatabase()
+ const result = await database
.select()
.from(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
@@ -110,7 +157,13 @@ export class SessionService extends BaseService {
const session = this.deserializeJsonFields(result[0]) as GetAgentSessionResponse
session.tools = await this.listMcpTools(session.agent_type, session.mcps)
- session.slash_commands = await this.listSlashCommands(session.agent_type)
+
+ // If slash_commands is not in database yet (e.g., first invoke before init message),
+ // fall back to builtin + local commands. Otherwise, use the merged commands from database.
+ if (!session.slash_commands || session.slash_commands.length === 0) {
+ session.slash_commands = await this.listSlashCommands(session.agent_type, agentId)
+ }
+
return session
}
@@ -118,8 +171,6 @@ export class SessionService extends BaseService {
agentId?: string,
options: ListOptions = {}
): Promise<{ sessions: AgentSessionEntity[]; total: number }> {
- this.ensureInitialized()
-
// Build where conditions
const whereConditions: SQL[] = []
if (agentId) {
@@ -134,16 +185,13 @@ export class SessionService extends BaseService {
: undefined
// Get total count
- const totalResult = await this.database.select({ count: count() }).from(sessionsTable).where(whereClause)
+ const database = await this.getDatabase()
+ const totalResult = await database.select({ count: count() }).from(sessionsTable).where(whereClause)
const total = totalResult[0].count
// Build list query with pagination - sort by updated_at descending (latest first)
- const baseQuery = this.database
- .select()
- .from(sessionsTable)
- .where(whereClause)
- .orderBy(desc(sessionsTable.updated_at))
+ const baseQuery = database.select().from(sessionsTable).where(whereClause).orderBy(desc(sessionsTable.updated_at))
const result =
options.limit !== undefined
@@ -162,8 +210,6 @@ export class SessionService extends BaseService {
id: string,
updates: UpdateSessionRequest
): Promise {
- this.ensureInitialized()
-
// Check if session exists
const existing = await this.getSession(agentId, id)
if (!existing) {
@@ -204,15 +250,15 @@ export class SessionService extends BaseService {
}
}
- await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
+ const database = await this.getDatabase()
+ await database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
return await this.getSession(agentId, id)
}
async deleteSession(agentId: string, id: string): Promise {
- this.ensureInitialized()
-
- const result = await this.database
+ const database = await this.getDatabase()
+ const result = await database
.delete(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
@@ -220,9 +266,8 @@ export class SessionService extends BaseService {
}
async sessionExists(agentId: string, id: string): Promise {
- this.ensureInitialized()
-
- const result = await this.database
+ const database = await this.getDatabase()
+ const result = await database
.select({ id: sessionsTable.id })
.from(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
diff --git a/src/main/services/agents/services/claudecode/__tests__/transform.test.ts b/src/main/services/agents/services/claudecode/__tests__/transform.test.ts
index 1c5c2ade6b..2565f5e605 100644
--- a/src/main/services/agents/services/claudecode/__tests__/transform.test.ts
+++ b/src/main/services/agents/services/claudecode/__tests__/transform.test.ts
@@ -1,7 +1,7 @@
import type { SDKMessage } from '@anthropic-ai/claude-agent-sdk'
import { describe, expect, it } from 'vitest'
-import { ClaudeStreamState, transformSDKMessageToStreamParts } from '../transform'
+import { ClaudeStreamState, stripLocalCommandTags, transformSDKMessageToStreamParts } from '../transform'
const baseStreamMetadata = {
parent_tool_use_id: null,
@@ -10,9 +10,27 @@ const baseStreamMetadata = {
const uuid = (n: number) => `00000000-0000-0000-0000-${n.toString().padStart(12, '0')}`
+describe('stripLocalCommandTags', () => {
+ it('removes stdout wrapper while preserving inner text', () => {
+ const input = 'before echo "hi" after'
+ expect(stripLocalCommandTags(input)).toBe('before echo "hi" after')
+ })
+
+ it('strips multiple stdout/stderr blocks and leaves other content intact', () => {
+ const input =
+ 'line1\nkeep\nError'
+ expect(stripLocalCommandTags(input)).toBe('line1\nkeep\nError')
+ })
+
+ it('if no tags present, returns original string', () => {
+ const input = 'just some normal text'
+ expect(stripLocalCommandTags(input)).toBe(input)
+ })
+})
+
describe('Claude → AiSDK transform', () => {
it('handles tool call streaming lifecycle', () => {
- const state = new ClaudeStreamState()
+ const state = new ClaudeStreamState({ agentSessionId: baseStreamMetadata.session_id })
const parts: ReturnType[number][] = []
const messages: SDKMessage[] = [
@@ -169,14 +187,119 @@ describe('Claude → AiSDK transform', () => {
(typeof parts)[number],
{ type: 'tool-result' }
>
- expect(toolResult.toolCallId).toBe('tool-1')
+ expect(toolResult.toolCallId).toBe('session-123:tool-1')
expect(toolResult.toolName).toBe('Bash')
expect(toolResult.input).toEqual({ command: 'ls' })
expect(toolResult.output).toBe('ok')
})
+ it('handles tool calls without streaming events (no content_block_start/stop)', () => {
+ const state = new ClaudeStreamState({ agentSessionId: '12344' })
+ const parts: ReturnType[number][] = []
+
+ const messages: SDKMessage[] = [
+ {
+ ...baseStreamMetadata,
+ type: 'assistant',
+ uuid: uuid(20),
+ message: {
+ id: 'msg-tool-no-stream',
+ type: 'message',
+ role: 'assistant',
+ model: 'claude-test',
+ content: [
+ {
+ type: 'tool_use',
+ id: 'tool-read',
+ name: 'Read',
+ input: { file_path: '/test.txt' }
+ },
+ {
+ type: 'tool_use',
+ id: 'tool-bash',
+ name: 'Bash',
+ input: { command: 'ls -la' }
+ }
+ ],
+ stop_reason: 'tool_use',
+ stop_sequence: null,
+ usage: {
+ input_tokens: 10,
+ output_tokens: 20
+ }
+ }
+ } as unknown as SDKMessage,
+ {
+ ...baseStreamMetadata,
+ type: 'user',
+ uuid: uuid(21),
+ message: {
+ role: 'user',
+ content: [
+ {
+ type: 'tool_result',
+ tool_use_id: 'tool-read',
+ content: 'file contents',
+ is_error: false
+ }
+ ]
+ }
+ } as SDKMessage,
+ {
+ ...baseStreamMetadata,
+ type: 'user',
+ uuid: uuid(22),
+ message: {
+ role: 'user',
+ content: [
+ {
+ type: 'tool_result',
+ tool_use_id: 'tool-bash',
+ content: 'total 42\n...',
+ is_error: false
+ }
+ ]
+ }
+ } as SDKMessage
+ ]
+
+ for (const message of messages) {
+ const transformed = transformSDKMessageToStreamParts(message, state)
+ parts.push(...transformed)
+ }
+
+ const types = parts.map((part) => part.type)
+ expect(types).toEqual(['tool-call', 'tool-call', 'tool-result', 'tool-result'])
+
+ const toolCalls = parts.filter((part) => part.type === 'tool-call') as Extract<
+ (typeof parts)[number],
+ { type: 'tool-call' }
+ >[]
+ expect(toolCalls).toHaveLength(2)
+ expect(toolCalls[0].toolName).toBe('Read')
+ expect(toolCalls[0].toolCallId).toBe('12344:tool-read')
+ expect(toolCalls[1].toolName).toBe('Bash')
+ expect(toolCalls[1].toolCallId).toBe('12344:tool-bash')
+
+ const toolResults = parts.filter((part) => part.type === 'tool-result') as Extract<
+ (typeof parts)[number],
+ { type: 'tool-result' }
+ >[]
+ expect(toolResults).toHaveLength(2)
+ // This is the key assertion - toolName should NOT be 'unknown'
+ expect(toolResults[0].toolName).toBe('Read')
+ expect(toolResults[0].toolCallId).toBe('12344:tool-read')
+ expect(toolResults[0].input).toEqual({ file_path: '/test.txt' })
+ expect(toolResults[0].output).toBe('file contents')
+
+ expect(toolResults[1].toolName).toBe('Bash')
+ expect(toolResults[1].toolCallId).toBe('12344:tool-bash')
+ expect(toolResults[1].input).toEqual({ command: 'ls -la' })
+ expect(toolResults[1].output).toBe('total 42\n...')
+ })
+
it('handles streaming text completion', () => {
- const state = new ClaudeStreamState()
+ const state = new ClaudeStreamState({ agentSessionId: baseStreamMetadata.session_id })
const parts: ReturnType[number][] = []
const messages: SDKMessage[] = [
@@ -287,4 +410,87 @@ describe('Claude → AiSDK transform', () => {
expect(finishStep.finishReason).toBe('stop')
expect(finishStep.usage).toEqual({ inputTokens: 2, outputTokens: 4, totalTokens: 6 })
})
+
+ it('emits fallback text when Claude sends a snapshot instead of deltas', () => {
+ const state = new ClaudeStreamState({ agentSessionId: '12344' })
+ const parts: ReturnType[number][] = []
+
+ const messages: SDKMessage[] = [
+ {
+ ...baseStreamMetadata,
+ type: 'stream_event',
+ uuid: uuid(30),
+ event: {
+ type: 'message_start',
+ message: {
+ id: 'msg-fallback',
+ type: 'message',
+ role: 'assistant',
+ model: 'claude-test',
+ content: [],
+ stop_reason: null,
+ stop_sequence: null,
+ usage: {}
+ }
+ }
+ } as unknown as SDKMessage,
+ {
+ ...baseStreamMetadata,
+ type: 'stream_event',
+ uuid: uuid(31),
+ event: {
+ type: 'content_block_start',
+ index: 0,
+ content_block: {
+ type: 'text',
+ text: ''
+ }
+ }
+ } as unknown as SDKMessage,
+ {
+ ...baseStreamMetadata,
+ type: 'assistant',
+ uuid: uuid(32),
+ message: {
+ id: 'msg-fallback-content',
+ type: 'message',
+ role: 'assistant',
+ model: 'claude-test',
+ content: [
+ {
+ type: 'text',
+ text: 'Final answer without streaming deltas.'
+ }
+ ],
+ stop_reason: 'end_turn',
+ stop_sequence: null,
+ usage: {
+ input_tokens: 3,
+ output_tokens: 7
+ }
+ }
+ } as unknown as SDKMessage
+ ]
+
+ for (const message of messages) {
+ const transformed = transformSDKMessageToStreamParts(message, state)
+ parts.push(...transformed)
+ }
+
+ const types = parts.map((part) => part.type)
+ expect(types).toEqual(['start-step', 'text-start', 'text-delta', 'text-end', 'finish-step'])
+
+ const delta = parts.find((part) => part.type === 'text-delta') as Extract<
+ (typeof parts)[number],
+ { type: 'text-delta' }
+ >
+ expect(delta.text).toBe('Final answer without streaming deltas.')
+
+ const finish = parts.find((part) => part.type === 'finish-step') as Extract<
+ (typeof parts)[number],
+ { type: 'finish-step' }
+ >
+ expect(finish.usage).toEqual({ inputTokens: 3, outputTokens: 7, totalTokens: 10 })
+ expect(finish.finishReason).toBe('stop')
+ })
})
diff --git a/src/main/services/agents/services/claudecode/claude-stream-state.ts b/src/main/services/agents/services/claudecode/claude-stream-state.ts
index 078f048ce8..30b5790c82 100644
--- a/src/main/services/agents/services/claudecode/claude-stream-state.ts
+++ b/src/main/services/agents/services/claudecode/claude-stream-state.ts
@@ -10,8 +10,21 @@
* Every Claude turn gets its own instance. `resetStep` should be invoked once the finish event has
* been emitted to avoid leaking state into the next turn.
*/
+import { loggerService } from '@logger'
import type { FinishReason, LanguageModelUsage, ProviderMetadata } from 'ai'
+/**
+ * Builds a namespaced tool call ID by combining session ID with raw tool call ID.
+ * This ensures tool calls from different sessions don't conflict even if they have
+ * the same raw ID from the SDK.
+ *
+ * @param sessionId - The agent session ID
+ * @param rawToolCallId - The raw tool call ID from SDK (e.g., "WebFetch_0")
+ */
+export function buildNamespacedToolCallId(sessionId: string, rawToolCallId: string): string {
+ return `${sessionId}:${rawToolCallId}`
+}
+
/**
* Shared fields for every block that Claude can stream (text, reasoning, tool).
*/
@@ -34,6 +47,7 @@ type ReasoningBlockState = BaseBlockState & {
type ToolBlockState = BaseBlockState & {
kind: 'tool'
toolCallId: string
+ rawToolCallId: string
toolName: string
inputBuffer: string
providerMetadata?: ProviderMetadata
@@ -48,12 +62,17 @@ type PendingUsageState = {
}
type PendingToolCall = {
+ rawToolCallId: string
toolCallId: string
toolName: string
input: unknown
providerMetadata?: ProviderMetadata
}
+type ClaudeStreamStateOptions = {
+ agentSessionId: string
+}
+
/**
* Tracks the lifecycle of Claude streaming blocks (text, thinking, tool calls)
* across individual websocket events. The transformer relies on this class to
@@ -61,12 +80,20 @@ type PendingToolCall = {
* usage/finish metadata once Anthropic closes a message.
*/
export class ClaudeStreamState {
+ private logger
+ private readonly agentSessionId: string
private blocksByIndex = new Map()
- private toolIndexById = new Map()
+ private toolIndexByNamespacedId = new Map()
private pendingUsage: PendingUsageState = {}
private pendingToolCalls = new Map()
private stepActive = false
+ constructor(options: ClaudeStreamStateOptions) {
+ this.logger = loggerService.withContext('ClaudeStreamState')
+ this.agentSessionId = options.agentSessionId
+ this.logger.silly('ClaudeStreamState', options)
+ }
+
/** Marks the beginning of a new AiSDK step. */
beginStep(): void {
this.stepActive = true
@@ -104,19 +131,21 @@ export class ClaudeStreamState {
/** Caches tool metadata so subsequent input deltas and results can find it. */
openToolBlock(
index: number,
- params: { toolCallId: string; toolName: string; providerMetadata?: ProviderMetadata }
+ params: { rawToolCallId: string; toolName: string; providerMetadata?: ProviderMetadata }
): ToolBlockState {
+ const toolCallId = buildNamespacedToolCallId(this.agentSessionId, params.rawToolCallId)
const block: ToolBlockState = {
kind: 'tool',
- id: params.toolCallId,
+ id: toolCallId,
index,
- toolCallId: params.toolCallId,
+ toolCallId,
+ rawToolCallId: params.rawToolCallId,
toolName: params.toolName,
inputBuffer: '',
providerMetadata: params.providerMetadata
}
this.blocksByIndex.set(index, block)
- this.toolIndexById.set(params.toolCallId, index)
+ this.toolIndexByNamespacedId.set(toolCallId, index)
return block
}
@@ -124,14 +153,32 @@ export class ClaudeStreamState {
return this.blocksByIndex.get(index)
}
+ getFirstOpenTextBlock(): TextBlockState | undefined {
+ const candidates: TextBlockState[] = []
+ for (const block of this.blocksByIndex.values()) {
+ if (block.kind === 'text') {
+ candidates.push(block)
+ }
+ }
+ if (candidates.length === 0) {
+ return undefined
+ }
+ candidates.sort((a, b) => a.index - b.index)
+ return candidates[0]
+ }
+
getToolBlockById(toolCallId: string): ToolBlockState | undefined {
- const index = this.toolIndexById.get(toolCallId)
+ const index = this.toolIndexByNamespacedId.get(toolCallId)
if (index === undefined) return undefined
const block = this.blocksByIndex.get(index)
if (!block || block.kind !== 'tool') return undefined
return block
}
+ getToolBlockByRawId(rawToolCallId: string): ToolBlockState | undefined {
+ return this.getToolBlockById(buildNamespacedToolCallId(this.agentSessionId, rawToolCallId))
+ }
+
/** Appends streamed text to a text block, returning the updated state when present. */
appendTextDelta(index: number, text: string): TextBlockState | undefined {
const block = this.blocksByIndex.get(index)
@@ -158,10 +205,12 @@ export class ClaudeStreamState {
/** Records a tool call to be consumed once its result arrives from the user. */
registerToolCall(
- toolCallId: string,
+ rawToolCallId: string,
payload: { toolName: string; input: unknown; providerMetadata?: ProviderMetadata }
): void {
- this.pendingToolCalls.set(toolCallId, {
+ const toolCallId = buildNamespacedToolCallId(this.agentSessionId, rawToolCallId)
+ this.pendingToolCalls.set(rawToolCallId, {
+ rawToolCallId,
toolCallId,
toolName: payload.toolName,
input: payload.input,
@@ -170,10 +219,10 @@ export class ClaudeStreamState {
}
/** Retrieves and clears the buffered tool call metadata for the given id. */
- consumePendingToolCall(toolCallId: string): PendingToolCall | undefined {
- const entry = this.pendingToolCalls.get(toolCallId)
+ consumePendingToolCall(rawToolCallId: string): PendingToolCall | undefined {
+ const entry = this.pendingToolCalls.get(rawToolCallId)
if (entry) {
- this.pendingToolCalls.delete(toolCallId)
+ this.pendingToolCalls.delete(rawToolCallId)
}
return entry
}
@@ -182,13 +231,13 @@ export class ClaudeStreamState {
* Persists the final input payload for a tool block once the provider signals
* completion so that downstream tool results can reference the original call.
*/
- completeToolBlock(toolCallId: string, input: unknown, providerMetadata?: ProviderMetadata): void {
+ completeToolBlock(toolCallId: string, toolName: string, input: unknown, providerMetadata?: ProviderMetadata): void {
+ const block = this.getToolBlockByRawId(toolCallId)
this.registerToolCall(toolCallId, {
- toolName: this.getToolBlockById(toolCallId)?.toolName ?? 'unknown',
+ toolName,
input,
providerMetadata
})
- const block = this.getToolBlockById(toolCallId)
if (block) {
block.resolvedInput = input
}
@@ -200,7 +249,7 @@ export class ClaudeStreamState {
if (!block) return undefined
this.blocksByIndex.delete(index)
if (block.kind === 'tool') {
- this.toolIndexById.delete(block.toolCallId)
+ this.toolIndexByNamespacedId.delete(block.toolCallId)
}
return block
}
@@ -227,7 +276,7 @@ export class ClaudeStreamState {
/** Drops cached block metadata for the currently active message. */
resetBlocks(): void {
this.blocksByIndex.clear()
- this.toolIndexById.clear()
+ this.toolIndexByNamespacedId.clear()
}
/** Resets the entire step lifecycle after emitting a terminal frame. */
@@ -236,6 +285,10 @@ export class ClaudeStreamState {
this.resetPendingUsage()
this.stepActive = false
}
+
+ getNamespacedToolCallId(rawToolCallId: string): string {
+ return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId)
+ }
}
export type { PendingToolCall }
diff --git a/src/main/services/agents/services/claudecode/commands.ts b/src/main/services/agents/services/claudecode/commands.ts
index f30d620572..0ce4f4ccef 100644
--- a/src/main/services/agents/services/claudecode/commands.ts
+++ b/src/main/services/agents/services/claudecode/commands.ts
@@ -1,25 +1,12 @@
import type { SlashCommand } from '@types'
export const builtinSlashCommands: SlashCommand[] = [
- { command: '/add-dir', description: 'Add additional working directories' },
- { command: '/agents', description: 'Manage custom AI subagents for specialized tasks' },
- { command: '/bug', description: 'Report bugs (sends conversation to Anthropic)' },
{ command: '/clear', description: 'Clear conversation history' },
{ command: '/compact', description: 'Compact conversation with optional focus instructions' },
- { command: '/config', description: 'View/modify configuration' },
- { command: '/cost', description: 'Show token usage statistics' },
- { command: '/doctor', description: 'Checks the health of your Claude Code installation' },
- { command: '/help', description: 'Get usage help' },
- { command: '/init', description: 'Initialize project with CLAUDE.md guide' },
- { command: '/login', description: 'Switch Anthropic accounts' },
- { command: '/logout', description: 'Sign out from your Anthropic account' },
- { command: '/mcp', description: 'Manage MCP server connections and OAuth authentication' },
- { command: '/memory', description: 'Edit CLAUDE.md memory files' },
- { command: '/model', description: 'Select or change the AI model' },
- { command: '/permissions', description: 'View or update permissions' },
- { command: '/pr_comments', description: 'View pull request comments' },
- { command: '/review', description: 'Request code review' },
- { command: '/status', description: 'View account and system statuses' },
- { command: '/terminal-setup', description: 'Install Shift+Enter key binding for newlines (iTerm2 and VSCode only)' },
- { command: '/vim', description: 'Enter vim mode for alternating insert and command modes' }
+ { command: '/context', description: 'Visualize current context usage as a colored grid' },
+ {
+ command: '/cost',
+ description: 'Show token usage statistics (see cost tracking guide for subscription-specific details)'
+ },
+ { command: '/todos', description: 'List current todo items' }
]
diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts
index 4e20520017..e5cefadd68 100644
--- a/src/main/services/agents/services/claudecode/index.ts
+++ b/src/main/services/agents/services/claudecode/index.ts
@@ -1,8 +1,16 @@
// src/main/services/agents/services/claudecode/index.ts
import { EventEmitter } from 'node:events'
import { createRequire } from 'node:module'
+import path from 'node:path'
-import type { CanUseTool, McpHttpServerConfig, Options, SDKMessage } from '@anthropic-ai/claude-agent-sdk'
+import type {
+ CanUseTool,
+ HookCallback,
+ McpHttpServerConfig,
+ Options,
+ PreToolUseHookInput,
+ SDKMessage
+} from '@anthropic-ai/claude-agent-sdk'
import { query } from '@anthropic-ai/claude-agent-sdk'
import { loggerService } from '@logger'
import { config as apiConfigService } from '@main/apiServer/config'
@@ -12,6 +20,8 @@ import { app } from 'electron'
import type { GetAgentSessionResponse } from '../..'
import type { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface'
+import { sessionService } from '../SessionService'
+import { buildNamespacedToolCallId } from './claude-stream-state'
import { promptForToolApproval } from './tool-permissions'
import { ClaudeStreamState, transformSDKMessageToStreamParts } from './transform'
@@ -19,6 +29,7 @@ const require_ = createRequire(import.meta.url)
const logger = loggerService.withContext('ClaudeCodeService')
const DEFAULT_AUTO_ALLOW_TOOLS = new Set(['Read', 'Glob', 'Grep'])
const shouldAutoApproveTools = process.env.CHERRY_AUTO_ALLOW_TOOLS === '1'
+const NO_RESUME_COMMANDS = ['/clear']
type UserInputMessage = {
type: 'user'
@@ -111,7 +122,11 @@ class ClaudeCodeService implements AgentServiceInterface {
// TODO: support set small model in UI
ANTHROPIC_DEFAULT_HAIKU_MODEL: modelInfo.modelId,
ELECTRON_RUN_AS_NODE: '1',
- ELECTRON_NO_ATTACH_CONSOLE: '1'
+ ELECTRON_NO_ATTACH_CONSOLE: '1',
+ // Set CLAUDE_CONFIG_DIR to app's userData directory to avoid path encoding issues
+ // on Windows when the username contains non-ASCII characters (e.g., Chinese characters)
+ // This prevents the SDK from using the user's home directory which may have encoding problems
+ CLAUDE_CONFIG_DIR: path.join(app.getPath('userData'), '.claude')
}
const errorChunks: string[] = []
@@ -148,7 +163,67 @@ class ClaudeCodeService implements AgentServiceInterface {
return { behavior: 'allow', updatedInput: input }
}
- return promptForToolApproval(toolName, input, options)
+ return promptForToolApproval(toolName, input, {
+ ...options,
+ toolCallId: buildNamespacedToolCallId(session.id, options.toolUseID)
+ })
+ }
+
+ const preToolUseHook: HookCallback = async (input, toolUseID, options) => {
+ // Type guard to ensure we're handling PreToolUse event
+ if (input.hook_event_name !== 'PreToolUse') {
+ return {}
+ }
+
+ const hookInput = input as PreToolUseHookInput
+ const toolName = hookInput.tool_name
+
+ logger.debug('PreToolUse hook triggered', {
+ session_id: hookInput.session_id,
+ tool_name: hookInput.tool_name,
+ tool_use_id: toolUseID,
+ tool_input: hookInput.tool_input,
+ cwd: hookInput.cwd,
+ permission_mode: hookInput.permission_mode,
+ autoAllowTools: autoAllowTools
+ })
+
+ if (options?.signal?.aborted) {
+ logger.debug('PreToolUse hook signal already aborted; skipping tool use', {
+ tool_name: hookInput.tool_name
+ })
+ return {}
+ }
+
+ // handle auto approved tools since it never triggers canUseTool
+ const normalizedToolName = normalizeToolName(toolName)
+ if (toolUseID) {
+ const bypassAll = input.permission_mode === 'bypassPermissions'
+ const autoAllowed = autoAllowTools.has(toolName) || autoAllowTools.has(normalizedToolName)
+ if (bypassAll || autoAllowed) {
+ const namespacedToolCallId = buildNamespacedToolCallId(session.id, toolUseID)
+ logger.debug('handling auto approved tools', {
+ toolName,
+ normalizedToolName,
+ namespacedToolCallId,
+ permission_mode: input.permission_mode,
+ autoAllowTools
+ })
+ const isRecord = (v: unknown): v is Record => {
+ return !!v && typeof v === 'object' && !Array.isArray(v)
+ }
+ const toolInput = isRecord(input.tool_input) ? input.tool_input : {}
+
+ await promptForToolApproval(toolName, toolInput, {
+ ...options,
+ toolCallId: namespacedToolCallId,
+ autoApprove: true
+ })
+ }
+ }
+
+ // Return to proceed without modification
+ return {}
}
// Build SDK options from parameters
@@ -174,7 +249,14 @@ class ClaudeCodeService implements AgentServiceInterface {
permissionMode: session.configuration?.permission_mode,
maxTurns: session.configuration?.max_turns,
allowedTools: session.allowed_tools,
- canUseTool
+ canUseTool,
+ hooks: {
+ PreToolUse: [
+ {
+ hooks: [preToolUseHook]
+ }
+ ]
+ }
}
if (session.accessible_paths.length > 1) {
@@ -197,7 +279,7 @@ class ClaudeCodeService implements AgentServiceInterface {
options.strictMcpConfig = true
}
- if (lastAgentSessionId) {
+ if (lastAgentSessionId && !NO_RESUME_COMMANDS.some((cmd) => prompt.includes(cmd))) {
options.resume = lastAgentSessionId
// TODO: use fork session when we support branching sessions
// options.forkSession = true
@@ -220,7 +302,15 @@ class ClaudeCodeService implements AgentServiceInterface {
// Start async processing on the next tick so listeners can subscribe first
setImmediate(() => {
- this.processSDKQuery(userInputStream, closeUserStream, options, aiStream, errorChunks).catch((error) => {
+ this.processSDKQuery(
+ userInputStream,
+ closeUserStream,
+ options,
+ aiStream,
+ errorChunks,
+ session.agent_id,
+ session.id
+ ).catch((error) => {
logger.error('Unhandled Claude Code stream error', {
error: error instanceof Error ? { name: error.name, message: error.message } : String(error)
})
@@ -329,12 +419,14 @@ class ClaudeCodeService implements AgentServiceInterface {
closePromptStream: () => void,
options: Options,
stream: ClaudeCodeStream,
- errorChunks: string[]
+ errorChunks: string[],
+ agentId: string,
+ sessionId: string
): Promise {
const jsonOutput: SDKMessage[] = []
let hasCompleted = false
const startTime = Date.now()
- const streamState = new ClaudeStreamState()
+ const streamState = new ClaudeStreamState({ agentSessionId: sessionId })
try {
for await (const message of query({ prompt: promptStream, options })) {
@@ -342,21 +434,60 @@ class ClaudeCodeService implements AgentServiceInterface {
jsonOutput.push(message)
- if (message.type === 'assistant' || message.type === 'user') {
- logger.silly('claude response', {
- message,
- content: JSON.stringify(message.message.content)
- })
- } else if (message.type === 'stream_event') {
- // logger.silly('Claude stream event', {
- // message,
- // event: JSON.stringify(message.event)
- // })
- } else {
- logger.silly('Claude response', {
- message,
- event: JSON.stringify(message)
+ // Handle init message - merge builtin and SDK slash_commands
+ if (message.type === 'system' && message.subtype === 'init') {
+ const sdkSlashCommands = message.slash_commands || []
+ logger.info('Received init message with slash commands', {
+ sessionId,
+ commands: sdkSlashCommands
})
+
+ try {
+ // Get builtin + local slash commands from BaseService
+ const existingCommands = await sessionService.listSlashCommands('claude-code', agentId)
+
+ // Convert SDK slash_commands (string[]) to SlashCommand[] format
+ // Ensure all commands start with '/'
+ const sdkCommands = sdkSlashCommands.map((cmd) => {
+ const normalizedCmd = cmd.startsWith('/') ? cmd : `/${cmd}`
+ return {
+ command: normalizedCmd,
+ description: undefined
+ }
+ })
+
+ // Merge: existing commands (builtin + local) + SDK commands, deduplicate by command name
+ const commandMap = new Map()
+
+ for (const cmd of existingCommands) {
+ commandMap.set(cmd.command, cmd)
+ }
+
+ for (const cmd of sdkCommands) {
+ if (!commandMap.has(cmd.command)) {
+ commandMap.set(cmd.command, cmd)
+ }
+ }
+
+ const mergedCommands = Array.from(commandMap.values())
+
+ // Update session in database
+ await sessionService.updateSession(agentId, sessionId, {
+ slash_commands: mergedCommands
+ })
+
+ logger.info('Updated session with merged slash commands', {
+ sessionId,
+ existingCount: existingCommands.length,
+ sdkCount: sdkCommands.length,
+ totalCount: mergedCommands.length
+ })
+ } catch (error) {
+ logger.error('Failed to update session slash_commands', {
+ sessionId,
+ error: error instanceof Error ? error.message : String(error)
+ })
+ }
}
const chunks = transformSDKMessageToStreamParts(message, streamState)
@@ -378,7 +509,6 @@ class ClaudeCodeService implements AgentServiceInterface {
}
}
- hasCompleted = true
const duration = Date.now() - startTime
logger.debug('SDK query completed successfully', {
diff --git a/src/main/services/agents/services/claudecode/tool-permissions.ts b/src/main/services/agents/services/claudecode/tool-permissions.ts
index c95f4c679e..bbca3bd40e 100644
--- a/src/main/services/agents/services/claudecode/tool-permissions.ts
+++ b/src/main/services/agents/services/claudecode/tool-permissions.ts
@@ -31,12 +31,14 @@ type PendingPermissionRequest = {
abortListener?: () => void
originalInput: Record
toolName: string
+ toolCallId?: string
}
type RendererPermissionRequestPayload = {
requestId: string
toolName: string
toolId: string
+ toolCallId: string
description?: string
requiresPermissions: boolean
input: Record
@@ -44,6 +46,7 @@ type RendererPermissionRequestPayload = {
createdAt: number
expiresAt: number
suggestions: PermissionUpdate[]
+ autoApprove?: boolean
}
type RendererPermissionResultPayload = {
@@ -51,6 +54,7 @@ type RendererPermissionResultPayload = {
behavior: ToolPermissionBehavior
message?: string
reason: 'response' | 'timeout' | 'aborted' | 'no-window'
+ toolCallId?: string
}
const pendingRequests = new Map()
@@ -144,7 +148,8 @@ const finalizeRequest = (
requestId,
behavior: update.behavior,
message: update.behavior === 'deny' ? update.message : undefined,
- reason
+ reason,
+ toolCallId: pending.toolCallId
}
const dispatched = broadcastToRenderer(IpcChannel.AgentToolPermission_Result, resultPayload)
@@ -206,10 +211,20 @@ const ensureIpcHandlersRegistered = () => {
})
}
+type PromptForToolApprovalOptions = {
+ signal: AbortSignal
+ suggestions?: PermissionUpdate[]
+ autoApprove?: boolean
+
+ // NOTICE: This ID is namespaced with session ID, not the raw SDK tool call ID.
+ // Format: `${sessionId}:${rawToolCallId}`, e.g., `session_123:WebFetch_0`
+ toolCallId: string
+}
+
export async function promptForToolApproval(
toolName: string,
input: Record,
- options?: { signal: AbortSignal; suggestions?: PermissionUpdate[] }
+ options: PromptForToolApprovalOptions
): Promise {
if (shouldAutoApproveTools) {
logger.debug('promptForToolApproval auto-approving tool for test', {
@@ -245,6 +260,7 @@ export async function promptForToolApproval(
logger.info('Requesting user approval for tool usage', {
requestId,
toolName,
+ toolCallId: options.toolCallId,
description: toolMetadata?.description
})
@@ -252,13 +268,15 @@ export async function promptForToolApproval(
requestId,
toolName,
toolId: toolMetadata?.id ?? toolName,
+ toolCallId: options.toolCallId,
description: toolMetadata?.description,
requiresPermissions: toolMetadata?.requirePermissions ?? false,
input: sanitizedInput,
inputPreview,
createdAt,
expiresAt,
- suggestions: sanitizedSuggestions
+ suggestions: sanitizedSuggestions,
+ autoApprove: options.autoApprove
}
const defaultDenyUpdate: PermissionResult = { behavior: 'deny', message: 'Tool request aborted before user decision' }
@@ -266,6 +284,7 @@ export async function promptForToolApproval(
logger.debug('Registering tool permission request', {
requestId,
toolName,
+ toolCallId: options.toolCallId,
requiresPermissions: requestPayload.requiresPermissions,
timeoutMs: TOOL_APPROVAL_TIMEOUT_MS,
suggestionCount: sanitizedSuggestions.length
@@ -273,7 +292,11 @@ export async function promptForToolApproval(
return new Promise((resolve) => {
const timeout = setTimeout(() => {
- logger.info('User tool permission request timed out', { requestId, toolName })
+ logger.info('User tool permission request timed out', {
+ requestId,
+ toolName,
+ toolCallId: options.toolCallId
+ })
finalizeRequest(requestId, { behavior: 'deny', message: 'Timed out waiting for approval' }, 'timeout')
}, TOOL_APPROVAL_TIMEOUT_MS)
@@ -282,12 +305,17 @@ export async function promptForToolApproval(
timeout,
originalInput: sanitizedInput,
toolName,
- signal: options?.signal
+ signal: options?.signal,
+ toolCallId: options.toolCallId
}
if (options?.signal) {
const abortListener = () => {
- logger.info('Tool permission request aborted before user responded', { requestId, toolName })
+ logger.info('Tool permission request aborted before user responded', {
+ requestId,
+ toolName,
+ toolCallId: options.toolCallId
+ })
finalizeRequest(requestId, defaultDenyUpdate, 'aborted')
}
diff --git a/src/main/services/agents/services/claudecode/transform.ts b/src/main/services/agents/services/claudecode/transform.ts
index 5905ed6434..00be683ba8 100644
--- a/src/main/services/agents/services/claudecode/transform.ts
+++ b/src/main/services/agents/services/claudecode/transform.ts
@@ -73,13 +73,21 @@ const emptyUsage: LanguageModelUsage = {
*/
const generateMessageId = (): string => `msg_${uuidv4().replace(/-/g, '')}`
+/**
+ * Removes any local command stdout/stderr XML wrappers that should never surface to the UI.
+ */
+export const stripLocalCommandTags = (text: string): string => {
+ return text.replace(/(.*?)<\/local-command-\1>/gs, '$2')
+}
+
/**
* Filters out command-* tags from text content to prevent internal command
* messages from appearing in the user-facing UI.
* Removes tags like ... and ...
*/
const filterCommandTags = (text: string): string => {
- return text.replace(/]+>.*?<\/command-[^>]+>/gs, '').trim()
+ const withoutLocalCommandTags = stripLocalCommandTags(text)
+ return withoutLocalCommandTags.replace(/]+>.*?<\/command-[^>]+>/gs, '').trim()
}
/**
@@ -102,6 +110,7 @@ const sdkMessageToProviderMetadata = (message: SDKMessage): ProviderMetadata =>
* blocks across calls so that incremental deltas can be correlated correctly.
*/
export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state: ClaudeStreamState): AgentStreamPart[] {
+ logger.silly('Transforming SDKMessage', { message: JSON.stringify(sdkMessage) })
switch (sdkMessage.type) {
case 'assistant':
return handleAssistantMessage(sdkMessage, state)
@@ -135,7 +144,8 @@ function handleAssistantMessage(
const isStreamingActive = state.hasActiveStep()
if (typeof content === 'string') {
- if (!content) {
+ const sanitizedContent = stripLocalCommandTags(content)
+ if (!sanitizedContent) {
return chunks
}
@@ -157,7 +167,7 @@ function handleAssistantMessage(
chunks.push({
type: 'text-delta',
id: textId,
- text: content,
+ text: sanitizedContent,
providerMetadata
})
chunks.push({
@@ -176,11 +186,13 @@ function handleAssistantMessage(
for (const block of content) {
switch (block.type) {
- case 'text':
- if (!isStreamingActive) {
- textBlocks.push(block.text)
+ case 'text': {
+ const sanitizedText = stripLocalCommandTags(block.text)
+ if (sanitizedText) {
+ textBlocks.push(sanitizedText)
}
break
+ }
case 'tool_use':
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
break
@@ -190,7 +202,16 @@ function handleAssistantMessage(
}
}
- if (!isStreamingActive && textBlocks.length > 0) {
+ if (textBlocks.length === 0) {
+ return chunks
+ }
+
+ const combinedText = textBlocks.join('')
+ if (!combinedText) {
+ return chunks
+ }
+
+ if (!isStreamingActive) {
const id = message.uuid?.toString() || generateMessageId()
state.beginStep()
chunks.push({
@@ -206,7 +227,7 @@ function handleAssistantMessage(
chunks.push({
type: 'text-delta',
id,
- text: textBlocks.join(''),
+ text: combinedText,
providerMetadata
})
chunks.push({
@@ -217,7 +238,27 @@ function handleAssistantMessage(
return finalizeNonStreamingStep(message, state, chunks)
}
- return chunks
+ const existingTextBlock = state.getFirstOpenTextBlock()
+ const fallbackId = existingTextBlock?.id || message.uuid?.toString() || generateMessageId()
+ if (!existingTextBlock) {
+ chunks.push({
+ type: 'text-start',
+ id: fallbackId,
+ providerMetadata
+ })
+ }
+ chunks.push({
+ type: 'text-delta',
+ id: fallbackId,
+ text: combinedText,
+ providerMetadata
+ })
+ chunks.push({
+ type: 'text-end',
+ id: fallbackId,
+ providerMetadata
+ })
+ return finalizeNonStreamingStep(message, state, chunks)
}
/**
@@ -230,15 +271,16 @@ function handleAssistantToolUse(
state: ClaudeStreamState,
chunks: AgentStreamPart[]
): void {
+ const toolCallId = state.getNamespacedToolCallId(block.id)
chunks.push({
type: 'tool-call',
- toolCallId: block.id,
+ toolCallId,
toolName: block.name,
input: block.input,
providerExecuted: true,
providerMetadata
})
- state.completeToolBlock(block.id, block.input, providerMetadata)
+ state.completeToolBlock(block.id, block.name, block.input, providerMetadata)
}
/**
@@ -318,10 +360,11 @@ function handleUserMessage(
if (block.type === 'tool_result') {
const toolResult = block as ToolResultContent
const pendingCall = state.consumePendingToolCall(toolResult.tool_use_id)
+ const toolCallId = pendingCall?.toolCallId ?? state.getNamespacedToolCallId(toolResult.tool_use_id)
if (toolResult.is_error) {
chunks.push({
type: 'tool-error',
- toolCallId: toolResult.tool_use_id,
+ toolCallId,
toolName: pendingCall?.toolName ?? 'unknown',
input: pendingCall?.input,
error: toolResult.content,
@@ -330,7 +373,7 @@ function handleUserMessage(
} else {
chunks.push({
type: 'tool-result',
- toolCallId: toolResult.tool_use_id,
+ toolCallId,
toolName: pendingCall?.toolName ?? 'unknown',
input: pendingCall?.input,
output: toolResult.content,
@@ -444,6 +487,9 @@ function handleStreamEvent(
}
case 'message_stop': {
+ if (!state.hasActiveStep()) {
+ break
+ }
const pending = state.getPendingUsage()
chunks.push({
type: 'finish-step',
@@ -501,7 +547,7 @@ function handleContentBlockStart(
}
case 'tool_use': {
const block = state.openToolBlock(index, {
- toolCallId: contentBlock.id,
+ rawToolCallId: contentBlock.id,
toolName: contentBlock.name,
providerMetadata
})
@@ -537,6 +583,10 @@ function handleContentBlockDelta(
logger.warn('Received text_delta for unknown block', { index })
return
}
+ block.text = stripLocalCommandTags(block.text)
+ if (!block.text) {
+ break
+ }
chunks.push({
type: 'text-delta',
id: block.id,
diff --git a/src/main/services/mcp/oauth/callback.ts b/src/main/services/mcp/oauth/callback.ts
index 81d435f867..c13ecd5c07 100644
--- a/src/main/services/mcp/oauth/callback.ts
+++ b/src/main/services/mcp/oauth/callback.ts
@@ -1,4 +1,6 @@
import { loggerService } from '@logger'
+import { configManager } from '@main/services/ConfigManager'
+import { locales } from '@main/utils/locales'
import type EventEmitter from 'events'
import http from 'http'
import { URL } from 'url'
@@ -7,6 +9,36 @@ import type { OAuthCallbackServerOptions } from './types'
const logger = loggerService.withContext('MCP:OAuthCallbackServer')
+function getTranslation(key: string): string {
+ const language = configManager.getLanguage()
+ const localeData = locales[language]
+
+ if (!localeData) {
+ logger.warn(`No locale data found for language: ${language}`)
+ return key
+ }
+
+ const translations = localeData.translation as any
+ if (!translations) {
+ logger.warn(`No translations found for language: ${language}`)
+ return key
+ }
+
+ const keys = key.split('.')
+ let value = translations
+
+ for (const k of keys) {
+ if (value && typeof value === 'object' && k in value) {
+ value = value[k]
+ } else {
+ logger.warn(`Translation key not found: ${key} (failed at: ${k})`)
+ return key // fallback to key if translation not found
+ }
+ }
+
+ return typeof value === 'string' ? value : key
+}
+
export class CallBackServer {
private server: Promise
private events: EventEmitter
@@ -28,6 +60,55 @@ export class CallBackServer {
if (code) {
// Emit the code event
this.events.emit('auth-code-received', code)
+ // Send success response to browser
+ const title = getTranslation('settings.mcp.oauth.callback.title')
+ const message = getTranslation('settings.mcp.oauth.callback.message')
+
+ res.writeHead(200, { 'Content-Type': 'text/html; charset=utf-8' })
+ res.end(`
+
+
+
+
+ ${title}
+
+
+
+
+
${title}
+
${message}
+
+
+
+ `)
+ } else {
+ res.writeHead(400, { 'Content-Type': 'text/plain' })
+ res.end('Missing authorization code')
}
} catch (error) {
logger.error('Error processing OAuth callback:', error as Error)
diff --git a/src/main/services/ocr/builtin/OvOcrService.ts b/src/main/services/ocr/builtin/OvOcrService.ts
index 6e0eee1c37..052682be64 100644
--- a/src/main/services/ocr/builtin/OvOcrService.ts
+++ b/src/main/services/ocr/builtin/OvOcrService.ts
@@ -1,5 +1,6 @@
import { loggerService } from '@logger'
import { isWin } from '@main/constant'
+import { HOME_CHERRY_DIR } from '@shared/config/constant'
import type { OcrOvConfig, OcrResult, SupportedOcrFile } from '@types'
import { isImageFileMetadata } from '@types'
import { exec } from 'child_process'
@@ -13,7 +14,7 @@ import { OcrBaseService } from './OcrBaseService'
const logger = loggerService.withContext('OvOcrService')
const execAsync = promisify(exec)
-const PATH_BAT_FILE = path.join(os.homedir(), '.cherrystudio', 'ovms', 'ovocr', 'run.npu.bat')
+const PATH_BAT_FILE = path.join(os.homedir(), HOME_CHERRY_DIR, 'ovms', 'ovocr', 'run.npu.bat')
export class OvOcrService extends OcrBaseService {
constructor() {
@@ -30,7 +31,7 @@ export class OvOcrService extends OcrBaseService {
}
private getOvOcrPath(): string {
- return path.join(os.homedir(), '.cherrystudio', 'ovms', 'ovocr')
+ return path.join(os.homedir(), HOME_CHERRY_DIR, 'ovms', 'ovocr')
}
private getImgDir(): string {
diff --git a/src/main/utils/__tests__/mcp.test.ts b/src/main/utils/__tests__/mcp.test.ts
new file mode 100644
index 0000000000..b1a35f925e
--- /dev/null
+++ b/src/main/utils/__tests__/mcp.test.ts
@@ -0,0 +1,196 @@
+import { describe, expect, it } from 'vitest'
+
+import { buildFunctionCallToolName } from '../mcp'
+
+describe('buildFunctionCallToolName', () => {
+ describe('basic functionality', () => {
+ it('should combine server name and tool name', () => {
+ const result = buildFunctionCallToolName('github', 'search_issues')
+ expect(result).toContain('github')
+ expect(result).toContain('search')
+ })
+
+ it('should sanitize names by replacing dashes with underscores', () => {
+ const result = buildFunctionCallToolName('my-server', 'my-tool')
+ // Input dashes are replaced, but the separator between server and tool is a dash
+ expect(result).toBe('my_serv-my_tool')
+ expect(result).toContain('_')
+ })
+
+ it('should handle empty server names gracefully', () => {
+ const result = buildFunctionCallToolName('', 'tool')
+ expect(result).toBeTruthy()
+ })
+ })
+
+ describe('uniqueness with serverId', () => {
+ it('should generate different IDs for same server name but different serverIds', () => {
+ const serverId1 = 'server-id-123456'
+ const serverId2 = 'server-id-789012'
+ const serverName = 'github'
+ const toolName = 'search_repos'
+
+ const result1 = buildFunctionCallToolName(serverName, toolName, serverId1)
+ const result2 = buildFunctionCallToolName(serverName, toolName, serverId2)
+
+ expect(result1).not.toBe(result2)
+ expect(result1).toContain('123456')
+ expect(result2).toContain('789012')
+ })
+
+ it('should generate same ID when serverId is not provided', () => {
+ const serverName = 'github'
+ const toolName = 'search_repos'
+
+ const result1 = buildFunctionCallToolName(serverName, toolName)
+ const result2 = buildFunctionCallToolName(serverName, toolName)
+
+ expect(result1).toBe(result2)
+ })
+
+ it('should include serverId suffix when provided', () => {
+ const serverId = 'abc123def456'
+ const result = buildFunctionCallToolName('server', 'tool', serverId)
+
+ // Should include last 6 chars of serverId
+ expect(result).toContain('ef456')
+ })
+ })
+
+ describe('character sanitization', () => {
+ it('should replace invalid characters with underscores', () => {
+ const result = buildFunctionCallToolName('test@server', 'tool#name')
+ expect(result).not.toMatch(/[@#]/)
+ expect(result).toMatch(/^[a-zA-Z0-9_-]+$/)
+ })
+
+ it('should ensure name starts with a letter', () => {
+ const result = buildFunctionCallToolName('123server', '456tool')
+ expect(result).toMatch(/^[a-zA-Z]/)
+ })
+
+ it('should handle consecutive underscores/dashes', () => {
+ const result = buildFunctionCallToolName('my--server', 'my__tool')
+ expect(result).not.toMatch(/[_-]{2,}/)
+ })
+ })
+
+ describe('length constraints', () => {
+ it('should truncate names longer than 63 characters', () => {
+ const longServerName = 'a'.repeat(50)
+ const longToolName = 'b'.repeat(50)
+ const result = buildFunctionCallToolName(longServerName, longToolName, 'id123456')
+
+ expect(result.length).toBeLessThanOrEqual(63)
+ })
+
+ it('should not end with underscore or dash after truncation', () => {
+ const longServerName = 'a'.repeat(50)
+ const longToolName = 'b'.repeat(50)
+ const result = buildFunctionCallToolName(longServerName, longToolName, 'id123456')
+
+ expect(result).not.toMatch(/[_-]$/)
+ })
+
+ it('should preserve serverId suffix even with long server/tool names', () => {
+ const longServerName = 'a'.repeat(50)
+ const longToolName = 'b'.repeat(50)
+ const serverId = 'server-id-xyz789'
+
+ const result = buildFunctionCallToolName(longServerName, longToolName, serverId)
+
+ // The suffix should be preserved and not truncated
+ expect(result).toContain('xyz789')
+ expect(result.length).toBeLessThanOrEqual(63)
+ })
+
+ it('should ensure two long-named servers with different IDs produce different results', () => {
+ const longServerName = 'a'.repeat(50)
+ const longToolName = 'b'.repeat(50)
+ const serverId1 = 'server-id-abc123'
+ const serverId2 = 'server-id-def456'
+
+ const result1 = buildFunctionCallToolName(longServerName, longToolName, serverId1)
+ const result2 = buildFunctionCallToolName(longServerName, longToolName, serverId2)
+
+ // Both should be within limit
+ expect(result1.length).toBeLessThanOrEqual(63)
+ expect(result2.length).toBeLessThanOrEqual(63)
+
+ // They should be different due to preserved suffix
+ expect(result1).not.toBe(result2)
+ })
+ })
+
+ describe('edge cases with serverId', () => {
+ it('should handle serverId with only non-alphanumeric characters', () => {
+ const serverId = '------' // All dashes
+ const result = buildFunctionCallToolName('server', 'tool', serverId)
+
+ // Should still produce a valid unique suffix via fallback hash
+ expect(result).toBeTruthy()
+ expect(result.length).toBeLessThanOrEqual(63)
+ expect(result).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/)
+ // Should have a suffix (underscore followed by something)
+ expect(result).toMatch(/_[a-z0-9]+$/)
+ })
+
+ it('should produce different results for different non-alphanumeric serverIds', () => {
+ const serverId1 = '------'
+ const serverId2 = '!!!!!!'
+
+ const result1 = buildFunctionCallToolName('server', 'tool', serverId1)
+ const result2 = buildFunctionCallToolName('server', 'tool', serverId2)
+
+ // Should be different because the hash fallback produces different values
+ expect(result1).not.toBe(result2)
+ })
+
+ it('should handle empty string serverId differently from undefined', () => {
+ const resultWithEmpty = buildFunctionCallToolName('server', 'tool', '')
+ const resultWithUndefined = buildFunctionCallToolName('server', 'tool', undefined)
+
+ // Empty string is falsy, so both should behave the same (no suffix)
+ expect(resultWithEmpty).toBe(resultWithUndefined)
+ })
+
+ it('should handle serverId with mixed alphanumeric and special chars', () => {
+ const serverId = 'ab@#cd' // Mixed chars, last 6 chars contain some alphanumeric
+ const result = buildFunctionCallToolName('server', 'tool', serverId)
+
+ // Should extract alphanumeric chars: 'abcd' from 'ab@#cd'
+ expect(result).toContain('abcd')
+ })
+ })
+
+ describe('real-world scenarios', () => {
+ it('should handle GitHub MCP server instances correctly', () => {
+ const serverName = 'github'
+ const toolName = 'search_repositories'
+
+ const githubComId = 'server-github-com-abc123'
+ const gheId = 'server-ghe-internal-xyz789'
+
+ const tool1 = buildFunctionCallToolName(serverName, toolName, githubComId)
+ const tool2 = buildFunctionCallToolName(serverName, toolName, gheId)
+
+ // Should be different
+ expect(tool1).not.toBe(tool2)
+
+ // Both should be valid identifiers
+ expect(tool1).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/)
+ expect(tool2).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/)
+
+ // Both should be <= 63 chars
+ expect(tool1.length).toBeLessThanOrEqual(63)
+ expect(tool2.length).toBeLessThanOrEqual(63)
+ })
+
+ it('should handle tool names that already include server name prefix', () => {
+ const result = buildFunctionCallToolName('github', 'github_search_repos')
+ expect(result).toBeTruthy()
+ // Should not double the server name
+ expect(result.split('github').length - 1).toBeLessThanOrEqual(2)
+ })
+ })
+})
diff --git a/src/main/utils/file.ts b/src/main/utils/file.ts
index 17155f423b..1432dccc8a 100644
--- a/src/main/utils/file.ts
+++ b/src/main/utils/file.ts
@@ -5,7 +5,7 @@ import os from 'node:os'
import path from 'node:path'
import { loggerService } from '@logger'
-import { audioExts, documentExts, imageExts, MB, textExts, videoExts } from '@shared/config/constant'
+import { audioExts, documentExts, HOME_CHERRY_DIR, imageExts, MB, textExts, videoExts } from '@shared/config/constant'
import type { FileMetadata, NotesTreeNode } from '@types'
import { FileTypes } from '@types'
import chardet from 'chardet'
@@ -160,7 +160,7 @@ export function getNotesDir() {
}
export function getConfigDir() {
- return path.join(os.homedir(), '.cherrystudio', 'config')
+ return path.join(os.homedir(), HOME_CHERRY_DIR, 'config')
}
export function getCacheDir() {
@@ -172,7 +172,7 @@ export function getAppConfigDir(name: string) {
}
export function getMcpDir() {
- return path.join(os.homedir(), '.cherrystudio', 'mcp')
+ return path.join(os.homedir(), HOME_CHERRY_DIR, 'mcp')
}
/**
diff --git a/src/main/utils/init.ts b/src/main/utils/init.ts
index 63cf69e89b..20884b1eeb 100644
--- a/src/main/utils/init.ts
+++ b/src/main/utils/init.ts
@@ -3,6 +3,7 @@ import os from 'node:os'
import path from 'node:path'
import { isLinux, isPortable, isWin } from '@main/constant'
+import { HOME_CHERRY_DIR } from '@shared/config/constant'
import { app } from 'electron'
// Please don't import any other modules which is not node/electron built-in modules
@@ -17,7 +18,7 @@ function hasWritePermission(path: string) {
}
function getConfigDir() {
- return path.join(os.homedir(), '.cherrystudio', 'config')
+ return path.join(os.homedir(), HOME_CHERRY_DIR, 'config')
}
export function initAppDataDir() {
diff --git a/src/main/utils/mcp.ts b/src/main/utils/mcp.ts
index 23d19806d9..cfa700f2e6 100644
--- a/src/main/utils/mcp.ts
+++ b/src/main/utils/mcp.ts
@@ -1,7 +1,25 @@
-export function buildFunctionCallToolName(serverName: string, toolName: string) {
+export function buildFunctionCallToolName(serverName: string, toolName: string, serverId?: string) {
const sanitizedServer = serverName.trim().replace(/-/g, '_')
const sanitizedTool = toolName.trim().replace(/-/g, '_')
+ // Calculate suffix first to reserve space for it
+ // Suffix format: "_" + 6 alphanumeric chars = 7 chars total
+ let serverIdSuffix = ''
+ if (serverId) {
+ // Take the last 6 characters of the serverId for brevity
+ serverIdSuffix = serverId.slice(-6).replace(/[^a-zA-Z0-9]/g, '')
+
+ // Fallback: if suffix becomes empty (all non-alphanumeric chars), use a simple hash
+ if (!serverIdSuffix) {
+ const hash = serverId.split('').reduce((acc, char) => acc + char.charCodeAt(0), 0)
+ serverIdSuffix = hash.toString(36).slice(-6) || 'x'
+ }
+ }
+
+ // Reserve space for suffix when calculating max base name length
+ const SUFFIX_LENGTH = serverIdSuffix ? serverIdSuffix.length + 1 : 0 // +1 for underscore
+ const MAX_BASE_LENGTH = 63 - SUFFIX_LENGTH
+
// Combine server name and tool name
let name = sanitizedTool
if (!sanitizedTool.includes(sanitizedServer.slice(0, 7))) {
@@ -20,9 +38,9 @@ export function buildFunctionCallToolName(serverName: string, toolName: string)
// Remove consecutive underscores/dashes (optional improvement)
name = name.replace(/[_-]{2,}/g, '_')
- // Truncate to 63 characters maximum
- if (name.length > 63) {
- name = name.slice(0, 63)
+ // Truncate base name BEFORE adding suffix to ensure suffix is never cut off
+ if (name.length > MAX_BASE_LENGTH) {
+ name = name.slice(0, MAX_BASE_LENGTH)
}
// Handle edge case: ensure we still have a valid name if truncation left invalid chars at edges
@@ -30,5 +48,10 @@ export function buildFunctionCallToolName(serverName: string, toolName: string)
name = name.slice(0, -1)
}
+ // Now append the suffix - it will always fit within 63 chars
+ if (serverIdSuffix) {
+ name = `${name}_${serverIdSuffix}`
+ }
+
return name
}
diff --git a/src/main/utils/process.ts b/src/main/utils/process.ts
index f028f2d3c7..f36e86861d 100644
--- a/src/main/utils/process.ts
+++ b/src/main/utils/process.ts
@@ -1,4 +1,5 @@
import { loggerService } from '@logger'
+import { HOME_CHERRY_DIR } from '@shared/config/constant'
import { spawn } from 'child_process'
import fs from 'fs'
import os from 'os'
@@ -46,11 +47,11 @@ export async function getBinaryName(name: string): Promise {
export async function getBinaryPath(name?: string): Promise {
if (!name) {
- return path.join(os.homedir(), '.cherrystudio', 'bin')
+ return path.join(os.homedir(), HOME_CHERRY_DIR, 'bin')
}
const binaryName = await getBinaryName(name)
- const binariesDir = path.join(os.homedir(), '.cherrystudio', 'bin')
+ const binariesDir = path.join(os.homedir(), HOME_CHERRY_DIR, 'bin')
const binariesDirExists = fs.existsSync(binariesDir)
return binariesDirExists ? path.join(binariesDir, binaryName) : binaryName
}
diff --git a/src/preload/index.ts b/src/preload/index.ts
index 861b020f18..26df2c1f20 100644
--- a/src/preload/index.ts
+++ b/src/preload/index.ts
@@ -48,6 +48,16 @@ import type {
} from '../renderer/src/types/plugin'
import type { ActionItem } from '../renderer/src/types/selectionTypes'
+type DirectoryListOptions = {
+ recursive?: boolean
+ maxDepth?: number
+ includeHidden?: boolean
+ includeFiles?: boolean
+ includeDirectories?: boolean
+ maxEntries?: number
+ searchPattern?: string
+}
+
export function tracedInvoke(channel: string, spanContext: SpanContext | undefined, ...args: any[]) {
if (spanContext) {
const data = { type: 'trace', context: spanContext }
@@ -101,6 +111,7 @@ const api = {
setFullScreen: (value: boolean): Promise => ipcRenderer.invoke(IpcChannel.App_SetFullScreen, value),
isFullScreen: (): Promise => ipcRenderer.invoke(IpcChannel.App_IsFullScreen),
getSystemFonts: (): Promise => ipcRenderer.invoke(IpcChannel.App_GetSystemFonts),
+ mockCrashRenderProcess: () => ipcRenderer.invoke(IpcChannel.APP_CrashRenderProcess),
mac: {
isProcessTrusted: (): Promise => ipcRenderer.invoke(IpcChannel.App_MacIsProcessTrusted),
requestProcessTrust: (): Promise => ipcRenderer.invoke(IpcChannel.App_MacRequestProcessTrust)
@@ -111,7 +122,8 @@ const api = {
system: {
getDeviceType: () => ipcRenderer.invoke(IpcChannel.System_GetDeviceType),
getHostname: () => ipcRenderer.invoke(IpcChannel.System_GetHostname),
- getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName)
+ getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName),
+ checkGitBash: (): Promise => ipcRenderer.invoke(IpcChannel.System_CheckGitBash)
},
devTools: {
toggle: () => ipcRenderer.invoke(IpcChannel.System_ToggleDevTools)
@@ -201,6 +213,8 @@ const api = {
openFileWithRelativePath: (file: FileMetadata) => ipcRenderer.invoke(IpcChannel.File_OpenWithRelativePath, file),
isTextFile: (filePath: string): Promise => ipcRenderer.invoke(IpcChannel.File_IsTextFile, filePath),
getDirectoryStructure: (dirPath: string) => ipcRenderer.invoke(IpcChannel.File_GetDirectoryStructure, dirPath),
+ listDirectory: (dirPath: string, options?: DirectoryListOptions) =>
+ ipcRenderer.invoke(IpcChannel.File_ListDirectory, dirPath, options),
checkFileName: (dirPath: string, fileName: string, isFile: boolean) =>
ipcRenderer.invoke(IpcChannel.File_CheckFileName, dirPath, fileName, isFile),
validateNotesDirectory: (dirPath: string) => ipcRenderer.invoke(IpcChannel.File_ValidateNotesDirectory, dirPath),
diff --git a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts
index 6e4288d241..5de2ac3453 100644
--- a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts
+++ b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts
@@ -30,18 +30,22 @@ export class AiSdkToChunkAdapter {
private onSessionUpdate?: (sessionId: string) => void
private responseStartTimestamp: number | null = null
private firstTokenTimestamp: number | null = null
+ private hasTextContent = false
+ private getSessionWasCleared?: () => boolean
constructor(
private onChunk: (chunk: Chunk) => void,
mcpTools: MCPTool[] = [],
accumulate?: boolean,
enableWebSearch?: boolean,
- onSessionUpdate?: (sessionId: string) => void
+ onSessionUpdate?: (sessionId: string) => void,
+ getSessionWasCleared?: () => boolean
) {
this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools)
this.accumulate = accumulate
this.enableWebSearch = enableWebSearch || false
this.onSessionUpdate = onSessionUpdate
+ this.getSessionWasCleared = getSessionWasCleared
}
private markFirstTokenIfNeeded() {
@@ -84,8 +88,9 @@ export class AiSdkToChunkAdapter {
}
this.resetTimingState()
this.responseStartTimestamp = Date.now()
- // Reset link converter state at the start of stream
+ // Reset state at the start of stream
this.isFirstChunk = true
+ this.hasTextContent = false
try {
while (true) {
@@ -129,6 +134,8 @@ export class AiSdkToChunkAdapter {
const agentRawMessage = chunk.rawValue as ClaudeCodeRawValue
if (agentRawMessage.type === 'init' && agentRawMessage.session_id) {
this.onSessionUpdate?.(agentRawMessage.session_id)
+ } else if (agentRawMessage.type === 'compact' && agentRawMessage.session_id) {
+ this.onSessionUpdate?.(agentRawMessage.session_id)
}
this.onChunk({
type: ChunkType.RAW,
@@ -143,6 +150,7 @@ export class AiSdkToChunkAdapter {
})
break
case 'text-delta': {
+ this.hasTextContent = true
const processedText = chunk.text || ''
let finalText: string
@@ -301,6 +309,25 @@ export class AiSdkToChunkAdapter {
}
case 'finish': {
+ // Check if session was cleared (e.g., /clear command) and no text was output
+ const sessionCleared = this.getSessionWasCleared?.() ?? false
+ if (sessionCleared && !this.hasTextContent) {
+ // Inject a "context cleared" message for the user
+ const clearMessage = '✨ Context cleared. Starting fresh conversation.'
+ this.onChunk({
+ type: ChunkType.TEXT_START
+ })
+ this.onChunk({
+ type: ChunkType.TEXT_DELTA,
+ text: clearMessage
+ })
+ this.onChunk({
+ type: ChunkType.TEXT_COMPLETE,
+ text: clearMessage
+ })
+ final.text = clearMessage
+ }
+
const usage = {
completion_tokens: chunk.totalUsage?.outputTokens || 0,
prompt_tokens: chunk.totalUsage?.inputTokens || 0,
@@ -359,14 +386,13 @@ export class AiSdkToChunkAdapter {
case 'error':
this.onChunk({
type: ChunkType.ERROR,
- error:
- chunk.error instanceof AISDKError
- ? chunk.error
- : new ProviderSpecificError({
- message: formatErrorMessage(chunk.error),
- provider: 'unknown',
- cause: chunk.error
- })
+ error: AISDKError.isInstance(chunk.error)
+ ? chunk.error
+ : new ProviderSpecificError({
+ message: formatErrorMessage(chunk.error),
+ provider: 'unknown',
+ cause: chunk.error
+ })
})
break
diff --git a/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts b/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts
index 32c7e534e3..b5acbb690b 100644
--- a/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts
+++ b/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts
@@ -212,8 +212,9 @@ export class ToolCallChunkHandler {
description: toolName,
type: 'builtin'
} as BaseTool
- } else if ((mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool)) {
+ } else if ((mcpTool = this.mcpTools.find((t) => t.id === toolName) as MCPTool)) {
// 如果是客户端执行的 MCP 工具,沿用现有逻辑
+ // toolName is mcpTool.id (registered with id as key in convertMcpToolsToAiSdkTools)
logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
// mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool
// if (!mcpTool) {
diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts
index 800d2ff302..4379547a3c 100644
--- a/src/renderer/src/aiCore/index_new.ts
+++ b/src/renderer/src/aiCore/index_new.ts
@@ -10,13 +10,14 @@
import { createExecutor } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
+import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/models/ModelAdapter'
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
-import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
+import { type Assistant, type GenerateImageParams, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import type { AiSdkModel, StreamTextParams } from '@renderer/types/aiCoreTypes'
import { SUPPORTED_IMAGE_ENDPOINT_LIST } from '@renderer/utils'
import { buildClaudeCodeSystemModelMessage } from '@shared/anthropic'
-import { type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai'
+import { gateway, type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai'
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
import LegacyAiProvider from './legacy/index'
@@ -26,11 +27,13 @@ import { buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
import { buildPlugins } from './plugins/PluginBuilder'
import { createAiSdkProvider } from './provider/factory'
import {
+ adaptProvider,
getActualProvider,
isModernSdkSupported,
prepareSpecialProviderConfig,
providerToAiSdkConfig
} from './provider/providerConfig'
+import type { AiSdkConfig } from './types'
const logger = loggerService.withContext('ModernAiProvider')
@@ -43,12 +46,44 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider
- private config?: ReturnType
+ private config?: AiSdkConfig
private actualProvider: Provider
private model?: Model
private localProvider: Awaited | null = null
- // 构造函数重载签名
+ /**
+ * Constructor for ModernAiProvider
+ *
+ * @param modelOrProvider - Model or Provider object
+ * @param provider - Optional Provider object (only used when first param is Model)
+ *
+ * @remarks
+ * **Important behavior notes**:
+ *
+ * 1. When called with `(model)`:
+ * - Calls `getActualProvider(model)` to retrieve and format the provider
+ * - URL will be automatically formatted via `formatProviderApiHost`, adding version suffixes like `/v1`
+ *
+ * 2. When called with `(model, provider)`:
+ * - The provided provider will be adapted via `adaptProvider`
+ * - URL formatting behavior depends on the adapted result
+ *
+ * 3. When called with `(provider)`:
+ * - The provider will be adapted via `adaptProvider`
+ * - Used for operations that don't need a model (e.g., fetchModels)
+ *
+ * @example
+ * ```typescript
+ * // Recommended: Auto-format URL
+ * const ai = new ModernAiProvider(model)
+ *
+ * // Provider will be adapted
+ * const ai = new ModernAiProvider(model, customProvider)
+ *
+ * // For operations that don't need a model
+ * const ai = new ModernAiProvider(provider)
+ * ```
+ */
constructor(model: Model, provider?: Provider)
constructor(provider: Provider)
constructor(modelOrProvider: Model | Provider, provider?: Provider)
@@ -56,12 +91,12 @@ export default class ModernAiProvider {
if (this.isModel(modelOrProvider)) {
// 传入的是 Model
this.model = modelOrProvider
- this.actualProvider = provider || getActualProvider(modelOrProvider)
+ this.actualProvider = provider ? adaptProvider({ provider }) : getActualProvider(modelOrProvider)
// 只保存配置,不预先创建executor
this.config = providerToAiSdkConfig(this.actualProvider, modelOrProvider)
} else {
// 传入的是 Provider
- this.actualProvider = modelOrProvider
+ this.actualProvider = adaptProvider({ provider: modelOrProvider })
// model为可选,某些操作(如fetchModels)不需要model
}
@@ -85,9 +120,17 @@ export default class ModernAiProvider {
throw new Error('Model is required for completions. Please use constructor with model parameter.')
}
- // 每次请求时重新生成配置以确保API key轮换生效
- this.config = providerToAiSdkConfig(this.actualProvider, this.model)
- logger.debug('Generated provider config for completions', this.config)
+ // Config is now set in constructor, ApiService handles key rotation before passing provider
+ if (!this.config) {
+ // If config wasn't set in constructor (when provider only), generate it now
+ this.config = providerToAiSdkConfig(this.actualProvider, this.model!)
+ }
+ logger.debug('Using provider config for completions', this.config)
+
+ // 检查 config 是否存在
+ if (!this.config) {
+ throw new Error('Provider config is undefined; cannot proceed with completions')
+ }
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
providerConfig.isImageGenerationEndpoint = true
}
@@ -148,7 +191,8 @@ export default class ModernAiProvider {
params: StreamTextParams,
config: ModernAiProviderConfig
): Promise {
- if (config.isImageGenerationEndpoint) {
+ // ai-gateway不是image/generation 端点,所以就先不走legacy了
+ if (config.isImageGenerationEndpoint && this.getActualProvider().id !== SystemProviderIds.gateway) {
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
if (!config.uiMessages) {
throw new Error('uiMessages is required for image generation endpoint')
@@ -314,10 +358,10 @@ export default class ModernAiProvider {
}
}
- /**
- * 使用现代化 AI SDK 的图像生成实现,支持流式输出
- * @deprecated 已改为使用 legacy 实现以支持图片编辑等高级功能
- */
+ // /**
+ // * 使用现代化 AI SDK 的图像生成实现,支持流式输出
+ // * @deprecated 已改为使用 legacy 实现以支持图片编辑等高级功能
+ // */
/*
private async modernImageGeneration(
model: ImageModel,
@@ -439,7 +483,12 @@ export default class ModernAiProvider {
// 代理其他方法到原有实现
public async models() {
- return this.legacyProvider.models()
+ if (this.actualProvider.id === SystemProviderIds.gateway) {
+ const gatewayModels = (await gateway.getAvailableModels()).models
+ return normalizeGatewayModels(this.actualProvider, gatewayModels)
+ }
+ const sdkModels = await this.legacyProvider.models()
+ return normalizeSdkModels(this.actualProvider, sdkModels)
}
public async getEmbeddingDimensions(model: Model): Promise {
@@ -450,8 +499,13 @@ export default class ModernAiProvider {
// 如果支持新的 AI SDK,使用现代化实现
if (isModernSdkSupported(this.actualProvider)) {
try {
+ // 确保 config 已定义
+ if (!this.config) {
+ throw new Error('Provider config is undefined; cannot proceed with generateImage')
+ }
+
// 确保本地provider已创建
- if (!this.localProvider) {
+ if (!this.localProvider && this.config) {
this.localProvider = await createAiSdkProvider(this.config)
if (!this.localProvider) {
throw new Error('Local provider not created')
diff --git a/src/renderer/src/aiCore/legacy/clients/ApiClientFactory.ts b/src/renderer/src/aiCore/legacy/clients/ApiClientFactory.ts
index bc416161c4..ee878f5861 100644
--- a/src/renderer/src/aiCore/legacy/clients/ApiClientFactory.ts
+++ b/src/renderer/src/aiCore/legacy/clients/ApiClientFactory.ts
@@ -1,6 +1,6 @@
import { loggerService } from '@logger'
-import { isNewApiProvider } from '@renderer/config/providers'
import type { Provider } from '@renderer/types'
+import { isNewApiProvider } from '@renderer/utils/provider'
import { AihubmixAPIClient } from './aihubmix/AihubmixAPIClient'
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
diff --git a/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts b/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts
index 767cad1294..92f24b4abe 100644
--- a/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts
+++ b/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts
@@ -1,14 +1,15 @@
import { loggerService } from '@logger'
import {
+ getModelSupportedVerbosity,
isFunctionCallingModel,
isNotSupportTemperatureAndTopP,
isOpenAIModel,
isSupportFlexServiceTierModel
} from '@renderer/config/models'
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
-import { isSupportServiceTierProvider } from '@renderer/config/providers'
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
import { getAssistantSettings } from '@renderer/services/AssistantService'
+import type { RootState } from '@renderer/store'
import type {
Assistant,
GenerateImageParams,
@@ -18,7 +19,6 @@ import type {
MCPToolResponse,
MemoryItem,
Model,
- OpenAIVerbosity,
Provider,
ToolCallResponse,
WebSearchProviderResponse,
@@ -32,6 +32,7 @@ import {
OpenAIServiceTiers,
SystemProviderIds
} from '@renderer/types'
+import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import type { Message } from '@renderer/types/newMessage'
import type {
RequestOptions,
@@ -47,6 +48,7 @@ import type {
import { isJSON, parseJSON } from '@renderer/utils'
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
+import { isSupportServiceTierProvider } from '@renderer/utils/provider'
import { defaultTimeout } from '@shared/config/constant'
import { defaultAppHeaders } from '@shared/utils'
import { isEmpty } from 'lodash'
@@ -242,19 +244,22 @@ export abstract class BaseApiClient<
return serviceTierSetting
}
- protected getVerbosity(): OpenAIVerbosity {
+ protected getVerbosity(model?: Model): OpenAIVerbosity {
try {
- const state = window.store?.getState()
+ const state = window.store?.getState() as RootState
const verbosity = state?.settings?.openAI?.verbosity
- if (verbosity && ['low', 'medium', 'high'].includes(verbosity)) {
- return verbosity
+ // If model is provided, check if the verbosity is supported by the model
+ if (model) {
+ const supportedVerbosity = getModelSupportedVerbosity(model)
+ // Use user's verbosity if supported, otherwise use the first supported option
+ return supportedVerbosity.includes(verbosity) ? verbosity : supportedVerbosity[0]
}
+ return verbosity
} catch (error) {
- logger.warn('Failed to get verbosity from state:', error as Error)
+ logger.warn('Failed to get verbosity from state. Fallback to undefined.', error as Error)
+ return undefined
}
-
- return 'medium'
}
protected getTimeout(model: Model) {
@@ -398,6 +403,9 @@ export abstract class BaseApiClient<
if (!param.name?.trim()) {
return acc
}
+ // Parse JSON type parameters (Legacy API clients)
+ // Related: src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx:133-148
+ // The UI stores JSON type params as strings, this function parses them before sending to API
if (param.type === 'json') {
const value = param.value as string
if (value === 'undefined') {
diff --git a/src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts b/src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts
index 03ec1e1ea2..991c436ca3 100644
--- a/src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts
+++ b/src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts
@@ -58,10 +58,27 @@ vi.mock('../aws/AwsBedrockAPIClient', () => ({
AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({}))
}))
+vi.mock('@renderer/services/AssistantService.ts', () => ({
+ getDefaultAssistant: () => {
+ return {
+ id: 'default',
+ name: 'default',
+ emoji: '😀',
+ prompt: '',
+ topics: [],
+ messages: [],
+ type: 'assistant',
+ regularPhrases: [],
+ settings: {}
+ }
+ }
+}))
+
// Mock the models config to prevent circular dependency issues
vi.mock('@renderer/config/models', () => ({
findTokenLimit: vi.fn(),
isReasoningModel: vi.fn(),
+ isOpenAILLMModel: vi.fn(),
SYSTEM_MODELS: {
silicon: [],
defaultModel: []
diff --git a/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts
index 27e659c1af..9c930a33ec 100644
--- a/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts
+++ b/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts
@@ -46,6 +46,7 @@ import type {
GeminiSdkRawOutput,
GeminiSdkToolCall
} from '@renderer/types/sdk'
+import { getTrailingApiVersion, withoutTrailingApiVersion } from '@renderer/utils'
import { isToolUseModeFunction } from '@renderer/utils/assistant'
import {
geminiFunctionCallToMcpTool,
@@ -163,6 +164,10 @@ export class GeminiAPIClient extends BaseApiClient<
return models
}
+ override getBaseURL(): string {
+ return withoutTrailingApiVersion(super.getBaseURL())
+ }
+
override async getSdkInstance() {
if (this.sdkInstance) {
return this.sdkInstance
@@ -188,6 +193,13 @@ export class GeminiAPIClient extends BaseApiClient<
if (this.provider.isVertex) {
return 'v1'
}
+
+ // Extract trailing API version from the URL
+ const trailingVersion = getTrailingApiVersion(this.provider.apiHost || '')
+ if (trailingVersion) {
+ return trailingVersion
+ }
+
return 'v1beta'
}
diff --git a/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts
index 49a96a8f19..fb371d9ae5 100644
--- a/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts
+++ b/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts
@@ -1,7 +1,8 @@
import { GoogleGenAI } from '@google/genai'
import { loggerService } from '@logger'
-import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
+import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import type { Model, Provider, VertexProvider } from '@renderer/types'
+import { isVertexProvider } from '@renderer/utils/provider'
import { isEmpty } from 'lodash'
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'
diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts
index 8ff25e356d..cfc9087545 100644
--- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts
+++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts
@@ -10,12 +10,9 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
- getOpenAIWebSearchParams,
getThinkModelType,
- isClaudeReasoningModel,
isDeepSeekHybridInferenceModel,
isDoubaoThinkingAutoModel,
- isGeminiReasoningModel,
isGPT5SeriesModel,
isGrokReasoningModel,
isNotSupportSystemMessageModel,
@@ -39,12 +36,6 @@ import {
MODEL_SUPPORTED_REASONING_EFFORT,
ZHIPU_RESULT_TOKENS
} from '@renderer/config/models'
-import {
- isSupportArrayContentProvider,
- isSupportDeveloperRoleProvider,
- isSupportEnableThinkingProvider,
- isSupportStreamOptionsProvider
-} from '@renderer/config/providers'
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
import { estimateTextTokens } from '@renderer/services/TokenService'
@@ -88,6 +79,12 @@ import {
openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
+import {
+ isSupportArrayContentProvider,
+ isSupportDeveloperRoleProvider,
+ isSupportEnableThinkingProvider,
+ isSupportStreamOptionsProvider
+} from '@renderer/utils/provider'
import { t } from 'i18next'
import type { GenericChunk } from '../../middleware/schemas'
@@ -651,7 +648,6 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
logger.warn('No user message. Some providers may not support.')
}
- // poe 需要通过用户消息传递 reasoningEffort
const reasoningEffort = this.getReasoningEffort(assistant, model)
const lastUserMsg = userMessages.findLast((m) => m.role === 'user')
@@ -662,22 +658,6 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
lastUserMsg.content = processPostsuffixQwen3Model(currentContent, qwenThinkModeEnabled)
}
- if (this.provider.id === SystemProviderIds.poe) {
- // 如果以后 poe 支持 reasoning_effort 参数了,可以删掉这部分
- let suffix = ''
- if (isGPT5SeriesModel(model) && reasoningEffort.reasoning_effort) {
- suffix = ` --reasoning_effort ${reasoningEffort.reasoning_effort}`
- } else if (isClaudeReasoningModel(model) && reasoningEffort.thinking?.budget_tokens) {
- suffix = ` --thinking_budget ${reasoningEffort.thinking.budget_tokens}`
- } else if (isGeminiReasoningModel(model) && reasoningEffort.extra_body?.google?.thinking_config) {
- suffix = ` --thinking_budget ${reasoningEffort.extra_body.google.thinking_config.thinking_budget}`
- }
- // FIXME: poe 不支持多个text part,上传文本文件的时候用的不是file part而是text part,因此会出问题
- // 临时解决方案是强制poe用string content,但是其实poe部分支持array
- if (typeof lastUserMsg.content === 'string') {
- lastUserMsg.content += suffix
- }
- }
}
// 4. 最终请求消息
@@ -733,9 +713,11 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
...modalities,
// groq 有不同的 service tier 配置,不符合 openai 接口类型
service_tier: this.getServiceTier(model) as OpenAIServiceTier,
+ // verbosity. getVerbosity ensures the returned value is valid.
+ verbosity: this.getVerbosity(model),
...this.getProviderSpecificParameters(assistant, model),
...reasoningEffort,
- ...getOpenAIWebSearchParams(model, enableWebSearch),
+ // ...getOpenAIWebSearchParams(model, enableWebSearch),
// OpenRouter usage tracking
...(this.provider.id === 'openrouter' ? { usage: { include: true } } : {}),
...extra_body,
diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts
index abd1793618..dc97e74a3c 100644
--- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts
+++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts
@@ -11,7 +11,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import type { SettingsState } from '@renderer/store/settings'
-import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
+import { type Assistant, type GenerateImageParams, type Model, type Provider } from '@renderer/types'
import type {
OpenAIResponseSdkMessageParam,
OpenAIResponseSdkParams,
@@ -25,7 +25,8 @@ import type {
OpenAISdkRawOutput,
ReasoningEffortOptionalParams
} from '@renderer/types/sdk'
-import { formatApiHost } from '@renderer/utils/api'
+import { formatApiHost, withoutTrailingSlash } from '@renderer/utils/api'
+import { isOllamaProvider } from '@renderer/utils/provider'
import { BaseApiClient } from '../BaseApiClient'
@@ -48,9 +49,8 @@ export abstract class OpenAIBaseClient<
}
// 仅适用于openai
- override getBaseURL(): string {
- const host = this.provider.apiHost
- return formatApiHost(host)
+ override getBaseURL(isSupportedAPIVerion: boolean = true): string {
+ return formatApiHost(this.provider.apiHost, isSupportedAPIVerion)
}
override async generateImage({
@@ -116,6 +116,34 @@ export abstract class OpenAIBaseClient<
}))
.filter(isSupportedModel)
}
+
+ if (isOllamaProvider(this.provider)) {
+ const baseUrl = withoutTrailingSlash(this.getBaseURL(false))
+ .replace(/\/v1$/, '')
+ .replace(/\/api$/, '')
+ const response = await fetch(`${baseUrl}/api/tags`, {
+ headers: {
+ Authorization: `Bearer ${this.apiKey}`,
+ ...this.defaultHeaders(),
+ ...this.provider.extra_headers
+ }
+ })
+
+ if (!response.ok) {
+ throw new Error(`Ollama server returned ${response.status} ${response.statusText}`)
+ }
+
+ const data = await response.json()
+ if (!data?.models || !Array.isArray(data.models)) {
+ throw new Error('Invalid response from Ollama API: missing models array')
+ }
+
+ return data.models.map((model) => ({
+ id: model.name,
+ object: 'model',
+ owned_by: 'ollama'
+ }))
+ }
const response = await sdk.models.list()
if (this.provider.id === 'together') {
// @ts-ignore key is not typed
@@ -144,6 +172,11 @@ export abstract class OpenAIBaseClient<
}
let apiKeyForSdkInstance = this.apiKey
+ let baseURLForSdkInstance = this.getBaseURL()
+ let headersForSdkInstance = {
+ ...this.defaultHeaders(),
+ ...this.provider.extra_headers
+ }
if (this.provider.id === 'copilot') {
const defaultHeaders = store.getState().copilot.defaultHeaders
@@ -151,6 +184,11 @@ export abstract class OpenAIBaseClient<
// this.provider.apiKey不允许修改
// this.provider.apiKey = token
apiKeyForSdkInstance = token
+ baseURLForSdkInstance = this.getBaseURL(false)
+ headersForSdkInstance = {
+ ...headersForSdkInstance,
+ ...COPILOT_DEFAULT_HEADERS
+ }
}
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
@@ -164,12 +202,8 @@ export abstract class OpenAIBaseClient<
this.sdkInstance = new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: apiKeyForSdkInstance,
- baseURL: this.getBaseURL(),
- defaultHeaders: {
- ...this.defaultHeaders(),
- ...this.provider.extra_headers,
- ...(this.provider.id === 'copilot' ? COPILOT_DEFAULT_HEADERS : {})
- }
+ baseURL: baseURLForSdkInstance,
+ defaultHeaders: headersForSdkInstance
}) as TSdkInstance
}
return this.sdkInstance
diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts
index b9131be661..8356826e26 100644
--- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts
+++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts
@@ -12,7 +12,6 @@ import {
isSupportVerbosityModel,
isVisionModel
} from '@renderer/config/models'
-import { isSupportDeveloperRoleProvider } from '@renderer/config/providers'
import { estimateTextTokens } from '@renderer/services/TokenService'
import type {
FileMetadata,
@@ -43,6 +42,7 @@ import {
openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
+import { isSupportDeveloperRoleProvider } from '@renderer/utils/provider'
import { MB } from '@shared/config/constant'
import { t } from 'i18next'
import { isEmpty } from 'lodash'
@@ -90,7 +90,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
if (isOpenAILLMModel(model) && !isOpenAIChatCompletionOnlyModel(model)) {
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
this.provider = { ...this.provider, apiHost: this.formatApiHost() }
- if (this.provider.apiVersion === 'preview') {
+ if (this.provider.apiVersion === 'preview' || this.provider.apiVersion === 'v1') {
return this
} else {
return this.client
@@ -297,7 +297,31 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
private convertResponseToMessageContent(response: OpenAI.Responses.Response): ResponseInput {
const content: OpenAI.Responses.ResponseInput = []
- content.push(...response.output)
+ response.output.forEach((item) => {
+ if (item.type !== 'apply_patch_call' && item.type !== 'apply_patch_call_output') {
+ content.push(item)
+ } else if (item.type === 'apply_patch_call') {
+ if (item.operation !== undefined) {
+ const applyPatchToolCall: OpenAI.Responses.ResponseInputItem.ApplyPatchCall = {
+ ...item,
+ operation: item.operation
+ }
+ content.push(applyPatchToolCall)
+ } else {
+ logger.warn('Undefined tool call operation for ApplyPatchToolCall.')
+ }
+ } else if (item.type === 'apply_patch_call_output') {
+ if (item.output !== undefined) {
+ const applyPatchToolCallOutput: OpenAI.Responses.ResponseInputItem.ApplyPatchCallOutput = {
+ ...item,
+ output: item.output === null ? undefined : item.output
+ }
+ content.push(applyPatchToolCallOutput)
+ } else {
+ logger.warn('Undefined tool call operation for ApplyPatchToolCall.')
+ }
+ }
+ })
return content
}
@@ -496,7 +520,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
...(isSupportVerbosityModel(model)
? {
text: {
- verbosity: this.getVerbosity()
+ verbosity: this.getVerbosity(model)
}
}
: {}),
diff --git a/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts b/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts
index 179bb54a1e..02ac6de091 100644
--- a/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts
+++ b/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts
@@ -3,6 +3,7 @@ import { loggerService } from '@logger'
import { isSupportedModel } from '@renderer/config/models'
import type { Provider } from '@renderer/types'
import { objectKeys } from '@renderer/types'
+import { formatApiHost, withoutTrailingApiVersion } from '@renderer/utils'
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
@@ -16,11 +17,8 @@ export class OVMSClient extends OpenAIAPIClient {
override async listModels(): Promise {
try {
const sdk = await this.getSdkInstance()
-
- const chatModelsResponse = await sdk.request({
- method: 'get',
- path: '../v1/config'
- })
+ const url = formatApiHost(withoutTrailingApiVersion(this.getBaseURL()), true, 'v1')
+ const chatModelsResponse = await sdk.withOptions({ baseURL: url }).get('/config')
logger.debug(`Chat models response: ${JSON.stringify(chatModelsResponse)}`)
// Parse the config response to extract model information
diff --git a/src/renderer/src/aiCore/legacy/middleware/common/ErrorHandlerMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/common/ErrorHandlerMiddleware.ts
index 7d6a7f631a..c93e42fbb2 100644
--- a/src/renderer/src/aiCore/legacy/middleware/common/ErrorHandlerMiddleware.ts
+++ b/src/renderer/src/aiCore/legacy/middleware/common/ErrorHandlerMiddleware.ts
@@ -1,6 +1,7 @@
import { loggerService } from '@logger'
import { isZhipuModel } from '@renderer/config/models'
import { getStoreProviders } from '@renderer/hooks/useStore'
+import { getDefaultModel } from '@renderer/services/AssistantService'
import type { Chunk } from '@renderer/types/chunk'
import type { CompletionsParams, CompletionsResult } from '../schemas'
@@ -66,7 +67,7 @@ export const ErrorHandlerMiddleware =
}
function handleError(error: any, params: CompletionsParams): any {
- if (isZhipuModel(params.assistant.model) && error.status && !params.enableGenerateImage) {
+ if (isZhipuModel(params.assistant.model || getDefaultModel()) && error.status && !params.enableGenerateImage) {
return handleZhipuError(error)
}
diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts
index 3f14917cdd..10a4d59384 100644
--- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts
+++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts
@@ -1,18 +1,21 @@
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import { loggerService } from '@logger'
-import { isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
-import { isSupportEnableThinkingProvider } from '@renderer/config/providers'
+import { isGemini3Model, isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
import type { MCPTool } from '@renderer/types'
-import { type Assistant, type Message, type Model, type Provider } from '@renderer/types'
+import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
+import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import type { LanguageModelMiddleware } from 'ai'
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
import { isEmpty } from 'lodash'
+import { getAiSdkProviderId } from '../provider/factory'
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
import { noThinkMiddleware } from './noThinkMiddleware'
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
+import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
+import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware'
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
@@ -217,6 +220,14 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
middleware: noThinkMiddleware()
})
}
+
+ if (config.provider.id === SystemProviderIds.openrouter && config.enableReasoning) {
+ builder.add({
+ name: 'openrouter-reasoning-redaction',
+ middleware: openrouterReasoningMiddleware()
+ })
+ logger.debug('Added OpenRouter reasoning redaction middleware')
+ }
}
/**
@@ -229,6 +240,7 @@ function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: Ai
// Use /think or /no_think suffix to control thinking mode
if (
config.provider &&
+ !isOllamaProvider(config.provider) &&
isSupportedThinkingTokenQwenModel(config.model) &&
!isSupportEnableThinkingProvider(config.provider)
) {
@@ -248,6 +260,15 @@ function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: Ai
middleware: openrouterGenerateImageMiddleware()
})
}
+
+ if (isGemini3Model(config.model)) {
+ const aiSdkId = getAiSdkProviderId(config.provider)
+ builder.add({
+ name: 'skip-gemini3-thought-signature',
+ middleware: skipGeminiThoughtSignatureMiddleware(aiSdkId)
+ })
+ logger.debug('Added skip Gemini3 thought signature middleware')
+ }
}
/**
diff --git a/src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts b/src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts
new file mode 100644
index 0000000000..9ef3df61e9
--- /dev/null
+++ b/src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts
@@ -0,0 +1,50 @@
+import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
+import type { LanguageModelMiddleware } from 'ai'
+
+/**
+ * https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
+ *
+ * @returns LanguageModelMiddleware - a middleware filter redacted block
+ */
+export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
+ const REDACTED_BLOCK = '[REDACTED]'
+ return {
+ middlewareVersion: 'v2',
+ wrapGenerate: async ({ doGenerate }) => {
+ const { content, ...rest } = await doGenerate()
+ const modifiedContent = content.map((part) => {
+ if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
+ return {
+ ...part,
+ text: part.text.replace(REDACTED_BLOCK, '')
+ }
+ }
+ return part
+ })
+ return { content: modifiedContent, ...rest }
+ },
+ wrapStream: async ({ doStream }) => {
+ const { stream, ...rest } = await doStream()
+ return {
+ stream: stream.pipeThrough(
+ new TransformStream({
+ transform(
+ chunk: LanguageModelV2StreamPart,
+ controller: TransformStreamDefaultController
+ ) {
+ if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
+ controller.enqueue({
+ ...chunk,
+ delta: chunk.delta.replace(REDACTED_BLOCK, '')
+ })
+ } else {
+ controller.enqueue(chunk)
+ }
+ }
+ })
+ ),
+ ...rest
+ }
+ }
+ }
+}
diff --git a/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts b/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts
new file mode 100644
index 0000000000..da318ea60d
--- /dev/null
+++ b/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts
@@ -0,0 +1,36 @@
+import type { LanguageModelMiddleware } from 'ai'
+
+/**
+ * skip Gemini Thought Signature Middleware
+ * 由于多模型客户端请求的复杂性(可以中途切换其他模型),这里选择通过中间件方式添加跳过所有 Gemini3 思考签名
+ * Due to the complexity of multi-model client requests (which can switch to other models mid-process),
+ * it was decided to add a skip for all Gemini3 thinking signatures via middleware.
+ * @param aiSdkId AI SDK Provider ID
+ * @returns LanguageModelMiddleware
+ */
+export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware {
+ const MAGIC_STRING = 'skip_thought_signature_validator'
+ return {
+ middlewareVersion: 'v2',
+
+ transformParams: async ({ params }) => {
+ const transformedParams = { ...params }
+ // Process messages in prompt
+ if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
+ transformedParams.prompt = transformedParams.prompt.map((message) => {
+ if (typeof message.content !== 'string') {
+ for (const part of message.content) {
+ const googleOptions = part?.providerOptions?.[aiSdkId]
+ if (googleOptions?.thoughtSignature) {
+ googleOptions.thoughtSignature = MAGIC_STRING
+ }
+ }
+ }
+ return message
+ })
+ }
+
+ return transformedParams
+ }
+ }
+}
diff --git a/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts b/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts
new file mode 100644
index 0000000000..2433192cd0
--- /dev/null
+++ b/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts
@@ -0,0 +1,239 @@
+import type { Message, Model } from '@renderer/types'
+import type { FileMetadata } from '@renderer/types/file'
+import { FileTypes } from '@renderer/types/file'
+import {
+ AssistantMessageStatus,
+ type FileMessageBlock,
+ type ImageMessageBlock,
+ MessageBlockStatus,
+ MessageBlockType,
+ type ThinkingMessageBlock,
+ UserMessageStatus
+} from '@renderer/types/newMessage'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+const { convertFileBlockToFilePartMock, convertFileBlockToTextPartMock } = vi.hoisted(() => ({
+ convertFileBlockToFilePartMock: vi.fn(),
+ convertFileBlockToTextPartMock: vi.fn()
+}))
+
+vi.mock('../fileProcessor', () => ({
+ convertFileBlockToFilePart: convertFileBlockToFilePartMock,
+ convertFileBlockToTextPart: convertFileBlockToTextPartMock
+}))
+
+const visionModelIds = new Set(['gpt-4o-mini', 'qwen-image-edit'])
+const imageEnhancementModelIds = new Set(['qwen-image-edit'])
+
+vi.mock('@renderer/config/models', () => ({
+ isVisionModel: (model: Model) => visionModelIds.has(model.id),
+ isImageEnhancementModel: (model: Model) => imageEnhancementModelIds.has(model.id)
+}))
+
+type MockableMessage = Message & {
+ __mockContent?: string
+ __mockFileBlocks?: FileMessageBlock[]
+ __mockImageBlocks?: ImageMessageBlock[]
+ __mockThinkingBlocks?: ThinkingMessageBlock[]
+}
+
+vi.mock('@renderer/utils/messageUtils/find', () => ({
+ getMainTextContent: (message: Message) => (message as MockableMessage).__mockContent ?? '',
+ findFileBlocks: (message: Message) => (message as MockableMessage).__mockFileBlocks ?? [],
+ findImageBlocks: (message: Message) => (message as MockableMessage).__mockImageBlocks ?? [],
+ findThinkingBlocks: (message: Message) => (message as MockableMessage).__mockThinkingBlocks ?? []
+}))
+
+import { convertMessagesToSdkMessages, convertMessageToSdkParam } from '../messageConverter'
+
+let messageCounter = 0
+let blockCounter = 0
+
+const createModel = (overrides: Partial = {}): Model => ({
+ id: 'gpt-4o-mini',
+ name: 'GPT-4o mini',
+ provider: 'openai',
+ group: 'openai',
+ ...overrides
+})
+
+const createMessage = (role: Message['role']): MockableMessage =>
+ ({
+ id: `message-${++messageCounter}`,
+ role,
+ assistantId: 'assistant-1',
+ topicId: 'topic-1',
+ createdAt: new Date(2024, 0, 1, 0, 0, messageCounter).toISOString(),
+ status: role === 'assistant' ? AssistantMessageStatus.SUCCESS : UserMessageStatus.SUCCESS,
+ blocks: []
+ }) as MockableMessage
+
+const createFileBlock = (
+ messageId: string,
+ overrides: Partial> & { file?: Partial } = {}
+): FileMessageBlock => {
+ const { file, ...blockOverrides } = overrides
+ const timestamp = new Date(2024, 0, 1, 0, 0, ++blockCounter).toISOString()
+ return {
+ id: blockOverrides.id ?? `file-block-${blockCounter}`,
+ messageId,
+ type: MessageBlockType.FILE,
+ createdAt: blockOverrides.createdAt ?? timestamp,
+ status: blockOverrides.status ?? MessageBlockStatus.SUCCESS,
+ file: {
+ id: file?.id ?? `file-${blockCounter}`,
+ name: file?.name ?? 'document.txt',
+ origin_name: file?.origin_name ?? 'document.txt',
+ path: file?.path ?? '/tmp/document.txt',
+ size: file?.size ?? 1024,
+ ext: file?.ext ?? '.txt',
+ type: file?.type ?? FileTypes.TEXT,
+ created_at: file?.created_at ?? timestamp,
+ count: file?.count ?? 1,
+ ...file
+ },
+ ...blockOverrides
+ }
+}
+
+const createImageBlock = (
+ messageId: string,
+ overrides: Partial> = {}
+): ImageMessageBlock => ({
+ id: overrides.id ?? `image-block-${++blockCounter}`,
+ messageId,
+ type: MessageBlockType.IMAGE,
+ createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(),
+ status: overrides.status ?? MessageBlockStatus.SUCCESS,
+ url: overrides.url ?? 'https://example.com/image.png',
+ ...overrides
+})
+
+describe('messageConverter', () => {
+ beforeEach(() => {
+ convertFileBlockToFilePartMock.mockReset()
+ convertFileBlockToTextPartMock.mockReset()
+ convertFileBlockToFilePartMock.mockResolvedValue(null)
+ convertFileBlockToTextPartMock.mockResolvedValue(null)
+ messageCounter = 0
+ blockCounter = 0
+ })
+
+ describe('convertMessageToSdkParam', () => {
+ it('includes text and image parts for user messages on vision models', async () => {
+ const model = createModel()
+ const message = createMessage('user')
+ message.__mockContent = 'Describe this picture'
+ message.__mockImageBlocks = [createImageBlock(message.id, { url: 'https://example.com/cat.png' })]
+
+ const result = await convertMessageToSdkParam(message, true, model)
+
+ expect(result).toEqual({
+ role: 'user',
+ content: [
+ { type: 'text', text: 'Describe this picture' },
+ { type: 'image', image: 'https://example.com/cat.png' }
+ ]
+ })
+ })
+
+ it('returns file instructions as a system message when native uploads succeed', async () => {
+ const model = createModel()
+ const message = createMessage('user')
+ message.__mockContent = 'Summarize the PDF'
+ message.__mockFileBlocks = [createFileBlock(message.id)]
+ convertFileBlockToFilePartMock.mockResolvedValueOnce({
+ type: 'file',
+ filename: 'document.pdf',
+ mediaType: 'application/pdf',
+ data: 'fileid://remote-file'
+ })
+
+ const result = await convertMessageToSdkParam(message, false, model)
+
+ expect(result).toEqual([
+ {
+ role: 'system',
+ content: 'fileid://remote-file'
+ },
+ {
+ role: 'user',
+ content: [{ type: 'text', text: 'Summarize the PDF' }]
+ }
+ ])
+ })
+ })
+
+ describe('convertMessagesToSdkMessages', () => {
+ it('appends assistant images to the final user message for image enhancement models', async () => {
+ const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
+ const initialUser = createMessage('user')
+ initialUser.__mockContent = 'Start editing'
+
+ const assistant = createMessage('assistant')
+ assistant.__mockContent = 'Here is the current preview'
+ assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/preview.png' })]
+
+ const finalUser = createMessage('user')
+ finalUser.__mockContent = 'Increase the brightness'
+
+ const result = await convertMessagesToSdkMessages([initialUser, assistant, finalUser], model)
+
+ expect(result).toEqual([
+ {
+ role: 'user',
+ content: [{ type: 'text', text: 'Start editing' }]
+ },
+ {
+ role: 'assistant',
+ content: [{ type: 'text', text: 'Here is the current preview' }]
+ },
+ {
+ role: 'user',
+ content: [
+ { type: 'text', text: 'Increase the brightness' },
+ { type: 'image', image: 'https://example.com/preview.png' }
+ ]
+ }
+ ])
+ })
+
+ it('preserves preceding system instructions when building enhancement payloads', async () => {
+ const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
+ const fileUser = createMessage('user')
+ fileUser.__mockContent = 'Use this document as inspiration'
+ fileUser.__mockFileBlocks = [createFileBlock(fileUser.id, { file: { ext: '.pdf', type: FileTypes.DOCUMENT } })]
+ convertFileBlockToFilePartMock.mockResolvedValueOnce({
+ type: 'file',
+ filename: 'reference.pdf',
+ mediaType: 'application/pdf',
+ data: 'fileid://reference'
+ })
+
+ const assistant = createMessage('assistant')
+ assistant.__mockContent = 'Generated previews ready'
+ assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/reference.png' })]
+
+ const finalUser = createMessage('user')
+ finalUser.__mockContent = 'Apply the edits'
+
+ const result = await convertMessagesToSdkMessages([fileUser, assistant, finalUser], model)
+
+ expect(result).toEqual([
+ { role: 'system', content: 'fileid://reference' },
+ { role: 'user', content: [{ type: 'text', text: 'Use this document as inspiration' }] },
+ {
+ role: 'assistant',
+ content: [{ type: 'text', text: 'Generated previews ready' }]
+ },
+ {
+ role: 'user',
+ content: [
+ { type: 'text', text: 'Apply the edits' },
+ { type: 'image', image: 'https://example.com/reference.png' }
+ ]
+ }
+ ])
+ })
+ })
+})
diff --git a/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts b/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts
new file mode 100644
index 0000000000..70b4ac84b7
--- /dev/null
+++ b/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts
@@ -0,0 +1,218 @@
+import type { Assistant, AssistantSettings, Model, Topic } from '@renderer/types'
+import { TopicType } from '@renderer/types'
+import { defaultTimeout } from '@shared/config/constant'
+import { describe, expect, it, vi } from 'vitest'
+
+import { getTemperature, getTimeout, getTopP } from '../modelParameters'
+
+vi.mock('@renderer/services/AssistantService', () => ({
+ getAssistantSettings: (assistant: Assistant): AssistantSettings => ({
+ contextCount: assistant.settings?.contextCount ?? 4096,
+ temperature: assistant.settings?.temperature ?? 0.7,
+ enableTemperature: assistant.settings?.enableTemperature ?? true,
+ topP: assistant.settings?.topP ?? 1,
+ enableTopP: assistant.settings?.enableTopP ?? false,
+ enableMaxTokens: assistant.settings?.enableMaxTokens ?? false,
+ maxTokens: assistant.settings?.maxTokens,
+ streamOutput: assistant.settings?.streamOutput ?? true,
+ toolUseMode: assistant.settings?.toolUseMode ?? 'prompt',
+ defaultModel: assistant.defaultModel,
+ customParameters: assistant.settings?.customParameters ?? [],
+ reasoning_effort: assistant.settings?.reasoning_effort,
+ reasoning_effort_cache: assistant.settings?.reasoning_effort_cache,
+ qwenThinkMode: assistant.settings?.qwenThinkMode
+ })
+}))
+
+vi.mock('@renderer/hooks/useSettings', () => ({
+ getStoreSetting: vi.fn(),
+ useSettings: vi.fn(() => ({})),
+ useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left', isLeftNavbar: true, isTopNavbar: false }))
+}))
+
+vi.mock('@renderer/hooks/useStore', () => ({
+ getStoreProviders: vi.fn(() => [])
+}))
+
+vi.mock('@renderer/store/settings', () => ({
+ default: (state = { settings: {} }) => state
+}))
+
+vi.mock('@renderer/store/assistants', () => ({
+ default: (state = { assistants: [] }) => state
+}))
+
+const createTopic = (assistantId: string): Topic => ({
+ id: `topic-${assistantId}`,
+ assistantId,
+ name: 'topic',
+ createdAt: new Date().toISOString(),
+ updatedAt: new Date().toISOString(),
+ messages: [],
+ type: TopicType.Chat
+})
+
+const createAssistant = (settings: Assistant['settings'] = {}): Assistant => {
+ const assistantId = 'assistant-1'
+ return {
+ id: assistantId,
+ name: 'Test Assistant',
+ prompt: 'prompt',
+ topics: [createTopic(assistantId)],
+ type: 'assistant',
+ settings
+ }
+}
+
+const createModel = (overrides: Partial = {}): Model => ({
+ id: 'gpt-4o',
+ provider: 'openai',
+ name: 'GPT-4o',
+ group: 'openai',
+ ...overrides
+})
+
+describe('modelParameters', () => {
+ describe('getTemperature', () => {
+ it('returns undefined when reasoning effort is enabled for Claude models', () => {
+ const assistant = createAssistant({ reasoning_effort: 'medium' })
+ const model = createModel({ id: 'claude-opus-4', name: 'Claude Opus 4', provider: 'anthropic', group: 'claude' })
+
+ expect(getTemperature(assistant, model)).toBeUndefined()
+ })
+
+ it('returns undefined for models without temperature/topP support', () => {
+ const assistant = createAssistant({ enableTemperature: true })
+ const model = createModel({ id: 'qwen-mt-large', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
+
+ expect(getTemperature(assistant, model)).toBeUndefined()
+ })
+
+ it('returns undefined for Claude 4.5 reasoning models when only TopP is enabled', () => {
+ const assistant = createAssistant({ enableTopP: true, enableTemperature: false })
+ const model = createModel({
+ id: 'claude-sonnet-4.5',
+ name: 'Claude Sonnet 4.5',
+ provider: 'anthropic',
+ group: 'claude'
+ })
+
+ expect(getTemperature(assistant, model)).toBeUndefined()
+ })
+
+ it('returns configured temperature when enabled', () => {
+ const assistant = createAssistant({ enableTemperature: true, temperature: 0.42 })
+ const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
+
+ expect(getTemperature(assistant, model)).toBe(0.42)
+ })
+
+ it('returns undefined when temperature is disabled', () => {
+ const assistant = createAssistant({ enableTemperature: false, temperature: 0.9 })
+ const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
+
+ expect(getTemperature(assistant, model)).toBeUndefined()
+ })
+
+ it('clamps temperature to max 1.0 for Zhipu models', () => {
+ const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
+ const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
+
+ expect(getTemperature(assistant, model)).toBe(1.0)
+ })
+
+ it('clamps temperature to max 1.0 for Anthropic models', () => {
+ const assistant = createAssistant({ enableTemperature: true, temperature: 1.5 })
+ const model = createModel({
+ id: 'claude-sonnet-3.5',
+ name: 'Claude 3.5 Sonnet',
+ provider: 'anthropic',
+ group: 'claude'
+ })
+
+ expect(getTemperature(assistant, model)).toBe(1.0)
+ })
+
+ it('clamps temperature to max 1.0 for Moonshot models', () => {
+ const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
+ const model = createModel({
+ id: 'moonshot-v1-8k',
+ name: 'Moonshot v1 8k',
+ provider: 'moonshot',
+ group: 'moonshot'
+ })
+
+ expect(getTemperature(assistant, model)).toBe(1.0)
+ })
+
+ it('does not clamp temperature for OpenAI models', () => {
+ const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
+ const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
+
+ expect(getTemperature(assistant, model)).toBe(2.0)
+ })
+
+ it('does not clamp temperature when it is already within limits', () => {
+ const assistant = createAssistant({ enableTemperature: true, temperature: 0.8 })
+ const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
+
+ expect(getTemperature(assistant, model)).toBe(0.8)
+ })
+ })
+
+ describe('getTopP', () => {
+ it('returns undefined when reasoning effort is enabled for Claude models', () => {
+ const assistant = createAssistant({ reasoning_effort: 'high' })
+ const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
+
+ expect(getTopP(assistant, model)).toBeUndefined()
+ })
+
+ it('returns undefined for models without TopP support', () => {
+ const assistant = createAssistant({ enableTopP: true })
+ const model = createModel({ id: 'qwen-mt-small', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
+
+ expect(getTopP(assistant, model)).toBeUndefined()
+ })
+
+ it('returns undefined for Claude 4.5 reasoning models when temperature is enabled', () => {
+ const assistant = createAssistant({ enableTemperature: true })
+ const model = createModel({
+ id: 'claude-opus-4.5',
+ name: 'Claude Opus 4.5',
+ provider: 'anthropic',
+ group: 'claude'
+ })
+
+ expect(getTopP(assistant, model)).toBeUndefined()
+ })
+
+ it('returns configured TopP when enabled', () => {
+ const assistant = createAssistant({ enableTopP: true, topP: 0.73 })
+ const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
+
+ expect(getTopP(assistant, model)).toBe(0.73)
+ })
+
+ it('returns undefined when TopP is disabled', () => {
+ const assistant = createAssistant({ enableTopP: false, topP: 0.5 })
+ const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
+
+ expect(getTopP(assistant, model)).toBeUndefined()
+ })
+ })
+
+ describe('getTimeout', () => {
+ it('uses an extended timeout for flex service tier models', () => {
+ const model = createModel({ id: 'o3-pro', provider: 'openai', group: 'openai' })
+
+ expect(getTimeout(model)).toBe(15 * 1000 * 60)
+ })
+
+ it('falls back to the default timeout otherwise', () => {
+ const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
+
+ expect(getTimeout(model)).toBe(defaultTimeout)
+ })
+ })
+})
diff --git a/src/renderer/src/aiCore/prepareParams/header.ts b/src/renderer/src/aiCore/prepareParams/header.ts
new file mode 100644
index 0000000000..480f13314e
--- /dev/null
+++ b/src/renderer/src/aiCore/prepareParams/header.ts
@@ -0,0 +1,33 @@
+import { isClaude4SeriesModel, isClaude45ReasoningModel } from '@renderer/config/models'
+import { getProviderByModel } from '@renderer/services/AssistantService'
+import type { Assistant, Model } from '@renderer/types'
+import { isToolUseModeFunction } from '@renderer/utils/assistant'
+import { isAwsBedrockProvider, isVertexProvider } from '@renderer/utils/provider'
+
+// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
+const INTERLEAVED_THINKING_HEADER = 'interleaved-thinking-2025-05-14'
+// https://docs.claude.com/en/docs/build-with-claude/context-windows#1m-token-context-window
+// const CONTEXT_100M_HEADER = 'context-1m-2025-08-07'
+// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/web-search
+const WEBSEARCH_HEADER = 'web-search-2025-03-05'
+
+export function addAnthropicHeaders(assistant: Assistant, model: Model): string[] {
+ const anthropicHeaders: string[] = []
+ const provider = getProviderByModel(model)
+ if (
+ isClaude45ReasoningModel(model) &&
+ isToolUseModeFunction(assistant) &&
+ !(isVertexProvider(provider) || isAwsBedrockProvider(provider))
+ ) {
+ anthropicHeaders.push(INTERLEAVED_THINKING_HEADER)
+ }
+ if (isClaude4SeriesModel(model)) {
+ if (isVertexProvider(provider) && assistant.enableWebSearch) {
+ anthropicHeaders.push(WEBSEARCH_HEADER)
+ }
+ // We may add it by user preference in assistant.settings instead of always adding it.
+ // See #11540, #11397
+ // anthropicHeaders.push(CONTEXT_100M_HEADER)
+ }
+ return anthropicHeaders
+}
diff --git a/src/renderer/src/aiCore/prepareParams/messageConverter.ts b/src/renderer/src/aiCore/prepareParams/messageConverter.ts
index 72f387d9a4..b0c432ef85 100644
--- a/src/renderer/src/aiCore/prepareParams/messageConverter.ts
+++ b/src/renderer/src/aiCore/prepareParams/messageConverter.ts
@@ -194,20 +194,20 @@ async function convertMessageToAssistantModelMessage(
* This function processes messages and transforms them into the format required by the SDK.
* It handles special cases for vision models and image enhancement models.
*
- * @param messages - Array of messages to convert. Must contain at least 2 messages when using image enhancement models.
+ * @param messages - Array of messages to convert. Must contain at least 3 messages when using image enhancement models for special handling.
* @param model - The model configuration that determines conversion behavior
*
* @returns A promise that resolves to an array of SDK-compatible model messages
*
* @remarks
- * For image enhancement models with 2+ messages:
- * - Expects the second-to-last message (index length-2) to be an assistant message containing image blocks
- * - Expects the last message (index length-1) to be a user message
- * - Extracts images from the assistant message and appends them to the user message content
- * - Returns only the last two processed messages [assistantSdkMessage, userSdkMessage]
+ * For image enhancement models with 3+ messages:
+ * - Examines the last 2 messages to find an assistant message containing image blocks
+ * - If found, extracts images from the assistant message and appends them to the last user message content
+ * - Returns all converted messages (not just the last two) with the images merged into the user message
+ * - Typical pattern: [system?, assistant(image), user] -> [system?, assistant, user(image)]
*
* For other models:
- * - Returns all converted messages in order
+ * - Returns all converted messages in order without special image handling
*
* The function automatically detects vision model capabilities and adjusts conversion accordingly.
*/
@@ -220,29 +220,25 @@ export async function convertMessagesToSdkMessages(messages: Message[], model: M
sdkMessages.push(...(Array.isArray(sdkMessage) ? sdkMessage : [sdkMessage]))
}
// Special handling for image enhancement models
- // Only keep the last two messages and merge images into the user message
- // [system?, user, assistant, user]
+ // Only merge images into the user message
+ // [system?, assistant(image), user] -> [system?, assistant, user(image)]
if (isImageEnhancementModel(model) && messages.length >= 3) {
const needUpdatedMessages = messages.slice(-2)
- const needUpdatedSdkMessages = sdkMessages.slice(-2)
- const assistantMessage = needUpdatedMessages.filter((m) => m.role === 'assistant')[0]
- const assistantSdkMessage = needUpdatedSdkMessages.filter((m) => m.role === 'assistant')[0]
- const userSdkMessage = needUpdatedSdkMessages.filter((m) => m.role === 'user')[0]
- const systemSdkMessages = sdkMessages.filter((m) => m.role === 'system')
- const imageBlocks = findImageBlocks(assistantMessage)
- const imageParts = await convertImageBlockToImagePart(imageBlocks)
- const parts: Array = []
- if (typeof userSdkMessage.content === 'string') {
- parts.push({ type: 'text', text: userSdkMessage.content })
- parts.push(...imageParts)
- userSdkMessage.content = parts
- } else {
- userSdkMessage.content.push(...imageParts)
+ const assistantMessage = needUpdatedMessages.find((m) => m.role === 'assistant')
+ const userSdkMessage = sdkMessages[sdkMessages.length - 1]
+
+ if (assistantMessage && userSdkMessage?.role === 'user') {
+ const imageBlocks = findImageBlocks(assistantMessage)
+ const imageParts = await convertImageBlockToImagePart(imageBlocks)
+
+ if (imageParts.length > 0) {
+ if (typeof userSdkMessage.content === 'string') {
+ userSdkMessage.content = [{ type: 'text', text: userSdkMessage.content }, ...imageParts]
+ } else if (Array.isArray(userSdkMessage.content)) {
+ userSdkMessage.content.push(...imageParts)
+ }
+ }
}
- if (systemSdkMessages.length > 0) {
- return [systemSdkMessages[0], assistantSdkMessage, userSdkMessage]
- }
- return [assistantSdkMessage, userSdkMessage]
}
return sdkMessages
diff --git a/src/renderer/src/aiCore/prepareParams/modelCapabilities.ts b/src/renderer/src/aiCore/prepareParams/modelCapabilities.ts
index b6e4b25843..4a3c3f4bbf 100644
--- a/src/renderer/src/aiCore/prepareParams/modelCapabilities.ts
+++ b/src/renderer/src/aiCore/prepareParams/modelCapabilities.ts
@@ -85,19 +85,6 @@ export function supportsLargeFileUpload(model: Model): boolean {
})
}
-/**
- * 检查模型是否支持TopP
- */
-export function supportsTopP(model: Model): boolean {
- const provider = getProviderByModel(model)
-
- if (provider?.type === 'anthropic' || model?.endpoint_type === 'anthropic') {
- return false
- }
-
- return true
-}
-
/**
* 获取提供商特定的文件大小限制
*/
diff --git a/src/renderer/src/aiCore/prepareParams/modelParameters.ts b/src/renderer/src/aiCore/prepareParams/modelParameters.ts
index ed3f4fa210..8a1d53a754 100644
--- a/src/renderer/src/aiCore/prepareParams/modelParameters.ts
+++ b/src/renderer/src/aiCore/prepareParams/modelParameters.ts
@@ -6,14 +6,23 @@
import {
isClaude45ReasoningModel,
isClaudeReasoningModel,
+ isMaxTemperatureOneModel,
isNotSupportTemperatureAndTopP,
- isSupportedFlexServiceTier
+ isSupportedFlexServiceTier,
+ isSupportedThinkingTokenClaudeModel
} from '@renderer/config/models'
-import { getAssistantSettings } from '@renderer/services/AssistantService'
+import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
import type { Assistant, Model } from '@renderer/types'
import { defaultTimeout } from '@shared/config/constant'
+import { getAnthropicThinkingBudget } from '../utils/reasoning'
+
/**
+ * Claude 4.5 推理模型:
+ * - 只启用 temperature → 使用 temperature
+ * - 只启用 top_p → 使用 top_p
+ * - 同时启用 → temperature 生效,top_p 被忽略
+ * - 都不启用 → 都不使用
* 获取温度参数
*/
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
@@ -27,7 +36,11 @@ export function getTemperature(assistant: Assistant, model: Model): number | und
return undefined
}
const assistantSettings = getAssistantSettings(assistant)
- return assistantSettings?.enableTemperature ? assistantSettings?.temperature : undefined
+ let temperature = assistantSettings?.temperature
+ if (temperature && isMaxTemperatureOneModel(model)) {
+ temperature = Math.min(1, temperature)
+ }
+ return assistantSettings?.enableTemperature ? temperature : undefined
}
/**
@@ -56,3 +69,26 @@ export function getTimeout(model: Model): number {
}
return defaultTimeout
}
+
+export function getMaxTokens(assistant: Assistant, model: Model): number | undefined {
+ // NOTE: ai-sdk会把maxToken和budgetToken加起来
+ const assistantSettings = getAssistantSettings(assistant)
+ const enabledMaxTokens = assistantSettings.enableMaxTokens ?? false
+ let maxTokens = assistantSettings.maxTokens
+
+ // If user hasn't enabled enableMaxTokens, return undefined to let the API use its default value.
+ // Note: Anthropic API requires max_tokens, but that's handled by the Anthropic client with a fallback.
+ if (!enabledMaxTokens || maxTokens === undefined) {
+ return undefined
+ }
+
+ const provider = getProviderByModel(model)
+ if (isSupportedThinkingTokenClaudeModel(model) && ['anthropic', 'aws-bedrock'].includes(provider.type)) {
+ const { reasoning_effort: reasoningEffort } = assistantSettings
+ const budget = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
+ if (budget) {
+ maxTokens -= budget
+ }
+ }
+ return maxTokens
+}
diff --git a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts
index 397c481cf3..cba7fcdb10 100644
--- a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts
+++ b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts
@@ -4,43 +4,69 @@
*/
import { anthropic } from '@ai-sdk/anthropic'
+import { azure } from '@ai-sdk/azure'
import { google } from '@ai-sdk/google'
import { vertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge'
import { vertex } from '@ai-sdk/google-vertex/edge'
-import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
+import { combineHeaders } from '@ai-sdk/provider-utils'
+import type { AnthropicSearchConfig, WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import { isBaseProvider } from '@cherrystudio/ai-core/core/providers/schemas'
+import type { BaseProviderId } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import {
+ isAnthropicModel,
+ isFixedReasoningModel,
+ isGeminiModel,
isGenerateImageModel,
+ isGrokModel,
+ isOpenAIModel,
isOpenRouterBuiltInWebSearchModel,
- isReasoningModel,
isSupportedReasoningEffortModel,
- isSupportedThinkingTokenClaudeModel,
isSupportedThinkingTokenModel,
isWebSearchModel
} from '@renderer/config/models'
-import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
+import { getDefaultModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
-import { type Assistant, type MCPTool, type Provider } from '@renderer/types'
+import type { Model } from '@renderer/types'
+import { type Assistant, type MCPTool, type Provider, SystemProviderIds } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
import { replacePromptVariables } from '@renderer/utils/prompt'
+import { isAIGatewayProvider, isAwsBedrockProvider } from '@renderer/utils/provider'
import type { ModelMessage, Tool } from 'ai'
import { stepCountIs } from 'ai'
import { getAiSdkProviderId } from '../provider/factory'
import { setupToolsConfig } from '../utils/mcp'
import { buildProviderOptions } from '../utils/options'
-import { getAnthropicThinkingBudget } from '../utils/reasoning'
import { buildProviderBuiltinWebSearchConfig } from '../utils/websearch'
-import { supportsTopP } from './modelCapabilities'
-import { getTemperature, getTopP } from './modelParameters'
+import { addAnthropicHeaders } from './header'
+import { getMaxTokens, getTemperature, getTopP } from './modelParameters'
const logger = loggerService.withContext('parameterBuilder')
type ProviderDefinedTool = Extract, { type: 'provider-defined' }>
+function mapVertexAIGatewayModelToProviderId(model: Model): BaseProviderId | undefined {
+ if (isAnthropicModel(model)) {
+ return 'anthropic'
+ }
+ if (isGeminiModel(model)) {
+ return 'google'
+ }
+ if (isGrokModel(model)) {
+ return 'xai'
+ }
+ if (isOpenAIModel(model)) {
+ return 'openai'
+ }
+ logger.warn(
+ `[mapVertexAIGatewayModelToProviderId] Unknown model type for AI Gateway: ${model.id}. Web search will not be enabled.`
+ )
+ return undefined
+}
+
/**
* 构建 AI SDK 流式参数
* 这是主要的参数构建函数,整合所有转换逻辑
@@ -58,7 +84,7 @@ export async function buildStreamTextParams(
timeout?: number
headers?: Record
}
- } = {}
+ }
): Promise<{
params: StreamTextParams
modelId: string
@@ -75,15 +101,13 @@ export async function buildStreamTextParams(
const model = assistant.model || getDefaultModel()
const aiSdkProviderId = getAiSdkProviderId(provider)
- let { maxTokens } = getAssistantSettings(assistant)
-
// 这三个变量透传出来,交给下面启用插件/中间件
// 也可以在外部构建好再传入buildStreamTextParams
// FIXME: qwen3即使关闭思考仍然会导致enableReasoning的结果为true
const enableReasoning =
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
assistant.settings?.reasoning_effort !== undefined) ||
- (isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
+ isFixedReasoningModel(model)
// 判断是否使用内置搜索
// 条件:没有外部搜索提供商 && (用户开启了内置搜索 || 模型强制使用内置搜索)
@@ -107,26 +131,21 @@ export async function buildStreamTextParams(
searchWithTime: store.getState().websearch.searchWithTime
}
- const providerOptions = buildProviderOptions(assistant, model, provider, {
+ const { providerOptions, standardParams } = buildProviderOptions(assistant, model, provider, {
enableReasoning,
enableWebSearch,
enableGenerateImage
})
- // NOTE: ai-sdk会把maxToken和budgetToken加起来
- if (
- enableReasoning &&
- maxTokens !== undefined &&
- isSupportedThinkingTokenClaudeModel(model) &&
- (provider.type === 'anthropic' || provider.type === 'aws-bedrock')
- ) {
- maxTokens -= getAnthropicThinkingBudget(assistant, model)
- }
-
let webSearchPluginConfig: WebSearchPluginConfig | undefined = undefined
if (enableWebSearch) {
if (isBaseProvider(aiSdkProviderId)) {
webSearchPluginConfig = buildProviderBuiltinWebSearchConfig(aiSdkProviderId, webSearchConfig, model)
+ } else if (isAIGatewayProvider(provider) || SystemProviderIds.gateway === provider.id) {
+ const aiSdkProviderId = mapVertexAIGatewayModelToProviderId(model)
+ if (aiSdkProviderId) {
+ webSearchPluginConfig = buildProviderBuiltinWebSearchConfig(aiSdkProviderId, webSearchConfig, model)
+ }
}
if (!tools) {
tools = {}
@@ -139,6 +158,17 @@ export async function buildStreamTextParams(
maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
}) as ProviderDefinedTool
+ } else if (aiSdkProviderId === 'azure-responses') {
+ tools.web_search_preview = azure.tools.webSearchPreview({
+ searchContextSize: webSearchPluginConfig?.openai!.searchContextSize
+ }) as ProviderDefinedTool
+ } else if (aiSdkProviderId === 'azure-anthropic') {
+ const blockedDomains = mapRegexToPatterns(webSearchConfig.excludeDomains)
+ const anthropicSearchOptions: AnthropicSearchConfig = {
+ maxUses: webSearchConfig.maxResults,
+ blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
+ }
+ tools.web_search = anthropic.tools.webSearch_20250305(anthropicSearchOptions) as ProviderDefinedTool
}
}
@@ -156,9 +186,10 @@ export async function buildStreamTextParams(
tools.url_context = google.tools.urlContext({}) as ProviderDefinedTool
break
case 'anthropic':
+ case 'azure-anthropic':
case 'google-vertex-anthropic':
tools.web_fetch = (
- aiSdkProviderId === 'anthropic'
+ ['anthropic', 'azure-anthropic'].includes(aiSdkProviderId)
? anthropic.tools.webFetch_20250910({
maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
@@ -172,22 +203,35 @@ export async function buildStreamTextParams(
}
}
+ let headers: Record = options.requestOptions?.headers ?? {}
+
+ if (isAnthropicModel(model) && !isAwsBedrockProvider(provider)) {
+ const betaHeaders = addAnthropicHeaders(assistant, model)
+ // Only add the anthropic-beta header if there are actual beta headers to include
+ if (betaHeaders.length > 0) {
+ const newBetaHeaders = { 'anthropic-beta': betaHeaders.join(',') }
+ headers = combineHeaders(headers, newBetaHeaders)
+ }
+ }
+
// 构建基础参数
+ // Note: standardParams (topK, frequencyPenalty, presencePenalty, stopSequences, seed)
+ // are extracted from custom parameters and passed directly to streamText()
+ // instead of being placed in providerOptions
const params: StreamTextParams = {
messages: sdkMessages,
- maxOutputTokens: maxTokens,
+ maxOutputTokens: getMaxTokens(assistant, model),
temperature: getTemperature(assistant, model),
+ topP: getTopP(assistant, model),
+ // Include AI SDK standard params extracted from custom parameters
+ ...standardParams,
abortSignal: options.requestOptions?.signal,
- headers: options.requestOptions?.headers,
+ headers,
providerOptions,
stopWhen: stepCountIs(20),
maxRetries: 0
}
- if (supportsTopP(model)) {
- params.topP = getTopP(assistant, model)
- }
-
if (tools) {
params.tools = tools
}
diff --git a/src/renderer/src/aiCore/provider/__tests__/integratedRegistry.test.ts b/src/renderer/src/aiCore/provider/__tests__/integratedRegistry.test.ts
index e26597e2d1..9b2c0639e2 100644
--- a/src/renderer/src/aiCore/provider/__tests__/integratedRegistry.test.ts
+++ b/src/renderer/src/aiCore/provider/__tests__/integratedRegistry.test.ts
@@ -1,4 +1,4 @@
-import type { Provider } from '@renderer/types'
+import type { Model, Provider } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest'
import { getAiSdkProviderId } from '../factory'
@@ -23,6 +23,26 @@ vi.mock('@cherrystudio/ai-core', () => ({
}
}))
+vi.mock('@renderer/services/AssistantService', () => ({
+ getProviderByModel: vi.fn(),
+ getAssistantSettings: vi.fn(),
+ getDefaultAssistant: vi.fn().mockReturnValue({
+ id: 'default',
+ name: 'Default Assistant',
+ prompt: '',
+ settings: {}
+ })
+}))
+
+vi.mock('@renderer/store/settings', () => ({
+ default: {},
+ settingsSlice: {
+ name: 'settings',
+ reducer: vi.fn(),
+ actions: {}
+ }
+}))
+
// Mock the provider configs
vi.mock('../providerConfigs', () => ({
initializeNewProviders: vi.fn()
@@ -48,6 +68,18 @@ function createTestProvider(id: string, type: string): Provider {
} as Provider
}
+function createAzureProvider(id: string, apiVersion?: string, model?: string): Provider {
+ return {
+ id,
+ type: 'azure-openai',
+ name: `Azure Test ${id}`,
+ apiKey: 'azure-test-key',
+ apiHost: 'azure-test-host',
+ apiVersion,
+ models: [{ id: model || 'gpt-4' } as Model]
+ }
+}
+
describe('Integrated Provider Registry', () => {
describe('Provider ID Resolution', () => {
it('should resolve openrouter provider correctly', () => {
@@ -91,6 +123,24 @@ describe('Integrated Provider Registry', () => {
const result = getAiSdkProviderId(unknownProvider)
expect(result).toBe('unknown-provider')
})
+
+ it('should handle Azure OpenAI providers correctly', () => {
+ const azureProvider = createAzureProvider('azure-test', '2024-02-15', 'gpt-4o')
+ const result = getAiSdkProviderId(azureProvider)
+ expect(result).toBe('azure')
+ })
+
+ it('should handle Azure OpenAI providers response endpoint correctly', () => {
+ const azureProvider = createAzureProvider('azure-test', 'v1', 'gpt-4o')
+ const result = getAiSdkProviderId(azureProvider)
+ expect(result).toBe('azure-responses')
+ })
+
+ it('should handle Azure provider Claude Models', () => {
+ const provider = createTestProvider('azure-anthropic', 'anthropic')
+ const result = getAiSdkProviderId(provider)
+ expect(result).toBe('azure-anthropic')
+ })
})
describe('Backward Compatibility', () => {
diff --git a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts
index 39786231e6..43d3cc52b8 100644
--- a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts
+++ b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts
@@ -12,14 +12,25 @@ vi.mock('@renderer/services/LoggerService', () => ({
}))
vi.mock('@renderer/services/AssistantService', () => ({
- getProviderByModel: vi.fn()
+ getProviderByModel: vi.fn(),
+ getAssistantSettings: vi.fn(),
+ getDefaultAssistant: vi.fn().mockReturnValue({
+ id: 'default',
+ name: 'Default Assistant',
+ prompt: '',
+ settings: {}
+ })
}))
-vi.mock('@renderer/store', () => ({
- default: {
- getState: () => ({ copilot: { defaultHeaders: {} } })
+vi.mock('@renderer/store', () => {
+ const mockGetState = vi.fn()
+ return {
+ default: {
+ getState: mockGetState
+ },
+ __mockGetState: mockGetState
}
-}))
+})
vi.mock('@renderer/utils/api', () => ({
formatApiHost: vi.fn((host, isSupportedAPIVersion = true) => {
@@ -34,7 +45,7 @@ vi.mock('@renderer/utils/api', () => ({
}))
}))
-vi.mock('@renderer/config/providers', async (importOriginal) => {
+vi.mock('@renderer/utils/provider', async (importOriginal) => {
const actual = (await importOriginal()) as any
return {
...actual,
@@ -53,14 +64,27 @@ vi.mock('@renderer/hooks/useVertexAI', () => ({
createVertexProvider: vi.fn()
}))
-import { isCherryAIProvider, isPerplexityProvider } from '@renderer/config/providers'
+vi.mock('@renderer/services/AssistantService', () => ({
+ getProviderByModel: vi.fn(),
+ getAssistantSettings: vi.fn(),
+ getDefaultAssistant: vi.fn().mockReturnValue({
+ id: 'default',
+ name: 'Default Assistant',
+ prompt: '',
+ settings: {}
+ })
+}))
+
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api'
+import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'
+const { __mockGetState: mockGetState } = vi.mocked(await import('@renderer/store')) as any
+
const createWindowKeyv = () => {
const store = new Map()
return {
@@ -114,6 +138,16 @@ describe('Copilot responses routing', () => {
...(globalThis as any).window,
keyv: createWindowKeyv()
}
+ mockGetState.mockReturnValue({
+ copilot: { defaultHeaders: {} },
+ settings: {
+ openAI: {
+ streamOptions: {
+ includeUsage: undefined
+ }
+ }
+ }
+ })
})
it('detects official GPT-5 Codex identifiers case-insensitively', () => {
@@ -149,6 +183,16 @@ describe('CherryAI provider configuration', () => {
...(globalThis as any).window,
keyv: createWindowKeyv()
}
+ mockGetState.mockReturnValue({
+ copilot: { defaultHeaders: {} },
+ settings: {
+ openAI: {
+ streamOptions: {
+ includeUsage: undefined
+ }
+ }
+ }
+ })
vi.clearAllMocks()
})
@@ -213,6 +257,16 @@ describe('Perplexity provider configuration', () => {
...(globalThis as any).window,
keyv: createWindowKeyv()
}
+ mockGetState.mockReturnValue({
+ copilot: { defaultHeaders: {} },
+ settings: {
+ openAI: {
+ streamOptions: {
+ includeUsage: undefined
+ }
+ }
+ }
+ })
vi.clearAllMocks()
})
@@ -273,3 +327,165 @@ describe('Perplexity provider configuration', () => {
expect(actualProvider.apiHost).toBe('')
})
})
+
+describe('Stream options includeUsage configuration', () => {
+ beforeEach(() => {
+ ;(globalThis as any).window = {
+ ...(globalThis as any).window,
+ keyv: createWindowKeyv()
+ }
+ vi.clearAllMocks()
+ })
+
+ const createOpenAIProvider = (): Provider => ({
+ id: 'openai-compatible',
+ type: 'openai',
+ name: 'OpenAI',
+ apiKey: 'test-key',
+ apiHost: 'https://api.openai.com',
+ models: [],
+ isSystem: true
+ })
+
+ it('uses includeUsage from settings when undefined', () => {
+ mockGetState.mockReturnValue({
+ copilot: { defaultHeaders: {} },
+ settings: {
+ openAI: {
+ streamOptions: {
+ includeUsage: undefined
+ }
+ }
+ }
+ })
+
+ const provider = createOpenAIProvider()
+ const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai'))
+
+ expect(config.options.includeUsage).toBeUndefined()
+ })
+
+ it('uses includeUsage from settings when set to true', () => {
+ mockGetState.mockReturnValue({
+ copilot: { defaultHeaders: {} },
+ settings: {
+ openAI: {
+ streamOptions: {
+ includeUsage: true
+ }
+ }
+ }
+ })
+
+ const provider = createOpenAIProvider()
+ const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai'))
+
+ expect(config.options.includeUsage).toBe(true)
+ })
+
+ it('uses includeUsage from settings when set to false', () => {
+ mockGetState.mockReturnValue({
+ copilot: { defaultHeaders: {} },
+ settings: {
+ openAI: {
+ streamOptions: {
+ includeUsage: false
+ }
+ }
+ }
+ })
+
+ const provider = createOpenAIProvider()
+ const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai'))
+
+ expect(config.options.includeUsage).toBe(false)
+ })
+
+ it('respects includeUsage setting for non-supporting providers', () => {
+ mockGetState.mockReturnValue({
+ copilot: { defaultHeaders: {} },
+ settings: {
+ openAI: {
+ streamOptions: {
+ includeUsage: true
+ }
+ }
+ }
+ })
+
+ const testProvider: Provider = {
+ id: 'test',
+ type: 'openai',
+ name: 'test',
+ apiKey: 'test-key',
+ apiHost: 'https://api.test.com',
+ models: [],
+ isSystem: false,
+ apiOptions: {
+ isNotSupportStreamOptions: true
+ }
+ }
+
+ const config = providerToAiSdkConfig(testProvider, createModel('gpt-4', 'GPT-4', 'test'))
+
+ // Even though setting is true, provider doesn't support it, so includeUsage should be undefined
+ expect(config.options.includeUsage).toBeUndefined()
+ })
+
+ it('uses includeUsage from settings for Copilot provider when set to false', () => {
+ mockGetState.mockReturnValue({
+ copilot: { defaultHeaders: {} },
+ settings: {
+ openAI: {
+ streamOptions: {
+ includeUsage: false
+ }
+ }
+ }
+ })
+
+ const provider = createCopilotProvider()
+ const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot'))
+
+ expect(config.options.includeUsage).toBe(false)
+ expect(config.providerId).toBe('github-copilot-openai-compatible')
+ })
+
+ it('uses includeUsage from settings for Copilot provider when set to true', () => {
+ mockGetState.mockReturnValue({
+ copilot: { defaultHeaders: {} },
+ settings: {
+ openAI: {
+ streamOptions: {
+ includeUsage: true
+ }
+ }
+ }
+ })
+
+ const provider = createCopilotProvider()
+ const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot'))
+
+ expect(config.options.includeUsage).toBe(true)
+ expect(config.providerId).toBe('github-copilot-openai-compatible')
+ })
+
+ it('uses includeUsage from settings for Copilot provider when undefined', () => {
+ mockGetState.mockReturnValue({
+ copilot: { defaultHeaders: {} },
+ settings: {
+ openAI: {
+ streamOptions: {
+ includeUsage: undefined
+ }
+ }
+ }
+ })
+
+ const provider = createCopilotProvider()
+ const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot'))
+
+ expect(config.options.includeUsage).toBeUndefined()
+ expect(config.providerId).toBe('github-copilot-openai-compatible')
+ })
+})
diff --git a/src/renderer/src/aiCore/provider/config/azure-anthropic.ts b/src/renderer/src/aiCore/provider/config/azure-anthropic.ts
new file mode 100644
index 0000000000..c6cb521386
--- /dev/null
+++ b/src/renderer/src/aiCore/provider/config/azure-anthropic.ts
@@ -0,0 +1,22 @@
+import type { Provider } from '@renderer/types'
+
+import { provider2Provider, startsWith } from './helper'
+import type { RuleSet } from './types'
+
+// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
+const AZURE_ANTHROPIC_RULES: RuleSet = {
+ rules: [
+ {
+ match: startsWith('claude'),
+ provider: (provider: Provider) => ({
+ ...provider,
+ type: 'anthropic',
+ apiHost: provider.apiHost + 'anthropic/v1',
+ id: 'azure-anthropic'
+ })
+ }
+ ],
+ fallbackRule: (provider: Provider) => provider
+}
+
+export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES)
diff --git a/src/renderer/src/aiCore/provider/factory.ts b/src/renderer/src/aiCore/provider/factory.ts
index 4cdbfb6d40..ff100051b7 100644
--- a/src/renderer/src/aiCore/provider/factory.ts
+++ b/src/renderer/src/aiCore/provider/factory.ts
@@ -2,8 +2,10 @@ import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } fr
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import type { Provider } from '@renderer/types'
+import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider'
import type { Provider as AiSdkProvider } from 'ai'
+import type { AiSdkConfig } from '../types'
import { initializeNewProviders } from './providerInitialization'
const logger = loggerService.withContext('ProviderFactory')
@@ -54,10 +56,18 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
/**
* 获取AI SDK Provider ID
* 简化版:减少重复逻辑,利用通用解析函数
+ * TODO: 整理函数逻辑
*/
-export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
+export function getAiSdkProviderId(provider: Provider): string {
// 1. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
+ if (isAzureOpenAIProvider(provider)) {
+ if (isAzureResponsesEndpoint(provider)) {
+ return 'azure-responses'
+ } else {
+ return 'azure'
+ }
+ }
if (resolvedFromId) {
return resolvedFromId
}
@@ -73,17 +83,19 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
- // 3. 最后的fallback(通常会成为openai-compatible)
- return provider.id as ProviderId
+ // 3. 最后的fallback(使用provider本身的id)
+ return provider.id
}
-export async function createAiSdkProvider(config) {
+export async function createAiSdkProvider(config: AiSdkConfig): Promise {
let localProvider: Awaited | null = null
try {
if (config.providerId === 'openai' && config.options?.mode === 'chat') {
config.providerId = `${config.providerId}-chat`
} else if (config.providerId === 'azure' && config.options?.mode === 'responses') {
config.providerId = `${config.providerId}-responses`
+ } else if (config.providerId === 'cherryin' && config.options?.mode === 'chat') {
+ config.providerId = 'cherryin-chat'
}
localProvider = await createProviderCore(config.providerId, config.options)
diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts
index a5a52a7eaf..f588032608 100644
--- a/src/renderer/src/aiCore/provider/providerConfig.ts
+++ b/src/renderer/src/aiCore/provider/providerConfig.ts
@@ -1,19 +1,5 @@
-import {
- formatPrivateKey,
- hasProviderConfig,
- ProviderConfigFactory,
- type ProviderId,
- type ProviderSettingsMap
-} from '@cherrystudio/ai-core/provider'
+import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
-import {
- isAnthropicProvider,
- isAzureOpenAIProvider,
- isCherryAIProvider,
- isGeminiProvider,
- isNewApiProvider,
- isPerplexityProvider
-} from '@renderer/config/providers'
import {
getAwsBedrockAccessKeyId,
getAwsBedrockApiKey,
@@ -21,43 +7,37 @@ import {
getAwsBedrockRegion,
getAwsBedrockSecretAccessKey
} from '@renderer/hooks/useAwsBedrock'
-import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
+import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
-import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api'
-import { cloneDeep } from 'lodash'
+import type { OpenAICompletionsStreamOptions } from '@renderer/types/aiCoreTypes'
+import {
+ formatApiHost,
+ formatAzureOpenAIApiHost,
+ formatOllamaApiHost,
+ formatVertexApiHost,
+ routeToEndpoint
+} from '@renderer/utils/api'
+import {
+ isAnthropicProvider,
+ isAzureOpenAIProvider,
+ isCherryAIProvider,
+ isGeminiProvider,
+ isNewApiProvider,
+ isOllamaProvider,
+ isPerplexityProvider,
+ isSupportStreamOptionsProvider,
+ isVertexProvider
+} from '@renderer/utils/provider'
+import { cloneDeep, isEmpty } from 'lodash'
+import type { AiSdkConfig } from '../types'
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
+import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory'
-/**
- * 获取轮询的API key
- * 复用legacy架构的多key轮询逻辑
- */
-function getRotatedApiKey(provider: Provider): string {
- const keys = provider.apiKey.split(',').map((key) => key.trim())
- const keyName = `provider:${provider.id}:last_used_key`
-
- if (keys.length === 1) {
- return keys[0]
- }
-
- const lastUsedKey = window.keyv.get(keyName)
- if (!lastUsedKey) {
- window.keyv.set(keyName, keys[0])
- return keys[0]
- }
-
- const currentIndex = keys.indexOf(lastUsedKey)
- const nextIndex = (currentIndex + 1) % keys.length
- const nextKey = keys[nextIndex]
- window.keyv.set(keyName, nextKey)
-
- return nextKey
-}
-
/**
* 处理特殊provider的转换逻辑
*/
@@ -74,15 +54,20 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
return vertexAnthropicProviderCreator(model, provider)
}
}
+ if (isAzureOpenAIProvider(provider)) {
+ return azureAnthropicProviderCreator(model, provider)
+ }
return provider
}
/**
- * 主要用来对齐AISdk的BaseURL格式
- * @param provider
- * @returns
+ * Format and normalize the API host URL for a provider.
+ * Handles provider-specific URL formatting rules (e.g., appending version paths, Azure formatting).
+ *
+ * @param provider - The provider whose API host is to be formatted.
+ * @returns A new provider instance with the formatted API host.
*/
-function formatProviderApiHost(provider: Provider): Provider {
+export function formatProviderApiHost(provider: Provider): Provider {
const formatted = { ...provider }
if (formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost)
@@ -90,12 +75,15 @@ function formatProviderApiHost(provider: Provider): Provider {
if (isAnthropicProvider(provider)) {
const baseHost = formatted.anthropicApiHost || formatted.apiHost
+ // AI SDK needs /v1 in baseURL, Anthropic SDK will strip it in getSdkClient
formatted.apiHost = formatApiHost(baseHost)
if (!formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatted.apiHost
}
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
+ } else if (isOllamaProvider(formatted)) {
+ formatted.apiHost = formatOllamaApiHost(formatted.apiHost)
} else if (isGeminiProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta')
} else if (isAzureOpenAIProvider(formatted)) {
@@ -113,38 +101,56 @@ function formatProviderApiHost(provider: Provider): Provider {
}
/**
- * 获取实际的Provider配置
- * 简化版:将逻辑分解为小函数
+ * Retrieve the effective Provider configuration for the given model.
+ * Applies all necessary transformations (special-provider handling, URL formatting, etc.).
+ *
+ * @param model - The model whose provider is to be resolved.
+ * @returns A new Provider instance with all adaptations applied.
*/
export function getActualProvider(model: Model): Provider {
const baseProvider = getProviderByModel(model)
- // 按顺序处理各种转换
- let actualProvider = cloneDeep(baseProvider)
- actualProvider = handleSpecialProviders(model, actualProvider)
- actualProvider = formatProviderApiHost(actualProvider)
+ return adaptProvider({ provider: baseProvider, model })
+}
- return actualProvider
+/**
+ * Transforms a provider configuration by applying model-specific adaptations and normalizing its API host.
+ * The transformations are applied in the following order:
+ * 1. Model-specific provider handling (e.g., New-API, system providers, Azure OpenAI)
+ * 2. API host formatting (provider-specific URL normalization)
+ *
+ * @param provider - The base provider configuration to transform.
+ * @param model - The model associated with the provider; optional but required for special-provider handling.
+ * @returns A new Provider instance with all transformations applied.
+ */
+export function adaptProvider({ provider, model }: { provider: Provider; model?: Model }): Provider {
+ let adaptedProvider = cloneDeep(provider)
+
+ // Apply transformations in order
+ if (model) {
+ adaptedProvider = handleSpecialProviders(model, adaptedProvider)
+ }
+ adaptedProvider = formatProviderApiHost(adaptedProvider)
+
+ return adaptedProvider
}
/**
* 将 Provider 配置转换为新 AI SDK 格式
* 简化版:利用新的别名映射系统
*/
-export function providerToAiSdkConfig(
- actualProvider: Provider,
- model: Model
-): {
- providerId: ProviderId | 'openai-compatible'
- options: ProviderSettingsMap[keyof ProviderSettingsMap]
-} {
+export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// 构建基础配置
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
const baseConfig = {
baseURL: baseURL,
- apiKey: getRotatedApiKey(actualProvider)
+ apiKey: actualProvider.apiKey
+ }
+ let includeUsage: OpenAICompletionsStreamOptions['include_usage'] = undefined
+ if (isSupportStreamOptionsProvider(actualProvider)) {
+ includeUsage = store.getState().settings.openAI?.streamOptions?.includeUsage
}
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot
@@ -157,7 +163,7 @@ export function providerToAiSdkConfig(
...actualProvider.extra_headers
},
name: actualProvider.id,
- includeUsage: true
+ includeUsage
})
return {
@@ -166,12 +172,25 @@ export function providerToAiSdkConfig(
}
}
+ if (isOllamaProvider(actualProvider)) {
+ return {
+ providerId: 'ollama',
+ options: {
+ ...baseConfig,
+ headers: {
+ ...actualProvider.extra_headers,
+ Authorization: !isEmpty(baseConfig.apiKey) ? `Bearer ${baseConfig.apiKey}` : undefined
+ }
+ }
+ }
+ }
+
// 处理OpenAI模式
const extraOptions: any = {}
extraOptions.endpoint = endpoint
if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
extraOptions.mode = 'responses'
- } else if (aiSdkProviderId === 'openai') {
+ } else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai')) {
// OAuth authentication requires using the responses API mode instead of chat mode
if (actualProvider.authType == 'oauth') {
extraOptions.mode = 'responses'
@@ -194,13 +213,12 @@ export function providerToAiSdkConfig(
}
}
// azure
- if (aiSdkProviderId === 'azure' || actualProvider.type === 'azure-openai') {
- // extraOptions.apiVersion = actualProvider.apiVersion 默认使用v1,不使用azure endpoint
- if (actualProvider.apiVersion === 'preview') {
- extraOptions.mode = 'responses'
- } else {
- extraOptions.mode = 'chat'
- }
+ // https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
+ // https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
+ if (aiSdkProviderId === 'azure-responses') {
+ extraOptions.mode = 'responses'
+ } else if (aiSdkProviderId === 'azure') {
+ extraOptions.mode = 'chat'
}
// bedrock
@@ -230,10 +248,17 @@ export function providerToAiSdkConfig(
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
}
+ // cherryin
+ if (aiSdkProviderId === 'cherryin') {
+ if (model.endpoint_type) {
+ extraOptions.endpointType = model.endpoint_type
+ }
+ }
+
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
- providerId: aiSdkProviderId as ProviderId,
+ providerId: aiSdkProviderId,
options
}
}
@@ -246,7 +271,7 @@ export function providerToAiSdkConfig(
...options,
name: actualProvider.id,
...extraOptions,
- includeUsage: true
+ includeUsage
}
}
}
@@ -318,7 +343,6 @@ export async function prepareSpecialProviderConfig(
...(config.options.headers ? config.options.headers : {}),
'Content-Type': 'application/json',
'anthropic-version': '2023-06-01',
- 'anthropic-beta': 'oauth-2025-04-20',
Authorization: `Bearer ${oauthToken}`
},
baseURL: 'https://api.anthropic.com/v1',
diff --git a/src/renderer/src/aiCore/provider/providerInitialization.ts b/src/renderer/src/aiCore/provider/providerInitialization.ts
index 665f2bd05c..51176c1e60 100644
--- a/src/renderer/src/aiCore/provider/providerInitialization.ts
+++ b/src/renderer/src/aiCore/provider/providerInitialization.ts
@@ -1,5 +1,6 @@
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
+import * as z from 'zod'
const logger = loggerService.withContext('ProviderConfigs')
@@ -32,6 +33,14 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
+ {
+ id: 'azure-anthropic',
+ name: 'Azure AI Anthropic',
+ import: () => import('@ai-sdk/anthropic'),
+ creatorFunctionName: 'createAnthropic',
+ supportsImageGeneration: false,
+ aliases: ['azure-anthropic']
+ },
{
id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible',
@@ -71,9 +80,34 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
creatorFunctionName: 'createHuggingFace',
supportsImageGeneration: true,
aliases: ['hf', 'hugging-face']
+ },
+ {
+ id: 'gateway',
+ name: 'Vercel AI Gateway',
+ import: () => import('@ai-sdk/gateway'),
+ creatorFunctionName: 'createGateway',
+ supportsImageGeneration: true,
+ aliases: ['ai-gateway']
+ },
+ {
+ id: 'cerebras',
+ name: 'Cerebras',
+ import: () => import('@ai-sdk/cerebras'),
+ creatorFunctionName: 'createCerebras',
+ supportsImageGeneration: false
+ },
+ {
+ id: 'ollama',
+ name: 'Ollama',
+ import: () => import('ollama-ai-provider-v2'),
+ creatorFunctionName: 'createOllama',
+ supportsImageGeneration: false
}
] as const
+export const registeredNewProviderIds = NEW_PROVIDER_CONFIGS.map((config) => config.id)
+export const registeredNewProviderIdSchema = z.enum(registeredNewProviderIds)
+
/**
* 初始化新的Providers
* 使用aiCore的动态注册功能
diff --git a/src/renderer/src/aiCore/trace/AiSdkSpanAdapter.ts b/src/renderer/src/aiCore/trace/AiSdkSpanAdapter.ts
index 732397de40..0c0e08a03d 100644
--- a/src/renderer/src/aiCore/trace/AiSdkSpanAdapter.ts
+++ b/src/renderer/src/aiCore/trace/AiSdkSpanAdapter.ts
@@ -133,7 +133,7 @@ export class AiSdkSpanAdapter {
// 详细记录转换过程
const operationId = attributes['ai.operationId']
- logger.info('Converting AI SDK span to SpanEntity', {
+ logger.debug('Converting AI SDK span to SpanEntity', {
spanName: spanName,
operationId,
spanTag,
@@ -149,7 +149,7 @@ export class AiSdkSpanAdapter {
})
if (tokenUsage) {
- logger.info('Token usage data found', {
+ logger.debug('Token usage data found', {
spanName: spanName,
operationId,
usage: tokenUsage,
@@ -158,7 +158,7 @@ export class AiSdkSpanAdapter {
}
if (inputs || outputs) {
- logger.info('Input/Output data extracted', {
+ logger.debug('Input/Output data extracted', {
spanName: spanName,
operationId,
hasInputs: !!inputs,
@@ -170,7 +170,7 @@ export class AiSdkSpanAdapter {
}
if (Object.keys(typeSpecificData).length > 0) {
- logger.info('Type-specific data extracted', {
+ logger.debug('Type-specific data extracted', {
spanName: spanName,
operationId,
typeSpecificKeys: Object.keys(typeSpecificData),
@@ -204,7 +204,7 @@ export class AiSdkSpanAdapter {
modelName: modelName || this.extractModelFromAttributes(attributes)
}
- logger.info('AI SDK span successfully converted to SpanEntity', {
+ logger.debug('AI SDK span successfully converted to SpanEntity', {
spanName: spanName,
operationId,
spanId: spanContext.spanId,
@@ -245,8 +245,8 @@ export class AiSdkSpanAdapter {
'gen_ai.usage.output_tokens'
]
- const completionTokens = attributes[inputsTokenKeys.find((key) => attributes[key]) || '']
- const promptTokens = attributes[outputTokenKeys.find((key) => attributes[key]) || '']
+ const promptTokens = attributes[inputsTokenKeys.find((key) => attributes[key]) || '']
+ const completionTokens = attributes[outputTokenKeys.find((key) => attributes[key]) || '']
if (completionTokens !== undefined || promptTokens !== undefined) {
const usage: TokenUsage = {
diff --git a/src/renderer/src/aiCore/trace/__tests__/AiSdkSpanAdapter.test.ts b/src/renderer/src/aiCore/trace/__tests__/AiSdkSpanAdapter.test.ts
new file mode 100644
index 0000000000..4cd6241e64
--- /dev/null
+++ b/src/renderer/src/aiCore/trace/__tests__/AiSdkSpanAdapter.test.ts
@@ -0,0 +1,53 @@
+import type { Span } from '@opentelemetry/api'
+import { SpanKind, SpanStatusCode } from '@opentelemetry/api'
+import { describe, expect, it, vi } from 'vitest'
+
+import { AiSdkSpanAdapter } from '../AiSdkSpanAdapter'
+
+vi.mock('@logger', () => ({
+ loggerService: {
+ withContext: () => ({
+ debug: vi.fn(),
+ error: vi.fn(),
+ info: vi.fn(),
+ warn: vi.fn()
+ })
+ }
+}))
+
+describe('AiSdkSpanAdapter', () => {
+ const createMockSpan = (attributes: Record): Span => {
+ const span = {
+ spanContext: () => ({
+ traceId: 'trace-id',
+ spanId: 'span-id'
+ }),
+ _attributes: attributes,
+ _events: [],
+ name: 'test span',
+ status: { code: SpanStatusCode.OK },
+ kind: SpanKind.CLIENT,
+ startTime: [0, 0] as [number, number],
+ endTime: [0, 1] as [number, number],
+ ended: true,
+ parentSpanId: '',
+ links: []
+ }
+ return span as unknown as Span
+ }
+
+ it('maps prompt and completion usage tokens to the correct fields', () => {
+ const attributes = {
+ 'ai.usage.promptTokens': 321,
+ 'ai.usage.completionTokens': 654
+ }
+
+ const span = createMockSpan(attributes)
+ const result = AiSdkSpanAdapter.convertToSpanEntity({ span })
+
+ expect(result.usage).toBeDefined()
+ expect(result.usage?.prompt_tokens).toBe(321)
+ expect(result.usage?.completion_tokens).toBe(654)
+ expect(result.usage?.total_tokens).toBe(975)
+ })
+})
diff --git a/src/renderer/src/aiCore/types/index.ts b/src/renderer/src/aiCore/types/index.ts
new file mode 100644
index 0000000000..a8a64cf45e
--- /dev/null
+++ b/src/renderer/src/aiCore/types/index.ts
@@ -0,0 +1,15 @@
+/**
+ * This type definition file is only for renderer.
+ * It cannot be migrated to @renderer/types since files within it are actually being used by both main and renderer.
+ * If we do that, main would throw an error because it cannot import a module which imports a type from a browser-enviroment-only package.
+ * (ai-core package is set as browser-enviroment-only)
+ *
+ * TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
+ */
+
+import type { ProviderSettingsMap } from '@cherrystudio/ai-core/provider'
+
+export type AiSdkConfig = {
+ providerId: string
+ options: ProviderSettingsMap[keyof ProviderSettingsMap]
+}
diff --git a/src/renderer/src/aiCore/utils/__tests__/extractAiSdkStandardParams.test.ts b/src/renderer/src/aiCore/utils/__tests__/extractAiSdkStandardParams.test.ts
new file mode 100644
index 0000000000..288cc2e4a5
--- /dev/null
+++ b/src/renderer/src/aiCore/utils/__tests__/extractAiSdkStandardParams.test.ts
@@ -0,0 +1,652 @@
+/**
+ * extractAiSdkStandardParams Unit Tests
+ * Tests for extracting AI SDK standard parameters from custom parameters
+ */
+
+import { describe, expect, it, vi } from 'vitest'
+
+import { extractAiSdkStandardParams } from '../options'
+
+// Mock logger to prevent errors
+vi.mock('@logger', () => ({
+ loggerService: {
+ withContext: () => ({
+ debug: vi.fn(),
+ error: vi.fn(),
+ warn: vi.fn(),
+ info: vi.fn()
+ })
+ }
+}))
+
+// Mock settings store
+vi.mock('@renderer/store/settings', () => ({
+ default: (state = { settings: {} }) => state
+}))
+
+// Mock hooks to prevent uuid errors
+vi.mock('@renderer/hooks/useSettings', () => ({
+ getStoreSetting: vi.fn(() => ({}))
+}))
+
+// Mock uuid to prevent errors
+vi.mock('uuid', () => ({
+ v4: vi.fn(() => 'test-uuid')
+}))
+
+// Mock AssistantService to prevent uuid errors
+vi.mock('@renderer/services/AssistantService', () => ({
+ getDefaultAssistant: vi.fn(() => ({
+ id: 'test-assistant',
+ name: 'Test Assistant',
+ settings: {}
+ })),
+ getDefaultTopic: vi.fn(() => ({
+ id: 'test-topic',
+ assistantId: 'test-assistant',
+ createdAt: new Date().toISOString()
+ }))
+}))
+
+// Mock provider service
+vi.mock('@renderer/services/ProviderService', () => ({
+ getProviderById: vi.fn(() => ({
+ id: 'test-provider',
+ name: 'Test Provider'
+ }))
+}))
+
+// Mock config modules
+vi.mock('@renderer/config/models', () => ({
+ isOpenAIModel: vi.fn(() => false),
+ isQwenMTModel: vi.fn(() => false),
+ isSupportFlexServiceTierModel: vi.fn(() => false),
+ isSupportVerbosityModel: vi.fn(() => false),
+ getModelSupportedVerbosity: vi.fn(() => [])
+}))
+
+vi.mock('@renderer/config/translate', () => ({
+ mapLanguageToQwenMTModel: vi.fn()
+}))
+
+vi.mock('@renderer/utils/provider', () => ({
+ isSupportServiceTierProvider: vi.fn(() => false),
+ isSupportVerbosityProvider: vi.fn(() => false)
+}))
+
+describe('extractAiSdkStandardParams', () => {
+ describe('Positive cases - Standard parameters extraction', () => {
+ it('should extract all AI SDK standard parameters', () => {
+ const customParams = {
+ maxOutputTokens: 1000,
+ temperature: 0.7,
+ topP: 0.9,
+ topK: 40,
+ presencePenalty: 0.5,
+ frequencyPenalty: 0.3,
+ stopSequences: ['STOP', 'END'],
+ seed: 42
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ maxOutputTokens: 1000,
+ temperature: 0.7,
+ topP: 0.9,
+ topK: 40,
+ presencePenalty: 0.5,
+ frequencyPenalty: 0.3,
+ stopSequences: ['STOP', 'END'],
+ seed: 42
+ })
+ expect(result.providerParams).toStrictEqual({})
+ })
+
+ it('should extract single standard parameter', () => {
+ const customParams = {
+ temperature: 0.8
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ temperature: 0.8
+ })
+ expect(result.providerParams).toStrictEqual({})
+ })
+
+ it('should extract topK parameter', () => {
+ const customParams = {
+ topK: 50
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ topK: 50
+ })
+ expect(result.providerParams).toStrictEqual({})
+ })
+
+ it('should extract frequencyPenalty parameter', () => {
+ const customParams = {
+ frequencyPenalty: 0.6
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ frequencyPenalty: 0.6
+ })
+ expect(result.providerParams).toStrictEqual({})
+ })
+
+ it('should extract presencePenalty parameter', () => {
+ const customParams = {
+ presencePenalty: 0.4
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ presencePenalty: 0.4
+ })
+ expect(result.providerParams).toStrictEqual({})
+ })
+
+ it('should extract stopSequences parameter', () => {
+ const customParams = {
+ stopSequences: ['HALT', 'TERMINATE']
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ stopSequences: ['HALT', 'TERMINATE']
+ })
+ expect(result.providerParams).toStrictEqual({})
+ })
+
+ it('should extract seed parameter', () => {
+ const customParams = {
+ seed: 12345
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ seed: 12345
+ })
+ expect(result.providerParams).toStrictEqual({})
+ })
+
+ it('should extract maxOutputTokens parameter', () => {
+ const customParams = {
+ maxOutputTokens: 2048
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ maxOutputTokens: 2048
+ })
+ expect(result.providerParams).toStrictEqual({})
+ })
+
+ it('should extract topP parameter', () => {
+ const customParams = {
+ topP: 0.95
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ topP: 0.95
+ })
+ expect(result.providerParams).toStrictEqual({})
+ })
+ })
+
+ describe('Negative cases - Provider-specific parameters', () => {
+ it('should place all non-standard parameters in providerParams', () => {
+ const customParams = {
+ customParam: 'value',
+ anotherParam: 123,
+ thirdParam: true
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({})
+ expect(result.providerParams).toStrictEqual({
+ customParam: 'value',
+ anotherParam: 123,
+ thirdParam: true
+ })
+ })
+
+ it('should place single provider-specific parameter in providerParams', () => {
+ const customParams = {
+ reasoningEffort: 'high'
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({})
+ expect(result.providerParams).toStrictEqual({
+ reasoningEffort: 'high'
+ })
+ })
+
+ it('should place model-specific parameter in providerParams', () => {
+ const customParams = {
+ thinking: { type: 'enabled', budgetTokens: 5000 }
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({})
+ expect(result.providerParams).toStrictEqual({
+ thinking: { type: 'enabled', budgetTokens: 5000 }
+ })
+ })
+
+ it('should place serviceTier in providerParams', () => {
+ const customParams = {
+ serviceTier: 'auto'
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({})
+ expect(result.providerParams).toStrictEqual({
+ serviceTier: 'auto'
+ })
+ })
+
+ it('should place textVerbosity in providerParams', () => {
+ const customParams = {
+ textVerbosity: 'high'
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({})
+ expect(result.providerParams).toStrictEqual({
+ textVerbosity: 'high'
+ })
+ })
+ })
+
+ describe('Mixed parameters', () => {
+ it('should correctly separate mixed standard and provider-specific parameters', () => {
+ const customParams = {
+ temperature: 0.7,
+ topK: 40,
+ customParam: 'custom_value',
+ reasoningEffort: 'medium',
+ frequencyPenalty: 0.5,
+ seed: 999
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ temperature: 0.7,
+ topK: 40,
+ frequencyPenalty: 0.5,
+ seed: 999
+ })
+ expect(result.providerParams).toStrictEqual({
+ customParam: 'custom_value',
+ reasoningEffort: 'medium'
+ })
+ })
+
+ it('should handle complex mixed parameters with nested objects', () => {
+ const customParams = {
+ topP: 0.9,
+ presencePenalty: 0.3,
+ thinking: { type: 'enabled', budgetTokens: 5000 },
+ stopSequences: ['STOP'],
+ serviceTier: 'auto',
+ maxOutputTokens: 4096
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ topP: 0.9,
+ presencePenalty: 0.3,
+ stopSequences: ['STOP'],
+ maxOutputTokens: 4096
+ })
+ expect(result.providerParams).toStrictEqual({
+ thinking: { type: 'enabled', budgetTokens: 5000 },
+ serviceTier: 'auto'
+ })
+ })
+
+ it('should handle all standard params with some provider params', () => {
+ const customParams = {
+ maxOutputTokens: 2000,
+ temperature: 0.8,
+ topP: 0.95,
+ topK: 50,
+ presencePenalty: 0.6,
+ frequencyPenalty: 0.4,
+ stopSequences: ['END', 'DONE'],
+ seed: 777,
+ customApiParam: 'value',
+ anotherCustomParam: 123
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ maxOutputTokens: 2000,
+ temperature: 0.8,
+ topP: 0.95,
+ topK: 50,
+ presencePenalty: 0.6,
+ frequencyPenalty: 0.4,
+ stopSequences: ['END', 'DONE'],
+ seed: 777
+ })
+ expect(result.providerParams).toStrictEqual({
+ customApiParam: 'value',
+ anotherCustomParam: 123
+ })
+ })
+ })
+
+ describe('Edge cases', () => {
+ it('should handle empty object', () => {
+ const customParams = {}
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({})
+ expect(result.providerParams).toStrictEqual({})
+ })
+
+ it('should handle zero values for numeric parameters', () => {
+ const customParams = {
+ temperature: 0,
+ topK: 0,
+ seed: 0
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ temperature: 0,
+ topK: 0,
+ seed: 0
+ })
+ expect(result.providerParams).toStrictEqual({})
+ })
+
+ it('should handle negative values for numeric parameters', () => {
+ const customParams = {
+ presencePenalty: -0.5,
+ frequencyPenalty: -0.3,
+ seed: -1
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ presencePenalty: -0.5,
+ frequencyPenalty: -0.3,
+ seed: -1
+ })
+ expect(result.providerParams).toStrictEqual({})
+ })
+
+ it('should handle empty arrays for stopSequences', () => {
+ const customParams = {
+ stopSequences: []
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ stopSequences: []
+ })
+ expect(result.providerParams).toStrictEqual({})
+ })
+
+ it('should handle null values in mixed parameters', () => {
+ const customParams = {
+ temperature: 0.7,
+ customNull: null,
+ topK: 40
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ temperature: 0.7,
+ topK: 40
+ })
+ expect(result.providerParams).toStrictEqual({
+ customNull: null
+ })
+ })
+
+ it('should handle undefined values in mixed parameters', () => {
+ const customParams = {
+ temperature: 0.7,
+ customUndefined: undefined,
+ topK: 40
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ temperature: 0.7,
+ topK: 40
+ })
+ expect(result.providerParams).toStrictEqual({
+ customUndefined: undefined
+ })
+ })
+
+ it('should handle boolean values for standard parameters', () => {
+ const customParams = {
+ temperature: 0.7,
+ customBoolean: false,
+ topK: 40
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ temperature: 0.7,
+ topK: 40
+ })
+ expect(result.providerParams).toStrictEqual({
+ customBoolean: false
+ })
+ })
+
+ it('should handle very large numeric values', () => {
+ const customParams = {
+ maxOutputTokens: 999999,
+ seed: 2147483647,
+ topK: 10000
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ maxOutputTokens: 999999,
+ seed: 2147483647,
+ topK: 10000
+ })
+ expect(result.providerParams).toStrictEqual({})
+ })
+
+ it('should handle decimal values with high precision', () => {
+ const customParams = {
+ temperature: 0.123456789,
+ topP: 0.987654321,
+ presencePenalty: 0.111111111
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ temperature: 0.123456789,
+ topP: 0.987654321,
+ presencePenalty: 0.111111111
+ })
+ expect(result.providerParams).toStrictEqual({})
+ })
+ })
+
+ describe('Case sensitivity', () => {
+ it('should NOT extract parameters with incorrect case - uppercase first letter', () => {
+ const customParams = {
+ Temperature: 0.7,
+ TopK: 40,
+ FrequencyPenalty: 0.5
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({})
+ expect(result.providerParams).toStrictEqual({
+ Temperature: 0.7,
+ TopK: 40,
+ FrequencyPenalty: 0.5
+ })
+ })
+
+ it('should NOT extract parameters with incorrect case - all uppercase', () => {
+ const customParams = {
+ TEMPERATURE: 0.7,
+ TOPK: 40,
+ SEED: 42
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({})
+ expect(result.providerParams).toStrictEqual({
+ TEMPERATURE: 0.7,
+ TOPK: 40,
+ SEED: 42
+ })
+ })
+
+ it('should NOT extract parameters with incorrect case - all lowercase', () => {
+ const customParams = {
+ maxoutputtokens: 1000,
+ frequencypenalty: 0.5,
+ stopsequences: ['STOP']
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({})
+ expect(result.providerParams).toStrictEqual({
+ maxoutputtokens: 1000,
+ frequencypenalty: 0.5,
+ stopsequences: ['STOP']
+ })
+ })
+
+ it('should correctly extract exact case match while rejecting incorrect case', () => {
+ const customParams = {
+ temperature: 0.7,
+ Temperature: 0.8,
+ TEMPERATURE: 0.9,
+ topK: 40,
+ TopK: 50
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ temperature: 0.7,
+ topK: 40
+ })
+ expect(result.providerParams).toStrictEqual({
+ Temperature: 0.8,
+ TEMPERATURE: 0.9,
+ TopK: 50
+ })
+ })
+ })
+
+ describe('Parameter name variations', () => {
+ it('should NOT extract similar but incorrect parameter names', () => {
+ const customParams = {
+ temp: 0.7, // should not match temperature
+ top_k: 40, // should not match topK
+ max_tokens: 1000, // should not match maxOutputTokens
+ freq_penalty: 0.5 // should not match frequencyPenalty
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({})
+ expect(result.providerParams).toStrictEqual({
+ temp: 0.7,
+ top_k: 40,
+ max_tokens: 1000,
+ freq_penalty: 0.5
+ })
+ })
+
+ it('should NOT extract snake_case versions of standard parameters', () => {
+ const customParams = {
+ top_k: 40,
+ top_p: 0.9,
+ presence_penalty: 0.5,
+ frequency_penalty: 0.3,
+ stop_sequences: ['STOP'],
+ max_output_tokens: 1000
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({})
+ expect(result.providerParams).toStrictEqual({
+ top_k: 40,
+ top_p: 0.9,
+ presence_penalty: 0.5,
+ frequency_penalty: 0.3,
+ stop_sequences: ['STOP'],
+ max_output_tokens: 1000
+ })
+ })
+
+ it('should extract exact camelCase parameters only', () => {
+ const customParams = {
+ topK: 40, // correct
+ top_k: 50, // incorrect
+ topP: 0.9, // correct
+ top_p: 0.8, // incorrect
+ frequencyPenalty: 0.5, // correct
+ frequency_penalty: 0.4 // incorrect
+ }
+
+ const result = extractAiSdkStandardParams(customParams)
+
+ expect(result.standardParams).toStrictEqual({
+ topK: 40,
+ topP: 0.9,
+ frequencyPenalty: 0.5
+ })
+ expect(result.providerParams).toStrictEqual({
+ top_k: 50,
+ top_p: 0.8,
+ frequency_penalty: 0.4
+ })
+ })
+ })
+})
diff --git a/src/renderer/src/aiCore/utils/__tests__/image.test.ts b/src/renderer/src/aiCore/utils/__tests__/image.test.ts
new file mode 100644
index 0000000000..1c5381a5ef
--- /dev/null
+++ b/src/renderer/src/aiCore/utils/__tests__/image.test.ts
@@ -0,0 +1,121 @@
+/**
+ * image.ts Unit Tests
+ * Tests for Gemini image generation utilities
+ */
+
+import type { Model, Provider } from '@renderer/types'
+import { SystemProviderIds } from '@renderer/types'
+import { describe, expect, it } from 'vitest'
+
+import { buildGeminiGenerateImageParams, isOpenRouterGeminiGenerateImageModel } from '../image'
+
+describe('image utils', () => {
+ describe('buildGeminiGenerateImageParams', () => {
+ it('should return correct response modalities', () => {
+ const result = buildGeminiGenerateImageParams()
+
+ expect(result).toEqual({
+ responseModalities: ['TEXT', 'IMAGE']
+ })
+ })
+
+ it('should return an object with responseModalities property', () => {
+ const result = buildGeminiGenerateImageParams()
+
+ expect(result).toHaveProperty('responseModalities')
+ expect(Array.isArray(result.responseModalities)).toBe(true)
+ expect(result.responseModalities).toHaveLength(2)
+ })
+ })
+
+ describe('isOpenRouterGeminiGenerateImageModel', () => {
+ const mockOpenRouterProvider: Provider = {
+ id: SystemProviderIds.openrouter,
+ name: 'OpenRouter',
+ apiKey: 'test-key',
+ apiHost: 'https://openrouter.ai/api/v1',
+ isSystem: true
+ } as Provider
+
+ const mockOtherProvider: Provider = {
+ id: SystemProviderIds.openai,
+ name: 'OpenAI',
+ apiKey: 'test-key',
+ apiHost: 'https://api.openai.com/v1',
+ isSystem: true
+ } as Provider
+
+ it('should return true for OpenRouter Gemini 2.5 Flash Image model', () => {
+ const model: Model = {
+ id: 'google/gemini-2.5-flash-image-preview',
+ name: 'Gemini 2.5 Flash Image',
+ provider: SystemProviderIds.openrouter
+ } as Model
+
+ const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
+ expect(result).toBe(true)
+ })
+
+ it('should return false for non-Gemini model on OpenRouter', () => {
+ const model: Model = {
+ id: 'openai/gpt-4',
+ name: 'GPT-4',
+ provider: SystemProviderIds.openrouter
+ } as Model
+
+ const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
+ expect(result).toBe(false)
+ })
+
+ it('should return false for Gemini model on non-OpenRouter provider', () => {
+ const model: Model = {
+ id: 'gemini-2.5-flash-image-preview',
+ name: 'Gemini 2.5 Flash Image',
+ provider: SystemProviderIds.gemini
+ } as Model
+
+ const result = isOpenRouterGeminiGenerateImageModel(model, mockOtherProvider)
+ expect(result).toBe(false)
+ })
+
+ it('should return false for Gemini model without image suffix', () => {
+ const model: Model = {
+ id: 'google/gemini-2.5-flash',
+ name: 'Gemini 2.5 Flash',
+ provider: SystemProviderIds.openrouter
+ } as Model
+
+ const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
+ expect(result).toBe(false)
+ })
+
+ it('should handle model ID with partial match', () => {
+ const model: Model = {
+ id: 'google/gemini-2.5-flash-image-generation',
+ name: 'Gemini Image Gen',
+ provider: SystemProviderIds.openrouter
+ } as Model
+
+ const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
+ expect(result).toBe(true)
+ })
+
+ it('should return false for custom provider', () => {
+ const customProvider: Provider = {
+ id: 'custom-provider-123',
+ name: 'Custom Provider',
+ apiKey: 'test-key',
+ apiHost: 'https://custom.com'
+ } as Provider
+
+ const model: Model = {
+ id: 'gemini-2.5-flash-image-preview',
+ name: 'Gemini 2.5 Flash Image',
+ provider: 'custom-provider-123'
+ } as Model
+
+ const result = isOpenRouterGeminiGenerateImageModel(model, customProvider)
+ expect(result).toBe(false)
+ })
+ })
+})
diff --git a/src/renderer/src/aiCore/utils/__tests__/mcp.test.ts b/src/renderer/src/aiCore/utils/__tests__/mcp.test.ts
new file mode 100644
index 0000000000..dc26a03c80
--- /dev/null
+++ b/src/renderer/src/aiCore/utils/__tests__/mcp.test.ts
@@ -0,0 +1,440 @@
+/**
+ * mcp.ts Unit Tests
+ * Tests for MCP tools configuration and conversion utilities
+ */
+
+import type { MCPTool } from '@renderer/types'
+import type { Tool } from 'ai'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+import { convertMcpToolsToAiSdkTools, setupToolsConfig } from '../mcp'
+
+// Mock dependencies
+vi.mock('@logger', () => ({
+ loggerService: {
+ withContext: () => ({
+ debug: vi.fn(),
+ error: vi.fn(),
+ warn: vi.fn(),
+ info: vi.fn()
+ })
+ }
+}))
+
+vi.mock('@renderer/utils/mcp-tools', () => ({
+ getMcpServerByTool: vi.fn(() => ({ id: 'test-server', autoApprove: false })),
+ isToolAutoApproved: vi.fn(() => false),
+ callMCPTool: vi.fn(async () => ({
+ content: [{ type: 'text', text: 'Tool executed successfully' }],
+ isError: false
+ }))
+}))
+
+vi.mock('@renderer/utils/userConfirmation', () => ({
+ requestToolConfirmation: vi.fn(async () => true)
+}))
+
+describe('mcp utils', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ })
+
+ describe('setupToolsConfig', () => {
+ it('should return undefined when no MCP tools provided', () => {
+ const result = setupToolsConfig()
+ expect(result).toBeUndefined()
+ })
+
+ it('should return undefined when empty MCP tools array provided', () => {
+ const result = setupToolsConfig([])
+ expect(result).toBeUndefined()
+ })
+
+ it('should convert MCP tools to AI SDK tools format', () => {
+ const mcpTools: MCPTool[] = [
+ {
+ id: 'test-tool-1',
+ serverId: 'test-server',
+ serverName: 'test-server',
+ name: 'test-tool',
+ description: 'A test tool',
+ type: 'mcp',
+ inputSchema: {
+ type: 'object',
+ properties: {
+ query: { type: 'string' }
+ }
+ }
+ }
+ ]
+
+ const result = setupToolsConfig(mcpTools)
+
+ expect(result).not.toBeUndefined()
+ // Tools are now keyed by id (which includes serverId suffix) for uniqueness
+ expect(Object.keys(result!)).toEqual(['test-tool-1'])
+ expect(result!['test-tool-1']).toHaveProperty('description')
+ expect(result!['test-tool-1']).toHaveProperty('inputSchema')
+ expect(result!['test-tool-1']).toHaveProperty('execute')
+ })
+
+ it('should handle multiple MCP tools', () => {
+ const mcpTools: MCPTool[] = [
+ {
+ id: 'tool1-id',
+ serverId: 'server1',
+ serverName: 'server1',
+ name: 'tool1',
+ description: 'First tool',
+ type: 'mcp',
+ inputSchema: {
+ type: 'object',
+ properties: {}
+ }
+ },
+ {
+ id: 'tool2-id',
+ serverId: 'server2',
+ serverName: 'server2',
+ name: 'tool2',
+ description: 'Second tool',
+ type: 'mcp',
+ inputSchema: {
+ type: 'object',
+ properties: {}
+ }
+ }
+ ]
+
+ const result = setupToolsConfig(mcpTools)
+
+ expect(result).not.toBeUndefined()
+ expect(Object.keys(result!)).toHaveLength(2)
+ // Tools are keyed by id for uniqueness
+ expect(Object.keys(result!)).toEqual(['tool1-id', 'tool2-id'])
+ })
+ })
+
+ describe('convertMcpToolsToAiSdkTools', () => {
+ it('should convert single MCP tool to AI SDK tool', () => {
+ const mcpTools: MCPTool[] = [
+ {
+ id: 'get-weather-id',
+ serverId: 'weather-server',
+ serverName: 'weather-server',
+ name: 'get-weather',
+ description: 'Get weather information',
+ type: 'mcp',
+ inputSchema: {
+ type: 'object',
+ properties: {
+ location: { type: 'string' }
+ },
+ required: ['location']
+ }
+ }
+ ]
+
+ const result = convertMcpToolsToAiSdkTools(mcpTools)
+
+ // Tools are keyed by id for uniqueness when multiple server instances exist
+ expect(Object.keys(result)).toEqual(['get-weather-id'])
+
+ const tool = result['get-weather-id'] as Tool
+ expect(tool.description).toBe('Get weather information')
+ expect(tool.inputSchema).toBeDefined()
+ expect(typeof tool.execute).toBe('function')
+ })
+
+ it('should handle tool without description', () => {
+ const mcpTools: MCPTool[] = [
+ {
+ id: 'no-desc-tool-id',
+ serverId: 'test-server',
+ serverName: 'test-server',
+ name: 'no-desc-tool',
+ type: 'mcp',
+ inputSchema: {
+ type: 'object',
+ properties: {}
+ }
+ }
+ ]
+
+ const result = convertMcpToolsToAiSdkTools(mcpTools)
+
+ expect(Object.keys(result)).toEqual(['no-desc-tool-id'])
+ const tool = result['no-desc-tool-id'] as Tool
+ expect(tool.description).toBe('Tool from test-server')
+ })
+
+ it('should convert empty tools array', () => {
+ const result = convertMcpToolsToAiSdkTools([])
+ expect(result).toEqual({})
+ })
+
+ it('should handle complex input schemas', () => {
+ const mcpTools: MCPTool[] = [
+ {
+ id: 'complex-tool-id',
+ serverId: 'server',
+ serverName: 'server',
+ name: 'complex-tool',
+ description: 'Tool with complex schema',
+ type: 'mcp',
+ inputSchema: {
+ type: 'object',
+ properties: {
+ name: { type: 'string' },
+ age: { type: 'number' },
+ tags: {
+ type: 'array',
+ items: { type: 'string' }
+ },
+ metadata: {
+ type: 'object',
+ properties: {
+ key: { type: 'string' }
+ }
+ }
+ },
+ required: ['name']
+ }
+ }
+ ]
+
+ const result = convertMcpToolsToAiSdkTools(mcpTools)
+
+ expect(Object.keys(result)).toEqual(['complex-tool-id'])
+ const tool = result['complex-tool-id'] as Tool
+ expect(tool.inputSchema).toBeDefined()
+ expect(typeof tool.execute).toBe('function')
+ })
+
+ it('should preserve tool id with special characters', () => {
+ const mcpTools: MCPTool[] = [
+ {
+ id: 'special-tool-id',
+ serverId: 'server',
+ serverName: 'server',
+ name: 'tool_with-special.chars',
+ description: 'Special chars tool',
+ type: 'mcp',
+ inputSchema: {
+ type: 'object',
+ properties: {}
+ }
+ }
+ ]
+
+ const result = convertMcpToolsToAiSdkTools(mcpTools)
+ // Tools are keyed by id for uniqueness
+ expect(Object.keys(result)).toEqual(['special-tool-id'])
+ })
+
+ it('should handle multiple tools with different schemas', () => {
+ const mcpTools: MCPTool[] = [
+ {
+ id: 'string-tool-id',
+ serverId: 'server1',
+ serverName: 'server1',
+ name: 'string-tool',
+ description: 'String tool',
+ type: 'mcp',
+ inputSchema: {
+ type: 'object',
+ properties: {
+ input: { type: 'string' }
+ }
+ }
+ },
+ {
+ id: 'number-tool-id',
+ serverId: 'server2',
+ serverName: 'server2',
+ name: 'number-tool',
+ description: 'Number tool',
+ type: 'mcp',
+ inputSchema: {
+ type: 'object',
+ properties: {
+ count: { type: 'number' }
+ }
+ }
+ },
+ {
+ id: 'boolean-tool-id',
+ serverId: 'server3',
+ serverName: 'server3',
+ name: 'boolean-tool',
+ description: 'Boolean tool',
+ type: 'mcp',
+ inputSchema: {
+ type: 'object',
+ properties: {
+ enabled: { type: 'boolean' }
+ }
+ }
+ }
+ ]
+
+ const result = convertMcpToolsToAiSdkTools(mcpTools)
+
+ // Tools are keyed by id for uniqueness
+ expect(Object.keys(result).sort()).toEqual(['boolean-tool-id', 'number-tool-id', 'string-tool-id'])
+ expect(result['string-tool-id']).toBeDefined()
+ expect(result['number-tool-id']).toBeDefined()
+ expect(result['boolean-tool-id']).toBeDefined()
+ })
+ })
+
+ describe('tool execution', () => {
+ it('should execute tool with user confirmation', async () => {
+ const { callMCPTool } = await import('@renderer/utils/mcp-tools')
+ const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
+
+ vi.mocked(requestToolConfirmation).mockResolvedValue(true)
+ vi.mocked(callMCPTool).mockResolvedValue({
+ content: [{ type: 'text', text: 'Success' }],
+ isError: false
+ })
+
+ const mcpTools: MCPTool[] = [
+ {
+ id: 'test-exec-tool-id',
+ serverId: 'test-server',
+ serverName: 'test-server',
+ name: 'test-exec-tool',
+ description: 'Test execution tool',
+ type: 'mcp',
+ inputSchema: {
+ type: 'object',
+ properties: {}
+ }
+ }
+ ]
+
+ const tools = convertMcpToolsToAiSdkTools(mcpTools)
+ const tool = tools['test-exec-tool-id'] as Tool
+ const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'test-call-123' })
+
+ expect(requestToolConfirmation).toHaveBeenCalled()
+ expect(callMCPTool).toHaveBeenCalled()
+ expect(result).toEqual({
+ content: [{ type: 'text', text: 'Success' }],
+ isError: false
+ })
+ })
+
+ it('should handle user cancellation', async () => {
+ const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
+ const { callMCPTool } = await import('@renderer/utils/mcp-tools')
+
+ vi.mocked(requestToolConfirmation).mockResolvedValue(false)
+
+ const mcpTools: MCPTool[] = [
+ {
+ id: 'cancelled-tool-id',
+ serverId: 'test-server',
+ serverName: 'test-server',
+ name: 'cancelled-tool',
+ description: 'Tool to cancel',
+ type: 'mcp',
+ inputSchema: {
+ type: 'object',
+ properties: {}
+ }
+ }
+ ]
+
+ const tools = convertMcpToolsToAiSdkTools(mcpTools)
+ const tool = tools['cancelled-tool-id'] as Tool
+ const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'cancel-call-123' })
+
+ expect(requestToolConfirmation).toHaveBeenCalled()
+ expect(callMCPTool).not.toHaveBeenCalled()
+ expect(result).toEqual({
+ content: [
+ {
+ type: 'text',
+ text: 'User declined to execute tool "cancelled-tool".'
+ }
+ ],
+ isError: false
+ })
+ })
+
+ it('should handle tool execution error', async () => {
+ const { callMCPTool } = await import('@renderer/utils/mcp-tools')
+ const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
+
+ vi.mocked(requestToolConfirmation).mockResolvedValue(true)
+ vi.mocked(callMCPTool).mockResolvedValue({
+ content: [{ type: 'text', text: 'Error occurred' }],
+ isError: true
+ })
+
+ const mcpTools: MCPTool[] = [
+ {
+ id: 'error-tool-id',
+ serverId: 'test-server',
+ serverName: 'test-server',
+ name: 'error-tool',
+ description: 'Tool that errors',
+ type: 'mcp',
+ inputSchema: {
+ type: 'object',
+ properties: {}
+ }
+ }
+ ]
+
+ const tools = convertMcpToolsToAiSdkTools(mcpTools)
+ const tool = tools['error-tool-id'] as Tool
+
+ await expect(
+ tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'error-call-123' })
+ ).rejects.toEqual({
+ content: [{ type: 'text', text: 'Error occurred' }],
+ isError: true
+ })
+ })
+
+ it('should auto-approve when enabled', async () => {
+ const { callMCPTool, isToolAutoApproved } = await import('@renderer/utils/mcp-tools')
+ const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
+
+ vi.mocked(isToolAutoApproved).mockReturnValue(true)
+ vi.mocked(callMCPTool).mockResolvedValue({
+ content: [{ type: 'text', text: 'Auto-approved success' }],
+ isError: false
+ })
+
+ const mcpTools: MCPTool[] = [
+ {
+ id: 'auto-approve-tool-id',
+ serverId: 'test-server',
+ serverName: 'test-server',
+ name: 'auto-approve-tool',
+ description: 'Auto-approved tool',
+ type: 'mcp',
+ inputSchema: {
+ type: 'object',
+ properties: {}
+ }
+ }
+ ]
+
+ const tools = convertMcpToolsToAiSdkTools(mcpTools)
+ const tool = tools['auto-approve-tool-id'] as Tool
+ const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'auto-call-123' })
+
+ expect(requestToolConfirmation).not.toHaveBeenCalled()
+ expect(callMCPTool).toHaveBeenCalled()
+ expect(result).toEqual({
+ content: [{ type: 'text', text: 'Auto-approved success' }],
+ isError: false
+ })
+ })
+ })
+})
diff --git a/src/renderer/src/aiCore/utils/__tests__/options.test.ts b/src/renderer/src/aiCore/utils/__tests__/options.test.ts
new file mode 100644
index 0000000000..9eeeac725b
--- /dev/null
+++ b/src/renderer/src/aiCore/utils/__tests__/options.test.ts
@@ -0,0 +1,1156 @@
+/**
+ * options.ts Unit Tests
+ * Tests for building provider-specific options
+ */
+
+import type { Assistant, Model, Provider } from '@renderer/types'
+import { OpenAIServiceTiers, SystemProviderIds } from '@renderer/types'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+import { buildProviderOptions } from '../options'
+
+// Mock dependencies
+vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
+ const actual = (await importOriginal()) as object
+ return {
+ ...actual,
+ baseProviderIdSchema: {
+ safeParse: vi.fn((id) => {
+ const baseProviders = [
+ 'openai',
+ 'openai-chat',
+ 'azure',
+ 'azure-responses',
+ 'huggingface',
+ 'anthropic',
+ 'google',
+ 'xai',
+ 'deepseek',
+ 'openrouter',
+ 'openai-compatible',
+ 'cherryin'
+ ]
+ if (baseProviders.includes(id)) {
+ return { success: true, data: id }
+ }
+ return { success: false }
+ })
+ },
+ customProviderIdSchema: {
+ safeParse: vi.fn((id) => {
+ const customProviders = [
+ 'google-vertex',
+ 'google-vertex-anthropic',
+ 'bedrock',
+ 'gateway',
+ 'aihubmix',
+ 'newapi',
+ 'ollama'
+ ]
+ if (customProviders.includes(id)) {
+ return { success: true, data: id }
+ }
+ return { success: false, error: new Error('Invalid provider') }
+ })
+ }
+ }
+})
+
+// Don't mock getAiSdkProviderId - use real implementation for more accurate tests
+
+vi.mock('@renderer/config/models', async (importOriginal) => ({
+ ...(await importOriginal()),
+ isOpenAIModel: vi.fn((model) => model.id.includes('gpt') || model.id.includes('o1')),
+ isQwenMTModel: vi.fn(() => false),
+ isSupportFlexServiceTierModel: vi.fn(() => true),
+ isOpenAILLMModel: vi.fn(() => true),
+ SYSTEM_MODELS: {
+ defaultModel: [
+ { id: 'default-1', name: 'Default 1' },
+ { id: 'default-2', name: 'Default 2' },
+ { id: 'default-3', name: 'Default 3' }
+ ]
+ }
+}))
+
+vi.mock(import('@renderer/utils/provider'), async (importOriginal) => {
+ return {
+ ...(await importOriginal()),
+ isSupportServiceTierProvider: vi.fn((provider) => {
+ return [SystemProviderIds.openai, SystemProviderIds.groq].includes(provider.id)
+ })
+ }
+})
+
+vi.mock('@renderer/store/settings', () => ({
+ default: (state = { settings: {} }) => state
+}))
+
+vi.mock('@renderer/hooks/useSettings', () => ({
+ getStoreSetting: vi.fn((key) => {
+ if (key === 'openAI') {
+ return { summaryText: 'off', verbosity: 'medium' } as any
+ }
+ return {}
+ })
+}))
+
+vi.mock('@renderer/services/AssistantService', () => ({
+ getDefaultAssistant: vi.fn(() => ({
+ id: 'default',
+ name: 'Default Assistant',
+ settings: {}
+ })),
+ getAssistantSettings: vi.fn(() => ({
+ reasoning_effort: 'medium',
+ maxTokens: 4096
+ })),
+ getProviderByModel: vi.fn((model: Model) => ({
+ id: model.provider,
+ name: 'Mock Provider'
+ }))
+}))
+
+vi.mock('../reasoning', () => ({
+ getOpenAIReasoningParams: vi.fn(() => ({ reasoningEffort: 'medium' })),
+ getAnthropicReasoningParams: vi.fn(() => ({
+ thinking: { type: 'enabled', budgetTokens: 5000 }
+ })),
+ getGeminiReasoningParams: vi.fn(() => ({
+ thinkingConfig: { include_thoughts: true }
+ })),
+ getXAIReasoningParams: vi.fn(() => ({ reasoningEffort: 'high' })),
+ getBedrockReasoningParams: vi.fn(() => ({
+ reasoningConfig: { type: 'enabled', budgetTokens: 5000 }
+ })),
+ getReasoningEffort: vi.fn(() => ({ reasoningEffort: 'medium' })),
+ getCustomParameters: vi.fn(() => ({})),
+ extractAiSdkStandardParams: vi.fn((customParams: Record) => {
+ const AI_SDK_STANDARD_PARAMS = ['topK', 'frequencyPenalty', 'presencePenalty', 'stopSequences', 'seed']
+ const standardParams: Record = {}
+ const providerParams: Record = {}
+ for (const [key, value] of Object.entries(customParams)) {
+ if (AI_SDK_STANDARD_PARAMS.includes(key)) {
+ standardParams[key] = value
+ } else {
+ providerParams[key] = value
+ }
+ }
+ return { standardParams, providerParams }
+ })
+}))
+
+vi.mock('../image', () => ({
+ buildGeminiGenerateImageParams: vi.fn(() => ({
+ responseModalities: ['TEXT', 'IMAGE']
+ }))
+}))
+
+vi.mock('../websearch', () => ({
+ getWebSearchParams: vi.fn(() => ({ enable_search: true }))
+}))
+
+vi.mock('../../prepareParams/header', () => ({
+ addAnthropicHeaders: vi.fn(() => ['context-1m-2025-08-07'])
+}))
+
+const ensureWindowApi = () => {
+ const globalWindow = window as any
+ globalWindow.api = globalWindow.api || {}
+ globalWindow.api.getAppInfo = globalWindow.api.getAppInfo || vi.fn(async () => ({ notesPath: '' }))
+}
+
+ensureWindowApi()
+
+describe('options utils', () => {
+ const mockAssistant: Assistant = {
+ id: 'test-assistant',
+ name: 'Test Assistant',
+ settings: {}
+ } as Assistant
+
+ const mockModel: Model = {
+ id: 'gpt-4',
+ name: 'GPT-4',
+ provider: SystemProviderIds.openai
+ } as Model
+
+ beforeEach(async () => {
+ vi.clearAllMocks()
+ // Reset getCustomParameters to return empty object by default
+ const { getCustomParameters } = await import('../reasoning')
+ vi.mocked(getCustomParameters).mockReturnValue({})
+ })
+
+ describe('buildProviderOptions', () => {
+ describe('OpenAI provider', () => {
+ const openaiProvider: Provider = {
+ id: SystemProviderIds.openai,
+ name: 'OpenAI',
+ type: 'openai-response',
+ apiKey: 'test-key',
+ apiHost: 'https://api.openai.com/v1',
+ isSystem: true
+ } as Provider
+
+ it('should build basic OpenAI options', () => {
+ const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions).toHaveProperty('openai')
+ expect(result.providerOptions.openai).toBeDefined()
+ expect(result.standardParams).toBeDefined()
+ })
+
+ it('should include reasoning parameters when enabled', () => {
+ const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
+ enableReasoning: true,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions.openai).toHaveProperty('reasoningEffort')
+ expect(result.providerOptions.openai.reasoningEffort).toBe('medium')
+ })
+
+ it('should include service tier when supported', () => {
+ const providerWithServiceTier: Provider = {
+ ...openaiProvider,
+ serviceTier: OpenAIServiceTiers.auto
+ }
+
+ const result = buildProviderOptions(mockAssistant, mockModel, providerWithServiceTier, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions.openai).toHaveProperty('serviceTier')
+ expect(result.providerOptions.openai.serviceTier).toBe(OpenAIServiceTiers.auto)
+ })
+ })
+
+ describe('Anthropic provider', () => {
+ const anthropicProvider: Provider = {
+ id: SystemProviderIds.anthropic,
+ name: 'Anthropic',
+ type: 'anthropic',
+ apiKey: 'test-key',
+ apiHost: 'https://api.anthropic.com',
+ isSystem: true
+ } as Provider
+
+ const anthropicModel: Model = {
+ id: 'claude-3-5-sonnet-20241022',
+ name: 'Claude 3.5 Sonnet',
+ provider: SystemProviderIds.anthropic
+ } as Model
+
+ it('should build basic Anthropic options', () => {
+ const result = buildProviderOptions(mockAssistant, anthropicModel, anthropicProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions).toHaveProperty('anthropic')
+ expect(result.providerOptions.anthropic).toBeDefined()
+ })
+
+ it('should include reasoning parameters when enabled', () => {
+ const result = buildProviderOptions(mockAssistant, anthropicModel, anthropicProvider, {
+ enableReasoning: true,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions.anthropic).toHaveProperty('thinking')
+ expect(result.providerOptions.anthropic.thinking).toEqual({
+ type: 'enabled',
+ budgetTokens: 5000
+ })
+ })
+ })
+
+ describe('Google provider', () => {
+ const googleProvider: Provider = {
+ id: SystemProviderIds.gemini,
+ name: 'Google',
+ type: 'gemini',
+ apiKey: 'test-key',
+ apiHost: 'https://generativelanguage.googleapis.com',
+ isSystem: true,
+ models: [{ id: 'gemini-2.0-flash-exp' }] as Model[]
+ } as Provider
+
+ const googleModel: Model = {
+ id: 'gemini-2.0-flash-exp',
+ name: 'Gemini 2.0 Flash',
+ provider: SystemProviderIds.gemini
+ } as Model
+
+ it('should build basic Google options', () => {
+ const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions).toHaveProperty('google')
+ expect(result.providerOptions.google).toBeDefined()
+ })
+
+ it('should include reasoning parameters when enabled', () => {
+ const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
+ enableReasoning: true,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions.google).toHaveProperty('thinkingConfig')
+ expect(result.providerOptions.google.thinkingConfig).toEqual({
+ include_thoughts: true
+ })
+ })
+
+ it('should include image generation parameters when enabled', () => {
+ const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: true
+ })
+
+ expect(result.providerOptions.google).toHaveProperty('responseModalities')
+ expect(result.providerOptions.google.responseModalities).toEqual(['TEXT', 'IMAGE'])
+ })
+ })
+
+ describe('xAI provider', () => {
+ const xaiProvider = {
+ id: SystemProviderIds.grok,
+ name: 'xAI',
+ type: 'new-api',
+ apiKey: 'test-key',
+ apiHost: 'https://api.x.ai/v1',
+ isSystem: true,
+ models: [] as Model[]
+ } as Provider
+
+ const xaiModel: Model = {
+ id: 'grok-2-latest',
+ name: 'Grok 2',
+ provider: SystemProviderIds.grok
+ } as Model
+
+ it('should build basic xAI options', () => {
+ const result = buildProviderOptions(mockAssistant, xaiModel, xaiProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions).toHaveProperty('xai')
+ expect(result.providerOptions.xai).toBeDefined()
+ })
+
+ it('should include reasoning parameters when enabled', () => {
+ const result = buildProviderOptions(mockAssistant, xaiModel, xaiProvider, {
+ enableReasoning: true,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions.xai).toHaveProperty('reasoningEffort')
+ expect(result.providerOptions.xai.reasoningEffort).toBe('high')
+ })
+ })
+
+ describe('DeepSeek provider', () => {
+ const deepseekProvider: Provider = {
+ id: SystemProviderIds.deepseek,
+ name: 'DeepSeek',
+ type: 'openai',
+ apiKey: 'test-key',
+ apiHost: 'https://api.deepseek.com',
+ isSystem: true
+ } as Provider
+
+ const deepseekModel: Model = {
+ id: 'deepseek-chat',
+ name: 'DeepSeek Chat',
+ provider: SystemProviderIds.deepseek
+ } as Model
+
+ it('should build basic DeepSeek options', () => {
+ const result = buildProviderOptions(mockAssistant, deepseekModel, deepseekProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+ expect(result.providerOptions).toHaveProperty('deepseek')
+ expect(result.providerOptions.deepseek).toBeDefined()
+ })
+ })
+
+ describe('OpenRouter provider', () => {
+ const openrouterProvider: Provider = {
+ id: SystemProviderIds.openrouter,
+ name: 'OpenRouter',
+ type: 'openai',
+ apiKey: 'test-key',
+ apiHost: 'https://openrouter.ai/api/v1',
+ isSystem: true
+ } as Provider
+
+ const openrouterModel: Model = {
+ id: 'openai/gpt-4',
+ name: 'GPT-4',
+ provider: SystemProviderIds.openrouter
+ } as Model
+
+ it('should build basic OpenRouter options', () => {
+ const result = buildProviderOptions(mockAssistant, openrouterModel, openrouterProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions).toHaveProperty('openrouter')
+ expect(result.providerOptions.openrouter).toBeDefined()
+ })
+
+ it('should include web search parameters when enabled', () => {
+ const result = buildProviderOptions(mockAssistant, openrouterModel, openrouterProvider, {
+ enableReasoning: false,
+ enableWebSearch: true,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions.openrouter).toHaveProperty('enable_search')
+ })
+ })
+
+ describe('Custom parameters', () => {
+ it('should merge custom provider-specific parameters', async () => {
+ const { getCustomParameters } = await import('../reasoning')
+
+ vi.mocked(getCustomParameters).mockReturnValue({
+ custom_param: 'custom_value',
+ another_param: 123
+ })
+
+ const result = buildProviderOptions(
+ mockAssistant,
+ mockModel,
+ {
+ id: SystemProviderIds.openai,
+ name: 'OpenAI',
+ type: 'openai',
+ apiKey: 'test-key',
+ apiHost: 'https://api.openai.com/v1'
+ } as Provider,
+ {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ }
+ )
+
+ expect(result.providerOptions).toStrictEqual({
+ openai: {
+ custom_param: 'custom_value',
+ another_param: 123,
+ serviceTier: undefined,
+ textVerbosity: undefined
+ }
+ })
+ })
+
+ it('should extract AI SDK standard params from custom parameters', async () => {
+ const { getCustomParameters } = await import('../reasoning')
+
+ vi.mocked(getCustomParameters).mockReturnValue({
+ topK: 5,
+ frequencyPenalty: 0.5,
+ presencePenalty: 0.3,
+ seed: 42,
+ custom_param: 'custom_value'
+ })
+
+ const result = buildProviderOptions(
+ mockAssistant,
+ mockModel,
+ {
+ id: SystemProviderIds.gemini,
+ name: 'Google',
+ type: 'gemini',
+ apiKey: 'test-key',
+ apiHost: 'https://generativelanguage.googleapis.com'
+ } as Provider,
+ {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ }
+ )
+
+ // Standard params should be extracted and returned separately
+ expect(result.standardParams).toEqual({
+ topK: 5,
+ frequencyPenalty: 0.5,
+ presencePenalty: 0.3,
+ seed: 42
+ })
+
+ // Provider-specific params should still be in providerOptions
+ expect(result.providerOptions.google).toHaveProperty('custom_param')
+ expect(result.providerOptions.google.custom_param).toBe('custom_value')
+
+ // Standard params should NOT be in providerOptions
+ expect(result.providerOptions.google).not.toHaveProperty('topK')
+ expect(result.providerOptions.google).not.toHaveProperty('frequencyPenalty')
+ expect(result.providerOptions.google).not.toHaveProperty('presencePenalty')
+ expect(result.providerOptions.google).not.toHaveProperty('seed')
+ })
+
+ it('should handle stopSequences in custom parameters', async () => {
+ const { getCustomParameters } = await import('../reasoning')
+
+ vi.mocked(getCustomParameters).mockReturnValue({
+ stopSequences: ['STOP', 'END'],
+ custom_param: 'value'
+ })
+
+ const result = buildProviderOptions(
+ mockAssistant,
+ mockModel,
+ {
+ id: SystemProviderIds.gemini,
+ name: 'Google',
+ type: 'gemini',
+ apiKey: 'test-key',
+ apiHost: 'https://generativelanguage.googleapis.com'
+ } as Provider,
+ {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ }
+ )
+
+ expect(result.standardParams).toEqual({
+ stopSequences: ['STOP', 'END']
+ })
+ expect(result.providerOptions.google).not.toHaveProperty('stopSequences')
+ })
+ })
+
+ describe('Multiple capabilities', () => {
+ const googleProvider = {
+ id: SystemProviderIds.gemini,
+ name: 'Google',
+ type: 'gemini',
+ apiKey: 'test-key',
+ apiHost: 'https://generativelanguage.googleapis.com',
+ isSystem: true,
+ models: [] as Model[]
+ } as Provider
+
+ const googleModel: Model = {
+ id: 'gemini-2.0-flash-exp',
+ name: 'Gemini 2.0 Flash',
+ provider: SystemProviderIds.gemini
+ } as Model
+
+ it('should combine reasoning and image generation', () => {
+ const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
+ enableReasoning: true,
+ enableWebSearch: false,
+ enableGenerateImage: true
+ })
+
+ expect(result.providerOptions.google).toHaveProperty('thinkingConfig')
+ expect(result.providerOptions.google).toHaveProperty('responseModalities')
+ })
+
+ it('should handle all capabilities enabled', () => {
+ const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
+ enableReasoning: true,
+ enableWebSearch: true,
+ enableGenerateImage: true
+ })
+
+ expect(result.providerOptions.google).toBeDefined()
+ expect(Object.keys(result.providerOptions.google).length).toBeGreaterThan(0)
+ })
+ })
+
+ describe('Vertex AI providers', () => {
+ it('should map google-vertex to google', () => {
+ const vertexProvider = {
+ id: 'google-vertex',
+ name: 'Vertex AI',
+ type: 'vertexai',
+ apiKey: 'test-key',
+ apiHost: 'https://vertex-ai.googleapis.com',
+ models: [] as Model[]
+ } as Provider
+
+ const vertexModel: Model = {
+ id: 'gemini-2.0-flash-exp',
+ name: 'Gemini 2.0 Flash',
+ provider: 'google-vertex'
+ } as Model
+
+ const result = buildProviderOptions(mockAssistant, vertexModel, vertexProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions).toHaveProperty('google')
+ })
+
+ it('should map google-vertex-anthropic to anthropic', () => {
+ const vertexAnthropicProvider = {
+ id: 'google-vertex-anthropic',
+ name: 'Vertex AI Anthropic',
+ type: 'vertex-anthropic',
+ apiKey: 'test-key',
+ apiHost: 'https://vertex-ai.googleapis.com',
+ models: [] as Model[]
+ } as Provider
+
+ const vertexModel: Model = {
+ id: 'claude-3-5-sonnet-20241022',
+ name: 'Claude 3.5 Sonnet',
+ provider: 'google-vertex-anthropic'
+ } as Model
+
+ const result = buildProviderOptions(mockAssistant, vertexModel, vertexAnthropicProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions).toHaveProperty('anthropic')
+ })
+ })
+
+ describe('AWS Bedrock provider', () => {
+ const bedrockProvider = {
+ id: 'bedrock',
+ name: 'AWS Bedrock',
+ type: 'aws-bedrock',
+ apiKey: 'test-key',
+ apiHost: 'https://bedrock.us-east-1.amazonaws.com',
+ models: [] as Model[]
+ } as Provider
+
+ const bedrockModel: Model = {
+ id: 'anthropic.claude-sonnet-4-20250514-v1:0',
+ name: 'Claude Sonnet 4',
+ provider: 'bedrock'
+ } as Model
+
+ it('should build basic Bedrock options', () => {
+ const result = buildProviderOptions(mockAssistant, bedrockModel, bedrockProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions).toHaveProperty('bedrock')
+ expect(result.providerOptions.bedrock).toBeDefined()
+ })
+
+ it('should include anthropicBeta when Anthropic headers are needed', async () => {
+ const { addAnthropicHeaders } = await import('../../prepareParams/header')
+ vi.mocked(addAnthropicHeaders).mockReturnValue(['interleaved-thinking-2025-05-14', 'context-1m-2025-08-07'])
+
+ const result = buildProviderOptions(mockAssistant, bedrockModel, bedrockProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions.bedrock).toHaveProperty('anthropicBeta')
+ expect(result.providerOptions.bedrock.anthropicBeta).toEqual([
+ 'interleaved-thinking-2025-05-14',
+ 'context-1m-2025-08-07'
+ ])
+ })
+
+ it('should include reasoning parameters when enabled', () => {
+ const result = buildProviderOptions(mockAssistant, bedrockModel, bedrockProvider, {
+ enableReasoning: true,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions.bedrock).toHaveProperty('reasoningConfig')
+ expect(result.providerOptions.bedrock.reasoningConfig).toEqual({
+ type: 'enabled',
+ budgetTokens: 5000
+ })
+ })
+ })
+
+ describe('AI Gateway provider', () => {
+ const gatewayProvider: Provider = {
+ id: SystemProviderIds.gateway,
+ name: 'Vercel AI Gateway',
+ type: 'gateway',
+ apiKey: 'test-key',
+ apiHost: 'https://gateway.vercel.com',
+ isSystem: true
+ } as Provider
+
+ it('should build OpenAI options for OpenAI models through gateway', () => {
+ const openaiModel: Model = {
+ id: 'openai/gpt-4',
+ name: 'GPT-4',
+ provider: SystemProviderIds.gateway
+ } as Model
+
+ const result = buildProviderOptions(mockAssistant, openaiModel, gatewayProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions).toHaveProperty('openai')
+ expect(result.providerOptions.openai).toBeDefined()
+ })
+
+ it('should build Anthropic options for Anthropic models through gateway', () => {
+ const anthropicModel: Model = {
+ id: 'anthropic/claude-3-5-sonnet-20241022',
+ name: 'Claude 3.5 Sonnet',
+ provider: SystemProviderIds.gateway
+ } as Model
+
+ const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions).toHaveProperty('anthropic')
+ expect(result.providerOptions.anthropic).toBeDefined()
+ })
+
+ it('should build Google options for Gemini models through gateway', () => {
+ const geminiModel: Model = {
+ id: 'google/gemini-2.0-flash-exp',
+ name: 'Gemini 2.0 Flash',
+ provider: SystemProviderIds.gateway
+ } as Model
+
+ const result = buildProviderOptions(mockAssistant, geminiModel, gatewayProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions).toHaveProperty('google')
+ expect(result.providerOptions.google).toBeDefined()
+ })
+
+ it('should build xAI options for Grok models through gateway', () => {
+ const grokModel: Model = {
+ id: 'xai/grok-2-latest',
+ name: 'Grok 2',
+ provider: SystemProviderIds.gateway
+ } as Model
+
+ const result = buildProviderOptions(mockAssistant, grokModel, gatewayProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions).toHaveProperty('xai')
+ expect(result.providerOptions.xai).toBeDefined()
+ })
+
+ it('should include reasoning parameters for Anthropic models when enabled', () => {
+ const anthropicModel: Model = {
+ id: 'anthropic/claude-3-5-sonnet-20241022',
+ name: 'Claude 3.5 Sonnet',
+ provider: SystemProviderIds.gateway
+ } as Model
+
+ const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, {
+ enableReasoning: true,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions.anthropic).toHaveProperty('thinking')
+ expect(result.providerOptions.anthropic.thinking).toEqual({
+ type: 'enabled',
+ budgetTokens: 5000
+ })
+ })
+
+ it('should merge gateway routing options from custom parameters', async () => {
+ const { getCustomParameters } = await import('../reasoning')
+
+ vi.mocked(getCustomParameters).mockReturnValue({
+ gateway: {
+ order: ['vertex', 'anthropic'],
+ only: ['vertex', 'anthropic']
+ }
+ })
+
+ const anthropicModel: Model = {
+ id: 'anthropic/claude-3-5-sonnet-20241022',
+ name: 'Claude 3.5 Sonnet',
+ provider: SystemProviderIds.gateway
+ } as Model
+
+ const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ // Should have both anthropic provider options and gateway routing options
+ expect(result.providerOptions).toHaveProperty('anthropic')
+ expect(result.providerOptions).toHaveProperty('gateway')
+ expect(result.providerOptions.gateway).toEqual({
+ order: ['vertex', 'anthropic'],
+ only: ['vertex', 'anthropic']
+ })
+ })
+
+ it('should combine provider-specific options with gateway routing options', async () => {
+ const { getCustomParameters } = await import('../reasoning')
+
+ vi.mocked(getCustomParameters).mockReturnValue({
+ gateway: {
+ order: ['openai', 'anthropic']
+ }
+ })
+
+ const openaiModel: Model = {
+ id: 'openai/gpt-4',
+ name: 'GPT-4',
+ provider: SystemProviderIds.gateway
+ } as Model
+
+ const result = buildProviderOptions(mockAssistant, openaiModel, gatewayProvider, {
+ enableReasoning: true,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ // Should have OpenAI provider options with reasoning
+ expect(result.providerOptions.openai).toBeDefined()
+ expect(result.providerOptions.openai).toHaveProperty('reasoningEffort')
+
+ // Should also have gateway routing options
+ expect(result.providerOptions.gateway).toBeDefined()
+ expect(result.providerOptions.gateway.order).toEqual(['openai', 'anthropic'])
+ })
+
+ it('should build generic options for unknown model types through gateway', () => {
+ const unknownModel: Model = {
+ id: 'unknown-provider/model-name',
+ name: 'Unknown Model',
+ provider: SystemProviderIds.gateway
+ } as Model
+
+ const result = buildProviderOptions(mockAssistant, unknownModel, gatewayProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ expect(result.providerOptions).toHaveProperty('openai-compatible')
+ expect(result.providerOptions['openai-compatible']).toBeDefined()
+ })
+ })
+
+ describe('Proxy provider custom parameters mapping', () => {
+ it('should map cherryin provider ID to actual AI SDK provider ID (Google)', async () => {
+ const { getCustomParameters } = await import('../reasoning')
+
+ // Mock Cherry In provider that uses Google SDK
+ const cherryinProvider = {
+ id: 'cherryin',
+ name: 'Cherry In',
+ type: 'gemini', // Using Google SDK
+ apiKey: 'test-key',
+ apiHost: 'https://cherryin.com',
+ models: [] as Model[]
+ } as Provider
+
+ const geminiModel: Model = {
+ id: 'gemini-2.0-flash-exp',
+ name: 'Gemini 2.0 Flash',
+ provider: 'cherryin'
+ } as Model
+
+ // User provides custom parameters with Cherry Studio provider ID
+ vi.mocked(getCustomParameters).mockReturnValue({
+ cherryin: {
+ customOption1: 'value1',
+ customOption2: 'value2'
+ }
+ })
+
+ const result = buildProviderOptions(mockAssistant, geminiModel, cherryinProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ // Should map to 'google' AI SDK provider, not 'cherryin'
+ expect(result.providerOptions).toHaveProperty('google')
+ expect(result.providerOptions).not.toHaveProperty('cherryin')
+ expect(result.providerOptions.google).toMatchObject({
+ customOption1: 'value1',
+ customOption2: 'value2'
+ })
+ })
+
+ it('should map cherryin provider ID to actual AI SDK provider ID (OpenAI)', async () => {
+ const { getCustomParameters } = await import('../reasoning')
+
+ // Mock Cherry In provider that uses OpenAI SDK
+ const cherryinProvider = {
+ id: 'cherryin',
+ name: 'Cherry In',
+ type: 'openai-response', // Using OpenAI SDK
+ apiKey: 'test-key',
+ apiHost: 'https://cherryin.com',
+ models: [] as Model[]
+ } as Provider
+
+ const openaiModel: Model = {
+ id: 'gpt-4',
+ name: 'GPT-4',
+ provider: 'cherryin'
+ } as Model
+
+ // User provides custom parameters with Cherry Studio provider ID
+ vi.mocked(getCustomParameters).mockReturnValue({
+ cherryin: {
+ customOpenAIOption: 'openai_value'
+ }
+ })
+
+ const result = buildProviderOptions(mockAssistant, openaiModel, cherryinProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ // Should map to 'openai' AI SDK provider, not 'cherryin'
+ expect(result.providerOptions).toHaveProperty('openai')
+ expect(result.providerOptions).not.toHaveProperty('cherryin')
+ expect(result.providerOptions.openai).toMatchObject({
+ customOpenAIOption: 'openai_value'
+ })
+ })
+
+ it('should allow direct AI SDK provider ID in custom parameters', async () => {
+ const { getCustomParameters } = await import('../reasoning')
+
+ const geminiProvider = {
+ id: SystemProviderIds.gemini,
+ name: 'Google',
+ type: 'gemini',
+ apiKey: 'test-key',
+ apiHost: 'https://generativelanguage.googleapis.com',
+ models: [] as Model[]
+ } as Provider
+
+ const geminiModel: Model = {
+ id: 'gemini-2.0-flash-exp',
+ name: 'Gemini 2.0 Flash',
+ provider: SystemProviderIds.gemini
+ } as Model
+
+ // User provides custom parameters directly with AI SDK provider ID
+ vi.mocked(getCustomParameters).mockReturnValue({
+ google: {
+ directGoogleOption: 'google_value'
+ }
+ })
+
+ const result = buildProviderOptions(mockAssistant, geminiModel, geminiProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ // Should merge directly to 'google' provider
+ expect(result.providerOptions.google).toMatchObject({
+ directGoogleOption: 'google_value'
+ })
+ })
+
+ it('should map gateway provider custom parameters to actual AI SDK provider', async () => {
+ const { getCustomParameters } = await import('../reasoning')
+
+ const gatewayProvider: Provider = {
+ id: SystemProviderIds.gateway,
+ name: 'Vercel AI Gateway',
+ type: 'gateway',
+ apiKey: 'test-key',
+ apiHost: 'https://gateway.vercel.com',
+ isSystem: true
+ } as Provider
+
+ const anthropicModel: Model = {
+ id: 'anthropic/claude-3-5-sonnet-20241022',
+ name: 'Claude 3.5 Sonnet',
+ provider: SystemProviderIds.gateway
+ } as Model
+
+ // User provides both gateway routing options and gateway-scoped custom parameters
+ vi.mocked(getCustomParameters).mockReturnValue({
+ gateway: {
+ order: ['vertex', 'anthropic'],
+ only: ['vertex']
+ },
+ customParam: 'should_go_to_anthropic'
+ })
+
+ const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ // Gateway routing options should be preserved
+ expect(result.providerOptions.gateway).toEqual({
+ order: ['vertex', 'anthropic'],
+ only: ['vertex']
+ })
+
+ // Custom parameters should go to the actual AI SDK provider (anthropic)
+ expect(result.providerOptions.anthropic).toMatchObject({
+ customParam: 'should_go_to_anthropic'
+ })
+ })
+
+ it('should handle mixed custom parameters (AI SDK provider ID + custom params)', async () => {
+ const { getCustomParameters } = await import('../reasoning')
+
+ const openaiProvider: Provider = {
+ id: SystemProviderIds.openai,
+ name: 'OpenAI',
+ type: 'openai-response',
+ apiKey: 'test-key',
+ apiHost: 'https://api.openai.com/v1',
+ isSystem: true
+ } as Provider
+
+ // User provides both direct AI SDK provider params and custom params
+ vi.mocked(getCustomParameters).mockReturnValue({
+ openai: {
+ providerSpecific: 'value1'
+ },
+ customParam1: 'value2',
+ customParam2: 123
+ })
+
+ const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ // Should merge both into 'openai' provider options
+ expect(result.providerOptions.openai).toMatchObject({
+ providerSpecific: 'value1',
+ customParam1: 'value2',
+ customParam2: 123
+ })
+ })
+
+ // Note: For proxy providers like aihubmix/newapi, users should write AI SDK provider ID (google/anthropic)
+ // instead of the Cherry Studio provider ID for custom parameters to work correctly
+
+ it('should handle cherryin fallback to openai-compatible with custom parameters', async () => {
+ const { getCustomParameters } = await import('../reasoning')
+
+ // Mock cherryin provider that falls back to openai-compatible (default case)
+ const cherryinProvider = {
+ id: 'cherryin',
+ name: 'Cherry In',
+ type: 'openai',
+ apiKey: 'test-key',
+ apiHost: 'https://cherryin.com',
+ models: [] as Model[]
+ } as Provider
+
+ const testModel: Model = {
+ id: 'some-model',
+ name: 'Some Model',
+ provider: 'cherryin'
+ } as Model
+
+ // User provides custom parameters with cherryin provider ID
+ vi.mocked(getCustomParameters).mockReturnValue({
+ customCherryinOption: 'cherryin_value'
+ })
+
+ const result = buildProviderOptions(mockAssistant, testModel, cherryinProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ // When cherryin falls back to default case, it should use rawProviderId (cherryin)
+ // User's cherryin params should merge with the provider options
+ expect(result.providerOptions).toHaveProperty('cherryin')
+ expect(result.providerOptions.cherryin).toMatchObject({
+ customCherryinOption: 'cherryin_value'
+ })
+ })
+
+ it('should handle cross-provider configurations', async () => {
+ const { getCustomParameters } = await import('../reasoning')
+
+ const openaiProvider: Provider = {
+ id: SystemProviderIds.openai,
+ name: 'OpenAI',
+ type: 'openai-response',
+ apiKey: 'test-key',
+ apiHost: 'https://api.openai.com/v1',
+ isSystem: true
+ } as Provider
+
+ // User provides parameters for multiple providers
+ // In real usage, anthropic/google params would be treated as regular params for openai provider
+ vi.mocked(getCustomParameters).mockReturnValue({
+ openai: {
+ openaiSpecific: 'openai_value'
+ },
+ customParam: 'value'
+ })
+
+ const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
+ enableReasoning: false,
+ enableWebSearch: false,
+ enableGenerateImage: false
+ })
+
+ // Should have openai provider options with both scoped and custom params
+ expect(result.providerOptions).toHaveProperty('openai')
+ expect(result.providerOptions.openai).toMatchObject({
+ openaiSpecific: 'openai_value',
+ customParam: 'value'
+ })
+ })
+ })
+ })
+})
diff --git a/src/renderer/src/aiCore/utils/__tests__/reasoning.poe.test.ts b/src/renderer/src/aiCore/utils/__tests__/reasoning.poe.test.ts
new file mode 100644
index 0000000000..90876998da
--- /dev/null
+++ b/src/renderer/src/aiCore/utils/__tests__/reasoning.poe.test.ts
@@ -0,0 +1,288 @@
+import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types'
+import { SystemProviderIds } from '@renderer/types'
+import { describe, expect, it, vi } from 'vitest'
+
+import { getReasoningEffort } from '../reasoning'
+
+// Mock logger
+vi.mock('@logger', () => ({
+ loggerService: {
+ withContext: () => ({
+ warn: vi.fn(),
+ info: vi.fn(),
+ error: vi.fn()
+ })
+ }
+}))
+
+vi.mock('@renderer/store/settings', () => ({
+ default: {},
+ settingsSlice: {
+ name: 'settings',
+ reducer: vi.fn(),
+ actions: {}
+ }
+}))
+
+vi.mock('@renderer/store/assistants', () => {
+ const mockAssistantsSlice = {
+ name: 'assistants',
+ reducer: vi.fn((state = { entities: {}, ids: [] }) => state),
+ actions: {
+ updateTopicUpdatedAt: vi.fn(() => ({ type: 'UPDATE_TOPIC_UPDATED_AT' }))
+ }
+ }
+
+ return {
+ default: mockAssistantsSlice.reducer,
+ updateTopicUpdatedAt: vi.fn(() => ({ type: 'UPDATE_TOPIC_UPDATED_AT' })),
+ assistantsSlice: mockAssistantsSlice
+ }
+})
+
+// Mock provider service
+vi.mock('@renderer/services/AssistantService', () => ({
+ getProviderByModel: (model: Model) => ({
+ id: model.provider,
+ name: 'Poe',
+ type: 'openai'
+ }),
+ getAssistantSettings: (assistant: Assistant) => assistant.settings || {}
+}))
+
+describe('Poe Provider Reasoning Support', () => {
+ const createPoeModel = (id: string): Model => ({
+ id,
+ name: id,
+ provider: SystemProviderIds.poe,
+ group: 'poe'
+ })
+
+ const createAssistant = (reasoning_effort?: ReasoningEffortOption, maxTokens?: number): Assistant => ({
+ id: 'test-assistant',
+ name: 'Test Assistant',
+ emoji: '🤖',
+ prompt: '',
+ topics: [],
+ messages: [],
+ type: 'assistant',
+ regularPhrases: [],
+ settings: {
+ reasoning_effort,
+ maxTokens
+ }
+ })
+
+ describe('GPT-5 Series Models', () => {
+ it('should return reasoning_effort in extra_body for GPT-5 model with low effort', () => {
+ const model = createPoeModel('gpt-5')
+ const assistant = createAssistant('low')
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result).toEqual({
+ extra_body: {
+ reasoning_effort: 'low'
+ }
+ })
+ })
+
+ it('should return reasoning_effort in extra_body for GPT-5 model with medium effort', () => {
+ const model = createPoeModel('gpt-5')
+ const assistant = createAssistant('medium')
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result).toEqual({
+ extra_body: {
+ reasoning_effort: 'medium'
+ }
+ })
+ })
+
+ it('should return reasoning_effort in extra_body for GPT-5 model with high effort', () => {
+ const model = createPoeModel('gpt-5')
+ const assistant = createAssistant('high')
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result).toEqual({
+ extra_body: {
+ reasoning_effort: 'high'
+ }
+ })
+ })
+
+ it('should convert auto to medium for GPT-5 model in extra_body', () => {
+ const model = createPoeModel('gpt-5')
+ const assistant = createAssistant('auto')
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result).toEqual({
+ extra_body: {
+ reasoning_effort: 'medium'
+ }
+ })
+ })
+
+ it('should return reasoning_effort in extra_body for GPT-5.1 model', () => {
+ const model = createPoeModel('gpt-5.1')
+ const assistant = createAssistant('medium')
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result).toEqual({
+ extra_body: {
+ reasoning_effort: 'medium'
+ }
+ })
+ })
+ })
+
+ describe('Claude Models', () => {
+ it('should return thinking_budget in extra_body for Claude 3.7 Sonnet', () => {
+ const model = createPoeModel('claude-3.7-sonnet')
+ const assistant = createAssistant('medium', 4096)
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result).toHaveProperty('extra_body')
+ expect(result.extra_body).toHaveProperty('thinking_budget')
+ expect(typeof result.extra_body?.thinking_budget).toBe('number')
+ expect(result.extra_body?.thinking_budget).toBeGreaterThan(0)
+ })
+
+ it('should return thinking_budget in extra_body for Claude Sonnet 4', () => {
+ const model = createPoeModel('claude-sonnet-4')
+ const assistant = createAssistant('high', 8192)
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result).toHaveProperty('extra_body')
+ expect(result.extra_body).toHaveProperty('thinking_budget')
+ expect(typeof result.extra_body?.thinking_budget).toBe('number')
+ })
+
+ it('should calculate thinking_budget based on effort ratio and maxTokens', () => {
+ const model = createPoeModel('claude-3.7-sonnet')
+ const assistant = createAssistant('low', 4096)
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result.extra_body?.thinking_budget).toBeGreaterThanOrEqual(1024)
+ })
+ })
+
+ describe('Gemini Models', () => {
+ it('should return thinking_budget in extra_body for Gemini 2.5 Flash', () => {
+ const model = createPoeModel('gemini-2.5-flash')
+ const assistant = createAssistant('medium')
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result).toHaveProperty('extra_body')
+ expect(result.extra_body).toHaveProperty('thinking_budget')
+ expect(typeof result.extra_body?.thinking_budget).toBe('number')
+ })
+
+ it('should return thinking_budget in extra_body for Gemini 2.5 Pro', () => {
+ const model = createPoeModel('gemini-2.5-pro')
+ const assistant = createAssistant('high')
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result).toHaveProperty('extra_body')
+ expect(result.extra_body).toHaveProperty('thinking_budget')
+ })
+
+ it('should use -1 for auto effort', () => {
+ const model = createPoeModel('gemini-2.5-flash')
+ const assistant = createAssistant('auto')
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result.extra_body?.thinking_budget).toBe(-1)
+ })
+
+ it('should calculate thinking_budget for non-auto effort', () => {
+ const model = createPoeModel('gemini-2.5-flash')
+ const assistant = createAssistant('low')
+ const result = getReasoningEffort(assistant, model)
+
+ expect(typeof result.extra_body?.thinking_budget).toBe('number')
+ })
+ })
+
+ describe('No Reasoning Effort', () => {
+ it('should return empty object when reasoning_effort is not set', () => {
+ const model = createPoeModel('gpt-5')
+ const assistant = createAssistant(undefined)
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result).toEqual({})
+ })
+
+ it('should return empty object when reasoning_effort is "none"', () => {
+ const model = createPoeModel('gpt-5')
+ const assistant = createAssistant('none')
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result).toEqual({})
+ })
+ })
+
+ describe('Non-Reasoning Models', () => {
+ it('should return empty object for non-reasoning models', () => {
+ const model = createPoeModel('gpt-4')
+ const assistant = createAssistant('medium')
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result).toEqual({})
+ })
+ })
+
+ describe('Edge Cases: Models Without Token Limit Configuration', () => {
+ it('should return empty object for Claude models without token limit configuration', () => {
+ const model = createPoeModel('claude-unknown-variant')
+ const assistant = createAssistant('medium', 4096)
+ const result = getReasoningEffort(assistant, model)
+
+ // Should return empty object when token limit is not found
+ expect(result).toEqual({})
+ expect(result.extra_body?.thinking_budget).toBeUndefined()
+ })
+
+ it('should return empty object for unmatched Poe reasoning models', () => {
+ // A hypothetical reasoning model that doesn't match GPT-5, Claude, or Gemini
+ const model = createPoeModel('some-reasoning-model')
+ // Make it appear as a reasoning model by giving it a name that won't match known categories
+ const assistant = createAssistant('medium')
+ const result = getReasoningEffort(assistant, model)
+
+ // Should return empty object for unmatched models
+ expect(result).toEqual({})
+ })
+
+ it('should fallback to -1 for Gemini models without token limit', () => {
+ // Use a Gemini model variant that won't match any token limit pattern
+ // The current regex patterns cover gemini-.*-flash.*$ and gemini-.*-pro.*$
+ // so we need a model that matches isSupportedThinkingTokenGeminiModel but not THINKING_TOKEN_MAP
+ const model = createPoeModel('gemini-2.5-flash')
+ const assistant = createAssistant('auto')
+ const result = getReasoningEffort(assistant, model)
+
+ // For 'auto' effort, should use -1
+ expect(result.extra_body?.thinking_budget).toBe(-1)
+ })
+
+ it('should enforce minimum 1024 token floor for Claude models', () => {
+ const model = createPoeModel('claude-3.7-sonnet')
+ // Use very small maxTokens to test the minimum floor
+ const assistant = createAssistant('low', 100)
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result.extra_body?.thinking_budget).toBeGreaterThanOrEqual(1024)
+ })
+
+ it('should handle undefined maxTokens for Claude models', () => {
+ const model = createPoeModel('claude-3.7-sonnet')
+ const assistant = createAssistant('medium', undefined)
+ const result = getReasoningEffort(assistant, model)
+
+ expect(result).toHaveProperty('extra_body')
+ expect(result.extra_body).toHaveProperty('thinking_budget')
+ expect(typeof result.extra_body?.thinking_budget).toBe('number')
+ expect(result.extra_body?.thinking_budget).toBeGreaterThanOrEqual(1024)
+ })
+ })
+})
diff --git a/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts b/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts
new file mode 100644
index 0000000000..36253e5c1d
--- /dev/null
+++ b/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts
@@ -0,0 +1,992 @@
+/**
+ * reasoning.ts Unit Tests
+ * Tests for reasoning parameter generation utilities
+ */
+
+import { getStoreSetting } from '@renderer/hooks/useSettings'
+import type { SettingsState } from '@renderer/store/settings'
+import type { Assistant, Model, Provider } from '@renderer/types'
+import { SystemProviderIds } from '@renderer/types'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+import {
+ getAnthropicReasoningParams,
+ getBedrockReasoningParams,
+ getCustomParameters,
+ getGeminiReasoningParams,
+ getOpenAIReasoningParams,
+ getReasoningEffort,
+ getXAIReasoningParams
+} from '../reasoning'
+
+function defaultGetStoreSetting(key: K): SettingsState[K] {
+ if (key === 'openAI') {
+ return {
+ summaryText: 'auto',
+ verbosity: 'medium'
+ } as SettingsState[K]
+ }
+ return undefined as SettingsState[K]
+}
+
+// Mock dependencies
+vi.mock('@logger', () => ({
+ loggerService: {
+ withContext: () => ({
+ debug: vi.fn(),
+ error: vi.fn(),
+ warn: vi.fn(),
+ info: vi.fn()
+ })
+ }
+}))
+
+vi.mock('@renderer/store/settings', () => ({
+ default: (state = { settings: {} }) => state
+}))
+
+vi.mock('@renderer/store/llm', () => ({
+ initialState: {},
+ default: (state = { llm: {} }) => state
+}))
+
+vi.mock('@renderer/config/constant', () => ({
+ DEFAULT_MAX_TOKENS: 4096,
+ isMac: false,
+ isWin: false,
+ TOKENFLUX_HOST: 'mock-host'
+}))
+
+vi.mock('@renderer/utils/provider', () => ({
+ isSupportEnableThinkingProvider: vi.fn((provider) => {
+ return [SystemProviderIds.dashscope, SystemProviderIds.silicon].includes(provider.id)
+ })
+}))
+
+vi.mock('@renderer/config/models', async (importOriginal) => {
+ const actual: any = await importOriginal()
+ return {
+ ...actual,
+ isReasoningModel: vi.fn(() => false),
+ isOpenAIDeepResearchModel: vi.fn(() => false),
+ isOpenAIModel: vi.fn(() => false),
+ isSupportedReasoningEffortOpenAIModel: vi.fn(() => false),
+ isSupportedThinkingTokenQwenModel: vi.fn(() => false),
+ isQwenReasoningModel: vi.fn(() => false),
+ isSupportedThinkingTokenClaudeModel: vi.fn(() => false),
+ isSupportedThinkingTokenGeminiModel: vi.fn(() => false),
+ isSupportedThinkingTokenDoubaoModel: vi.fn(() => false),
+ isSupportedThinkingTokenZhipuModel: vi.fn(() => false),
+ isSupportedReasoningEffortModel: vi.fn(() => false),
+ isDeepSeekHybridInferenceModel: vi.fn(() => false),
+ isSupportedReasoningEffortGrokModel: vi.fn(() => false),
+ getThinkModelType: vi.fn(() => 'default'),
+ isDoubaoSeedAfter251015: vi.fn(() => false),
+ isDoubaoThinkingAutoModel: vi.fn(() => false),
+ isGrok4FastReasoningModel: vi.fn(() => false),
+ isGrokReasoningModel: vi.fn(() => false),
+ isOpenAIReasoningModel: vi.fn(() => false),
+ isQwenAlwaysThinkModel: vi.fn(() => false),
+ isSupportedThinkingTokenHunyuanModel: vi.fn(() => false),
+ isSupportedThinkingTokenModel: vi.fn(() => false),
+ isGPT51SeriesModel: vi.fn(() => false)
+ }
+})
+
+vi.mock('@renderer/hooks/useSettings', () => ({
+ getStoreSetting: vi.fn(defaultGetStoreSetting)
+}))
+
+vi.mock('@renderer/services/AssistantService', () => ({
+ getAssistantSettings: vi.fn((assistant) => ({
+ maxTokens: assistant?.settings?.maxTokens || 4096,
+ reasoning_effort: assistant?.settings?.reasoning_effort
+ })),
+ getProviderByModel: vi.fn((model) => ({
+ id: model.provider,
+ name: 'Test Provider'
+ })),
+ getDefaultAssistant: vi.fn(() => ({
+ id: 'default',
+ name: 'Default Assistant',
+ settings: {}
+ }))
+}))
+
+const ensureWindowApi = () => {
+ const globalWindow = window as any
+ globalWindow.api = globalWindow.api || {}
+ globalWindow.api.getAppInfo = globalWindow.api.getAppInfo || vi.fn(async () => ({ notesPath: '' }))
+}
+
+ensureWindowApi()
+
+describe('reasoning utils', () => {
+ beforeEach(() => {
+ vi.resetAllMocks()
+ })
+
+ describe('getReasoningEffort', () => {
+ it('should return empty object for non-reasoning model', async () => {
+ const model: Model = {
+ id: 'gpt-4',
+ name: 'GPT-4',
+ provider: SystemProviderIds.openai
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getReasoningEffort(assistant, model)
+ expect(result).toEqual({})
+ })
+
+ it('should not override reasoning for OpenRouter when reasoning effort undefined', async () => {
+ const { isReasoningModel } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'anthropic/claude-sonnet-4',
+ name: 'Claude Sonnet 4',
+ provider: SystemProviderIds.openrouter
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getReasoningEffort(assistant, model)
+ expect(result).toEqual({})
+ })
+
+ it('should disable reasoning for OpenRouter when reasoning effort explicitly none', async () => {
+ const { isReasoningModel } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'anthropic/claude-sonnet-4',
+ name: 'Claude Sonnet 4',
+ provider: SystemProviderIds.openrouter
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'none'
+ }
+ } as Assistant
+
+ const result = getReasoningEffort(assistant, model)
+ expect(result).toEqual({ reasoning: { enabled: false, exclude: true } })
+ })
+
+ it('should handle Qwen models with enable_thinking', async () => {
+ const { isReasoningModel, isSupportedThinkingTokenQwenModel, isQwenReasoningModel } = await import(
+ '@renderer/config/models'
+ )
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isSupportedThinkingTokenQwenModel).mockReturnValue(true)
+ vi.mocked(isQwenReasoningModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'qwen-plus',
+ name: 'Qwen Plus',
+ provider: SystemProviderIds.dashscope
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'medium'
+ }
+ } as Assistant
+
+ const result = getReasoningEffort(assistant, model)
+ expect(result).toHaveProperty('enable_thinking')
+ })
+
+ it('should handle Claude models with thinking config', async () => {
+ const {
+ isSupportedThinkingTokenClaudeModel,
+ isReasoningModel,
+ isQwenReasoningModel,
+ isSupportedThinkingTokenGeminiModel,
+ isSupportedThinkingTokenDoubaoModel,
+ isSupportedThinkingTokenZhipuModel,
+ isSupportedReasoningEffortModel
+ } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
+ vi.mocked(isQwenReasoningModel).mockReturnValue(false)
+ vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(false)
+ vi.mocked(isSupportedThinkingTokenDoubaoModel).mockReturnValue(false)
+ vi.mocked(isSupportedThinkingTokenZhipuModel).mockReturnValue(false)
+ vi.mocked(isSupportedReasoningEffortModel).mockReturnValue(false)
+
+ const model: Model = {
+ id: 'claude-3-7-sonnet',
+ name: 'Claude 3.7 Sonnet',
+ provider: SystemProviderIds.anthropic
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'high',
+ maxTokens: 4096
+ }
+ } as Assistant
+
+ const result = getReasoningEffort(assistant, model)
+ expect(result).toEqual({
+ thinking: {
+ type: 'enabled',
+ budget_tokens: expect.any(Number)
+ }
+ })
+ })
+
+ it('should handle Gemini Flash models with thinking budget 0', async () => {
+ const {
+ isSupportedThinkingTokenGeminiModel,
+ isReasoningModel,
+ isQwenReasoningModel,
+ isSupportedThinkingTokenClaudeModel,
+ isSupportedThinkingTokenDoubaoModel,
+ isSupportedThinkingTokenZhipuModel,
+ isOpenAIDeepResearchModel,
+ isSupportedThinkingTokenQwenModel,
+ isSupportedThinkingTokenHunyuanModel,
+ isDeepSeekHybridInferenceModel
+ } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
+ vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
+ vi.mocked(isQwenReasoningModel).mockReturnValue(false)
+ vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(false)
+ vi.mocked(isSupportedThinkingTokenDoubaoModel).mockReturnValue(false)
+ vi.mocked(isSupportedThinkingTokenZhipuModel).mockReturnValue(false)
+ vi.mocked(isSupportedThinkingTokenQwenModel).mockReturnValue(false)
+ vi.mocked(isSupportedThinkingTokenHunyuanModel).mockReturnValue(false)
+ vi.mocked(isDeepSeekHybridInferenceModel).mockReturnValue(false)
+
+ const model: Model = {
+ id: 'gemini-2.5-flash',
+ name: 'Gemini 2.5 Flash',
+ provider: SystemProviderIds.openai
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'none'
+ }
+ } as Assistant
+
+ const result = getReasoningEffort(assistant, model)
+ expect(result).toEqual({
+ extra_body: {
+ google: {
+ thinking_config: {
+ thinking_budget: 0
+ }
+ }
+ }
+ })
+ })
+
+ it('should handle GPT-5.1 reasoning model with effort levels', async () => {
+ const {
+ isReasoningModel,
+ isOpenAIDeepResearchModel,
+ isSupportedReasoningEffortModel,
+ isGPT51SeriesModel,
+ getThinkModelType
+ } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
+ vi.mocked(isSupportedReasoningEffortModel).mockReturnValue(true)
+ vi.mocked(getThinkModelType).mockReturnValue('gpt5_1')
+ vi.mocked(isGPT51SeriesModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'gpt-5.1',
+ name: 'GPT-5.1',
+ provider: SystemProviderIds.openai
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'none'
+ }
+ } as Assistant
+
+ const result = getReasoningEffort(assistant, model)
+ expect(result).toEqual({
+ reasoningEffort: 'none'
+ })
+ })
+
+ it('should handle DeepSeek hybrid inference models', async () => {
+ const { isReasoningModel, isDeepSeekHybridInferenceModel } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isDeepSeekHybridInferenceModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'deepseek-v3.1',
+ name: 'DeepSeek V3.1',
+ provider: SystemProviderIds.silicon
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'high'
+ }
+ } as Assistant
+
+ const result = getReasoningEffort(assistant, model)
+ expect(result).toEqual({
+ enable_thinking: true
+ })
+ })
+
+ it('should return medium effort for deep research models', async () => {
+ const { isReasoningModel, isOpenAIDeepResearchModel } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'o3-deep-research',
+ provider: SystemProviderIds.openai
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getReasoningEffort(assistant, model)
+ expect(result).toEqual({ reasoning_effort: 'medium' })
+ })
+
+ it('should return empty for groq provider', async () => {
+ const { getProviderByModel } = await import('@renderer/services/AssistantService')
+
+ vi.mocked(getProviderByModel).mockReturnValue({
+ id: 'groq',
+ name: 'Groq'
+ } as Provider)
+
+ const model: Model = {
+ id: 'groq-model',
+ name: 'Groq Model',
+ provider: 'groq'
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getReasoningEffort(assistant, model)
+ expect(result).toEqual({})
+ })
+ })
+
+ describe('getOpenAIReasoningParams', () => {
+ it('should return empty object for non-reasoning model', async () => {
+ const model: Model = {
+ id: 'gpt-4',
+ name: 'GPT-4',
+ provider: SystemProviderIds.openai
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getOpenAIReasoningParams(assistant, model)
+ expect(result).toEqual({})
+ })
+
+ it('should return empty when no reasoning effort set', async () => {
+ const model: Model = {
+ id: 'o1-preview',
+ name: 'O1 Preview',
+ provider: SystemProviderIds.openai
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getOpenAIReasoningParams(assistant, model)
+ expect(result).toEqual({})
+ })
+
+ it('should return reasoning effort for OpenAI models', async () => {
+ const { isReasoningModel, isOpenAIModel, isSupportedReasoningEffortOpenAIModel } = await import(
+ '@renderer/config/models'
+ )
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isOpenAIModel).mockReturnValue(true)
+ vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'gpt-5.1',
+ name: 'GPT 5.1',
+ provider: SystemProviderIds.openai
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'high'
+ }
+ } as Assistant
+
+ const result = getOpenAIReasoningParams(assistant, model)
+ expect(result).toEqual({
+ reasoningEffort: 'high',
+ reasoningSummary: 'auto'
+ })
+ })
+
+ it('should include reasoning summary when not o1-pro', async () => {
+ const { isReasoningModel, isOpenAIModel, isSupportedReasoningEffortOpenAIModel } = await import(
+ '@renderer/config/models'
+ )
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isOpenAIModel).mockReturnValue(true)
+ vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'gpt-5',
+ provider: SystemProviderIds.openai
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'medium'
+ }
+ } as Assistant
+
+ const result = getOpenAIReasoningParams(assistant, model)
+ expect(result).toEqual({
+ reasoningEffort: 'medium',
+ reasoningSummary: 'auto'
+ })
+ })
+
+ it('should not include reasoning summary for o1-pro', async () => {
+ const { isReasoningModel, isOpenAIDeepResearchModel, isSupportedReasoningEffortOpenAIModel } = await import(
+ '@renderer/config/models'
+ )
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
+ vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
+ vi.mocked(getStoreSetting).mockReturnValue({ summaryText: 'off' } as any)
+
+ const model: Model = {
+ id: 'o1-pro',
+ name: 'O1 Pro',
+ provider: SystemProviderIds.openai
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'high'
+ }
+ } as Assistant
+
+ const result = getOpenAIReasoningParams(assistant, model)
+ expect(result).toEqual({
+ reasoningEffort: 'high',
+ reasoningSummary: undefined
+ })
+ })
+
+ it('should force medium effort for deep research models', async () => {
+ const { isReasoningModel, isOpenAIModel, isOpenAIDeepResearchModel, isSupportedReasoningEffortOpenAIModel } =
+ await import('@renderer/config/models')
+ const { getStoreSetting } = await import('@renderer/hooks/useSettings')
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isOpenAIModel).mockReturnValue(true)
+ vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(true)
+ vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
+ vi.mocked(getStoreSetting).mockReturnValue({ summaryText: 'off' } as any)
+
+ const model: Model = {
+ id: 'o3-deep-research',
+ name: 'O3 Mini',
+ provider: SystemProviderIds.openai
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'high'
+ }
+ } as Assistant
+
+ const result = getOpenAIReasoningParams(assistant, model)
+ expect(result).toEqual({
+ reasoningEffort: 'medium',
+ reasoningSummary: 'off'
+ })
+ })
+ })
+
+ describe('getAnthropicReasoningParams', () => {
+ it('should return empty for non-reasoning model', async () => {
+ const { isReasoningModel } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(false)
+
+ const model: Model = {
+ id: 'claude-3-5-sonnet',
+ name: 'Claude 3.5 Sonnet',
+ provider: SystemProviderIds.anthropic
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getAnthropicReasoningParams(assistant, model)
+ expect(result).toEqual({})
+ })
+
+ it('should return disabled thinking when no reasoning effort', async () => {
+ const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(false)
+
+ const model: Model = {
+ id: 'claude-3-7-sonnet',
+ name: 'Claude 3.7 Sonnet',
+ provider: SystemProviderIds.anthropic
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getAnthropicReasoningParams(assistant, model)
+ expect(result).toEqual({
+ thinking: {
+ type: 'disabled'
+ }
+ })
+ })
+
+ it('should return enabled thinking with budget for Claude models', async () => {
+ const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'claude-3-7-sonnet',
+ name: 'Claude 3.7 Sonnet',
+ provider: SystemProviderIds.anthropic
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'medium',
+ maxTokens: 4096
+ }
+ } as Assistant
+
+ const result = getAnthropicReasoningParams(assistant, model)
+ expect(result).toEqual({
+ thinking: {
+ type: 'enabled',
+ budgetTokens: 2048
+ }
+ })
+ })
+ })
+
+ describe('getGeminiReasoningParams', () => {
+ it('should return empty for non-reasoning model', async () => {
+ const { isReasoningModel } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(false)
+
+ const model: Model = {
+ id: 'gemini-2.0-flash',
+ name: 'Gemini 2.0 Flash',
+ provider: SystemProviderIds.gemini
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getGeminiReasoningParams(assistant, model)
+ expect(result).toEqual({})
+ })
+
+ it('should disable thinking for Flash models without reasoning effort', async () => {
+ const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'gemini-2.5-flash',
+ name: 'Gemini 2.5 Flash',
+ provider: SystemProviderIds.gemini
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getGeminiReasoningParams(assistant, model)
+ expect(result).toEqual({
+ thinkingConfig: {
+ includeThoughts: false,
+ thinkingBudget: 0
+ }
+ })
+ })
+
+ it('should enable thinking with budget for reasoning effort', async () => {
+ const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'gemini-2.5-pro',
+ name: 'Gemini 2.5 Pro',
+ provider: SystemProviderIds.gemini
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'medium'
+ }
+ } as Assistant
+
+ const result = getGeminiReasoningParams(assistant, model)
+ expect(result).toEqual({
+ thinkingConfig: {
+ thinkingBudget: 16448,
+ includeThoughts: true
+ }
+ })
+ })
+
+ it('should enable thinking without budget for auto effort ratio > 1', async () => {
+ const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'gemini-2.5-pro',
+ name: 'Gemini 2.5 Pro',
+ provider: SystemProviderIds.gemini
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'auto'
+ }
+ } as Assistant
+
+ const result = getGeminiReasoningParams(assistant, model)
+ expect(result).toEqual({
+ thinkingConfig: {
+ includeThoughts: true
+ }
+ })
+ })
+ })
+
+ describe('getXAIReasoningParams', () => {
+ it('should return empty for non-Grok model', async () => {
+ const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
+
+ vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(false)
+
+ const model: Model = {
+ id: 'other-model',
+ name: 'Other Model',
+ provider: SystemProviderIds.grok
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getXAIReasoningParams(assistant, model)
+ expect(result).toEqual({})
+ })
+
+ it('should return empty when no reasoning effort', async () => {
+ const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
+
+ vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'grok-2',
+ name: 'Grok 2',
+ provider: SystemProviderIds.grok
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getXAIReasoningParams(assistant, model)
+ expect(result).toEqual({})
+ })
+
+ it('should return reasoning effort for Grok models', async () => {
+ const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
+
+ vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'grok-3',
+ name: 'Grok 3',
+ provider: SystemProviderIds.grok
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'high'
+ }
+ } as Assistant
+
+ const result = getXAIReasoningParams(assistant, model)
+ expect(result).toHaveProperty('reasoningEffort')
+ expect(result.reasoningEffort).toBe('high')
+ })
+ })
+
+ describe('getBedrockReasoningParams', () => {
+ it('should return empty for non-reasoning model', async () => {
+ const model: Model = {
+ id: 'other-model',
+ name: 'Other Model',
+ provider: 'bedrock'
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getBedrockReasoningParams(assistant, model)
+ expect(result).toEqual({})
+ })
+
+ it('should return empty when no reasoning effort', async () => {
+ const model: Model = {
+ id: 'claude-3-7-sonnet',
+ name: 'Claude 3.7 Sonnet',
+ provider: 'bedrock'
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getBedrockReasoningParams(assistant, model)
+ expect(result).toEqual({})
+ })
+
+ it('should return reasoning config for Claude models on Bedrock', async () => {
+ const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
+
+ vi.mocked(isReasoningModel).mockReturnValue(true)
+ vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
+
+ const model: Model = {
+ id: 'claude-3-7-sonnet',
+ name: 'Claude 3.7 Sonnet',
+ provider: 'bedrock'
+ } as Model
+
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ reasoning_effort: 'medium',
+ maxTokens: 4096
+ }
+ } as Assistant
+
+ const result = getBedrockReasoningParams(assistant, model)
+ expect(result).toEqual({
+ reasoningConfig: {
+ type: 'enabled',
+ budgetTokens: 2048
+ }
+ })
+ })
+ })
+
+ describe('getCustomParameters', () => {
+ it('should return empty object when no custom parameters', async () => {
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {}
+ } as Assistant
+
+ const result = getCustomParameters(assistant)
+ expect(result).toEqual({})
+ })
+
+ it('should return custom parameters as key-value pairs', async () => {
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ customParameters: [
+ { name: 'param1', value: 'value1', type: 'string' },
+ { name: 'param2', value: 123, type: 'number' }
+ ]
+ }
+ } as Assistant
+
+ const result = getCustomParameters(assistant)
+ expect(result).toEqual({
+ param1: 'value1',
+ param2: 123
+ })
+ })
+
+ it('should parse JSON type parameters', async () => {
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ customParameters: [{ name: 'config', value: '{"key": "value"}', type: 'json' }]
+ }
+ } as Assistant
+
+ const result = getCustomParameters(assistant)
+ expect(result).toEqual({
+ config: { key: 'value' }
+ })
+ })
+
+ it('should handle invalid JSON gracefully', async () => {
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ customParameters: [{ name: 'invalid', value: '{invalid json', type: 'json' }]
+ }
+ } as Assistant
+
+ const result = getCustomParameters(assistant)
+ expect(result).toEqual({
+ invalid: '{invalid json'
+ })
+ })
+
+ it('should handle undefined JSON value', async () => {
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ customParameters: [{ name: 'undef', value: 'undefined', type: 'json' }]
+ }
+ } as Assistant
+
+ const result = getCustomParameters(assistant)
+ expect(result).toEqual({
+ undef: undefined
+ })
+ })
+
+ it('should skip parameters with empty names', async () => {
+ const assistant: Assistant = {
+ id: 'test',
+ name: 'Test',
+ settings: {
+ customParameters: [
+ { name: '', value: 'value1', type: 'string' },
+ { name: ' ', value: 'value2', type: 'string' },
+ { name: 'valid', value: 'value3', type: 'string' }
+ ]
+ }
+ } as Assistant
+
+ const result = getCustomParameters(assistant)
+ expect(result).toEqual({
+ valid: 'value3'
+ })
+ })
+ })
+})
diff --git a/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts b/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts
new file mode 100644
index 0000000000..fa5e3c3b36
--- /dev/null
+++ b/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts
@@ -0,0 +1,384 @@
+/**
+ * websearch.ts Unit Tests
+ * Tests for web search parameters generation utilities
+ */
+
+import type { CherryWebSearchConfig } from '@renderer/store/websearch'
+import type { Model } from '@renderer/types'
+import { describe, expect, it, vi } from 'vitest'
+
+import { buildProviderBuiltinWebSearchConfig, getWebSearchParams } from '../websearch'
+
+// Mock dependencies
+vi.mock('@renderer/config/models', () => ({
+ isOpenAIWebSearchChatCompletionOnlyModel: vi.fn((model) => model?.id?.includes('o1-pro') ?? false),
+ isOpenAIDeepResearchModel: vi.fn((model) => model?.id?.includes('o3-mini') ?? false)
+}))
+
+vi.mock('@renderer/utils/blacklistMatchPattern', () => ({
+ mapRegexToPatterns: vi.fn((patterns) => patterns || [])
+}))
+
+describe('websearch utils', () => {
+ describe('getWebSearchParams', () => {
+ it('should return enhancement params for hunyuan provider', () => {
+ const model: Model = {
+ id: 'hunyuan-model',
+ name: 'Hunyuan Model',
+ provider: 'hunyuan'
+ } as Model
+
+ const result = getWebSearchParams(model)
+
+ expect(result).toEqual({
+ enable_enhancement: true,
+ citation: true,
+ search_info: true
+ })
+ })
+
+ it('should return search params for dashscope provider', () => {
+ const model: Model = {
+ id: 'qwen-model',
+ name: 'Qwen Model',
+ provider: 'dashscope'
+ } as Model
+
+ const result = getWebSearchParams(model)
+
+ expect(result).toEqual({
+ enable_search: true,
+ search_options: {
+ forced_search: true
+ }
+ })
+ })
+
+ it('should return web_search_options for OpenAI web search models', () => {
+ const model: Model = {
+ id: 'o1-pro',
+ name: 'O1 Pro',
+ provider: 'openai'
+ } as Model
+
+ const result = getWebSearchParams(model)
+
+ expect(result).toEqual({
+ web_search_options: {}
+ })
+ })
+
+ it('should return empty object for other providers', () => {
+ const model: Model = {
+ id: 'gpt-4',
+ name: 'GPT-4',
+ provider: 'openai'
+ } as Model
+
+ const result = getWebSearchParams(model)
+
+ expect(result).toEqual({})
+ })
+
+ it('should return empty object for custom provider', () => {
+ const model: Model = {
+ id: 'custom-model',
+ name: 'Custom Model',
+ provider: 'custom-provider'
+ } as Model
+
+ const result = getWebSearchParams(model)
+
+ expect(result).toEqual({})
+ })
+ })
+
+ describe('buildProviderBuiltinWebSearchConfig', () => {
+ const defaultWebSearchConfig: CherryWebSearchConfig = {
+ searchWithTime: true,
+ maxResults: 50,
+ excludeDomains: []
+ }
+
+ describe('openai provider', () => {
+ it('should return low search context size for low maxResults', () => {
+ const config: CherryWebSearchConfig = {
+ searchWithTime: true,
+ maxResults: 20,
+ excludeDomains: []
+ }
+
+ const result = buildProviderBuiltinWebSearchConfig('openai', config)
+
+ expect(result).toEqual({
+ openai: {
+ searchContextSize: 'low'
+ }
+ })
+ })
+
+ it('should return medium search context size for medium maxResults', () => {
+ const config: CherryWebSearchConfig = {
+ searchWithTime: true,
+ maxResults: 50,
+ excludeDomains: []
+ }
+
+ const result = buildProviderBuiltinWebSearchConfig('openai', config)
+
+ expect(result).toEqual({
+ openai: {
+ searchContextSize: 'medium'
+ }
+ })
+ })
+
+ it('should return high search context size for high maxResults', () => {
+ const config: CherryWebSearchConfig = {
+ searchWithTime: true,
+ maxResults: 80,
+ excludeDomains: []
+ }
+
+ const result = buildProviderBuiltinWebSearchConfig('openai', config)
+
+ expect(result).toEqual({
+ openai: {
+ searchContextSize: 'high'
+ }
+ })
+ })
+
+ it('should use medium for deep research models regardless of maxResults', () => {
+ const config: CherryWebSearchConfig = {
+ searchWithTime: true,
+ maxResults: 100,
+ excludeDomains: []
+ }
+
+ const model: Model = {
+ id: 'o3-mini',
+ name: 'O3 Mini',
+ provider: 'openai'
+ } as Model
+
+ const result = buildProviderBuiltinWebSearchConfig('openai', config, model)
+
+ expect(result).toEqual({
+ openai: {
+ searchContextSize: 'medium'
+ }
+ })
+ })
+ })
+
+ describe('openai-chat provider', () => {
+ it('should return correct search context size', () => {
+ const config: CherryWebSearchConfig = {
+ searchWithTime: true,
+ maxResults: 50,
+ excludeDomains: []
+ }
+
+ const result = buildProviderBuiltinWebSearchConfig('openai-chat', config)
+
+ expect(result).toEqual({
+ 'openai-chat': {
+ searchContextSize: 'medium'
+ }
+ })
+ })
+
+ it('should handle deep research models', () => {
+ const config: CherryWebSearchConfig = {
+ searchWithTime: true,
+ maxResults: 100,
+ excludeDomains: []
+ }
+
+ const model: Model = {
+ id: 'o3-mini',
+ name: 'O3 Mini',
+ provider: 'openai'
+ } as Model
+
+ const result = buildProviderBuiltinWebSearchConfig('openai-chat', config, model)
+
+ expect(result).toEqual({
+ 'openai-chat': {
+ searchContextSize: 'medium'
+ }
+ })
+ })
+ })
+
+ describe('anthropic provider', () => {
+ it('should return anthropic search options with maxUses', () => {
+ const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
+
+ expect(result).toEqual({
+ anthropic: {
+ maxUses: 50,
+ blockedDomains: undefined
+ }
+ })
+ })
+
+ it('should include blockedDomains when excludeDomains provided', () => {
+ const config: CherryWebSearchConfig = {
+ searchWithTime: true,
+ maxResults: 30,
+ excludeDomains: ['example.com', 'test.com']
+ }
+
+ const result = buildProviderBuiltinWebSearchConfig('anthropic', config)
+
+ expect(result).toEqual({
+ anthropic: {
+ maxUses: 30,
+ blockedDomains: ['example.com', 'test.com']
+ }
+ })
+ })
+
+ it('should not include blockedDomains when empty', () => {
+ const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
+
+ expect(result).toEqual({
+ anthropic: {
+ maxUses: 50,
+ blockedDomains: undefined
+ }
+ })
+ })
+ })
+
+ describe('xai provider', () => {
+ it('should return xai search options', () => {
+ const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
+
+ expect(result).toEqual({
+ xai: {
+ maxSearchResults: 50,
+ returnCitations: true,
+ sources: [{ type: 'web', excludedWebsites: [] }, { type: 'news' }, { type: 'x' }],
+ mode: 'on'
+ }
+ })
+ })
+
+ it('should limit excluded websites to 5', () => {
+ const config: CherryWebSearchConfig = {
+ searchWithTime: true,
+ maxResults: 40,
+ excludeDomains: ['site1.com', 'site2.com', 'site3.com', 'site4.com', 'site5.com', 'site6.com', 'site7.com']
+ }
+
+ const result = buildProviderBuiltinWebSearchConfig('xai', config)
+
+ expect(result?.xai?.sources).toBeDefined()
+ const webSource = result?.xai?.sources?.[0]
+ if (webSource && webSource.type === 'web') {
+ expect(webSource.excludedWebsites).toHaveLength(5)
+ }
+ })
+
+ it('should include all sources types', () => {
+ const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
+
+ expect(result?.xai?.sources).toHaveLength(3)
+ expect(result?.xai?.sources?.[0].type).toBe('web')
+ expect(result?.xai?.sources?.[1].type).toBe('news')
+ expect(result?.xai?.sources?.[2].type).toBe('x')
+ })
+ })
+
+ describe('openrouter provider', () => {
+ it('should return openrouter plugins config', () => {
+ const result = buildProviderBuiltinWebSearchConfig('openrouter', defaultWebSearchConfig)
+
+ expect(result).toEqual({
+ openrouter: {
+ plugins: [
+ {
+ id: 'web',
+ max_results: 50
+ }
+ ]
+ }
+ })
+ })
+
+ it('should respect custom maxResults', () => {
+ const config: CherryWebSearchConfig = {
+ searchWithTime: true,
+ maxResults: 75,
+ excludeDomains: []
+ }
+
+ const result = buildProviderBuiltinWebSearchConfig('openrouter', config)
+
+ expect(result).toEqual({
+ openrouter: {
+ plugins: [
+ {
+ id: 'web',
+ max_results: 75
+ }
+ ]
+ }
+ })
+ })
+ })
+
+ describe('unsupported provider', () => {
+ it('should return empty object for unsupported provider', () => {
+ const result = buildProviderBuiltinWebSearchConfig('unsupported' as any, defaultWebSearchConfig)
+
+ expect(result).toEqual({})
+ })
+
+ it('should return empty object for google provider', () => {
+ const result = buildProviderBuiltinWebSearchConfig('google', defaultWebSearchConfig)
+
+ expect(result).toEqual({})
+ })
+ })
+
+ describe('edge cases', () => {
+ it('should handle maxResults at boundary values', () => {
+ // Test boundary at 33 (low/medium)
+ const config33: CherryWebSearchConfig = { searchWithTime: true, maxResults: 33, excludeDomains: [] }
+ const result33 = buildProviderBuiltinWebSearchConfig('openai', config33)
+ expect(result33?.openai?.searchContextSize).toBe('low')
+
+ // Test boundary at 34 (medium)
+ const config34: CherryWebSearchConfig = { searchWithTime: true, maxResults: 34, excludeDomains: [] }
+ const result34 = buildProviderBuiltinWebSearchConfig('openai', config34)
+ expect(result34?.openai?.searchContextSize).toBe('medium')
+
+ // Test boundary at 66 (medium)
+ const config66: CherryWebSearchConfig = { searchWithTime: true, maxResults: 66, excludeDomains: [] }
+ const result66 = buildProviderBuiltinWebSearchConfig('openai', config66)
+ expect(result66?.openai?.searchContextSize).toBe('medium')
+
+ // Test boundary at 67 (high)
+ const config67: CherryWebSearchConfig = { searchWithTime: true, maxResults: 67, excludeDomains: [] }
+ const result67 = buildProviderBuiltinWebSearchConfig('openai', config67)
+ expect(result67?.openai?.searchContextSize).toBe('high')
+ })
+
+ it('should handle zero maxResults', () => {
+ const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 0, excludeDomains: [] }
+ const result = buildProviderBuiltinWebSearchConfig('openai', config)
+ expect(result?.openai?.searchContextSize).toBe('low')
+ })
+
+ it('should handle very large maxResults', () => {
+ const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 1000, excludeDomains: [] }
+ const result = buildProviderBuiltinWebSearchConfig('openai', config)
+ expect(result?.openai?.searchContextSize).toBe('high')
+ })
+ })
+ })
+})
diff --git a/src/renderer/src/aiCore/utils/mcp.ts b/src/renderer/src/aiCore/utils/mcp.ts
index 84bc661aa0..7d3be9ac96 100644
--- a/src/renderer/src/aiCore/utils/mcp.ts
+++ b/src/renderer/src/aiCore/utils/mcp.ts
@@ -28,7 +28,9 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
const tools: ToolSet = {}
for (const mcpTool of mcpTools) {
- tools[mcpTool.name] = tool({
+ // Use mcpTool.id (which includes serverId suffix) to ensure uniqueness
+ // when multiple instances of the same MCP server type are configured
+ tools[mcpTool.id] = tool({
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
execute: async (params, { toolCallId }) => {
diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts
index 60d9b1e098..8ec46c9df2 100644
--- a/src/renderer/src/aiCore/utils/options.ts
+++ b/src/renderer/src/aiCore/utils/options.ts
@@ -1,18 +1,47 @@
+import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
+import { type AnthropicProviderOptions } from '@ai-sdk/anthropic'
+import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
+import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
+import type { XaiProviderOptions } from '@ai-sdk/xai'
import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-core/provider'
-import { isOpenAIModel, isQwenMTModel, isSupportFlexServiceTierModel } from '@renderer/config/models'
-import { isSupportServiceTierProvider } from '@renderer/config/providers'
-import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
-import type { Assistant, Model, Provider } from '@renderer/types'
+import { loggerService } from '@logger'
import {
+ getModelSupportedVerbosity,
+ isAnthropicModel,
+ isGeminiModel,
+ isGrokModel,
+ isOpenAIModel,
+ isQwenMTModel,
+ isSupportFlexServiceTierModel,
+ isSupportVerbosityModel
+} from '@renderer/config/models'
+import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
+import { getStoreSetting } from '@renderer/hooks/useSettings'
+import { getProviderById } from '@renderer/services/ProviderService'
+import {
+ type Assistant,
+ type GroqServiceTier,
GroqServiceTiers,
+ type GroqSystemProvider,
isGroqServiceTier,
+ isGroqSystemProvider,
isOpenAIServiceTier,
isTranslateAssistant,
+ type Model,
+ type NotGroqProvider,
+ type OpenAIServiceTier,
OpenAIServiceTiers,
+ type Provider,
+ type ServiceTier,
SystemProviderIds
} from '@renderer/types'
+import { type AiSdkParam, isAiSdkParam, type OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
+import { isSupportServiceTierProvider, isSupportVerbosityProvider } from '@renderer/utils/provider'
+import type { JSONValue } from 'ai'
import { t } from 'i18next'
+import type { OllamaCompletionProviderOptions } from 'ollama-ai-provider-v2'
+import { addAnthropicHeaders } from '../prepareParams/header'
import { getAiSdkProviderId } from '../provider/factory'
import { buildGeminiGenerateImageParams } from './image'
import {
@@ -26,8 +55,33 @@ import {
} from './reasoning'
import { getWebSearchParams } from './websearch'
-// copy from BaseApiClient.ts
-const getServiceTier = (model: Model, provider: Provider) => {
+const logger = loggerService.withContext('aiCore.utils.options')
+
+function toOpenAIServiceTier(model: Model, serviceTier: ServiceTier): OpenAIServiceTier {
+ if (
+ !isOpenAIServiceTier(serviceTier) ||
+ (serviceTier === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
+ ) {
+ return undefined
+ } else {
+ return serviceTier
+ }
+}
+
+function toGroqServiceTier(model: Model, serviceTier: ServiceTier): GroqServiceTier {
+ if (
+ !isGroqServiceTier(serviceTier) ||
+ (serviceTier === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
+ ) {
+ return undefined
+ } else {
+ return serviceTier
+ }
+}
+
+function getServiceTier(model: Model, provider: T): GroqServiceTier
+function getServiceTier(model: Model, provider: T): OpenAIServiceTier
+function getServiceTier(model: Model, provider: T): OpenAIServiceTier | GroqServiceTier {
const serviceTierSetting = provider.serviceTier
if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) {
@@ -35,30 +89,64 @@ const getServiceTier = (model: Model, provider: Provider) => {
}
// 处理不同供应商需要 fallback 到默认值的情况
- if (provider.id === SystemProviderIds.groq) {
- if (
- !isGroqServiceTier(serviceTierSetting) ||
- (serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
- ) {
- return undefined
- }
+ if (isGroqSystemProvider(provider)) {
+ return toGroqServiceTier(model, serviceTierSetting)
} else {
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
- if (
- !isOpenAIServiceTier(serviceTierSetting) ||
- (serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
- ) {
- return undefined
+ return toOpenAIServiceTier(model, serviceTierSetting)
+ }
+}
+
+function getVerbosity(model: Model): OpenAIVerbosity {
+ if (!isSupportVerbosityModel(model) || !isSupportVerbosityProvider(getProviderById(model.provider)!)) {
+ return undefined
+ }
+ const openAI = getStoreSetting('openAI')
+
+ const userVerbosity = openAI.verbosity
+
+ if (userVerbosity) {
+ const supportedVerbosity = getModelSupportedVerbosity(model)
+ // Use user's verbosity if supported, otherwise use the first supported option
+ const verbosity = supportedVerbosity.includes(userVerbosity) ? userVerbosity : supportedVerbosity[0]
+ return verbosity
+ }
+ return undefined
+}
+
+/**
+ * Extract AI SDK standard parameters from custom parameters
+ * These parameters should be passed directly to streamText() instead of providerOptions
+ */
+export function extractAiSdkStandardParams(customParams: Record): {
+ standardParams: Partial>
+ providerParams: Record
+} {
+ const standardParams: Partial> = {}
+ const providerParams: Record = {}
+
+ for (const [key, value] of Object.entries(customParams)) {
+ if (isAiSdkParam(key)) {
+ standardParams[key] = value
+ } else {
+ providerParams[key] = value
}
}
- return serviceTierSetting
+ return { standardParams, providerParams }
}
/**
* 构建 AI SDK 的 providerOptions
* 按 provider 类型分离,保持类型安全
- * 返回格式:{ 'providerId': providerOptions }
+ * 返回格式:{
+ * providerOptions: { 'providerId': providerOptions },
+ * standardParams: { topK, frequencyPenalty, presencePenalty, stopSequences, seed }
+ * }
+ *
+ * Custom parameters are split into two categories:
+ * 1. AI SDK standard parameters (topK, frequencyPenalty, etc.) - returned separately to be passed to streamText()
+ * 2. Provider-specific parameters - merged into providerOptions
*/
export function buildProviderOptions(
assistant: Assistant,
@@ -69,12 +157,16 @@ export function buildProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
-): Record {
+): {
+ providerOptions: Record>
+ standardParams: Partial>
+} {
const rawProviderId = getAiSdkProviderId(actualProvider)
+ logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities, rawProviderId })
// 构建 provider 特定的选项
let providerSpecificOptions: Record = {}
- const serviceTierSetting = getServiceTier(model, actualProvider)
- providerSpecificOptions.serviceTier = serviceTierSetting
+ const serviceTier = getServiceTier(model, actualProvider)
+ const textVerbosity = getVerbosity(model)
// 根据 provider 类型分离构建逻辑
const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId)
if (success) {
@@ -84,14 +176,16 @@ export function buildProviderOptions(
case 'openai-chat':
case 'azure':
case 'azure-responses':
- providerSpecificOptions = {
- ...buildOpenAIProviderOptions(assistant, model, capabilities),
- serviceTier: serviceTierSetting
+ {
+ providerSpecificOptions = buildOpenAIProviderOptions(
+ assistant,
+ model,
+ capabilities,
+ serviceTier,
+ textVerbosity
+ )
}
break
- case 'huggingface':
- providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities)
- break
case 'anthropic':
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
break
@@ -107,12 +201,26 @@ export function buildProviderOptions(
case 'openrouter':
case 'openai-compatible': {
// 对于其他 provider,使用通用的构建逻辑
+ const genericOptions = buildGenericProviderOptions(rawProviderId, assistant, model, capabilities)
providerSpecificOptions = {
- ...buildGenericProviderOptions(assistant, model, capabilities),
- serviceTier: serviceTierSetting
+ [rawProviderId]: {
+ ...genericOptions[rawProviderId],
+ serviceTier,
+ textVerbosity
+ }
}
break
}
+ case 'cherryin':
+ providerSpecificOptions = buildCherryInProviderOptions(
+ assistant,
+ model,
+ capabilities,
+ actualProvider,
+ serviceTier,
+ textVerbosity
+ )
+ break
default:
throw new Error(`Unsupported base provider ${baseProviderId}`)
}
@@ -125,39 +233,119 @@ export function buildProviderOptions(
case 'google-vertex':
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
break
+ case 'azure-anthropic':
case 'google-vertex-anthropic':
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
break
case 'bedrock':
providerSpecificOptions = buildBedrockProviderOptions(assistant, model, capabilities)
break
+ case 'huggingface':
+ providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
+ break
+ case SystemProviderIds.ollama:
+ providerSpecificOptions = buildOllamaProviderOptions(assistant, capabilities)
+ break
+ case SystemProviderIds.gateway:
+ providerSpecificOptions = buildAIGatewayOptions(assistant, model, capabilities, serviceTier, textVerbosity)
+ break
default:
// 对于其他 provider,使用通用的构建逻辑
+ providerSpecificOptions = buildGenericProviderOptions(rawProviderId, assistant, model, capabilities)
+ // Merge serviceTier and textVerbosity
providerSpecificOptions = {
- ...buildGenericProviderOptions(assistant, model, capabilities),
- serviceTier: serviceTierSetting
+ ...providerSpecificOptions,
+ [rawProviderId]: {
+ ...providerSpecificOptions[rawProviderId],
+ serviceTier,
+ textVerbosity
+ }
}
}
} else {
throw error
}
}
+ logger.debug('Built providerSpecificOptions', { providerSpecificOptions })
+ /**
+ * Retrieve custom parameters and separate standard parameters from provider-specific parameters.
+ */
+ const customParams = getCustomParameters(assistant)
+ const { standardParams, providerParams } = extractAiSdkStandardParams(customParams)
+ logger.debug('Extracted standardParams and providerParams', { standardParams, providerParams })
- // 合并自定义参数到 provider 特定的选项中
- providerSpecificOptions = {
- ...providerSpecificOptions,
- ...getCustomParameters(assistant)
+ /**
+ * Get the actual AI SDK provider ID(s) from the already-built providerSpecificOptions.
+ * For proxy providers (cherryin, aihubmix, newapi), this will be the actual SDK provider (e.g., 'google', 'openai', 'anthropic')
+ * For regular providers, this will be the provider itself
+ */
+ const actualAiSdkProviderIds = Object.keys(providerSpecificOptions)
+ const primaryAiSdkProviderId = actualAiSdkProviderIds[0] // Use the first one as primary for non-scoped params
+
+ /**
+ * Merge custom parameters into providerSpecificOptions.
+ * Simple logic:
+ * 1. If key is in actualAiSdkProviderIds → merge directly (user knows the actual AI SDK provider ID)
+ * 2. If key == rawProviderId:
+ * - If it's gateway/ollama → preserve (they need their own config for routing/options)
+ * - Otherwise → map to primary (this is a proxy provider like cherryin)
+ * 3. Otherwise → treat as regular parameter, merge to primary provider
+ *
+ * Example:
+ * - User writes `cherryin: { opt: 'val' }` → mapped to `google: { opt: 'val' }` (case 2, proxy)
+ * - User writes `gateway: { order: [...] }` → stays as `gateway: { order: [...] }` (case 2, routing config)
+ * - User writes `google: { opt: 'val' }` → stays as `google: { opt: 'val' }` (case 1)
+ * - User writes `customKey: 'val'` → merged to `google: { customKey: 'val' }` (case 3)
+ */
+ for (const key of Object.keys(providerParams)) {
+ if (actualAiSdkProviderIds.includes(key)) {
+ // Case 1: Key is an actual AI SDK provider ID - merge directly
+ providerSpecificOptions = {
+ ...providerSpecificOptions,
+ [key]: {
+ ...providerSpecificOptions[key],
+ ...providerParams[key]
+ }
+ }
+ } else if (key === rawProviderId && !actualAiSdkProviderIds.includes(rawProviderId)) {
+ // Case 2: Key is the current provider (not in actualAiSdkProviderIds, so it's a proxy or special provider)
+ // Gateway is special: it needs routing config preserved
+ if (key === SystemProviderIds.gateway) {
+ // Preserve gateway config for routing
+ providerSpecificOptions = {
+ ...providerSpecificOptions,
+ [key]: {
+ ...providerSpecificOptions[key],
+ ...providerParams[key]
+ }
+ }
+ } else {
+ // Proxy provider (cherryin, etc.) - map to actual AI SDK provider
+ providerSpecificOptions = {
+ ...providerSpecificOptions,
+ [primaryAiSdkProviderId]: {
+ ...providerSpecificOptions[primaryAiSdkProviderId],
+ ...providerParams[key]
+ }
+ }
+ }
+ } else {
+ // Case 3: Regular parameter - merge to primary provider
+ providerSpecificOptions = {
+ ...providerSpecificOptions,
+ [primaryAiSdkProviderId]: {
+ ...providerSpecificOptions[primaryAiSdkProviderId],
+ [key]: providerParams[key]
+ }
+ }
+ }
}
- // vertex需要映射到google或anthropic
- const rawProviderKey =
- {
- 'google-vertex': 'google',
- 'google-vertex-anthropic': 'anthropic'
- }[rawProviderId] || rawProviderId
+ logger.debug('Final providerSpecificOptions after merging providerParams', { providerSpecificOptions })
- // 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
+ // 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数
return {
- [rawProviderKey]: providerSpecificOptions
+ providerOptions: providerSpecificOptions,
+ standardParams
}
}
@@ -171,10 +359,12 @@ function buildOpenAIProviderOptions(
enableReasoning: boolean
enableWebSearch: boolean
enableGenerateImage: boolean
- }
-): Record {
+ },
+ serviceTier: OpenAIServiceTier,
+ textVerbosity?: OpenAIVerbosity
+): Record {
const { enableReasoning } = capabilities
- let providerOptions: Record = {}
+ let providerOptions: OpenAIResponsesProviderOptions = {}
// OpenAI 推理参数
if (enableReasoning) {
const reasoningParams = getOpenAIReasoningParams(assistant, model)
@@ -183,7 +373,37 @@ function buildOpenAIProviderOptions(
...reasoningParams
}
}
- return providerOptions
+ const provider = getProviderById(model.provider)
+
+ if (!provider) {
+ throw new Error(`Provider ${model.provider} not found`)
+ }
+
+ if (isSupportVerbosityModel(model) && isSupportVerbosityProvider(provider)) {
+ const openAI = getStoreSetting<'openAI'>('openAI')
+ const userVerbosity = openAI?.verbosity
+
+ if (userVerbosity && ['low', 'medium', 'high'].includes(userVerbosity)) {
+ const supportedVerbosity = getModelSupportedVerbosity(model)
+ // Use user's verbosity if supported, otherwise use the first supported option
+ const verbosity = supportedVerbosity.includes(userVerbosity) ? userVerbosity : supportedVerbosity[0]
+
+ providerOptions = {
+ ...providerOptions,
+ textVerbosity: verbosity
+ }
+ }
+ }
+
+ providerOptions = {
+ ...providerOptions,
+ serviceTier,
+ textVerbosity
+ }
+
+ return {
+ openai: providerOptions
+ }
}
/**
@@ -197,9 +417,9 @@ function buildAnthropicProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
-): Record {
+): Record {
const { enableReasoning } = capabilities
- let providerOptions: Record = {}
+ let providerOptions: AnthropicProviderOptions = {}
// Anthropic 推理参数
if (enableReasoning) {
@@ -210,7 +430,11 @@ function buildAnthropicProviderOptions(
}
}
- return providerOptions
+ return {
+ anthropic: {
+ ...providerOptions
+ }
+ }
}
/**
@@ -224,9 +448,9 @@ function buildGeminiProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
-): Record {
+): Record {
const { enableReasoning, enableGenerateImage } = capabilities
- let providerOptions: Record