diff --git a/.github/workflows/auto-i18n.yml b/.github/workflows/auto-i18n.yml index ea9f05ae03..7537c4d4a3 100644 --- a/.github/workflows/auto-i18n.yml +++ b/.github/workflows/auto-i18n.yml @@ -23,7 +23,7 @@ jobs: steps: - name: 🐈‍⬛ Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -54,7 +54,7 @@ jobs: yarn install - name: 🏃‍♀️ Translate - run: yarn sync:i18n && yarn auto:i18n + run: yarn i18n:sync && yarn i18n:translate - name: 🔍 Format run: yarn format @@ -73,7 +73,7 @@ jobs: - name: 🚀 Create Pull Request if changes exist if: steps.git_status.outputs.has_changes == 'true' - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v8 with: token: ${{ secrets.GITHUB_TOKEN }} # Use the built-in GITHUB_TOKEN for bot actions commit-message: "feat(bot): Weekly automated script run" diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml index cc6d28817f..cc24438768 100644 --- a/.github/workflows/claude-code-review.yml +++ b/.github/workflows/claude-code-review.yml @@ -27,7 +27,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 1 diff --git a/.github/workflows/claude-translator.yml b/.github/workflows/claude-translator.yml index 23f359021d..71c2e0b87f 100644 --- a/.github/workflows/claude-translator.yml +++ b/.github/workflows/claude-translator.yml @@ -32,7 +32,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 1 diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index 82c7b4393b..be018fb5bb 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -37,7 +37,7 @@ jobs: actions: read # Required for Claude to read CI results on PRs steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 1 diff --git a/.github/workflows/github-issue-tracker.yml b/.github/workflows/github-issue-tracker.yml index 32bd393145..a628f9f13c 100644 --- a/.github/workflows/github-issue-tracker.yml +++ b/.github/workflows/github-issue-tracker.yml @@ -19,7 +19,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Check Beijing Time id: check_time @@ -42,7 +42,7 @@ jobs: - name: Add pending label if in quiet hours if: steps.check_time.outputs.should_delay == 'true' - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: script: | github.rest.issues.addLabels({ @@ -118,7 +118,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup Node.js uses: actions/setup-node@v6 diff --git a/.github/workflows/nightly-build.yml b/.github/workflows/nightly-build.yml index 523a670064..eb28b91c63 100644 --- a/.github/workflows/nightly-build.yml +++ b/.github/workflows/nightly-build.yml @@ -51,7 +51,7 @@ jobs: steps: - name: Check out Git repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: ref: main diff --git a/.github/workflows/pr-ci.yml b/.github/workflows/pr-ci.yml index aa273cc56e..1f7bf7d784 100644 --- a/.github/workflows/pr-ci.yml +++ b/.github/workflows/pr-ci.yml @@ -21,7 +21,7 @@ jobs: steps: - name: Check out Git repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Install Node.js uses: actions/setup-node@v6 @@ -58,7 +58,7 @@ jobs: run: yarn typecheck - name: i18n Check - run: yarn check:i18n + run: yarn i18n:check - name: Test run: yarn test diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8bbb46ee67..4488b1b9d3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -25,7 +25,7 @@ jobs: steps: - name: Check out Git repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 diff --git a/.github/workflows/sync-to-gitcode.yml b/.github/workflows/sync-to-gitcode.yml new file mode 100644 index 0000000000..53ecae445b --- /dev/null +++ b/.github/workflows/sync-to-gitcode.yml @@ -0,0 +1,305 @@ +name: Sync Release to GitCode + +on: + release: + types: [published] + workflow_dispatch: + inputs: + tag: + description: 'Release tag (e.g. v1.0.0)' + required: true + clean: + description: 'Clean node_modules before build' + type: boolean + default: false + +permissions: + contents: read + +jobs: + build-and-sync-to-gitcode: + runs-on: [self-hosted, windows-signing] + steps: + - name: Get tag name + id: get-tag + shell: bash + run: | + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + echo "tag=${{ github.event.inputs.tag }}" >> $GITHUB_OUTPUT + else + echo "tag=${{ github.event.release.tag_name }}" >> $GITHUB_OUTPUT + fi + + - name: Check out Git repository + uses: actions/checkout@v6 + with: + fetch-depth: 0 + ref: ${{ steps.get-tag.outputs.tag }} + + - name: Set package.json version + shell: bash + run: | + TAG="${{ steps.get-tag.outputs.tag }}" + VERSION="${TAG#v}" + npm version "$VERSION" --no-git-tag-version --allow-same-version + + - name: Install Node.js + uses: actions/setup-node@v6 + with: + node-version: 22 + + - name: Install corepack + shell: bash + run: corepack enable && corepack prepare yarn@4.9.1 --activate + + - name: Clean node_modules + if: ${{ github.event.inputs.clean == 'true' }} + shell: bash + run: rm -rf node_modules + + - name: Install Dependencies + shell: bash + run: yarn install + + - name: Build Windows with code signing + shell: bash + run: yarn build:win + env: + WIN_SIGN: true + CHERRY_CERT_PATH: ${{ secrets.CHERRY_CERT_PATH }} + CHERRY_CERT_KEY: ${{ secrets.CHERRY_CERT_KEY }} + CHERRY_CERT_CSP: ${{ secrets.CHERRY_CERT_CSP }} + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + NODE_OPTIONS: --max-old-space-size=8192 + MAIN_VITE_CHERRYAI_CLIENT_SECRET: ${{ secrets.MAIN_VITE_CHERRYAI_CLIENT_SECRET }} + MAIN_VITE_MINERU_API_KEY: ${{ secrets.MAIN_VITE_MINERU_API_KEY }} + RENDERER_VITE_AIHUBMIX_SECRET: ${{ secrets.RENDERER_VITE_AIHUBMIX_SECRET }} + RENDERER_VITE_PPIO_APP_SECRET: ${{ secrets.RENDERER_VITE_PPIO_APP_SECRET }} + + - name: List built Windows artifacts + shell: bash + run: | + echo "Built Windows artifacts:" + ls -la dist/*.exe dist/*.blockmap dist/latest*.yml + + - name: Download GitHub release assets + shell: bash + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TAG_NAME: ${{ steps.get-tag.outputs.tag }} + run: | + echo "Downloading release assets for $TAG_NAME..." + mkdir -p release-assets + cd release-assets + + # Download all assets from the release + gh release download "$TAG_NAME" \ + --repo "${{ github.repository }}" \ + --pattern "*" \ + --skip-existing + + echo "Downloaded GitHub release assets:" + ls -la + + - name: Replace Windows files with signed versions + shell: bash + run: | + echo "Replacing Windows files with signed versions..." + + # Verify signed files exist first + if ! ls dist/*.exe 1>/dev/null 2>&1; then + echo "ERROR: No signed .exe files found in dist/" + exit 1 + fi + + # Remove unsigned Windows files from downloaded assets + # *.exe, *.exe.blockmap, latest.yml (Windows only) + rm -f release-assets/*.exe release-assets/*.exe.blockmap release-assets/latest.yml 2>/dev/null || true + + # Copy signed Windows files with error checking + cp dist/*.exe release-assets/ || { echo "ERROR: Failed to copy .exe files"; exit 1; } + cp dist/*.exe.blockmap release-assets/ || { echo "ERROR: Failed to copy .blockmap files"; exit 1; } + cp dist/latest.yml release-assets/ || { echo "ERROR: Failed to copy latest.yml"; exit 1; } + + echo "Final release assets:" + ls -la release-assets/ + + - name: Get release info + id: release-info + shell: bash + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TAG_NAME: ${{ steps.get-tag.outputs.tag }} + LANG: C.UTF-8 + LC_ALL: C.UTF-8 + run: | + # Always use gh cli to avoid special character issues + RELEASE_NAME=$(gh release view "$TAG_NAME" --repo "${{ github.repository }}" --json name -q '.name') + # Use delimiter to safely handle special characters in release name + { + echo 'name<> $GITHUB_OUTPUT + # Extract releaseNotes from electron-builder.yml (from releaseNotes: | to end of file, remove 4-space indent) + sed -n '/releaseNotes: |/,$ { /releaseNotes: |/d; s/^ //; p }' electron-builder.yml > release_body.txt + + - name: Create GitCode release and upload files + shell: bash + env: + GITCODE_TOKEN: ${{ secrets.GITCODE_TOKEN }} + GITCODE_OWNER: ${{ vars.GITCODE_OWNER }} + GITCODE_REPO: ${{ vars.GITCODE_REPO }} + GITCODE_API_URL: ${{ vars.GITCODE_API_URL }} + TAG_NAME: ${{ steps.get-tag.outputs.tag }} + RELEASE_NAME: ${{ steps.release-info.outputs.name }} + LANG: C.UTF-8 + LC_ALL: C.UTF-8 + run: | + # Validate required environment variables + if [ -z "$GITCODE_TOKEN" ]; then + echo "ERROR: GITCODE_TOKEN is not set" + exit 1 + fi + if [ -z "$GITCODE_OWNER" ]; then + echo "ERROR: GITCODE_OWNER is not set" + exit 1 + fi + if [ -z "$GITCODE_REPO" ]; then + echo "ERROR: GITCODE_REPO is not set" + exit 1 + fi + + API_URL="${GITCODE_API_URL:-https://api.gitcode.com/api/v5}" + + echo "Creating GitCode release..." + echo "Tag: $TAG_NAME" + echo "Repo: $GITCODE_OWNER/$GITCODE_REPO" + + # Step 1: Create release + # Use --rawfile to read body directly from file, avoiding shell variable encoding issues + jq -n \ + --arg tag "$TAG_NAME" \ + --arg name "$RELEASE_NAME" \ + --rawfile body release_body.txt \ + '{ + tag_name: $tag, + name: $name, + body: $body, + target_commitish: "main" + }' > /tmp/release_payload.json + + RELEASE_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST \ + --connect-timeout 30 --max-time 60 \ + "${API_URL}/repos/${GITCODE_OWNER}/${GITCODE_REPO}/releases" \ + -H "Content-Type: application/json; charset=utf-8" \ + -H "Authorization: Bearer ${GITCODE_TOKEN}" \ + --data-binary "@/tmp/release_payload.json") + + HTTP_CODE=$(echo "$RELEASE_RESPONSE" | tail -n1) + RESPONSE_BODY=$(echo "$RELEASE_RESPONSE" | sed '$d') + + if [ "$HTTP_CODE" -ge 200 ] && [ "$HTTP_CODE" -lt 300 ]; then + echo "Release created successfully" + else + echo "Warning: Release creation returned HTTP $HTTP_CODE" + echo "$RESPONSE_BODY" + exit 1 + fi + + # Step 2: Upload files to release + echo "Uploading files to GitCode release..." + + # Function to upload a single file with retry + upload_file() { + local file="$1" + local filename=$(basename "$file") + local max_retries=3 + local retry=0 + local curl_status=0 + + echo "Uploading: $filename" + + # URL encode the filename + encoded_filename=$(printf '%s' "$filename" | jq -sRr @uri) + + while [ $retry -lt $max_retries ]; do + # Get upload URL + curl_status=0 + UPLOAD_INFO=$(curl -s --connect-timeout 30 --max-time 60 \ + -H "Authorization: Bearer ${GITCODE_TOKEN}" \ + "${API_URL}/repos/${GITCODE_OWNER}/${GITCODE_REPO}/releases/${TAG_NAME}/upload_url?file_name=${encoded_filename}") || curl_status=$? + + if [ $curl_status -eq 0 ]; then + UPLOAD_URL=$(echo "$UPLOAD_INFO" | jq -r '.url // empty') + + if [ -n "$UPLOAD_URL" ]; then + # Write headers to temp file to avoid shell escaping issues + echo "$UPLOAD_INFO" | jq -r '.headers | to_entries[] | "header = \"" + .key + ": " + .value + "\""' > /tmp/upload_headers.txt + + # Upload file using PUT with headers from file + curl_status=0 + UPLOAD_RESPONSE=$(curl -s -w "\n%{http_code}" -X PUT \ + -K /tmp/upload_headers.txt \ + --data-binary "@${file}" \ + "$UPLOAD_URL") || curl_status=$? + + if [ $curl_status -eq 0 ]; then + HTTP_CODE=$(echo "$UPLOAD_RESPONSE" | tail -n1) + RESPONSE_BODY=$(echo "$UPLOAD_RESPONSE" | sed '$d') + + if [ "$HTTP_CODE" -ge 200 ] && [ "$HTTP_CODE" -lt 300 ]; then + echo " Uploaded: $filename" + return 0 + else + echo " Failed (HTTP $HTTP_CODE), retry $((retry + 1))/$max_retries" + echo " Response: $RESPONSE_BODY" + fi + else + echo " Upload request failed (curl exit $curl_status), retry $((retry + 1))/$max_retries" + fi + else + echo " Failed to get upload URL, retry $((retry + 1))/$max_retries" + echo " Response: $UPLOAD_INFO" + fi + else + echo " Failed to get upload URL (curl exit $curl_status), retry $((retry + 1))/$max_retries" + echo " Response: $UPLOAD_INFO" + fi + + retry=$((retry + 1)) + [ $retry -lt $max_retries ] && sleep 3 + done + + echo " Failed: $filename after $max_retries retries" + exit 1 + } + + # Upload non-yml/json files first + for file in release-assets/*; do + if [ -f "$file" ]; then + filename=$(basename "$file") + if [[ ! "$filename" =~ \.(yml|yaml|json)$ ]]; then + upload_file "$file" + fi + fi + done + + # Upload yml/json files last + for file in release-assets/*; do + if [ -f "$file" ]; then + filename=$(basename "$file") + if [[ "$filename" =~ \.(yml|yaml|json)$ ]]; then + upload_file "$file" + fi + fi + done + + echo "GitCode release sync completed!" + + - name: Cleanup temp files + if: always() + shell: bash + run: | + rm -f /tmp/release_payload.json /tmp/upload_headers.txt release_body.txt + rm -rf release-assets/ diff --git a/.github/workflows/update-app-upgrade-config.yml b/.github/workflows/update-app-upgrade-config.yml index 7470bb0b6c..8b0b198008 100644 --- a/.github/workflows/update-app-upgrade-config.yml +++ b/.github/workflows/update-app-upgrade-config.yml @@ -19,10 +19,9 @@ on: permissions: contents: write - pull-requests: write jobs: - propose-update: + update-config: runs-on: ubuntu-latest if: github.event_name == 'workflow_dispatch' || (github.event_name == 'release' && github.event.release.draft == false) @@ -135,7 +134,7 @@ jobs: - name: Checkout default branch if: steps.check.outputs.should_run == 'true' - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: ref: ${{ github.event.repository.default_branch }} path: main @@ -143,7 +142,7 @@ jobs: - name: Checkout x-files/app-upgrade-config branch if: steps.check.outputs.should_run == 'true' - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: ref: x-files/app-upgrade-config path: cs @@ -187,25 +186,20 @@ jobs: echo "changed=true" >> "$GITHUB_OUTPUT" fi - - name: Create pull request + - name: Commit and push changes if: steps.check.outputs.should_run == 'true' && steps.diff.outputs.changed == 'true' - uses: peter-evans/create-pull-request@v7 - with: - path: cs - base: x-files/app-upgrade-config - branch: chore/update-app-upgrade-config/${{ steps.meta.outputs.safe_tag }} - commit-message: "🤖 chore: sync app-upgrade-config for ${{ steps.meta.outputs.tag }}" - title: "chore: update app-upgrade-config for ${{ steps.meta.outputs.tag }}" - body: | - Automated update triggered by `${{ steps.meta.outputs.trigger }}`. + working-directory: cs + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git add app-upgrade-config.json + git commit -m "chore: sync app-upgrade-config for ${{ steps.meta.outputs.tag }}" -m "Automated update triggered by \`${{ steps.meta.outputs.trigger }}\`. - - Source tag: `${{ steps.meta.outputs.tag }}` - - Pre-release: `${{ steps.meta.outputs.prerelease }}` - - Latest: `${{ steps.meta.outputs.latest }}` - - Workflow run: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }} - labels: | - automation - app-upgrade + - Source tag: \`${{ steps.meta.outputs.tag }}\` + - Pre-release: \`${{ steps.meta.outputs.prerelease }}\` + - Latest: \`${{ steps.meta.outputs.latest }}\` + - Workflow run: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" + git push origin x-files/app-upgrade-config - name: No changes detected if: steps.check.outputs.should_run == 'true' && steps.diff.outputs.changed != 'true' diff --git a/.yarn/patches/@ai-sdk-google-npm-2.0.43-689ed559b3.patch b/.yarn/patches/@ai-sdk-google-npm-2.0.49-84720f41bd.patch similarity index 64% rename from .yarn/patches/@ai-sdk-google-npm-2.0.43-689ed559b3.patch rename to .yarn/patches/@ai-sdk-google-npm-2.0.49-84720f41bd.patch index 3015e702ed..67403a6575 100644 --- a/.yarn/patches/@ai-sdk-google-npm-2.0.43-689ed559b3.patch +++ b/.yarn/patches/@ai-sdk-google-npm-2.0.49-84720f41bd.patch @@ -1,5 +1,5 @@ diff --git a/dist/index.js b/dist/index.js -index 51ce7e423934fb717cb90245cdfcdb3dae6780e6..0f7f7009e2f41a79a8669d38c8a44867bbff5e1f 100644 +index d004b415c5841a1969705823614f395265ea5a8a..6b1e0dad4610b0424393ecc12e9114723bbe316b 100644 --- a/dist/index.js +++ b/dist/index.js @@ -474,7 +474,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) { @@ -12,7 +12,7 @@ index 51ce7e423934fb717cb90245cdfcdb3dae6780e6..0f7f7009e2f41a79a8669d38c8a44867 // src/google-generative-ai-options.ts diff --git a/dist/index.mjs b/dist/index.mjs -index f4b77e35c0cbfece85a3ef0d4f4e67aa6dde6271..8d2fecf8155a226006a0bde72b00b6036d4014b6 100644 +index 1780dd2391b7f42224a0b8048c723d2f81222c44..1f12ed14399d6902107ce9b435d7d8e6cc61e06b 100644 --- a/dist/index.mjs +++ b/dist/index.mjs @@ -480,7 +480,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) { @@ -24,3 +24,14 @@ index f4b77e35c0cbfece85a3ef0d4f4e67aa6dde6271..8d2fecf8155a226006a0bde72b00b603 } // src/google-generative-ai-options.ts +@@ -1909,8 +1909,7 @@ function createGoogleGenerativeAI(options = {}) { + } + var google = createGoogleGenerativeAI(); + export { +- VERSION, + createGoogleGenerativeAI, +- google ++ google, VERSION + }; + //# sourceMappingURL=index.mjs.map +\ No newline at end of file diff --git a/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch b/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch deleted file mode 100644 index 2a13c33a78..0000000000 --- a/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch +++ /dev/null @@ -1,140 +0,0 @@ -diff --git a/dist/index.js b/dist/index.js -index 73045a7d38faafdc7f7d2cd79d7ff0e2b031056b..8d948c9ac4ea4b474db9ef3c5491961e7fcf9a07 100644 ---- a/dist/index.js -+++ b/dist/index.js -@@ -421,6 +421,17 @@ var OpenAICompatibleChatLanguageModel = class { - text: reasoning - }); - } -+ if (choice.message.images) { -+ for (const image of choice.message.images) { -+ const match1 = image.image_url.url.match(/^data:([^;]+)/) -+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); -+ content.push({ -+ type: 'file', -+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', -+ data: match2 ? match2[1] : image.image_url.url, -+ }); -+ } -+ } - if (choice.message.tool_calls != null) { - for (const toolCall of choice.message.tool_calls) { - content.push({ -@@ -598,6 +609,17 @@ var OpenAICompatibleChatLanguageModel = class { - delta: delta.content - }); - } -+ if (delta.images) { -+ for (const image of delta.images) { -+ const match1 = image.image_url.url.match(/^data:([^;]+)/) -+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); -+ controller.enqueue({ -+ type: 'file', -+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', -+ data: match2 ? match2[1] : image.image_url.url, -+ }); -+ } -+ } - if (delta.tool_calls != null) { - for (const toolCallDelta of delta.tool_calls) { - const index = toolCallDelta.index; -@@ -765,6 +787,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({ - arguments: import_v43.z.string() - }) - }) -+ ).nullish(), -+ images: import_v43.z.array( -+ import_v43.z.object({ -+ type: import_v43.z.literal('image_url'), -+ image_url: import_v43.z.object({ -+ url: import_v43.z.string(), -+ }) -+ }) - ).nullish() - }), - finish_reason: import_v43.z.string().nullish() -@@ -795,6 +825,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union( - arguments: import_v43.z.string().nullish() - }) - }) -+ ).nullish(), -+ images: import_v43.z.array( -+ import_v43.z.object({ -+ type: import_v43.z.literal('image_url'), -+ image_url: import_v43.z.object({ -+ url: import_v43.z.string(), -+ }) -+ }) - ).nullish() - }).nullish(), - finish_reason: import_v43.z.string().nullish() -diff --git a/dist/index.mjs b/dist/index.mjs -index 1c2b9560bbfbfe10cb01af080aeeed4ff59db29c..2c8ddc4fc9bfc5e7e06cfca105d197a08864c427 100644 ---- a/dist/index.mjs -+++ b/dist/index.mjs -@@ -405,6 +405,17 @@ var OpenAICompatibleChatLanguageModel = class { - text: reasoning - }); - } -+ if (choice.message.images) { -+ for (const image of choice.message.images) { -+ const match1 = image.image_url.url.match(/^data:([^;]+)/) -+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); -+ content.push({ -+ type: 'file', -+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', -+ data: match2 ? match2[1] : image.image_url.url, -+ }); -+ } -+ } - if (choice.message.tool_calls != null) { - for (const toolCall of choice.message.tool_calls) { - content.push({ -@@ -582,6 +593,17 @@ var OpenAICompatibleChatLanguageModel = class { - delta: delta.content - }); - } -+ if (delta.images) { -+ for (const image of delta.images) { -+ const match1 = image.image_url.url.match(/^data:([^;]+)/) -+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); -+ controller.enqueue({ -+ type: 'file', -+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', -+ data: match2 ? match2[1] : image.image_url.url, -+ }); -+ } -+ } - if (delta.tool_calls != null) { - for (const toolCallDelta of delta.tool_calls) { - const index = toolCallDelta.index; -@@ -749,6 +771,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({ - arguments: z3.string() - }) - }) -+ ).nullish(), -+ images: z3.array( -+ z3.object({ -+ type: z3.literal('image_url'), -+ image_url: z3.object({ -+ url: z3.string(), -+ }) -+ }) - ).nullish() - }), - finish_reason: z3.string().nullish() -@@ -779,6 +809,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([ - arguments: z3.string().nullish() - }) - }) -+ ).nullish(), -+ images: z3.array( -+ z3.object({ -+ type: z3.literal('image_url'), -+ image_url: z3.object({ -+ url: z3.string(), -+ }) -+ }) - ).nullish() - }).nullish(), - finish_reason: z3.string().nullish() diff --git a/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch b/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch new file mode 100644 index 0000000000..3178ffee5e --- /dev/null +++ b/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch @@ -0,0 +1,266 @@ +diff --git a/dist/index.d.ts b/dist/index.d.ts +index 48e2f6263c6ee4c75d7e5c28733e64f6ebe92200..00d0729c4a3cbf9a48e8e1e962c7e2b256b75eba 100644 +--- a/dist/index.d.ts ++++ b/dist/index.d.ts +@@ -7,6 +7,7 @@ declare const openaiCompatibleProviderOptions: z.ZodObject<{ + user: z.ZodOptional; + reasoningEffort: z.ZodOptional; + textVerbosity: z.ZodOptional; ++ sendReasoning: z.ZodOptional; + }, z.core.$strip>; + type OpenAICompatibleProviderOptions = z.infer; + +diff --git a/dist/index.js b/dist/index.js +index da237bb35b7fa8e24b37cd861ee73dfc51cdfc72..b3060fbaf010e30b64df55302807828e5bfe0f9a 100644 +--- a/dist/index.js ++++ b/dist/index.js +@@ -41,7 +41,7 @@ function getOpenAIMetadata(message) { + var _a, _b; + return (_b = (_a = message == null ? void 0 : message.providerOptions) == null ? void 0 : _a.openaiCompatible) != null ? _b : {}; + } +-function convertToOpenAICompatibleChatMessages(prompt) { ++function convertToOpenAICompatibleChatMessages({prompt, options}) { + const messages = []; + for (const { role, content, ...message } of prompt) { + const metadata = getOpenAIMetadata({ ...message }); +@@ -91,6 +91,7 @@ function convertToOpenAICompatibleChatMessages(prompt) { + } + case "assistant": { + let text = ""; ++ let reasoning_text = ""; + const toolCalls = []; + for (const part of content) { + const partMetadata = getOpenAIMetadata(part); +@@ -99,6 +100,12 @@ function convertToOpenAICompatibleChatMessages(prompt) { + text += part.text; + break; + } ++ case "reasoning": { ++ if (options.sendReasoning) { ++ reasoning_text += part.text; ++ } ++ break; ++ } + case "tool-call": { + toolCalls.push({ + id: part.toolCallId, +@@ -116,6 +123,7 @@ function convertToOpenAICompatibleChatMessages(prompt) { + messages.push({ + role: "assistant", + content: text, ++ reasoning_content: reasoning_text || undefined, + tool_calls: toolCalls.length > 0 ? toolCalls : void 0, + ...metadata + }); +@@ -200,7 +208,8 @@ var openaiCompatibleProviderOptions = import_v4.z.object({ + /** + * Controls the verbosity of the generated text. Defaults to `medium`. + */ +- textVerbosity: import_v4.z.string().optional() ++ textVerbosity: import_v4.z.string().optional(), ++ sendReasoning: import_v4.z.boolean().optional() + }); + + // src/openai-compatible-error.ts +@@ -378,7 +387,7 @@ var OpenAICompatibleChatLanguageModel = class { + reasoning_effort: compatibleOptions.reasoningEffort, + verbosity: compatibleOptions.textVerbosity, + // messages: +- messages: convertToOpenAICompatibleChatMessages(prompt), ++ messages: convertToOpenAICompatibleChatMessages({prompt, options: compatibleOptions}), + // tools: + tools: openaiTools, + tool_choice: openaiToolChoice +@@ -421,6 +430,17 @@ var OpenAICompatibleChatLanguageModel = class { + text: reasoning + }); + } ++ if (choice.message.images) { ++ for (const image of choice.message.images) { ++ const match1 = image.image_url.url.match(/^data:([^;]+)/) ++ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); ++ content.push({ ++ type: 'file', ++ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', ++ data: match2 ? match2[1] : image.image_url.url, ++ }); ++ } ++ } + if (choice.message.tool_calls != null) { + for (const toolCall of choice.message.tool_calls) { + content.push({ +@@ -598,6 +618,17 @@ var OpenAICompatibleChatLanguageModel = class { + delta: delta.content + }); + } ++ if (delta.images) { ++ for (const image of delta.images) { ++ const match1 = image.image_url.url.match(/^data:([^;]+)/) ++ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); ++ controller.enqueue({ ++ type: 'file', ++ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', ++ data: match2 ? match2[1] : image.image_url.url, ++ }); ++ } ++ } + if (delta.tool_calls != null) { + for (const toolCallDelta of delta.tool_calls) { + const index = toolCallDelta.index; +@@ -765,6 +796,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({ + arguments: import_v43.z.string() + }) + }) ++ ).nullish(), ++ images: import_v43.z.array( ++ import_v43.z.object({ ++ type: import_v43.z.literal('image_url'), ++ image_url: import_v43.z.object({ ++ url: import_v43.z.string(), ++ }) ++ }) + ).nullish() + }), + finish_reason: import_v43.z.string().nullish() +@@ -795,6 +834,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union( + arguments: import_v43.z.string().nullish() + }) + }) ++ ).nullish(), ++ images: import_v43.z.array( ++ import_v43.z.object({ ++ type: import_v43.z.literal('image_url'), ++ image_url: import_v43.z.object({ ++ url: import_v43.z.string(), ++ }) ++ }) + ).nullish() + }).nullish(), + finish_reason: import_v43.z.string().nullish() +diff --git a/dist/index.mjs b/dist/index.mjs +index a809a7aa0e148bfd43e01dd7b018568b151c8ad5..565b605eeacd9830b2b0e817e58ad0c5700264de 100644 +--- a/dist/index.mjs ++++ b/dist/index.mjs +@@ -23,7 +23,7 @@ function getOpenAIMetadata(message) { + var _a, _b; + return (_b = (_a = message == null ? void 0 : message.providerOptions) == null ? void 0 : _a.openaiCompatible) != null ? _b : {}; + } +-function convertToOpenAICompatibleChatMessages(prompt) { ++function convertToOpenAICompatibleChatMessages({prompt, options}) { + const messages = []; + for (const { role, content, ...message } of prompt) { + const metadata = getOpenAIMetadata({ ...message }); +@@ -73,6 +73,7 @@ function convertToOpenAICompatibleChatMessages(prompt) { + } + case "assistant": { + let text = ""; ++ let reasoning_text = ""; + const toolCalls = []; + for (const part of content) { + const partMetadata = getOpenAIMetadata(part); +@@ -81,6 +82,12 @@ function convertToOpenAICompatibleChatMessages(prompt) { + text += part.text; + break; + } ++ case "reasoning": { ++ if (options.sendReasoning) { ++ reasoning_text += part.text; ++ } ++ break; ++ } + case "tool-call": { + toolCalls.push({ + id: part.toolCallId, +@@ -98,6 +105,7 @@ function convertToOpenAICompatibleChatMessages(prompt) { + messages.push({ + role: "assistant", + content: text, ++ reasoning_content: reasoning_text || undefined, + tool_calls: toolCalls.length > 0 ? toolCalls : void 0, + ...metadata + }); +@@ -182,7 +190,8 @@ var openaiCompatibleProviderOptions = z.object({ + /** + * Controls the verbosity of the generated text. Defaults to `medium`. + */ +- textVerbosity: z.string().optional() ++ textVerbosity: z.string().optional(), ++ sendReasoning: z.boolean().optional() + }); + + // src/openai-compatible-error.ts +@@ -362,7 +371,7 @@ var OpenAICompatibleChatLanguageModel = class { + reasoning_effort: compatibleOptions.reasoningEffort, + verbosity: compatibleOptions.textVerbosity, + // messages: +- messages: convertToOpenAICompatibleChatMessages(prompt), ++ messages: convertToOpenAICompatibleChatMessages({prompt, options: compatibleOptions}), + // tools: + tools: openaiTools, + tool_choice: openaiToolChoice +@@ -405,6 +414,17 @@ var OpenAICompatibleChatLanguageModel = class { + text: reasoning + }); + } ++ if (choice.message.images) { ++ for (const image of choice.message.images) { ++ const match1 = image.image_url.url.match(/^data:([^;]+)/) ++ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); ++ content.push({ ++ type: 'file', ++ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', ++ data: match2 ? match2[1] : image.image_url.url, ++ }); ++ } ++ } + if (choice.message.tool_calls != null) { + for (const toolCall of choice.message.tool_calls) { + content.push({ +@@ -582,6 +602,17 @@ var OpenAICompatibleChatLanguageModel = class { + delta: delta.content + }); + } ++ if (delta.images) { ++ for (const image of delta.images) { ++ const match1 = image.image_url.url.match(/^data:([^;]+)/) ++ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); ++ controller.enqueue({ ++ type: 'file', ++ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', ++ data: match2 ? match2[1] : image.image_url.url, ++ }); ++ } ++ } + if (delta.tool_calls != null) { + for (const toolCallDelta of delta.tool_calls) { + const index = toolCallDelta.index; +@@ -749,6 +780,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({ + arguments: z3.string() + }) + }) ++ ).nullish(), ++ images: z3.array( ++ z3.object({ ++ type: z3.literal('image_url'), ++ image_url: z3.object({ ++ url: z3.string(), ++ }) ++ }) + ).nullish() + }), + finish_reason: z3.string().nullish() +@@ -779,6 +818,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([ + arguments: z3.string().nullish() + }) + }) ++ ).nullish(), ++ images: z3.array( ++ z3.object({ ++ type: z3.literal('image_url'), ++ image_url: z3.object({ ++ url: z3.string(), ++ }) ++ }) + ).nullish() + }).nullish(), + finish_reason: z3.string().nullish() diff --git a/.yarn/patches/@ai-sdk-openai-npm-2.0.72-234e68da87.patch b/.yarn/patches/@ai-sdk-openai-npm-2.0.85-27483d1d6a.patch similarity index 84% rename from .yarn/patches/@ai-sdk-openai-npm-2.0.72-234e68da87.patch rename to .yarn/patches/@ai-sdk-openai-npm-2.0.85-27483d1d6a.patch index 973ddc62ac..6fbe30e080 100644 --- a/.yarn/patches/@ai-sdk-openai-npm-2.0.72-234e68da87.patch +++ b/.yarn/patches/@ai-sdk-openai-npm-2.0.85-27483d1d6a.patch @@ -1,8 +1,8 @@ diff --git a/dist/index.js b/dist/index.js -index bf900591bf2847a3253fe441aad24c06da19c6c1..c1d9bb6fefa2df1383339324073db0a70ea2b5a2 100644 +index 130094d194ea1e8e7d3027d07d82465741192124..4d13dcee8c962ca9ee8f1c3d748f8ffe6a3cfb47 100644 --- a/dist/index.js +++ b/dist/index.js -@@ -274,6 +274,7 @@ var openaiChatResponseSchema = (0, import_provider_utils3.lazyValidator)( +@@ -290,6 +290,7 @@ var openaiChatResponseSchema = (0, import_provider_utils3.lazyValidator)( message: import_v42.z.object({ role: import_v42.z.literal("assistant").nullish(), content: import_v42.z.string().nullish(), @@ -10,7 +10,7 @@ index bf900591bf2847a3253fe441aad24c06da19c6c1..c1d9bb6fefa2df1383339324073db0a7 tool_calls: import_v42.z.array( import_v42.z.object({ id: import_v42.z.string().nullish(), -@@ -340,6 +341,7 @@ var openaiChatChunkSchema = (0, import_provider_utils3.lazyValidator)( +@@ -356,6 +357,7 @@ var openaiChatChunkSchema = (0, import_provider_utils3.lazyValidator)( delta: import_v42.z.object({ role: import_v42.z.enum(["assistant"]).nullish(), content: import_v42.z.string().nullish(), @@ -18,7 +18,7 @@ index bf900591bf2847a3253fe441aad24c06da19c6c1..c1d9bb6fefa2df1383339324073db0a7 tool_calls: import_v42.z.array( import_v42.z.object({ index: import_v42.z.number(), -@@ -795,6 +797,13 @@ var OpenAIChatLanguageModel = class { +@@ -814,6 +816,13 @@ var OpenAIChatLanguageModel = class { if (text != null && text.length > 0) { content.push({ type: "text", text }); } @@ -32,7 +32,7 @@ index bf900591bf2847a3253fe441aad24c06da19c6c1..c1d9bb6fefa2df1383339324073db0a7 for (const toolCall of (_a = choice.message.tool_calls) != null ? _a : []) { content.push({ type: "tool-call", -@@ -876,6 +885,7 @@ var OpenAIChatLanguageModel = class { +@@ -895,6 +904,7 @@ var OpenAIChatLanguageModel = class { }; let metadataExtracted = false; let isActiveText = false; @@ -40,7 +40,7 @@ index bf900591bf2847a3253fe441aad24c06da19c6c1..c1d9bb6fefa2df1383339324073db0a7 const providerMetadata = { openai: {} }; return { stream: response.pipeThrough( -@@ -933,6 +943,21 @@ var OpenAIChatLanguageModel = class { +@@ -952,6 +962,21 @@ var OpenAIChatLanguageModel = class { return; } const delta = choice.delta; @@ -62,7 +62,7 @@ index bf900591bf2847a3253fe441aad24c06da19c6c1..c1d9bb6fefa2df1383339324073db0a7 if (delta.content != null) { if (!isActiveText) { controller.enqueue({ type: "text-start", id: "0" }); -@@ -1045,6 +1070,9 @@ var OpenAIChatLanguageModel = class { +@@ -1064,6 +1089,9 @@ var OpenAIChatLanguageModel = class { } }, flush(controller) { diff --git a/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.53-4b77f4cf29.patch b/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.62-23ae56f8c8.patch similarity index 92% rename from .yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.53-4b77f4cf29.patch rename to .yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.62-23ae56f8c8.patch index 4481b58f32..62ab767576 100644 --- a/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.53-4b77f4cf29.patch +++ b/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.62-23ae56f8c8.patch @@ -1,5 +1,5 @@ diff --git a/sdk.mjs b/sdk.mjs -index bf429a344b7d59f70aead16b639f949b07688a81..f77d50cc5d3fb04292cb3ac7fa7085d02dcc628f 100755 +index dea7766a3432a1e809f12d6daba4f2834a219689..e0b02ef73da177ba32b903887d7bbbeaa08cc6d3 100755 --- a/sdk.mjs +++ b/sdk.mjs @@ -6250,7 +6250,7 @@ function createAbortController(maxListeners = DEFAULT_MAX_LISTENERS) { @@ -11,7 +11,7 @@ index bf429a344b7d59f70aead16b639f949b07688a81..f77d50cc5d3fb04292cb3ac7fa7085d0 import { createInterface } from "readline"; // ../src/utils/fsOperations.ts -@@ -6619,18 +6619,11 @@ class ProcessTransport { +@@ -6644,18 +6644,11 @@ class ProcessTransport { const errorMessage = isNativeBinary(pathToClaudeCodeExecutable) ? `Claude Code native binary not found at ${pathToClaudeCodeExecutable}. Please ensure Claude Code is installed via native installer or specify a valid path with options.pathToClaudeCodeExecutable.` : `Claude Code executable not found at ${pathToClaudeCodeExecutable}. Is options.pathToClaudeCodeExecutable set?`; throw new ReferenceError(errorMessage); } diff --git a/.yarn/patches/ollama-ai-provider-v2-npm-1.5.5-8bef249af9.patch b/.yarn/patches/ollama-ai-provider-v2-npm-1.5.5-8bef249af9.patch new file mode 100644 index 0000000000..c306bef6e5 --- /dev/null +++ b/.yarn/patches/ollama-ai-provider-v2-npm-1.5.5-8bef249af9.patch @@ -0,0 +1,145 @@ +diff --git a/dist/index.d.ts b/dist/index.d.ts +index 8dd9b498050dbecd8dd6b901acf1aa8ca38a49af..ed644349c9d38fe2a66b2fb44214f7c18eb97f89 100644 +--- a/dist/index.d.ts ++++ b/dist/index.d.ts +@@ -4,7 +4,7 @@ import { z } from 'zod/v4'; + + type OllamaChatModelId = "athene-v2" | "athene-v2:72b" | "aya-expanse" | "aya-expanse:8b" | "aya-expanse:32b" | "codegemma" | "codegemma:2b" | "codegemma:7b" | "codellama" | "codellama:7b" | "codellama:13b" | "codellama:34b" | "codellama:70b" | "codellama:code" | "codellama:python" | "command-r" | "command-r:35b" | "command-r-plus" | "command-r-plus:104b" | "command-r7b" | "command-r7b:7b" | "deepseek-r1" | "deepseek-r1:1.5b" | "deepseek-r1:7b" | "deepseek-r1:8b" | "deepseek-r1:14b" | "deepseek-r1:32b" | "deepseek-r1:70b" | "deepseek-r1:671b" | "deepseek-coder-v2" | "deepseek-coder-v2:16b" | "deepseek-coder-v2:236b" | "deepseek-v3" | "deepseek-v3:671b" | "devstral" | "devstral:24b" | "dolphin3" | "dolphin3:8b" | "exaone3.5" | "exaone3.5:2.4b" | "exaone3.5:7.8b" | "exaone3.5:32b" | "falcon2" | "falcon2:11b" | "falcon3" | "falcon3:1b" | "falcon3:3b" | "falcon3:7b" | "falcon3:10b" | "firefunction-v2" | "firefunction-v2:70b" | "gemma" | "gemma:2b" | "gemma:7b" | "gemma2" | "gemma2:2b" | "gemma2:9b" | "gemma2:27b" | "gemma3" | "gemma3:1b" | "gemma3:4b" | "gemma3:12b" | "gemma3:27b" | "granite3-dense" | "granite3-dense:2b" | "granite3-dense:8b" | "granite3-guardian" | "granite3-guardian:2b" | "granite3-guardian:8b" | "granite3-moe" | "granite3-moe:1b" | "granite3-moe:3b" | "granite3.1-dense" | "granite3.1-dense:2b" | "granite3.1-dense:8b" | "granite3.1-moe" | "granite3.1-moe:1b" | "granite3.1-moe:3b" | "llama2" | "llama2:7b" | "llama2:13b" | "llama2:70b" | "llama3" | "llama3:8b" | "llama3:70b" | "llama3-chatqa" | "llama3-chatqa:8b" | "llama3-chatqa:70b" | "llama3-gradient" | "llama3-gradient:8b" | "llama3-gradient:70b" | "llama3.1" | "llama3.1:8b" | "llama3.1:70b" | "llama3.1:405b" | "llama3.2" | "llama3.2:1b" | "llama3.2:3b" | "llama3.2-vision" | "llama3.2-vision:11b" | "llama3.2-vision:90b" | "llama3.3" | "llama3.3:70b" | "llama4" | "llama4:16x17b" | "llama4:128x17b" | "llama-guard3" | "llama-guard3:1b" | "llama-guard3:8b" | "llava" | "llava:7b" | "llava:13b" | "llava:34b" | "llava-llama3" | "llava-llama3:8b" | "llava-phi3" | "llava-phi3:3.8b" | "marco-o1" | "marco-o1:7b" | "mistral" | "mistral:7b" | "mistral-large" | "mistral-large:123b" | "mistral-nemo" | "mistral-nemo:12b" | "mistral-small" | "mistral-small:22b" | "mixtral" | "mixtral:8x7b" | "mixtral:8x22b" | "moondream" | "moondream:1.8b" | "openhermes" | "openhermes:v2.5" | "nemotron" | "nemotron:70b" | "nemotron-mini" | "nemotron-mini:4b" | "olmo" | "olmo:7b" | "olmo:13b" | "opencoder" | "opencoder:1.5b" | "opencoder:8b" | "phi3" | "phi3:3.8b" | "phi3:14b" | "phi3.5" | "phi3.5:3.8b" | "phi4" | "phi4:14b" | "qwen" | "qwen:7b" | "qwen:14b" | "qwen:32b" | "qwen:72b" | "qwen:110b" | "qwen2" | "qwen2:0.5b" | "qwen2:1.5b" | "qwen2:7b" | "qwen2:72b" | "qwen2.5" | "qwen2.5:0.5b" | "qwen2.5:1.5b" | "qwen2.5:3b" | "qwen2.5:7b" | "qwen2.5:14b" | "qwen2.5:32b" | "qwen2.5:72b" | "qwen2.5-coder" | "qwen2.5-coder:0.5b" | "qwen2.5-coder:1.5b" | "qwen2.5-coder:3b" | "qwen2.5-coder:7b" | "qwen2.5-coder:14b" | "qwen2.5-coder:32b" | "qwen3" | "qwen3:0.6b" | "qwen3:1.7b" | "qwen3:4b" | "qwen3:8b" | "qwen3:14b" | "qwen3:30b" | "qwen3:32b" | "qwen3:235b" | "qwq" | "qwq:32b" | "sailor2" | "sailor2:1b" | "sailor2:8b" | "sailor2:20b" | "shieldgemma" | "shieldgemma:2b" | "shieldgemma:9b" | "shieldgemma:27b" | "smallthinker" | "smallthinker:3b" | "smollm" | "smollm:135m" | "smollm:360m" | "smollm:1.7b" | "tinyllama" | "tinyllama:1.1b" | "tulu3" | "tulu3:8b" | "tulu3:70b" | (string & {}); + declare const ollamaProviderOptions: z.ZodObject<{ +- think: z.ZodOptional; ++ think: z.ZodOptional, z.ZodLiteral<"medium">, z.ZodLiteral<"high">]>>; + options: z.ZodOptional; + repeat_last_n: z.ZodOptional; +@@ -27,9 +27,11 @@ interface OllamaCompletionSettings { + * the model's thinking from the model's output. When disabled, the model will not think + * and directly output the content. + * ++ * For gpt-oss models, you can also use 'low', 'medium', or 'high' to control the depth of thinking. ++ * + * Only supported by certain models like DeepSeek R1 and Qwen 3. + */ +- think?: boolean; ++ think?: boolean | 'low' | 'medium' | 'high'; + /** + * Echo back the prompt in addition to the completion. + */ +@@ -146,7 +148,7 @@ declare const ollamaEmbeddingProviderOptions: z.ZodObject<{ + type OllamaEmbeddingProviderOptions = z.infer; + + declare const ollamaCompletionProviderOptions: z.ZodObject<{ +- think: z.ZodOptional; ++ think: z.ZodOptional, z.ZodLiteral<"medium">, z.ZodLiteral<"high">]>>; + user: z.ZodOptional; + suffix: z.ZodOptional; + echo: z.ZodOptional; +diff --git a/dist/index.js b/dist/index.js +index 35b5142ce8476ce2549ed7c2ec48e7d8c46c90d9..2ef64dc9a4c2be043e6af608241a6a8309a5a69f 100644 +--- a/dist/index.js ++++ b/dist/index.js +@@ -158,7 +158,7 @@ function getResponseMetadata({ + + // src/completion/ollama-completion-language-model.ts + var ollamaCompletionProviderOptions = import_v42.z.object({ +- think: import_v42.z.boolean().optional(), ++ think: import_v42.z.union([import_v42.z.boolean(), import_v42.z.literal('low'), import_v42.z.literal('medium'), import_v42.z.literal('high')]).optional(), + user: import_v42.z.string().optional(), + suffix: import_v42.z.string().optional(), + echo: import_v42.z.boolean().optional() +@@ -662,7 +662,7 @@ function convertToOllamaChatMessages({ + const images = content.filter((part) => part.type === "file" && part.mediaType.startsWith("image/")).map((part) => part.data); + messages.push({ + role: "user", +- content: userText.length > 0 ? userText : [], ++ content: userText.length > 0 ? userText : '', + images: images.length > 0 ? images : void 0 + }); + break; +@@ -813,9 +813,11 @@ var ollamaProviderOptions = import_v44.z.object({ + * the model's thinking from the model's output. When disabled, the model will not think + * and directly output the content. + * ++ * For gpt-oss models, you can also use 'low', 'medium', or 'high' to control the depth of thinking. ++ * + * Only supported by certain models like DeepSeek R1 and Qwen 3. + */ +- think: import_v44.z.boolean().optional(), ++ think: import_v44.z.union([import_v44.z.boolean(), import_v44.z.literal('low'), import_v44.z.literal('medium'), import_v44.z.literal('high')]).optional(), + options: import_v44.z.object({ + num_ctx: import_v44.z.number().optional(), + repeat_last_n: import_v44.z.number().optional(), +@@ -929,14 +931,16 @@ var OllamaRequestBuilder = class { + prompt, + systemMessageMode: "system" + }), +- temperature, +- top_p: topP, + max_output_tokens: maxOutputTokens, + ...(responseFormat == null ? void 0 : responseFormat.type) === "json" && { + format: responseFormat.schema != null ? responseFormat.schema : "json" + }, + think: (_a = ollamaOptions == null ? void 0 : ollamaOptions.think) != null ? _a : false, +- options: (_b = ollamaOptions == null ? void 0 : ollamaOptions.options) != null ? _b : void 0 ++ options: { ++ ...temperature !== void 0 && { temperature }, ++ ...topP !== void 0 && { top_p: topP }, ++ ...((_b = ollamaOptions == null ? void 0 : ollamaOptions.options) != null ? _b : {}) ++ } + }; + } + }; +diff --git a/dist/index.mjs b/dist/index.mjs +index e2a634a78d80ac9542f2cc4f96cf2291094b10cf..67b23efce3c1cf4f026693d3ff9246988a3ef26e 100644 +--- a/dist/index.mjs ++++ b/dist/index.mjs +@@ -144,7 +144,7 @@ function getResponseMetadata({ + + // src/completion/ollama-completion-language-model.ts + var ollamaCompletionProviderOptions = z2.object({ +- think: z2.boolean().optional(), ++ think: z2.union([z2.boolean(), z2.literal('low'), z2.literal('medium'), z2.literal('high')]).optional(), + user: z2.string().optional(), + suffix: z2.string().optional(), + echo: z2.boolean().optional() +@@ -662,7 +662,7 @@ function convertToOllamaChatMessages({ + const images = content.filter((part) => part.type === "file" && part.mediaType.startsWith("image/")).map((part) => part.data); + messages.push({ + role: "user", +- content: userText.length > 0 ? userText : [], ++ content: userText.length > 0 ? userText : '', + images: images.length > 0 ? images : void 0 + }); + break; +@@ -815,9 +815,11 @@ var ollamaProviderOptions = z4.object({ + * the model's thinking from the model's output. When disabled, the model will not think + * and directly output the content. + * ++ * For gpt-oss models, you can also use 'low', 'medium', or 'high' to control the depth of thinking. ++ * + * Only supported by certain models like DeepSeek R1 and Qwen 3. + */ +- think: z4.boolean().optional(), ++ think: z4.union([z4.boolean(), z4.literal('low'), z4.literal('medium'), z4.literal('high')]).optional(), + options: z4.object({ + num_ctx: z4.number().optional(), + repeat_last_n: z4.number().optional(), +@@ -931,14 +933,16 @@ var OllamaRequestBuilder = class { + prompt, + systemMessageMode: "system" + }), +- temperature, +- top_p: topP, + max_output_tokens: maxOutputTokens, + ...(responseFormat == null ? void 0 : responseFormat.type) === "json" && { + format: responseFormat.schema != null ? responseFormat.schema : "json" + }, + think: (_a = ollamaOptions == null ? void 0 : ollamaOptions.think) != null ? _a : false, +- options: (_b = ollamaOptions == null ? void 0 : ollamaOptions.options) != null ? _b : void 0 ++ options: { ++ ...temperature !== void 0 && { temperature }, ++ ...topP !== void 0 && { top_p: topP }, ++ ...((_b = ollamaOptions == null ? void 0 : ollamaOptions.options) != null ? _b : {}) ++ } + }; + } + }; diff --git a/CLAUDE.md b/CLAUDE.md index c96fc0e403..c68187db93 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -28,7 +28,7 @@ When creating a Pull Request, you MUST: - **Development**: `yarn dev` - Runs Electron app in development mode with hot reload - **Debug**: `yarn debug` - Starts with debugging enabled, use `chrome://inspect` to attach debugger - **Build Check**: `yarn build:check` - **REQUIRED** before commits (lint + test + typecheck) - - If having i18n sort issues, run `yarn sync:i18n` first to sync template + - If having i18n sort issues, run `yarn i18n:sync` first to sync template - If having formatting issues, run `yarn format` first - **Test**: `yarn test` - Run all tests (Vitest) across main and renderer processes - **Single Test**: @@ -40,20 +40,23 @@ When creating a Pull Request, you MUST: ## Project Architecture ### Electron Structure + - **Main Process** (`src/main/`): Node.js backend with services (MCP, Knowledge, Storage, etc.) - **Renderer Process** (`src/renderer/`): React UI with Redux state management - **Preload Scripts** (`src/preload/`): Secure IPC bridge ### Key Components + - **AI Core** (`src/renderer/src/aiCore/`): Middleware pipeline for multiple AI providers. - **Services** (`src/main/services/`): MCPService, KnowledgeService, WindowService, etc. - **Build System**: Electron-Vite with experimental rolldown-vite, yarn workspaces. - **State Management**: Redux Toolkit (`src/renderer/src/store/`) for predictable state. ### Logging + ```typescript -import { loggerService } from '@logger' -const logger = loggerService.withContext('moduleName') +import { loggerService } from "@logger"; +const logger = loggerService.withContext("moduleName"); // Renderer: loggerService.initWindowSource('windowName') first -logger.info('message', CONTEXT) +logger.info("message", CONTEXT); ``` diff --git a/README.md b/README.md index f790c10cbd..781e9299e5 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ -

English | 中文 | Official Site | Documents | Development | Feedback

+

English | 中文 | Official Site | Documents | Development | Feedback

@@ -242,12 +242,12 @@ The Enterprise Edition addresses core challenges in team collaboration by centra ## Version Comparison -| Feature | Community Edition | Enterprise Edition | -| :---------------- | :----------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------- | -| **Open Source** | ✅ Yes | ⭕️ Partially released to customers | +| Feature | Community Edition | Enterprise Edition | +| :---------------- | :----------------------------------------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------- | +| **Open Source** | ✅ Yes | ⭕️ Partially released to customers | | **Cost** | [AGPL-3.0 License](https://github.com/CherryHQ/cherry-studio?tab=AGPL-3.0-1-ov-file) | Buyout / Subscription Fee | -| **Admin Backend** | — | ● Centralized **Model** Access
● **Employee** Management
● Shared **Knowledge Base**
● **Access** Control
● **Data** Backup | -| **Server** | — | ✅ Dedicated Private Deployment | +| **Admin Backend** | — | ● Centralized **Model** Access
● **Employee** Management
● Shared **Knowledge Base**
● **Access** Control
● **Data** Backup | +| **Server** | — | ✅ Dedicated Private Deployment | ## Get the Enterprise Edition @@ -275,7 +275,7 @@ We believe the Enterprise Edition will become your team's AI productivity engine # 📊 GitHub Stats -![Stats](https://repobeats.axiom.co/api/embed/a693f2e5f773eed620f70031e974552156c7f397.svg 'Repobeats analytics image') +![Stats](https://repobeats.axiom.co/api/embed/a693f2e5f773eed620f70031e974552156c7f397.svg "Repobeats analytics image") # ⭐️ Star History diff --git a/biome.jsonc b/biome.jsonc index 705b1e01f3..6f925f5af2 100644 --- a/biome.jsonc +++ b/biome.jsonc @@ -23,7 +23,7 @@ }, "files": { "ignoreUnknown": false, - "includes": ["**", "!**/.claude/**", "!**/.vscode/**"], + "includes": ["**", "!**/.claude/**", "!**/.vscode/**", "!**/.conductor/**"], "maxSize": 2097152 }, "formatter": { diff --git a/build/nsis-installer.nsh b/build/nsis-installer.nsh index 769ccaaa19..e644e18f3d 100644 --- a/build/nsis-installer.nsh +++ b/build/nsis-installer.nsh @@ -12,8 +12,13 @@ ; https://github.com/electron-userland/electron-builder/issues/1122 !ifndef BUILD_UNINSTALLER + ; Check VC++ Redistributable based on architecture stored in $1 Function checkVCRedist - ReadRegDWORD $0 HKLM "SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\x64" "Installed" + ${If} $1 == "arm64" + ReadRegDWORD $0 HKLM "SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\ARM64" "Installed" + ${Else} + ReadRegDWORD $0 HKLM "SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\x64" "Installed" + ${EndIf} FunctionEnd Function checkArchitectureCompatibility @@ -97,29 +102,47 @@ Call checkVCRedist ${If} $0 != "1" - MessageBox MB_YESNO "\ - NOTE: ${PRODUCT_NAME} requires $\r$\n\ - 'Microsoft Visual C++ Redistributable'$\r$\n\ - to function properly.$\r$\n$\r$\n\ - Download and install now?" /SD IDYES IDYES InstallVCRedist IDNO DontInstall - InstallVCRedist: - inetc::get /CAPTION " " /BANNER "Downloading Microsoft Visual C++ Redistributable..." "https://aka.ms/vs/17/release/vc_redist.x64.exe" "$TEMP\vc_redist.x64.exe" - ExecWait "$TEMP\vc_redist.x64.exe /install /norestart" - ;IfErrors InstallError ContinueInstall ; vc_redist exit code is unreliable :( - Call checkVCRedist - ${If} $0 == "1" - Goto ContinueInstall - ${EndIf} + ; VC++ is required - install automatically since declining would abort anyway + ; Select download URL based on system architecture (stored in $1) + ${If} $1 == "arm64" + StrCpy $2 "https://aka.ms/vs/17/release/vc_redist.arm64.exe" + StrCpy $3 "$TEMP\vc_redist.arm64.exe" + ${Else} + StrCpy $2 "https://aka.ms/vs/17/release/vc_redist.x64.exe" + StrCpy $3 "$TEMP\vc_redist.x64.exe" + ${EndIf} - ;InstallError: - MessageBox MB_ICONSTOP "\ - There was an unexpected error installing$\r$\n\ - Microsoft Visual C++ Redistributable.$\r$\n\ - The installation of ${PRODUCT_NAME} cannot continue." - DontInstall: + inetc::get /CAPTION " " /BANNER "Downloading Microsoft Visual C++ Redistributable..." \ + $2 $3 /END + Pop $0 ; Get download status from inetc::get + ${If} $0 != "OK" + MessageBox MB_ICONSTOP|MB_YESNO "\ + Failed to download Microsoft Visual C++ Redistributable.$\r$\n$\r$\n\ + Error: $0$\r$\n$\r$\n\ + Would you like to open the download page in your browser?$\r$\n\ + $2" IDYES openDownloadUrl IDNO skipDownloadUrl + openDownloadUrl: + ExecShell "open" $2 + skipDownloadUrl: Abort + ${EndIf} + + ExecWait "$3 /install /quiet /norestart" + ; Note: vc_redist exit code is unreliable, verify via registry check instead + + Call checkVCRedist + ${If} $0 != "1" + MessageBox MB_ICONSTOP|MB_YESNO "\ + Microsoft Visual C++ Redistributable installation failed.$\r$\n$\r$\n\ + Would you like to open the download page in your browser?$\r$\n\ + $2$\r$\n$\r$\n\ + The installation of ${PRODUCT_NAME} cannot continue." IDYES openInstallUrl IDNO skipInstallUrl + openInstallUrl: + ExecShell "open" $2 + skipInstallUrl: + Abort + ${EndIf} ${EndIf} - ContinueInstall: Pop $4 Pop $3 Pop $2 diff --git a/docs/en/guides/development.md b/docs/en/guides/development.md index fe67742768..032a515f61 100644 --- a/docs/en/guides/development.md +++ b/docs/en/guides/development.md @@ -36,7 +36,7 @@ yarn install ### ENV ```bash -copy .env.example .env +cp .env.example .env ``` ### Start diff --git a/docs/en/guides/i18n.md b/docs/en/guides/i18n.md index a3284e3ab9..7fccfde695 100644 --- a/docs/en/guides/i18n.md +++ b/docs/en/guides/i18n.md @@ -71,7 +71,7 @@ Tools like i18n Ally cannot parse dynamic content within template strings, resul ```javascript // Not recommended - Plugin cannot resolve -const message = t(`fruits.${fruit}`) +const message = t(`fruits.${fruit}`); ``` #### 2. **No Real-time Rendering in Editor** @@ -91,14 +91,14 @@ For example: ```ts // src/renderer/src/i18n/label.ts const themeModeKeyMap = { - dark: 'settings.theme.dark', - light: 'settings.theme.light', - system: 'settings.theme.system' -} as const + dark: "settings.theme.dark", + light: "settings.theme.light", + system: "settings.theme.system", +} as const; export const getThemeModeLabel = (key: string): string => { - return themeModeKeyMap[key] ? t(themeModeKeyMap[key]) : key -} + return themeModeKeyMap[key] ? t(themeModeKeyMap[key]) : key; +}; ``` By avoiding template strings, you gain better developer experience, more reliable translation checks, and a more maintainable codebase. @@ -107,7 +107,7 @@ By avoiding template strings, you gain better developer experience, more reliabl The project includes several scripts to automate i18n-related tasks: -### `check:i18n` - Validate i18n Structure +### `i18n:check` - Validate i18n Structure This script checks: @@ -116,10 +116,10 @@ This script checks: - Whether keys are properly sorted ```bash -yarn check:i18n +yarn i18n:check ``` -### `sync:i18n` - Synchronize JSON Structure and Sort Order +### `i18n:sync` - Synchronize JSON Structure and Sort Order This script uses `zh-cn.json` as the source of truth to sync structure across all language files, including: @@ -128,14 +128,14 @@ This script uses `zh-cn.json` as the source of truth to sync structure across al 3. Sorting keys automatically ```bash -yarn sync:i18n +yarn i18n:sync ``` -### `auto:i18n` - Automatically Translate Pending Texts +### `i18n:translate` - Automatically Translate Pending Texts This script fills in texts marked as `[to be translated]` using machine translation. -Typically, after adding new texts in `zh-cn.json`, run `sync:i18n`, then `auto:i18n` to complete translations. +Typically, after adding new texts in `zh-cn.json`, run `i18n:sync`, then `i18n:translate` to complete translations. Before using this script, set the required environment variables: @@ -148,30 +148,20 @@ MODEL="qwen-plus-latest" Alternatively, add these variables directly to your `.env` file. ```bash -yarn auto:i18n -``` - -### `update:i18n` - Object-level Translation Update - -Updates translations in language files under `src/renderer/src/i18n/translate` at the object level, preserving existing translations and only updating new content. - -**Not recommended** — prefer `auto:i18n` for translation tasks. - -```bash -yarn update:i18n +yarn i18n:translate ``` ### Workflow 1. During development, first add the required text in `zh-cn.json` 2. Confirm it displays correctly in the Chinese environment -3. Run `yarn sync:i18n` to propagate the keys to other language files -4. Run `yarn auto:i18n` to perform machine translation +3. Run `yarn i18n:sync` to propagate the keys to other language files +4. Run `yarn i18n:translate` to perform machine translation 5. Grab a coffee and let the magic happen! ## Best Practices 1. **Use Chinese as Source Language**: All development starts in Chinese, then translates to other languages. -2. **Run Check Script Before Commit**: Use `yarn check:i18n` to catch i18n issues early. +2. **Run Check Script Before Commit**: Use `yarn i18n:check` to catch i18n issues early. 3. **Translate in Small Increments**: Avoid accumulating a large backlog of untranslated content. 4. **Keep Keys Semantically Clear**: Keys should clearly express their purpose, e.g., `user.profile.avatar.upload.error` diff --git a/docs/en/references/fuzzy-search.md b/docs/en/references/fuzzy-search.md new file mode 100644 index 0000000000..11c2002cb9 --- /dev/null +++ b/docs/en/references/fuzzy-search.md @@ -0,0 +1,129 @@ +# Fuzzy Search for File List + +This document describes the fuzzy search implementation for file listing in Cherry Studio. + +## Overview + +The fuzzy search feature allows users to find files by typing partial or approximate file names/paths. It uses a two-tier file filtering strategy (ripgrep glob pre-filtering with greedy substring fallback) combined with subsequence-based scoring for optimal performance and flexibility. + +## Features + +- **Ripgrep Glob Pre-filtering**: Primary filtering using glob patterns for fast native-level filtering +- **Greedy Substring Matching**: Fallback file filtering strategy when ripgrep glob pre-filtering returns no results +- **Subsequence-based Segment Scoring**: During scoring, path segments gain additional weight when query characters appear in order +- **Relevance Scoring**: Results are sorted by a relevance score derived from multiple factors + +## Matching Strategies + +### 1. Ripgrep Glob Pre-filtering (Primary) + +The query is converted to a glob pattern for ripgrep to do initial filtering: + +``` +Query: "updater" +Glob: "*u*p*d*a*t*e*r*" +``` + +This leverages ripgrep's native performance for the initial file filtering. + +### 2. Greedy Substring Matching (Fallback) + +When the glob pre-filter returns no results, the system falls back to greedy substring matching. This allows more flexible matching: + +``` +Query: "updatercontroller" +File: "packages/update/src/node/updateController.ts" + +Matching process: +1. Find "update" (longest match from start) +2. Remaining "rcontroller" → find "r" then "controller" +3. All parts matched → Success +``` + +## Scoring Algorithm + +Results are ranked by a relevance score based on named constants defined in `FileStorage.ts`: + +| Constant | Value | Description | +|----------|-------|-------------| +| `SCORE_FILENAME_STARTS` | 100 | Filename starts with query (highest priority) | +| `SCORE_FILENAME_CONTAINS` | 80 | Filename contains exact query substring | +| `SCORE_SEGMENT_MATCH` | 60 | Per path segment that matches query | +| `SCORE_WORD_BOUNDARY` | 20 | Query matches start of a word | +| `SCORE_CONSECUTIVE_CHAR` | 15 | Per consecutive character match | +| `PATH_LENGTH_PENALTY_FACTOR` | 4 | Logarithmic penalty for longer paths | + +### Scoring Strategy + +The scoring prioritizes: +1. **Filename matches** (highest): Files where the query appears in the filename are most relevant +2. **Path segment matches**: Multiple matching segments indicate stronger relevance +3. **Word boundaries**: Matching at word starts (e.g., "upd" matching "update") is preferred +4. **Consecutive matches**: Longer consecutive character sequences score higher +5. **Path length**: Shorter paths are preferred (logarithmic penalty prevents long paths from dominating) + +### Example Scoring + +For query `updater`: + +| File | Score Factors | +|------|---------------| +| `RCUpdater.js` | Short path + filename contains "updater" | +| `updateController.ts` | Multiple segment matches | +| `UpdaterHelper.plist` | Long path penalty | + +## Configuration + +### DirectoryListOptions + +```typescript +interface DirectoryListOptions { + recursive?: boolean // Default: true + maxDepth?: number // Default: 10 + includeHidden?: boolean // Default: false + includeFiles?: boolean // Default: true + includeDirectories?: boolean // Default: true + maxEntries?: number // Default: 20 + searchPattern?: string // Default: '.' + fuzzy?: boolean // Default: true +} +``` + +## Usage + +```typescript +// Basic fuzzy search +const files = await window.api.file.listDirectory(dirPath, { + searchPattern: 'updater', + fuzzy: true, + maxEntries: 20 +}) + +// Disable fuzzy search (exact glob matching) +const files = await window.api.file.listDirectory(dirPath, { + searchPattern: 'update', + fuzzy: false +}) +``` + +## Performance Considerations + +1. **Ripgrep Pre-filtering**: Most queries are handled by ripgrep's native glob matching, which is extremely fast +2. **Fallback Only When Needed**: Greedy substring matching (which loads all files) only runs when glob matching returns empty results +3. **Result Limiting**: Only top 20 results are returned by default +4. **Excluded Directories**: Common large directories are automatically excluded: + - `node_modules` + - `.git` + - `dist`, `build` + - `.next`, `.nuxt` + - `coverage`, `.cache` + +## Implementation Details + +The implementation is located in `src/main/services/FileStorage.ts`: + +- `queryToGlobPattern()`: Converts query to ripgrep glob pattern +- `isFuzzyMatch()`: Subsequence matching algorithm +- `isGreedySubstringMatch()`: Greedy substring matching fallback +- `getFuzzyMatchScore()`: Calculates relevance score +- `listDirectoryWithRipgrep()`: Main search orchestration diff --git a/docs/zh/README.md b/docs/zh/README.md index f8a1f1ab8c..c4adeb4901 100644 --- a/docs/zh/README.md +++ b/docs/zh/README.md @@ -34,7 +34,7 @@

- English | 中文 | 官方网站 | 文档 | 开发 | 反馈
+ English | 中文 | 官方网站 | 文档 | 开发 | 反馈

@@ -281,7 +281,7 @@ https://docs.cherry-ai.com # 📊 GitHub 统计 -![Stats](https://repobeats.axiom.co/api/embed/a693f2e5f773eed620f70031e974552156c7f397.svg 'Repobeats analytics image') +![Stats](https://repobeats.axiom.co/api/embed/a693f2e5f773eed620f70031e974552156c7f397.svg "Repobeats analytics image") # ⭐️ Star 记录 diff --git a/docs/zh/guides/development.md b/docs/zh/guides/development.md index fe67742768..032a515f61 100644 --- a/docs/zh/guides/development.md +++ b/docs/zh/guides/development.md @@ -36,7 +36,7 @@ yarn install ### ENV ```bash -copy .env.example .env +cp .env.example .env ``` ### Start diff --git a/docs/zh/guides/i18n.md b/docs/zh/guides/i18n.md index 82624d35c8..c8a8ccc66b 100644 --- a/docs/zh/guides/i18n.md +++ b/docs/zh/guides/i18n.md @@ -1,17 +1,17 @@ # 如何优雅地做好 i18n -## 使用i18n ally插件提升开发体验 +## 使用 i18n ally 插件提升开发体验 -i18n ally是一个强大的VSCode插件,它能在开发阶段提供实时反馈,帮助开发者更早发现文案缺失和错译问题。 +i18n ally 是一个强大的 VSCode 插件,它能在开发阶段提供实时反馈,帮助开发者更早发现文案缺失和错译问题。 项目中已经配置好了插件设置,直接安装即可。 ### 开发时优势 - **实时预览**:翻译文案会直接显示在编辑器中 -- **错误检测**:自动追踪标记出缺失的翻译或未使用的key -- **快速跳转**:可通过key直接跳转到定义处(Ctrl/Cmd + click) -- **自动补全**:输入i18n key时提供自动补全建议 +- **错误检测**:自动追踪标记出缺失的翻译或未使用的 key +- **快速跳转**:可通过 key 直接跳转到定义处(Ctrl/Cmd + click) +- **自动补全**:输入 i18n key 时提供自动补全建议 ### 效果展示 @@ -23,9 +23,9 @@ i18n ally是一个强大的VSCode插件,它能在开发阶段提供实时反 ## i18n 约定 -### **绝对避免使用flat格式** +### **绝对避免使用 flat 格式** -绝对避免使用flat格式,如`"add.button.tip": "添加"`。应采用清晰的嵌套结构: +绝对避免使用 flat 格式,如`"add.button.tip": "添加"`。应采用清晰的嵌套结构: ```json // 错误示例 - flat结构 @@ -52,14 +52,14 @@ i18n ally是一个强大的VSCode插件,它能在开发阶段提供实时反 #### 为什么要使用嵌套结构 1. **自然分组**:通过对象结构天然能将相关上下文的文案分到一个组别中 -2. **插件要求**:i18n ally 插件需要嵌套或flat格式其一的文件才能正常分析 +2. **插件要求**:i18n ally 插件需要嵌套或 flat 格式其一的文件才能正常分析 ### **避免在`t()`中使用模板字符串** -**强烈建议避免使用模板字符串**进行动态插值。虽然模板字符串在JavaScript开发中非常方便,但在国际化场景下会带来一系列问题。 +**强烈建议避免使用模板字符串**进行动态插值。虽然模板字符串在 JavaScript 开发中非常方便,但在国际化场景下会带来一系列问题。 1. **插件无法跟踪** - i18n ally等工具无法解析模板字符串中的动态内容,导致: + i18n ally 等工具无法解析模板字符串中的动态内容,导致: - 无法正确显示实时预览 - 无法检测翻译缺失 @@ -67,11 +67,11 @@ i18n ally是一个强大的VSCode插件,它能在开发阶段提供实时反 ```javascript // 不推荐 - 插件无法解析 - const message = t(`fruits.${fruit}`) + const message = t(`fruits.${fruit}`); ``` 2. **编辑器无法实时渲染** - 在IDE中,模板字符串会显示为原始代码而非最终翻译结果,降低了开发体验。 + 在 IDE 中,模板字符串会显示为原始代码而非最终翻译结果,降低了开发体验。 3. **更难以维护** 由于插件无法跟踪这样的文案,编辑器中也无法渲染,开发者必须人工确认语言文件中是否存在相应的文案。 @@ -85,36 +85,36 @@ i18n ally是一个强大的VSCode插件,它能在开发阶段提供实时反 ```ts // src/renderer/src/i18n/label.ts const themeModeKeyMap = { - dark: 'settings.theme.dark', - light: 'settings.theme.light', - system: 'settings.theme.system' -} as const + dark: "settings.theme.dark", + light: "settings.theme.light", + system: "settings.theme.system", +} as const; export const getThemeModeLabel = (key: string): string => { - return themeModeKeyMap[key] ? t(themeModeKeyMap[key]) : key -} + return themeModeKeyMap[key] ? t(themeModeKeyMap[key]) : key; +}; ``` 通过避免模板字符串,可以获得更好的开发体验、更可靠的翻译检查以及更易维护的代码库。 ## 自动化脚本 -项目中有一系列脚本来自动化i18n相关任务: +项目中有一系列脚本来自动化 i18n 相关任务: -### `check:i18n` - 检查i18n结构 +### `i18n:check` - 检查 i18n 结构 此脚本会检查: - 所有语言文件是否为嵌套结构 -- 是否存在缺失的key -- 是否存在多余的key +- 是否存在缺失的 key +- 是否存在多余的 key - 是否已经有序 ```bash -yarn check:i18n +yarn i18n:check ``` -### `sync:i18n` - 同步json结构与排序 +### `i18n:sync` - 同步 json 结构与排序 此脚本以`zh-cn.json`文件为基准,将结构同步到其他语言文件,包括: @@ -123,14 +123,14 @@ yarn check:i18n 3. 自动排序 ```bash -yarn sync:i18n +yarn i18n:sync ``` -### `auto:i18n` - 自动翻译待翻译文本 +### `i18n:translate` - 自动翻译待翻译文本 次脚本自动将标记为待翻译的文本通过机器翻译填充。 -通常,在`zh-cn.json`中添加所需文案后,执行`sync:i18n`即可自动完成翻译。 +通常,在`zh-cn.json`中添加所需文案后,执行`i18n:sync`即可自动完成翻译。 使用该脚本前,需要配置环境变量,例如: @@ -143,29 +143,19 @@ MODEL="qwen-plus-latest" 你也可以通过直接编辑`.env`文件来添加环境变量。 ```bash -yarn auto:i18n -``` - -### `update:i18n` - 对象级别翻译更新 - -对`src/renderer/src/i18n/translate`中的语言文件进行对象级别的翻译更新,保留已有翻译,只更新新增内容。 - -**不建议**使用该脚本,更推荐使用`auto:i18n`进行翻译。 - -```bash -yarn update:i18n +yarn i18n:translate ``` ### 工作流 1. 开发阶段,先在`zh-cn.json`中添加所需文案 -2. 确认在中文环境下显示无误后,使用`yarn sync:i18n`将文案同步到其他语言文件 -3. 使用`yarn auto:i18n`进行自动翻译 +2. 确认在中文环境下显示无误后,使用`yarn i18n:sync`将文案同步到其他语言文件 +3. 使用`yarn i18n:translate`进行自动翻译 4. 喝杯咖啡,等翻译完成吧! ## 最佳实践 1. **以中文为源语言**:所有开发首先使用中文,再翻译为其他语言 -2. **提交前运行检查脚本**:使用`yarn check:i18n`检查i18n是否有问题 +2. **提交前运行检查脚本**:使用`yarn i18n:check`检查 i18n 是否有问题 3. **小步提交翻译**:避免积累大量未翻译文本 -4. **保持key语义明确**:key应能清晰表达其用途,如`user.profile.avatar.upload.error` +4. **保持 key 语义明确**:key 应能清晰表达其用途,如`user.profile.avatar.upload.error` diff --git a/docs/zh/references/fuzzy-search.md b/docs/zh/references/fuzzy-search.md new file mode 100644 index 0000000000..d28d189928 --- /dev/null +++ b/docs/zh/references/fuzzy-search.md @@ -0,0 +1,129 @@ +# 文件列表模糊搜索 + +本文档描述了 Cherry Studio 中文件列表的模糊搜索实现。 + +## 概述 + +模糊搜索功能允许用户通过输入部分或近似的文件名/路径来查找文件。它使用两层文件过滤策略(ripgrep glob 预过滤 + 贪婪子串匹配回退),结合基于子序列的评分,以获得最佳性能和灵活性。 + +## 功能特性 + +- **Ripgrep Glob 预过滤**:使用 glob 模式进行快速原生级过滤的主要过滤策略 +- **贪婪子串匹配**:当 ripgrep glob 预过滤无结果时的回退文件过滤策略 +- **基于子序列的段评分**:评分时,当查询字符按顺序出现时,路径段获得额外权重 +- **相关性评分**:结果按多因素相关性分数排序 + +## 匹配策略 + +### 1. Ripgrep Glob 预过滤(主要) + +查询被转换为 glob 模式供 ripgrep 进行初始过滤: + +``` +查询: "updater" +Glob: "*u*p*d*a*t*e*r*" +``` + +这利用了 ripgrep 的原生性能进行初始文件过滤。 + +### 2. 贪婪子串匹配(回退) + +当 glob 预过滤无结果时,系统回退到贪婪子串匹配。这允许更灵活的匹配: + +``` +查询: "updatercontroller" +文件: "packages/update/src/node/updateController.ts" + +匹配过程: +1. 找到 "update"(从开头的最长匹配) +2. 剩余 "rcontroller" → 找到 "r" 然后 "controller" +3. 所有部分都匹配 → 成功 +``` + +## 评分算法 + +结果根据 `FileStorage.ts` 中定义的命名常量进行相关性分数排名: + +| 常量 | 值 | 描述 | +|------|-----|------| +| `SCORE_FILENAME_STARTS` | 100 | 文件名以查询开头(最高优先级)| +| `SCORE_FILENAME_CONTAINS` | 80 | 文件名包含精确查询子串 | +| `SCORE_SEGMENT_MATCH` | 60 | 每个匹配查询的路径段 | +| `SCORE_WORD_BOUNDARY` | 20 | 查询匹配单词开头 | +| `SCORE_CONSECUTIVE_CHAR` | 15 | 每个连续字符匹配 | +| `PATH_LENGTH_PENALTY_FACTOR` | 4 | 较长路径的对数惩罚 | + +### 评分策略 + +评分优先级: +1. **文件名匹配**(最高):查询出现在文件名中的文件最相关 +2. **路径段匹配**:多个匹配段表示更强的相关性 +3. **词边界**:在单词开头匹配(如 "upd" 匹配 "update")更优先 +4. **连续匹配**:更长的连续字符序列得分更高 +5. **路径长度**:较短路径更优先(对数惩罚防止长路径主导评分) + +### 评分示例 + +对于查询 `updater`: + +| 文件 | 评分因素 | +|------|----------| +| `RCUpdater.js` | 短路径 + 文件名包含 "updater" | +| `updateController.ts` | 多个路径段匹配 | +| `UpdaterHelper.plist` | 长路径惩罚 | + +## 配置 + +### DirectoryListOptions + +```typescript +interface DirectoryListOptions { + recursive?: boolean // 默认: true + maxDepth?: number // 默认: 10 + includeHidden?: boolean // 默认: false + includeFiles?: boolean // 默认: true + includeDirectories?: boolean // 默认: true + maxEntries?: number // 默认: 20 + searchPattern?: string // 默认: '.' + fuzzy?: boolean // 默认: true +} +``` + +## 使用方法 + +```typescript +// 基本模糊搜索 +const files = await window.api.file.listDirectory(dirPath, { + searchPattern: 'updater', + fuzzy: true, + maxEntries: 20 +}) + +// 禁用模糊搜索(精确 glob 匹配) +const files = await window.api.file.listDirectory(dirPath, { + searchPattern: 'update', + fuzzy: false +}) +``` + +## 性能考虑 + +1. **Ripgrep 预过滤**:大多数查询由 ripgrep 的原生 glob 匹配处理,速度极快 +2. **仅在需要时回退**:贪婪子串匹配(加载所有文件)仅在 glob 匹配返回空结果时运行 +3. **结果限制**:默认只返回前 20 个结果 +4. **排除目录**:自动排除常见的大型目录: + - `node_modules` + - `.git` + - `dist`、`build` + - `.next`、`.nuxt` + - `coverage`、`.cache` + +## 实现细节 + +实现位于 `src/main/services/FileStorage.ts`: + +- `queryToGlobPattern()`:将查询转换为 ripgrep glob 模式 +- `isFuzzyMatch()`:子序列匹配算法 +- `isGreedySubstringMatch()`:贪婪子串匹配回退 +- `getFuzzyMatchScore()`:计算相关性分数 +- `listDirectoryWithRipgrep()`:主搜索协调 diff --git a/docs/zh/references/lan-transfer-protocol.md b/docs/zh/references/lan-transfer-protocol.md new file mode 100644 index 0000000000..a4c01a23c5 --- /dev/null +++ b/docs/zh/references/lan-transfer-protocol.md @@ -0,0 +1,850 @@ +# Cherry Studio 局域网传输协议规范 + +> 版本: 1.0 +> 最后更新: 2025-12 + +本文档定义了 Cherry Studio 桌面客户端(Electron)与移动端(Expo)之间的局域网文件传输协议。 + +--- + +## 目录 + +1. [协议概述](#1-协议概述) +2. [服务发现(Bonjour/mDNS)](#2-服务发现bonjourmdns) +3. [TCP 连接与握手](#3-tcp-连接与握手) +4. [消息格式规范](#4-消息格式规范) +5. [文件传输协议](#5-文件传输协议) +6. [心跳与连接保活](#6-心跳与连接保活) +7. [错误处理](#7-错误处理) +8. [常量与配置](#8-常量与配置) +9. [完整时序图](#9-完整时序图) +10. [移动端实现指南](#10-移动端实现指南) + +--- + +## 1. 协议概述 + +### 1.1 架构角色 + +| 角色 | 平台 | 职责 | +| -------------------- | --------------- | ---------------------------- | +| **Client(客户端)** | Electron 桌面端 | 扫描服务、发起连接、发送文件 | +| **Server(服务端)** | Expo 移动端 | 发布服务、接受连接、接收文件 | + +### 1.2 协议栈(v1) + +``` +┌─────────────────────────────────────┐ +│ 应用层(文件传输) │ +├─────────────────────────────────────┤ +│ 消息层(控制: JSON \n) │ +│ (数据: 二进制帧) │ +├─────────────────────────────────────┤ +│ 传输层(TCP) │ +├─────────────────────────────────────┤ +│ 发现层(Bonjour/mDNS) │ +└─────────────────────────────────────┘ +``` + +### 1.3 通信流程概览 + +``` +1. 服务发现 → 移动端发布 mDNS 服务,桌面端扫描发现 +2. TCP 握手 → 建立连接,交换设备信息(`version=1`) +3. 文件传输 → 控制消息使用 JSON,`file_chunk` 使用二进制帧分块传输 +4. 连接保活 → ping/pong 心跳 +``` + +--- + +## 2. 服务发现(Bonjour/mDNS) + +### 2.1 服务类型 + +| 属性 | 值 | +| ------------ | -------------------- | +| 服务类型 | `cherrystudio` | +| 协议 | `tcp` | +| 完整服务标识 | `_cherrystudio._tcp` | + +### 2.2 服务发布(移动端) + +移动端需要通过 mDNS/Bonjour 发布服务: + +```typescript +// 服务发布参数 +{ + name: "Cherry Studio Mobile", // 设备名称 + type: "cherrystudio", // 服务类型 + protocol: "tcp", // 协议 + port: 53317, // TCP 监听端口 + txt: { // TXT 记录(可选) + version: "1", + platform: "ios" // 或 "android" + } +} +``` + +### 2.3 服务发现(桌面端) + +桌面端扫描并解析服务信息: + +```typescript +// 发现的服务信息结构 +type LocalTransferPeer = { + id: string; // 唯一标识符 + name: string; // 设备名称 + host?: string; // 主机名 + fqdn?: string; // 完全限定域名 + port?: number; // TCP 端口 + type?: string; // 服务类型 + protocol?: "tcp" | "udp"; // 协议 + addresses: string[]; // IP 地址列表 + txt?: Record; // TXT 记录 + updatedAt: number; // 发现时间戳 +}; +``` + +### 2.4 IP 地址选择策略 + +当服务有多个 IP 地址时,优先选择 IPv4: + +```typescript +// 优先选择 IPv4 地址 +const preferredAddress = addresses.find((addr) => isIPv4(addr)) || addresses[0]; +``` + +--- + +## 3. TCP 连接与握手 + +### 3.1 连接建立 + +1. 客户端使用发现的 `host:port` 建立 TCP 连接 +2. 连接成功后立即发送握手消息 +3. 等待服务端响应握手确认 + +### 3.2 握手消息(协议版本 v1) + +#### Client → Server: `handshake` + +```typescript +type LanTransferHandshakeMessage = { + type: "handshake"; + deviceName: string; // 设备名称 + version: string; // 协议版本,当前为 "1" + platform?: string; // 平台:'darwin' | 'win32' | 'linux' + appVersion?: string; // 应用版本 +}; +``` + +**示例:** + +```json +{ + "type": "handshake", + "deviceName": "Cherry Studio 1.7.2", + "version": "1", + "platform": "darwin", + "appVersion": "1.7.2" +} +``` + +### 4. 消息格式规范(混合协议) + +v1 使用"控制 JSON + 二进制数据帧"的混合协议(流式传输模式,无 per-chunk ACK): + +- **控制消息**(握手、心跳、file_start/ack、file_end、file_complete):UTF-8 JSON,`\n` 分隔 +- **数据消息**(`file_chunk`):二进制帧,使用 Magic + 总长度做分帧,不经 Base64 + +### 4.1 控制消息编码(JSON + `\n`) + +| 属性 | 规范 | +| ---------- | ------------ | +| 编码格式 | UTF-8 | +| 序列化格式 | JSON | +| 消息分隔符 | `\n`(0x0A) | + +```typescript +function sendControlMessage(socket: Socket, message: object): void { + socket.write(`${JSON.stringify(message)}\n`); +} +``` + +### 4.2 `file_chunk` 二进制帧格式 + +为解决 TCP 分包/粘包并消除 Base64 开销,`file_chunk` 采用带总长度的二进制帧: + +``` +┌──────────┬──────────┬────────┬───────────────┬──────────────┬────────────┬───────────┐ +│ Magic │ TotalLen │ Type │ TransferId Len│ TransferId │ ChunkIdx │ Data │ +│ 0x43 0x53│ (4B BE) │ 0x01 │ (2B BE) │ (UTF-8) │ (4B BE) │ (raw) │ +└──────────┴──────────┴────────┴───────────────┴──────────────┴────────────┴───────────┘ +``` + +| 字段 | 大小 | 说明 | +| -------------- | ---- | ------------------------------------------- | +| Magic | 2B | 常量 `0x43 0x53` ("CS"), 用于区分 JSON 消息 | +| TotalLen | 4B | Big-endian,帧总长度(不含 Magic/TotalLen) | +| Type | 1B | `0x01` 代表 `file_chunk` | +| TransferId Len | 2B | Big-endian,transferId 字符串长度 | +| TransferId | nB | UTF-8 transferId(长度由上一字段给出) | +| ChunkIdx | 4B | Big-endian,块索引,从 0 开始 | +| Data | mB | 原始文件二进制数据(未编码) | + +> 计算帧总长度:`TotalLen = 1 + 2 + transferIdLen + 4 + dataLen`(即 Type~Data 的长度和)。 + +### 4.3 消息解析策略 + +1. 读取 socket 数据到缓冲区; +2. 若前两字节为 `0x43 0x53` → 按二进制帧解析: + - 至少需要 6 字节头(Magic + TotalLen),不足则等待更多数据 + - 读取 `TotalLen` 判断帧整体长度,缓冲区不足则继续等待 + - 解析 Type/TransferId/ChunkIdx/Data,并传入文件接收逻辑 +3. 否则若首字节为 `{` → 按 JSON + `\n` 解析控制消息 +4. 其它数据丢弃 1 字节并继续循环,避免阻塞。 + +### 4.4 消息类型汇总(v1) + +| 类型 | 方向 | 编码 | 用途 | +| ---------------- | --------------- | -------- | ----------------------- | +| `handshake` | Client → Server | JSON+\n | 握手请求(version=1) | +| `handshake_ack` | Server → Client | JSON+\n | 握手响应 | +| `ping` | Client → Server | JSON+\n | 心跳请求 | +| `pong` | Server → Client | JSON+\n | 心跳响应 | +| `file_start` | Client → Server | JSON+\n | 开始文件传输 | +| `file_start_ack` | Server → Client | JSON+\n | 文件传输确认 | +| `file_chunk` | Client → Server | 二进制帧 | 文件数据块(无 Base64,流式无 per-chunk ACK) | +| `file_end` | Client → Server | JSON+\n | 文件传输结束 | +| `file_complete` | Server → Client | JSON+\n | 传输完成结果 | + +``` +{"type":"message_type",...其他字段...}\n +``` + +--- + +## 5. 文件传输协议 + +### 5.1 传输流程 + +``` +Client (Sender) Server (Receiver) + | | + |──── 1. file_start ────────────────>| + | (文件元数据) | + | | + |<─── 2. file_start_ack ─────────────| + | (接受/拒绝) | + | | + |══════ 循环发送数据块(流式,无 ACK) ═════| + | | + |──── 3. file_chunk [0] ────────────>| + | | + |──── 3. file_chunk [1] ────────────>| + | | + | ... 重复直到所有块发送完成 ... | + | | + |══════════════════════════════════════ + | | + |──── 5. file_end ──────────────────>| + | (所有块已发送) | + | | + |<─── 6. file_complete ──────────────| + | (最终结果) | +``` + +### 5.2 消息定义 + +#### 5.2.1 `file_start` - 开始传输 + +**方向:** Client → Server + +```typescript +type LanTransferFileStartMessage = { + type: "file_start"; + transferId: string; // UUID,唯一传输标识 + fileName: string; // 文件名(含扩展名) + fileSize: number; // 文件总字节数 + mimeType: string; // MIME 类型 + checksum: string; // 整个文件的 SHA-256 哈希(hex) + totalChunks: number; // 总数据块数 + chunkSize: number; // 每块大小(字节) +}; +``` + +**示例:** + +```json +{ + "type": "file_start", + "transferId": "550e8400-e29b-41d4-a716-446655440000", + "fileName": "backup.zip", + "fileSize": 524288000, + "mimeType": "application/zip", + "checksum": "a1b2c3d4e5f6789012345678901234567890abcdef1234567890abcdef123456", + "totalChunks": 8192, + "chunkSize": 65536 +} +``` + +#### 5.2.2 `file_start_ack` - 传输确认 + +**方向:** Server → Client + +```typescript +type LanTransferFileStartAckMessage = { + type: "file_start_ack"; + transferId: string; // 对应的传输 ID + accepted: boolean; // 是否接受传输 + message?: string; // 拒绝原因 +}; +``` + +**接受示例:** + +```json +{ + "type": "file_start_ack", + "transferId": "550e8400-e29b-41d4-a716-446655440000", + "accepted": true +} +``` + +**拒绝示例:** + +```json +{ + "type": "file_start_ack", + "transferId": "550e8400-e29b-41d4-a716-446655440000", + "accepted": false, + "message": "Insufficient storage space" +} +``` + +#### 5.2.3 `file_chunk` - 数据块 + +**方向:** Client → Server(**二进制帧**,见 4.2) + +- 不再使用 JSON/`\n`,也不再使用 Base64 +- 帧结构:`Magic` + `TotalLen` + `Type` + `TransferId` + `ChunkIdx` + `Data` +- `Type` 固定 `0x01`,`Data` 为原始文件二进制数据 +- 传输完整性依赖 `file_start.checksum`(全文件 SHA-256);分块校验和可选,不在帧中发送 + +#### 5.2.4 `file_chunk_ack` - 数据块确认(v1 流式不使用) + +v1 采用流式传输,不发送 per-chunk ACK。本节类型仅保留作为向后兼容参考,实际不会发送。 + +#### 5.2.5 `file_end` - 传输结束 + +**方向:** Client → Server + +```typescript +type LanTransferFileEndMessage = { + type: "file_end"; + transferId: string; // 传输 ID +}; +``` + +**示例:** + +```json +{ + "type": "file_end", + "transferId": "550e8400-e29b-41d4-a716-446655440000" +} +``` + +#### 5.2.6 `file_complete` - 传输完成 + +**方向:** Server → Client + +```typescript +type LanTransferFileCompleteMessage = { + type: "file_complete"; + transferId: string; // 传输 ID + success: boolean; // 是否成功 + filePath?: string; // 保存路径(成功时) + error?: string; // 错误信息(失败时) +}; +``` + +**成功示例:** + +```json +{ + "type": "file_complete", + "transferId": "550e8400-e29b-41d4-a716-446655440000", + "success": true, + "filePath": "/storage/emulated/0/Documents/backup.zip" +} +``` + +**失败示例:** + +```json +{ + "type": "file_complete", + "transferId": "550e8400-e29b-41d4-a716-446655440000", + "success": false, + "error": "File checksum verification failed" +} +``` + +### 5.3 校验和算法 + +#### 整个文件校验和(保持不变) + +```typescript +async function calculateFileChecksum(filePath: string): Promise { + const hash = crypto.createHash("sha256"); + const stream = fs.createReadStream(filePath); + + for await (const chunk of stream) { + hash.update(chunk); + } + + return hash.digest("hex"); +} +``` + +#### 数据块校验和 + +v1 默认 **不传输分块校验和**,依赖最终文件 checksum。若需要,可在应用层自定义(非协议字段)。 + +### 5.4 校验流程 + +**发送端(Client):** + +1. 发送前计算整个文件的 SHA-256 → `file_start.checksum` +2. 分块直接发送原始二进制(无 Base64) + +**接收端(Server):** + +1. 收到 `file_chunk` 后直接使用二进制数据 +2. 边收边落盘并增量计算 SHA-256(推荐) +3. 所有块接收完成后,计算/完成增量哈希,得到最终 SHA-256 +4. 与 `file_start.checksum` 比对,结果写入 `file_complete` + +### 5.5 数据块大小计算 + +```typescript +const CHUNK_SIZE = 512 * 1024; // 512KB + +const totalChunks = Math.ceil(fileSize / CHUNK_SIZE); + +// 最后一个块可能小于 CHUNK_SIZE +const lastChunkSize = fileSize % CHUNK_SIZE || CHUNK_SIZE; +``` + +--- + +## 6. 心跳与连接保活 + +### 6.1 心跳消息 + +#### `ping` + +**方向:** Client → Server + +```typescript +type LanTransferPingMessage = { + type: "ping"; + payload?: string; // 可选载荷 +}; +``` + +```json +{ + "type": "ping", + "payload": "heartbeat" +} +``` + +#### `pong` + +**方向:** Server → Client + +```typescript +type LanTransferPongMessage = { + type: "pong"; + received: boolean; // 确认收到 + payload?: string; // 回传 ping 的载荷 +}; +``` + +```json +{ + "type": "pong", + "received": true, + "payload": "heartbeat" +} +``` + +### 6.2 心跳策略 + +- 握手成功后立即发送一次 `ping` 验证连接 +- 可选:定期发送心跳保持连接活跃 +- `pong` 应返回 `ping` 中的 `payload`(可选) + +--- + +## 7. 错误处理 + +### 7.1 超时配置 + +| 操作 | 超时时间 | 说明 | +| ---------- | -------- | --------------------- | +| TCP 连接 | 10 秒 | 连接建立超时 | +| 握手等待 | 10 秒 | 等待 `handshake_ack` | +| 传输完成 | 60 秒 | 等待 `file_complete` | + +### 7.2 错误场景处理 + +| 场景 | Client 处理 | Server 处理 | +| --------------- | ------------------ | ---------------------- | +| TCP 连接失败 | 通知 UI,允许重试 | - | +| 握手超时 | 断开连接,通知 UI | 关闭 socket | +| 握手被拒绝 | 显示拒绝原因 | - | +| 数据块处理失败 | 中止传输,清理状态 | 清理临时文件 | +| 连接意外断开 | 清理状态,通知 UI | 清理临时文件 | +| 存储空间不足 | - | 发送 `accepted: false` | + +### 7.3 资源清理 + +**Client 端:** + +```typescript +function cleanup(): void { + // 1. 销毁文件读取流 + if (readStream) { + readStream.destroy(); + } + // 2. 清理传输状态 + activeTransfer = undefined; + // 3. 关闭 socket(如需要) + socket?.destroy(); +} +``` + +**Server 端:** + +```typescript +function cleanup(): void { + // 1. 关闭文件写入流 + if (writeStream) { + writeStream.end(); + } + // 2. 删除未完成的临时文件 + if (tempFilePath) { + fs.unlinkSync(tempFilePath); + } + // 3. 清理传输状态 + activeTransfer = undefined; +} +``` + +--- + +## 8. 常量与配置 + +### 8.1 协议常量 + +```typescript +// 协议版本(v1 = 控制 JSON + 二进制 chunk + 流式传输) +export const LAN_TRANSFER_PROTOCOL_VERSION = "1"; + +// 服务发现 +export const LAN_TRANSFER_SERVICE_TYPE = "cherrystudio"; +export const LAN_TRANSFER_SERVICE_FULL_NAME = "_cherrystudio._tcp"; + +// TCP 端口 +export const LAN_TRANSFER_TCP_PORT = 53317; + +// 文件传输(与二进制帧一致) +export const LAN_TRANSFER_CHUNK_SIZE = 512 * 1024; // 512KB +export const LAN_TRANSFER_GLOBAL_TIMEOUT_MS = 10 * 60 * 1000; // 10 分钟 + +// 超时设置 +export const LAN_TRANSFER_HANDSHAKE_TIMEOUT_MS = 10_000; // 10秒 +export const LAN_TRANSFER_CHUNK_TIMEOUT_MS = 30_000; // 30秒 +export const LAN_TRANSFER_COMPLETE_TIMEOUT_MS = 60_000; // 60秒 +``` + +### 8.2 支持的文件类型 + +当前仅支持 ZIP 文件: + +```typescript +export const LAN_TRANSFER_ALLOWED_EXTENSIONS = [".zip"]; +export const LAN_TRANSFER_ALLOWED_MIME_TYPES = [ + "application/zip", + "application/x-zip-compressed", +]; +``` + +--- + +## 9. 完整时序图 + +### 9.1 完整传输流程(v1,流式传输) + +``` +┌─────────┐ ┌─────────┐ ┌─────────┐ +│ Renderer│ │ Main │ │ Mobile │ +│ (UI) │ │ Process │ │ Server │ +└────┬────┘ └────┬────┘ └────┬────┘ + │ │ │ + │ ════════════ 服务发现阶段 ════════════ │ + │ │ │ + │ startScan() │ │ + │────────────────────────────────────>│ │ + │ │ mDNS browse │ + │ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─>│ + │ │ │ + │ │<─ ─ ─ service discovered ─ ─ ─ ─ ─ ─│ + │ │ │ + │<────── onServicesUpdated ───────────│ │ + │ │ │ + │ ════════════ 握手连接阶段 ════════════ │ + │ │ │ + │ connect(peer) │ │ + │────────────────────────────────────>│ │ + │ │──────── TCP Connect ───────────────>│ + │ │ │ + │ │──────── handshake ─────────────────>│ + │ │ │ + │ │<─────── handshake_ack ──────────────│ + │ │ │ + │ │──────── ping ──────────────────────>│ + │ │<─────── pong ───────────────────────│ + │ │ │ + │<────── connect result ──────────────│ │ + │ │ │ + │ ════════════ 文件传输阶段 ════════════ │ + │ │ │ + │ sendFile(path) │ │ + │────────────────────────────────────>│ │ + │ │──────── file_start ────────────────>│ + │ │ │ + │ │<─────── file_start_ack ─────────────│ + │ │ │ + │ │ │ + │ │══════ 循环发送数据块 ═══════════════│ + │ │ │ + │ │──────── file_chunk[0] (binary) ────>│ + │<────── progress event ──────────────│ │ + │ │ │ + │ │──────── file_chunk[1] (binary) ────>│ + │<────── progress event ──────────────│ │ + │ │ │ + │ │ ... 重复 ... │ + │ │ │ + │ │══════════════════════════════════════│ + │ │ │ + │ │──────── file_end ──────────────────>│ + │ │ │ + │ │<─────── file_complete ──────────────│ + │ │ │ + │<────── complete event ──────────────│ │ + │<────── sendFile result ─────────────│ │ + │ │ │ +``` + +--- + +## 10. 移动端实现指南(v1 要点) + +### 10.1 必须实现的功能 + +1. **mDNS 服务发布** + + - 发布 `_cherrystudio._tcp` 服务 + - 提供 TCP 端口号 `53317` + - 可选:TXT 记录(版本、平台信息) + +2. **TCP 服务端** + + - 监听指定端口 + - 支持单连接或多连接 + +3. **消息解析** + + - 控制消息:UTF-8 + `\n` JSON + - 数据消息:二进制帧(Magic+TotalLen 分帧) + +4. **握手处理** + + - 验证 `handshake` 消息 + - 发送 `handshake_ack` 响应 + - 响应 `ping` 消息 + +5. **文件接收(流式模式)** + - 解析 `file_start`,准备接收 + - 接收 `file_chunk` 二进制帧,直接写入文件/缓冲并增量哈希 + - v1 不发送 per-chunk ACK(流式传输) + - 处理 `file_end`,完成增量哈希并校验 checksum + - 发送 `file_complete` 结果 + +### 10.2 推荐的库 + +**React Native / Expo:** + +- mDNS: `react-native-zeroconf` 或 `@homielab/react-native-bonjour` +- TCP: `react-native-tcp-socket` +- Crypto: `expo-crypto` 或 `react-native-quick-crypto` + +### 10.3 接收端伪代码 + +```typescript +class FileReceiver { + private transfer?: { + id: string; + fileName: string; + fileSize: number; + checksum: string; + totalChunks: number; + receivedChunks: number; + tempPath: string; + // v1: 边收边写文件,避免大文件 OOM + // stream: FileSystem writable stream (平台相关封装) + }; + + handleMessage(message: any) { + switch (message.type) { + case "handshake": + this.handleHandshake(message); + break; + case "ping": + this.sendPong(message); + break; + case "file_start": + this.handleFileStart(message); + break; + // v1: file_chunk 为二进制帧,不再走 JSON 分支 + case "file_end": + this.handleFileEnd(message); + break; + } + } + + handleFileStart(msg: LanTransferFileStartMessage) { + // 1. 检查存储空间 + // 2. 创建临时文件 + // 3. 初始化传输状态 + // 4. 发送 file_start_ack + } + + // v1: 二进制帧处理在 socket data 流中解析,随后调用 handleBinaryFileChunk + handleBinaryFileChunk(transferId: string, chunkIndex: number, data: Buffer) { + // 直接使用二进制数据,按 chunkSize/lastChunk 计算长度 + // 写入文件流并更新增量 SHA-256 + this.transfer.receivedChunks++; + // v1: 流式传输,不发送 per-chunk ACK + } + + handleFileEnd(msg: LanTransferFileEndMessage) { + // 1. 合并所有数据块 + // 2. 验证完整文件 checksum + // 3. 写入最终位置 + // 4. 发送 file_complete + } +} +``` + +--- + +## 附录 A:TypeScript 类型定义 + +完整的类型定义位于 `packages/shared/config/types.ts`: + +```typescript +// 握手消息 +export interface LanTransferHandshakeMessage { + type: "handshake"; + deviceName: string; + version: string; + platform?: string; + appVersion?: string; +} + +export interface LanTransferHandshakeAckMessage { + type: "handshake_ack"; + accepted: boolean; + message?: string; +} + +// 心跳消息 +export interface LanTransferPingMessage { + type: "ping"; + payload?: string; +} + +export interface LanTransferPongMessage { + type: "pong"; + received: boolean; + payload?: string; +} + +// 文件传输消息 (Client -> Server) +export interface LanTransferFileStartMessage { + type: "file_start"; + transferId: string; + fileName: string; + fileSize: number; + mimeType: string; + checksum: string; + totalChunks: number; + chunkSize: number; +} + +export interface LanTransferFileChunkMessage { + type: "file_chunk"; + transferId: string; + chunkIndex: number; + data: string; // Base64 encoded (v1: 二进制帧模式下不使用) +} + +export interface LanTransferFileEndMessage { + type: "file_end"; + transferId: string; +} + +// 文件传输响应消息 (Server -> Client) +export interface LanTransferFileStartAckMessage { + type: "file_start_ack"; + transferId: string; + accepted: boolean; + message?: string; +} + +// v1 流式不发送 per-chunk ACK,以下类型仅用于向后兼容参考 +export interface LanTransferFileChunkAckMessage { + type: "file_chunk_ack"; + transferId: string; + chunkIndex: number; + received: boolean; + error?: string; +} + +export interface LanTransferFileCompleteMessage { + type: "file_complete"; + transferId: string; + success: boolean; + filePath?: string; + error?: string; +} + +// 常量 +export const LAN_TRANSFER_TCP_PORT = 53317; +export const LAN_TRANSFER_CHUNK_SIZE = 512 * 1024; +export const LAN_TRANSFER_CHUNK_TIMEOUT_MS = 30_000; +``` + +--- + +## 附录 B:版本历史 + +| 版本 | 日期 | 变更 | +| ---- | ------- | ---------------------------------------- | +| 1.0 | 2025-12 | 初始发布版本,支持二进制帧格式与流式传输 | diff --git a/electron-builder.yml b/electron-builder.yml index 5e63e7231d..bf7b7b4e91 100644 --- a/electron-builder.yml +++ b/electron-builder.yml @@ -134,108 +134,44 @@ artifactBuildCompleted: scripts/artifact-build-completed.js releaseInfo: releaseNotes: | - A New Era of Intelligence with Cherry Studio 1.7.1 + Cherry Studio 1.7.9 - New Features & Bug Fixes - Today we're releasing Cherry Studio 1.7.1 — our most ambitious update yet, introducing Agent: autonomous AI that thinks, plans, and acts. + ✨ New Features + - [Agent] Add 302.AI provider support + - [Browser] Browser data now persists and supports multiple tabs + - [Language] Add Romanian language support + - [Search] Add fuzzy search for file list + - [Models] Add latest Zhipu models + - [Image] Improve text-to-image functionality - For years, AI assistants have been reactive — waiting for your commands, responding to your questions. With Agent, we're changing that. Now, AI can truly work alongside you: understanding complex goals, breaking them into steps, and executing them independently. - - This is what we've been building toward. And it's just the beginning. - - 🤖 Meet Agent - Imagine having a brilliant colleague who never sleeps. Give Agent a goal — write a report, analyze data, refactor code — and watch it work. It reasons through problems, breaks them into steps, calls the right tools, and adapts when things change. - - - **Think → Plan → Act**: From goal to execution, fully autonomous - - **Deep Reasoning**: Multi-turn thinking that solves real problems - - **Tool Mastery**: File operations, web search, code execution, and more - - **Skill Plugins**: Extend with custom commands and capabilities - - **You Stay in Control**: Real-time approval for sensitive actions - - **Full Visibility**: Every thought, every decision, fully transparent - - 🌐 Expanding Ecosystem - - **New Providers**: HuggingFace, Mistral, CherryIN, AI Gateway, Intel OVMS, Didi MCP - - **New Models**: Claude 4.5 Haiku, DeepSeek v3.2, GLM-4.6, Doubao, Ling series - - **MCP Integration**: Alibaba Cloud, ModelScope, Higress, MCP.so, TokenFlux and more - - 📚 Smarter Knowledge Base - - **OpenMinerU**: Self-hosted document processing - - **Full-Text Search**: Find anything instantly across your notes - - **Enhanced Tool Selection**: Smarter configuration for better AI assistance - - 📝 Notes, Reimagined - - Full-text search with highlighted results - - AI-powered smart rename - - Export as image - - Auto-wrap for tables - - 🖼️ Image & OCR - - Intel OVMS painting capabilities - - Intel OpenVINO NPU-accelerated OCR - - 🌍 Now in 10+ Languages - - Added German support - - Enhanced internationalization - - ⚡ Faster & More Polished - - Electron 38 upgrade - - New MCP management interface - - Dozens of UI refinements - - ❤️ Fully Open Source - Commercial restrictions removed. Cherry Studio now follows standard AGPL v3 — free for teams of any size. - - The Agent Era is here. We can't wait to see what you'll create. + 🐛 Bug Fixes + - [Mac] Fix mini window unexpected closing issue + - [Preview] Fix HTML preview controls not working in fullscreen + - [Translate] Fix translation duplicate execution issue + - [Zoom] Fix page zoom reset issue during navigation + - [Agent] Fix crash when switching between agent and assistant + - [Agent] Fix navigation in agent mode + - [Copy] Fix markdown copy button issue + - [Windows] Fix compatibility issues on non-Windows systems - Cherry Studio 1.7.1:开启智能新纪元 + Cherry Studio 1.7.9 - 新功能与问题修复 - 今天,我们正式发布 Cherry Studio 1.7.1 —— 迄今最具雄心的版本,带来全新的 Agent:能够自主思考、规划和行动的 AI。 + ✨ 新功能 + - [Agent] 新增 302.AI 服务商支持 + - [浏览器] 浏览器数据现在可以保存,支持多标签页 + - [语言] 新增罗马尼亚语支持 + - [搜索] 文件列表新增模糊搜索功能 + - [模型] 新增最新智谱模型 + - [图片] 优化文生图功能 - 多年来,AI 助手一直是被动的——等待你的指令,回应你的问题。Agent 改变了这一切。现在,AI 能够真正与你并肩工作:理解复杂目标,将其拆解为步骤,并独立执行。 - - 这是我们一直在构建的未来。而这,仅仅是开始。 - - 🤖 认识 Agent - 想象一位永不疲倦的得力伙伴。给 Agent 一个目标——撰写报告、分析数据、重构代码——然后看它工作。它会推理问题、拆解步骤、调用工具,并在情况变化时灵活应对。 - - - **思考 → 规划 → 行动**:从目标到执行,全程自主 - - **深度推理**:多轮思考,解决真实问题 - - **工具大师**:文件操作、网络搜索、代码执行,样样精通 - - **技能插件**:自定义命令,无限扩展 - - **你掌控全局**:敏感操作,实时审批 - - **完全透明**:每一步思考,每一个决策,清晰可见 - - 🌐 生态持续壮大 - - **新增服务商**:Hugging Face、Mistral、Perplexity、SophNet、AI Gateway、Cerebras AI - - **新增模型**:Gemini 3、Gemini 3 Pro(支持图像预览)、GPT-5.1、Claude Opus 4.5 - - **MCP 集成**:百炼、魔搭、Higress、MCP.so、TokenFlux 等平台 - - 📚 更智能的知识库 - - **OpenMinerU**:本地自部署文档处理 - - **全文搜索**:笔记内容一搜即达 - - **增强工具选择**:更智能的配置,更好的 AI 协助 - - 📝 笔记,焕然一新 - - 全文搜索,结果高亮 - - AI 智能重命名 - - 导出为图片 - - 表格自动换行 - - 🖼️ 图像与 OCR - - Intel OVMS 绘图能力 - - Intel OpenVINO NPU 加速 OCR - - 🌍 支持 10+ 种语言 - - 新增德语支持 - - 全面增强国际化 - - ⚡ 更快、更精致 - - 升级 Electron 38 - - 新的 MCP 管理界面 - - 数十处 UI 细节打磨 - - ❤️ 完全开源 - 商用限制已移除。Cherry Studio 现遵循标准 AGPL v3 协议——任意规模团队均可自由使用。 - - Agent 纪元已至。期待你的创造。 + 🐛 问题修复 + - [Mac] 修复迷你窗口意外关闭的问题 + - [预览] 修复全屏模式下 HTML 预览控件无法使用的问题 + - [翻译] 修复翻译重复执行的问题 + - [缩放] 修复页面导航时缩放被重置的问题 + - [智能体] 修复在智能体和助手间切换时崩溃的问题 + - [智能体] 修复智能体模式下的导航问题 + - [复制] 修复 Markdown 复制按钮问题 + - [兼容性] 修复非 Windows 系统的兼容性问题 diff --git a/electron.vite.config.ts b/electron.vite.config.ts index 172d48ca9a..89c0cf2f9b 100644 --- a/electron.vite.config.ts +++ b/electron.vite.config.ts @@ -1,6 +1,6 @@ import react from '@vitejs/plugin-react-swc' import { CodeInspectorPlugin } from 'code-inspector-plugin' -import { defineConfig, externalizeDepsPlugin } from 'electron-vite' +import { defineConfig } from 'electron-vite' import { resolve } from 'path' import { visualizer } from 'rollup-plugin-visualizer' @@ -17,7 +17,7 @@ const isProd = process.env.NODE_ENV === 'production' export default defineConfig({ main: { - plugins: [externalizeDepsPlugin(), ...visualizerPlugin('main')], + plugins: [...visualizerPlugin('main')], resolve: { alias: { '@main': resolve('src/main'), @@ -51,8 +51,7 @@ export default defineConfig({ plugins: [ react({ tsDecorators: true - }), - externalizeDepsPlugin() + }) ], resolve: { alias: { diff --git a/eslint.config.mjs b/eslint.config.mjs index 64fdefa1dc..9eb20d1238 100644 --- a/eslint.config.mjs +++ b/eslint.config.mjs @@ -61,6 +61,7 @@ export default defineConfig([ 'tests/**', '.yarn/**', '.gitignore', + '.conductor/**', 'scripts/cloudflare-worker.js', 'src/main/integration/nutstore/sso/lib/**', 'src/main/integration/cherryai/index.js', diff --git a/package.json b/package.json index fd5eb0151d..6dddf4fd4a 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "CherryStudio", - "version": "1.7.1", + "version": "1.7.9", "private": true, "description": "A powerful AI assistant for producer.", "main": "./out/main/index.js", @@ -27,6 +27,7 @@ "scripts": { "start": "electron-vite preview", "dev": "dotenv electron-vite dev", + "dev:watch": "dotenv electron-vite dev -- -w", "debug": "electron-vite -- --inspect --sourcemap --remote-debugging-port=9222", "build": "npm run typecheck && electron-vite build", "build:check": "yarn lint && yarn test", @@ -53,10 +54,10 @@ "typecheck": "concurrently -n \"node,web\" -c \"cyan,magenta\" \"npm run typecheck:node\" \"npm run typecheck:web\"", "typecheck:node": "tsgo --noEmit -p tsconfig.node.json --composite false", "typecheck:web": "tsgo --noEmit -p tsconfig.web.json --composite false", - "check:i18n": "dotenv -e .env -- tsx scripts/check-i18n.ts", - "sync:i18n": "dotenv -e .env -- tsx scripts/sync-i18n.ts", - "update:i18n": "dotenv -e .env -- tsx scripts/update-i18n.ts", - "auto:i18n": "dotenv -e .env -- tsx scripts/auto-translate-i18n.ts", + "i18n:check": "dotenv -e .env -- tsx scripts/check-i18n.ts", + "i18n:sync": "dotenv -e .env -- tsx scripts/sync-i18n.ts", + "i18n:translate": "dotenv -e .env -- tsx scripts/auto-translate-i18n.ts", + "i18n:all": "yarn i18n:check && yarn i18n:sync && yarn i18n:translate", "update:languages": "tsx scripts/update-languages.ts", "update:upgrade-config": "tsx scripts/update-app-upgrade-config.ts", "test": "vitest run --silent", @@ -70,7 +71,7 @@ "test:e2e": "yarn playwright test", "test:lint": "oxlint --deny-warnings && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --cache", "test:scripts": "vitest scripts", - "lint": "oxlint --fix && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --cache && yarn typecheck && yarn check:i18n && yarn format:check", + "lint": "oxlint --fix && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --cache && yarn typecheck && yarn i18n:check && yarn format:check", "format": "biome format --write && biome lint --write", "format:check": "biome format && biome lint", "prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky", @@ -81,12 +82,13 @@ "release:ai-sdk-provider": "yarn workspace @cherrystudio/ai-sdk-provider version patch --immediate && yarn workspace @cherrystudio/ai-sdk-provider build && yarn workspace @cherrystudio/ai-sdk-provider npm publish --access public" }, "dependencies": { - "@anthropic-ai/claude-agent-sdk": "patch:@anthropic-ai/claude-agent-sdk@npm%3A0.1.53#~/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.53-4b77f4cf29.patch", + "@anthropic-ai/claude-agent-sdk": "patch:@anthropic-ai/claude-agent-sdk@npm%3A0.1.62#~/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.62-23ae56f8c8.patch", "@libsql/client": "0.14.0", "@libsql/win32-x64-msvc": "^0.4.7", "@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch", "@paymoapp/electron-shutdown-handler": "^1.1.2", "@strongtz/win32-arm64-msvc": "^0.4.7", + "bonjour-service": "^1.3.0", "emoji-picker-element-data": "^1", "express": "^5.1.0", "font-list": "^2.0.0", @@ -97,10 +99,8 @@ "node-stream-zip": "^1.15.0", "officeparser": "^4.2.0", "os-proxy-config": "^1.1.2", - "qrcode.react": "^4.2.0", "selection-hook": "^1.0.12", "sharp": "^0.34.3", - "socket.io": "^4.8.1", "swagger-jsdoc": "^6.2.8", "swagger-ui-express": "^5.0.1", "tesseract.js": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch", @@ -114,11 +114,11 @@ "@ai-sdk/anthropic": "^2.0.49", "@ai-sdk/cerebras": "^1.0.31", "@ai-sdk/gateway": "^2.0.15", - "@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.43#~/.yarn/patches/@ai-sdk-google-npm-2.0.43-689ed559b3.patch", - "@ai-sdk/google-vertex": "^3.0.79", + "@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.49#~/.yarn/patches/@ai-sdk-google-npm-2.0.49-84720f41bd.patch", + "@ai-sdk/google-vertex": "^3.0.94", "@ai-sdk/huggingface": "^0.0.10", "@ai-sdk/mistral": "^2.0.24", - "@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.72#~/.yarn/patches/@ai-sdk-openai-npm-2.0.72-234e68da87.patch", + "@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.85#~/.yarn/patches/@ai-sdk-openai-npm-2.0.85-27483d1d6a.patch", "@ai-sdk/perplexity": "^2.0.20", "@ai-sdk/test-server": "^0.0.1", "@ant-design/v5-patch-for-react-19": "^1.0.3", @@ -142,7 +142,7 @@ "@cherrystudio/embedjs-ollama": "^0.1.31", "@cherrystudio/embedjs-openai": "^0.1.31", "@cherrystudio/extension-table-plus": "workspace:^", - "@cherrystudio/openai": "^6.9.0", + "@cherrystudio/openai": "^6.12.0", "@dnd-kit/core": "^6.3.1", "@dnd-kit/modifiers": "^9.0.0", "@dnd-kit/sortable": "^10.0.0", @@ -274,7 +274,7 @@ "electron-reload": "^2.0.0-alpha.1", "electron-store": "^8.2.0", "electron-updater": "patch:electron-updater@npm%3A6.7.0#~/.yarn/patches/electron-updater-npm-6.7.0-47b11bb0d4.patch", - "electron-vite": "4.0.1", + "electron-vite": "5.0.0", "electron-window-state": "^5.0.3", "emittery": "^1.0.3", "emoji-picker-element": "^1.22.1", @@ -318,7 +318,7 @@ "motion": "^12.10.5", "notion-helper": "^1.3.22", "npx-scope-finder": "^1.2.0", - "ollama-ai-provider-v2": "^1.5.5", + "ollama-ai-provider-v2": "patch:ollama-ai-provider-v2@npm%3A1.5.5#~/.yarn/patches/ollama-ai-provider-v2-npm-1.5.5-8bef249af9.patch", "oxlint": "^1.22.0", "oxlint-tsgolint": "^0.2.0", "p-queue": "^8.1.0", @@ -371,7 +371,7 @@ "undici": "6.21.2", "unified": "^11.0.5", "uuid": "^13.0.0", - "vite": "npm:rolldown-vite@7.1.5", + "vite": "npm:rolldown-vite@7.3.0", "vitest": "^3.2.4", "webdav": "^5.8.0", "winston": "^3.17.0", @@ -401,7 +401,7 @@ "pkce-challenge@npm:^4.1.0": "patch:pkce-challenge@npm%3A4.1.0#~/.yarn/patches/pkce-challenge-npm-4.1.0-fbc51695a3.patch", "tar-fs": "^2.1.4", "undici": "6.21.2", - "vite": "npm:rolldown-vite@7.1.5", + "vite": "npm:rolldown-vite@7.3.0", "tesseract.js@npm:*": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch", "@ai-sdk/openai@npm:^2.0.52": "patch:@ai-sdk/openai@npm%3A2.0.52#~/.yarn/patches/@ai-sdk-openai-npm-2.0.52-b36d949c76.patch", "@img/sharp-darwin-arm64": "0.34.3", @@ -414,9 +414,12 @@ "@langchain/openai@npm:>=0.1.0 <0.6.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch", "@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch", "@langchain/openai@npm:>=0.2.0 <0.7.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch", - "@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.72#~/.yarn/patches/@ai-sdk-openai-npm-2.0.72-234e68da87.patch", + "@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.85#~/.yarn/patches/@ai-sdk-openai-npm-2.0.85-27483d1d6a.patch", "@ai-sdk/google@npm:^2.0.40": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch", - "@ai-sdk/openai-compatible@npm:^1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch" + "@ai-sdk/openai-compatible@npm:^1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch", + "@ai-sdk/google@npm:2.0.49": "patch:@ai-sdk/google@npm%3A2.0.49#~/.yarn/patches/@ai-sdk-google-npm-2.0.49-84720f41bd.patch", + "@ai-sdk/openai-compatible@npm:1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch", + "@ai-sdk/openai-compatible@npm:^1.0.19": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch" }, "packageManager": "yarn@4.9.1", "lint-staged": { diff --git a/packages/ai-sdk-provider/package.json b/packages/ai-sdk-provider/package.json index 25864f3b1f..e635f93aeb 100644 --- a/packages/ai-sdk-provider/package.json +++ b/packages/ai-sdk-provider/package.json @@ -41,7 +41,7 @@ "ai": "^5.0.26" }, "dependencies": { - "@ai-sdk/openai-compatible": "^1.0.28", + "@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch", "@ai-sdk/provider": "^2.0.0", "@ai-sdk/provider-utils": "^3.0.17" }, diff --git a/packages/aiCore/package.json b/packages/aiCore/package.json index a648dcf3c7..e73a843b1d 100644 --- a/packages/aiCore/package.json +++ b/packages/aiCore/package.json @@ -40,9 +40,9 @@ }, "dependencies": { "@ai-sdk/anthropic": "^2.0.49", - "@ai-sdk/azure": "^2.0.74", + "@ai-sdk/azure": "^2.0.87", "@ai-sdk/deepseek": "^1.0.31", - "@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch", + "@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch", "@ai-sdk/provider": "^2.0.0", "@ai-sdk/provider-utils": "^3.0.17", "@ai-sdk/xai": "^2.0.36", diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts index 59a425712c..c30c2015f6 100644 --- a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts @@ -62,7 +62,7 @@ export class StreamEventManager { const recursiveResult = await context.recursiveCall(recursiveParams) if (recursiveResult && recursiveResult.fullStream) { - await this.pipeRecursiveStream(controller, recursiveResult.fullStream, context) + await this.pipeRecursiveStream(controller, recursiveResult.fullStream) } else { console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult) } @@ -74,11 +74,7 @@ export class StreamEventManager { /** * 将递归流的数据传递到当前流 */ - private async pipeRecursiveStream( - controller: StreamController, - recursiveStream: ReadableStream, - context?: AiRequestContext - ): Promise { + private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise { const reader = recursiveStream.getReader() try { while (true) { @@ -86,18 +82,14 @@ export class StreamEventManager { if (done) { break } + if (value.type === 'start') { + continue + } + if (value.type === 'finish') { - // 迭代的流不发finish,但需要累加其 usage - if (value.usage && context?.accumulatedUsage) { - this.accumulateUsage(context.accumulatedUsage, value.usage) - } break } - // 对于 finish-step 类型,累加其 usage - if (value.type === 'finish-step' && value.usage && context?.accumulatedUsage) { - this.accumulateUsage(context.accumulatedUsage, value.usage) - } - // 将递归流的数据传递到当前流 + controller.enqueue(value) } } finally { @@ -135,10 +127,8 @@ export class StreamEventManager { // 构建新的对话消息 const newMessages: ModelMessage[] = [ ...(context.originalParams.messages || []), - { - role: 'assistant', - content: textBuffer - }, + // 只有当 textBuffer 有内容时才添加 assistant 消息,避免空消息导致 API 错误 + ...(textBuffer ? [{ role: 'assistant' as const, content: textBuffer }] : []), { role: 'user', content: toolResultsText @@ -161,7 +151,7 @@ export class StreamEventManager { /** * 累加 usage 数据 */ - private accumulateUsage(target: any, source: any): void { + accumulateUsage(target: any, source: any): void { if (!target || !source) return // 累加各种 token 类型 diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts index 274fdcee5c..224cee05ae 100644 --- a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts @@ -22,10 +22,10 @@ const TOOL_USE_TAG_CONFIG: TagConfig = { } /** - * 默认系统提示符模板(提取自 Cherry Studio) + * 默认系统提示符模板 */ -const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \\ -You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use. +export const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \ +You can use one or more tools per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use. ## Tool Use Formatting @@ -74,10 +74,13 @@ Here are the rules you should always follow to solve your task: 4. Never re-do a tool call that you previously did with the exact same parameters. 5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format. +## Response rules + +Respond in the language of the user's query, unless the user instructions specify additional requirements for the language to be used. + # User Instructions {{ USER_SYSTEM_PROMPT }} - -Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.` +` /** * 默认工具使用示例(提取自 Cherry Studio) @@ -411,7 +414,10 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => { } } - // 如果没有执行工具调用,直接传递原始finish-step事件 + // 如果没有执行工具调用,累加 usage 后透传 finish-step 事件 + if (chunk.usage && context.accumulatedUsage) { + streamEventManager.accumulateUsage(context.accumulatedUsage, chunk.usage) + } controller.enqueue(chunk) // 清理状态 diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts index 61e6f49b81..6e313bdd27 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts @@ -6,6 +6,7 @@ import { type Tool } from 'ai' import { createOpenRouterOptions, createXaiOptions, mergeProviderOptions } from '../../../options' import type { ProviderOptionsMap } from '../../../options/types' +import type { AiRequestContext } from '../../' import type { OpenRouterSearchConfig } from './openrouter' /** @@ -95,28 +96,84 @@ export type WebSearchToolInputSchema = { 'openai-chat': InferToolInput } -export const switchWebSearchTool = (config: WebSearchPluginConfig, params: any) => { - if (config.openai) { - if (!params.tools) params.tools = {} - params.tools.web_search = openai.tools.webSearch(config.openai) - } else if (config['openai-chat']) { - if (!params.tools) params.tools = {} - params.tools.web_search_preview = openai.tools.webSearchPreview(config['openai-chat']) - } else if (config.anthropic) { - if (!params.tools) params.tools = {} - params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic) - } else if (config.google) { - // case 'google-vertex': - if (!params.tools) params.tools = {} - params.tools.web_search = google.tools.googleSearch(config.google || {}) - } else if (config.xai) { - const searchOptions = createXaiOptions({ - searchParameters: { ...config.xai, mode: 'on' } - }) - params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) - } else if (config.openrouter) { - const searchOptions = createOpenRouterOptions(config.openrouter) - params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) +/** + * Helper function to ensure params.tools object exists + */ +const ensureToolsObject = (params: any) => { + if (!params.tools) params.tools = {} +} + +/** + * Helper function to apply tool-based web search configuration + */ +const applyToolBasedSearch = (params: any, toolName: string, toolInstance: any) => { + ensureToolsObject(params) + params.tools[toolName] = toolInstance +} + +/** + * Helper function to apply provider options-based web search configuration + */ +const applyProviderOptionsSearch = (params: any, searchOptions: any) => { + params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) +} + +export const switchWebSearchTool = (config: WebSearchPluginConfig, params: any, context?: AiRequestContext) => { + const providerId = context?.providerId + + // Provider-specific configuration map + const providerHandlers: Record void> = { + openai: () => { + const cfg = config.openai ?? DEFAULT_WEB_SEARCH_CONFIG.openai + applyToolBasedSearch(params, 'web_search', openai.tools.webSearch(cfg)) + }, + 'openai-chat': () => { + const cfg = (config['openai-chat'] ?? DEFAULT_WEB_SEARCH_CONFIG['openai-chat']) as OpenAISearchPreviewConfig + applyToolBasedSearch(params, 'web_search_preview', openai.tools.webSearchPreview(cfg)) + }, + anthropic: () => { + const cfg = config.anthropic ?? DEFAULT_WEB_SEARCH_CONFIG.anthropic + applyToolBasedSearch(params, 'web_search', anthropic.tools.webSearch_20250305(cfg)) + }, + google: () => { + const cfg = (config.google ?? DEFAULT_WEB_SEARCH_CONFIG.google) as GoogleSearchConfig + applyToolBasedSearch(params, 'web_search', google.tools.googleSearch(cfg)) + }, + xai: () => { + const cfg = config.xai ?? DEFAULT_WEB_SEARCH_CONFIG.xai + const searchOptions = createXaiOptions({ searchParameters: { ...cfg, mode: 'on' } }) + applyProviderOptionsSearch(params, searchOptions) + }, + openrouter: () => { + const cfg = (config.openrouter ?? DEFAULT_WEB_SEARCH_CONFIG.openrouter) as OpenRouterSearchConfig + const searchOptions = createOpenRouterOptions(cfg) + applyProviderOptionsSearch(params, searchOptions) + } } + + // Try provider-specific handler first + const handler = providerId && providerHandlers[providerId] + if (handler) { + handler() + return params + } + + // Fallback: apply based on available config keys (prioritized order) + const fallbackOrder: Array = [ + 'openai', + 'openai-chat', + 'anthropic', + 'google', + 'xai', + 'openrouter' + ] + + for (const key of fallbackOrder) { + if (config[key]) { + providerHandlers[key]() + break + } + } + return params } diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts index a46df7dd4c..e02fd179fe 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts @@ -17,8 +17,22 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR name: 'webSearch', enforce: 'pre', - transformParams: async (params: any) => { - switchWebSearchTool(config, params) + transformParams: async (params: any, context) => { + let { providerId } = context + + // For cherryin providers, extract the actual provider from the model's provider string + // Expected format: "cherryin.{actualProvider}" (e.g., "cherryin.gemini") + if (providerId === 'cherryin' || providerId === 'cherryin-chat') { + const provider = params.model?.provider + if (provider && typeof provider === 'string' && provider.includes('.')) { + const extractedProviderId = provider.split('.')[1] + if (extractedProviderId) { + providerId = extractedProviderId + } + } + } + + switchWebSearchTool(config, params, { ...context, providerId }) return params } }) diff --git a/packages/shared/IpcChannel.ts b/packages/shared/IpcChannel.ts index 27d208a2a8..7771d4d87a 100644 --- a/packages/shared/IpcChannel.ts +++ b/packages/shared/IpcChannel.ts @@ -55,6 +55,8 @@ export enum IpcChannel { Webview_SetOpenLinkExternal = 'webview:set-open-link-external', Webview_SetSpellCheckEnabled = 'webview:set-spell-check-enabled', Webview_SearchHotkey = 'webview:search-hotkey', + Webview_PrintToPDF = 'webview:print-to-pdf', + Webview_SaveAsHTML = 'webview:save-as-html', // Open Open_Path = 'open:path', @@ -90,6 +92,8 @@ export enum IpcChannel { Mcp_AbortTool = 'mcp:abort-tool', Mcp_GetServerVersion = 'mcp:get-server-version', Mcp_Progress = 'mcp:progress', + Mcp_GetServerLogs = 'mcp:get-server-logs', + Mcp_ServerLog = 'mcp:server-log', // Python Python_Execute = 'python:execute', @@ -232,6 +236,8 @@ export enum IpcChannel { Backup_ListS3Files = 'backup:listS3Files', Backup_DeleteS3File = 'backup:deleteS3File', Backup_CheckS3Connection = 'backup:checkS3Connection', + Backup_CreateLanTransferBackup = 'backup:createLanTransferBackup', + Backup_DeleteTempBackup = 'backup:deleteTempBackup', // zip Zip_Compress = 'zip:compress', @@ -242,6 +248,9 @@ export enum IpcChannel { System_GetHostname = 'system:getHostname', System_GetCpuName = 'system:getCpuName', System_CheckGitBash = 'system:checkGitBash', + System_GetGitBashPath = 'system:getGitBashPath', + System_GetGitBashPathInfo = 'system:getGitBashPathInfo', + System_SetGitBashPath = 'system:setGitBashPath', // DevTools System_ToggleDevTools = 'system:toggleDevTools', @@ -296,6 +305,8 @@ export enum IpcChannel { Selection_ActionWindowClose = 'selection:action-window-close', Selection_ActionWindowMinimize = 'selection:action-window-minimize', Selection_ActionWindowPin = 'selection:action-window-pin', + // [Windows only] Electron bug workaround - can be removed once https://github.com/electron/electron/issues/48554 is fixed + Selection_ActionWindowResize = 'selection:action-window-resize', Selection_ProcessAction = 'selection:process-action', Selection_UpdateActionData = 'selection:update-action-data', @@ -310,6 +321,7 @@ export enum IpcChannel { Memory_DeleteUser = 'memory:delete-user', Memory_DeleteAllMemoriesForUser = 'memory:delete-all-memories-for-user', Memory_GetUsersList = 'memory:get-users-list', + Memory_MigrateMemoryDb = 'memory:migrate-memory-db', // TRACE TRACE_SAVE_DATA = 'trace:saveData', @@ -355,6 +367,7 @@ export enum IpcChannel { OCR_ListProviders = 'ocr:list-providers', // OVMS + Ovms_IsSupported = 'ovms:is-supported', Ovms_AddModel = 'ovms:add-model', Ovms_StopAddModel = 'ovms:stop-addmodel', Ovms_GetModels = 'ovms:get-models', @@ -375,10 +388,14 @@ export enum IpcChannel { ClaudeCodePlugin_ReadContent = 'claudeCodePlugin:read-content', ClaudeCodePlugin_WriteContent = 'claudeCodePlugin:write-content', - // WebSocket - WebSocket_Start = 'webSocket:start', - WebSocket_Stop = 'webSocket:stop', - WebSocket_Status = 'webSocket:status', - WebSocket_SendFile = 'webSocket:send-file', - WebSocket_GetAllCandidates = 'webSocket:get-all-candidates' + // Local Transfer + LocalTransfer_ListServices = 'local-transfer:list', + LocalTransfer_StartScan = 'local-transfer:start-scan', + LocalTransfer_StopScan = 'local-transfer:stop-scan', + LocalTransfer_ServicesUpdated = 'local-transfer:services-updated', + LocalTransfer_Connect = 'local-transfer:connect', + LocalTransfer_Disconnect = 'local-transfer:disconnect', + LocalTransfer_ClientEvent = 'local-transfer:client-event', + LocalTransfer_SendFile = 'local-transfer:send-file', + LocalTransfer_CancelTransfer = 'local-transfer:cancel-transfer' } diff --git a/packages/shared/anthropic/index.ts b/packages/shared/anthropic/index.ts index bff143d118..b9e9cb8846 100644 --- a/packages/shared/anthropic/index.ts +++ b/packages/shared/anthropic/index.ts @@ -88,16 +88,11 @@ export function getSdkClient( } }) } - let baseURL = + const baseURL = provider.type === 'anthropic' ? provider.apiHost : (provider.anthropicApiHost && provider.anthropicApiHost.trim()) || provider.apiHost - // Anthropic SDK automatically appends /v1 to all endpoints (like /v1/messages, /v1/models) - // We need to strip api version from baseURL to avoid duplication (e.g., /v3/v1/models) - // formatProviderApiHost adds /v1 for AI SDK compatibility, but Anthropic SDK needs it removed - baseURL = baseURL.replace(/\/v\d+(?:alpha|beta)?(?=\/|$)/i, '') - logger.debug('Anthropic API baseURL', { baseURL, providerId: provider.id }) if (provider.id === 'aihubmix') { diff --git a/packages/shared/config/constant.ts b/packages/shared/config/constant.ts index 1e02ce7706..af0191f4fa 100644 --- a/packages/shared/config/constant.ts +++ b/packages/shared/config/constant.ts @@ -488,3 +488,11 @@ export const MACOS_TERMINALS_WITH_COMMANDS: TerminalConfigWithCommand[] = [ // resources/scripts should be maintained manually export const HOME_CHERRY_DIR = '.cherrystudio' + +// Git Bash path configuration types +export type GitBashPathSource = 'manual' | 'auto' + +export interface GitBashPathInfo { + path: string | null + source: GitBashPathSource | null +} diff --git a/packages/shared/config/types.ts b/packages/shared/config/types.ts index 8fba6399f8..56f746b0d5 100644 --- a/packages/shared/config/types.ts +++ b/packages/shared/config/types.ts @@ -23,6 +23,14 @@ export type MCPProgressEvent = { progress: number // 0-1 range } +export type MCPServerLogEntry = { + timestamp: number + level: 'debug' | 'info' | 'warn' | 'error' | 'stderr' | 'stdout' + message: string + data?: any + source?: string +} + export type WebviewKeyEvent = { webviewId: number key: string @@ -44,3 +52,196 @@ export interface WebSocketCandidatesResponse { interface: string priority: number } + +export type LocalTransferPeer = { + id: string + name: string + host?: string + fqdn?: string + port?: number + type?: string + protocol?: 'tcp' | 'udp' + addresses: string[] + txt?: Record + updatedAt: number +} + +export type LocalTransferState = { + services: LocalTransferPeer[] + isScanning: boolean + lastScanStartedAt?: number + lastUpdatedAt: number + lastError?: string +} + +export type LanHandshakeRequestMessage = { + type: 'handshake' + deviceName: string + version: string + platform?: string + appVersion?: string +} + +export type LanHandshakeAckMessage = { + type: 'handshake_ack' + accepted: boolean + message?: string +} + +export type LocalTransferConnectPayload = { + peerId: string + metadata?: Record + timeoutMs?: number +} + +export type LanClientEvent = + | { + type: 'ping_sent' + payload: string + timestamp: number + peerId?: string + peerName?: string + } + | { + type: 'pong' + payload?: string + received?: boolean + timestamp: number + peerId?: string + peerName?: string + } + | { + type: 'socket_closed' + reason?: string + timestamp: number + peerId?: string + peerName?: string + } + | { + type: 'error' + message: string + timestamp: number + peerId?: string + peerName?: string + } + | { + type: 'file_transfer_progress' + transferId: string + fileName: string + bytesSent: number + totalBytes: number + chunkIndex: number + totalChunks: number + progress: number // 0-100 + speed: number // bytes/sec + timestamp: number + peerId?: string + peerName?: string + } + | { + type: 'file_transfer_complete' + transferId: string + fileName: string + success: boolean + filePath?: string + error?: string + timestamp: number + peerId?: string + peerName?: string + } + +// ============================================================================= +// LAN File Transfer Protocol Types +// ============================================================================= + +// Constants for file transfer +export const LAN_TRANSFER_TCP_PORT = 53317 +export const LAN_TRANSFER_CHUNK_SIZE = 512 * 1024 // 512KB +export const LAN_TRANSFER_MAX_FILE_SIZE = 500 * 1024 * 1024 // 500MB +export const LAN_TRANSFER_COMPLETE_TIMEOUT_MS = 60_000 // 60s - wait for file_complete after file_end +export const LAN_TRANSFER_GLOBAL_TIMEOUT_MS = 10 * 60 * 1000 // 10 minutes - global transfer timeout + +// Binary protocol constants (v1) +export const LAN_TRANSFER_PROTOCOL_VERSION = '1' +export const LAN_BINARY_FRAME_MAGIC = 0x4353 // "CS" as uint16 +export const LAN_BINARY_TYPE_FILE_CHUNK = 0x01 + +// Messages from Electron (Client/Sender) to Mobile (Server/Receiver) + +/** Request to start file transfer */ +export type LanFileStartMessage = { + type: 'file_start' + transferId: string + fileName: string + fileSize: number + mimeType: string // 'application/zip' + checksum: string // SHA-256 of entire file + totalChunks: number + chunkSize: number +} + +/** + * File chunk data (JSON format) + * @deprecated Use binary frame format in protocol v1. This type is kept for reference only. + */ +export type LanFileChunkMessage = { + type: 'file_chunk' + transferId: string + chunkIndex: number + data: string // Base64 encoded + chunkChecksum: string // SHA-256 of this chunk +} + +/** Notification that all chunks have been sent */ +export type LanFileEndMessage = { + type: 'file_end' + transferId: string +} + +/** Request to cancel file transfer */ +export type LanFileCancelMessage = { + type: 'file_cancel' + transferId: string + reason?: string +} + +// Messages from Mobile (Server/Receiver) to Electron (Client/Sender) + +/** Acknowledgment of file transfer request */ +export type LanFileStartAckMessage = { + type: 'file_start_ack' + transferId: string + accepted: boolean + message?: string // Rejection reason +} + +/** + * Acknowledgment of file chunk received + * @deprecated Protocol v1 uses streaming mode without per-chunk acknowledgment. + * This type is kept for backward compatibility reference only. + */ +export type LanFileChunkAckMessage = { + type: 'file_chunk_ack' + transferId: string + chunkIndex: number + received: boolean + message?: string +} + +/** Final result of file transfer */ +export type LanFileCompleteMessage = { + type: 'file_complete' + transferId: string + success: boolean + filePath?: string // Path where file was saved on mobile + error?: string + // Enhanced error diagnostics + errorCode?: 'CHECKSUM_MISMATCH' | 'INCOMPLETE_TRANSFER' | 'DISK_ERROR' | 'CANCELLED' + receivedChunks?: number + receivedBytes?: number +} + +/** Payload for sending a file via IPC */ +export type LanFileSendPayload = { + filePath: string +} diff --git a/packages/shared/utils.ts b/packages/shared/utils.ts index a14f78958d..7e90624aba 100644 --- a/packages/shared/utils.ts +++ b/packages/shared/utils.ts @@ -35,3 +35,56 @@ export const defaultAppHeaders = () => { // return value // } // } + +/** + * Extracts the trailing API version segment from a URL path. + * + * This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL. + * Only versions at the end of the path are extracted, not versions in the middle. + * The returned version string does not include leading or trailing slashes. + * + * @param {string} url - The URL string to parse. + * @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found. + * + * @example + * getTrailingApiVersion('https://api.example.com/v1') // 'v1' + * getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta' + * getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end) + * getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta' + * getTrailingApiVersion('https://api.example.com') // undefined + */ +export function getTrailingApiVersion(url: string): string | undefined { + const match = url.match(TRAILING_VERSION_REGEX) + + if (match) { + // Extract version without leading slash and trailing slash + return match[0].replace(/^\//, '').replace(/\/$/, '') + } + + return undefined +} + +/** + * Matches an API version at the end of a URL (with optional trailing slash). + * Used to detect and extract versions only from the trailing position. + */ +const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i + +/** + * Removes the trailing API version segment from a URL path. + * + * This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL. + * Only versions at the end of the path are removed, not versions in the middle. + * + * @param {string} url - The URL string to process. + * @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found. + * + * @example + * withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com' + * withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com' + * withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change) + * withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com' + */ +export function withoutTrailingApiVersion(url: string): string { + return url.replace(TRAILING_VERSION_REGEX, '') +} diff --git a/resources/scripts/install-ovms.js b/resources/scripts/install-ovms.js index e4a5cf0444..8ccd522b01 100644 --- a/resources/scripts/install-ovms.js +++ b/resources/scripts/install-ovms.js @@ -6,12 +6,12 @@ const { downloadWithPowerShell } = require('./download') // Base URL for downloading OVMS binaries const OVMS_RELEASE_BASE_URL = - 'https://storage.openvinotoolkit.org/repositories/openvino_model_server/packages/2025.3.0/ovms_windows_python_on.zip' -const OVMS_EX_URL = 'https://gitcode.com/gcw_ggDjjkY3/kjfile/releases/download/download/ovms_25.3_ex.zip' + 'https://storage.openvinotoolkit.org/repositories/openvino_model_server/packages/2025.4.1/ovms_windows_python_on.zip' +const OVMS_EX_URL = 'https://gitcode.com/gcw_ggDjjkY3/kjfile/releases/download/download/ovms_25.4_ex.zip' /** * error code: - * 101: Unsupported CPU (not Intel Ultra) + * 101: Unsupported CPU (not Intel) * 102: Unsupported platform (not Windows) * 103: Download failed * 104: Installation failed @@ -213,8 +213,8 @@ async function installOvms() { console.log(`CPU Name: ${cpuName}`) // Check if CPU name contains "Ultra" - if (!cpuName.toLowerCase().includes('intel') || !cpuName.toLowerCase().includes('ultra')) { - console.error('OVMS installation requires an Intel(R) Core(TM) Ultra CPU.') + if (!cpuName.toLowerCase().includes('intel')) { + console.error('OVMS installation requires an Intel CPU.') return 101 } diff --git a/scripts/auto-translate-i18n.ts b/scripts/auto-translate-i18n.ts index 7a1bea6f35..41bb14a0a1 100644 --- a/scripts/auto-translate-i18n.ts +++ b/scripts/auto-translate-i18n.ts @@ -50,7 +50,7 @@ Usage Instructions: - pt-pt (Portuguese) Run Command: -yarn auto:i18n +yarn i18n:translate Performance Optimization Recommendations: - For stable API services: MAX_CONCURRENT_TRANSLATIONS=8, TRANSLATION_DELAY_MS=50 @@ -152,7 +152,8 @@ const languageMap = { 'es-es': 'Spanish', 'fr-fr': 'French', 'pt-pt': 'Portuguese', - 'de-de': 'German' + 'de-de': 'German', + 'ro-ro': 'Romanian' } const PROMPT = ` diff --git a/scripts/check-i18n.ts b/scripts/check-i18n.ts index 5735474106..ac1adc3de8 100644 --- a/scripts/check-i18n.ts +++ b/scripts/check-i18n.ts @@ -145,7 +145,7 @@ export function main() { console.log('i18n 检查已通过') } catch (e) { console.error(e) - throw new Error(`检查未通过。尝试运行 yarn sync:i18n 以解决问题。`) + throw new Error(`检查未通过。尝试运行 yarn i18n:sync 以解决问题。`) } } diff --git a/scripts/win-sign.js b/scripts/win-sign.js index f9b37c3aed..cdbfe11e17 100644 --- a/scripts/win-sign.js +++ b/scripts/win-sign.js @@ -5,9 +5,17 @@ exports.default = async function (configuration) { const { path } = configuration if (configuration.path) { try { + const certPath = process.env.CHERRY_CERT_PATH + const keyContainer = process.env.CHERRY_CERT_KEY + const csp = process.env.CHERRY_CERT_CSP + + if (!certPath || !keyContainer || !csp) { + throw new Error('CHERRY_CERT_PATH, CHERRY_CERT_KEY or CHERRY_CERT_CSP is not set') + } + console.log('Start code signing...') console.log('Signing file:', path) - const signCommand = `signtool sign /tr http://timestamp.comodoca.com /td sha256 /fd sha256 /a /v "${path}"` + const signCommand = `signtool sign /tr http://timestamp.comodoca.com /td sha256 /fd sha256 /v /f "${certPath}" /csp "${csp}" /k "${keyContainer}" "${path}"` execSync(signCommand, { stdio: 'inherit' }) console.log('Code signing completed') } catch (error) { diff --git a/src/main/index.ts b/src/main/index.ts index 56750e6b61..536485a490 100644 --- a/src/main/index.ts +++ b/src/main/index.ts @@ -19,7 +19,9 @@ import { agentService } from './services/agents' import { apiServerService } from './services/ApiServerService' import { appMenuService } from './services/AppMenuService' import { configManager } from './services/ConfigManager' +import { lanTransferClientService } from './services/lanTransfer' import mcpService from './services/MCPService' +import { localTransferService } from './services/LocalTransferService' import { nodeTraceService } from './services/NodeTraceService' import powerMonitorService from './services/PowerMonitorService' import { @@ -35,6 +37,7 @@ import { versionService } from './services/VersionService' import { windowService } from './services/WindowService' import { initWebviewHotkeys } from './services/WebviewService' import { runAsyncFunction } from './utils' +import { isOvmsSupported } from './services/OvmsManager' const logger = loggerService.withContext('MainEntry') @@ -155,7 +158,8 @@ if (!app.requestSingleInstanceLock()) { registerShortcuts(mainWindow) - registerIpc(mainWindow, app) + await registerIpc(mainWindow, app) + localTransferService.startDiscovery({ resetList: true }) replaceDevtoolsFont(mainWindow) @@ -237,16 +241,29 @@ if (!app.requestSingleInstanceLock()) { if (selectionService) { selectionService.quit() } + + lanTransferClientService.dispose() + localTransferService.dispose() }) app.on('will-quit', async () => { // 简单的资源清理,不阻塞退出流程 + if (isOvmsSupported) { + const { ovmsManager } = await import('./services/OvmsManager') + if (ovmsManager) { + await ovmsManager.stopOvms() + } else { + logger.warn('Unexpected behavior: undefined ovmsManager, but OVMS should be supported.') + } + } + try { await mcpService.cleanup() await apiServerService.stop() } catch (error) { logger.warn('Error cleaning up MCP service:', error as Error) } + // finish the logger logger.finish() }) diff --git a/src/main/ipc.ts b/src/main/ipc.ts index 3b7736a6d2..e9cd17dd13 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -6,11 +6,19 @@ import { loggerService } from '@logger' import { isLinux, isMac, isPortable, isWin } from '@main/constant' import { generateSignature } from '@main/integration/cherryai' import anthropicService from '@main/services/AnthropicService' -import { findGitBash, getBinaryPath, isBinaryExists, runInstallScript } from '@main/utils/process' +import { + autoDiscoverGitBash, + getBinaryPath, + getGitBashPathInfo, + isBinaryExists, + runInstallScript, + validateGitBashPath +} from '@main/utils/process' import { handleZoomFactor } from '@main/utils/zoom' import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core' import type { UpgradeChannel } from '@shared/config/constant' import { MIN_WINDOW_HEIGHT, MIN_WINDOW_WIDTH } from '@shared/config/constant' +import type { LocalTransferConnectPayload } from '@shared/config/types' import { IpcChannel } from '@shared/IpcChannel' import type { PluginError } from '@types' import type { @@ -35,13 +43,15 @@ import appService from './services/AppService' import AppUpdater from './services/AppUpdater' import BackupManager from './services/BackupManager' import { codeToolsService } from './services/CodeToolsService' -import { configManager } from './services/ConfigManager' +import { ConfigKeys, configManager } from './services/ConfigManager' import CopilotService from './services/CopilotService' import DxtService from './services/DxtService' import { ExportService } from './services/ExportService' import { fileStorage as fileManager } from './services/FileStorage' import FileService from './services/FileSystemService' import KnowledgeService from './services/KnowledgeService' +import { lanTransferClientService } from './services/lanTransfer' +import { localTransferService } from './services/LocalTransferService' import mcpService from './services/MCPService' import MemoryService from './services/memory/MemoryService' import { openTraceWindow, setTraceWindowTitle } from './services/NodeTraceService' @@ -49,7 +59,7 @@ import NotificationService from './services/NotificationService' import * as NutstoreService from './services/NutstoreService' import ObsidianVaultService from './services/ObsidianVaultService' import { ocrService } from './services/ocr/OcrService' -import OvmsManager from './services/OvmsManager' +import { isOvmsSupported } from './services/OvmsManager' import powerMonitorService from './services/PowerMonitorService' import { proxyManager } from './services/ProxyManager' import { pythonService } from './services/PythonService' @@ -73,7 +83,6 @@ import { import storeSyncService from './services/StoreSyncService' import { themeService } from './services/ThemeService' import VertexAIService from './services/VertexAIService' -import WebSocketService from './services/WebSocketService' import { setOpenLinkExternal } from './services/WebviewService' import { windowService } from './services/WindowService' import { calculateDirectorySize, getResourcePath } from './utils' @@ -88,6 +97,7 @@ import { untildify } from './utils/file' import { updateAppDataConfig } from './utils/init' +import { getCpuName, getDeviceType, getHostname } from './utils/system' import { compress, decompress } from './utils/zip' const logger = loggerService.withContext('IPC') @@ -98,7 +108,6 @@ const obsidianVaultService = new ObsidianVaultService() const vertexAIService = VertexAIService.getInstance() const memoryService = MemoryService.getInstance() const dxtService = new DxtService() -const ovmsManager = new OvmsManager() const pluginService = PluginService.getInstance() function normalizeError(error: unknown): Error { @@ -112,7 +121,7 @@ function extractPluginError(error: unknown): PluginError | null { return null } -export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { +export async function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { const appUpdater = new AppUpdater() const notificationService = new NotificationService() @@ -490,17 +499,17 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text)) // system - ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux')) - ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname()) - ipcMain.handle(IpcChannel.System_GetCpuName, () => require('os').cpus()[0].model) + ipcMain.handle(IpcChannel.System_GetDeviceType, getDeviceType) + ipcMain.handle(IpcChannel.System_GetHostname, getHostname) + ipcMain.handle(IpcChannel.System_GetCpuName, getCpuName) ipcMain.handle(IpcChannel.System_CheckGitBash, () => { if (!isWin) { return true // Non-Windows systems don't need Git Bash } try { - const bashPath = findGitBash() - + // Use autoDiscoverGitBash to handle auto-discovery and persistence + const bashPath = autoDiscoverGitBash() if (bashPath) { logger.info('Git Bash is available', { path: bashPath }) return true @@ -513,6 +522,46 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { return false } }) + + ipcMain.handle(IpcChannel.System_GetGitBashPath, () => { + if (!isWin) { + return null + } + + const customPath = configManager.get(ConfigKeys.GitBashPath) as string | undefined + return customPath ?? null + }) + + // Returns { path, source } where source is 'manual' | 'auto' | null + ipcMain.handle(IpcChannel.System_GetGitBashPathInfo, () => { + return getGitBashPathInfo() + }) + + ipcMain.handle(IpcChannel.System_SetGitBashPath, (_, newPath: string | null) => { + if (!isWin) { + return false + } + + if (!newPath) { + // Clear manual setting and re-run auto-discovery + configManager.set(ConfigKeys.GitBashPath, null) + configManager.set(ConfigKeys.GitBashPathSource, null) + // Re-run auto-discovery to restore auto-discovered path if available + autoDiscoverGitBash() + return true + } + + const validated = validateGitBashPath(newPath) + if (!validated) { + return false + } + + // Set path with 'manual' source + configManager.set(ConfigKeys.GitBashPath, validated) + configManager.set(ConfigKeys.GitBashPathSource, 'manual') + return true + }) + ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => { const win = BrowserWindow.fromWebContents(e.sender) win && win.webContents.toggleDevTools() @@ -536,6 +585,8 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.Backup_ListS3Files, backupManager.listS3Files.bind(backupManager)) ipcMain.handle(IpcChannel.Backup_DeleteS3File, backupManager.deleteS3File.bind(backupManager)) ipcMain.handle(IpcChannel.Backup_CheckS3Connection, backupManager.checkS3Connection.bind(backupManager)) + ipcMain.handle(IpcChannel.Backup_CreateLanTransferBackup, backupManager.createLanTransferBackup.bind(backupManager)) + ipcMain.handle(IpcChannel.Backup_DeleteTempBackup, backupManager.deleteTempBackup.bind(backupManager)) // file ipcMain.handle(IpcChannel.File_Open, fileManager.open.bind(fileManager)) @@ -638,36 +689,19 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.KnowledgeBase_Check_Quota, KnowledgeService.checkQuota.bind(KnowledgeService)) // memory - ipcMain.handle(IpcChannel.Memory_Add, async (_, messages, config) => { - return await memoryService.add(messages, config) - }) - ipcMain.handle(IpcChannel.Memory_Search, async (_, query, config) => { - return await memoryService.search(query, config) - }) - ipcMain.handle(IpcChannel.Memory_List, async (_, config) => { - return await memoryService.list(config) - }) - ipcMain.handle(IpcChannel.Memory_Delete, async (_, id) => { - return await memoryService.delete(id) - }) - ipcMain.handle(IpcChannel.Memory_Update, async (_, id, memory, metadata) => { - return await memoryService.update(id, memory, metadata) - }) - ipcMain.handle(IpcChannel.Memory_Get, async (_, memoryId) => { - return await memoryService.get(memoryId) - }) - ipcMain.handle(IpcChannel.Memory_SetConfig, async (_, config) => { - memoryService.setConfig(config) - }) - ipcMain.handle(IpcChannel.Memory_DeleteUser, async (_, userId) => { - return await memoryService.deleteUser(userId) - }) - ipcMain.handle(IpcChannel.Memory_DeleteAllMemoriesForUser, async (_, userId) => { - return await memoryService.deleteAllMemoriesForUser(userId) - }) - ipcMain.handle(IpcChannel.Memory_GetUsersList, async () => { - return await memoryService.getUsersList() - }) + ipcMain.handle(IpcChannel.Memory_Add, (_, messages, config) => memoryService.add(messages, config)) + ipcMain.handle(IpcChannel.Memory_Search, (_, query, config) => memoryService.search(query, config)) + ipcMain.handle(IpcChannel.Memory_List, (_, config) => memoryService.list(config)) + ipcMain.handle(IpcChannel.Memory_Delete, (_, id) => memoryService.delete(id)) + ipcMain.handle(IpcChannel.Memory_Update, (_, id, memory, metadata) => memoryService.update(id, memory, metadata)) + ipcMain.handle(IpcChannel.Memory_Get, (_, memoryId) => memoryService.get(memoryId)) + ipcMain.handle(IpcChannel.Memory_SetConfig, (_, config) => memoryService.setConfig(config)) + ipcMain.handle(IpcChannel.Memory_DeleteUser, (_, userId) => memoryService.deleteUser(userId)) + ipcMain.handle(IpcChannel.Memory_DeleteAllMemoriesForUser, (_, userId) => + memoryService.deleteAllMemoriesForUser(userId) + ) + ipcMain.handle(IpcChannel.Memory_GetUsersList, () => memoryService.getUsersList()) + ipcMain.handle(IpcChannel.Memory_MigrateMemoryDb, () => memoryService.migrateMemoryDb()) // window ipcMain.handle(IpcChannel.Windows_SetMinimumSize, (_, width: number, height: number) => { @@ -768,6 +802,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.Mcp_CheckConnectivity, mcpService.checkMcpConnectivity) ipcMain.handle(IpcChannel.Mcp_AbortTool, mcpService.abortTool) ipcMain.handle(IpcChannel.Mcp_GetServerVersion, mcpService.getServerVersion) + ipcMain.handle(IpcChannel.Mcp_GetServerLogs, mcpService.getServerLogs) // DXT upload handler ipcMain.handle(IpcChannel.Mcp_UploadDxt, async (event, fileBuffer: ArrayBuffer, fileName: string) => { @@ -826,8 +861,8 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ) // search window - ipcMain.handle(IpcChannel.SearchWindow_Open, async (_, uid: string) => { - await searchService.openSearchWindow(uid) + ipcMain.handle(IpcChannel.SearchWindow_Open, async (_, uid: string, show?: boolean) => { + await searchService.openSearchWindow(uid, show) }) ipcMain.handle(IpcChannel.SearchWindow_Close, async (_, uid: string) => { await searchService.closeSearchWindow(uid) @@ -846,6 +881,17 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { webview.session.setSpellCheckerEnabled(isEnable) }) + // Webview print and save handlers + ipcMain.handle(IpcChannel.Webview_PrintToPDF, async (_, webviewId: number) => { + const { printWebviewToPDF } = await import('./services/WebviewService') + return await printWebviewToPDF(webviewId) + }) + + ipcMain.handle(IpcChannel.Webview_SaveAsHTML, async (_, webviewId: number) => { + const { saveWebviewAsHTML } = await import('./services/WebviewService') + return await saveWebviewAsHTML(webviewId) + }) + // store sync storeSyncService.registerIpcHandler() @@ -932,15 +978,36 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.OCR_ListProviders, () => ocrService.listProviderIds()) // OVMS - ipcMain.handle(IpcChannel.Ovms_AddModel, (_, modelName: string, modelId: string, modelSource: string, task: string) => - ovmsManager.addModel(modelName, modelId, modelSource, task) - ) - ipcMain.handle(IpcChannel.Ovms_StopAddModel, () => ovmsManager.stopAddModel()) - ipcMain.handle(IpcChannel.Ovms_GetModels, () => ovmsManager.getModels()) - ipcMain.handle(IpcChannel.Ovms_IsRunning, () => ovmsManager.initializeOvms()) - ipcMain.handle(IpcChannel.Ovms_GetStatus, () => ovmsManager.getOvmsStatus()) - ipcMain.handle(IpcChannel.Ovms_RunOVMS, () => ovmsManager.runOvms()) - ipcMain.handle(IpcChannel.Ovms_StopOVMS, () => ovmsManager.stopOvms()) + ipcMain.handle(IpcChannel.Ovms_IsSupported, () => isOvmsSupported) + if (isOvmsSupported) { + const { ovmsManager } = await import('./services/OvmsManager') + if (ovmsManager) { + ipcMain.handle( + IpcChannel.Ovms_AddModel, + (_, modelName: string, modelId: string, modelSource: string, task: string) => + ovmsManager.addModel(modelName, modelId, modelSource, task) + ) + ipcMain.handle(IpcChannel.Ovms_StopAddModel, () => ovmsManager.stopAddModel()) + ipcMain.handle(IpcChannel.Ovms_GetModels, () => ovmsManager.getModels()) + ipcMain.handle(IpcChannel.Ovms_IsRunning, () => ovmsManager.initializeOvms()) + ipcMain.handle(IpcChannel.Ovms_GetStatus, () => ovmsManager.getOvmsStatus()) + ipcMain.handle(IpcChannel.Ovms_RunOVMS, () => ovmsManager.runOvms()) + ipcMain.handle(IpcChannel.Ovms_StopOVMS, () => ovmsManager.stopOvms()) + } else { + logger.error('Unexpected behavior: undefined ovmsManager, but OVMS should be supported.') + } + } else { + const fallback = () => { + throw new Error('OVMS is only supported on Windows with intel CPU.') + } + ipcMain.handle(IpcChannel.Ovms_AddModel, fallback) + ipcMain.handle(IpcChannel.Ovms_StopAddModel, fallback) + ipcMain.handle(IpcChannel.Ovms_GetModels, fallback) + ipcMain.handle(IpcChannel.Ovms_IsRunning, fallback) + ipcMain.handle(IpcChannel.Ovms_GetStatus, fallback) + ipcMain.handle(IpcChannel.Ovms_RunOVMS, fallback) + ipcMain.handle(IpcChannel.Ovms_StopOVMS, fallback) + } // CherryAI ipcMain.handle(IpcChannel.Cherryai_GetSignature, (_, params) => generateSignature(params)) @@ -997,12 +1064,18 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { } catch (error) { const pluginError = extractPluginError(error) if (pluginError) { - logger.error('Failed to list installed plugins', { agentId, error: pluginError }) + logger.error('Failed to list installed plugins', { + agentId, + error: pluginError + }) return { success: false, error: pluginError } } const err = normalizeError(error) - logger.error('Failed to list installed plugins', { agentId, error: err }) + logger.error('Failed to list installed plugins', { + agentId, + error: err + }) return { success: false, error: { @@ -1058,12 +1131,17 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { } }) - // WebSocket - ipcMain.handle(IpcChannel.WebSocket_Start, WebSocketService.start) - ipcMain.handle(IpcChannel.WebSocket_Stop, WebSocketService.stop) - ipcMain.handle(IpcChannel.WebSocket_Status, WebSocketService.getStatus) - ipcMain.handle(IpcChannel.WebSocket_SendFile, WebSocketService.sendFile) - ipcMain.handle(IpcChannel.WebSocket_GetAllCandidates, WebSocketService.getAllCandidates) + ipcMain.handle(IpcChannel.LocalTransfer_ListServices, () => localTransferService.getState()) + ipcMain.handle(IpcChannel.LocalTransfer_StartScan, () => localTransferService.startDiscovery({ resetList: true })) + ipcMain.handle(IpcChannel.LocalTransfer_StopScan, () => localTransferService.stopDiscovery()) + ipcMain.handle(IpcChannel.LocalTransfer_Connect, (_, payload: LocalTransferConnectPayload) => + lanTransferClientService.connectAndHandshake(payload) + ) + ipcMain.handle(IpcChannel.LocalTransfer_Disconnect, () => lanTransferClientService.disconnect()) + ipcMain.handle(IpcChannel.LocalTransfer_SendFile, (_, payload: { filePath: string }) => + lanTransferClientService.sendFile(payload.filePath) + ) + ipcMain.handle(IpcChannel.LocalTransfer_CancelTransfer, () => lanTransferClientService.cancelTransfer()) ipcMain.handle(IpcChannel.APP_CrashRenderProcess, () => { mainWindow.webContents.forcefullyCrashRenderer() diff --git a/src/main/mcpServers/__tests__/browser.test.ts b/src/main/mcpServers/__tests__/browser.test.ts new file mode 100644 index 0000000000..800d03d7c5 --- /dev/null +++ b/src/main/mcpServers/__tests__/browser.test.ts @@ -0,0 +1,372 @@ +import { describe, expect, it, vi } from 'vitest' + +vi.mock('node:fs', () => ({ + default: { + existsSync: vi.fn(() => false), + mkdirSync: vi.fn() + }, + existsSync: vi.fn(() => false), + mkdirSync: vi.fn() +})) + +vi.mock('electron', () => { + const sendCommand = vi.fn(async (command: string, params?: { expression?: string }) => { + if (command === 'Runtime.evaluate') { + if (params?.expression === 'document.documentElement.outerHTML') { + return { result: { value: '

Test

Content

' } } + } + if (params?.expression === 'document.body.innerText') { + return { result: { value: 'Test\nContent' } } + } + return { result: { value: 'ok' } } + } + return {} + }) + + const debuggerObj = { + isAttached: vi.fn(() => true), + attach: vi.fn(), + detach: vi.fn(), + sendCommand + } + + const createWebContents = () => ({ + debugger: debuggerObj, + setUserAgent: vi.fn(), + getURL: vi.fn(() => 'https://example.com/'), + getTitle: vi.fn(async () => 'Example Title'), + loadURL: vi.fn(async () => {}), + once: vi.fn(), + removeListener: vi.fn(), + on: vi.fn(), + isDestroyed: vi.fn(() => false), + canGoBack: vi.fn(() => false), + canGoForward: vi.fn(() => false), + goBack: vi.fn(), + goForward: vi.fn(), + reload: vi.fn(), + executeJavaScript: vi.fn(async () => null), + setWindowOpenHandler: vi.fn() + }) + + const windows: any[] = [] + const views: any[] = [] + + class MockBrowserWindow { + private destroyed = false + public webContents = createWebContents() + public isDestroyed = vi.fn(() => this.destroyed) + public close = vi.fn(() => { + this.destroyed = true + }) + public destroy = vi.fn(() => { + this.destroyed = true + }) + public on = vi.fn() + public setBrowserView = vi.fn() + public addBrowserView = vi.fn() + public removeBrowserView = vi.fn() + public getContentSize = vi.fn(() => [1200, 800]) + public show = vi.fn() + + constructor() { + windows.push(this) + } + } + + class MockBrowserView { + public webContents = createWebContents() + public setBounds = vi.fn() + public setAutoResize = vi.fn() + public destroy = vi.fn() + + constructor() { + views.push(this) + } + } + + const app = { + isReady: vi.fn(() => true), + whenReady: vi.fn(async () => {}), + on: vi.fn(), + getPath: vi.fn((key: string) => { + if (key === 'userData') return '/mock/userData' + if (key === 'temp') return '/tmp' + return '/mock/unknown' + }), + getAppPath: vi.fn(() => '/mock/app'), + setPath: vi.fn() + } + + const nativeTheme = { + on: vi.fn(), + shouldUseDarkColors: false + } + + return { + BrowserWindow: MockBrowserWindow as any, + BrowserView: MockBrowserView as any, + app, + nativeTheme, + __mockDebugger: debuggerObj, + __mockSendCommand: sendCommand, + __mockWindows: windows, + __mockViews: views + } +}) + +import { CdpBrowserController } from '../browser' + +describe('CdpBrowserController', () => { + it('executes single-line code via Runtime.evaluate', async () => { + const controller = new CdpBrowserController() + const result = await controller.execute('1+1') + expect(result).toBe('ok') + }) + + it('opens a URL in normal mode and returns current page info', async () => { + const controller = new CdpBrowserController() + const result = await controller.open('https://foo.bar/', 5000, false) + expect(result.currentUrl).toBe('https://example.com/') + expect(result.title).toBe('Example Title') + }) + + it('opens a URL in private mode', async () => { + const controller = new CdpBrowserController() + const result = await controller.open('https://foo.bar/', 5000, true) + expect(result.currentUrl).toBe('https://example.com/') + expect(result.title).toBe('Example Title') + }) + + it('reuses session for execute and supports multiline', async () => { + const controller = new CdpBrowserController() + await controller.open('https://foo.bar/', 5000, false) + const result = await controller.execute('const a=1; const b=2; a+b;', 5000, false) + expect(result).toBe('ok') + }) + + it('normal and private modes are isolated', async () => { + const controller = new CdpBrowserController() + await controller.open('https://foo.bar/', 5000, false) + await controller.open('https://foo.bar/', 5000, true) + const normalResult = await controller.execute('1+1', 5000, false) + const privateResult = await controller.execute('1+1', 5000, true) + expect(normalResult).toBe('ok') + expect(privateResult).toBe('ok') + }) + + it('fetches URL and returns html format with tabId', async () => { + const controller = new CdpBrowserController() + const result = await controller.fetch('https://example.com/', 'html') + expect(result.tabId).toBeDefined() + expect(result.content).toBe('

Test

Content

') + }) + + it('fetches URL and returns txt format with tabId', async () => { + const controller = new CdpBrowserController() + const result = await controller.fetch('https://example.com/', 'txt') + expect(result.tabId).toBeDefined() + expect(result.content).toBe('Test\nContent') + }) + + it('fetches URL and returns markdown format (default) with tabId', async () => { + const controller = new CdpBrowserController() + const result = await controller.fetch('https://example.com/') + expect(result.tabId).toBeDefined() + expect(typeof result.content).toBe('string') + expect(result.content).toContain('Test') + }) + + it('fetches URL in private mode with tabId', async () => { + const controller = new CdpBrowserController() + const result = await controller.fetch('https://example.com/', 'html', 10000, true) + expect(result.tabId).toBeDefined() + expect(result.content).toBe('

Test

Content

') + }) + + describe('Multi-tab support', () => { + it('creates new tab with newTab parameter', async () => { + const controller = new CdpBrowserController() + const result1 = await controller.open('https://site1.com/', 5000, false, true) + const result2 = await controller.open('https://site2.com/', 5000, false, true) + + expect(result1.tabId).toBeDefined() + expect(result2.tabId).toBeDefined() + expect(result1.tabId).not.toBe(result2.tabId) + }) + + it('reuses same tab without newTab parameter', async () => { + const controller = new CdpBrowserController() + const result1 = await controller.open('https://site1.com/', 5000, false) + const result2 = await controller.open('https://site2.com/', 5000, false) + + expect(result1.tabId).toBe(result2.tabId) + }) + + it('fetches in new tab with newTab parameter', async () => { + const controller = new CdpBrowserController() + await controller.open('https://example.com/', 5000, false) + const tabs = await controller.listTabs(false) + const initialTabCount = tabs.length + + await controller.fetch('https://other.com/', 'html', 10000, false, true) + const tabsAfter = await controller.listTabs(false) + + expect(tabsAfter.length).toBe(initialTabCount + 1) + }) + }) + + describe('Tab management', () => { + it('lists tabs in a window', async () => { + const controller = new CdpBrowserController() + await controller.open('https://example.com/', 5000, false) + + const tabs = await controller.listTabs(false) + expect(tabs.length).toBeGreaterThan(0) + expect(tabs[0].tabId).toBeDefined() + }) + + it('lists tabs separately for normal and private modes', async () => { + const controller = new CdpBrowserController() + await controller.open('https://example.com/', 5000, false) + await controller.open('https://example.com/', 5000, true) + + const normalTabs = await controller.listTabs(false) + const privateTabs = await controller.listTabs(true) + + expect(normalTabs.length).toBe(1) + expect(privateTabs.length).toBe(1) + expect(normalTabs[0].tabId).not.toBe(privateTabs[0].tabId) + }) + + it('closes specific tab', async () => { + const controller = new CdpBrowserController() + const result1 = await controller.open('https://site1.com/', 5000, false, true) + await controller.open('https://site2.com/', 5000, false, true) + + const tabsBefore = await controller.listTabs(false) + expect(tabsBefore.length).toBe(2) + + await controller.closeTab(false, result1.tabId) + + const tabsAfter = await controller.listTabs(false) + expect(tabsAfter.length).toBe(1) + expect(tabsAfter.find((t) => t.tabId === result1.tabId)).toBeUndefined() + }) + + it('switches active tab', async () => { + const controller = new CdpBrowserController() + const result1 = await controller.open('https://site1.com/', 5000, false, true) + const result2 = await controller.open('https://site2.com/', 5000, false, true) + + await controller.switchTab(false, result1.tabId) + await controller.switchTab(false, result2.tabId) + }) + + it('throws error when switching to non-existent tab', async () => { + const controller = new CdpBrowserController() + await controller.open('https://example.com/', 5000, false) + + await expect(controller.switchTab(false, 'non-existent-tab')).rejects.toThrow('Tab non-existent-tab not found') + }) + }) + + describe('Reset behavior', () => { + it('resets specific tab only', async () => { + const controller = new CdpBrowserController() + const result1 = await controller.open('https://site1.com/', 5000, false, true) + await controller.open('https://site2.com/', 5000, false, true) + + await controller.reset(false, result1.tabId) + + const tabs = await controller.listTabs(false) + expect(tabs.length).toBe(1) + }) + + it('resets specific window only', async () => { + const controller = new CdpBrowserController() + await controller.open('https://example.com/', 5000, false) + await controller.open('https://example.com/', 5000, true) + + await controller.reset(false) + + const normalTabs = await controller.listTabs(false) + const privateTabs = await controller.listTabs(true) + + expect(normalTabs.length).toBe(0) + expect(privateTabs.length).toBe(1) + }) + + it('resets all windows', async () => { + const controller = new CdpBrowserController() + await controller.open('https://example.com/', 5000, false) + await controller.open('https://example.com/', 5000, true) + + await controller.reset() + + const normalTabs = await controller.listTabs(false) + const privateTabs = await controller.listTabs(true) + + expect(normalTabs.length).toBe(0) + expect(privateTabs.length).toBe(0) + }) + }) + + describe('showWindow parameter', () => { + it('passes showWindow parameter through open', async () => { + const controller = new CdpBrowserController() + const result = await controller.open('https://example.com/', 5000, false, false, true) + expect(result.currentUrl).toBe('https://example.com/') + expect(result.tabId).toBeDefined() + }) + + it('passes showWindow parameter through fetch', async () => { + const controller = new CdpBrowserController() + const result = await controller.fetch('https://example.com/', 'html', 10000, false, false, true) + expect(result.tabId).toBeDefined() + expect(result.content).toBe('

Test

Content

') + }) + + it('passes showWindow parameter through createTab', async () => { + const controller = new CdpBrowserController() + const { tabId, view } = await controller.createTab(false, true) + expect(tabId).toBeDefined() + expect(view).toBeDefined() + }) + + it('shows existing window when showWindow=true on subsequent calls', async () => { + const controller = new CdpBrowserController() + // First call creates window + await controller.open('https://example.com/', 5000, false, false, false) + // Second call with showWindow=true should show existing window + const result = await controller.open('https://example.com/', 5000, false, false, true) + expect(result.currentUrl).toBe('https://example.com/') + }) + }) + + describe('Window limits and eviction', () => { + it('respects maxWindows limit', async () => { + const controller = new CdpBrowserController({ maxWindows: 1 }) + await controller.open('https://example.com/', 5000, false) + await controller.open('https://example.com/', 5000, true) + + const normalTabs = await controller.listTabs(false) + const privateTabs = await controller.listTabs(true) + + expect(privateTabs.length).toBe(1) + expect(normalTabs.length).toBe(0) + }) + + it('cleans up idle windows on next access', async () => { + const controller = new CdpBrowserController({ idleTimeoutMs: 1 }) + await controller.open('https://example.com/', 5000, false) + + await new Promise((r) => setTimeout(r, 10)) + + await controller.open('https://example.com/', 5000, true) + + const normalTabs = await controller.listTabs(false) + expect(normalTabs.length).toBe(0) + }) + }) +}) diff --git a/src/main/mcpServers/browser/README.md b/src/main/mcpServers/browser/README.md new file mode 100644 index 0000000000..27d1307782 --- /dev/null +++ b/src/main/mcpServers/browser/README.md @@ -0,0 +1,177 @@ +# Browser MCP Server + +A Model Context Protocol (MCP) server for controlling browser windows via Chrome DevTools Protocol (CDP). + +## Features + +### ✨ User Data Persistence +- **Normal mode (default)**: Cookies, localStorage, and sessionStorage persist across browser restarts +- **Private mode**: Ephemeral browsing - no data persists (like incognito mode) + +### 🔄 Window Management +- Two browsing modes: normal (persistent) and private (ephemeral) +- Lazy idle timeout cleanup (cleaned on next window access) +- Maximum window limits to prevent resource exhaustion + +> **Note**: Normal mode uses a global `persist:default` partition shared by all clients. This means login sessions and stored data are accessible to any code using the MCP server. + +## Architecture + +### How It Works +``` +Normal Mode (BrowserWindow) +├─ Persistent Storage (partition: persist:default) ← Global, shared across all clients +└─ Tabs (BrowserView) ← created via newTab or automatically + +Private Mode (BrowserWindow) +├─ Ephemeral Storage (partition: private) ← No disk persistence +└─ Tabs (BrowserView) ← created via newTab or automatically +``` + +- **One Window Per Mode**: Normal and private modes each have their own window +- **Multi-Tab Support**: Use `newTab: true` for parallel URL requests +- **Storage Isolation**: Normal and private modes have completely separate storage + +## Available Tools + +### `open` +Open a URL in a browser window. Optionally return page content. +```json +{ + "url": "https://example.com", + "format": "markdown", + "timeout": 10000, + "privateMode": false, + "newTab": false, + "showWindow": false +} +``` +- `format`: If set (`html`, `txt`, `markdown`, `json`), returns page content in that format along with tabId. If not set, just opens the page and returns navigation info. +- `newTab`: Set to `true` to open in a new tab (required for parallel requests) +- `showWindow`: Set to `true` to display the browser window (useful for debugging) +- Returns (without format): `{ currentUrl, title, tabId }` +- Returns (with format): `{ tabId, content }` where content is in the specified format + +### `execute` +Execute JavaScript code in the page context. +```json +{ + "code": "document.title", + "timeout": 5000, + "privateMode": false, + "tabId": "optional-tab-id" +} +``` +- `tabId`: Target a specific tab (from `open` response) + +### `reset` +Reset browser windows and tabs. +```json +{ + "privateMode": false, + "tabId": "optional-tab-id" +} +``` +- Omit all parameters to close all windows +- Set `privateMode` to close a specific window +- Set both `privateMode` and `tabId` to close a specific tab only + +## Usage Examples + +### Basic Navigation +```typescript +// Open a URL in normal mode (data persists) +await controller.open('https://example.com') +``` + +### Fetch Page Content +```typescript +// Open URL and get content as markdown +await open({ url: 'https://example.com', format: 'markdown' }) + +// Open URL and get raw HTML +await open({ url: 'https://example.com', format: 'html' }) +``` + +### Multi-Tab / Parallel Requests +```typescript +// Open multiple URLs in parallel using newTab +const [page1, page2] = await Promise.all([ + controller.open('https://site1.com', 10000, false, true), // newTab: true + controller.open('https://site2.com', 10000, false, true) // newTab: true +]) + +// Execute on specific tab +await controller.execute('document.title', 5000, false, page1.tabId) + +// Close specific tab when done +await controller.reset(false, page1.tabId) +``` + +### Private Browsing +```typescript +// Open a URL in private mode (no data persistence) +await controller.open('https://example.com', 10000, true) + +// Cookies and localStorage won't persist after reset +``` + +### Data Persistence (Normal Mode) +```typescript +// Set data +await controller.open('https://example.com', 10000, false) +await controller.execute('localStorage.setItem("key", "value")', 5000, false) + +// Close window +await controller.reset(false) + +// Reopen - data persists! +await controller.open('https://example.com', 10000, false) +const value = await controller.execute('localStorage.getItem("key")', 5000, false) +// Returns: "value" +``` + +### No Persistence (Private Mode) +```typescript +// Set data in private mode +await controller.open('https://example.com', 10000, true) +await controller.execute('localStorage.setItem("key", "value")', 5000, true) + +// Close private window +await controller.reset(true) + +// Reopen - data is gone! +await controller.open('https://example.com', 10000, true) +const value = await controller.execute('localStorage.getItem("key")', 5000, true) +// Returns: null +``` + +## Configuration + +```typescript +const controller = new CdpBrowserController({ + maxWindows: 5, // Maximum concurrent windows + idleTimeoutMs: 5 * 60 * 1000 // 5 minutes idle timeout (lazy cleanup) +}) +``` + +> **Note on Idle Timeout**: Idle windows are cleaned up lazily when the next window is created or accessed, not on a background timer. + +## Best Practices + +1. **Use Normal Mode for Authentication**: When you need to stay logged in across sessions +2. **Use Private Mode for Sensitive Operations**: When you don't want data to persist +3. **Use `newTab: true` for Parallel Requests**: Avoid race conditions when fetching multiple URLs +4. **Resource Cleanup**: Call `reset()` when done, or `reset(privateMode, tabId)` to close specific tabs +5. **Error Handling**: All tool handlers return error responses on failure +6. **Timeout Configuration**: Adjust timeouts based on page complexity + +## Technical Details + +- **CDP Version**: 1.3 +- **User Agent**: Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:145.0) Gecko/20100101 Firefox/145.0 +- **Storage**: + - Normal mode: `persist:default` (disk-persisted, global) + - Private mode: `private` (memory only) +- **Window Size**: 1200x800 (default) +- **Visibility**: Windows hidden by default (use `showWindow: true` to display) diff --git a/src/main/mcpServers/browser/constants.ts b/src/main/mcpServers/browser/constants.ts new file mode 100644 index 0000000000..2b10943f8e --- /dev/null +++ b/src/main/mcpServers/browser/constants.ts @@ -0,0 +1,3 @@ +export const TAB_BAR_HEIGHT = 92 // Height for Chrome-style tab bar (42px) + address bar (50px) +export const SESSION_KEY_DEFAULT = 'default' +export const SESSION_KEY_PRIVATE = 'private' diff --git a/src/main/mcpServers/browser/controller.ts b/src/main/mcpServers/browser/controller.ts new file mode 100644 index 0000000000..9e0f5220ca --- /dev/null +++ b/src/main/mcpServers/browser/controller.ts @@ -0,0 +1,909 @@ +import { titleBarOverlayDark, titleBarOverlayLight } from '@main/config' +import { isMac } from '@main/constant' +import { randomUUID } from 'crypto' +import { app, BrowserView, BrowserWindow, nativeTheme } from 'electron' +import TurndownService from 'turndown' + +import { SESSION_KEY_DEFAULT, SESSION_KEY_PRIVATE, TAB_BAR_HEIGHT } from './constants' +import { TAB_BAR_HTML } from './tabbar-html' +import { logger, type TabInfo, userAgent, type WindowInfo } from './types' + +/** + * Controller for managing browser windows via Chrome DevTools Protocol (CDP). + * Supports two modes: normal (persistent) and private (ephemeral). + * Normal mode persists user data (cookies, localStorage, etc.) globally across all clients. + * Private mode is ephemeral - data is cleared when the window closes. + */ +export class CdpBrowserController { + private windows: Map = new Map() + private readonly maxWindows: number + private readonly idleTimeoutMs: number + private readonly turndownService: TurndownService + + constructor(options?: { maxWindows?: number; idleTimeoutMs?: number }) { + this.maxWindows = options?.maxWindows ?? 5 + this.idleTimeoutMs = options?.idleTimeoutMs ?? 5 * 60 * 1000 + this.turndownService = new TurndownService() + + // Listen for theme changes and update all tab bars + nativeTheme.on('updated', () => { + const isDark = nativeTheme.shouldUseDarkColors + for (const windowInfo of this.windows.values()) { + if (windowInfo.tabBarView && !windowInfo.tabBarView.webContents.isDestroyed()) { + windowInfo.tabBarView.webContents.executeJavaScript(`window.setTheme(${isDark})`).catch(() => { + // Ignore errors if tab bar is not ready + }) + } + } + }) + } + + private getWindowKey(privateMode: boolean): string { + return privateMode ? SESSION_KEY_PRIVATE : SESSION_KEY_DEFAULT + } + + private getPartition(privateMode: boolean): string { + return privateMode ? SESSION_KEY_PRIVATE : `persist:${SESSION_KEY_DEFAULT}` + } + + private async ensureAppReady() { + if (!app.isReady()) { + await app.whenReady() + } + } + + private touchWindow(windowKey: string) { + const windowInfo = this.windows.get(windowKey) + if (windowInfo) windowInfo.lastActive = Date.now() + } + + private touchTab(windowKey: string, tabId: string) { + const windowInfo = this.windows.get(windowKey) + if (windowInfo) { + const tab = windowInfo.tabs.get(tabId) + if (tab) tab.lastActive = Date.now() + windowInfo.lastActive = Date.now() + } + } + + private closeTabInternal(windowInfo: WindowInfo, tabId: string) { + try { + const tab = windowInfo.tabs.get(tabId) + if (!tab) return + + if (!tab.view.webContents.isDestroyed()) { + if (tab.view.webContents.debugger.isAttached()) { + tab.view.webContents.debugger.detach() + } + } + + // Remove view from window + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.removeBrowserView(tab.view) + } + + // Destroy the view using safe cast + const viewWithDestroy = tab.view as BrowserView & { destroy?: () => void } + if (viewWithDestroy.destroy) { + viewWithDestroy.destroy() + } + } catch (error) { + logger.warn('Error closing tab', { error, windowKey: windowInfo.windowKey, tabId }) + } + } + + private async ensureDebuggerAttached(dbg: Electron.Debugger, sessionKey: string) { + if (!dbg.isAttached()) { + try { + logger.info('Attaching debugger', { sessionKey }) + dbg.attach('1.3') + await dbg.sendCommand('Page.enable') + await dbg.sendCommand('Runtime.enable') + logger.info('Debugger attached and domains enabled') + } catch (error) { + logger.error('Failed to attach debugger', { error }) + throw error + } + } + } + + private sweepIdle() { + const now = Date.now() + const windowKeys = Array.from(this.windows.keys()) + for (const windowKey of windowKeys) { + const windowInfo = this.windows.get(windowKey) + if (!windowInfo) continue + if (now - windowInfo.lastActive > this.idleTimeoutMs) { + const tabIds = Array.from(windowInfo.tabs.keys()) + for (const tabId of tabIds) { + this.closeTabInternal(windowInfo, tabId) + } + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.close() + } + this.windows.delete(windowKey) + } + } + } + + private evictIfNeeded(newWindowKey: string) { + if (this.windows.size < this.maxWindows) return + let lruKey: string | null = null + let lruTime = Number.POSITIVE_INFINITY + for (const [key, windowInfo] of this.windows.entries()) { + if (key === newWindowKey) continue + if (windowInfo.lastActive < lruTime) { + lruTime = windowInfo.lastActive + lruKey = key + } + } + if (lruKey) { + const windowInfo = this.windows.get(lruKey) + if (windowInfo) { + for (const [tabId] of windowInfo.tabs.entries()) { + this.closeTabInternal(windowInfo, tabId) + } + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.close() + } + } + this.windows.delete(lruKey) + logger.info('Evicted window to respect maxWindows', { evicted: lruKey }) + } + } + + private sendTabBarUpdate(windowInfo: WindowInfo) { + if (!windowInfo.tabBarView || !windowInfo.tabBarView.webContents || windowInfo.tabBarView.webContents.isDestroyed()) + return + + const tabs = Array.from(windowInfo.tabs.values()).map((tab) => ({ + id: tab.id, + title: tab.title || 'New Tab', + url: tab.url, + isActive: tab.id === windowInfo.activeTabId + })) + + let activeUrl = '' + let canGoBack = false + let canGoForward = false + + if (windowInfo.activeTabId) { + const activeTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (activeTab && !activeTab.view.webContents.isDestroyed()) { + activeUrl = activeTab.view.webContents.getURL() + canGoBack = activeTab.view.webContents.canGoBack() + canGoForward = activeTab.view.webContents.canGoForward() + } + } + + const script = `window.updateTabs(${JSON.stringify(tabs)}, ${JSON.stringify(activeUrl)}, ${canGoBack}, ${canGoForward})` + windowInfo.tabBarView.webContents.executeJavaScript(script).catch((error) => { + logger.debug('Tab bar update failed', { error, windowKey: windowInfo.windowKey }) + }) + } + + private handleNavigateAction(windowInfo: WindowInfo, url: string) { + if (!windowInfo.activeTabId) return + const activeTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (!activeTab || activeTab.view.webContents.isDestroyed()) return + + let finalUrl = url.trim() + if (!/^https?:\/\//i.test(finalUrl)) { + if (/^[a-zA-Z0-9][a-zA-Z0-9-]*\.[a-zA-Z]{2,}/.test(finalUrl) || finalUrl.includes('.')) { + finalUrl = 'https://' + finalUrl + } else { + finalUrl = 'https://www.google.com/search?q=' + encodeURIComponent(finalUrl) + } + } + + activeTab.view.webContents.loadURL(finalUrl).catch((error) => { + logger.warn('Navigation failed in tab bar', { error, url: finalUrl, tabId: windowInfo.activeTabId }) + }) + } + + private handleBackAction(windowInfo: WindowInfo) { + if (!windowInfo.activeTabId) return + const activeTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (!activeTab || activeTab.view.webContents.isDestroyed()) return + + if (activeTab.view.webContents.canGoBack()) { + activeTab.view.webContents.goBack() + } + } + + private handleForwardAction(windowInfo: WindowInfo) { + if (!windowInfo.activeTabId) return + const activeTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (!activeTab || activeTab.view.webContents.isDestroyed()) return + + if (activeTab.view.webContents.canGoForward()) { + activeTab.view.webContents.goForward() + } + } + + private handleRefreshAction(windowInfo: WindowInfo) { + if (!windowInfo.activeTabId) return + const activeTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (!activeTab || activeTab.view.webContents.isDestroyed()) return + + activeTab.view.webContents.reload() + } + + private setupTabBarMessageHandler(windowInfo: WindowInfo) { + if (!windowInfo.tabBarView) return + + windowInfo.tabBarView.webContents.on('console-message', (_event, _level, message) => { + try { + const parsed = JSON.parse(message) + if (parsed?.channel === 'tabbar-action' && parsed?.payload) { + this.handleTabBarAction(windowInfo, parsed.payload) + } + } catch { + // Not a JSON message, ignore + } + }) + + windowInfo.tabBarView.webContents + .executeJavaScript(` + (function() { + window.addEventListener('message', function(e) { + if (e.data && e.data.channel === 'tabbar-action') { + console.log(JSON.stringify(e.data)); + } + }); + })(); + `) + .catch((error) => { + logger.debug('Tab bar message handler setup failed', { error, windowKey: windowInfo.windowKey }) + }) + } + + private handleTabBarAction(windowInfo: WindowInfo, action: { type: string; tabId?: string; url?: string }) { + if (action.type === 'switch' && action.tabId) { + this.switchTab(windowInfo.privateMode, action.tabId).catch((error) => { + logger.warn('Tab switch failed', { error, tabId: action.tabId, windowKey: windowInfo.windowKey }) + }) + } else if (action.type === 'close' && action.tabId) { + this.closeTab(windowInfo.privateMode, action.tabId).catch((error) => { + logger.warn('Tab close failed', { error, tabId: action.tabId, windowKey: windowInfo.windowKey }) + }) + } else if (action.type === 'new') { + this.createTab(windowInfo.privateMode, true) + .then(({ tabId }) => this.switchTab(windowInfo.privateMode, tabId)) + .catch((error) => { + logger.warn('New tab creation failed', { error, windowKey: windowInfo.windowKey }) + }) + } else if (action.type === 'navigate' && action.url) { + this.handleNavigateAction(windowInfo, action.url) + } else if (action.type === 'back') { + this.handleBackAction(windowInfo) + } else if (action.type === 'forward') { + this.handleForwardAction(windowInfo) + } else if (action.type === 'refresh') { + this.handleRefreshAction(windowInfo) + } else if (action.type === 'window-minimize') { + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.minimize() + } + } else if (action.type === 'window-maximize') { + if (!windowInfo.window.isDestroyed()) { + if (windowInfo.window.isMaximized()) { + windowInfo.window.unmaximize() + } else { + windowInfo.window.maximize() + } + } + } else if (action.type === 'window-close') { + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.close() + } + } + } + + private createTabBarView(windowInfo: WindowInfo): BrowserView { + const tabBarView = new BrowserView({ + webPreferences: { + contextIsolation: false, + sandbox: false, + nodeIntegration: false + } + }) + + windowInfo.window.addBrowserView(tabBarView) + const [width] = windowInfo.window.getContentSize() + tabBarView.setBounds({ x: 0, y: 0, width, height: TAB_BAR_HEIGHT }) + tabBarView.setAutoResize({ width: true, height: false }) + tabBarView.webContents.loadURL(`data:text/html;charset=utf-8,${encodeURIComponent(TAB_BAR_HTML)}`) + + tabBarView.webContents.on('did-finish-load', () => { + // Initialize platform for proper styling + const platform = isMac ? 'mac' : process.platform === 'win32' ? 'win' : 'linux' + tabBarView.webContents.executeJavaScript(`window.initPlatform('${platform}')`).catch((error) => { + logger.debug('Platform init failed', { error, windowKey: windowInfo.windowKey }) + }) + // Initialize theme + const isDark = nativeTheme.shouldUseDarkColors + tabBarView.webContents.executeJavaScript(`window.setTheme(${isDark})`).catch((error) => { + logger.debug('Theme init failed', { error, windowKey: windowInfo.windowKey }) + }) + this.setupTabBarMessageHandler(windowInfo) + this.sendTabBarUpdate(windowInfo) + }) + + return tabBarView + } + + private async createBrowserWindow( + windowKey: string, + privateMode: boolean, + showWindow = false + ): Promise { + await this.ensureAppReady() + + const partition = this.getPartition(privateMode) + + const win = new BrowserWindow({ + show: showWindow, + width: 1200, + height: 800, + ...(isMac + ? { + titleBarStyle: 'hidden', + titleBarOverlay: nativeTheme.shouldUseDarkColors ? titleBarOverlayDark : titleBarOverlayLight, + trafficLightPosition: { x: 8, y: 13 } + } + : { + frame: false // Frameless window for Windows and Linux + }), + webPreferences: { + contextIsolation: true, + sandbox: true, + nodeIntegration: false, + devTools: true, + partition + } + }) + + win.on('closed', () => { + const windowInfo = this.windows.get(windowKey) + if (windowInfo) { + const tabIds = Array.from(windowInfo.tabs.keys()) + for (const tabId of tabIds) { + this.closeTabInternal(windowInfo, tabId) + } + this.windows.delete(windowKey) + } + }) + + return win + } + + private async getOrCreateWindow(privateMode: boolean, showWindow = false): Promise { + await this.ensureAppReady() + this.sweepIdle() + + const windowKey = this.getWindowKey(privateMode) + + let windowInfo = this.windows.get(windowKey) + if (!windowInfo) { + this.evictIfNeeded(windowKey) + const window = await this.createBrowserWindow(windowKey, privateMode, showWindow) + windowInfo = { + windowKey, + privateMode, + window, + tabs: new Map(), + activeTabId: null, + lastActive: Date.now(), + tabBarView: undefined + } + this.windows.set(windowKey, windowInfo) + const tabBarView = this.createTabBarView(windowInfo) + windowInfo.tabBarView = tabBarView + + // Register resize listener once per window (not per tab) + // Capture windowKey to look up fresh windowInfo on each resize + windowInfo.window.on('resize', () => { + const info = this.windows.get(windowKey) + if (info) this.updateViewBounds(info) + }) + + logger.info('Created new window', { windowKey, privateMode }) + } else if (showWindow && !windowInfo.window.isDestroyed()) { + windowInfo.window.show() + } + + this.touchWindow(windowKey) + return windowInfo + } + + private updateViewBounds(windowInfo: WindowInfo) { + if (windowInfo.window.isDestroyed()) return + + const [width, height] = windowInfo.window.getContentSize() + + // Update tab bar bounds + if (windowInfo.tabBarView && !windowInfo.tabBarView.webContents.isDestroyed()) { + windowInfo.tabBarView.setBounds({ x: 0, y: 0, width, height: TAB_BAR_HEIGHT }) + } + + // Update active tab view bounds + if (windowInfo.activeTabId) { + const activeTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (activeTab && !activeTab.view.webContents.isDestroyed()) { + activeTab.view.setBounds({ + x: 0, + y: TAB_BAR_HEIGHT, + width, + height: Math.max(0, height - TAB_BAR_HEIGHT) + }) + } + } + } + + /** + * Creates a new tab in the window + * @param privateMode - If true, uses private browsing mode (default: false) + * @param showWindow - If true, shows the browser window (default: false) + * @returns Tab ID and view + */ + public async createTab(privateMode = false, showWindow = false): Promise<{ tabId: string; view: BrowserView }> { + const windowInfo = await this.getOrCreateWindow(privateMode, showWindow) + const tabId = randomUUID() + const partition = this.getPartition(privateMode) + + const view = new BrowserView({ + webPreferences: { + contextIsolation: true, + sandbox: true, + nodeIntegration: false, + devTools: true, + partition + } + }) + + view.webContents.setUserAgent(userAgent) + + const windowKey = windowInfo.windowKey + view.webContents.on('did-start-loading', () => logger.info(`did-start-loading`, { windowKey, tabId })) + view.webContents.on('dom-ready', () => logger.info(`dom-ready`, { windowKey, tabId })) + view.webContents.on('did-finish-load', () => logger.info(`did-finish-load`, { windowKey, tabId })) + view.webContents.on('did-fail-load', (_e, code, desc) => logger.warn('Navigation failed', { code, desc })) + + view.webContents.on('destroyed', () => { + windowInfo.tabs.delete(tabId) + if (windowInfo.activeTabId === tabId) { + windowInfo.activeTabId = windowInfo.tabs.keys().next().value ?? null + if (windowInfo.activeTabId) { + const newActiveTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (newActiveTab && !windowInfo.window.isDestroyed()) { + windowInfo.window.addBrowserView(newActiveTab.view) + this.updateViewBounds(windowInfo) + } + } + } + this.sendTabBarUpdate(windowInfo) + }) + + view.webContents.on('page-title-updated', (_event, title) => { + tabInfo.title = title + this.sendTabBarUpdate(windowInfo) + }) + + view.webContents.on('did-navigate', (_event, url) => { + tabInfo.url = url + this.sendTabBarUpdate(windowInfo) + }) + + view.webContents.on('did-navigate-in-page', (_event, url) => { + tabInfo.url = url + this.sendTabBarUpdate(windowInfo) + }) + + // Handle new window requests (e.g., target="_blank" links) - open in new tab instead + view.webContents.setWindowOpenHandler(({ url }) => { + // Create a new tab and navigate to the URL + this.createTab(privateMode, true) + .then(({ tabId: newTabId }) => { + return this.switchTab(privateMode, newTabId).then(() => { + const newTab = windowInfo.tabs.get(newTabId) + if (newTab && !newTab.view.webContents.isDestroyed()) { + newTab.view.webContents.loadURL(url) + } + }) + }) + .catch((error) => { + logger.warn('Failed to open link in new tab', { error, url }) + }) + return { action: 'deny' } + }) + + const tabInfo: TabInfo = { + id: tabId, + view, + url: '', + title: '', + lastActive: Date.now() + } + + windowInfo.tabs.set(tabId, tabInfo) + + // Set as active tab and add to window + if (!windowInfo.activeTabId || windowInfo.tabs.size === 1) { + windowInfo.activeTabId = tabId + windowInfo.window.addBrowserView(view) + this.updateViewBounds(windowInfo) + } + + this.sendTabBarUpdate(windowInfo) + logger.info('Created new tab', { windowKey, tabId, privateMode }) + return { tabId, view } + } + + /** + * Gets an existing tab or creates a new one + * @param privateMode - Whether to use private browsing mode + * @param tabId - Optional specific tab ID to use + * @param newTab - If true, always create a new tab (useful for parallel requests) + * @param showWindow - If true, shows the browser window (default: false) + */ + private async getTab( + privateMode: boolean, + tabId?: string, + newTab?: boolean, + showWindow = false + ): Promise<{ tabId: string; tab: TabInfo }> { + const windowInfo = await this.getOrCreateWindow(privateMode, showWindow) + + // If newTab is requested, create a fresh tab + if (newTab) { + const { tabId: freshTabId } = await this.createTab(privateMode, showWindow) + const tab = windowInfo.tabs.get(freshTabId) + if (!tab) { + throw new Error(`Tab ${freshTabId} was created but not found - it may have been closed`) + } + return { tabId: freshTabId, tab } + } + + if (tabId) { + const tab = windowInfo.tabs.get(tabId) + if (tab && !tab.view.webContents.isDestroyed()) { + this.touchTab(windowInfo.windowKey, tabId) + return { tabId, tab } + } + } + + // Use active tab or create new one + if (windowInfo.activeTabId) { + const activeTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (activeTab && !activeTab.view.webContents.isDestroyed()) { + this.touchTab(windowInfo.windowKey, windowInfo.activeTabId) + return { tabId: windowInfo.activeTabId, tab: activeTab } + } + } + + // Create new tab + const { tabId: newTabId } = await this.createTab(privateMode, showWindow) + const tab = windowInfo.tabs.get(newTabId) + if (!tab) { + throw new Error(`Tab ${newTabId} was created but not found - it may have been closed`) + } + return { tabId: newTabId, tab } + } + + /** + * Opens a URL in a browser window and waits for navigation to complete. + * @param url - The URL to navigate to + * @param timeout - Navigation timeout in milliseconds (default: 10000) + * @param privateMode - If true, uses private browsing mode (default: false) + * @param newTab - If true, always creates a new tab (useful for parallel requests) + * @param showWindow - If true, shows the browser window (default: false) + * @returns Object containing the current URL, page title, and tab ID after navigation + */ + public async open(url: string, timeout = 10000, privateMode = false, newTab = false, showWindow = false) { + const { tabId: actualTabId, tab } = await this.getTab(privateMode, undefined, newTab, showWindow) + const view = tab.view + const windowKey = this.getWindowKey(privateMode) + + logger.info('Loading URL', { url, windowKey, tabId: actualTabId, privateMode }) + const { webContents } = view + this.touchTab(windowKey, actualTabId) + + let resolved = false + let timeoutHandle: ReturnType | undefined + let onFinish: () => void + let onDomReady: () => void + let onFail: (_event: Electron.Event, code: number, desc: string) => void + + const cleanup = () => { + if (timeoutHandle) clearTimeout(timeoutHandle) + webContents.removeListener('did-finish-load', onFinish) + webContents.removeListener('did-fail-load', onFail) + webContents.removeListener('dom-ready', onDomReady) + } + + const loadPromise = new Promise((resolve, reject) => { + onFinish = () => { + if (resolved) return + resolved = true + cleanup() + resolve() + } + onDomReady = () => { + if (resolved) return + resolved = true + cleanup() + resolve() + } + onFail = (_event: Electron.Event, code: number, desc: string) => { + if (resolved) return + resolved = true + cleanup() + reject(new Error(`Navigation failed (${code}): ${desc}`)) + } + webContents.once('did-finish-load', onFinish) + webContents.once('dom-ready', onDomReady) + webContents.once('did-fail-load', onFail) + }) + + const timeoutPromise = new Promise((_, reject) => { + timeoutHandle = setTimeout(() => reject(new Error('Navigation timed out')), timeout) + }) + + try { + await Promise.race([view.webContents.loadURL(url), loadPromise, timeoutPromise]) + } finally { + cleanup() + } + + const currentUrl = webContents.getURL() + const title = await webContents.getTitle() + + // Update tab info + tab.url = currentUrl + tab.title = title + + return { currentUrl, title, tabId: actualTabId } + } + + /** + * Executes JavaScript code in the page context using Chrome DevTools Protocol. + * @param code - JavaScript code to evaluate in the page + * @param timeout - Execution timeout in milliseconds (default: 5000) + * @param privateMode - If true, targets the private browsing window (default: false) + * @param tabId - Optional specific tab ID to target; if omitted, uses the active tab + * @returns The result value from the evaluated code, or null if no value returned + */ + public async execute(code: string, timeout = 5000, privateMode = false, tabId?: string) { + const { tabId: actualTabId, tab } = await this.getTab(privateMode, tabId) + const windowKey = this.getWindowKey(privateMode) + this.touchTab(windowKey, actualTabId) + const dbg = tab.view.webContents.debugger + + await this.ensureDebuggerAttached(dbg, windowKey) + + let timeoutHandle: ReturnType | undefined + const evalPromise = dbg.sendCommand('Runtime.evaluate', { + expression: code, + awaitPromise: true, + returnByValue: true + }) + + try { + const result = await Promise.race([ + evalPromise, + new Promise((_, reject) => { + timeoutHandle = setTimeout(() => reject(new Error('Execution timed out')), timeout) + }) + ]) + + const evalResult = result as any + + if (evalResult?.exceptionDetails) { + const message = evalResult.exceptionDetails.exception?.description || 'Unknown script error' + logger.warn('Runtime.evaluate raised exception', { message }) + throw new Error(message) + } + + const value = evalResult?.result?.value ?? evalResult?.result?.description ?? null + return value + } finally { + if (timeoutHandle) clearTimeout(timeoutHandle) + } + } + + public async reset(privateMode?: boolean, tabId?: string) { + if (privateMode !== undefined && tabId) { + const windowKey = this.getWindowKey(privateMode) + const windowInfo = this.windows.get(windowKey) + if (windowInfo) { + this.closeTabInternal(windowInfo, tabId) + windowInfo.tabs.delete(tabId) + + // If no tabs left, close the window + if (windowInfo.tabs.size === 0) { + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.close() + } + this.windows.delete(windowKey) + logger.info('Browser CDP window closed (last tab closed)', { windowKey, tabId }) + return + } + + if (windowInfo.activeTabId === tabId) { + windowInfo.activeTabId = windowInfo.tabs.keys().next().value ?? null + if (windowInfo.activeTabId) { + const newActiveTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (newActiveTab && !windowInfo.window.isDestroyed()) { + windowInfo.window.addBrowserView(newActiveTab.view) + this.updateViewBounds(windowInfo) + } + } + } + this.sendTabBarUpdate(windowInfo) + } + logger.info('Browser CDP tab reset', { windowKey, tabId }) + return + } + + if (privateMode !== undefined) { + const windowKey = this.getWindowKey(privateMode) + const windowInfo = this.windows.get(windowKey) + if (windowInfo) { + const tabIds = Array.from(windowInfo.tabs.keys()) + for (const tid of tabIds) { + this.closeTabInternal(windowInfo, tid) + } + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.close() + } + } + this.windows.delete(windowKey) + logger.info('Browser CDP window reset', { windowKey, privateMode }) + return + } + + const allWindowInfos = Array.from(this.windows.values()) + for (const windowInfo of allWindowInfos) { + const tabIds = Array.from(windowInfo.tabs.keys()) + for (const tid of tabIds) { + this.closeTabInternal(windowInfo, tid) + } + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.close() + } + } + this.windows.clear() + logger.info('Browser CDP context reset (all windows)') + } + + /** + * Fetches a URL and returns content in the specified format. + * @param url - The URL to fetch + * @param format - Output format: 'html', 'txt', 'markdown', or 'json' (default: 'markdown') + * @param timeout - Navigation timeout in milliseconds (default: 10000) + * @param privateMode - If true, uses private browsing mode (default: false) + * @param newTab - If true, always creates a new tab (useful for parallel requests) + * @param showWindow - If true, shows the browser window (default: false) + * @returns Object with tabId and content in the requested format. For 'json', content is parsed object or { data: rawContent } if parsing fails + */ + public async fetch( + url: string, + format: 'html' | 'txt' | 'markdown' | 'json' = 'markdown', + timeout = 10000, + privateMode = false, + newTab = false, + showWindow = false + ): Promise<{ tabId: string; content: string | object }> { + const { tabId } = await this.open(url, timeout, privateMode, newTab, showWindow) + + const { tab } = await this.getTab(privateMode, tabId, false, showWindow) + const dbg = tab.view.webContents.debugger + const windowKey = this.getWindowKey(privateMode) + + await this.ensureDebuggerAttached(dbg, windowKey) + + let expression: string + if (format === 'json' || format === 'txt') { + expression = 'document.body.innerText' + } else { + expression = 'document.documentElement.outerHTML' + } + + let timeoutHandle: ReturnType | undefined + try { + const result = (await Promise.race([ + dbg.sendCommand('Runtime.evaluate', { + expression, + returnByValue: true + }), + new Promise((_, reject) => { + timeoutHandle = setTimeout(() => reject(new Error('Fetch content timed out')), timeout) + }) + ])) as { result?: { value?: string } } + + const rawContent = result?.result?.value ?? '' + + let content: string | object + if (format === 'markdown') { + content = this.turndownService.turndown(rawContent) + } else if (format === 'json') { + try { + content = JSON.parse(rawContent) + } catch (parseError) { + logger.warn('JSON parse failed, returning raw content', { + url, + contentLength: rawContent.length, + error: parseError + }) + content = { data: rawContent } + } + } else { + content = rawContent + } + + return { tabId, content } + } finally { + if (timeoutHandle) clearTimeout(timeoutHandle) + } + } + + /** + * Lists all tabs in a window + * @param privateMode - If true, lists tabs from private window (default: false) + */ + public async listTabs(privateMode = false): Promise> { + const windowKey = this.getWindowKey(privateMode) + const windowInfo = this.windows.get(windowKey) + if (!windowInfo) return [] + + return Array.from(windowInfo.tabs.values()).map((tab) => ({ + tabId: tab.id, + url: tab.url, + title: tab.title + })) + } + + /** + * Closes a specific tab + * @param privateMode - If true, closes tab from private window (default: false) + * @param tabId - Tab identifier to close + */ + public async closeTab(privateMode: boolean, tabId: string) { + await this.reset(privateMode, tabId) + } + + /** + * Switches the active tab + * @param privateMode - If true, switches tab in private window (default: false) + * @param tabId - Tab identifier to switch to + */ + public async switchTab(privateMode: boolean, tabId: string) { + const windowKey = this.getWindowKey(privateMode) + const windowInfo = this.windows.get(windowKey) + if (!windowInfo) throw new Error(`Window not found for ${privateMode ? 'private' : 'normal'} mode`) + + const tab = windowInfo.tabs.get(tabId) + if (!tab) throw new Error(`Tab ${tabId} not found`) + + // Remove previous active tab view (but NOT the tabBarView) + if (windowInfo.activeTabId && windowInfo.activeTabId !== tabId) { + const prevTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (prevTab && !windowInfo.window.isDestroyed()) { + windowInfo.window.removeBrowserView(prevTab.view) + } + } + + windowInfo.activeTabId = tabId + + // Add the new active tab view + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.addBrowserView(tab.view) + this.updateViewBounds(windowInfo) + } + + this.touchTab(windowKey, tabId) + this.sendTabBarUpdate(windowInfo) + logger.info('Switched active tab', { windowKey, tabId, privateMode }) + } +} diff --git a/src/main/mcpServers/browser/index.ts b/src/main/mcpServers/browser/index.ts new file mode 100644 index 0000000000..fbdb0a0f6e --- /dev/null +++ b/src/main/mcpServers/browser/index.ts @@ -0,0 +1,3 @@ +export { CdpBrowserController } from './controller' +export { BrowserServer } from './server' +export { BrowserServer as default } from './server' diff --git a/src/main/mcpServers/browser/server.ts b/src/main/mcpServers/browser/server.ts new file mode 100644 index 0000000000..3e889a7b66 --- /dev/null +++ b/src/main/mcpServers/browser/server.ts @@ -0,0 +1,50 @@ +import type { Server } from '@modelcontextprotocol/sdk/server/index.js' +import { Server as MCServer } from '@modelcontextprotocol/sdk/server/index.js' +import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js' +import { app } from 'electron' + +import { CdpBrowserController } from './controller' +import { toolDefinitions, toolHandlers } from './tools' + +export class BrowserServer { + public server: Server + private controller = new CdpBrowserController() + + constructor() { + const server = new MCServer( + { + name: '@cherry/browser', + version: '0.1.0' + }, + { + capabilities: { + resources: {}, + tools: {} + } + } + ) + + server.setRequestHandler(ListToolsRequestSchema, async () => { + return { + tools: toolDefinitions + } + }) + + server.setRequestHandler(CallToolRequestSchema, async (request) => { + const { name, arguments: args } = request.params + const handler = toolHandlers[name] + if (!handler) { + throw new Error('Tool not found') + } + return handler(this.controller, args) + }) + + app.on('before-quit', () => { + void this.controller.reset() + }) + + this.server = server + } +} + +export default BrowserServer diff --git a/src/main/mcpServers/browser/tabbar-html.ts b/src/main/mcpServers/browser/tabbar-html.ts new file mode 100644 index 0000000000..4a1bec0e0d --- /dev/null +++ b/src/main/mcpServers/browser/tabbar-html.ts @@ -0,0 +1,567 @@ +export const TAB_BAR_HTML = ` + + + + + + +
+
+
+ +
+
+ +
+ + + +
+
+
+ + + +
+ +
+
+ + +` diff --git a/src/main/mcpServers/browser/tools/execute.ts b/src/main/mcpServers/browser/tools/execute.ts new file mode 100644 index 0000000000..09cd79f2d1 --- /dev/null +++ b/src/main/mcpServers/browser/tools/execute.ts @@ -0,0 +1,52 @@ +import * as z from 'zod' + +import type { CdpBrowserController } from '../controller' +import { logger } from '../types' +import { errorResponse, successResponse } from './utils' + +export const ExecuteSchema = z.object({ + code: z.string().describe('JavaScript code to run in page context'), + timeout: z.number().default(5000).describe('Execution timeout in ms (default: 5000)'), + privateMode: z.boolean().optional().describe('Target private session (default: false)'), + tabId: z.string().optional().describe('Target specific tab by ID') +}) + +export const executeToolDefinition = { + name: 'execute', + description: + 'Run JavaScript in the currently open page. Use after open to: click elements, fill forms, extract content (document.body.innerText), or interact with the page. The page must be opened first with open or fetch.', + inputSchema: { + type: 'object', + properties: { + code: { + type: 'string', + description: + 'JavaScript to evaluate. Examples: document.body.innerText (get text), document.querySelector("button").click() (click), document.title (get title)' + }, + timeout: { + type: 'number', + description: 'Execution timeout in ms (default: 5000)' + }, + privateMode: { + type: 'boolean', + description: 'Target private session (default: false)' + }, + tabId: { + type: 'string', + description: 'Target specific tab by ID (from open response)' + } + }, + required: ['code'] + } +} + +export async function handleExecute(controller: CdpBrowserController, args: unknown) { + const { code, timeout, privateMode, tabId } = ExecuteSchema.parse(args) + try { + const value = await controller.execute(code, timeout, privateMode ?? false, tabId) + return successResponse(typeof value === 'string' ? value : JSON.stringify(value)) + } catch (error) { + logger.error('Execute failed', { error, code: code.slice(0, 100), privateMode, tabId }) + return errorResponse(error as Error) + } +} diff --git a/src/main/mcpServers/browser/tools/index.ts b/src/main/mcpServers/browser/tools/index.ts new file mode 100644 index 0000000000..5ba6fcae6d --- /dev/null +++ b/src/main/mcpServers/browser/tools/index.ts @@ -0,0 +1,22 @@ +export { ExecuteSchema, executeToolDefinition, handleExecute } from './execute' +export { handleOpen, OpenSchema, openToolDefinition } from './open' +export { handleReset, resetToolDefinition } from './reset' + +import type { CdpBrowserController } from '../controller' +import { executeToolDefinition, handleExecute } from './execute' +import { handleOpen, openToolDefinition } from './open' +import { handleReset, resetToolDefinition } from './reset' + +export const toolDefinitions = [openToolDefinition, executeToolDefinition, resetToolDefinition] + +export const toolHandlers: Record< + string, + ( + controller: CdpBrowserController, + args: unknown + ) => Promise<{ content: { type: string; text: string }[]; isError: boolean }> +> = { + open: handleOpen, + execute: handleExecute, + reset: handleReset +} diff --git a/src/main/mcpServers/browser/tools/open.ts b/src/main/mcpServers/browser/tools/open.ts new file mode 100644 index 0000000000..6ea9ec9e48 --- /dev/null +++ b/src/main/mcpServers/browser/tools/open.ts @@ -0,0 +1,81 @@ +import * as z from 'zod' + +import type { CdpBrowserController } from '../controller' +import { logger } from '../types' +import { errorResponse, successResponse } from './utils' + +export const OpenSchema = z.object({ + url: z.url().describe('URL to navigate to'), + format: z + .enum(['html', 'txt', 'markdown', 'json']) + .optional() + .describe('If set, return page content in this format. If not set, just open the page and return tabId.'), + timeout: z.number().optional().describe('Navigation timeout in ms (default: 10000)'), + privateMode: z.boolean().optional().describe('Use incognito mode, no data persisted (default: false)'), + newTab: z.boolean().optional().describe('Open in new tab, required for parallel requests (default: false)'), + showWindow: z.boolean().optional().default(true).describe('Show browser window (default: true)') +}) + +export const openToolDefinition = { + name: 'open', + description: + 'Navigate to a URL in a browser window. If format is specified, returns { tabId, content } with page content in that format. Otherwise, returns { currentUrl, title, tabId } for subsequent operations with execute tool. Set newTab=true when opening multiple URLs in parallel.', + inputSchema: { + type: 'object', + properties: { + url: { + type: 'string', + description: 'URL to navigate to' + }, + format: { + type: 'string', + enum: ['html', 'txt', 'markdown', 'json'], + description: 'If set, return page content in this format. If not set, just open the page and return tabId.' + }, + timeout: { + type: 'number', + description: 'Navigation timeout in ms (default: 10000)' + }, + privateMode: { + type: 'boolean', + description: 'Use incognito mode, no data persisted (default: false)' + }, + newTab: { + type: 'boolean', + description: 'Open in new tab, required for parallel requests (default: false)' + }, + showWindow: { + type: 'boolean', + description: 'Show browser window (default: true)' + } + }, + required: ['url'] + } +} + +export async function handleOpen(controller: CdpBrowserController, args: unknown) { + try { + const { url, format, timeout, privateMode, newTab, showWindow } = OpenSchema.parse(args) + + if (format) { + const { tabId, content } = await controller.fetch( + url, + format, + timeout ?? 10000, + privateMode ?? false, + newTab ?? false, + showWindow + ) + return successResponse(JSON.stringify({ tabId, content })) + } else { + const res = await controller.open(url, timeout ?? 10000, privateMode ?? false, newTab ?? false, showWindow) + return successResponse(JSON.stringify(res)) + } + } catch (error) { + logger.error('Open failed', { + error, + url: args && typeof args === 'object' && 'url' in args ? args.url : undefined + }) + return errorResponse(error instanceof Error ? error : String(error)) + } +} diff --git a/src/main/mcpServers/browser/tools/reset.ts b/src/main/mcpServers/browser/tools/reset.ts new file mode 100644 index 0000000000..fe67b74b1d --- /dev/null +++ b/src/main/mcpServers/browser/tools/reset.ts @@ -0,0 +1,43 @@ +import * as z from 'zod' + +import type { CdpBrowserController } from '../controller' +import { logger } from '../types' +import { errorResponse, successResponse } from './utils' + +export const ResetSchema = z.object({ + privateMode: z.boolean().optional().describe('true=private window, false=normal window, omit=all windows'), + tabId: z.string().optional().describe('Close specific tab only (requires privateMode)') +}) + +export const resetToolDefinition = { + name: 'reset', + description: + 'Close browser windows and clear state. Call when done browsing to free resources. Omit all parameters to close everything.', + inputSchema: { + type: 'object', + properties: { + privateMode: { + type: 'boolean', + description: 'true=reset private window only, false=reset normal window only, omit=reset all' + }, + tabId: { + type: 'string', + description: 'Close specific tab only (requires privateMode to be set)' + } + } + } +} + +export async function handleReset(controller: CdpBrowserController, args: unknown) { + try { + const { privateMode, tabId } = ResetSchema.parse(args) + await controller.reset(privateMode, tabId) + return successResponse('reset') + } catch (error) { + logger.error('Reset failed', { + error, + privateMode: args && typeof args === 'object' && 'privateMode' in args ? args.privateMode : undefined + }) + return errorResponse(error instanceof Error ? error : String(error)) + } +} diff --git a/src/main/mcpServers/browser/tools/utils.ts b/src/main/mcpServers/browser/tools/utils.ts new file mode 100644 index 0000000000..f5272ac81c --- /dev/null +++ b/src/main/mcpServers/browser/tools/utils.ts @@ -0,0 +1,14 @@ +export function successResponse(text: string) { + return { + content: [{ type: 'text', text }], + isError: false + } +} + +export function errorResponse(error: Error | string) { + const message = error instanceof Error ? error.message : error + return { + content: [{ type: 'text', text: message }], + isError: true + } +} diff --git a/src/main/mcpServers/browser/types.ts b/src/main/mcpServers/browser/types.ts new file mode 100644 index 0000000000..a59fe59665 --- /dev/null +++ b/src/main/mcpServers/browser/types.ts @@ -0,0 +1,24 @@ +import { loggerService } from '@logger' +import type { BrowserView, BrowserWindow } from 'electron' + +export const logger = loggerService.withContext('MCPBrowserCDP') +export const userAgent = + 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36' + +export interface TabInfo { + id: string + view: BrowserView + url: string + title: string + lastActive: number +} + +export interface WindowInfo { + windowKey: string + privateMode: boolean + window: BrowserWindow + tabs: Map + activeTabId: string | null + lastActive: number + tabBarView?: BrowserView +} diff --git a/src/main/mcpServers/factory.ts b/src/main/mcpServers/factory.ts index 2323701e49..909901c1c8 100644 --- a/src/main/mcpServers/factory.ts +++ b/src/main/mcpServers/factory.ts @@ -4,6 +4,7 @@ import type { BuiltinMCPServerName } from '@types' import { BuiltinMCPServerNames } from '@types' import BraveSearchServer from './brave-search' +import BrowserServer from './browser' import DiDiMcpServer from './didi-mcp' import DifyKnowledgeServer from './dify-knowledge' import FetchServer from './fetch' @@ -35,7 +36,7 @@ export function createInMemoryMCPServer( return new FetchServer().server } case BuiltinMCPServerNames.filesystem: { - return new FileSystemServer(args).server + return new FileSystemServer(envs.WORKSPACE_ROOT).server } case BuiltinMCPServerNames.difyKnowledge: { const difyKey = envs.DIFY_KEY @@ -48,6 +49,9 @@ export function createInMemoryMCPServer( const apiKey = envs.DIDI_API_KEY return new DiDiMcpServer(apiKey).server } + case BuiltinMCPServerNames.browser: { + return new BrowserServer().server + } default: throw new Error(`Unknown in-memory MCP server: ${name}`) } diff --git a/src/main/mcpServers/filesystem.ts b/src/main/mcpServers/filesystem.ts deleted file mode 100644 index ba10783881..0000000000 --- a/src/main/mcpServers/filesystem.ts +++ /dev/null @@ -1,652 +0,0 @@ -// port https://github.com/modelcontextprotocol/servers/blob/main/src/filesystem/index.ts - -import { loggerService } from '@logger' -import { Server } from '@modelcontextprotocol/sdk/server/index.js' -import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js' -import { createTwoFilesPatch } from 'diff' -import fs from 'fs/promises' -import { minimatch } from 'minimatch' -import os from 'os' -import path from 'path' -import * as z from 'zod' - -const logger = loggerService.withContext('MCP:FileSystemServer') - -// Normalize all paths consistently -function normalizePath(p: string): string { - return path.normalize(p) -} - -function expandHome(filepath: string): string { - if (filepath.startsWith('~/') || filepath === '~') { - return path.join(os.homedir(), filepath.slice(1)) - } - return filepath -} - -// Security utilities -async function validatePath(allowedDirectories: string[], requestedPath: string): Promise { - const expandedPath = expandHome(requestedPath) - const absolute = path.isAbsolute(expandedPath) - ? path.resolve(expandedPath) - : path.resolve(process.cwd(), expandedPath) - - const normalizedRequested = normalizePath(absolute) - - // Check if path is within allowed directories - const isAllowed = allowedDirectories.some((dir) => normalizedRequested.startsWith(dir)) - if (!isAllowed) { - throw new Error( - `Access denied - path outside allowed directories: ${absolute} not in ${allowedDirectories.join(', ')}` - ) - } - - // Handle symlinks by checking their real path - try { - const realPath = await fs.realpath(absolute) - const normalizedReal = normalizePath(realPath) - const isRealPathAllowed = allowedDirectories.some((dir) => normalizedReal.startsWith(dir)) - if (!isRealPathAllowed) { - throw new Error('Access denied - symlink target outside allowed directories') - } - return realPath - } catch (error) { - // For new files that don't exist yet, verify parent directory - const parentDir = path.dirname(absolute) - try { - const realParentPath = await fs.realpath(parentDir) - const normalizedParent = normalizePath(realParentPath) - const isParentAllowed = allowedDirectories.some((dir) => normalizedParent.startsWith(dir)) - if (!isParentAllowed) { - throw new Error('Access denied - parent directory outside allowed directories') - } - return absolute - } catch { - throw new Error(`Parent directory does not exist: ${parentDir}`) - } - } -} - -// Schema definitions -const ReadFileArgsSchema = z.object({ - path: z.string() -}) - -const ReadMultipleFilesArgsSchema = z.object({ - paths: z.array(z.string()) -}) - -const WriteFileArgsSchema = z.object({ - path: z.string(), - content: z.string() -}) - -const EditOperation = z.object({ - oldText: z.string().describe('Text to search for - must match exactly'), - newText: z.string().describe('Text to replace with') -}) - -const EditFileArgsSchema = z.object({ - path: z.string(), - edits: z.array(EditOperation), - dryRun: z.boolean().default(false).describe('Preview changes using git-style diff format') -}) - -const CreateDirectoryArgsSchema = z.object({ - path: z.string() -}) - -const ListDirectoryArgsSchema = z.object({ - path: z.string() -}) - -const DirectoryTreeArgsSchema = z.object({ - path: z.string() -}) - -const MoveFileArgsSchema = z.object({ - source: z.string(), - destination: z.string() -}) - -const SearchFilesArgsSchema = z.object({ - path: z.string(), - pattern: z.string(), - excludePatterns: z.array(z.string()).optional().default([]) -}) - -const GetFileInfoArgsSchema = z.object({ - path: z.string() -}) - -interface FileInfo { - size: number - created: Date - modified: Date - accessed: Date - isDirectory: boolean - isFile: boolean - permissions: string -} - -// Tool implementations -async function getFileStats(filePath: string): Promise { - const stats = await fs.stat(filePath) - return { - size: stats.size, - created: stats.birthtime, - modified: stats.mtime, - accessed: stats.atime, - isDirectory: stats.isDirectory(), - isFile: stats.isFile(), - permissions: stats.mode.toString(8).slice(-3) - } -} - -async function searchFiles( - allowedDirectories: string[], - rootPath: string, - pattern: string, - excludePatterns: string[] = [] -): Promise { - const results: string[] = [] - - async function search(currentPath: string) { - const entries = await fs.readdir(currentPath, { withFileTypes: true }) - - for (const entry of entries) { - const fullPath = path.join(currentPath, entry.name) - - try { - // Validate each path before processing - await validatePath(allowedDirectories, fullPath) - - // Check if path matches any exclude pattern - const relativePath = path.relative(rootPath, fullPath) - const shouldExclude = excludePatterns.some((pattern) => { - const globPattern = pattern.includes('*') ? pattern : `**/${pattern}/**` - return minimatch(relativePath, globPattern, { dot: true }) - }) - - if (shouldExclude) { - continue - } - - if (entry.name.toLowerCase().includes(pattern.toLowerCase())) { - results.push(fullPath) - } - - if (entry.isDirectory()) { - await search(fullPath) - } - } catch (error) { - // Skip invalid paths during search - } - } - } - - await search(rootPath) - return results -} - -// file editing and diffing utilities -function normalizeLineEndings(text: string): string { - return text.replace(/\r\n/g, '\n') -} - -function createUnifiedDiff(originalContent: string, newContent: string, filepath: string = 'file'): string { - // Ensure consistent line endings for diff - const normalizedOriginal = normalizeLineEndings(originalContent) - const normalizedNew = normalizeLineEndings(newContent) - - return createTwoFilesPatch(filepath, filepath, normalizedOriginal, normalizedNew, 'original', 'modified') -} - -async function applyFileEdits( - filePath: string, - edits: Array<{ oldText: string; newText: string }>, - dryRun = false -): Promise { - // Read file content and normalize line endings - const content = normalizeLineEndings(await fs.readFile(filePath, 'utf-8')) - - // Apply edits sequentially - let modifiedContent = content - for (const edit of edits) { - const normalizedOld = normalizeLineEndings(edit.oldText) - const normalizedNew = normalizeLineEndings(edit.newText) - - // If exact match exists, use it - if (modifiedContent.includes(normalizedOld)) { - modifiedContent = modifiedContent.replace(normalizedOld, normalizedNew) - continue - } - - // Otherwise, try line-by-line matching with flexibility for whitespace - const oldLines = normalizedOld.split('\n') - const contentLines = modifiedContent.split('\n') - let matchFound = false - - for (let i = 0; i <= contentLines.length - oldLines.length; i++) { - const potentialMatch = contentLines.slice(i, i + oldLines.length) - - // Compare lines with normalized whitespace - const isMatch = oldLines.every((oldLine, j) => { - const contentLine = potentialMatch[j] - return oldLine.trim() === contentLine.trim() - }) - - if (isMatch) { - // Preserve original indentation of first line - const originalIndent = contentLines[i].match(/^\s*/)?.[0] || '' - const newLines = normalizedNew.split('\n').map((line, j) => { - if (j === 0) return originalIndent + line.trimStart() - // For subsequent lines, try to preserve relative indentation - const oldIndent = oldLines[j]?.match(/^\s*/)?.[0] || '' - const newIndent = line.match(/^\s*/)?.[0] || '' - if (oldIndent && newIndent) { - const relativeIndent = newIndent.length - oldIndent.length - return originalIndent + ' '.repeat(Math.max(0, relativeIndent)) + line.trimStart() - } - return line - }) - - contentLines.splice(i, oldLines.length, ...newLines) - modifiedContent = contentLines.join('\n') - matchFound = true - break - } - } - - if (!matchFound) { - throw new Error(`Could not find exact match for edit:\n${edit.oldText}`) - } - } - - // Create unified diff - const diff = createUnifiedDiff(content, modifiedContent, filePath) - - // Format diff with appropriate number of backticks - let numBackticks = 3 - while (diff.includes('`'.repeat(numBackticks))) { - numBackticks++ - } - const formattedDiff = `${'`'.repeat(numBackticks)}diff\n${diff}${'`'.repeat(numBackticks)}\n\n` - - if (!dryRun) { - await fs.writeFile(filePath, modifiedContent, 'utf-8') - } - - return formattedDiff -} - -class FileSystemServer { - public server: Server - private allowedDirectories: string[] - constructor(allowedDirs: string[]) { - if (!Array.isArray(allowedDirs) || allowedDirs.length === 0) { - throw new Error('No allowed directories provided, please specify at least one directory in args') - } - - this.allowedDirectories = allowedDirs.map((dir) => normalizePath(path.resolve(expandHome(dir)))) - - // Validate that all directories exist and are accessible - this.validateDirs().catch((error) => { - logger.error('Error validating allowed directories:', error) - throw new Error(`Error validating allowed directories: ${error}`) - }) - - this.server = new Server( - { - name: 'secure-filesystem-server', - version: '0.2.0' - }, - { - capabilities: { - tools: {} - } - } - ) - this.initialize() - } - - async validateDirs() { - // Validate that all directories exist and are accessible - await Promise.all( - this.allowedDirectories.map(async (dir) => { - try { - const stats = await fs.stat(expandHome(dir)) - if (!stats.isDirectory()) { - logger.error(`Error: ${dir} is not a directory`) - throw new Error(`Error: ${dir} is not a directory`) - } - } catch (error: any) { - logger.error(`Error accessing directory ${dir}:`, error) - throw new Error(`Error accessing directory ${dir}:`, error) - } - }) - ) - } - - initialize() { - // Tool handlers - this.server.setRequestHandler(ListToolsRequestSchema, async () => { - return { - tools: [ - { - name: 'read_file', - description: - 'Read the complete contents of a file from the file system. ' + - 'Handles various text encodings and provides detailed error messages ' + - 'if the file cannot be read. Use this tool when you need to examine ' + - 'the contents of a single file. Only works within allowed directories.', - inputSchema: z.toJSONSchema(ReadFileArgsSchema) - }, - { - name: 'read_multiple_files', - description: - 'Read the contents of multiple files simultaneously. This is more ' + - 'efficient than reading files one by one when you need to analyze ' + - "or compare multiple files. Each file's content is returned with its " + - "path as a reference. Failed reads for individual files won't stop " + - 'the entire operation. Only works within allowed directories.', - inputSchema: z.toJSONSchema(ReadMultipleFilesArgsSchema) - }, - { - name: 'write_file', - description: - 'Create a new file or completely overwrite an existing file with new content. ' + - 'Use with caution as it will overwrite existing files without warning. ' + - 'Handles text content with proper encoding. Only works within allowed directories.', - inputSchema: z.toJSONSchema(WriteFileArgsSchema) - }, - { - name: 'edit_file', - description: - 'Make line-based edits to a text file. Each edit replaces exact line sequences ' + - 'with new content. Returns a git-style diff showing the changes made. ' + - 'Only works within allowed directories.', - inputSchema: z.toJSONSchema(EditFileArgsSchema) - }, - { - name: 'create_directory', - description: - 'Create a new directory or ensure a directory exists. Can create multiple ' + - 'nested directories in one operation. If the directory already exists, ' + - 'this operation will succeed silently. Perfect for setting up directory ' + - 'structures for projects or ensuring required paths exist. Only works within allowed directories.', - inputSchema: z.toJSONSchema(CreateDirectoryArgsSchema) - }, - { - name: 'list_directory', - description: - 'Get a detailed listing of all files and directories in a specified path. ' + - 'Results clearly distinguish between files and directories with [FILE] and [DIR] ' + - 'prefixes. This tool is essential for understanding directory structure and ' + - 'finding specific files within a directory. Only works within allowed directories.', - inputSchema: z.toJSONSchema(ListDirectoryArgsSchema) - }, - { - name: 'directory_tree', - description: - 'Get a recursive tree view of files and directories as a JSON structure. ' + - "Each entry includes 'name', 'type' (file/directory), and 'children' for directories. " + - 'Files have no children array, while directories always have a children array (which may be empty). ' + - 'The output is formatted with 2-space indentation for readability. Only works within allowed directories.', - inputSchema: z.toJSONSchema(DirectoryTreeArgsSchema) - }, - { - name: 'move_file', - description: - 'Move or rename files and directories. Can move files between directories ' + - 'and rename them in a single operation. If the destination exists, the ' + - 'operation will fail. Works across different directories and can be used ' + - 'for simple renaming within the same directory. Both source and destination must be within allowed directories.', - inputSchema: z.toJSONSchema(MoveFileArgsSchema) - }, - { - name: 'search_files', - description: - 'Recursively search for files and directories matching a pattern. ' + - 'Searches through all subdirectories from the starting path. The search ' + - 'is case-insensitive and matches partial names. Returns full paths to all ' + - "matching items. Great for finding files when you don't know their exact location. " + - 'Only searches within allowed directories.', - inputSchema: z.toJSONSchema(SearchFilesArgsSchema) - }, - { - name: 'get_file_info', - description: - 'Retrieve detailed metadata about a file or directory. Returns comprehensive ' + - 'information including size, creation time, last modified time, permissions, ' + - 'and type. This tool is perfect for understanding file characteristics ' + - 'without reading the actual content. Only works within allowed directories.', - inputSchema: z.toJSONSchema(GetFileInfoArgsSchema) - }, - { - name: 'list_allowed_directories', - description: - 'Returns the list of directories that this server is allowed to access. ' + - 'Use this to understand which directories are available before trying to access files.', - inputSchema: { - type: 'object', - properties: {}, - required: [] - } - } - ] - } - }) - - this.server.setRequestHandler(CallToolRequestSchema, async (request) => { - try { - const { name, arguments: args } = request.params - - switch (name) { - case 'read_file': { - const parsed = ReadFileArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for read_file: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const content = await fs.readFile(validPath, 'utf-8') - return { - content: [{ type: 'text', text: content }] - } - } - - case 'read_multiple_files': { - const parsed = ReadMultipleFilesArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for read_multiple_files: ${parsed.error}`) - } - const results = await Promise.all( - parsed.data.paths.map(async (filePath: string) => { - try { - const validPath = await validatePath(this.allowedDirectories, filePath) - const content = await fs.readFile(validPath, 'utf-8') - return `${filePath}:\n${content}\n` - } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error) - return `${filePath}: Error - ${errorMessage}` - } - }) - ) - return { - content: [{ type: 'text', text: results.join('\n---\n') }] - } - } - - case 'write_file': { - const parsed = WriteFileArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for write_file: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - await fs.writeFile(validPath, parsed.data.content, 'utf-8') - return { - content: [{ type: 'text', text: `Successfully wrote to ${parsed.data.path}` }] - } - } - - case 'edit_file': { - const parsed = EditFileArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for edit_file: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const result = await applyFileEdits(validPath, parsed.data.edits, parsed.data.dryRun) - return { - content: [{ type: 'text', text: result }] - } - } - - case 'create_directory': { - const parsed = CreateDirectoryArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for create_directory: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - await fs.mkdir(validPath, { recursive: true }) - return { - content: [{ type: 'text', text: `Successfully created directory ${parsed.data.path}` }] - } - } - - case 'list_directory': { - const parsed = ListDirectoryArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for list_directory: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const entries = await fs.readdir(validPath, { withFileTypes: true }) - const formatted = entries - .map((entry) => `${entry.isDirectory() ? '[DIR]' : '[FILE]'} ${entry.name}`) - .join('\n') - return { - content: [{ type: 'text', text: formatted }] - } - } - - case 'directory_tree': { - const parsed = DirectoryTreeArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for directory_tree: ${parsed.error}`) - } - - interface TreeEntry { - name: string - type: 'file' | 'directory' - children?: TreeEntry[] - } - - async function buildTree(allowedDirectories: string[], currentPath: string): Promise { - const validPath = await validatePath(allowedDirectories, currentPath) - const entries = await fs.readdir(validPath, { withFileTypes: true }) - const result: TreeEntry[] = [] - - for (const entry of entries) { - const entryData: TreeEntry = { - name: entry.name, - type: entry.isDirectory() ? 'directory' : 'file' - } - - if (entry.isDirectory()) { - const subPath = path.join(currentPath, entry.name) - entryData.children = await buildTree(allowedDirectories, subPath) - } - - result.push(entryData) - } - - return result - } - - const treeData = await buildTree(this.allowedDirectories, parsed.data.path) - return { - content: [ - { - type: 'text', - text: JSON.stringify(treeData, null, 2) - } - ] - } - } - - case 'move_file': { - const parsed = MoveFileArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for move_file: ${parsed.error}`) - } - const validSourcePath = await validatePath(this.allowedDirectories, parsed.data.source) - const validDestPath = await validatePath(this.allowedDirectories, parsed.data.destination) - await fs.rename(validSourcePath, validDestPath) - return { - content: [ - { type: 'text', text: `Successfully moved ${parsed.data.source} to ${parsed.data.destination}` } - ] - } - } - - case 'search_files': { - const parsed = SearchFilesArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for search_files: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const results = await searchFiles( - this.allowedDirectories, - validPath, - parsed.data.pattern, - parsed.data.excludePatterns - ) - return { - content: [{ type: 'text', text: results.length > 0 ? results.join('\n') : 'No matches found' }] - } - } - - case 'get_file_info': { - const parsed = GetFileInfoArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for get_file_info: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const info = await getFileStats(validPath) - return { - content: [ - { - type: 'text', - text: Object.entries(info) - .map(([key, value]) => `${key}: ${value}`) - .join('\n') - } - ] - } - } - - case 'list_allowed_directories': { - return { - content: [ - { - type: 'text', - text: `Allowed directories:\n${this.allowedDirectories.join('\n')}` - } - ] - } - } - - default: - throw new Error(`Unknown tool: ${name}`) - } - } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error) - return { - content: [{ type: 'text', text: `Error: ${errorMessage}` }], - isError: true - } - } - }) - } -} - -export default FileSystemServer diff --git a/src/main/mcpServers/filesystem/index.ts b/src/main/mcpServers/filesystem/index.ts new file mode 100644 index 0000000000..cec4c31cdf --- /dev/null +++ b/src/main/mcpServers/filesystem/index.ts @@ -0,0 +1,2 @@ +// Re-export FileSystemServer to maintain existing import pattern +export { default, FileSystemServer } from './server' diff --git a/src/main/mcpServers/filesystem/server.ts b/src/main/mcpServers/filesystem/server.ts new file mode 100644 index 0000000000..164ba0c9c4 --- /dev/null +++ b/src/main/mcpServers/filesystem/server.ts @@ -0,0 +1,118 @@ +import { Server } from '@modelcontextprotocol/sdk/server/index.js' +import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js' +import { app } from 'electron' +import fs from 'fs/promises' +import path from 'path' + +import { + deleteToolDefinition, + editToolDefinition, + globToolDefinition, + grepToolDefinition, + handleDeleteTool, + handleEditTool, + handleGlobTool, + handleGrepTool, + handleLsTool, + handleReadTool, + handleWriteTool, + lsToolDefinition, + readToolDefinition, + writeToolDefinition +} from './tools' +import { logger } from './types' + +export class FileSystemServer { + public server: Server + private baseDir: string + + constructor(baseDir?: string) { + if (baseDir && path.isAbsolute(baseDir)) { + this.baseDir = baseDir + logger.info(`Using provided baseDir for filesystem MCP: ${baseDir}`) + } else { + const userData = app.getPath('userData') + this.baseDir = path.join(userData, 'Data', 'Workspace') + logger.info(`Using default workspace for filesystem MCP baseDir: ${this.baseDir}`) + } + + this.server = new Server( + { + name: 'filesystem-server', + version: '2.0.0' + }, + { + capabilities: { + tools: {} + } + } + ) + + this.initialize() + } + + async initialize() { + try { + await fs.mkdir(this.baseDir, { recursive: true }) + } catch (error) { + logger.error('Failed to create filesystem MCP baseDir', { error, baseDir: this.baseDir }) + } + + // Register tool list handler + this.server.setRequestHandler(ListToolsRequestSchema, async () => { + return { + tools: [ + globToolDefinition, + lsToolDefinition, + grepToolDefinition, + readToolDefinition, + editToolDefinition, + writeToolDefinition, + deleteToolDefinition + ] + } + }) + + // Register tool call handler + this.server.setRequestHandler(CallToolRequestSchema, async (request) => { + try { + const { name, arguments: args } = request.params + + switch (name) { + case 'glob': + return await handleGlobTool(args, this.baseDir) + + case 'ls': + return await handleLsTool(args, this.baseDir) + + case 'grep': + return await handleGrepTool(args, this.baseDir) + + case 'read': + return await handleReadTool(args, this.baseDir) + + case 'edit': + return await handleEditTool(args, this.baseDir) + + case 'write': + return await handleWriteTool(args, this.baseDir) + + case 'delete': + return await handleDeleteTool(args, this.baseDir) + + default: + throw new Error(`Unknown tool: ${name}`) + } + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + logger.error(`Tool execution error for ${request.params.name}:`, { error }) + return { + content: [{ type: 'text', text: `Error: ${errorMessage}` }], + isError: true + } + } + }) + } +} + +export default FileSystemServer diff --git a/src/main/mcpServers/filesystem/tools/delete.ts b/src/main/mcpServers/filesystem/tools/delete.ts new file mode 100644 index 0000000000..83becc4f17 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/delete.ts @@ -0,0 +1,93 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { logger, validatePath } from '../types' + +// Schema definition +export const DeleteToolSchema = z.object({ + path: z.string().describe('The path to the file or directory to delete'), + recursive: z.boolean().optional().describe('For directories, whether to delete recursively (default: false)') +}) + +// Tool definition with detailed description +export const deleteToolDefinition = { + name: 'delete', + description: `Deletes a file or directory from the filesystem. + +CAUTION: This operation cannot be undone! + +- For files: simply provide the path +- For empty directories: provide the path +- For non-empty directories: set recursive=true +- The path must be an absolute path, not a relative path +- Always verify the path before deleting to avoid data loss`, + inputSchema: z.toJSONSchema(DeleteToolSchema) +} + +// Handler implementation +export async function handleDeleteTool(args: unknown, baseDir: string) { + const parsed = DeleteToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for delete: ${parsed.error}`) + } + + const targetPath = parsed.data.path + const validPath = await validatePath(targetPath, baseDir) + const recursive = parsed.data.recursive || false + + // Check if path exists and get stats + let stats + try { + stats = await fs.stat(validPath) + } catch (error: any) { + if (error.code === 'ENOENT') { + throw new Error(`Path not found: ${targetPath}`) + } + throw error + } + + const isDirectory = stats.isDirectory() + const relativePath = path.relative(baseDir, validPath) + + // Perform deletion + try { + if (isDirectory) { + if (recursive) { + // Delete directory recursively + await fs.rm(validPath, { recursive: true, force: true }) + } else { + // Try to delete empty directory + await fs.rmdir(validPath) + } + } else { + // Delete file + await fs.unlink(validPath) + } + } catch (error: any) { + if (error.code === 'ENOTEMPTY') { + throw new Error(`Directory not empty: ${targetPath}. Use recursive=true to delete non-empty directories.`) + } + throw new Error(`Failed to delete: ${error.message}`) + } + + // Log the operation + logger.info('Path deleted', { + path: validPath, + type: isDirectory ? 'directory' : 'file', + recursive: isDirectory ? recursive : undefined + }) + + // Format output + const itemType = isDirectory ? 'Directory' : 'File' + const recursiveNote = isDirectory && recursive ? ' (recursive)' : '' + + return { + content: [ + { + type: 'text', + text: `${itemType} deleted${recursiveNote}: ${relativePath}` + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/edit.ts b/src/main/mcpServers/filesystem/tools/edit.ts new file mode 100644 index 0000000000..c1a0e637ce --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/edit.ts @@ -0,0 +1,130 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { logger, replaceWithFuzzyMatch, validatePath } from '../types' + +// Schema definition +export const EditToolSchema = z.object({ + file_path: z.string().describe('The path to the file to modify'), + old_string: z.string().describe('The text to replace'), + new_string: z.string().describe('The text to replace it with'), + replace_all: z.boolean().optional().default(false).describe('Replace all occurrences of old_string (default false)') +}) + +// Tool definition with detailed description +export const editToolDefinition = { + name: 'edit', + description: `Performs exact string replacements in files. + +- You must use the 'read' tool at least once before editing +- The file_path must be an absolute path, not a relative path +- Preserve exact indentation from read output (after the line number prefix) +- Never include line number prefixes in old_string or new_string +- ALWAYS prefer editing existing files over creating new ones +- The edit will FAIL if old_string is not found in the file +- The edit will FAIL if old_string appears multiple times (provide more context or use replace_all) +- The edit will FAIL if old_string equals new_string +- Use replace_all to rename variables or replace all occurrences`, + inputSchema: z.toJSONSchema(EditToolSchema) +} + +// Handler implementation +export async function handleEditTool(args: unknown, baseDir: string) { + const parsed = EditToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for edit: ${parsed.error}`) + } + + const { file_path: filePath, old_string: oldString, new_string: newString, replace_all: replaceAll } = parsed.data + + // Validate path + const validPath = await validatePath(filePath, baseDir) + + // Check if file exists + try { + const stats = await fs.stat(validPath) + if (!stats.isFile()) { + throw new Error(`Path is not a file: ${filePath}`) + } + } catch (error: any) { + if (error.code === 'ENOENT') { + // If old_string is empty, this is a create new file operation + if (oldString === '') { + // Create parent directory if needed + const parentDir = path.dirname(validPath) + await fs.mkdir(parentDir, { recursive: true }) + + // Write the new content + await fs.writeFile(validPath, newString, 'utf-8') + + logger.info('File created', { path: validPath }) + + const relativePath = path.relative(baseDir, validPath) + return { + content: [ + { + type: 'text', + text: `Created new file: ${relativePath}\nLines: ${newString.split('\n').length}` + } + ] + } + } + throw new Error(`File not found: ${filePath}`) + } + throw error + } + + // Read current content + const content = await fs.readFile(validPath, 'utf-8') + + // Handle special case: old_string is empty (create file with content) + if (oldString === '') { + await fs.writeFile(validPath, newString, 'utf-8') + + logger.info('File overwritten', { path: validPath }) + + const relativePath = path.relative(baseDir, validPath) + return { + content: [ + { + type: 'text', + text: `Overwrote file: ${relativePath}\nLines: ${newString.split('\n').length}` + } + ] + } + } + + // Perform the replacement with fuzzy matching + const newContent = replaceWithFuzzyMatch(content, oldString, newString, replaceAll) + + // Write the modified content + await fs.writeFile(validPath, newContent, 'utf-8') + + logger.info('File edited', { + path: validPath, + replaceAll + }) + + // Generate a simple diff summary + const oldLines = content.split('\n').length + const newLines = newContent.split('\n').length + const lineDiff = newLines - oldLines + + const relativePath = path.relative(baseDir, validPath) + let diffSummary = `Edited: ${relativePath}` + if (lineDiff > 0) { + diffSummary += `\n+${lineDiff} lines` + } else if (lineDiff < 0) { + diffSummary += `\n${lineDiff} lines` + } + + return { + content: [ + { + type: 'text', + text: diffSummary + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/glob.ts b/src/main/mcpServers/filesystem/tools/glob.ts new file mode 100644 index 0000000000..d6a6b4a757 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/glob.ts @@ -0,0 +1,149 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import type { FileInfo } from '../types' +import { logger, MAX_FILES_LIMIT, runRipgrep, validatePath } from '../types' + +// Schema definition +export const GlobToolSchema = z.object({ + pattern: z.string().describe('The glob pattern to match files against'), + path: z + .string() + .optional() + .describe('The directory to search in (must be absolute path). Defaults to the base directory') +}) + +// Tool definition with detailed description +export const globToolDefinition = { + name: 'glob', + description: `Fast file pattern matching tool that works with any codebase size. + +- Supports glob patterns like "**/*.js" or "src/**/*.ts" +- Returns matching absolute file paths sorted by modification time (newest first) +- Use this when you need to find files by name patterns +- Patterns without "/" (e.g., "*.txt") match files at ANY depth in the directory tree +- Patterns with "/" (e.g., "src/*.ts") match relative to the search path +- Pattern syntax: * (any chars), ** (any path), {a,b} (alternatives), ? (single char) +- Results are limited to 100 files +- The path parameter must be an absolute path if specified +- If path is not specified, defaults to the base directory +- IMPORTANT: Omit the path field for the default directory (don't use "undefined" or "null")`, + inputSchema: z.toJSONSchema(GlobToolSchema) +} + +// Handler implementation +export async function handleGlobTool(args: unknown, baseDir: string) { + const parsed = GlobToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for glob: ${parsed.error}`) + } + + const searchPath = parsed.data.path || baseDir + const validPath = await validatePath(searchPath, baseDir) + + // Verify the search directory exists + try { + const stats = await fs.stat(validPath) + if (!stats.isDirectory()) { + throw new Error(`Path is not a directory: ${validPath}`) + } + } catch (error: unknown) { + if (error && typeof error === 'object' && 'code' in error && error.code === 'ENOENT') { + throw new Error(`Directory not found: ${validPath}`) + } + throw error + } + + // Validate pattern + const pattern = parsed.data.pattern.trim() + if (!pattern) { + throw new Error('Pattern cannot be empty') + } + + const files: FileInfo[] = [] + let truncated = false + + // Build ripgrep arguments for file listing using --glob=pattern format + const rgArgs: string[] = [ + '--files', + '--follow', + '--hidden', + `--glob=${pattern}`, + '--glob=!.git/*', + '--glob=!node_modules/*', + '--glob=!dist/*', + '--glob=!build/*', + '--glob=!__pycache__/*', + validPath + ] + + // Use ripgrep for file listing + logger.debug('Running ripgrep with args', { rgArgs }) + const rgResult = await runRipgrep(rgArgs) + logger.debug('Ripgrep result', { + ok: rgResult.ok, + exitCode: rgResult.exitCode, + stdoutLength: rgResult.stdout.length, + stdoutPreview: rgResult.stdout.slice(0, 500) + }) + + // Process results if we have stdout content + // Exit code 2 can indicate partial errors (e.g., permission denied on some dirs) but still have valid results + if (rgResult.ok && rgResult.stdout.length > 0) { + const lines = rgResult.stdout.split('\n').filter(Boolean) + logger.debug('Parsed lines from ripgrep', { lineCount: lines.length, lines }) + + for (const line of lines) { + if (files.length >= MAX_FILES_LIMIT) { + truncated = true + break + } + + const filePath = line.trim() + if (!filePath) continue + + const absolutePath = path.isAbsolute(filePath) ? filePath : path.resolve(validPath, filePath) + + try { + const stats = await fs.stat(absolutePath) + files.push({ + path: absolutePath, + type: 'file', // ripgrep --files only returns files + size: stats.size, + modified: stats.mtime + }) + } catch (error) { + logger.debug('Failed to stat file from ripgrep output, skipping', { file: absolutePath, error }) + } + } + } + + // Sort by modification time (newest first) + files.sort((a, b) => { + const aTime = a.modified ? a.modified.getTime() : 0 + const bTime = b.modified ? b.modified.getTime() : 0 + return bTime - aTime + }) + + // Format output - always use absolute paths + const output: string[] = [] + if (files.length === 0) { + output.push(`No files found matching pattern "${parsed.data.pattern}" in ${validPath}`) + } else { + output.push(...files.map((f) => f.path)) + if (truncated) { + output.push('') + output.push(`(Results truncated to ${MAX_FILES_LIMIT} files. Consider using a more specific pattern.)`) + } + } + + return { + content: [ + { + type: 'text', + text: output.join('\n') + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/grep.ts b/src/main/mcpServers/filesystem/tools/grep.ts new file mode 100644 index 0000000000..d822db9d88 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/grep.ts @@ -0,0 +1,266 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import type { GrepMatch } from '../types' +import { isBinaryFile, MAX_GREP_MATCHES, MAX_LINE_LENGTH, runRipgrep, validatePath } from '../types' + +// Schema definition +export const GrepToolSchema = z.object({ + pattern: z.string().describe('The regex pattern to search for in file contents'), + path: z + .string() + .optional() + .describe('The directory to search in (must be absolute path). Defaults to the base directory'), + include: z.string().optional().describe('File pattern to include in the search (e.g. "*.js", "*.{ts,tsx}")') +}) + +// Tool definition with detailed description +export const grepToolDefinition = { + name: 'grep', + description: `Fast content search tool that works with any codebase size. + +- Searches file contents using regular expressions +- Supports full regex syntax (e.g., "log.*Error", "function\\s+\\w+") +- Filter files by pattern with include (e.g., "*.js", "*.{ts,tsx}") +- Returns absolute file paths and line numbers with matching content +- Results are limited to 100 matches +- Binary files are automatically skipped +- Common directories (node_modules, .git, dist) are excluded +- The path parameter must be an absolute path if specified +- If path is not specified, defaults to the base directory`, + inputSchema: z.toJSONSchema(GrepToolSchema) +} + +// Handler implementation +export async function handleGrepTool(args: unknown, baseDir: string) { + const parsed = GrepToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for grep: ${parsed.error}`) + } + + const data = parsed.data + + if (!data.pattern) { + throw new Error('Pattern is required for grep') + } + + const searchPath = data.path || baseDir + const validPath = await validatePath(searchPath, baseDir) + + const matches: GrepMatch[] = [] + let truncated = false + let regex: RegExp + + // Build ripgrep arguments + const rgArgs: string[] = [ + '--no-heading', + '--line-number', + '--color', + 'never', + '--ignore-case', + '--glob', + '!.git/**', + '--glob', + '!node_modules/**', + '--glob', + '!dist/**', + '--glob', + '!build/**', + '--glob', + '!__pycache__/**' + ] + + if (data.include) { + for (const pat of data.include + .split(',') + .map((p) => p.trim()) + .filter(Boolean)) { + rgArgs.push('--glob', pat) + } + } + + rgArgs.push(data.pattern) + rgArgs.push(validPath) + + try { + regex = new RegExp(data.pattern, 'gi') + } catch (error) { + throw new Error(`Invalid regex pattern: ${data.pattern}`) + } + + async function searchFile(filePath: string): Promise { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + return + } + + try { + // Skip binary files + if (await isBinaryFile(filePath)) { + return + } + + const content = await fs.readFile(filePath, 'utf-8') + const lines = content.split('\n') + + lines.forEach((line, index) => { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + return + } + + if (regex.test(line)) { + // Truncate long lines + const truncatedLine = line.length > MAX_LINE_LENGTH ? line.substring(0, MAX_LINE_LENGTH) + '...' : line + + matches.push({ + file: filePath, + line: index + 1, + content: truncatedLine.trim() + }) + } + }) + } catch (error) { + // Skip files we can't read + } + } + + async function searchDirectory(dir: string): Promise { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + return + } + + try { + const entries = await fs.readdir(dir, { withFileTypes: true }) + + for (const entry of entries) { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + break + } + + const fullPath = path.join(dir, entry.name) + + // Skip common ignore patterns + if (entry.name.startsWith('.') && entry.name !== '.env.example') { + continue + } + if (['node_modules', 'dist', 'build', '__pycache__', '.git'].includes(entry.name)) { + continue + } + + if (entry.isFile()) { + // Check if file matches include pattern + if (data.include) { + const includePatterns = data.include.split(',').map((p) => p.trim()) + const fileName = path.basename(fullPath) + const matchesInclude = includePatterns.some((pattern) => { + // Simple glob pattern matching + const regexPattern = pattern + .replace(/\*/g, '.*') + .replace(/\?/g, '.') + .replace(/\{([^}]+)\}/g, (_, group) => `(${group.split(',').join('|')})`) + return new RegExp(`^${regexPattern}$`).test(fileName) + }) + if (!matchesInclude) { + continue + } + } + + await searchFile(fullPath) + } else if (entry.isDirectory()) { + await searchDirectory(fullPath) + } + } + } catch (error) { + // Skip directories we can't read + } + } + + // Perform the search + let usedRipgrep = false + try { + const rgResult = await runRipgrep(rgArgs) + if (rgResult.ok && rgResult.exitCode !== null && rgResult.exitCode !== 2) { + usedRipgrep = true + const lines = rgResult.stdout.split('\n').filter(Boolean) + for (const line of lines) { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + break + } + + const firstColon = line.indexOf(':') + const secondColon = line.indexOf(':', firstColon + 1) + if (firstColon === -1 || secondColon === -1) continue + + const filePart = line.slice(0, firstColon) + const linePart = line.slice(firstColon + 1, secondColon) + const contentPart = line.slice(secondColon + 1) + const lineNum = Number.parseInt(linePart, 10) + if (!Number.isFinite(lineNum)) continue + + const absoluteFilePath = path.isAbsolute(filePart) ? filePart : path.resolve(baseDir, filePart) + const truncatedLine = + contentPart.length > MAX_LINE_LENGTH ? contentPart.substring(0, MAX_LINE_LENGTH) + '...' : contentPart + + matches.push({ + file: absoluteFilePath, + line: lineNum, + content: truncatedLine.trim() + }) + } + } + } catch { + usedRipgrep = false + } + + if (!usedRipgrep) { + const stats = await fs.stat(validPath) + if (stats.isFile()) { + await searchFile(validPath) + } else { + await searchDirectory(validPath) + } + } + + // Format output + const output: string[] = [] + + if (matches.length === 0) { + output.push('No matches found') + } else { + // Group matches by file + const fileGroups = new Map() + matches.forEach((match) => { + if (!fileGroups.has(match.file)) { + fileGroups.set(match.file, []) + } + fileGroups.get(match.file)!.push(match) + }) + + // Format grouped matches - always use absolute paths + fileGroups.forEach((fileMatches, filePath) => { + output.push(`\n${filePath}:`) + fileMatches.forEach((match) => { + output.push(` ${match.line}: ${match.content}`) + }) + }) + + if (truncated) { + output.push('') + output.push(`(Results truncated to ${MAX_GREP_MATCHES} matches. Consider using a more specific pattern or path.)`) + } + } + + return { + content: [ + { + type: 'text', + text: output.join('\n') + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/index.ts b/src/main/mcpServers/filesystem/tools/index.ts new file mode 100644 index 0000000000..2e02d613c4 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/index.ts @@ -0,0 +1,8 @@ +// Export all tool definitions and handlers +export { deleteToolDefinition, handleDeleteTool } from './delete' +export { editToolDefinition, handleEditTool } from './edit' +export { globToolDefinition, handleGlobTool } from './glob' +export { grepToolDefinition, handleGrepTool } from './grep' +export { handleLsTool, lsToolDefinition } from './ls' +export { handleReadTool, readToolDefinition } from './read' +export { handleWriteTool, writeToolDefinition } from './write' diff --git a/src/main/mcpServers/filesystem/tools/ls.ts b/src/main/mcpServers/filesystem/tools/ls.ts new file mode 100644 index 0000000000..22672c9fb9 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/ls.ts @@ -0,0 +1,150 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { MAX_FILES_LIMIT, validatePath } from '../types' + +// Schema definition +export const LsToolSchema = z.object({ + path: z.string().optional().describe('The directory to list (must be absolute path). Defaults to the base directory'), + recursive: z.boolean().optional().describe('Whether to list directories recursively (default: false)') +}) + +// Tool definition with detailed description +export const lsToolDefinition = { + name: 'ls', + description: `Lists files and directories in a specified path. + +- Returns a tree-like structure with icons (📁 directories, 📄 files) +- Shows the absolute directory path in the header +- Entries are sorted alphabetically with directories first +- Can list recursively with recursive=true (up to 5 levels deep) +- Common directories (node_modules, dist, .git) are excluded +- Hidden files (starting with .) are excluded except .env.example +- Results are limited to 100 entries +- The path parameter must be an absolute path if specified +- If path is not specified, defaults to the base directory`, + inputSchema: z.toJSONSchema(LsToolSchema) +} + +// Handler implementation +export async function handleLsTool(args: unknown, baseDir: string) { + const parsed = LsToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for ls: ${parsed.error}`) + } + + const targetPath = parsed.data.path || baseDir + const validPath = await validatePath(targetPath, baseDir) + const recursive = parsed.data.recursive || false + + interface TreeNode { + name: string + type: 'file' | 'directory' + children?: TreeNode[] + } + + let fileCount = 0 + let truncated = false + + async function buildTree(dirPath: string, depth: number = 0): Promise { + if (fileCount >= MAX_FILES_LIMIT) { + truncated = true + return [] + } + + try { + const entries = await fs.readdir(dirPath, { withFileTypes: true }) + const nodes: TreeNode[] = [] + + // Sort entries: directories first, then files, alphabetically + entries.sort((a, b) => { + if (a.isDirectory() && !b.isDirectory()) return -1 + if (!a.isDirectory() && b.isDirectory()) return 1 + return a.name.localeCompare(b.name) + }) + + for (const entry of entries) { + if (fileCount >= MAX_FILES_LIMIT) { + truncated = true + break + } + + // Skip hidden files and common ignore patterns + if (entry.name.startsWith('.') && entry.name !== '.env.example') { + continue + } + if (['node_modules', 'dist', 'build', '__pycache__'].includes(entry.name)) { + continue + } + + fileCount++ + const node: TreeNode = { + name: entry.name, + type: entry.isDirectory() ? 'directory' : 'file' + } + + if (entry.isDirectory() && recursive && depth < 5) { + // Limit depth to prevent infinite recursion + const childPath = path.join(dirPath, entry.name) + node.children = await buildTree(childPath, depth + 1) + } + + nodes.push(node) + } + + return nodes + } catch (error) { + return [] + } + } + + // Build the tree + const tree = await buildTree(validPath) + + // Format as text output + function formatTree(nodes: TreeNode[], prefix: string = ''): string[] { + const lines: string[] = [] + + nodes.forEach((node, index) => { + const isLastNode = index === nodes.length - 1 + const connector = isLastNode ? '└── ' : '├── ' + const icon = node.type === 'directory' ? '📁 ' : '📄 ' + + lines.push(prefix + connector + icon + node.name) + + if (node.children && node.children.length > 0) { + const childPrefix = prefix + (isLastNode ? ' ' : '│ ') + lines.push(...formatTree(node.children, childPrefix)) + } + }) + + return lines + } + + // Generate output + const output: string[] = [] + output.push(`Directory: ${validPath}`) + output.push('') + + if (tree.length === 0) { + output.push('(empty directory)') + } else { + const treeLines = formatTree(tree, '') + output.push(...treeLines) + + if (truncated) { + output.push('') + output.push(`(Results truncated to ${MAX_FILES_LIMIT} files. Consider listing a more specific directory.)`) + } + } + + return { + content: [ + { + type: 'text', + text: output.join('\n') + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/read.ts b/src/main/mcpServers/filesystem/tools/read.ts new file mode 100644 index 0000000000..460c88dda4 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/read.ts @@ -0,0 +1,101 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { DEFAULT_READ_LIMIT, isBinaryFile, MAX_LINE_LENGTH, validatePath } from '../types' + +// Schema definition +export const ReadToolSchema = z.object({ + file_path: z.string().describe('The path to the file to read'), + offset: z.number().optional().describe('The line number to start reading from (1-based)'), + limit: z.number().optional().describe('The number of lines to read (defaults to 2000)') +}) + +// Tool definition with detailed description +export const readToolDefinition = { + name: 'read', + description: `Reads a file from the local filesystem. + +- Assumes this tool can read all files on the machine +- The file_path parameter must be an absolute path, not a relative path +- By default, reads up to 2000 lines starting from the beginning +- You can optionally specify a line offset and limit for long files +- Any lines longer than 2000 characters will be truncated +- Results are returned with line numbers starting at 1 +- Binary files are detected and rejected with an error +- Empty files return a warning`, + inputSchema: z.toJSONSchema(ReadToolSchema) +} + +// Handler implementation +export async function handleReadTool(args: unknown, baseDir: string) { + const parsed = ReadToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for read: ${parsed.error}`) + } + + const filePath = parsed.data.file_path + const validPath = await validatePath(filePath, baseDir) + + // Check if file exists + try { + const stats = await fs.stat(validPath) + if (!stats.isFile()) { + throw new Error(`Path is not a file: ${filePath}`) + } + } catch (error: any) { + if (error.code === 'ENOENT') { + throw new Error(`File not found: ${filePath}`) + } + throw error + } + + // Check if file is binary + if (await isBinaryFile(validPath)) { + throw new Error(`Cannot read binary file: ${filePath}`) + } + + // Read file content + const content = await fs.readFile(validPath, 'utf-8') + const lines = content.split('\n') + + // Apply offset and limit + const offset = (parsed.data.offset || 1) - 1 // Convert to 0-based + const limit = parsed.data.limit || DEFAULT_READ_LIMIT + + if (offset < 0 || offset >= lines.length) { + throw new Error(`Invalid offset: ${offset + 1}. File has ${lines.length} lines.`) + } + + const selectedLines = lines.slice(offset, offset + limit) + + // Format output with line numbers and truncate long lines + const output: string[] = [] + const relativePath = path.relative(baseDir, validPath) + + output.push(`File: ${relativePath}`) + if (offset > 0 || limit < lines.length) { + output.push(`Lines ${offset + 1} to ${Math.min(offset + limit, lines.length)} of ${lines.length}`) + } + output.push('') + + selectedLines.forEach((line, index) => { + const lineNumber = offset + index + 1 + const truncatedLine = line.length > MAX_LINE_LENGTH ? line.substring(0, MAX_LINE_LENGTH) + '...' : line + output.push(`${lineNumber.toString().padStart(6)}\t${truncatedLine}`) + }) + + if (offset + limit < lines.length) { + output.push('') + output.push(`(${lines.length - (offset + limit)} more lines not shown)`) + } + + return { + content: [ + { + type: 'text', + text: output.join('\n') + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/write.ts b/src/main/mcpServers/filesystem/tools/write.ts new file mode 100644 index 0000000000..2898f2f874 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/write.ts @@ -0,0 +1,83 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { logger, validatePath } from '../types' + +// Schema definition +export const WriteToolSchema = z.object({ + file_path: z.string().describe('The path to the file to write'), + content: z.string().describe('The content to write to the file') +}) + +// Tool definition with detailed description +export const writeToolDefinition = { + name: 'write', + description: `Writes a file to the local filesystem. + +- This tool will overwrite the existing file if one exists at the path +- You MUST use the read tool first to understand what you're overwriting +- ALWAYS prefer using the 'edit' tool for existing files +- NEVER proactively create documentation files unless explicitly requested +- Parent directories will be created automatically if they don't exist +- The file_path must be an absolute path, not a relative path`, + inputSchema: z.toJSONSchema(WriteToolSchema) +} + +// Handler implementation +export async function handleWriteTool(args: unknown, baseDir: string) { + const parsed = WriteToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for write: ${parsed.error}`) + } + + const filePath = parsed.data.file_path + const validPath = await validatePath(filePath, baseDir) + + // Create parent directory if it doesn't exist + const parentDir = path.dirname(validPath) + try { + await fs.mkdir(parentDir, { recursive: true }) + } catch (error: any) { + if (error.code !== 'EEXIST') { + throw new Error(`Failed to create parent directory: ${error.message}`) + } + } + + // Check if file exists (for logging) + let isOverwrite = false + try { + await fs.stat(validPath) + isOverwrite = true + } catch { + // File doesn't exist, that's fine + } + + // Write the file + try { + await fs.writeFile(validPath, parsed.data.content, 'utf-8') + } catch (error: any) { + throw new Error(`Failed to write file: ${error.message}`) + } + + // Log the operation + logger.info('File written', { + path: validPath, + overwrite: isOverwrite, + size: parsed.data.content.length + }) + + // Format output + const relativePath = path.relative(baseDir, validPath) + const action = isOverwrite ? 'Updated' : 'Created' + const lines = parsed.data.content.split('\n').length + + return { + content: [ + { + type: 'text', + text: `${action} file: ${relativePath}\n` + `Size: ${parsed.data.content.length} bytes\n` + `Lines: ${lines}` + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/types.ts b/src/main/mcpServers/filesystem/types.ts new file mode 100644 index 0000000000..922fe0b23a --- /dev/null +++ b/src/main/mcpServers/filesystem/types.ts @@ -0,0 +1,627 @@ +import { loggerService } from '@logger' +import { isMac, isWin } from '@main/constant' +import { spawn } from 'child_process' +import fs from 'fs/promises' +import os from 'os' +import path from 'path' + +export const logger = loggerService.withContext('MCP:FileSystemServer') + +// Constants +export const MAX_LINE_LENGTH = 2000 +export const DEFAULT_READ_LIMIT = 2000 +export const MAX_FILES_LIMIT = 100 +export const MAX_GREP_MATCHES = 100 + +// Common types +export interface FileInfo { + path: string + type: 'file' | 'directory' + size?: number + modified?: Date +} + +export interface GrepMatch { + file: string + line: number + content: string +} + +// Utility functions for path handling +export function normalizePath(p: string): string { + return path.normalize(p) +} + +export function expandHome(filepath: string): string { + if (filepath.startsWith('~/') || filepath === '~') { + return path.join(os.homedir(), filepath.slice(1)) + } + return filepath +} + +// Security validation +export async function validatePath(requestedPath: string, baseDir?: string): Promise { + const expandedPath = expandHome(requestedPath) + const root = baseDir ?? process.cwd() + const absolute = path.isAbsolute(expandedPath) ? path.resolve(expandedPath) : path.resolve(root, expandedPath) + + // Handle symlinks by checking their real path + try { + const realPath = await fs.realpath(absolute) + return normalizePath(realPath) + } catch (error) { + // For new files that don't exist yet, verify parent directory + const parentDir = path.dirname(absolute) + try { + const realParentPath = await fs.realpath(parentDir) + normalizePath(realParentPath) + return normalizePath(absolute) + } catch { + return normalizePath(absolute) + } + } +} + +// ============================================================================ +// Edit Tool Utilities - Fuzzy matching replacers from opencode +// ============================================================================ + +export type Replacer = (content: string, find: string) => Generator + +// Similarity thresholds for block anchor fallback matching +const SINGLE_CANDIDATE_SIMILARITY_THRESHOLD = 0.0 +const MULTIPLE_CANDIDATES_SIMILARITY_THRESHOLD = 0.3 + +/** + * Levenshtein distance algorithm implementation + */ +function levenshtein(a: string, b: string): number { + if (a === '' || b === '') { + return Math.max(a.length, b.length) + } + const matrix = Array.from({ length: a.length + 1 }, (_, i) => + Array.from({ length: b.length + 1 }, (_, j) => (i === 0 ? j : j === 0 ? i : 0)) + ) + + for (let i = 1; i <= a.length; i++) { + for (let j = 1; j <= b.length; j++) { + const cost = a[i - 1] === b[j - 1] ? 0 : 1 + matrix[i][j] = Math.min(matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + cost) + } + } + return matrix[a.length][b.length] +} + +export const SimpleReplacer: Replacer = function* (_content, find) { + yield find +} + +export const LineTrimmedReplacer: Replacer = function* (content, find) { + const originalLines = content.split('\n') + const searchLines = find.split('\n') + + if (searchLines[searchLines.length - 1] === '') { + searchLines.pop() + } + + for (let i = 0; i <= originalLines.length - searchLines.length; i++) { + let matches = true + + for (let j = 0; j < searchLines.length; j++) { + const originalTrimmed = originalLines[i + j].trim() + const searchTrimmed = searchLines[j].trim() + + if (originalTrimmed !== searchTrimmed) { + matches = false + break + } + } + + if (matches) { + let matchStartIndex = 0 + for (let k = 0; k < i; k++) { + matchStartIndex += originalLines[k].length + 1 + } + + let matchEndIndex = matchStartIndex + for (let k = 0; k < searchLines.length; k++) { + matchEndIndex += originalLines[i + k].length + if (k < searchLines.length - 1) { + matchEndIndex += 1 + } + } + + yield content.substring(matchStartIndex, matchEndIndex) + } + } +} + +export const BlockAnchorReplacer: Replacer = function* (content, find) { + const originalLines = content.split('\n') + const searchLines = find.split('\n') + + if (searchLines.length < 3) { + return + } + + if (searchLines[searchLines.length - 1] === '') { + searchLines.pop() + } + + const firstLineSearch = searchLines[0].trim() + const lastLineSearch = searchLines[searchLines.length - 1].trim() + const searchBlockSize = searchLines.length + + const candidates: Array<{ startLine: number; endLine: number }> = [] + for (let i = 0; i < originalLines.length; i++) { + if (originalLines[i].trim() !== firstLineSearch) { + continue + } + + for (let j = i + 2; j < originalLines.length; j++) { + if (originalLines[j].trim() === lastLineSearch) { + candidates.push({ startLine: i, endLine: j }) + break + } + } + } + + if (candidates.length === 0) { + return + } + + if (candidates.length === 1) { + const { startLine, endLine } = candidates[0] + const actualBlockSize = endLine - startLine + 1 + + let similarity = 0 + const linesToCheck = Math.min(searchBlockSize - 2, actualBlockSize - 2) + + if (linesToCheck > 0) { + for (let j = 1; j < searchBlockSize - 1 && j < actualBlockSize - 1; j++) { + const originalLine = originalLines[startLine + j].trim() + const searchLine = searchLines[j].trim() + const maxLen = Math.max(originalLine.length, searchLine.length) + if (maxLen === 0) { + continue + } + const distance = levenshtein(originalLine, searchLine) + similarity += (1 - distance / maxLen) / linesToCheck + + if (similarity >= SINGLE_CANDIDATE_SIMILARITY_THRESHOLD) { + break + } + } + } else { + similarity = 1.0 + } + + if (similarity >= SINGLE_CANDIDATE_SIMILARITY_THRESHOLD) { + let matchStartIndex = 0 + for (let k = 0; k < startLine; k++) { + matchStartIndex += originalLines[k].length + 1 + } + let matchEndIndex = matchStartIndex + for (let k = startLine; k <= endLine; k++) { + matchEndIndex += originalLines[k].length + if (k < endLine) { + matchEndIndex += 1 + } + } + yield content.substring(matchStartIndex, matchEndIndex) + } + return + } + + let bestMatch: { startLine: number; endLine: number } | null = null + let maxSimilarity = -1 + + for (const candidate of candidates) { + const { startLine, endLine } = candidate + const actualBlockSize = endLine - startLine + 1 + + let similarity = 0 + const linesToCheck = Math.min(searchBlockSize - 2, actualBlockSize - 2) + + if (linesToCheck > 0) { + for (let j = 1; j < searchBlockSize - 1 && j < actualBlockSize - 1; j++) { + const originalLine = originalLines[startLine + j].trim() + const searchLine = searchLines[j].trim() + const maxLen = Math.max(originalLine.length, searchLine.length) + if (maxLen === 0) { + continue + } + const distance = levenshtein(originalLine, searchLine) + similarity += 1 - distance / maxLen + } + similarity /= linesToCheck + } else { + similarity = 1.0 + } + + if (similarity > maxSimilarity) { + maxSimilarity = similarity + bestMatch = candidate + } + } + + if (maxSimilarity >= MULTIPLE_CANDIDATES_SIMILARITY_THRESHOLD && bestMatch) { + const { startLine, endLine } = bestMatch + let matchStartIndex = 0 + for (let k = 0; k < startLine; k++) { + matchStartIndex += originalLines[k].length + 1 + } + let matchEndIndex = matchStartIndex + for (let k = startLine; k <= endLine; k++) { + matchEndIndex += originalLines[k].length + if (k < endLine) { + matchEndIndex += 1 + } + } + yield content.substring(matchStartIndex, matchEndIndex) + } +} + +export const WhitespaceNormalizedReplacer: Replacer = function* (content, find) { + const normalizeWhitespace = (text: string) => text.replace(/\s+/g, ' ').trim() + const normalizedFind = normalizeWhitespace(find) + + const lines = content.split('\n') + for (let i = 0; i < lines.length; i++) { + const line = lines[i] + if (normalizeWhitespace(line) === normalizedFind) { + yield line + } else { + const normalizedLine = normalizeWhitespace(line) + if (normalizedLine.includes(normalizedFind)) { + const words = find.trim().split(/\s+/) + if (words.length > 0) { + const pattern = words.map((word) => word.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')).join('\\s+') + try { + const regex = new RegExp(pattern) + const match = line.match(regex) + if (match) { + yield match[0] + } + } catch { + // Invalid regex pattern, skip + } + } + } + } + } + + const findLines = find.split('\n') + if (findLines.length > 1) { + for (let i = 0; i <= lines.length - findLines.length; i++) { + const block = lines.slice(i, i + findLines.length) + if (normalizeWhitespace(block.join('\n')) === normalizedFind) { + yield block.join('\n') + } + } + } +} + +export const IndentationFlexibleReplacer: Replacer = function* (content, find) { + const removeIndentation = (text: string) => { + const lines = text.split('\n') + const nonEmptyLines = lines.filter((line) => line.trim().length > 0) + if (nonEmptyLines.length === 0) return text + + const minIndent = Math.min( + ...nonEmptyLines.map((line) => { + const match = line.match(/^(\s*)/) + return match ? match[1].length : 0 + }) + ) + + return lines.map((line) => (line.trim().length === 0 ? line : line.slice(minIndent))).join('\n') + } + + const normalizedFind = removeIndentation(find) + const contentLines = content.split('\n') + const findLines = find.split('\n') + + for (let i = 0; i <= contentLines.length - findLines.length; i++) { + const block = contentLines.slice(i, i + findLines.length).join('\n') + if (removeIndentation(block) === normalizedFind) { + yield block + } + } +} + +export const EscapeNormalizedReplacer: Replacer = function* (content, find) { + const unescapeString = (str: string): string => { + return str.replace(/\\(n|t|r|'|"|`|\\|\n|\$)/g, (match, capturedChar) => { + switch (capturedChar) { + case 'n': + return '\n' + case 't': + return '\t' + case 'r': + return '\r' + case "'": + return "'" + case '"': + return '"' + case '`': + return '`' + case '\\': + return '\\' + case '\n': + return '\n' + case '$': + return '$' + default: + return match + } + }) + } + + const unescapedFind = unescapeString(find) + + if (content.includes(unescapedFind)) { + yield unescapedFind + } + + const lines = content.split('\n') + const findLines = unescapedFind.split('\n') + + for (let i = 0; i <= lines.length - findLines.length; i++) { + const block = lines.slice(i, i + findLines.length).join('\n') + const unescapedBlock = unescapeString(block) + + if (unescapedBlock === unescapedFind) { + yield block + } + } +} + +export const TrimmedBoundaryReplacer: Replacer = function* (content, find) { + const trimmedFind = find.trim() + + if (trimmedFind === find) { + return + } + + if (content.includes(trimmedFind)) { + yield trimmedFind + } + + const lines = content.split('\n') + const findLines = find.split('\n') + + for (let i = 0; i <= lines.length - findLines.length; i++) { + const block = lines.slice(i, i + findLines.length).join('\n') + + if (block.trim() === trimmedFind) { + yield block + } + } +} + +export const ContextAwareReplacer: Replacer = function* (content, find) { + const findLines = find.split('\n') + if (findLines.length < 3) { + return + } + + if (findLines[findLines.length - 1] === '') { + findLines.pop() + } + + const contentLines = content.split('\n') + + const firstLine = findLines[0].trim() + const lastLine = findLines[findLines.length - 1].trim() + + for (let i = 0; i < contentLines.length; i++) { + if (contentLines[i].trim() !== firstLine) continue + + for (let j = i + 2; j < contentLines.length; j++) { + if (contentLines[j].trim() === lastLine) { + const blockLines = contentLines.slice(i, j + 1) + const block = blockLines.join('\n') + + if (blockLines.length === findLines.length) { + let matchingLines = 0 + let totalNonEmptyLines = 0 + + for (let k = 1; k < blockLines.length - 1; k++) { + const blockLine = blockLines[k].trim() + const findLine = findLines[k].trim() + + if (blockLine.length > 0 || findLine.length > 0) { + totalNonEmptyLines++ + if (blockLine === findLine) { + matchingLines++ + } + } + } + + if (totalNonEmptyLines === 0 || matchingLines / totalNonEmptyLines >= 0.5) { + yield block + break + } + } + break + } + } + } +} + +export const MultiOccurrenceReplacer: Replacer = function* (content, find) { + let startIndex = 0 + + while (true) { + const index = content.indexOf(find, startIndex) + if (index === -1) break + + yield find + startIndex = index + find.length + } +} + +/** + * All replacers in order of specificity + */ +export const ALL_REPLACERS: Replacer[] = [ + SimpleReplacer, + LineTrimmedReplacer, + BlockAnchorReplacer, + WhitespaceNormalizedReplacer, + IndentationFlexibleReplacer, + EscapeNormalizedReplacer, + TrimmedBoundaryReplacer, + ContextAwareReplacer, + MultiOccurrenceReplacer +] + +/** + * Replace oldString with newString in content using fuzzy matching + */ +export function replaceWithFuzzyMatch( + content: string, + oldString: string, + newString: string, + replaceAll = false +): string { + if (oldString === newString) { + throw new Error('old_string and new_string must be different') + } + + let notFound = true + + for (const replacer of ALL_REPLACERS) { + for (const search of replacer(content, oldString)) { + const index = content.indexOf(search) + if (index === -1) continue + notFound = false + if (replaceAll) { + return content.replaceAll(search, newString) + } + const lastIndex = content.lastIndexOf(search) + if (index !== lastIndex) continue + return content.substring(0, index) + newString + content.substring(index + search.length) + } + } + + if (notFound) { + throw new Error('old_string not found in content') + } + throw new Error( + 'Found multiple matches for old_string. Provide more surrounding lines in old_string to identify the correct match.' + ) +} + +// ============================================================================ +// Binary File Detection +// ============================================================================ + +// Check if a file is likely binary +export async function isBinaryFile(filePath: string): Promise { + try { + const buffer = Buffer.alloc(4096) + const fd = await fs.open(filePath, 'r') + const { bytesRead } = await fd.read(buffer, 0, buffer.length, 0) + await fd.close() + + if (bytesRead === 0) return false + + const view = buffer.subarray(0, bytesRead) + + let zeroBytes = 0 + let evenZeros = 0 + let oddZeros = 0 + let nonPrintable = 0 + + for (let i = 0; i < view.length; i++) { + const b = view[i] + + if (b === 0) { + zeroBytes++ + if (i % 2 === 0) evenZeros++ + else oddZeros++ + continue + } + + // treat common whitespace as printable + if (b === 9 || b === 10 || b === 13) continue + + // basic ASCII printable range + if (b >= 32 && b <= 126) continue + + // bytes >= 128 are likely part of UTF-8 sequences; count as printable + if (b >= 128) continue + + nonPrintable++ + } + + // If there are lots of null bytes, it's probably binary unless it looks like UTF-16 text. + if (zeroBytes > 0) { + const evenSlots = Math.ceil(view.length / 2) + const oddSlots = Math.floor(view.length / 2) + const evenZeroRatio = evenSlots > 0 ? evenZeros / evenSlots : 0 + const oddZeroRatio = oddSlots > 0 ? oddZeros / oddSlots : 0 + + // UTF-16LE/BE tends to have zeros on every other byte. + if (evenZeroRatio > 0.7 || oddZeroRatio > 0.7) return false + + if (zeroBytes / view.length > 0.05) return true + } + + // Heuristic: too many non-printable bytes => binary. + return nonPrintable / view.length > 0.3 + } catch { + return false + } +} + +// ============================================================================ +// Ripgrep Utilities +// ============================================================================ + +export interface RipgrepResult { + ok: boolean + stdout: string + exitCode: number | null +} + +export function getRipgrepAddonPath(): string { + const pkgJsonPath = require.resolve('@anthropic-ai/claude-agent-sdk/package.json') + const pkgRoot = path.dirname(pkgJsonPath) + const platform = isMac ? 'darwin' : isWin ? 'win32' : 'linux' + const arch = process.arch === 'arm64' ? 'arm64' : 'x64' + return path.join(pkgRoot, 'vendor', 'ripgrep', `${arch}-${platform}`, 'ripgrep.node') +} + +export async function runRipgrep(args: string[]): Promise { + const addonPath = getRipgrepAddonPath() + const childScript = `const { ripgrepMain } = require(process.env.RIPGREP_ADDON_PATH); process.exit(ripgrepMain(process.argv.slice(1)));` + + return new Promise((resolve) => { + const child = spawn(process.execPath, ['--eval', childScript, 'rg', ...args], { + cwd: process.cwd(), + env: { + ...process.env, + ELECTRON_RUN_AS_NODE: '1', + RIPGREP_ADDON_PATH: addonPath + }, + stdio: ['ignore', 'pipe', 'pipe'] + }) + + let stdout = '' + + child.stdout?.on('data', (chunk) => { + stdout += chunk.toString('utf-8') + }) + + child.on('error', () => { + resolve({ ok: false, stdout: '', exitCode: null }) + }) + + child.on('close', (code) => { + resolve({ ok: true, stdout, exitCode: code }) + }) + }) +} diff --git a/src/main/services/BackupManager.ts b/src/main/services/BackupManager.ts index f331254fdf..e08bbd4d7b 100644 --- a/src/main/services/BackupManager.ts +++ b/src/main/services/BackupManager.ts @@ -1,3 +1,19 @@ +/** + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- + */ import { loggerService } from '@logger' import { IpcChannel } from '@shared/IpcChannel' import type { WebDavConfig } from '@types' @@ -767,6 +783,56 @@ class BackupManager { const s3Client = this.getS3Storage(s3Config) return await s3Client.checkConnection() } + + /** + * Create a temporary backup for LAN transfer + * Creates a lightweight backup (skipBackupFile=true) in the temp directory + * Returns the path to the created ZIP file + */ + async createLanTransferBackup(_: Electron.IpcMainInvokeEvent, data: string): Promise { + const timestamp = new Date() + .toISOString() + .replace(/[-:T.Z]/g, '') + .slice(0, 12) + const fileName = `cherry-studio.${timestamp}.zip` + const tempPath = path.join(app.getPath('temp'), 'cherry-studio', 'lan-transfer') + + // Ensure temp directory exists + await fs.ensureDir(tempPath) + + // Create backup with skipBackupFile=true (no Data folder) + const backupedFilePath = await this.backup(_, fileName, data, tempPath, true) + + logger.info(`[BackupManager] Created LAN transfer backup at: ${backupedFilePath}`) + return backupedFilePath + } + + /** + * Delete a temporary backup file after LAN transfer completes + */ + async deleteTempBackup(_: Electron.IpcMainInvokeEvent, filePath: string): Promise { + try { + // Security check: only allow deletion within temp directory + const tempBase = path.normalize(path.join(app.getPath('temp'), 'cherry-studio', 'lan-transfer')) + const resolvedPath = path.normalize(path.resolve(filePath)) + + // Use normalized paths with trailing separator to prevent prefix attacks (e.g., /temp-evil) + if (!resolvedPath.startsWith(tempBase + path.sep) && resolvedPath !== tempBase) { + logger.warn(`[BackupManager] Attempted to delete file outside temp directory: ${filePath}`) + return false + } + + if (await fs.pathExists(resolvedPath)) { + await fs.remove(resolvedPath) + logger.info(`[BackupManager] Deleted temp backup: ${resolvedPath}`) + return true + } + return false + } catch (error) { + logger.error('[BackupManager] Failed to delete temp backup:', error as Error) + return false + } + } } export default BackupManager diff --git a/src/main/services/CacheService.ts b/src/main/services/CacheService.ts index d2984a9984..4f2e2f8b20 100644 --- a/src/main/services/CacheService.ts +++ b/src/main/services/CacheService.ts @@ -1,3 +1,19 @@ +/** + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- + */ interface CacheItem { data: T timestamp: number diff --git a/src/main/services/ConfigManager.ts b/src/main/services/ConfigManager.ts index 61e285ac1b..98537c85a1 100644 --- a/src/main/services/ConfigManager.ts +++ b/src/main/services/ConfigManager.ts @@ -1,3 +1,19 @@ +/** + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- + */ import type { UpgradeChannel } from '@shared/config/constant' import { defaultLanguage, ZOOM_SHORTCUTS } from '@shared/config/constant' import type { LanguageVarious, Shortcut } from '@types' @@ -31,7 +47,9 @@ export enum ConfigKeys { DisableHardwareAcceleration = 'disableHardwareAcceleration', Proxy = 'proxy', EnableDeveloperMode = 'enableDeveloperMode', - ClientId = 'clientId' + ClientId = 'clientId', + GitBashPath = 'gitBashPath', + GitBashPathSource = 'gitBashPathSource' // 'manual' | 'auto' | null } export class ConfigManager { diff --git a/src/main/services/FileStorage.ts b/src/main/services/FileStorage.ts index 95bf7da7e7..c6f67adaae 100644 --- a/src/main/services/FileStorage.ts +++ b/src/main/services/FileStorage.ts @@ -3,7 +3,7 @@ import { checkName, findCommonRoot, getFilesDir, - getFileType, + getFileType as getFileTypeByExt, getName, getNotesDir, getTempDir, @@ -12,13 +12,13 @@ import { } from '@main/utils/file' import { documentExts, imageExts, KB, MB } from '@shared/config/constant' import type { FileMetadata, NotesTreeNode } from '@types' +import { FileTypes } from '@types' import chardet from 'chardet' import type { FSWatcher } from 'chokidar' import chokidar from 'chokidar' import * as crypto from 'crypto' import type { OpenDialogOptions, OpenDialogReturnValue, SaveDialogOptions, SaveDialogReturnValue } from 'electron' -import { app } from 'electron' -import { dialog, net, shell } from 'electron' +import { app, dialog, net, shell } from 'electron' import * as fs from 'fs' import { writeFileSync } from 'fs' import { readFile } from 'fs/promises' @@ -132,16 +132,18 @@ interface DirectoryListOptions { includeDirectories?: boolean maxEntries?: number searchPattern?: string + fuzzy?: boolean } const DEFAULT_DIRECTORY_LIST_OPTIONS: Required = { recursive: true, - maxDepth: 3, + maxDepth: 10, includeHidden: false, includeFiles: true, includeDirectories: true, - maxEntries: 10, - searchPattern: '.' + maxEntries: 20, + searchPattern: '.', + fuzzy: true } class FileStorage { @@ -165,7 +167,7 @@ class FileStorage { fs.mkdirSync(this.storageDir, { recursive: true }) } if (!fs.existsSync(this.notesDir)) { - fs.mkdirSync(this.storageDir, { recursive: true }) + fs.mkdirSync(this.notesDir, { recursive: true }) } if (!fs.existsSync(this.tempDir)) { fs.mkdirSync(this.tempDir, { recursive: true }) @@ -187,7 +189,7 @@ class FileStorage { }) } - findDuplicateFile = async (filePath: string): Promise => { + private findDuplicateFile = async (filePath: string): Promise => { const stats = fs.statSync(filePath) logger.debug(`stats: ${stats}, filePath: ${filePath}`) const fileSize = stats.size @@ -206,6 +208,8 @@ class FileStorage { if (originalHash === storedHash) { const ext = path.extname(file) const id = path.basename(file, ext) + const type = await this.getFileType(filePath) + return { id, origin_name: file, @@ -214,7 +218,7 @@ class FileStorage { created_at: storedStats.birthtime.toISOString(), size: storedStats.size, ext, - type: getFileType(ext), + type, count: 2 } } @@ -224,6 +228,13 @@ class FileStorage { return null } + public getFileType = async (filePath: string): Promise => { + const ext = path.extname(filePath) + const fileType = getFileTypeByExt(ext) + + return fileType === FileTypes.OTHER && (await this._isTextFile(filePath)) ? FileTypes.TEXT : fileType + } + public selectFile = async ( _: Electron.IpcMainInvokeEvent, options?: OpenDialogOptions @@ -243,7 +254,7 @@ class FileStorage { const fileMetadataPromises = result.filePaths.map(async (filePath) => { const stats = fs.statSync(filePath) const ext = path.extname(filePath) - const fileType = getFileType(ext) + const fileType = await this.getFileType(filePath) return { id: uuidv4(), @@ -309,7 +320,7 @@ class FileStorage { } const stats = await fs.promises.stat(destPath) - const fileType = getFileType(ext) + const fileType = await this.getFileType(destPath) const fileMetadata: FileMetadata = { id: uuid, @@ -334,8 +345,7 @@ class FileStorage { } const stats = fs.statSync(filePath) - const ext = path.extname(filePath) - const fileType = getFileType(ext) + const fileType = await this.getFileType(filePath) return { id: uuidv4(), @@ -344,7 +354,7 @@ class FileStorage { path: filePath, created_at: stats.birthtime.toISOString(), size: stats.size, - ext: ext, + ext: path.extname(filePath), type: fileType, count: 1 } @@ -692,7 +702,7 @@ class FileStorage { created_at: new Date().toISOString(), size: buffer.length, ext: ext.slice(1), - type: getFileType(ext), + type: getFileTypeByExt(ext), count: 1 } } catch (error) { @@ -742,7 +752,7 @@ class FileStorage { created_at: new Date().toISOString(), size: stats.size, ext: ext.slice(1), - type: getFileType(ext), + type: getFileTypeByExt(ext), count: 1 } } catch (error) { @@ -1040,10 +1050,226 @@ class FileStorage { } /** - * Search files by content pattern + * Fuzzy match: checks if all characters in query appear in text in order (case-insensitive) + * Example: "updater" matches "packages/update/src/node/updateController.ts" */ - private async searchByContent(resolvedPath: string, options: Required): Promise { - const args: string[] = ['-l'] + private isFuzzyMatch(text: string, query: string): boolean { + let i = 0 // text index + let j = 0 // query index + const textLower = text.toLowerCase() + const queryLower = query.toLowerCase() + + while (i < textLower.length && j < queryLower.length) { + if (textLower[i] === queryLower[j]) { + j++ + } + i++ + } + return j === queryLower.length + } + + /** + * Scoring constants for fuzzy match relevance ranking + * Higher values = higher priority in search results + */ + private static readonly SCORE_SEGMENT_MATCH = 60 // Per path segment that matches query + private static readonly SCORE_FILENAME_CONTAINS = 80 // Filename contains exact query substring + private static readonly SCORE_FILENAME_STARTS = 100 // Filename starts with query (highest priority) + private static readonly SCORE_CONSECUTIVE_CHAR = 15 // Per consecutive character match + private static readonly SCORE_WORD_BOUNDARY = 20 // Query matches start of a word + private static readonly PATH_LENGTH_PENALTY_FACTOR = 4 // Logarithmic penalty multiplier for longer paths + + /** + * Calculate fuzzy match score (higher is better) + * Scoring factors: + * - Consecutive character matches (bonus) + * - Match at word boundaries (bonus) + * - Shorter path length (bonus) + * - Match in filename vs directory (bonus) + */ + private getFuzzyMatchScore(filePath: string, query: string): number { + const pathLower = filePath.toLowerCase() + const queryLower = query.toLowerCase() + const fileName = filePath.split('/').pop() || '' + const fileNameLower = fileName.toLowerCase() + + let score = 0 + + // Count how many times query-related words appear in path segments + const pathSegments = pathLower.split(/[/\\]/) + let segmentMatchCount = 0 + for (const segment of pathSegments) { + if (this.isFuzzyMatch(segment, queryLower)) { + segmentMatchCount++ + } + } + score += segmentMatchCount * FileStorage.SCORE_SEGMENT_MATCH + + // Bonus for filename starting with query (stronger than generic "contains") + if (fileNameLower.startsWith(queryLower)) { + score += FileStorage.SCORE_FILENAME_STARTS + } else if (fileNameLower.includes(queryLower)) { + // Bonus for exact substring match in filename (e.g., "updater" in "RCUpdater.js") + score += FileStorage.SCORE_FILENAME_CONTAINS + } + + // Calculate consecutive match bonus + let i = 0 + let j = 0 + let consecutiveCount = 0 + let maxConsecutive = 0 + + while (i < pathLower.length && j < queryLower.length) { + if (pathLower[i] === queryLower[j]) { + consecutiveCount++ + maxConsecutive = Math.max(maxConsecutive, consecutiveCount) + j++ + } else { + consecutiveCount = 0 + } + i++ + } + score += maxConsecutive * FileStorage.SCORE_CONSECUTIVE_CHAR + + // Bonus for word boundary matches (e.g., "upd" matches start of "update") + // Only count once to avoid inflating scores for paths with repeated patterns + const boundaryPrefix = queryLower.slice(0, Math.min(3, queryLower.length)) + const words = pathLower.split(/[/\\._-]/) + for (const word of words) { + if (word.startsWith(boundaryPrefix)) { + score += FileStorage.SCORE_WORD_BOUNDARY + break + } + } + + // Penalty for longer paths (prefer shorter, more specific matches) + // Use logarithmic scaling to prevent long paths from dominating the score + // A 50-char path gets ~-16 penalty, 100-char gets ~-18, 200-char gets ~-21 + score -= Math.log(filePath.length + 1) * FileStorage.PATH_LENGTH_PENALTY_FACTOR + + return score + } + + /** + * Convert query to glob pattern for ripgrep pre-filtering + * e.g., "updater" -> "*u*p*d*a*t*e*r*" + */ + private queryToGlobPattern(query: string): string { + // Escape special glob characters (including ! for negation) + const escaped = query.replace(/[[\]{}()*+?.,\\^$|#!]/g, '\\$&') + // Convert to fuzzy glob: each char separated by * + return '*' + escaped.split('').join('*') + '*' + } + + /** + * Greedy substring match: check if all characters in query can be matched + * by finding consecutive substrings in text (not necessarily single chars) + * e.g., "updatercontroller" matches "updateController" by: + * "update" + "r" (from Controller) + "controller" + */ + private isGreedySubstringMatch(text: string, query: string): boolean { + const textLower = text.toLowerCase() + const queryLower = query.toLowerCase() + + let queryIndex = 0 + let searchStart = 0 + + while (queryIndex < queryLower.length) { + // Try to find the longest matching substring starting at queryIndex + let bestMatchLen = 0 + let bestMatchPos = -1 + + for (let len = queryLower.length - queryIndex; len >= 1; len--) { + const substr = queryLower.slice(queryIndex, queryIndex + len) + const foundAt = textLower.indexOf(substr, searchStart) + if (foundAt !== -1) { + bestMatchLen = len + bestMatchPos = foundAt + break // Found longest possible match + } + } + + if (bestMatchLen === 0) { + // No substring match found, query cannot be matched + return false + } + + queryIndex += bestMatchLen + searchStart = bestMatchPos + bestMatchLen + } + + return true + } + + /** + * Calculate greedy substring match score (higher is better) + * Rewards: fewer match fragments, shorter match span, matches in filename + */ + private getGreedyMatchScore(filePath: string, query: string): number { + const textLower = filePath.toLowerCase() + const queryLower = query.toLowerCase() + const fileName = filePath.split('/').pop() || '' + const fileNameLower = fileName.toLowerCase() + + let queryIndex = 0 + let searchStart = 0 + let fragmentCount = 0 + let firstMatchPos = -1 + let lastMatchEnd = 0 + + while (queryIndex < queryLower.length) { + let bestMatchLen = 0 + let bestMatchPos = -1 + + for (let len = queryLower.length - queryIndex; len >= 1; len--) { + const substr = queryLower.slice(queryIndex, queryIndex + len) + const foundAt = textLower.indexOf(substr, searchStart) + if (foundAt !== -1) { + bestMatchLen = len + bestMatchPos = foundAt + break + } + } + + if (bestMatchLen === 0) { + return -Infinity // No match + } + + fragmentCount++ + if (firstMatchPos === -1) firstMatchPos = bestMatchPos + lastMatchEnd = bestMatchPos + bestMatchLen + queryIndex += bestMatchLen + searchStart = lastMatchEnd + } + + const matchSpan = lastMatchEnd - firstMatchPos + let score = 0 + + // Fewer fragments = better (single continuous match is best) + // Max bonus when fragmentCount=1, decreases as fragments increase + score += Math.max(0, 100 - (fragmentCount - 1) * 30) + + // Shorter span relative to query length = better (tighter match) + // Perfect match: span equals query length + const spanRatio = queryLower.length / matchSpan + score += spanRatio * 50 + + // Bonus for match in filename + if (this.isGreedySubstringMatch(fileNameLower, queryLower)) { + score += 80 + } + + // Penalty for longer paths + score -= Math.log(filePath.length + 1) * 4 + + return score + } + + /** + * Build common ripgrep arguments for file listing + */ + private buildRipgrepBaseArgs(options: Required, resolvedPath: string): string[] { + const args: string[] = ['--files'] // Handle hidden files if (!options.includeHidden) { @@ -1070,82 +1296,74 @@ class FileStorage { args.push('--max-depth', options.maxDepth.toString()) } - // Handle max count - if (options.maxEntries > 0) { - args.push('--max-count', options.maxEntries.toString()) - } - - // Add search pattern (search in content) - args.push(options.searchPattern) - - // Add the directory path args.push(resolvedPath) - const { exitCode, output } = await executeRipgrep(args) - - // Exit code 0 means files found, 1 means no files found (still success), 2+ means error - if (exitCode >= 2) { - throw new Error(`Ripgrep failed with exit code ${exitCode}: ${output}`) - } - - // Parse ripgrep output (already sorted by relevance) - const results = output - .split('\n') - .filter((line) => line.trim()) - .map((line) => line.replace(/\\/g, '/')) - .slice(0, options.maxEntries) - - return results + return args } private async listDirectoryWithRipgrep( resolvedPath: string, options: Required ): Promise { - const maxEntries = options.maxEntries + // Fuzzy search mode: use ripgrep glob for pre-filtering, then score in JS + if (options.fuzzy && options.searchPattern && options.searchPattern !== '.') { + const args = this.buildRipgrepBaseArgs(options, resolvedPath) - // Step 1: Search by filename first + // Insert glob pattern before the path (last element) + const globPattern = this.queryToGlobPattern(options.searchPattern) + args.splice(args.length - 1, 0, '--iglob', globPattern) + + const { exitCode, output } = await executeRipgrep(args) + + if (exitCode >= 2) { + throw new Error(`Ripgrep failed with exit code ${exitCode}: ${output}`) + } + + const filteredFiles = output + .split('\n') + .filter((line) => line.trim()) + .map((line) => line.replace(/\\/g, '/')) + + // If fuzzy glob found results, validate fuzzy match, sort and return + if (filteredFiles.length > 0) { + return filteredFiles + .filter((file) => this.isFuzzyMatch(file, options.searchPattern)) + .map((file) => ({ file, score: this.getFuzzyMatchScore(file, options.searchPattern) })) + .sort((a, b) => b.score - a.score) + .slice(0, options.maxEntries) + .map((item) => item.file) + } + + // Fallback: if no results, try greedy substring match on all files + logger.debug('Fuzzy glob returned no results, falling back to greedy substring match') + const fallbackArgs = this.buildRipgrepBaseArgs(options, resolvedPath) + + const fallbackResult = await executeRipgrep(fallbackArgs) + + if (fallbackResult.exitCode >= 2) { + return [] + } + + const allFiles = fallbackResult.output + .split('\n') + .filter((line) => line.trim()) + .map((line) => line.replace(/\\/g, '/')) + + const greedyMatched = allFiles.filter((file) => this.isGreedySubstringMatch(file, options.searchPattern)) + + return greedyMatched + .map((file) => ({ file, score: this.getGreedyMatchScore(file, options.searchPattern) })) + .sort((a, b) => b.score - a.score) + .slice(0, options.maxEntries) + .map((item) => item.file) + } + + // Fallback: search by filename only (non-fuzzy mode) logger.debug('Searching by filename pattern', { pattern: options.searchPattern, path: resolvedPath }) const filenameResults = await this.searchByFilename(resolvedPath, options) logger.debug('Found matches by filename', { count: filenameResults.length }) - - // If we have enough filename matches, return them - if (filenameResults.length >= maxEntries) { - return filenameResults.slice(0, maxEntries) - } - - // Step 2: If filename matches are less than maxEntries, search by content to fill up - logger.debug('Filename matches insufficient, searching by content to fill up', { - filenameCount: filenameResults.length, - needed: maxEntries - filenameResults.length - }) - - // Adjust maxEntries for content search to get enough results - const contentOptions = { - ...options, - maxEntries: maxEntries - filenameResults.length + 20 // Request extra to account for duplicates - } - - const contentResults = await this.searchByContent(resolvedPath, contentOptions) - - logger.debug('Found matches by content', { count: contentResults.length }) - - // Combine results: filename matches first, then content matches (deduplicated) - const combined = [...filenameResults] - const filenameSet = new Set(filenameResults) - - for (const filePath of contentResults) { - if (!filenameSet.has(filePath)) { - combined.push(filePath) - if (combined.length >= maxEntries) { - break - } - } - } - - logger.debug('Combined results', { total: combined.length, filenameCount: filenameResults.length }) - return combined.slice(0, maxEntries) + return filenameResults.slice(0, options.maxEntries) } public validateNotesDirectory = async (_: Electron.IpcMainInvokeEvent, dirPath: string): Promise => { @@ -1319,7 +1537,7 @@ class FileStorage { await fs.promises.writeFile(destPath, buffer) const stats = await fs.promises.stat(destPath) - const fileType = getFileType(ext) + const fileType = await this.getFileType(destPath) return { id: uuid, @@ -1612,6 +1830,10 @@ class FileStorage { } public isTextFile = async (_: Electron.IpcMainInvokeEvent, filePath: string): Promise => { + return this._isTextFile(filePath) + } + + private _isTextFile = async (filePath: string): Promise => { try { const isBinary = await isBinaryFile(filePath) if (isBinary) { diff --git a/src/main/services/LocalTransferService.ts b/src/main/services/LocalTransferService.ts new file mode 100644 index 0000000000..bc2743757c --- /dev/null +++ b/src/main/services/LocalTransferService.ts @@ -0,0 +1,207 @@ +import { loggerService } from '@logger' +import type { LocalTransferPeer, LocalTransferState } from '@shared/config/types' +import { IpcChannel } from '@shared/IpcChannel' +import type { Browser, Service } from 'bonjour-service' +import Bonjour from 'bonjour-service' + +import { windowService } from './WindowService' + +const SERVICE_TYPE = 'cherrystudio' +const SERVICE_PROTOCOL = 'tcp' as const + +const logger = loggerService.withContext('LocalTransferService') + +type StartDiscoveryOptions = { + resetList?: boolean +} + +class LocalTransferService { + private static instance: LocalTransferService + private bonjour: Bonjour | null = null + private browser: Browser | null = null + private services = new Map() + private isScanning = false + private lastScanStartedAt?: number + private lastUpdatedAt = Date.now() + private lastError?: string + + private constructor() {} + + public static getInstance(): LocalTransferService { + if (!LocalTransferService.instance) { + LocalTransferService.instance = new LocalTransferService() + } + return LocalTransferService.instance + } + + public startDiscovery(options?: StartDiscoveryOptions): LocalTransferState { + if (options?.resetList) { + this.services.clear() + } + + this.isScanning = true + this.lastScanStartedAt = Date.now() + this.lastUpdatedAt = Date.now() + this.lastError = undefined + this.restartBrowser() + this.broadcastState() + return this.getState() + } + + public stopDiscovery(): LocalTransferState { + if (this.browser) { + try { + this.browser.stop() + } catch (error) { + logger.warn('Failed to stop local transfer browser', error as Error) + } + } + this.isScanning = false + this.lastUpdatedAt = Date.now() + this.broadcastState() + return this.getState() + } + + public getState(): LocalTransferState { + const services = Array.from(this.services.values()).sort((a, b) => a.name.localeCompare(b.name)) + return { + services, + isScanning: this.isScanning, + lastScanStartedAt: this.lastScanStartedAt, + lastUpdatedAt: this.lastUpdatedAt, + lastError: this.lastError + } + } + + public getPeerById(id: string): LocalTransferPeer | undefined { + return this.services.get(id) + } + + public dispose(): void { + this.stopDiscovery() + this.services.clear() + this.browser?.removeAllListeners() + this.browser = null + if (this.bonjour) { + try { + this.bonjour.destroy() + } catch (error) { + logger.warn('Failed to destroy Bonjour instance', error as Error) + } + this.bonjour = null + } + } + + private getBonjour(): Bonjour { + if (!this.bonjour) { + this.bonjour = new Bonjour() + } + return this.bonjour + } + + private restartBrowser(): void { + // Clean up existing browser + if (this.browser) { + this.browser.removeAllListeners() + try { + this.browser.stop() + } catch (error) { + logger.warn('Error while stopping Bonjour browser', error as Error) + } + this.browser = null + } + + // Destroy and recreate Bonjour instance to prevent socket leaks + if (this.bonjour) { + try { + this.bonjour.destroy() + } catch (error) { + logger.warn('Error while destroying Bonjour instance', error as Error) + } + this.bonjour = null + } + + const browser = this.getBonjour().find({ type: SERVICE_TYPE, protocol: SERVICE_PROTOCOL }) + this.browser = browser + this.bindBrowserEvents(browser) + + try { + browser.start() + logger.info('Local transfer discovery started') + } catch (error) { + const err = error instanceof Error ? error : new Error(String(error)) + this.lastError = err.message + logger.error('Failed to start local transfer discovery', err) + } + } + + private bindBrowserEvents(browser: Browser) { + browser.on('up', (service) => { + const peer = this.normalizeService(service) + logger.info(`LAN peer detected: ${peer.name} (${peer.addresses.join(', ')})`) + this.services.set(peer.id, peer) + this.lastUpdatedAt = Date.now() + this.broadcastState() + }) + + browser.on('down', (service) => { + const key = this.buildServiceKey(service.fqdn || service.name, service.host, service.port) + if (this.services.delete(key)) { + logger.info(`LAN peer removed: ${service.name}`) + this.lastUpdatedAt = Date.now() + this.broadcastState() + } + }) + + browser.on('error', (error) => { + const err = error instanceof Error ? error : new Error(String(error)) + logger.error('Local transfer discovery error', err) + this.lastError = err.message + this.broadcastState() + }) + } + + private normalizeService(service: Service): LocalTransferPeer { + const addressCandidates = [...(service.addresses || []), service.referer?.address].filter( + (value): value is string => typeof value === 'string' && value.length > 0 + ) + const addresses = Array.from(new Set(addressCandidates)) + const txtEntries = Object.entries(service.txt || {}) + const txt = + txtEntries.length > 0 + ? Object.fromEntries( + txtEntries.map(([key, value]) => [key, value === undefined || value === null ? '' : String(value)]) + ) + : undefined + + const peer: LocalTransferPeer = { + id: this.buildServiceKey(service.fqdn || service.name, service.host, service.port), + name: service.name, + host: service.host, + fqdn: service.fqdn, + port: service.port, + type: service.type, + protocol: service.protocol, + addresses, + txt, + updatedAt: Date.now() + } + + return peer + } + + private buildServiceKey(name?: string, host?: string, port?: number): string { + const raw = [name, host, port?.toString()].filter(Boolean).join('-') + return raw || `service-${Date.now()}` + } + + private broadcastState() { + const mainWindow = windowService.getMainWindow() + if (!mainWindow || mainWindow.isDestroyed()) { + return + } + mainWindow.webContents.send(IpcChannel.LocalTransfer_ServicesUpdated, this.getState()) + } +} + +export const localTransferService = LocalTransferService.getInstance() diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index 3925376226..7d36e6d7e3 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -6,7 +6,7 @@ import { loggerService } from '@logger' import { createInMemoryMCPServer } from '@main/mcpServers/factory' import { makeSureDirExists, removeEnvProxy } from '@main/utils' import { buildFunctionCallToolName } from '@main/utils/mcp' -import { getBinaryName, getBinaryPath } from '@main/utils/process' +import { findCommandInShellEnv, getBinaryName, getBinaryPath, isBinaryExists } from '@main/utils/process' import getLoginShellEnvironment from '@main/utils/shell-env' import { TraceMethod, withSpanFunc } from '@mcp-trace/trace-core' import { Client } from '@modelcontextprotocol/sdk/client/index.js' @@ -33,6 +33,7 @@ import { import { nanoid } from '@reduxjs/toolkit' import { HOME_CHERRY_DIR } from '@shared/config/constant' import type { MCPProgressEvent } from '@shared/config/types' +import type { MCPServerLogEntry } from '@shared/config/types' import { IpcChannel } from '@shared/IpcChannel' import { defaultAppHeaders } from '@shared/utils' import { @@ -56,6 +57,7 @@ import { CacheService } from './CacheService' import DxtService from './DxtService' import { CallBackServer } from './mcp/oauth/callback' import { McpOAuthClientProvider } from './mcp/oauth/provider' +import { ServerLogBuffer } from './mcp/ServerLogBuffer' import { windowService } from './WindowService' // Generic type for caching wrapped functions @@ -142,6 +144,7 @@ class McpService { private pendingClients: Map> = new Map() private dxtService = new DxtService() private activeToolCalls: Map = new Map() + private serverLogs = new ServerLogBuffer(200) constructor() { this.initClient = this.initClient.bind(this) @@ -159,6 +162,7 @@ class McpService { this.cleanup = this.cleanup.bind(this) this.checkMcpConnectivity = this.checkMcpConnectivity.bind(this) this.getServerVersion = this.getServerVersion.bind(this) + this.getServerLogs = this.getServerLogs.bind(this) } private getServerKey(server: MCPServer): string { @@ -172,6 +176,19 @@ class McpService { }) } + private emitServerLog(server: MCPServer, entry: MCPServerLogEntry) { + const serverKey = this.getServerKey(server) + this.serverLogs.append(serverKey, entry) + const mainWindow = windowService.getMainWindow() + if (mainWindow) { + mainWindow.webContents.send(IpcChannel.Mcp_ServerLog, { ...entry, serverId: server.id }) + } + } + + public getServerLogs(_: Electron.IpcMainInvokeEvent, server: MCPServer): MCPServerLogEntry[] { + return this.serverLogs.get(this.getServerKey(server)) + } + async initClient(server: MCPServer): Promise { const serverKey = this.getServerKey(server) @@ -232,6 +249,26 @@ class McpService { StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport > => { // Create appropriate transport based on configuration + + // Special case for nowledgeMem - uses HTTP transport instead of in-memory + if (isBuiltinMCPServer(server) && server.name === BuiltinMCPServerNames.nowledgeMem) { + const nowledgeMemUrl = 'http://127.0.0.1:14242/mcp' + const options: StreamableHTTPClientTransportOptions = { + fetch: async (url, init) => { + return net.fetch(typeof url === 'string' ? url : url.toString(), init) + }, + requestInit: { + headers: { + ...defaultAppHeaders(), + APP: 'Cherry Studio' + } + }, + authProvider + } + getServerLogger(server).debug(`Using StreamableHTTPClientTransport for ${server.name}`) + return new StreamableHTTPClientTransport(new URL(nowledgeMemUrl), options) + } + if (isBuiltinMCPServer(server) && server.name !== BuiltinMCPServerNames.mcpAutoInstall) { getServerLogger(server).debug(`Using in-memory transport`) const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair() @@ -281,6 +318,10 @@ class McpService { } else if (server.command) { let cmd = server.command + // Get login shell environment first - needed for command detection and server execution + // Note: getLoginShellEnvironment() is memoized, so subsequent calls are fast + const loginShellEnv = await getLoginShellEnvironment() + // For DXT servers, use resolved configuration with platform overrides and variable substitution if (server.dxtPath) { const resolvedConfig = this.dxtService.getResolvedMcpConfig(server.dxtPath) @@ -302,18 +343,45 @@ class McpService { } if (server.command === 'npx') { - cmd = await getBinaryPath('bun') - getServerLogger(server).debug(`Using command`, { command: cmd }) + // First, check if npx is available in user's shell environment + const npxPath = await findCommandInShellEnv('npx', loginShellEnv) - // add -x to args if args exist - if (args && args.length > 0) { - if (!args.includes('-y')) { - args.unshift('-y') - } - if (!args.includes('x')) { - args.unshift('x') + if (npxPath) { + // Use system npx + cmd = npxPath + getServerLogger(server).debug(`Using system npx`, { command: cmd }) + } else { + // System npx not found, try bundled bun as fallback + getServerLogger(server).debug(`System npx not found, checking for bundled bun`) + + if (await isBinaryExists('bun')) { + // Fall back to bundled bun + cmd = await getBinaryPath('bun') + getServerLogger(server).info(`Using bundled bun as fallback (npx not found in PATH)`, { + command: cmd + }) + + // Transform args for bun x format + if (args && args.length > 0) { + if (!args.includes('-y')) { + args.unshift('-y') + } + if (!args.includes('x')) { + args.unshift('x') + } + } + } else { + // Neither npx nor bun available + throw new Error( + 'npx not found in PATH and bundled bun is not available. This may indicate an installation issue.\n' + + 'Please either:\n' + + '1. Install Node.js (which includes npx) from https://nodejs.org\n' + + '2. Run the MCP dependencies installer from Settings\n' + + '3. Restart the application if you recently installed Node.js' + ) } } + if (server.registryUrl) { server.env = { ...server.env, @@ -328,7 +396,35 @@ class McpService { } } } else if (server.command === 'uvx' || server.command === 'uv') { - cmd = await getBinaryPath(server.command) + // First, check if uvx/uv is available in user's shell environment + const uvPath = await findCommandInShellEnv(server.command, loginShellEnv) + + if (uvPath) { + // Use system uvx/uv + cmd = uvPath + getServerLogger(server).debug(`Using system ${server.command}`, { command: cmd }) + } else { + // System command not found, try bundled version as fallback + getServerLogger(server).debug(`System ${server.command} not found, checking for bundled version`) + + if (await isBinaryExists(server.command)) { + // Fall back to bundled version + cmd = await getBinaryPath(server.command) + getServerLogger(server).info(`Using bundled ${server.command} as fallback (not found in PATH)`, { + command: cmd + }) + } else { + // Neither system nor bundled available + throw new Error( + `${server.command} not found in PATH and bundled version is not available. This may indicate an installation issue.\n` + + 'Please either:\n' + + '1. Install uv from https://github.com/astral-sh/uv\n' + + '2. Run the MCP dependencies installer from Settings\n' + + `3. Restart the application if you recently installed ${server.command}` + ) + } + } + if (server.registryUrl) { server.env = { ...server.env, @@ -339,8 +435,6 @@ class McpService { } getServerLogger(server).debug(`Starting server`, { command: cmd, args }) - // Logger.info(`[MCP] Environment variables for server:`, server.env) - const loginShellEnv = await getLoginShellEnvironment() // Bun not support proxy https://github.com/oven-sh/bun/issues/16812 if (cmd.includes('bun')) { @@ -366,9 +460,18 @@ class McpService { } const stdioTransport = new StdioClientTransport(transportOptions) - stdioTransport.stderr?.on('data', (data) => - getServerLogger(server).debug(`Stdio stderr`, { data: data.toString() }) - ) + stdioTransport.stderr?.on('data', (data) => { + const msg = data.toString() + getServerLogger(server).debug(`Stdio stderr`, { data: msg }) + this.emitServerLog(server, { + timestamp: Date.now(), + level: 'stderr', + message: msg.trim(), + source: 'stdio' + }) + }) + // StdioClientTransport does not expose stdout as a readable stream for raw logging + // (stdout is reserved for JSON-RPC). Avoid attaching a listener that would never fire. return stdioTransport } else { throw new Error('Either baseUrl or command must be provided') @@ -436,6 +539,13 @@ class McpService { } } + this.emitServerLog(server, { + timestamp: Date.now(), + level: 'info', + message: 'Server connected', + source: 'client' + }) + // Store the new client in the cache this.clients.set(serverKey, client) @@ -446,9 +556,22 @@ class McpService { this.clearServerCache(serverKey) logger.debug(`Activated server: ${server.name}`) + this.emitServerLog(server, { + timestamp: Date.now(), + level: 'info', + message: 'Server activated', + source: 'client' + }) return client } catch (error) { getServerLogger(server).error(`Error activating server ${server.name}`, error as Error) + this.emitServerLog(server, { + timestamp: Date.now(), + level: 'error', + message: `Error activating server: ${(error as Error)?.message}`, + data: redactSensitive(error), + source: 'client' + }) throw error } } finally { @@ -506,6 +629,16 @@ class McpService { // Set up logging message notification handler client.setNotificationHandler(LoggingMessageNotificationSchema, async (notification) => { logger.debug(`Message from server ${server.name}:`, notification.params) + const msg = notification.params?.message + if (msg) { + this.emitServerLog(server, { + timestamp: Date.now(), + level: (notification.params?.level as MCPServerLogEntry['level']) || 'info', + message: typeof msg === 'string' ? msg : JSON.stringify(msg), + data: redactSensitive(notification.params?.data), + source: notification.params?.logger || 'server' + }) + } }) getServerLogger(server).debug(`Set up notification handlers`) @@ -540,6 +673,7 @@ class McpService { this.clients.delete(serverKey) // Clear all caches for this server this.clearServerCache(serverKey) + this.serverLogs.remove(serverKey) } else { logger.warn(`No client found for server`, { serverKey }) } @@ -548,6 +682,12 @@ class McpService { async stopServer(_: Electron.IpcMainInvokeEvent, server: MCPServer) { const serverKey = this.getServerKey(server) getServerLogger(server).debug(`Stopping server`) + this.emitServerLog(server, { + timestamp: Date.now(), + level: 'info', + message: 'Stopping server', + source: 'client' + }) await this.closeClient(serverKey) } @@ -574,6 +714,12 @@ class McpService { async restartServer(_: Electron.IpcMainInvokeEvent, server: MCPServer) { getServerLogger(server).debug(`Restarting server`) const serverKey = this.getServerKey(server) + this.emitServerLog(server, { + timestamp: Date.now(), + level: 'info', + message: 'Restarting server', + source: 'client' + }) await this.closeClient(serverKey) // Clear cache before restarting to ensure fresh data this.clearServerCache(serverKey) @@ -606,9 +752,22 @@ class McpService { // Attempt to list tools as a way to check connectivity await client.listTools() getServerLogger(server).debug(`Connectivity check successful`) + this.emitServerLog(server, { + timestamp: Date.now(), + level: 'info', + message: 'Connectivity check successful', + source: 'connectivity' + }) return true } catch (error) { getServerLogger(server).error(`Connectivity check failed`, error as Error) + this.emitServerLog(server, { + timestamp: Date.now(), + level: 'error', + message: `Connectivity check failed: ${(error as Error).message}`, + data: redactSensitive(error), + source: 'connectivity' + }) // Close the client if connectivity check fails to ensure a clean state for the next attempt const serverKey = this.getServerKey(server) await this.closeClient(serverKey) @@ -626,7 +785,7 @@ class McpService { ...tool, inputSchema: z.parse(MCPToolInputSchema, tool.inputSchema), outputSchema: tool.outputSchema ? z.parse(MCPToolOutputSchema, tool.outputSchema) : undefined, - id: buildFunctionCallToolName(server.name, tool.name, server.id), + id: buildFunctionCallToolName(server.name, tool.name), serverId: server.id, serverName: server.name, type: 'mcp' diff --git a/src/main/services/OvmsManager.ts b/src/main/services/OvmsManager.ts index 3a32d74ecf..67d6d9a9df 100644 --- a/src/main/services/OvmsManager.ts +++ b/src/main/services/OvmsManager.ts @@ -3,6 +3,8 @@ import { homedir } from 'node:os' import { promisify } from 'node:util' import { loggerService } from '@logger' +import { isWin } from '@main/constant' +import { getCpuName } from '@main/utils/system' import { HOME_CHERRY_DIR } from '@shared/config/constant' import * as fs from 'fs-extra' import * as path from 'path' @@ -11,6 +13,8 @@ const logger = loggerService.withContext('OvmsManager') const execAsync = promisify(exec) +export const isOvmsSupported = isWin && getCpuName().toLowerCase().includes('intel') + interface OvmsProcess { pid: number path: string @@ -29,6 +33,12 @@ interface OvmsConfig { class OvmsManager { private ovms: OvmsProcess | null = null + constructor() { + if (!isOvmsSupported) { + throw new Error('OVMS Manager is only supported on Windows platform with Intel CPU.') + } + } + /** * Recursively terminate a process and all its child processes * @param pid Process ID to terminate @@ -102,32 +112,10 @@ class OvmsManager { */ public async stopOvms(): Promise<{ success: boolean; message?: string }> { try { - // Check if OVMS process is running - const psCommand = `Get-Process -Name "ovms" -ErrorAction SilentlyContinue | Select-Object Id, Path | ConvertTo-Json` - const { stdout } = await execAsync(`powershell -Command "${psCommand}"`) - - if (!stdout.trim()) { - logger.info('OVMS process is not running') - return { success: true, message: 'OVMS process is not running' } - } - - const processes = JSON.parse(stdout) - const processList = Array.isArray(processes) ? processes : [processes] - - if (processList.length === 0) { - logger.info('OVMS process is not running') - return { success: true, message: 'OVMS process is not running' } - } - - // Terminate all OVMS processes using terminalProcess - for (const process of processList) { - const result = await this.terminalProcess(process.Id) - if (!result.success) { - logger.error(`Failed to terminate OVMS process with PID: ${process.Id}, ${result.message}`) - return { success: false, message: `Failed to terminate OVMS process: ${result.message}` } - } - logger.info(`Terminated OVMS process with PID: ${process.Id}`) - } + // close the OVMS process + await execAsync( + `powershell -Command "Get-WmiObject Win32_Process | Where-Object { $_.CommandLine -like 'ovms.exe*' } | ForEach-Object { Stop-Process -Id $_.ProcessId -Force }"` + ) // Reset the ovms instance this.ovms = null @@ -584,4 +572,5 @@ class OvmsManager { } } -export default OvmsManager +// Export singleton instance +export const ovmsManager = isOvmsSupported ? new OvmsManager() : undefined diff --git a/src/main/services/ReduxService.ts b/src/main/services/ReduxService.ts index cdbaff42bf..8880691a24 100644 --- a/src/main/services/ReduxService.ts +++ b/src/main/services/ReduxService.ts @@ -1,3 +1,19 @@ +/** + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- + */ import { loggerService } from '@logger' import { IpcChannel } from '@shared/IpcChannel' import { ipcMain } from 'electron' diff --git a/src/main/services/SearchService.ts b/src/main/services/SearchService.ts index 8a4e42099a..6c69f80889 100644 --- a/src/main/services/SearchService.ts +++ b/src/main/services/SearchService.ts @@ -14,38 +14,36 @@ export class SearchService { return SearchService.instance } - constructor() { - // Initialize the service - } - - private async createNewSearchWindow(uid: string): Promise { + private async createNewSearchWindow(uid: string, show: boolean = false): Promise { const newWindow = new BrowserWindow({ - width: 800, - height: 600, - show: false, + width: 1280, + height: 768, + show, webPreferences: { nodeIntegration: true, contextIsolation: false, devTools: is.dev } }) - newWindow.webContents.session.webRequest.onBeforeSendHeaders({ urls: ['*://*/*'] }, (details, callback) => { - const headers = { - ...details.requestHeaders, - 'User-Agent': - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' - } - callback({ requestHeaders: headers }) - }) + this.searchWindows[uid] = newWindow - newWindow.on('closed', () => { - delete this.searchWindows[uid] - }) + newWindow.on('closed', () => delete this.searchWindows[uid]) + + newWindow.webContents.userAgent = + 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36' + return newWindow } - public async openSearchWindow(uid: string): Promise { - await this.createNewSearchWindow(uid) + public async openSearchWindow(uid: string, show: boolean = false): Promise { + const existingWindow = this.searchWindows[uid] + + if (existingWindow) { + show && existingWindow.show() + return + } + + await this.createNewSearchWindow(uid, show) } public async closeSearchWindow(uid: string): Promise { diff --git a/src/main/services/SelectionService.ts b/src/main/services/SelectionService.ts index a096dfcfd7..629e67401c 100644 --- a/src/main/services/SelectionService.ts +++ b/src/main/services/SelectionService.ts @@ -1393,6 +1393,56 @@ export class SelectionService { actionWindow.setAlwaysOnTop(isPinned) } + /** + * [Windows only] Manual window resize handler + * + * ELECTRON BUG WORKAROUND: + * In Electron, when using `frame: false` + `transparent: true`, the native window + * resize functionality is broken on Windows. This is a known Electron bug. + * See: https://github.com/electron/electron/issues/48554 + * + * This method can be removed once the Electron bug is fixed. + */ + public resizeActionWindow(actionWindow: BrowserWindow, deltaX: number, deltaY: number, direction: string): void { + const bounds = actionWindow.getBounds() + const minWidth = 300 + const minHeight = 200 + + let { x, y, width, height } = bounds + + // Handle horizontal resize + if (direction.includes('e')) { + width = Math.max(minWidth, width + deltaX) + } + if (direction.includes('w')) { + const newWidth = Math.max(minWidth, width - deltaX) + if (newWidth !== width) { + x = x + (width - newWidth) + width = newWidth + } + } + + // Handle vertical resize + if (direction.includes('s')) { + height = Math.max(minHeight, height + deltaY) + } + if (direction.includes('n')) { + const newHeight = Math.max(minHeight, height - deltaY) + if (newHeight !== height) { + y = y + (height - newHeight) + height = newHeight + } + } + + actionWindow.setBounds({ x, y, width, height }) + + // [Windows only] Update remembered window size for custom resize + // setBounds() may not trigger the 'resized' event, so we need to update manually + if (this.isRemeberWinSize) { + this.lastActionWindowSize = { width, height } + } + } + /** * Update trigger mode behavior * Switches between selection-based and alt-key based triggering @@ -1510,6 +1560,18 @@ export class SelectionService { } }) + // [Windows only] Electron bug workaround - can be removed once fixed + // See: https://github.com/electron/electron/issues/48554 + ipcMain.handle( + IpcChannel.Selection_ActionWindowResize, + (event, deltaX: number, deltaY: number, direction: string) => { + const actionWindow = BrowserWindow.fromWebContents(event.sender) + if (actionWindow) { + selectionService?.resizeActionWindow(actionWindow, deltaX, deltaY, direction) + } + } + ) + this.isIpcHandlerRegistered = true } diff --git a/src/main/services/ShortcutService.ts b/src/main/services/ShortcutService.ts index 583dbbd95c..c99e0b5dc0 100644 --- a/src/main/services/ShortcutService.ts +++ b/src/main/services/ShortcutService.ts @@ -1,3 +1,19 @@ +/** + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- + */ import { loggerService } from '@logger' import { handleZoomFactor } from '@main/utils/zoom' import type { Shortcut } from '@types' @@ -35,6 +51,15 @@ function getShortcutHandler(shortcut: Shortcut) { } case 'mini_window': return () => { + // 在处理器内部检查QuickAssistant状态,而不是在注册时检查 + const quickAssistantEnabled = configManager.getEnableQuickAssistant() + logger.info(`mini_window shortcut triggered, QuickAssistant enabled: ${quickAssistantEnabled}`) + + if (!quickAssistantEnabled) { + logger.warn('QuickAssistant is disabled, ignoring mini_window shortcut trigger') + return + } + windowService.toggleMiniWindow() } case 'selection_assistant_toggle': @@ -190,11 +215,10 @@ export function registerShortcuts(window: BrowserWindow) { break case 'mini_window': - //available only when QuickAssistant enabled - if (!configManager.getEnableQuickAssistant()) { - return - } + // 移除注册时的条件检查,在处理器内部进行检查 + logger.info(`Processing mini_window shortcut, enabled: ${shortcut.enabled}`) showMiniWindowAccelerator = formatShortcutKey(shortcut.shortcut) + logger.debug(`Mini window accelerator set to: ${showMiniWindowAccelerator}`) break case 'selection_assistant_toggle': diff --git a/src/main/services/StoreSyncService.ts b/src/main/services/StoreSyncService.ts index 57f07195b6..6013afdd57 100644 --- a/src/main/services/StoreSyncService.ts +++ b/src/main/services/StoreSyncService.ts @@ -1,3 +1,19 @@ +/** + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- + */ import { IpcChannel } from '@shared/IpcChannel' import type { StoreSyncAction } from '@types' import { BrowserWindow, ipcMain } from 'electron' diff --git a/src/main/services/WebSocketService.ts b/src/main/services/WebSocketService.ts deleted file mode 100644 index e52919e96a..0000000000 --- a/src/main/services/WebSocketService.ts +++ /dev/null @@ -1,359 +0,0 @@ -import { loggerService } from '@logger' -import type { WebSocketCandidatesResponse, WebSocketStatusResponse } from '@shared/config/types' -import * as fs from 'fs' -import { networkInterfaces } from 'os' -import * as path from 'path' -import type { Socket } from 'socket.io' -import { Server } from 'socket.io' - -import { windowService } from './WindowService' - -const logger = loggerService.withContext('WebSocketService') - -class WebSocketService { - private io: Server | null = null - private isStarted = false - private port = 7017 - private connectedClients = new Set() - - private getLocalIpAddress(): string | undefined { - const interfaces = networkInterfaces() - - // 按优先级排序的网络接口名称模式 - const interfacePriority = [ - // macOS: 以太网/Wi-Fi 优先 - /^en[0-9]+$/, // en0, en1 (以太网/Wi-Fi) - /^(en|eth)[0-9]+$/, // 以太网接口 - /^wlan[0-9]+$/, // 无线接口 - // Windows: 以太网/Wi-Fi 优先 - /^(Ethernet|Wi-Fi|Local Area Connection)/, - /^(Wi-Fi|无线网络连接)/, - // Linux: 以太网/Wi-Fi 优先 - /^(eth|enp|wlp|wlan)[0-9]+/, - // 虚拟化接口(低优先级) - /^bridge[0-9]+$/, // Docker bridge - /^veth[0-9]+$/, // Docker veth - /^docker[0-9]+/, // Docker interfaces - /^br-[0-9a-f]+/, // Docker bridge - /^vmnet[0-9]+$/, // VMware - /^vboxnet[0-9]+$/, // VirtualBox - // VPN 隧道接口(低优先级) - /^utun[0-9]+$/, // macOS VPN - /^tun[0-9]+$/, // Linux/Unix VPN - /^tap[0-9]+$/, // TAP interfaces - /^tailscale[0-9]*$/, // Tailscale VPN - /^wg[0-9]+$/ // WireGuard VPN - ] - - const candidates: Array<{ interface: string; address: string; priority: number }> = [] - - for (const [name, ifaces] of Object.entries(interfaces)) { - for (const iface of ifaces || []) { - if (iface.family === 'IPv4' && !iface.internal) { - // 计算接口优先级 - let priority = 999 // 默认最低优先级 - for (let i = 0; i < interfacePriority.length; i++) { - if (interfacePriority[i].test(name)) { - priority = i - break - } - } - - candidates.push({ - interface: name, - address: iface.address, - priority - }) - } - } - } - - if (candidates.length === 0) { - logger.warn('无法获取局域网 IP,使用默认 IP: 127.0.0.1') - return '127.0.0.1' - } - - // 按优先级排序,选择优先级最高的 - candidates.sort((a, b) => a.priority - b.priority) - const best = candidates[0] - - logger.info(`获取局域网 IP: ${best.address} (interface: ${best.interface})`) - return best.address - } - - public start = async (): Promise<{ success: boolean; port?: number; error?: string }> => { - if (this.isStarted && this.io) { - return { success: true, port: this.port } - } - - try { - this.io = new Server(this.port, { - cors: { - origin: '*', - methods: ['GET', 'POST'] - }, - transports: ['websocket', 'polling'], - allowEIO3: true, - pingTimeout: 60000, - pingInterval: 25000 - }) - - this.io.on('connection', (socket: Socket) => { - this.connectedClients.add(socket.id) - - const mainWindow = windowService.getMainWindow() - if (!mainWindow) { - logger.error('Main window is null, cannot send connection event') - } else { - mainWindow.webContents.send('websocket-client-connected', { - connected: true, - clientId: socket.id - }) - logger.info(`Connection event sent to renderer, total clients: ${this.connectedClients.size}`) - } - - socket.on('message', (data) => { - logger.info('Received message from mobile:', data) - mainWindow?.webContents.send('websocket-message-received', data) - socket.emit('message_received', { success: true }) - }) - - socket.on('disconnect', () => { - logger.info(`Client disconnected: ${socket.id}`) - this.connectedClients.delete(socket.id) - - if (this.connectedClients.size === 0) { - mainWindow?.webContents.send('websocket-client-connected', { - connected: false, - clientId: socket.id - }) - } - }) - }) - - // Engine 层面的事件监听 - this.io.engine.on('connection_error', (err) => { - logger.error('Engine connection error:', err) - }) - - this.io.engine.on('connection', (rawSocket) => { - const remoteAddr = rawSocket.request.connection.remoteAddress - logger.info(`[Engine] Raw connection from: ${remoteAddr}`) - logger.info(`[Engine] Transport: ${rawSocket.transport.name}`) - - rawSocket.on('packet', (packet: { type: string; data?: any }) => { - logger.info( - `[Engine] ← Packet from ${remoteAddr}: type="${packet.type}"`, - packet.data ? { data: packet.data } : {} - ) - }) - - rawSocket.on('packetCreate', (packet: { type: string; data?: any }) => { - logger.info(`[Engine] → Packet to ${remoteAddr}: type="${packet.type}"`) - }) - - rawSocket.on('close', (reason: string) => { - logger.warn(`[Engine] Connection closed from ${remoteAddr}, reason: ${reason}`) - }) - - rawSocket.on('error', (error: Error) => { - logger.error(`[Engine] Connection error from ${remoteAddr}:`, error) - }) - }) - - // Socket.IO 握手失败监听 - this.io.on('connection_error', (err) => { - logger.error('[Socket.IO] Connection error during handshake:', err) - }) - - this.isStarted = true - logger.info(`WebSocket server started on port ${this.port}`) - - return { success: true, port: this.port } - } catch (error) { - logger.error('Failed to start WebSocket server:', error as Error) - return { - success: false, - error: error instanceof Error ? error.message : 'Unknown error' - } - } - } - - public stop = async (): Promise<{ success: boolean }> => { - if (!this.isStarted || !this.io) { - return { success: true } - } - - try { - await new Promise((resolve) => { - this.io!.close(() => { - resolve() - }) - }) - - this.io = null - this.isStarted = false - this.connectedClients.clear() - logger.info('WebSocket server stopped') - - return { success: true } - } catch (error) { - logger.error('Failed to stop WebSocket server:', error as Error) - return { success: false } - } - } - - public getStatus = async (): Promise => { - return { - isRunning: this.isStarted, - port: this.isStarted ? this.port : undefined, - ip: this.isStarted ? this.getLocalIpAddress() : undefined, - clientConnected: this.connectedClients.size > 0 - } - } - - public getAllCandidates = async (): Promise => { - const interfaces = networkInterfaces() - - // 按优先级排序的网络接口名称模式 - const interfacePriority = [ - // macOS: 以太网/Wi-Fi 优先 - /^en[0-9]+$/, // en0, en1 (以太网/Wi-Fi) - /^(en|eth)[0-9]+$/, // 以太网接口 - /^wlan[0-9]+$/, // 无线接口 - // Windows: 以太网/Wi-Fi 优先 - /^(Ethernet|Wi-Fi|Local Area Connection)/, - /^(Wi-Fi|无线网络连接)/, - // Linux: 以太网/Wi-Fi 优先 - /^(eth|enp|wlp|wlan)[0-9]+/, - // 虚拟化接口(低优先级) - /^bridge[0-9]+$/, // Docker bridge - /^veth[0-9]+$/, // Docker veth - /^docker[0-9]+/, // Docker interfaces - /^br-[0-9a-f]+/, // Docker bridge - /^vmnet[0-9]+$/, // VMware - /^vboxnet[0-9]+$/, // VirtualBox - // VPN 隧道接口(低优先级) - /^utun[0-9]+$/, // macOS VPN - /^tun[0-9]+$/, // Linux/Unix VPN - /^tap[0-9]+$/, // TAP interfaces - /^tailscale[0-9]*$/, // Tailscale VPN - /^wg[0-9]+$/ // WireGuard VPN - ] - - const candidates: Array<{ host: string; interface: string; priority: number }> = [] - - for (const [name, ifaces] of Object.entries(interfaces)) { - for (const iface of ifaces || []) { - if (iface.family === 'IPv4' && !iface.internal) { - // 计算接口优先级 - let priority = 999 // 默认最低优先级 - for (let i = 0; i < interfacePriority.length; i++) { - if (interfacePriority[i].test(name)) { - priority = i - break - } - } - - candidates.push({ - host: iface.address, - interface: name, - priority - }) - - logger.debug(`Found interface: ${name} -> ${iface.address} (priority: ${priority})`) - } - } - } - - // 按优先级排序返回 - candidates.sort((a, b) => a.priority - b.priority) - logger.info( - `Found ${candidates.length} IP candidates: ${candidates.map((c) => `${c.host}(${c.interface})`).join(', ')}` - ) - return candidates - } - - public sendFile = async ( - _: Electron.IpcMainInvokeEvent, - filePath: string - ): Promise<{ success: boolean; error?: string }> => { - if (!this.isStarted || !this.io) { - const errorMsg = 'WebSocket server is not running.' - logger.error(errorMsg) - return { success: false, error: errorMsg } - } - - if (this.connectedClients.size === 0) { - const errorMsg = 'No client connected.' - logger.error(errorMsg) - return { success: false, error: errorMsg } - } - - const mainWindow = windowService.getMainWindow() - - return new Promise((resolve, reject) => { - const stats = fs.statSync(filePath) - const totalSize = stats.size - const filename = path.basename(filePath) - const stream = fs.createReadStream(filePath) - let bytesSent = 0 - const startTime = Date.now() - - logger.info(`Starting file transfer: ${filename} (${this.formatFileSize(totalSize)})`) - - // 向客户端发送文件开始的信号,包含文件名和总大小 - this.io!.emit('zip-file-start', { filename, totalSize }) - - stream.on('data', (chunk) => { - bytesSent += chunk.length - const progress = (bytesSent / totalSize) * 100 - - // 向客户端发送文件块 - this.io!.emit('zip-file-chunk', chunk) - - // 向渲染进程发送进度更新 - mainWindow?.webContents.send('file-send-progress', { progress }) - - // 每10%记录一次进度 - if (Math.floor(progress) % 10 === 0) { - const elapsed = (Date.now() - startTime) / 1000 - const speed = elapsed > 0 ? bytesSent / elapsed : 0 - logger.info(`Transfer progress: ${Math.floor(progress)}% (${this.formatFileSize(speed)}/s)`) - } - }) - - stream.on('end', () => { - const totalTime = (Date.now() - startTime) / 1000 - const avgSpeed = totalTime > 0 ? totalSize / totalTime : 0 - logger.info( - `File transfer completed: ${filename} in ${totalTime.toFixed(1)}s (${this.formatFileSize(avgSpeed)}/s)` - ) - - // 确保发送100%的进度 - mainWindow?.webContents.send('file-send-progress', { progress: 100 }) - // 向客户端发送文件结束的信号 - this.io!.emit('zip-file-end') - resolve({ success: true }) - }) - - stream.on('error', (error) => { - logger.error(`File transfer failed: ${filename}`, error) - reject({ - success: false, - error: error instanceof Error ? error.message : 'Unknown error' - }) - }) - }) - } - - private formatFileSize(bytes: number): string { - if (bytes === 0) return '0 B' - const k = 1024 - const sizes = ['B', 'KB', 'MB', 'GB'] - const i = Math.floor(Math.log(bytes) / Math.log(k)) - return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i] - } -} - -export default new WebSocketService() diff --git a/src/main/services/WebviewService.ts b/src/main/services/WebviewService.ts index fb2049de74..7af008bd7a 100644 --- a/src/main/services/WebviewService.ts +++ b/src/main/services/WebviewService.ts @@ -1,5 +1,6 @@ import { IpcChannel } from '@shared/IpcChannel' -import { app, session, shell, webContents } from 'electron' +import { app, dialog, session, shell, webContents } from 'electron' +import { promises as fs } from 'fs' /** * init the useragent of the webview session @@ -53,11 +54,17 @@ const attachKeyboardHandler = (contents: Electron.WebContents) => { return } - const isFindShortcut = (input.control || input.meta) && key === 'f' - const isEscape = key === 'escape' - const isEnter = key === 'enter' + // Helper to check if this is a shortcut we handle + const isHandledShortcut = (k: string) => { + const isFindShortcut = (input.control || input.meta) && k === 'f' + const isPrintShortcut = (input.control || input.meta) && k === 'p' + const isSaveShortcut = (input.control || input.meta) && k === 's' + const isEscape = k === 'escape' + const isEnter = k === 'enter' + return isFindShortcut || isPrintShortcut || isSaveShortcut || isEscape || isEnter + } - if (!isFindShortcut && !isEscape && !isEnter) { + if (!isHandledShortcut(key)) { return } @@ -66,11 +73,20 @@ const attachKeyboardHandler = (contents: Electron.WebContents) => { return } + const isFindShortcut = (input.control || input.meta) && key === 'f' + const isPrintShortcut = (input.control || input.meta) && key === 'p' + const isSaveShortcut = (input.control || input.meta) && key === 's' + // Always prevent Cmd/Ctrl+F to override the guest page's native find dialog if (isFindShortcut) { event.preventDefault() } + // Prevent default print/save dialogs and handle them with custom logic + if (isPrintShortcut || isSaveShortcut) { + event.preventDefault() + } + // Send the hotkey event to the renderer // The renderer will decide whether to preventDefault for Escape and Enter // based on whether the search bar is visible @@ -100,3 +116,130 @@ export function initWebviewHotkeys() { attachKeyboardHandler(contents) }) } + +/** + * Print webview content to PDF + * @param webviewId The webview webContents id + * @returns Path to saved PDF file or null if user cancelled + */ +export async function printWebviewToPDF(webviewId: number): Promise { + const webview = webContents.fromId(webviewId) + if (!webview) { + throw new Error('Webview not found') + } + + try { + // Get the page title for default filename + const pageTitle = await webview.executeJavaScript('document.title || "webpage"').catch(() => 'webpage') + // Sanitize filename by removing invalid characters + const sanitizedTitle = pageTitle.replace(/[<>:"/\\|?*]/g, '-').substring(0, 100) + const defaultFilename = sanitizedTitle ? `${sanitizedTitle}.pdf` : `webpage-${Date.now()}.pdf` + + // Show save dialog + const { canceled, filePath } = await dialog.showSaveDialog({ + title: 'Save as PDF', + defaultPath: defaultFilename, + filters: [{ name: 'PDF Files', extensions: ['pdf'] }] + }) + + if (canceled || !filePath) { + return null + } + + // Generate PDF with settings to capture full page + const pdfData = await webview.printToPDF({ + margins: { + marginType: 'default' + }, + printBackground: true, + landscape: false, + pageSize: 'A4', + preferCSSPageSize: true + }) + + // Save PDF to file + await fs.writeFile(filePath, pdfData) + + return filePath + } catch (error) { + throw new Error(`Failed to print to PDF: ${(error as Error).message}`) + } +} + +/** + * Save webview content as HTML + * @param webviewId The webview webContents id + * @returns Path to saved HTML file or null if user cancelled + */ +export async function saveWebviewAsHTML(webviewId: number): Promise { + const webview = webContents.fromId(webviewId) + if (!webview) { + throw new Error('Webview not found') + } + + try { + // Get the page title for default filename + const pageTitle = await webview.executeJavaScript('document.title || "webpage"').catch(() => 'webpage') + // Sanitize filename by removing invalid characters + const sanitizedTitle = pageTitle.replace(/[<>:"/\\|?*]/g, '-').substring(0, 100) + const defaultFilename = sanitizedTitle ? `${sanitizedTitle}.html` : `webpage-${Date.now()}.html` + + // Show save dialog + const { canceled, filePath } = await dialog.showSaveDialog({ + title: 'Save as HTML', + defaultPath: defaultFilename, + filters: [ + { name: 'HTML Files', extensions: ['html', 'htm'] }, + { name: 'All Files', extensions: ['*'] } + ] + }) + + if (canceled || !filePath) { + return null + } + + // Get the HTML content with safe error handling + const html = await webview.executeJavaScript(` + (() => { + try { + // Build complete DOCTYPE string if present + let doctype = ''; + if (document.doctype) { + const dt = document.doctype; + doctype = ''; + } + return doctype + (document.documentElement?.outerHTML || ''); + } catch (error) { + // Fallback: just return the HTML without DOCTYPE if there's an error + return document.documentElement?.outerHTML || ''; + } + })() + `) + + // Save HTML to file + await fs.writeFile(filePath, html, 'utf-8') + + return filePath + } catch (error) { + throw new Error(`Failed to save as HTML: ${(error as Error).message}`) + } +} diff --git a/src/main/services/WindowService.ts b/src/main/services/WindowService.ts index 3f96497e63..cda99cc37a 100644 --- a/src/main/services/WindowService.ts +++ b/src/main/services/WindowService.ts @@ -255,6 +255,12 @@ export class WindowService { } private setupWebContentsHandlers(mainWindow: BrowserWindow) { + // Fix for Electron bug where zoom resets during in-page navigation (route changes) + // This complements the resize-based workaround by catching navigation events + mainWindow.webContents.on('did-navigate-in-page', () => { + mainWindow.webContents.setZoomFactor(configManager.getZoomFactor()) + }) + mainWindow.webContents.on('will-navigate', (event, url) => { if (url.includes('localhost:517')) { return @@ -516,7 +522,9 @@ export class WindowService { miniWindowState.manage(this.miniWindow) //miniWindow should show in current desktop - this.miniWindow?.setVisibleOnAllWorkspaces(true, { visibleOnFullScreen: true }) + this.miniWindow?.setVisibleOnAllWorkspaces(true, { + visibleOnFullScreen: true + }) //make miniWindow always on top of fullscreen apps with level set //[mac] level higher than 'floating' will cover the pinyin input method this.miniWindow.setAlwaysOnTop(true, 'floating') @@ -635,6 +643,11 @@ export class WindowService { return } else if (isMac) { this.miniWindow.hide() + const majorVersion = parseInt(process.getSystemVersion().split('.')[0], 10) + if (majorVersion >= 26) { + // on macOS 26+, the popup of the mimiWindow would not change the focus to previous application. + return + } if (!this.wasMainWindowFocused) { app.hide() } diff --git a/src/main/services/__tests__/BackupManager.deleteTempBackup.test.ts b/src/main/services/__tests__/BackupManager.deleteTempBackup.test.ts new file mode 100644 index 0000000000..062d140eb5 --- /dev/null +++ b/src/main/services/__tests__/BackupManager.deleteTempBackup.test.ts @@ -0,0 +1,274 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +// Use vi.hoisted to define mocks that are available during hoisting +const { mockLogger } = vi.hoisted(() => ({ + mockLogger: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn() + } +})) + +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => mockLogger + } +})) + +vi.mock('electron', () => ({ + app: { + getPath: vi.fn((key: string) => { + if (key === 'temp') return '/tmp' + if (key === 'userData') return '/mock/userData' + return '/mock/unknown' + }) + } +})) + +vi.mock('fs-extra', () => ({ + default: { + pathExists: vi.fn(), + remove: vi.fn(), + ensureDir: vi.fn(), + copy: vi.fn(), + readdir: vi.fn(), + stat: vi.fn(), + readFile: vi.fn(), + writeFile: vi.fn(), + createWriteStream: vi.fn(), + createReadStream: vi.fn() + }, + pathExists: vi.fn(), + remove: vi.fn(), + ensureDir: vi.fn(), + copy: vi.fn(), + readdir: vi.fn(), + stat: vi.fn(), + readFile: vi.fn(), + writeFile: vi.fn(), + createWriteStream: vi.fn(), + createReadStream: vi.fn() +})) + +vi.mock('../WindowService', () => ({ + windowService: { + getMainWindow: vi.fn() + } +})) + +vi.mock('../WebDav', () => ({ + default: vi.fn() +})) + +vi.mock('../S3Storage', () => ({ + default: vi.fn() +})) + +vi.mock('../../utils', () => ({ + getDataPath: vi.fn(() => '/mock/data') +})) + +vi.mock('archiver', () => ({ + default: vi.fn() +})) + +vi.mock('node-stream-zip', () => ({ + default: vi.fn() +})) + +// Import after mocks +import * as fs from 'fs-extra' + +import BackupManager from '../BackupManager' + +describe('BackupManager.deleteTempBackup - Security Tests', () => { + let backupManager: BackupManager + + beforeEach(() => { + vi.clearAllMocks() + backupManager = new BackupManager() + }) + + describe('Normal Operations', () => { + it('should delete valid file in allowed directory', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(true as never) + vi.mocked(fs.remove).mockResolvedValue(undefined as never) + + const validPath = '/tmp/cherry-studio/lan-transfer/backup.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, validPath) + + expect(result).toBe(true) + expect(fs.remove).toHaveBeenCalledWith(validPath) + expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('Deleted temp backup')) + }) + + it('should delete file in nested subdirectory', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(true as never) + vi.mocked(fs.remove).mockResolvedValue(undefined as never) + + const nestedPath = '/tmp/cherry-studio/lan-transfer/sub/dir/file.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, nestedPath) + + expect(result).toBe(true) + expect(fs.remove).toHaveBeenCalledWith(nestedPath) + }) + + it('should return false when file does not exist', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(false as never) + + const missingPath = '/tmp/cherry-studio/lan-transfer/missing.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, missingPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + }) + }) + + describe('Path Traversal Attacks', () => { + it('should block basic directory traversal attack (../../../../etc/passwd)', async () => { + const attackPath = '/tmp/cherry-studio/lan-transfer/../../../../etc/passwd' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.pathExists).not.toHaveBeenCalled() + expect(fs.remove).not.toHaveBeenCalled() + expect(mockLogger.warn).toHaveBeenCalledWith(expect.stringContaining('outside temp directory')) + }) + + it('should block absolute path escape (/etc/passwd)', async () => { + const attackPath = '/etc/passwd' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + expect(mockLogger.warn).toHaveBeenCalled() + }) + + it('should block traversal with multiple slashes', async () => { + const attackPath = '/tmp/cherry-studio/lan-transfer/../../../etc/passwd' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + }) + + it('should block relative path traversal from current directory', async () => { + const attackPath = '../../../etc/passwd' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + }) + + it('should block traversal to parent directory', async () => { + const attackPath = '/tmp/cherry-studio/lan-transfer/../backup/secret.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + }) + }) + + describe('Prefix Attacks', () => { + it('should block similar prefix attack (lan-transfer-evil)', async () => { + const attackPath = '/tmp/cherry-studio/lan-transfer-evil/file.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + expect(mockLogger.warn).toHaveBeenCalled() + }) + + it('should block path without separator (lan-transferx)', async () => { + const attackPath = '/tmp/cherry-studio/lan-transferx' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + }) + + it('should block different temp directory prefix', async () => { + const attackPath = '/tmp-evil/cherry-studio/lan-transfer/file.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + }) + }) + + describe('Error Handling', () => { + it('should return false and log error on permission denied', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(true as never) + vi.mocked(fs.remove).mockRejectedValue(new Error('EACCES: permission denied') as never) + + const validPath = '/tmp/cherry-studio/lan-transfer/file.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, validPath) + + expect(result).toBe(false) + expect(mockLogger.error).toHaveBeenCalledWith(expect.stringContaining('Failed to delete'), expect.any(Error)) + }) + + it('should return false on fs.pathExists error', async () => { + vi.mocked(fs.pathExists).mockRejectedValue(new Error('ENOENT') as never) + + const validPath = '/tmp/cherry-studio/lan-transfer/file.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, validPath) + + expect(result).toBe(false) + expect(mockLogger.error).toHaveBeenCalled() + }) + + it('should handle empty path string', async () => { + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, '') + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + }) + }) + + describe('Edge Cases', () => { + it('should allow deletion of the temp directory itself', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(true as never) + vi.mocked(fs.remove).mockResolvedValue(undefined as never) + + const tempDir = '/tmp/cherry-studio/lan-transfer' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, tempDir) + + expect(result).toBe(true) + expect(fs.remove).toHaveBeenCalledWith(tempDir) + }) + + it('should handle path with trailing slash', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(true as never) + vi.mocked(fs.remove).mockResolvedValue(undefined as never) + + const pathWithSlash = '/tmp/cherry-studio/lan-transfer/sub/' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, pathWithSlash) + + // path.normalize removes trailing slash + expect(result).toBe(true) + }) + + it('should handle file with special characters in name', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(true as never) + vi.mocked(fs.remove).mockResolvedValue(undefined as never) + + const specialPath = '/tmp/cherry-studio/lan-transfer/file with spaces & (special).zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, specialPath) + + expect(result).toBe(true) + expect(fs.remove).toHaveBeenCalled() + }) + + it('should handle path with double slashes', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(true as never) + vi.mocked(fs.remove).mockResolvedValue(undefined as never) + + const doubleSlashPath = '/tmp/cherry-studio//lan-transfer//file.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, doubleSlashPath) + + // path.normalize handles double slashes + expect(result).toBe(true) + }) + }) +}) diff --git a/src/main/services/__tests__/LocalTransferService.test.ts b/src/main/services/__tests__/LocalTransferService.test.ts new file mode 100644 index 0000000000..d00c7c269b --- /dev/null +++ b/src/main/services/__tests__/LocalTransferService.test.ts @@ -0,0 +1,481 @@ +import { EventEmitter } from 'events' +import { afterEach, beforeEach, describe, expect, it, type Mock, vi } from 'vitest' + +// Create mock objects before vi.mock calls +const mockLogger = { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn() +} + +let mockMainWindow: { + isDestroyed: Mock + webContents: { send: Mock } +} | null = null + +let mockBrowser: EventEmitter & { + start: Mock + stop: Mock + removeAllListeners: Mock +} + +let mockBonjour: { + find: Mock + destroy: Mock +} + +// Mock dependencies before importing the service +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => mockLogger + } +})) + +vi.mock('../WindowService', () => ({ + windowService: { + getMainWindow: vi.fn(() => mockMainWindow) + } +})) + +vi.mock('bonjour-service', () => ({ + default: vi.fn(() => mockBonjour) +})) + +describe('LocalTransferService', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.resetModules() + + // Reset mock objects + mockMainWindow = { + isDestroyed: vi.fn(() => false), + webContents: { send: vi.fn() } + } + + mockBrowser = Object.assign(new EventEmitter(), { + start: vi.fn(), + stop: vi.fn(), + removeAllListeners: vi.fn() + }) + + mockBonjour = { + find: vi.fn(() => mockBrowser), + destroy: vi.fn() + } + }) + + afterEach(() => { + vi.resetAllMocks() + }) + + describe('startDiscovery', () => { + it('should set isScanning to true and start browser', async () => { + const { localTransferService } = await import('../LocalTransferService') + + const state = localTransferService.startDiscovery() + + expect(state.isScanning).toBe(true) + expect(state.lastScanStartedAt).toBeDefined() + expect(mockBonjour.find).toHaveBeenCalledWith({ type: 'cherrystudio', protocol: 'tcp' }) + expect(mockBrowser.start).toHaveBeenCalled() + }) + + it('should clear services when resetList is true', async () => { + const { localTransferService } = await import('../LocalTransferService') + + // First, start discovery and add a service + localTransferService.startDiscovery() + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100'], + fqdn: 'test.local' + }) + + expect(localTransferService.getState().services).toHaveLength(1) + + // Now restart with resetList + const state = localTransferService.startDiscovery({ resetList: true }) + + expect(state.services).toHaveLength(0) + }) + + it('should broadcast state after starting discovery', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + expect(mockMainWindow?.webContents.send).toHaveBeenCalled() + }) + + it('should handle browser.start() error', async () => { + mockBrowser.start.mockImplementation(() => { + throw new Error('Failed to start mDNS') + }) + + const { localTransferService } = await import('../LocalTransferService') + + const state = localTransferService.startDiscovery() + + expect(state.lastError).toBe('Failed to start mDNS') + expect(mockLogger.error).toHaveBeenCalled() + }) + }) + + describe('stopDiscovery', () => { + it('should set isScanning to false and stop browser', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + const state = localTransferService.stopDiscovery() + + expect(state.isScanning).toBe(false) + expect(mockBrowser.stop).toHaveBeenCalled() + }) + + it('should handle browser.stop() error gracefully', async () => { + mockBrowser.stop.mockImplementation(() => { + throw new Error('Stop failed') + }) + + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + // Should not throw + expect(() => localTransferService.stopDiscovery()).not.toThrow() + expect(mockLogger.warn).toHaveBeenCalled() + }) + + it('should broadcast state after stopping', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + vi.clearAllMocks() + + localTransferService.stopDiscovery() + + expect(mockMainWindow?.webContents.send).toHaveBeenCalled() + }) + }) + + describe('browser events', () => { + it('should add service on "up" event', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100'], + fqdn: 'test.local', + type: 'cherrystudio', + protocol: 'tcp' + }) + + const state = localTransferService.getState() + expect(state.services).toHaveLength(1) + expect(state.services[0].name).toBe('Test Service') + expect(state.services[0].port).toBe(12345) + expect(state.services[0].addresses).toContain('192.168.1.100') + }) + + it('should remove service on "down" event', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + // Add service + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100'], + fqdn: 'test.local' + }) + + expect(localTransferService.getState().services).toHaveLength(1) + + // Remove service + mockBrowser.emit('down', { + name: 'Test Service', + host: 'localhost', + port: 12345, + fqdn: 'test.local' + }) + + expect(localTransferService.getState().services).toHaveLength(0) + expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('removed')) + }) + + it('should set lastError on "error" event', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('error', new Error('Discovery failed')) + + const state = localTransferService.getState() + expect(state.lastError).toBe('Discovery failed') + expect(mockLogger.error).toHaveBeenCalled() + }) + + it('should handle non-Error objects in error event', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('error', 'String error message') + + const state = localTransferService.getState() + expect(state.lastError).toBe('String error message') + }) + }) + + describe('getState', () => { + it('should return sorted services by name', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Zebra Service', + host: 'host1', + port: 1001, + addresses: ['192.168.1.1'] + }) + + mockBrowser.emit('up', { + name: 'Alpha Service', + host: 'host2', + port: 1002, + addresses: ['192.168.1.2'] + }) + + const state = localTransferService.getState() + expect(state.services[0].name).toBe('Alpha Service') + expect(state.services[1].name).toBe('Zebra Service') + }) + + it('should include all state properties', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + const state = localTransferService.getState() + + expect(state).toHaveProperty('services') + expect(state).toHaveProperty('isScanning') + expect(state).toHaveProperty('lastScanStartedAt') + expect(state).toHaveProperty('lastUpdatedAt') + }) + }) + + describe('getPeerById', () => { + it('should return peer when exists', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100'], + fqdn: 'test.local' + }) + + const services = localTransferService.getState().services + const peer = localTransferService.getPeerById(services[0].id) + + expect(peer).toBeDefined() + expect(peer?.name).toBe('Test Service') + }) + + it('should return undefined when peer does not exist', async () => { + const { localTransferService } = await import('../LocalTransferService') + + const peer = localTransferService.getPeerById('non-existent-id') + + expect(peer).toBeUndefined() + }) + }) + + describe('normalizeService', () => { + it('should deduplicate addresses', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100', '192.168.1.100', '10.0.0.1'], + referer: { address: '192.168.1.100' } + }) + + const services = localTransferService.getState().services + expect(services[0].addresses).toHaveLength(2) + expect(services[0].addresses).toContain('192.168.1.100') + expect(services[0].addresses).toContain('10.0.0.1') + }) + + it('should filter empty addresses', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100', '', null as any] + }) + + const services = localTransferService.getState().services + expect(services[0].addresses).toEqual(['192.168.1.100']) + }) + + it('should convert txt null/undefined values to empty strings', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100'], + txt: { + version: '1.0', + nullValue: null, + undefinedValue: undefined, + numberValue: 42 + } + }) + + const services = localTransferService.getState().services + expect(services[0].txt).toEqual({ + version: '1.0', + nullValue: '', + undefinedValue: '', + numberValue: '42' + }) + }) + + it('should not include txt when empty', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100'], + txt: {} + }) + + const services = localTransferService.getState().services + expect(services[0].txt).toBeUndefined() + }) + }) + + describe('dispose', () => { + it('should clean up all resources', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100'] + }) + + localTransferService.dispose() + + expect(localTransferService.getState().services).toHaveLength(0) + expect(localTransferService.getState().isScanning).toBe(false) + expect(mockBrowser.removeAllListeners).toHaveBeenCalled() + expect(mockBonjour.destroy).toHaveBeenCalled() + }) + + it('should handle bonjour.destroy() error gracefully', async () => { + mockBonjour.destroy.mockImplementation(() => { + throw new Error('Destroy failed') + }) + + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + // Should not throw + expect(() => localTransferService.dispose()).not.toThrow() + expect(mockLogger.warn).toHaveBeenCalled() + }) + + it('should be safe to call multiple times', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + expect(() => { + localTransferService.dispose() + localTransferService.dispose() + }).not.toThrow() + }) + }) + + describe('broadcastState', () => { + it('should not throw when main window is null', async () => { + mockMainWindow = null + + const { localTransferService } = await import('../LocalTransferService') + + // Should not throw + expect(() => localTransferService.startDiscovery()).not.toThrow() + }) + + it('should not throw when main window is destroyed', async () => { + mockMainWindow = { + isDestroyed: vi.fn(() => true), + webContents: { send: vi.fn() } + } + + const { localTransferService } = await import('../LocalTransferService') + + // Should not throw + expect(() => localTransferService.startDiscovery()).not.toThrow() + expect(mockMainWindow.webContents.send).not.toHaveBeenCalled() + }) + }) + + describe('restartBrowser', () => { + it('should destroy old bonjour instance to prevent socket leaks', async () => { + const { localTransferService } = await import('../LocalTransferService') + + // First start + localTransferService.startDiscovery() + expect(mockBonjour.destroy).not.toHaveBeenCalled() + + // Restart - should destroy old instance + localTransferService.startDiscovery() + expect(mockBonjour.destroy).toHaveBeenCalled() + }) + + it('should remove all listeners from old browser', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + localTransferService.startDiscovery() + + expect(mockBrowser.removeAllListeners).toHaveBeenCalled() + }) + }) +}) diff --git a/src/main/services/__tests__/ServerLogBuffer.test.ts b/src/main/services/__tests__/ServerLogBuffer.test.ts new file mode 100644 index 0000000000..0b7abe91e8 --- /dev/null +++ b/src/main/services/__tests__/ServerLogBuffer.test.ts @@ -0,0 +1,29 @@ +import { describe, expect, it } from 'vitest' + +import { ServerLogBuffer } from '../mcp/ServerLogBuffer' + +describe('ServerLogBuffer', () => { + it('keeps a bounded number of entries per server', () => { + const buffer = new ServerLogBuffer(3) + const key = 'srv' + + buffer.append(key, { timestamp: 1, level: 'info', message: 'a' }) + buffer.append(key, { timestamp: 2, level: 'info', message: 'b' }) + buffer.append(key, { timestamp: 3, level: 'info', message: 'c' }) + buffer.append(key, { timestamp: 4, level: 'info', message: 'd' }) + + const logs = buffer.get(key) + expect(logs).toHaveLength(3) + expect(logs[0].message).toBe('b') + expect(logs[2].message).toBe('d') + }) + + it('isolates entries by server key', () => { + const buffer = new ServerLogBuffer(5) + buffer.append('one', { timestamp: 1, level: 'info', message: 'a' }) + buffer.append('two', { timestamp: 2, level: 'info', message: 'b' }) + + expect(buffer.get('one')).toHaveLength(1) + expect(buffer.get('two')).toHaveLength(1) + }) +}) diff --git a/src/main/services/agents/BaseService.ts b/src/main/services/agents/BaseService.ts index 78bf72a952..e30814bb6f 100644 --- a/src/main/services/agents/BaseService.ts +++ b/src/main/services/agents/BaseService.ts @@ -2,6 +2,7 @@ import { loggerService } from '@logger' import { mcpApiService } from '@main/apiServer/services/mcp' import type { ModelValidationError } from '@main/apiServer/utils' import { validateModelId } from '@main/apiServer/utils' +import { buildFunctionCallToolName } from '@main/utils/mcp' import type { AgentType, MCPTool, SlashCommand, Tool } from '@types' import { objectKeys } from '@types' import fs from 'fs' @@ -14,6 +15,17 @@ import { builtinSlashCommands } from './services/claudecode/commands' import { builtinTools } from './services/claudecode/tools' const logger = loggerService.withContext('BaseService') +const MCP_TOOL_ID_PREFIX = 'mcp__' +const MCP_TOOL_LEGACY_PREFIX = 'mcp_' + +const buildMcpToolId = (serverId: string, toolName: string) => `${MCP_TOOL_ID_PREFIX}${serverId}__${toolName}` +const toLegacyMcpToolId = (toolId: string) => { + if (!toolId.startsWith(MCP_TOOL_ID_PREFIX)) { + return null + } + const rawId = toolId.slice(MCP_TOOL_ID_PREFIX.length) + return `${MCP_TOOL_LEGACY_PREFIX}${rawId.replace(/__/g, '_')}` +} /** * Base service class providing shared utilities for all agent-related services. @@ -35,8 +47,12 @@ export abstract class BaseService { 'slash_commands' ] - public async listMcpTools(agentType: AgentType, ids?: string[]): Promise { + public async listMcpTools( + agentType: AgentType, + ids?: string[] + ): Promise<{ tools: Tool[]; legacyIdMap: Map }> { const tools: Tool[] = [] + const legacyIdMap = new Map() if (agentType === 'claude-code') { tools.push(...builtinTools) } @@ -46,13 +62,21 @@ export abstract class BaseService { const server = await mcpApiService.getServerInfo(id) if (server) { server.tools.forEach((tool: MCPTool) => { + const canonicalId = buildFunctionCallToolName(server.name, tool.name) + const serverIdBasedId = buildMcpToolId(id, tool.name) + const legacyId = toLegacyMcpToolId(serverIdBasedId) + tools.push({ - id: `mcp_${id}_${tool.name}`, + id: canonicalId, name: tool.name, type: 'mcp', description: tool.description || '', requirePermissions: true }) + legacyIdMap.set(serverIdBasedId, canonicalId) + if (legacyId) { + legacyIdMap.set(legacyId, canonicalId) + } }) } } catch (error) { @@ -64,7 +88,53 @@ export abstract class BaseService { } } - return tools + return { tools, legacyIdMap } + } + + /** + * Normalize MCP tool IDs in allowed_tools to the current format. + * + * Legacy formats: + * - "mcp____" (double underscore separators, server ID based) + * - "mcp__" (single underscore separators) + * Current format: "mcp____" (double underscore separators). + * + * This keeps persisted data compatible without requiring a database migration. + */ + protected normalizeAllowedTools( + allowedTools: string[] | undefined, + tools: Tool[], + legacyIdMap?: Map + ): string[] | undefined { + if (!allowedTools || allowedTools.length === 0) { + return allowedTools + } + + const resolvedLegacyIdMap = new Map() + + if (legacyIdMap) { + for (const [legacyId, canonicalId] of legacyIdMap) { + resolvedLegacyIdMap.set(legacyId, canonicalId) + } + } + + for (const tool of tools) { + if (tool.type !== 'mcp') { + continue + } + const legacyId = toLegacyMcpToolId(tool.id) + if (!legacyId) { + continue + } + resolvedLegacyIdMap.set(legacyId, tool.id) + } + + if (resolvedLegacyIdMap.size === 0) { + return allowedTools + } + + const normalized = allowedTools.map((toolId) => resolvedLegacyIdMap.get(toolId) ?? toolId) + return Array.from(new Set(normalized)) } public async listSlashCommands(agentType: AgentType): Promise { @@ -78,7 +148,7 @@ export abstract class BaseService { * Get database instance * Automatically waits for initialization to complete */ - protected async getDatabase() { + public async getDatabase() { const dbManager = await DatabaseManager.getInstance() return dbManager.getDatabase() } diff --git a/src/main/services/agents/database/DatabaseManager.ts b/src/main/services/agents/database/DatabaseManager.ts index f4b13971c7..913f9e4a66 100644 --- a/src/main/services/agents/database/DatabaseManager.ts +++ b/src/main/services/agents/database/DatabaseManager.ts @@ -1,3 +1,19 @@ +/** + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- + */ import { type Client, createClient } from '@libsql/client' import { loggerService } from '@logger' import type { LibSQLDatabase } from 'drizzle-orm/libsql' diff --git a/src/main/services/agents/drizzle.config.ts b/src/main/services/agents/drizzle.config.ts index e12518c069..7278883c11 100644 --- a/src/main/services/agents/drizzle.config.ts +++ b/src/main/services/agents/drizzle.config.ts @@ -1,3 +1,19 @@ +/** + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- + */ /** * Drizzle Kit configuration for agents database */ diff --git a/src/main/services/agents/services/AgentService.ts b/src/main/services/agents/services/AgentService.ts index 2faa87bb45..7542c1935b 100644 --- a/src/main/services/agents/services/AgentService.ts +++ b/src/main/services/agents/services/AgentService.ts @@ -89,7 +89,9 @@ export class AgentService extends BaseService { } const agent = this.deserializeJsonFields(result[0]) as GetAgentResponse - agent.tools = await this.listMcpTools(agent.type, agent.mcps) + const { tools, legacyIdMap } = await this.listMcpTools(agent.type, agent.mcps) + agent.tools = tools + agent.allowed_tools = this.normalizeAllowedTools(agent.allowed_tools, agent.tools, legacyIdMap) // Load installed_plugins from cache file instead of database const workdir = agent.accessible_paths?.[0] @@ -134,7 +136,9 @@ export class AgentService extends BaseService { const agents = result.map((row) => this.deserializeJsonFields(row)) as GetAgentResponse[] for (const agent of agents) { - agent.tools = await this.listMcpTools(agent.type, agent.mcps) + const { tools, legacyIdMap } = await this.listMcpTools(agent.type, agent.mcps) + agent.tools = tools + agent.allowed_tools = this.normalizeAllowedTools(agent.allowed_tools, agent.tools, legacyIdMap) } return { agents, total: totalResult[0].count } diff --git a/src/main/services/agents/services/SessionService.ts b/src/main/services/agents/services/SessionService.ts index d933ef8dd9..90b32bb31c 100644 --- a/src/main/services/agents/services/SessionService.ts +++ b/src/main/services/agents/services/SessionService.ts @@ -156,7 +156,9 @@ export class SessionService extends BaseService { } const session = this.deserializeJsonFields(result[0]) as GetAgentSessionResponse - session.tools = await this.listMcpTools(session.agent_type, session.mcps) + const { tools, legacyIdMap } = await this.listMcpTools(session.agent_type, session.mcps) + session.tools = tools + session.allowed_tools = this.normalizeAllowedTools(session.allowed_tools, session.tools, legacyIdMap) // If slash_commands is not in database yet (e.g., first invoke before init message), // fall back to builtin + local commands. Otherwise, use the merged commands from database. @@ -202,6 +204,12 @@ export class SessionService extends BaseService { const sessions = result.map((row) => this.deserializeJsonFields(row)) as GetAgentSessionResponse[] + for (const session of sessions) { + const { tools, legacyIdMap } = await this.listMcpTools(session.agent_type, session.mcps) + session.tools = tools + session.allowed_tools = this.normalizeAllowedTools(session.allowed_tools, session.tools, legacyIdMap) + } + return { sessions, total } } diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts index e5cefadd68..69266f5a61 100644 --- a/src/main/services/agents/services/claudecode/index.ts +++ b/src/main/services/agents/services/claudecode/index.ts @@ -15,7 +15,10 @@ import { query } from '@anthropic-ai/claude-agent-sdk' import { loggerService } from '@logger' import { config as apiConfigService } from '@main/apiServer/config' import { validateModelId } from '@main/apiServer/utils' +import { isWin } from '@main/constant' +import { autoDiscoverGitBash } from '@main/utils/process' import getLoginShellEnvironment from '@main/utils/shell-env' +import { withoutTrailingApiVersion } from '@shared/utils' import { app } from 'electron' import type { GetAgentSessionResponse } from '../..' @@ -107,6 +110,16 @@ class ClaudeCodeService implements AgentServiceInterface { Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy')) ) as Record + // Auto-discover Git Bash path on Windows (already logs internally) + const customGitBashPath = isWin ? autoDiscoverGitBash() : null + + // Claude Agent SDK builds the final endpoint as `${ANTHROPIC_BASE_URL}/v1/messages`. + // To avoid malformed URLs like `/v1/v1/messages`, we normalize the provider host + // by stripping any trailing API version (e.g. `/v1`). + const anthropicBaseUrl = withoutTrailingApiVersion( + modelInfo.provider.anthropicApiHost?.trim() || modelInfo.provider.apiHost + ) + const env = { ...loginShellEnvWithoutProxies, // TODO: fix the proxy api server @@ -115,7 +128,7 @@ class ClaudeCodeService implements AgentServiceInterface { // ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`, ANTHROPIC_API_KEY: modelInfo.provider.apiKey, ANTHROPIC_AUTH_TOKEN: modelInfo.provider.apiKey, - ANTHROPIC_BASE_URL: modelInfo.provider.anthropicApiHost?.trim() || modelInfo.provider.apiHost, + ANTHROPIC_BASE_URL: anthropicBaseUrl, ANTHROPIC_MODEL: modelInfo.modelId, ANTHROPIC_DEFAULT_OPUS_MODEL: modelInfo.modelId, ANTHROPIC_DEFAULT_SONNET_MODEL: modelInfo.modelId, @@ -126,7 +139,8 @@ class ClaudeCodeService implements AgentServiceInterface { // Set CLAUDE_CONFIG_DIR to app's userData directory to avoid path encoding issues // on Windows when the username contains non-ASCII characters (e.g., Chinese characters) // This prevents the SDK from using the user's home directory which may have encoding problems - CLAUDE_CONFIG_DIR: path.join(app.getPath('userData'), '.claude') + CLAUDE_CONFIG_DIR: path.join(app.getPath('userData'), '.claude'), + ...(customGitBashPath ? { CLAUDE_CODE_GIT_BASH_PATH: customGitBashPath } : {}) } const errorChunks: string[] = [] diff --git a/src/main/services/agents/tests/BaseService.test.ts b/src/main/services/agents/tests/BaseService.test.ts new file mode 100644 index 0000000000..fe2f4e103a --- /dev/null +++ b/src/main/services/agents/tests/BaseService.test.ts @@ -0,0 +1,91 @@ +import type { Tool } from '@types' +import { describe, expect, it, vi } from 'vitest' + +vi.mock('@main/apiServer/services/mcp', () => ({ + mcpApiService: { + getServerInfo: vi.fn() + } +})) + +vi.mock('@main/apiServer/utils', () => ({ + validateModelId: vi.fn() +})) + +import { BaseService } from '../BaseService' + +class TestBaseService extends BaseService { + public normalize( + allowedTools: string[] | undefined, + tools: Tool[], + legacyIdMap?: Map + ): string[] | undefined { + return this.normalizeAllowedTools(allowedTools, tools, legacyIdMap) + } +} + +const buildMcpTool = (id: string): Tool => ({ + id, + name: id, + type: 'mcp', + description: 'test tool', + requirePermissions: true +}) + +describe('BaseService.normalizeAllowedTools', () => { + const service = new TestBaseService() + + it('returns undefined or empty inputs unchanged', () => { + expect(service.normalize(undefined, [])).toBeUndefined() + expect(service.normalize([], [])).toEqual([]) + }) + + it('normalizes legacy MCP tool IDs and deduplicates entries', () => { + const tools: Tool[] = [ + buildMcpTool('mcp__server_one__tool_one'), + buildMcpTool('mcp__server_two__tool_two'), + { id: 'custom_tool', name: 'custom_tool', type: 'custom' } + ] + + const legacyIdMap = new Map([ + ['mcp__server-1__tool-one', 'mcp__server_one__tool_one'], + ['mcp_server-1_tool-one', 'mcp__server_one__tool_one'], + ['mcp__server-2__tool-two', 'mcp__server_two__tool_two'] + ]) + + const allowedTools = [ + 'mcp__server-1__tool-one', + 'mcp_server-1_tool-one', + 'mcp_server_one_tool_one', + 'mcp__server_one__tool_one', + 'custom_tool', + 'mcp__server_two__tool_two', + 'mcp_server_two_tool_two', + 'mcp__server-2__tool-two' + ] + + expect(service.normalize(allowedTools, tools, legacyIdMap)).toEqual([ + 'mcp__server_one__tool_one', + 'custom_tool', + 'mcp__server_two__tool_two' + ]) + }) + + it('keeps legacy IDs when no matching MCP tool exists', () => { + const tools: Tool[] = [buildMcpTool('mcp__server_one__tool_one')] + const legacyIdMap = new Map([['mcp__server-1__tool-one', 'mcp__server_one__tool_one']]) + + const allowedTools = ['mcp__unknown__tool', 'mcp__server_one__tool_one'] + + expect(service.normalize(allowedTools, tools, legacyIdMap)).toEqual([ + 'mcp__unknown__tool', + 'mcp__server_one__tool_one' + ]) + }) + + it('returns allowed tools unchanged when no MCP tools are available', () => { + const allowedTools = ['custom_tool', 'builtin_tool'] + const tools: Tool[] = [{ id: 'custom_tool', name: 'custom_tool', type: 'custom' }] + + expect(service.normalize(allowedTools, tools)).toEqual(allowedTools) + }) +}) diff --git a/src/main/services/lanTransfer/LanTransferClientService.ts b/src/main/services/lanTransfer/LanTransferClientService.ts new file mode 100644 index 0000000000..a6da2f1a20 --- /dev/null +++ b/src/main/services/lanTransfer/LanTransferClientService.ts @@ -0,0 +1,525 @@ +import * as crypto from 'node:crypto' +import { createConnection, type Socket } from 'node:net' + +import { loggerService } from '@logger' +import type { + LanClientEvent, + LanFileCompleteMessage, + LanHandshakeAckMessage, + LocalTransferConnectPayload, + LocalTransferPeer +} from '@shared/config/types' +import { LAN_TRANSFER_GLOBAL_TIMEOUT_MS } from '@shared/config/types' +import { IpcChannel } from '@shared/IpcChannel' + +import { localTransferService } from '../LocalTransferService' +import { windowService } from '../WindowService' +import { + abortTransfer, + buildHandshakeMessage, + calculateFileChecksum, + cleanupTransfer, + createDataHandler, + createTransferState, + formatFileSize, + HANDSHAKE_PROTOCOL_VERSION, + pickHost, + sendFileEnd, + sendFileStart, + sendTestPing, + streamFileChunks, + validateFile, + waitForFileComplete, + waitForFileStartAck +} from './handlers' +import { ResponseManager } from './responseManager' +import type { ActiveFileTransfer, ConnectionContext, FileTransferContext } from './types' + +const DEFAULT_HANDSHAKE_TIMEOUT_MS = 10_000 + +const logger = loggerService.withContext('LanTransferClientService') + +/** + * LAN Transfer Client Service + * + * Handles outgoing file transfers to LAN peers via TCP. + * Protocol v1 with streaming mode (no per-chunk acknowledgment). + */ +class LanTransferClientService { + private socket: Socket | null = null + private currentPeer?: LocalTransferPeer + private dataHandler?: ReturnType + private responseManager = new ResponseManager() + private isConnecting = false + private activeTransfer?: ActiveFileTransfer + private lastConnectOptions?: LocalTransferConnectPayload + private consecutiveJsonErrors = 0 + private static readonly MAX_CONSECUTIVE_JSON_ERRORS = 3 + private reconnectPromise: Promise | null = null + + constructor() { + this.responseManager.setTimeoutCallback(() => void this.disconnect()) + } + + /** + * Connect to a LAN peer and perform handshake. + */ + public async connectAndHandshake(options: LocalTransferConnectPayload): Promise { + if (this.isConnecting) { + throw new Error('LAN transfer client is busy') + } + + const peer = localTransferService.getPeerById(options.peerId) + if (!peer) { + throw new Error('Selected LAN peer is no longer available') + } + if (!peer.port) { + throw new Error('Selected peer does not expose a TCP port') + } + + const host = pickHost(peer) + if (!host) { + throw new Error('Unable to resolve a reachable host for the peer') + } + + await this.disconnect() + this.isConnecting = true + + return new Promise((resolve, reject) => { + const socket = createConnection({ host, port: peer.port as number }, () => { + logger.info(`Connected to LAN peer ${peer.name} (${host}:${peer.port})`) + socket.setKeepAlive(true, 30_000) + this.socket = socket + this.currentPeer = peer + this.attachSocketListeners(socket) + + this.responseManager.waitForResponse( + 'handshake_ack', + options.timeoutMs ?? DEFAULT_HANDSHAKE_TIMEOUT_MS, + (payload) => { + const ack = payload as LanHandshakeAckMessage + if (!ack.accepted) { + const message = ack.message || 'Handshake rejected by remote device' + logger.warn(`Handshake rejected by ${peer.name}: ${message}`) + this.broadcastClientEvent({ + type: 'error', + message, + timestamp: Date.now() + }) + reject(new Error(message)) + void this.disconnect() + return + } + logger.info(`Handshake accepted by ${peer.name}`) + socket.setTimeout(0) + this.isConnecting = false + this.lastConnectOptions = options + sendTestPing(this.createConnectionContext()) + resolve(ack) + }, + (error) => { + this.isConnecting = false + reject(error) + } + ) + + const handshakeMessage = buildHandshakeMessage() + this.sendControlMessage(handshakeMessage) + }) + + socket.setTimeout(options.timeoutMs ?? DEFAULT_HANDSHAKE_TIMEOUT_MS, () => { + const error = new Error('Handshake timed out') + logger.error('LAN transfer socket timeout', error) + this.broadcastClientEvent({ + type: 'error', + message: error.message, + timestamp: Date.now() + }) + reject(error) + socket.destroy(error) + void this.disconnect() + }) + + socket.once('error', (error) => { + logger.error('LAN transfer socket error', error as Error) + const message = error instanceof Error ? error.message : String(error) + this.broadcastClientEvent({ + type: 'error', + message, + timestamp: Date.now() + }) + this.isConnecting = false + reject(error instanceof Error ? error : new Error(message)) + void this.disconnect() + }) + + socket.once('close', () => { + logger.info('LAN transfer socket closed') + if (this.socket === socket) { + this.socket = null + this.dataHandler?.resetBuffer() + this.responseManager.rejectAll(new Error('LAN transfer socket closed')) + this.currentPeer = undefined + abortTransfer(this.activeTransfer, new Error('LAN transfer socket closed')) + } + this.isConnecting = false + this.broadcastClientEvent({ + type: 'socket_closed', + reason: 'connection_closed', + timestamp: Date.now() + }) + }) + }) + } + + /** + * Disconnect from the current peer. + */ + public async disconnect(): Promise { + const socket = this.socket + if (!socket) { + return + } + + this.socket = null + this.dataHandler?.resetBuffer() + this.currentPeer = undefined + this.responseManager.rejectAll(new Error('LAN transfer socket disconnected')) + abortTransfer(this.activeTransfer, new Error('LAN transfer socket disconnected')) + + const DISCONNECT_TIMEOUT_MS = 3000 + await new Promise((resolve) => { + const timeout = setTimeout(() => { + logger.warn('Disconnect timeout, forcing cleanup') + socket.removeAllListeners() + resolve() + }, DISCONNECT_TIMEOUT_MS) + + socket.once('close', () => { + clearTimeout(timeout) + resolve() + }) + + socket.destroy() + }) + } + + /** + * Dispose the service and clean up all resources. + */ + public dispose(): void { + this.responseManager.rejectAll(new Error('LAN transfer client disposed')) + cleanupTransfer(this.activeTransfer) + this.activeTransfer = undefined + if (this.socket) { + this.socket.destroy() + this.socket = null + } + this.dataHandler?.resetBuffer() + this.isConnecting = false + } + + /** + * Send a ZIP file to the connected peer. + */ + public async sendFile(filePath: string): Promise { + await this.ensureConnection() + + if (this.activeTransfer) { + throw new Error('A file transfer is already in progress') + } + + // Validate file + const { stats, fileName } = await validateFile(filePath) + + // Calculate checksum + logger.info('Calculating file checksum...') + const checksum = await calculateFileChecksum(filePath) + logger.info(`File checksum: ${checksum.substring(0, 16)}...`) + + // Connection can drop while validating/checking file; ensure it is still ready before starting transfer. + await this.ensureConnection() + + // Initialize transfer state + const transferId = crypto.randomUUID() + this.activeTransfer = createTransferState(transferId, fileName, stats.size, checksum) + + logger.info( + `Starting file transfer: ${fileName} (${formatFileSize(stats.size)}, ${this.activeTransfer.totalChunks} chunks)` + ) + + // Global timeout + const globalTimeoutError = new Error('Transfer timed out (global timeout exceeded)') + const globalTimeoutHandle = setTimeout(() => { + logger.warn('Global transfer timeout exceeded, aborting transfer', { transferId, fileName }) + abortTransfer(this.activeTransfer, globalTimeoutError) + }, LAN_TRANSFER_GLOBAL_TIMEOUT_MS) + + try { + const result = await this.performFileTransfer(filePath, transferId, fileName) + return result + } catch (error) { + const message = error instanceof Error ? error.message : String(error) + logger.error(`File transfer failed: ${message}`) + + this.broadcastClientEvent({ + type: 'file_transfer_complete', + transferId, + fileName, + success: false, + error: message, + timestamp: Date.now() + }) + + throw error + } finally { + clearTimeout(globalTimeoutHandle) + cleanupTransfer(this.activeTransfer) + this.activeTransfer = undefined + } + } + + /** + * Cancel the current file transfer. + */ + public cancelTransfer(): void { + if (!this.activeTransfer) { + logger.warn('No active transfer to cancel') + return + } + + const { transferId, fileName } = this.activeTransfer + logger.info(`Cancelling file transfer: ${fileName}`) + + this.activeTransfer.isCancelled = true + + try { + this.sendControlMessage({ + type: 'file_cancel', + transferId, + reason: 'Cancelled by user' + }) + } catch (error) { + // Expected when connection is already broken + logger.warn('Failed to send cancel message', error as Error) + } + + abortTransfer(this.activeTransfer, new Error('Transfer cancelled by user')) + } + + // ============================================================================= + // Private Methods + // ============================================================================= + + private async ensureConnection(): Promise { + // Check socket is valid and writable (not just undestroyed) + if (this.socket && !this.socket.destroyed && this.socket.writable && this.currentPeer) { + return + } + + if (!this.lastConnectOptions) { + throw new Error('No active connection. Please connect to a peer first.') + } + + // Prevent concurrent reconnection attempts + if (this.reconnectPromise) { + logger.debug('Waiting for existing reconnection attempt...') + await this.reconnectPromise + return + } + + logger.info('Connection lost, attempting to reconnect...') + this.reconnectPromise = this.connectAndHandshake(this.lastConnectOptions) + .then(() => { + // Handshake succeeded, connection restored + }) + .finally(() => { + this.reconnectPromise = null + }) + + await this.reconnectPromise + } + + private async performFileTransfer( + filePath: string, + transferId: string, + fileName: string + ): Promise { + const transfer = this.activeTransfer! + const ctx = this.createFileTransferContext() + + // Step 1: Send file_start + sendFileStart(ctx, transfer) + + // Step 2: Wait for file_start_ack + const startAck = await waitForFileStartAck(ctx, transferId, transfer.abortController.signal) + if (!startAck.accepted) { + throw new Error(startAck.message || 'Transfer rejected by receiver') + } + logger.info('Received file_start_ack: accepted') + + // Step 3: Stream file chunks + await streamFileChunks(this.socket!, filePath, transfer, transfer.abortController.signal, (bytesSent, chunkIndex) => + this.onTransferProgress(transfer, bytesSent, chunkIndex) + ) + + // Step 4: Send file_end + sendFileEnd(ctx, transferId) + + // Step 5: Wait for file_complete + const result = await waitForFileComplete(ctx, transferId, transfer.abortController.signal) + logger.info(`File transfer ${result.success ? 'completed' : 'failed'}`) + + // Broadcast completion + this.broadcastClientEvent({ + type: 'file_transfer_complete', + transferId, + fileName, + success: result.success, + filePath: result.filePath, + error: result.error, + timestamp: Date.now() + }) + + return result + } + + private onTransferProgress(transfer: ActiveFileTransfer, bytesSent: number, chunkIndex: number): void { + const progress = (bytesSent / transfer.fileSize) * 100 + const elapsed = (Date.now() - transfer.startedAt) / 1000 + const speed = elapsed > 0 ? bytesSent / elapsed : 0 + + this.broadcastClientEvent({ + type: 'file_transfer_progress', + transferId: transfer.transferId, + fileName: transfer.fileName, + bytesSent, + totalBytes: transfer.fileSize, + chunkIndex, + totalChunks: transfer.totalChunks, + progress: Math.round(progress * 100) / 100, + speed, + timestamp: Date.now() + }) + } + + private attachSocketListeners(socket: Socket): void { + this.dataHandler = createDataHandler((line) => this.handleControlLine(line)) + socket.on('data', (chunk: Buffer) => { + try { + this.dataHandler?.handleData(chunk) + } catch (error) { + logger.error('Data handler error', error as Error) + void this.disconnect() + } + }) + } + + private handleControlLine(line: string): void { + let payload: Record + try { + payload = JSON.parse(line) + this.consecutiveJsonErrors = 0 // Reset on successful parse + } catch { + this.consecutiveJsonErrors++ + logger.warn('Received invalid JSON control message', { line, consecutiveErrors: this.consecutiveJsonErrors }) + + if (this.consecutiveJsonErrors >= LanTransferClientService.MAX_CONSECUTIVE_JSON_ERRORS) { + const message = `Protocol error: ${this.consecutiveJsonErrors} consecutive invalid messages, disconnecting` + logger.error(message) + this.broadcastClientEvent({ + type: 'error', + message, + timestamp: Date.now() + }) + void this.disconnect() + } + return + } + + const type = payload?.type as string | undefined + if (!type) { + logger.warn('Received control message without type', payload) + return + } + + // Try to resolve a pending response + const transferId = payload?.transferId as string | undefined + const chunkIndex = payload?.chunkIndex as number | undefined + if (this.responseManager.tryResolve(type, payload, transferId, chunkIndex)) { + return + } + + logger.info('Received control message', payload) + + if (type === 'pong') { + this.broadcastClientEvent({ + type: 'pong', + payload: payload?.payload as string | undefined, + received: payload?.received as boolean | undefined, + timestamp: Date.now() + }) + return + } + + // Ignore late-arriving file transfer messages + const fileTransferMessageTypes = ['file_start_ack', 'file_complete'] + if (fileTransferMessageTypes.includes(type)) { + logger.debug('Ignoring late file transfer message', { type, payload }) + return + } + + this.broadcastClientEvent({ + type: 'error', + message: `Unexpected control message type: ${type}`, + timestamp: Date.now() + }) + } + + private sendControlMessage(message: Record): void { + if (!this.socket || this.socket.destroyed || !this.socket.writable) { + throw new Error('Socket is not connected') + } + const payload = JSON.stringify(message) + this.socket.write(`${payload}\n`) + } + + private createConnectionContext(): ConnectionContext { + return { + socket: this.socket, + currentPeer: this.currentPeer, + sendControlMessage: (msg) => this.sendControlMessage(msg), + broadcastClientEvent: (event) => this.broadcastClientEvent(event) + } + } + + private createFileTransferContext(): FileTransferContext { + return { + ...this.createConnectionContext(), + activeTransfer: this.activeTransfer, + setActiveTransfer: (transfer) => { + this.activeTransfer = transfer + }, + waitForResponse: (type, timeoutMs, resolve, reject, transferId, chunkIndex, abortSignal) => { + this.responseManager.waitForResponse(type, timeoutMs, resolve, reject, transferId, chunkIndex, abortSignal) + } + } + } + + private broadcastClientEvent(event: LanClientEvent): void { + const mainWindow = windowService.getMainWindow() + if (!mainWindow || mainWindow.isDestroyed()) { + return + } + mainWindow.webContents.send(IpcChannel.LocalTransfer_ClientEvent, { + ...event, + peerId: event.peerId ?? this.currentPeer?.id, + peerName: event.peerName ?? this.currentPeer?.name + }) + } +} + +export const lanTransferClientService = new LanTransferClientService() + +// Re-export for backward compatibility +export { HANDSHAKE_PROTOCOL_VERSION } diff --git a/src/main/services/lanTransfer/__tests__/LanTransferClientService.test.ts b/src/main/services/lanTransfer/__tests__/LanTransferClientService.test.ts new file mode 100644 index 0000000000..16f188aa93 --- /dev/null +++ b/src/main/services/lanTransfer/__tests__/LanTransferClientService.test.ts @@ -0,0 +1,133 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// Mock dependencies before importing the service +vi.mock('node:net', async (importOriginal) => { + const actual = (await importOriginal()) as Record + return { + ...actual, + createConnection: vi.fn() + } +}) + +vi.mock('electron', () => ({ + app: { + getName: vi.fn(() => 'Cherry Studio'), + getVersion: vi.fn(() => '1.0.0') + } +})) + +vi.mock('../../LocalTransferService', () => ({ + localTransferService: { + getPeerById: vi.fn() + } +})) + +vi.mock('../../WindowService', () => ({ + windowService: { + getMainWindow: vi.fn(() => ({ + isDestroyed: () => false, + webContents: { + send: vi.fn() + } + })) + } +})) + +// Import after mocks +import { localTransferService } from '../../LocalTransferService' + +describe('LanTransferClientService', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.resetModules() + }) + + afterEach(() => { + vi.resetAllMocks() + }) + + describe('connectAndHandshake - validation', () => { + it('should throw error when peer is not found', async () => { + vi.mocked(localTransferService.getPeerById).mockReturnValue(undefined) + + const { lanTransferClientService } = await import('../LanTransferClientService') + + await expect( + lanTransferClientService.connectAndHandshake({ + peerId: 'non-existent' + }) + ).rejects.toThrow('Selected LAN peer is no longer available') + }) + + it('should throw error when peer has no port', async () => { + vi.mocked(localTransferService.getPeerById).mockReturnValue({ + id: 'test-peer', + name: 'Test Peer', + addresses: ['192.168.1.100'], + updatedAt: Date.now() + }) + + const { lanTransferClientService } = await import('../LanTransferClientService') + + await expect( + lanTransferClientService.connectAndHandshake({ + peerId: 'test-peer' + }) + ).rejects.toThrow('Selected peer does not expose a TCP port') + }) + + it('should throw error when no reachable host', async () => { + vi.mocked(localTransferService.getPeerById).mockReturnValue({ + id: 'test-peer', + name: 'Test Peer', + port: 12345, + addresses: [], + updatedAt: Date.now() + }) + + const { lanTransferClientService } = await import('../LanTransferClientService') + + await expect( + lanTransferClientService.connectAndHandshake({ + peerId: 'test-peer' + }) + ).rejects.toThrow('Unable to resolve a reachable host for the peer') + }) + }) + + describe('cancelTransfer', () => { + it('should not throw when no active transfer', async () => { + const { lanTransferClientService } = await import('../LanTransferClientService') + + // Should not throw, just log warning + expect(() => lanTransferClientService.cancelTransfer()).not.toThrow() + }) + }) + + describe('dispose', () => { + it('should clean up resources without throwing', async () => { + const { lanTransferClientService } = await import('../LanTransferClientService') + + // Should not throw + expect(() => lanTransferClientService.dispose()).not.toThrow() + }) + }) + + describe('sendFile', () => { + it('should throw error when not connected', async () => { + const { lanTransferClientService } = await import('../LanTransferClientService') + + await expect(lanTransferClientService.sendFile('/path/to/file.zip')).rejects.toThrow( + 'No active connection. Please connect to a peer first.' + ) + }) + }) + + describe('HANDSHAKE_PROTOCOL_VERSION', () => { + it('should export protocol version', async () => { + const { HANDSHAKE_PROTOCOL_VERSION } = await import('../LanTransferClientService') + + expect(HANDSHAKE_PROTOCOL_VERSION).toBe('1') + }) + }) +}) diff --git a/src/main/services/lanTransfer/__tests__/binaryProtocol.test.ts b/src/main/services/lanTransfer/__tests__/binaryProtocol.test.ts new file mode 100644 index 0000000000..c485a33098 --- /dev/null +++ b/src/main/services/lanTransfer/__tests__/binaryProtocol.test.ts @@ -0,0 +1,103 @@ +import { EventEmitter } from 'node:events' +import type { Socket } from 'node:net' + +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { BINARY_TYPE_FILE_CHUNK, sendBinaryChunk } from '../binaryProtocol' + +describe('binaryProtocol', () => { + describe('sendBinaryChunk', () => { + let mockSocket: Socket + let writtenBuffers: Buffer[] + + beforeEach(() => { + writtenBuffers = [] + mockSocket = Object.assign(new EventEmitter(), { + destroyed: false, + writable: true, + write: vi.fn((buffer: Buffer) => { + writtenBuffers.push(Buffer.from(buffer)) + return true + }), + cork: vi.fn(), + uncork: vi.fn() + }) as unknown as Socket + }) + + it('should send binary chunk with correct frame format', () => { + const transferId = 'test-uuid-1234' + const chunkIndex = 5 + const data = Buffer.from('test data chunk') + + const result = sendBinaryChunk(mockSocket, transferId, chunkIndex, data) + + expect(result).toBe(true) + expect(mockSocket.cork).toHaveBeenCalled() + expect(mockSocket.uncork).toHaveBeenCalled() + expect(mockSocket.write).toHaveBeenCalledTimes(2) + + // Verify header structure + const header = writtenBuffers[0] + + // Magic bytes "CS" + expect(header[0]).toBe(0x43) + expect(header[1]).toBe(0x53) + + // Type byte + const typeOffset = 2 + 4 // magic + totalLen + expect(header[typeOffset]).toBe(BINARY_TYPE_FILE_CHUNK) + + // TransferId length + const tidLenOffset = typeOffset + 1 + const tidLen = header.readUInt16BE(tidLenOffset) + expect(tidLen).toBe(Buffer.from(transferId).length) + + // ChunkIndex + const chunkIdxOffset = tidLenOffset + 2 + tidLen + expect(header.readUInt32BE(chunkIdxOffset)).toBe(chunkIndex) + + // Data buffer + expect(writtenBuffers[1].toString()).toBe('test data chunk') + }) + + it('should return false when socket write returns false (backpressure)', () => { + ;(mockSocket.write as ReturnType).mockReturnValueOnce(false) + + const result = sendBinaryChunk(mockSocket, 'test-id', 0, Buffer.from('data')) + + expect(result).toBe(false) + }) + + it('should correctly calculate totalLen in frame header', () => { + const transferId = 'uuid-1234' + const data = Buffer.from('chunk data here') + + sendBinaryChunk(mockSocket, transferId, 0, data) + + const header = writtenBuffers[0] + const totalLen = header.readUInt32BE(2) // After magic bytes + + // totalLen = type(1) + tidLen(2) + tid(n) + idx(4) + data(m) + const expectedTotalLen = 1 + 2 + Buffer.from(transferId).length + 4 + data.length + expect(totalLen).toBe(expectedTotalLen) + }) + + it('should throw error when socket is not writable', () => { + ;(mockSocket as any).writable = false + + expect(() => sendBinaryChunk(mockSocket, 'test-id', 0, Buffer.from('data'))).toThrow('Socket is not writable') + }) + + it('should throw error when socket is destroyed', () => { + ;(mockSocket as any).destroyed = true + + expect(() => sendBinaryChunk(mockSocket, 'test-id', 0, Buffer.from('data'))).toThrow('Socket is not writable') + }) + }) + + describe('BINARY_TYPE_FILE_CHUNK', () => { + it('should be 0x01', () => { + expect(BINARY_TYPE_FILE_CHUNK).toBe(0x01) + }) + }) +}) diff --git a/src/main/services/lanTransfer/__tests__/handlers/connection.test.ts b/src/main/services/lanTransfer/__tests__/handlers/connection.test.ts new file mode 100644 index 0000000000..3983e538d3 --- /dev/null +++ b/src/main/services/lanTransfer/__tests__/handlers/connection.test.ts @@ -0,0 +1,265 @@ +import { EventEmitter } from 'node:events' +import type { Socket } from 'node:net' + +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import { + buildHandshakeMessage, + createDataHandler, + getAbortError, + HANDSHAKE_PROTOCOL_VERSION, + pickHost, + waitForSocketDrain +} from '../../handlers/connection' + +// Mock electron app +vi.mock('electron', () => ({ + app: { + getName: vi.fn(() => 'Cherry Studio'), + getVersion: vi.fn(() => '1.0.0') + } +})) + +describe('connection handlers', () => { + describe('buildHandshakeMessage', () => { + it('should build handshake message with correct structure', () => { + const message = buildHandshakeMessage() + + expect(message.type).toBe('handshake') + expect(message.deviceName).toBe('Cherry Studio') + expect(message.version).toBe(HANDSHAKE_PROTOCOL_VERSION) + expect(message.appVersion).toBe('1.0.0') + expect(typeof message.platform).toBe('string') + }) + + it('should use protocol version 1', () => { + expect(HANDSHAKE_PROTOCOL_VERSION).toBe('1') + }) + }) + + describe('pickHost', () => { + it('should prefer IPv4 addresses', () => { + const peer = { + id: '1', + name: 'Test', + addresses: ['fe80::1', '192.168.1.100', '::1'], + updatedAt: Date.now() + } + + expect(pickHost(peer)).toBe('192.168.1.100') + }) + + it('should fall back to first address if no IPv4', () => { + const peer = { + id: '1', + name: 'Test', + addresses: ['fe80::1', '::1'], + updatedAt: Date.now() + } + + expect(pickHost(peer)).toBe('fe80::1') + }) + + it('should fall back to host property if no addresses', () => { + const peer = { + id: '1', + name: 'Test', + host: 'example.local', + addresses: [], + updatedAt: Date.now() + } + + expect(pickHost(peer)).toBe('example.local') + }) + + it('should return undefined if no addresses or host', () => { + const peer = { + id: '1', + name: 'Test', + addresses: [], + updatedAt: Date.now() + } + + expect(pickHost(peer)).toBeUndefined() + }) + }) + + describe('createDataHandler', () => { + it('should parse complete lines from buffer', () => { + const lines: string[] = [] + const handler = createDataHandler((line) => lines.push(line)) + + handler.handleData(Buffer.from('{"type":"test"}\n')) + + expect(lines).toEqual(['{"type":"test"}']) + }) + + it('should handle partial lines across multiple chunks', () => { + const lines: string[] = [] + const handler = createDataHandler((line) => lines.push(line)) + + handler.handleData(Buffer.from('{"type":')) + handler.handleData(Buffer.from('"test"}\n')) + + expect(lines).toEqual(['{"type":"test"}']) + }) + + it('should handle multiple lines in single chunk', () => { + const lines: string[] = [] + const handler = createDataHandler((line) => lines.push(line)) + + handler.handleData(Buffer.from('{"a":1}\n{"b":2}\n')) + + expect(lines).toEqual(['{"a":1}', '{"b":2}']) + }) + + it('should reset buffer', () => { + const lines: string[] = [] + const handler = createDataHandler((line) => lines.push(line)) + + handler.handleData(Buffer.from('partial')) + handler.resetBuffer() + handler.handleData(Buffer.from('{"complete":true}\n')) + + expect(lines).toEqual(['{"complete":true}']) + }) + + it('should trim whitespace from lines', () => { + const lines: string[] = [] + const handler = createDataHandler((line) => lines.push(line)) + + handler.handleData(Buffer.from(' {"type":"test"} \n')) + + expect(lines).toEqual(['{"type":"test"}']) + }) + + it('should skip empty lines', () => { + const lines: string[] = [] + const handler = createDataHandler((line) => lines.push(line)) + + handler.handleData(Buffer.from('\n\n{"type":"test"}\n\n')) + + expect(lines).toEqual(['{"type":"test"}']) + }) + + it('should throw error when buffer exceeds MAX_LINE_BUFFER_SIZE', () => { + const handler = createDataHandler(vi.fn()) + + // Create a buffer larger than 1MB (MAX_LINE_BUFFER_SIZE) + const largeData = 'x'.repeat(1024 * 1024 + 1) + + expect(() => handler.handleData(Buffer.from(largeData))).toThrow('Control message too large') + }) + + it('should reset buffer after exceeding MAX_LINE_BUFFER_SIZE', () => { + const lines: string[] = [] + const handler = createDataHandler((line) => lines.push(line)) + + // Create a buffer larger than 1MB + const largeData = 'x'.repeat(1024 * 1024 + 1) + + try { + handler.handleData(Buffer.from(largeData)) + } catch { + // Expected error + } + + // Buffer should be reset, so lineBuffer should be empty + expect(handler.lineBuffer).toBe('') + }) + }) + + describe('waitForSocketDrain', () => { + let mockSocket: Socket & EventEmitter + + beforeEach(() => { + mockSocket = Object.assign(new EventEmitter(), { + destroyed: false, + writable: true, + write: vi.fn(), + off: vi.fn(), + removeAllListeners: vi.fn() + }) as unknown as Socket & EventEmitter + }) + + afterEach(() => { + vi.resetAllMocks() + }) + + it('should throw error when abort signal is already aborted', async () => { + const abortController = new AbortController() + abortController.abort(new Error('Already aborted')) + + await expect(waitForSocketDrain(mockSocket, abortController.signal)).rejects.toThrow('Already aborted') + }) + + it('should throw error when socket is destroyed', async () => { + ;(mockSocket as any).destroyed = true + const abortController = new AbortController() + + await expect(waitForSocketDrain(mockSocket, abortController.signal)).rejects.toThrow('Socket is closed') + }) + + it('should resolve when drain event is emitted', async () => { + const abortController = new AbortController() + + const drainPromise = waitForSocketDrain(mockSocket, abortController.signal) + + // Emit drain event after a short delay + setImmediate(() => mockSocket.emit('drain')) + + await expect(drainPromise).resolves.toBeUndefined() + }) + + it('should reject when close event is emitted', async () => { + const abortController = new AbortController() + + const drainPromise = waitForSocketDrain(mockSocket, abortController.signal) + + setImmediate(() => mockSocket.emit('close')) + + await expect(drainPromise).rejects.toThrow('Socket closed while waiting for drain') + }) + + it('should reject when error event is emitted', async () => { + const abortController = new AbortController() + + const drainPromise = waitForSocketDrain(mockSocket, abortController.signal) + + setImmediate(() => mockSocket.emit('error', new Error('Network error'))) + + await expect(drainPromise).rejects.toThrow('Network error') + }) + + it('should reject when abort signal is triggered', async () => { + const abortController = new AbortController() + + const drainPromise = waitForSocketDrain(mockSocket, abortController.signal) + + setImmediate(() => abortController.abort(new Error('User cancelled'))) + + await expect(drainPromise).rejects.toThrow('User cancelled') + }) + }) + + describe('getAbortError', () => { + it('should return Error reason directly', () => { + const originalError = new Error('Original') + const signal = { aborted: true, reason: originalError } as AbortSignal + + expect(getAbortError(signal, 'Fallback')).toBe(originalError) + }) + + it('should create Error from string reason', () => { + const signal = { aborted: true, reason: 'String reason' } as AbortSignal + + expect(getAbortError(signal, 'Fallback').message).toBe('String reason') + }) + + it('should use fallback for empty reason', () => { + const signal = { aborted: true, reason: '' } as AbortSignal + + expect(getAbortError(signal, 'Fallback').message).toBe('Fallback') + }) + }) +}) diff --git a/src/main/services/lanTransfer/__tests__/handlers/fileTransfer.test.ts b/src/main/services/lanTransfer/__tests__/handlers/fileTransfer.test.ts new file mode 100644 index 0000000000..814fd2f5c9 --- /dev/null +++ b/src/main/services/lanTransfer/__tests__/handlers/fileTransfer.test.ts @@ -0,0 +1,216 @@ +import { EventEmitter } from 'node:events' +import type * as fs from 'node:fs' +import type { Socket } from 'node:net' + +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import { + abortTransfer, + cleanupTransfer, + createTransferState, + formatFileSize, + streamFileChunks +} from '../../handlers/fileTransfer' +import type { ActiveFileTransfer } from '../../types' + +// Mock binaryProtocol +vi.mock('../../binaryProtocol', () => ({ + sendBinaryChunk: vi.fn().mockReturnValue(true) +})) + +// Mock connection handlers +vi.mock('./connection', () => ({ + waitForSocketDrain: vi.fn().mockResolvedValue(undefined), + getAbortError: vi.fn((signal, fallback) => { + const reason = (signal as AbortSignal & { reason?: unknown }).reason + if (reason instanceof Error) return reason + if (typeof reason === 'string' && reason.length > 0) return new Error(reason) + return new Error(fallback) + }) +})) + +// Note: validateFile and calculateFileChecksum tests are skipped because +// the test environment has globally mocked node:fs and node:os modules. +// These functions are tested through integration tests instead. + +describe('fileTransfer handlers', () => { + describe('createTransferState', () => { + it('should create transfer state with correct defaults', () => { + const state = createTransferState('uuid-123', 'test.zip', 1024000, 'abc123') + + expect(state.transferId).toBe('uuid-123') + expect(state.fileName).toBe('test.zip') + expect(state.fileSize).toBe(1024000) + expect(state.checksum).toBe('abc123') + expect(state.bytesSent).toBe(0) + expect(state.currentChunk).toBe(0) + expect(state.isCancelled).toBe(false) + expect(state.abortController).toBeInstanceOf(AbortController) + }) + + it('should calculate totalChunks based on chunk size', () => { + // 512KB chunk size + const state = createTransferState('id', 'test.zip', 1024 * 1024, 'checksum') // 1MB + + expect(state.totalChunks).toBe(2) // 1MB / 512KB = 2 + }) + }) + + describe('abortTransfer', () => { + it('should abort transfer and destroy stream', () => { + const mockStream = { + destroyed: false, + destroy: vi.fn() + } as unknown as fs.ReadStream + + const transfer: ActiveFileTransfer = { + transferId: 'test', + fileName: 'test.zip', + fileSize: 1000, + checksum: 'abc', + totalChunks: 1, + chunkSize: 512000, + bytesSent: 0, + currentChunk: 0, + startedAt: Date.now(), + stream: mockStream, + isCancelled: false, + abortController: new AbortController() + } + + const error = new Error('Test abort') + abortTransfer(transfer, error) + + expect(transfer.isCancelled).toBe(true) + expect(transfer.abortController.signal.aborted).toBe(true) + expect(mockStream.destroy).toHaveBeenCalledWith(error) + }) + + it('should handle undefined transfer', () => { + expect(() => abortTransfer(undefined, new Error('test'))).not.toThrow() + }) + + it('should not abort already aborted controller', () => { + const transfer: ActiveFileTransfer = { + transferId: 'test', + fileName: 'test.zip', + fileSize: 1000, + checksum: 'abc', + totalChunks: 1, + chunkSize: 512000, + bytesSent: 0, + currentChunk: 0, + startedAt: Date.now(), + isCancelled: false, + abortController: new AbortController() + } + + transfer.abortController.abort() + + // Should not throw when aborting again + expect(() => abortTransfer(transfer, new Error('test'))).not.toThrow() + }) + }) + + describe('cleanupTransfer', () => { + it('should cleanup transfer resources', () => { + const mockStream = { + destroyed: false, + destroy: vi.fn() + } as unknown as fs.ReadStream + + const transfer: ActiveFileTransfer = { + transferId: 'test', + fileName: 'test.zip', + fileSize: 1000, + checksum: 'abc', + totalChunks: 1, + chunkSize: 512000, + bytesSent: 0, + currentChunk: 0, + startedAt: Date.now(), + stream: mockStream, + isCancelled: false, + abortController: new AbortController() + } + + cleanupTransfer(transfer) + + expect(transfer.abortController.signal.aborted).toBe(true) + expect(mockStream.destroy).toHaveBeenCalled() + }) + + it('should handle undefined transfer', () => { + expect(() => cleanupTransfer(undefined)).not.toThrow() + }) + }) + + describe('formatFileSize', () => { + it('should format 0 bytes', () => { + expect(formatFileSize(0)).toBe('0 B') + }) + + it('should format bytes', () => { + expect(formatFileSize(500)).toBe('500 B') + }) + + it('should format kilobytes', () => { + expect(formatFileSize(1024)).toBe('1 KB') + expect(formatFileSize(2048)).toBe('2 KB') + }) + + it('should format megabytes', () => { + expect(formatFileSize(1024 * 1024)).toBe('1 MB') + expect(formatFileSize(5 * 1024 * 1024)).toBe('5 MB') + }) + + it('should format gigabytes', () => { + expect(formatFileSize(1024 * 1024 * 1024)).toBe('1 GB') + }) + + it('should format with decimal precision', () => { + expect(formatFileSize(1536)).toBe('1.5 KB') + expect(formatFileSize(1.5 * 1024 * 1024)).toBe('1.5 MB') + }) + }) + + // Note: streamFileChunks tests require careful mocking of fs.createReadStream + // which is globally mocked in the test environment. These tests verify the + // streaming logic works correctly with mock streams. + describe('streamFileChunks', () => { + let mockSocket: Socket & EventEmitter + let mockProgress: ReturnType + + beforeEach(() => { + vi.clearAllMocks() + + mockSocket = Object.assign(new EventEmitter(), { + destroyed: false, + writable: true, + write: vi.fn().mockReturnValue(true), + cork: vi.fn(), + uncork: vi.fn() + }) as unknown as Socket & EventEmitter + + mockProgress = vi.fn() + }) + + afterEach(() => { + vi.resetAllMocks() + }) + + it('should throw when abort signal is already aborted', async () => { + const transfer = createTransferState('test-id', 'test.zip', 1024, 'checksum') + transfer.abortController.abort(new Error('Already cancelled')) + + await expect( + streamFileChunks(mockSocket, '/fake/path.zip', transfer, transfer.abortController.signal, mockProgress) + ).rejects.toThrow() + }) + + // Note: Full integration testing of streamFileChunks with actual file streaming + // requires a real file system, which cannot be easily mocked in ESM. + // The abort signal test above verifies the early abort path. + // Additional streaming tests are covered through integration tests. + }) +}) diff --git a/src/main/services/lanTransfer/__tests__/responseManager.test.ts b/src/main/services/lanTransfer/__tests__/responseManager.test.ts new file mode 100644 index 0000000000..170ee2de8c --- /dev/null +++ b/src/main/services/lanTransfer/__tests__/responseManager.test.ts @@ -0,0 +1,177 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import { ResponseManager } from '../responseManager' + +describe('ResponseManager', () => { + let manager: ResponseManager + + beforeEach(() => { + vi.useFakeTimers() + manager = new ResponseManager() + }) + + afterEach(() => { + vi.useRealTimers() + }) + + describe('buildResponseKey', () => { + it('should build key with type only', () => { + expect(manager.buildResponseKey('handshake_ack')).toBe('handshake_ack') + }) + + it('should build key with type and transferId', () => { + expect(manager.buildResponseKey('file_start_ack', 'uuid-123')).toBe('file_start_ack:uuid-123') + }) + + it('should build key with type, transferId, and chunkIndex', () => { + expect(manager.buildResponseKey('file_chunk_ack', 'uuid-123', 5)).toBe('file_chunk_ack:uuid-123:5') + }) + }) + + describe('waitForResponse', () => { + it('should resolve when tryResolve is called with matching key', async () => { + const resolvePromise = new Promise((resolve, reject) => { + manager.waitForResponse('handshake_ack', 5000, resolve, reject) + }) + + const payload = { type: 'handshake_ack', accepted: true } + const resolved = manager.tryResolve('handshake_ack', payload) + + expect(resolved).toBe(true) + await expect(resolvePromise).resolves.toEqual(payload) + }) + + it('should reject on timeout', async () => { + const resolvePromise = new Promise((resolve, reject) => { + manager.waitForResponse('handshake_ack', 1000, resolve, reject) + }) + + vi.advanceTimersByTime(1001) + + await expect(resolvePromise).rejects.toThrow('Timeout waiting for handshake_ack') + }) + + it('should call onTimeout callback when timeout occurs', async () => { + const onTimeout = vi.fn() + manager.setTimeoutCallback(onTimeout) + + const resolvePromise = new Promise((resolve, reject) => { + manager.waitForResponse('test', 1000, resolve, reject) + }) + + vi.advanceTimersByTime(1001) + + await expect(resolvePromise).rejects.toThrow() + expect(onTimeout).toHaveBeenCalled() + }) + + it('should reject when abort signal is triggered', async () => { + const abortController = new AbortController() + + const resolvePromise = new Promise((resolve, reject) => { + manager.waitForResponse('test', 10000, resolve, reject, undefined, undefined, abortController.signal) + }) + + abortController.abort(new Error('User cancelled')) + + await expect(resolvePromise).rejects.toThrow('User cancelled') + }) + + it('should replace existing response with same key', async () => { + const firstReject = vi.fn() + const secondResolve = vi.fn() + const secondReject = vi.fn() + + manager.waitForResponse('test', 5000, vi.fn(), firstReject) + manager.waitForResponse('test', 5000, secondResolve, secondReject) + + // First should be cleared (no rejection since it's replaced) + const payload = { type: 'test' } + manager.tryResolve('test', payload) + + expect(secondResolve).toHaveBeenCalledWith(payload) + }) + }) + + describe('tryResolve', () => { + it('should return false when no matching response', () => { + expect(manager.tryResolve('nonexistent', {})).toBe(false) + }) + + it('should match with transferId', async () => { + const resolvePromise = new Promise((resolve, reject) => { + manager.waitForResponse('file_start_ack', 5000, resolve, reject, 'uuid-123') + }) + + const payload = { type: 'file_start_ack', transferId: 'uuid-123' } + manager.tryResolve('file_start_ack', payload, 'uuid-123') + + await expect(resolvePromise).resolves.toEqual(payload) + }) + }) + + describe('rejectAll', () => { + it('should reject all pending responses', async () => { + const promises = [ + new Promise((resolve, reject) => { + manager.waitForResponse('test1', 5000, resolve, reject) + }), + new Promise((resolve, reject) => { + manager.waitForResponse('test2', 5000, resolve, reject, 'uuid') + }) + ] + + manager.rejectAll(new Error('Connection closed')) + + await expect(promises[0]).rejects.toThrow('Connection closed') + await expect(promises[1]).rejects.toThrow('Connection closed') + }) + }) + + describe('clearPendingResponse', () => { + it('should clear specific response by key', () => { + manager.waitForResponse('test', 5000, vi.fn(), vi.fn()) + + manager.clearPendingResponse('test') + + expect(manager.tryResolve('test', {})).toBe(false) + }) + + it('should clear all responses when no key provided', () => { + manager.waitForResponse('test1', 5000, vi.fn(), vi.fn()) + manager.waitForResponse('test2', 5000, vi.fn(), vi.fn()) + + manager.clearPendingResponse() + + expect(manager.tryResolve('test1', {})).toBe(false) + expect(manager.tryResolve('test2', {})).toBe(false) + }) + }) + + describe('getAbortError', () => { + it('should return Error reason directly', () => { + const originalError = new Error('Original error') + const signal = { aborted: true, reason: originalError } as AbortSignal + + const error = manager.getAbortError(signal, 'Fallback') + + expect(error).toBe(originalError) + }) + + it('should create Error from string reason', () => { + const signal = { aborted: true, reason: 'String reason' } as AbortSignal + + const error = manager.getAbortError(signal, 'Fallback') + + expect(error.message).toBe('String reason') + }) + + it('should use fallback message when no reason', () => { + const signal = { aborted: true } as AbortSignal + + const error = manager.getAbortError(signal, 'Fallback message') + + expect(error.message).toBe('Fallback message') + }) + }) +}) diff --git a/src/main/services/lanTransfer/binaryProtocol.ts b/src/main/services/lanTransfer/binaryProtocol.ts new file mode 100644 index 0000000000..864a8b95bd --- /dev/null +++ b/src/main/services/lanTransfer/binaryProtocol.ts @@ -0,0 +1,67 @@ +import type { Socket } from 'node:net' + +/** + * Binary protocol constants (v1) + */ +export const BINARY_TYPE_FILE_CHUNK = 0x01 + +/** + * Send file chunk as binary frame (protocol v1 - streaming mode) + * + * Frame format: + * ``` + * ┌──────────┬──────────┬──────────┬───────────────┬──────────────┬────────────┬───────────┐ + * │ Magic │ TotalLen │ Type │ TransferId Len│ TransferId │ ChunkIdx │ Data │ + * │ 0x43 0x53│ (4B BE) │ 0x01 │ (2B BE) │ (variable) │ (4B BE) │ (raw) │ + * └──────────┴──────────┴──────────┴───────────────┴──────────────┴────────────┴───────────┘ + * ``` + * + * @param socket - TCP socket to write to + * @param transferId - UUID of the transfer + * @param chunkIndex - Index of the chunk (0-based) + * @param data - Raw chunk data buffer + * @returns true if data was buffered, false if backpressure should be applied + */ +export function sendBinaryChunk(socket: Socket, transferId: string, chunkIndex: number, data: Buffer): boolean { + if (!socket || socket.destroyed || !socket.writable) { + throw new Error('Socket is not writable') + } + + const tidBuffer = Buffer.from(transferId, 'utf8') + const tidLen = tidBuffer.length + + // totalLen = type(1) + tidLen(2) + tid(n) + idx(4) + data(m) + const totalLen = 1 + 2 + tidLen + 4 + data.length + + const header = Buffer.allocUnsafe(2 + 4 + 1 + 2 + tidLen + 4) + let offset = 0 + + // Magic (2 bytes): "CS" + header[offset++] = 0x43 + header[offset++] = 0x53 + + // TotalLen (4 bytes, Big-Endian) + header.writeUInt32BE(totalLen, offset) + offset += 4 + + // Type (1 byte) + header[offset++] = BINARY_TYPE_FILE_CHUNK + + // TransferId length (2 bytes, Big-Endian) + header.writeUInt16BE(tidLen, offset) + offset += 2 + + // TransferId (variable) + tidBuffer.copy(header, offset) + offset += tidLen + + // ChunkIndex (4 bytes, Big-Endian) + header.writeUInt32BE(chunkIndex, offset) + + socket.cork() + const wroteHeader = socket.write(header) + const wroteData = socket.write(data) + socket.uncork() + + return wroteHeader && wroteData +} diff --git a/src/main/services/lanTransfer/handlers/connection.ts b/src/main/services/lanTransfer/handlers/connection.ts new file mode 100644 index 0000000000..5a53eeb373 --- /dev/null +++ b/src/main/services/lanTransfer/handlers/connection.ts @@ -0,0 +1,162 @@ +import { isIP, type Socket } from 'node:net' +import { platform } from 'node:os' + +import { loggerService } from '@logger' +import type { LanHandshakeRequestMessage, LocalTransferPeer } from '@shared/config/types' +import { app } from 'electron' + +import type { ConnectionContext } from '../types' + +export const HANDSHAKE_PROTOCOL_VERSION = '1' + +/** Maximum size for line buffer to prevent memory exhaustion from malicious peers */ +const MAX_LINE_BUFFER_SIZE = 1024 * 1024 // 1MB limit for control messages + +const logger = loggerService.withContext('LanTransferConnection') + +/** + * Build a handshake request message with device info. + */ +export function buildHandshakeMessage(): LanHandshakeRequestMessage { + return { + type: 'handshake', + deviceName: app.getName(), + version: HANDSHAKE_PROTOCOL_VERSION, + platform: platform(), + appVersion: app.getVersion() + } +} + +/** + * Pick the best host address from a peer's available addresses. + * Prefers IPv4 addresses over IPv6. + */ +export function pickHost(peer: LocalTransferPeer): string | undefined { + const preferred = peer.addresses?.find((addr) => isIP(addr) === 4) || peer.addresses?.[0] + return preferred || peer.host +} + +/** + * Send a test ping message after successful handshake. + */ +export function sendTestPing(ctx: ConnectionContext): void { + const payload = 'hello world' + try { + ctx.sendControlMessage({ type: 'ping', payload }) + logger.info('Sent LAN ping test payload') + ctx.broadcastClientEvent({ + type: 'ping_sent', + payload, + timestamp: Date.now() + }) + } catch (error) { + const message = error instanceof Error ? error.message : String(error) + logger.error('Failed to send LAN test ping', error as Error) + ctx.broadcastClientEvent({ + type: 'error', + message, + timestamp: Date.now() + }) + } +} + +/** + * Attach data listener to socket for receiving control messages. + * Returns a function to parse the line buffer. + */ +export function createDataHandler(onControlLine: (line: string) => void): { + lineBuffer: string + handleData: (chunk: Buffer) => void + resetBuffer: () => void +} { + let lineBuffer = '' + + return { + get lineBuffer() { + return lineBuffer + }, + handleData(chunk: Buffer) { + lineBuffer += chunk.toString('utf8') + + // Prevent memory exhaustion from malicious peers sending data without newlines + if (lineBuffer.length > MAX_LINE_BUFFER_SIZE) { + logger.error('Line buffer exceeded maximum size, resetting') + lineBuffer = '' + throw new Error('Control message too large') + } + + let newlineIndex = lineBuffer.indexOf('\n') + while (newlineIndex !== -1) { + const line = lineBuffer.slice(0, newlineIndex).trim() + lineBuffer = lineBuffer.slice(newlineIndex + 1) + if (line.length > 0) { + onControlLine(line) + } + newlineIndex = lineBuffer.indexOf('\n') + } + }, + resetBuffer() { + lineBuffer = '' + } + } +} + +/** + * Wait for socket to drain (backpressure handling). + */ +export async function waitForSocketDrain(socket: Socket, abortSignal: AbortSignal): Promise { + if (abortSignal.aborted) { + throw getAbortError(abortSignal, 'Transfer aborted while waiting for socket drain') + } + if (socket.destroyed) { + throw new Error('Socket is closed') + } + + await new Promise((resolve, reject) => { + const cleanup = () => { + socket.off('drain', onDrain) + socket.off('close', onClose) + socket.off('error', onError) + abortSignal.removeEventListener('abort', onAbort) + } + + const onDrain = () => { + cleanup() + resolve() + } + + const onClose = () => { + cleanup() + reject(new Error('Socket closed while waiting for drain')) + } + + const onError = (error: Error) => { + cleanup() + reject(error) + } + + const onAbort = () => { + cleanup() + reject(getAbortError(abortSignal, 'Transfer aborted while waiting for socket drain')) + } + + socket.once('drain', onDrain) + socket.once('close', onClose) + socket.once('error', onError) + abortSignal.addEventListener('abort', onAbort, { once: true }) + }) +} + +/** + * Get the error from an abort signal, or create a fallback error. + */ +export function getAbortError(signal: AbortSignal, fallbackMessage: string): Error { + const reason = (signal as AbortSignal & { reason?: unknown }).reason + if (reason instanceof Error) { + return reason + } + if (typeof reason === 'string' && reason.length > 0) { + return new Error(reason) + } + return new Error(fallbackMessage) +} diff --git a/src/main/services/lanTransfer/handlers/fileTransfer.ts b/src/main/services/lanTransfer/handlers/fileTransfer.ts new file mode 100644 index 0000000000..c469a58421 --- /dev/null +++ b/src/main/services/lanTransfer/handlers/fileTransfer.ts @@ -0,0 +1,267 @@ +import * as crypto from 'node:crypto' +import * as fs from 'node:fs' +import type { Socket } from 'node:net' +import * as path from 'node:path' + +import { loggerService } from '@logger' +import type { + LanFileCompleteMessage, + LanFileEndMessage, + LanFileStartAckMessage, + LanFileStartMessage +} from '@shared/config/types' +import { + LAN_TRANSFER_CHUNK_SIZE, + LAN_TRANSFER_COMPLETE_TIMEOUT_MS, + LAN_TRANSFER_MAX_FILE_SIZE +} from '@shared/config/types' + +import { sendBinaryChunk } from '../binaryProtocol' +import type { ActiveFileTransfer, FileTransferContext } from '../types' +import { getAbortError, waitForSocketDrain } from './connection' + +const DEFAULT_FILE_START_ACK_TIMEOUT_MS = 30_000 // 30s for file_start_ack + +const logger = loggerService.withContext('LanTransferFileHandler') + +/** + * Validate a file for transfer. + * Checks existence, type, extension, and size limits. + */ +export async function validateFile(filePath: string): Promise<{ stats: fs.Stats; fileName: string }> { + let stats: fs.Stats + try { + stats = await fs.promises.stat(filePath) + } catch (error) { + const nodeError = error as NodeJS.ErrnoException + if (nodeError.code === 'ENOENT') { + throw new Error(`File not found: ${filePath}`) + } else if (nodeError.code === 'EACCES') { + throw new Error(`Permission denied: ${filePath}`) + } else if (nodeError.code === 'ENOTDIR') { + throw new Error(`Invalid path: ${filePath}`) + } else { + throw new Error(`Cannot access file: ${filePath} (${nodeError.code || 'unknown error'})`) + } + } + + if (!stats.isFile()) { + throw new Error('Path is not a file') + } + + const fileName = path.basename(filePath) + const ext = path.extname(fileName).toLowerCase() + if (ext !== '.zip') { + throw new Error('Only ZIP files are supported') + } + + if (stats.size > LAN_TRANSFER_MAX_FILE_SIZE) { + throw new Error(`File too large. Maximum size is ${formatFileSize(LAN_TRANSFER_MAX_FILE_SIZE)}`) + } + + return { stats, fileName } +} + +/** + * Calculate SHA-256 checksum of a file. + */ +export async function calculateFileChecksum(filePath: string): Promise { + return new Promise((resolve, reject) => { + const hash = crypto.createHash('sha256') + const stream = fs.createReadStream(filePath) + stream.on('data', (data) => hash.update(data)) + stream.on('end', () => resolve(hash.digest('hex'))) + stream.on('error', reject) + }) +} + +/** + * Create initial transfer state for a new file transfer. + */ +export function createTransferState( + transferId: string, + fileName: string, + fileSize: number, + checksum: string +): ActiveFileTransfer { + const chunkSize = LAN_TRANSFER_CHUNK_SIZE + const totalChunks = Math.ceil(fileSize / chunkSize) + + return { + transferId, + fileName, + fileSize, + checksum, + totalChunks, + chunkSize, + bytesSent: 0, + currentChunk: 0, + startedAt: Date.now(), + isCancelled: false, + abortController: new AbortController() + } +} + +/** + * Send file_start message to receiver. + */ +export function sendFileStart(ctx: FileTransferContext, transfer: ActiveFileTransfer): void { + const startMessage: LanFileStartMessage = { + type: 'file_start', + transferId: transfer.transferId, + fileName: transfer.fileName, + fileSize: transfer.fileSize, + mimeType: 'application/zip', + checksum: transfer.checksum, + totalChunks: transfer.totalChunks, + chunkSize: transfer.chunkSize + } + ctx.sendControlMessage(startMessage) + logger.info('Sent file_start message') +} + +/** + * Wait for file_start_ack from receiver. + */ +export function waitForFileStartAck( + ctx: FileTransferContext, + transferId: string, + abortSignal?: AbortSignal +): Promise { + return new Promise((resolve, reject) => { + ctx.waitForResponse( + 'file_start_ack', + DEFAULT_FILE_START_ACK_TIMEOUT_MS, + (payload) => resolve(payload as LanFileStartAckMessage), + reject, + transferId, + undefined, + abortSignal + ) + }) +} + +/** + * Wait for file_complete from receiver after all chunks sent. + */ +export function waitForFileComplete( + ctx: FileTransferContext, + transferId: string, + abortSignal?: AbortSignal +): Promise { + return new Promise((resolve, reject) => { + ctx.waitForResponse( + 'file_complete', + LAN_TRANSFER_COMPLETE_TIMEOUT_MS, + (payload) => resolve(payload as LanFileCompleteMessage), + reject, + transferId, + undefined, + abortSignal + ) + }) +} + +/** + * Send file_end message to receiver. + */ +export function sendFileEnd(ctx: FileTransferContext, transferId: string): void { + const endMessage: LanFileEndMessage = { + type: 'file_end', + transferId + } + ctx.sendControlMessage(endMessage) + logger.info('Sent file_end message') +} + +/** + * Stream file chunks to the receiver (v1 streaming mode - no per-chunk acknowledgment). + */ +export async function streamFileChunks( + socket: Socket, + filePath: string, + transfer: ActiveFileTransfer, + abortSignal: AbortSignal, + onProgress: (bytesSent: number, chunkIndex: number) => void +): Promise { + const { chunkSize, transferId } = transfer + + const stream = fs.createReadStream(filePath, { highWaterMark: chunkSize }) + transfer.stream = stream + + let chunkIndex = 0 + let bytesSent = 0 + + try { + for await (const chunk of stream) { + if (abortSignal.aborted) { + throw getAbortError(abortSignal, 'Transfer aborted') + } + + const buffer = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk) + bytesSent += buffer.length + + // Send chunk as binary frame (v1 streaming) with backpressure handling + const canContinue = sendBinaryChunk(socket, transferId, chunkIndex, buffer) + if (!canContinue) { + await waitForSocketDrain(socket, abortSignal) + } + + // Update progress + transfer.bytesSent = bytesSent + transfer.currentChunk = chunkIndex + + onProgress(bytesSent, chunkIndex) + chunkIndex++ + } + + logger.info(`File streaming completed: ${chunkIndex} chunks sent`) + } catch (error) { + logger.error('File streaming failed', error as Error) + throw error + } +} + +/** + * Abort an active transfer and clean up resources. + */ +export function abortTransfer(transfer: ActiveFileTransfer | undefined, error: Error): void { + if (!transfer) { + return + } + + transfer.isCancelled = true + if (!transfer.abortController.signal.aborted) { + transfer.abortController.abort(error) + } + if (transfer.stream && !transfer.stream.destroyed) { + transfer.stream.destroy(error) + } +} + +/** + * Clean up transfer resources without error. + */ +export function cleanupTransfer(transfer: ActiveFileTransfer | undefined): void { + if (!transfer) { + return + } + + if (!transfer.abortController.signal.aborted) { + transfer.abortController.abort() + } + if (transfer.stream && !transfer.stream.destroyed) { + transfer.stream.destroy() + } +} + +/** + * Format bytes into human-readable size string. + */ +export function formatFileSize(bytes: number): string { + if (bytes === 0) return '0 B' + const k = 1024 + const sizes = ['B', 'KB', 'MB', 'GB'] + const i = Math.floor(Math.log(bytes) / Math.log(k)) + return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i] +} diff --git a/src/main/services/lanTransfer/handlers/index.ts b/src/main/services/lanTransfer/handlers/index.ts new file mode 100644 index 0000000000..33620d188c --- /dev/null +++ b/src/main/services/lanTransfer/handlers/index.ts @@ -0,0 +1,22 @@ +export { + buildHandshakeMessage, + createDataHandler, + getAbortError, + HANDSHAKE_PROTOCOL_VERSION, + pickHost, + sendTestPing, + waitForSocketDrain +} from './connection' +export { + abortTransfer, + calculateFileChecksum, + cleanupTransfer, + createTransferState, + formatFileSize, + sendFileEnd, + sendFileStart, + streamFileChunks, + validateFile, + waitForFileComplete, + waitForFileStartAck +} from './fileTransfer' diff --git a/src/main/services/lanTransfer/index.ts b/src/main/services/lanTransfer/index.ts new file mode 100644 index 0000000000..12f3c38afc --- /dev/null +++ b/src/main/services/lanTransfer/index.ts @@ -0,0 +1,21 @@ +/** + * LAN Transfer Client Module + * + * Protocol: v1.0 (streaming mode) + * + * Features: + * - Binary frame format for file chunks (no base64 overhead) + * - Streaming mode (no per-chunk acknowledgment) + * - JSON messages for control flow (handshake, file_start, file_end, etc.) + * - Global timeout protection + * - Backpressure handling + * + * Binary Frame Format: + * ┌──────────┬──────────┬──────────┬───────────────┬──────────────┬────────────┬───────────┐ + * │ Magic │ TotalLen │ Type │ TransferId Len│ TransferId │ ChunkIdx │ Data │ + * │ 0x43 0x53│ (4B BE) │ 0x01 │ (2B BE) │ (variable) │ (4B BE) │ (raw) │ + * └──────────┴──────────┴──────────┴───────────────┴──────────────┴────────────┴───────────┘ + */ + +export { HANDSHAKE_PROTOCOL_VERSION, lanTransferClientService } from './LanTransferClientService' +export type { ActiveFileTransfer, ConnectionContext, FileTransferContext, PendingResponse } from './types' diff --git a/src/main/services/lanTransfer/responseManager.ts b/src/main/services/lanTransfer/responseManager.ts new file mode 100644 index 0000000000..74d5196dba --- /dev/null +++ b/src/main/services/lanTransfer/responseManager.ts @@ -0,0 +1,144 @@ +import type { PendingResponse } from './types' + +/** + * Manages pending response handlers for awaiting control messages. + * Handles timeouts, abort signals, and cleanup. + */ +export class ResponseManager { + private pendingResponses = new Map() + private onTimeout?: () => void + + /** + * Set a callback to be called when a response times out. + * Typically used to trigger disconnect on timeout. + */ + setTimeoutCallback(callback: () => void): void { + this.onTimeout = callback + } + + /** + * Build a composite key for identifying pending responses. + */ + buildResponseKey(type: string, transferId?: string, chunkIndex?: number): string { + const parts = [type] + if (transferId !== undefined) parts.push(transferId) + if (chunkIndex !== undefined) parts.push(String(chunkIndex)) + return parts.join(':') + } + + /** + * Register a response listener with timeout and optional abort signal. + */ + waitForResponse( + type: string, + timeoutMs: number, + resolve: (payload: unknown) => void, + reject: (error: Error) => void, + transferId?: string, + chunkIndex?: number, + abortSignal?: AbortSignal + ): void { + const responseKey = this.buildResponseKey(type, transferId, chunkIndex) + + // Clear any existing response with the same key + this.clearPendingResponse(responseKey) + + const timeoutHandle = setTimeout(() => { + this.clearPendingResponse(responseKey) + const error = new Error(`Timeout waiting for ${type}`) + reject(error) + this.onTimeout?.() + }, timeoutMs) + + const pending: PendingResponse = { + type, + transferId, + chunkIndex, + resolve, + reject, + timeoutHandle, + abortSignal + } + + if (abortSignal) { + const abortListener = () => { + this.clearPendingResponse(responseKey) + reject(this.getAbortError(abortSignal, `Aborted while waiting for ${type}`)) + } + pending.abortListener = abortListener + abortSignal.addEventListener('abort', abortListener, { once: true }) + } + + this.pendingResponses.set(responseKey, pending) + } + + /** + * Try to resolve a pending response by type and optional identifiers. + * Returns true if a matching response was found and resolved. + */ + tryResolve(type: string, payload: unknown, transferId?: string, chunkIndex?: number): boolean { + const responseKey = this.buildResponseKey(type, transferId, chunkIndex) + const pendingResponse = this.pendingResponses.get(responseKey) + + if (pendingResponse) { + const resolver = pendingResponse.resolve + this.clearPendingResponse(responseKey) + resolver(payload) + return true + } + + return false + } + + /** + * Clear a single pending response by key, or all responses if no key provided. + */ + clearPendingResponse(key?: string): void { + if (key) { + const pending = this.pendingResponses.get(key) + if (pending?.timeoutHandle) { + clearTimeout(pending.timeoutHandle) + } + if (pending?.abortSignal && pending.abortListener) { + pending.abortSignal.removeEventListener('abort', pending.abortListener) + } + this.pendingResponses.delete(key) + } else { + // Clear all pending responses + for (const pending of this.pendingResponses.values()) { + if (pending.timeoutHandle) { + clearTimeout(pending.timeoutHandle) + } + if (pending.abortSignal && pending.abortListener) { + pending.abortSignal.removeEventListener('abort', pending.abortListener) + } + } + this.pendingResponses.clear() + } + } + + /** + * Reject all pending responses with the given error. + */ + rejectAll(error: Error): void { + for (const key of Array.from(this.pendingResponses.keys())) { + const pending = this.pendingResponses.get(key) + this.clearPendingResponse(key) + pending?.reject(error) + } + } + + /** + * Get the abort error from an abort signal, or create a fallback error. + */ + getAbortError(signal: AbortSignal, fallbackMessage: string): Error { + const reason = (signal as AbortSignal & { reason?: unknown }).reason + if (reason instanceof Error) { + return reason + } + if (typeof reason === 'string' && reason.length > 0) { + return new Error(reason) + } + return new Error(fallbackMessage) + } +} diff --git a/src/main/services/lanTransfer/types.ts b/src/main/services/lanTransfer/types.ts new file mode 100644 index 0000000000..52be660af3 --- /dev/null +++ b/src/main/services/lanTransfer/types.ts @@ -0,0 +1,65 @@ +import type * as fs from 'node:fs' +import type { Socket } from 'node:net' + +import type { LanClientEvent, LocalTransferPeer } from '@shared/config/types' + +/** + * Pending response handler for awaiting control messages + */ +export type PendingResponse = { + type: string + transferId?: string + chunkIndex?: number + resolve: (payload: unknown) => void + reject: (error: Error) => void + timeoutHandle?: NodeJS.Timeout + abortSignal?: AbortSignal + abortListener?: () => void +} + +/** + * Active file transfer state tracking + */ +export type ActiveFileTransfer = { + transferId: string + fileName: string + fileSize: number + checksum: string + totalChunks: number + chunkSize: number + bytesSent: number + currentChunk: number + startedAt: number + stream?: fs.ReadStream + isCancelled: boolean + abortController: AbortController +} + +/** + * Context interface for connection handlers + * Provides access to service methods without circular dependencies + */ +export type ConnectionContext = { + socket: Socket | null + currentPeer?: LocalTransferPeer + sendControlMessage: (message: Record) => void + broadcastClientEvent: (event: LanClientEvent) => void +} + +/** + * Context interface for file transfer handlers + * Extends connection context with transfer-specific methods + */ +export type FileTransferContext = ConnectionContext & { + activeTransfer?: ActiveFileTransfer + setActiveTransfer: (transfer: ActiveFileTransfer | undefined) => void + waitForResponse: ( + type: string, + timeoutMs: number, + resolve: (payload: unknown) => void, + reject: (error: Error) => void, + transferId?: string, + chunkIndex?: number, + abortSignal?: AbortSignal + ) => void +} diff --git a/src/main/services/mcp/ServerLogBuffer.ts b/src/main/services/mcp/ServerLogBuffer.ts new file mode 100644 index 0000000000..01c45f373f --- /dev/null +++ b/src/main/services/mcp/ServerLogBuffer.ts @@ -0,0 +1,36 @@ +export type MCPServerLogEntry = { + timestamp: number + level: 'debug' | 'info' | 'warn' | 'error' | 'stderr' | 'stdout' + message: string + data?: any + source?: string +} + +/** + * Lightweight ring buffer for per-server MCP logs. + */ +export class ServerLogBuffer { + private maxEntries: number + private logs: Map = new Map() + + constructor(maxEntries = 200) { + this.maxEntries = maxEntries + } + + append(serverKey: string, entry: MCPServerLogEntry) { + const list = this.logs.get(serverKey) ?? [] + list.push(entry) + if (list.length > this.maxEntries) { + list.splice(0, list.length - this.maxEntries) + } + this.logs.set(serverKey, list) + } + + get(serverKey: string): MCPServerLogEntry[] { + return [...(this.logs.get(serverKey) ?? [])] + } + + remove(serverKey: string) { + this.logs.delete(serverKey) + } +} diff --git a/src/main/services/mcp/oauth/callback.ts b/src/main/services/mcp/oauth/callback.ts index c13ecd5c07..7da7544585 100644 --- a/src/main/services/mcp/oauth/callback.ts +++ b/src/main/services/mcp/oauth/callback.ts @@ -128,8 +128,8 @@ export class CallBackServer { }) return new Promise((resolve, reject) => { - server.listen(port, () => { - logger.info(`OAuth callback server listening on port ${port}`) + server.listen(port, '127.0.0.1', () => { + logger.info(`OAuth callback server listening on 127.0.0.1:${port}`) resolve(server) }) diff --git a/src/main/services/memory/MemoryService.ts b/src/main/services/memory/MemoryService.ts index 3466e2c3c6..101dd54294 100644 --- a/src/main/services/memory/MemoryService.ts +++ b/src/main/services/memory/MemoryService.ts @@ -1,7 +1,9 @@ import type { Client } from '@libsql/client' import { createClient } from '@libsql/client' import { loggerService } from '@logger' +import { DATA_PATH } from '@main/config' import Embeddings from '@main/knowledge/embedjs/embeddings/Embeddings' +import { makeSureDirExists } from '@main/utils' import type { AddMemoryOptions, AssistantMessage, @@ -13,6 +15,7 @@ import type { } from '@types' import crypto from 'crypto' import { app } from 'electron' +import fs from 'fs' import path from 'path' import { MemoryQueries } from './queries' @@ -71,6 +74,21 @@ export class MemoryService { return MemoryService.instance } + /** + * Migrate the memory database from the old path to the new path + * If the old memory database exists, rename it to the new path + */ + public migrateMemoryDb(): void { + const oldMemoryDbPath = path.join(app.getPath('userData'), 'memories.db') + const memoryDbPath = path.join(DATA_PATH, 'Memory', 'memories.db') + + makeSureDirExists(path.dirname(memoryDbPath)) + + if (fs.existsSync(oldMemoryDbPath)) { + fs.renameSync(oldMemoryDbPath, memoryDbPath) + } + } + /** * Initialize the database connection and create tables */ @@ -80,11 +98,12 @@ export class MemoryService { } try { - const userDataPath = app.getPath('userData') - const dbPath = path.join(userDataPath, 'memories.db') + const memoryDbPath = path.join(DATA_PATH, 'Memory', 'memories.db') + + makeSureDirExists(path.dirname(memoryDbPath)) this.db = createClient({ - url: `file:${dbPath}`, + url: `file:${memoryDbPath}`, intMode: 'number' }) @@ -168,12 +187,13 @@ export class MemoryService { // Generate embedding if model is configured let embedding: number[] | null = null - const embedderApiClient = this.config?.embedderApiClient - if (embedderApiClient) { + const embeddingModel = this.config?.embeddingModel + + if (embeddingModel) { try { embedding = await this.generateEmbedding(trimmedMemory) logger.debug( - `Generated embedding for restored memory with dimension: ${embedding.length} (target: ${this.config?.embedderDimensions || MemoryService.UNIFIED_DIMENSION})` + `Generated embedding for restored memory with dimension: ${embedding.length} (target: ${this.config?.embeddingDimensions || MemoryService.UNIFIED_DIMENSION})` ) } catch (error) { logger.error('Failed to generate embedding for restored memory:', error as Error) @@ -211,11 +231,11 @@ export class MemoryService { // Generate embedding if model is configured let embedding: number[] | null = null - if (this.config?.embedderApiClient) { + if (this.config?.embeddingModel) { try { embedding = await this.generateEmbedding(trimmedMemory) logger.debug( - `Generated embedding with dimension: ${embedding.length} (target: ${this.config?.embedderDimensions || MemoryService.UNIFIED_DIMENSION})` + `Generated embedding with dimension: ${embedding.length} (target: ${this.config?.embeddingDimensions || MemoryService.UNIFIED_DIMENSION})` ) // Check for similar memories using vector similarity @@ -300,7 +320,7 @@ export class MemoryService { try { // If we have an embedder model configured, use vector search - if (this.config?.embedderApiClient) { + if (this.config?.embeddingModel) { try { const queryEmbedding = await this.generateEmbedding(query) return await this.hybridSearch(query, queryEmbedding, { limit, userId, agentId, filters }) @@ -497,11 +517,11 @@ export class MemoryService { // Generate new embedding if model is configured let embedding: number[] | null = null - if (this.config?.embedderApiClient) { + if (this.config?.embeddingModel) { try { embedding = await this.generateEmbedding(memory) logger.debug( - `Updated embedding with dimension: ${embedding.length} (target: ${this.config?.embedderDimensions || MemoryService.UNIFIED_DIMENSION})` + `Updated embedding with dimension: ${embedding.length} (target: ${this.config?.embeddingDimensions || MemoryService.UNIFIED_DIMENSION})` ) } catch (error) { logger.error('Failed to generate embedding for update:', error as Error) @@ -710,21 +730,22 @@ export class MemoryService { * Generate embedding for text */ private async generateEmbedding(text: string): Promise { - if (!this.config?.embedderApiClient) { + if (!this.config?.embeddingModel) { throw new Error('Embedder model not configured') } try { // Initialize embeddings instance if needed if (!this.embeddings) { - if (!this.config.embedderApiClient) { + if (!this.config.embeddingApiClient) { throw new Error('Embedder provider not configured') } this.embeddings = new Embeddings({ - embedApiClient: this.config.embedderApiClient, - dimensions: this.config.embedderDimensions + embedApiClient: this.config.embeddingApiClient, + dimensions: this.config.embeddingDimensions }) + await this.embeddings.init() } diff --git a/src/main/utils/__tests__/mcp.test.ts b/src/main/utils/__tests__/mcp.test.ts index b1a35f925e..706a44bc84 100644 --- a/src/main/utils/__tests__/mcp.test.ts +++ b/src/main/utils/__tests__/mcp.test.ts @@ -3,194 +3,223 @@ import { describe, expect, it } from 'vitest' import { buildFunctionCallToolName } from '../mcp' describe('buildFunctionCallToolName', () => { - describe('basic functionality', () => { - it('should combine server name and tool name', () => { + describe('basic format', () => { + it('should return format mcp__{server}__{tool}', () => { const result = buildFunctionCallToolName('github', 'search_issues') - expect(result).toContain('github') - expect(result).toContain('search') + expect(result).toBe('mcp__github__search_issues') }) - it('should sanitize names by replacing dashes with underscores', () => { - const result = buildFunctionCallToolName('my-server', 'my-tool') - // Input dashes are replaced, but the separator between server and tool is a dash - expect(result).toBe('my_serv-my_tool') - expect(result).toContain('_') - }) - - it('should handle empty server names gracefully', () => { - const result = buildFunctionCallToolName('', 'tool') - expect(result).toBeTruthy() + it('should handle simple server and tool names', () => { + expect(buildFunctionCallToolName('fetch', 'get_page')).toBe('mcp__fetch__get_page') + expect(buildFunctionCallToolName('database', 'query')).toBe('mcp__database__query') + expect(buildFunctionCallToolName('cherry_studio', 'search')).toBe('mcp__cherry_studio__search') }) }) - describe('uniqueness with serverId', () => { - it('should generate different IDs for same server name but different serverIds', () => { - const serverId1 = 'server-id-123456' - const serverId2 = 'server-id-789012' - const serverName = 'github' - const toolName = 'search_repos' - - const result1 = buildFunctionCallToolName(serverName, toolName, serverId1) - const result2 = buildFunctionCallToolName(serverName, toolName, serverId2) - - expect(result1).not.toBe(result2) - expect(result1).toContain('123456') - expect(result2).toContain('789012') + describe('valid JavaScript identifier', () => { + it('should always start with mcp__ prefix (valid JS identifier start)', () => { + const result = buildFunctionCallToolName('123server', '456tool') + expect(result).toMatch(/^mcp__/) + expect(result).toBe('mcp__123server__456tool') }) - it('should generate same ID when serverId is not provided', () => { + it('should only contain alphanumeric chars and underscores', () => { + const result = buildFunctionCallToolName('my-server', 'my-tool') + expect(result).toBe('mcp__my_server__my_tool') + expect(result).toMatch(/^[a-zA-Z][a-zA-Z0-9_]*$/) + }) + + it('should be a valid JavaScript identifier', () => { + const testCases = [ + ['github', 'create_issue'], + ['my-server', 'fetch-data'], + ['test@server', 'tool#name'], + ['server.name', 'tool.action'], + ['123abc', 'def456'] + ] + + for (const [server, tool] of testCases) { + const result = buildFunctionCallToolName(server, tool) + // Valid JS identifiers match this pattern + expect(result).toMatch(/^[a-zA-Z_][a-zA-Z0-9_]*$/) + } + }) + }) + + describe('character sanitization', () => { + it('should replace dashes with underscores', () => { + const result = buildFunctionCallToolName('my-server', 'my-tool-name') + expect(result).toBe('mcp__my_server__my_tool_name') + }) + + it('should replace special characters with underscores', () => { + const result = buildFunctionCallToolName('test@server!', 'tool#name$') + expect(result).toBe('mcp__test_server__tool_name') + }) + + it('should replace dots with underscores', () => { + const result = buildFunctionCallToolName('server.name', 'tool.action') + expect(result).toBe('mcp__server_name__tool_action') + }) + + it('should replace spaces with underscores', () => { + const result = buildFunctionCallToolName('my server', 'my tool') + expect(result).toBe('mcp__my_server__my_tool') + }) + + it('should collapse consecutive underscores', () => { + const result = buildFunctionCallToolName('my--server', 'my___tool') + expect(result).toBe('mcp__my_server__my_tool') + expect(result).not.toMatch(/_{3,}/) + }) + + it('should trim leading and trailing underscores from parts', () => { + const result = buildFunctionCallToolName('_server_', '_tool_') + expect(result).toBe('mcp__server__tool') + }) + + it('should handle names with only special characters', () => { + const result = buildFunctionCallToolName('---', '###') + expect(result).toBe('mcp____') + }) + }) + + describe('length constraints', () => { + it('should not exceed 63 characters', () => { + const longServerName = 'a'.repeat(50) + const longToolName = 'b'.repeat(50) + const result = buildFunctionCallToolName(longServerName, longToolName) + + expect(result.length).toBeLessThanOrEqual(63) + }) + + it('should truncate server name to max 20 chars', () => { + const longServerName = 'abcdefghijklmnopqrstuvwxyz' // 26 chars + const result = buildFunctionCallToolName(longServerName, 'tool') + + expect(result).toBe('mcp__abcdefghijklmnopqrst__tool') + expect(result).toContain('abcdefghijklmnopqrst') // First 20 chars + expect(result).not.toContain('uvwxyz') // Truncated + }) + + it('should truncate tool name to max 35 chars', () => { + const longToolName = 'a'.repeat(40) + const result = buildFunctionCallToolName('server', longToolName) + + const expectedTool = 'a'.repeat(35) + expect(result).toBe(`mcp__server__${expectedTool}`) + }) + + it('should not end with underscores after truncation', () => { + // Create a name that would end with underscores after truncation + const longServerName = 'a'.repeat(20) + const longToolName = 'b'.repeat(35) + '___extra' + const result = buildFunctionCallToolName(longServerName, longToolName) + + expect(result).not.toMatch(/_+$/) + expect(result.length).toBeLessThanOrEqual(63) + }) + + it('should handle max length edge case exactly', () => { + // mcp__ (5) + server (20) + __ (2) + tool (35) = 62 chars + const server = 'a'.repeat(20) + const tool = 'b'.repeat(35) + const result = buildFunctionCallToolName(server, tool) + + expect(result.length).toBe(62) + expect(result).toBe(`mcp__${'a'.repeat(20)}__${'b'.repeat(35)}`) + }) + }) + + describe('edge cases', () => { + it('should handle empty server name', () => { + const result = buildFunctionCallToolName('', 'tool') + expect(result).toBe('mcp____tool') + }) + + it('should handle empty tool name', () => { + const result = buildFunctionCallToolName('server', '') + expect(result).toBe('mcp__server__') + }) + + it('should handle both empty names', () => { + const result = buildFunctionCallToolName('', '') + expect(result).toBe('mcp____') + }) + + it('should handle whitespace-only names', () => { + const result = buildFunctionCallToolName(' ', ' ') + expect(result).toBe('mcp____') + }) + + it('should trim whitespace from names', () => { + const result = buildFunctionCallToolName(' server ', ' tool ') + expect(result).toBe('mcp__server__tool') + }) + + it('should handle unicode characters', () => { + const result = buildFunctionCallToolName('服务器', '工具') + // Unicode chars are replaced with underscores, then collapsed + expect(result).toMatch(/^mcp__/) + }) + + it('should handle mixed case', () => { + const result = buildFunctionCallToolName('MyServer', 'MyTool') + expect(result).toBe('mcp__MyServer__MyTool') + }) + }) + + describe('deterministic output', () => { + it('should produce consistent results for same input', () => { const serverName = 'github' const toolName = 'search_repos' const result1 = buildFunctionCallToolName(serverName, toolName) const result2 = buildFunctionCallToolName(serverName, toolName) + const result3 = buildFunctionCallToolName(serverName, toolName) expect(result1).toBe(result2) + expect(result2).toBe(result3) }) - it('should include serverId suffix when provided', () => { - const serverId = 'abc123def456' - const result = buildFunctionCallToolName('server', 'tool', serverId) + it('should produce different results for different inputs', () => { + const result1 = buildFunctionCallToolName('server1', 'tool') + const result2 = buildFunctionCallToolName('server2', 'tool') + const result3 = buildFunctionCallToolName('server', 'tool1') + const result4 = buildFunctionCallToolName('server', 'tool2') - // Should include last 6 chars of serverId - expect(result).toContain('ef456') - }) - }) - - describe('character sanitization', () => { - it('should replace invalid characters with underscores', () => { - const result = buildFunctionCallToolName('test@server', 'tool#name') - expect(result).not.toMatch(/[@#]/) - expect(result).toMatch(/^[a-zA-Z0-9_-]+$/) - }) - - it('should ensure name starts with a letter', () => { - const result = buildFunctionCallToolName('123server', '456tool') - expect(result).toMatch(/^[a-zA-Z]/) - }) - - it('should handle consecutive underscores/dashes', () => { - const result = buildFunctionCallToolName('my--server', 'my__tool') - expect(result).not.toMatch(/[_-]{2,}/) - }) - }) - - describe('length constraints', () => { - it('should truncate names longer than 63 characters', () => { - const longServerName = 'a'.repeat(50) - const longToolName = 'b'.repeat(50) - const result = buildFunctionCallToolName(longServerName, longToolName, 'id123456') - - expect(result.length).toBeLessThanOrEqual(63) - }) - - it('should not end with underscore or dash after truncation', () => { - const longServerName = 'a'.repeat(50) - const longToolName = 'b'.repeat(50) - const result = buildFunctionCallToolName(longServerName, longToolName, 'id123456') - - expect(result).not.toMatch(/[_-]$/) - }) - - it('should preserve serverId suffix even with long server/tool names', () => { - const longServerName = 'a'.repeat(50) - const longToolName = 'b'.repeat(50) - const serverId = 'server-id-xyz789' - - const result = buildFunctionCallToolName(longServerName, longToolName, serverId) - - // The suffix should be preserved and not truncated - expect(result).toContain('xyz789') - expect(result.length).toBeLessThanOrEqual(63) - }) - - it('should ensure two long-named servers with different IDs produce different results', () => { - const longServerName = 'a'.repeat(50) - const longToolName = 'b'.repeat(50) - const serverId1 = 'server-id-abc123' - const serverId2 = 'server-id-def456' - - const result1 = buildFunctionCallToolName(longServerName, longToolName, serverId1) - const result2 = buildFunctionCallToolName(longServerName, longToolName, serverId2) - - // Both should be within limit - expect(result1.length).toBeLessThanOrEqual(63) - expect(result2.length).toBeLessThanOrEqual(63) - - // They should be different due to preserved suffix expect(result1).not.toBe(result2) - }) - }) - - describe('edge cases with serverId', () => { - it('should handle serverId with only non-alphanumeric characters', () => { - const serverId = '------' // All dashes - const result = buildFunctionCallToolName('server', 'tool', serverId) - - // Should still produce a valid unique suffix via fallback hash - expect(result).toBeTruthy() - expect(result.length).toBeLessThanOrEqual(63) - expect(result).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/) - // Should have a suffix (underscore followed by something) - expect(result).toMatch(/_[a-z0-9]+$/) - }) - - it('should produce different results for different non-alphanumeric serverIds', () => { - const serverId1 = '------' - const serverId2 = '!!!!!!' - - const result1 = buildFunctionCallToolName('server', 'tool', serverId1) - const result2 = buildFunctionCallToolName('server', 'tool', serverId2) - - // Should be different because the hash fallback produces different values - expect(result1).not.toBe(result2) - }) - - it('should handle empty string serverId differently from undefined', () => { - const resultWithEmpty = buildFunctionCallToolName('server', 'tool', '') - const resultWithUndefined = buildFunctionCallToolName('server', 'tool', undefined) - - // Empty string is falsy, so both should behave the same (no suffix) - expect(resultWithEmpty).toBe(resultWithUndefined) - }) - - it('should handle serverId with mixed alphanumeric and special chars', () => { - const serverId = 'ab@#cd' // Mixed chars, last 6 chars contain some alphanumeric - const result = buildFunctionCallToolName('server', 'tool', serverId) - - // Should extract alphanumeric chars: 'abcd' from 'ab@#cd' - expect(result).toContain('abcd') + expect(result3).not.toBe(result4) }) }) describe('real-world scenarios', () => { - it('should handle GitHub MCP server instances correctly', () => { - const serverName = 'github' - const toolName = 'search_repositories' - - const githubComId = 'server-github-com-abc123' - const gheId = 'server-ghe-internal-xyz789' - - const tool1 = buildFunctionCallToolName(serverName, toolName, githubComId) - const tool2 = buildFunctionCallToolName(serverName, toolName, gheId) - - // Should be different - expect(tool1).not.toBe(tool2) - - // Both should be valid identifiers - expect(tool1).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/) - expect(tool2).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/) - - // Both should be <= 63 chars - expect(tool1.length).toBeLessThanOrEqual(63) - expect(tool2.length).toBeLessThanOrEqual(63) + it('should handle GitHub MCP server', () => { + expect(buildFunctionCallToolName('github', 'create_issue')).toBe('mcp__github__create_issue') + expect(buildFunctionCallToolName('github', 'search_repositories')).toBe('mcp__github__search_repositories') + expect(buildFunctionCallToolName('github', 'get_pull_request')).toBe('mcp__github__get_pull_request') }) - it('should handle tool names that already include server name prefix', () => { - const result = buildFunctionCallToolName('github', 'github_search_repos') - expect(result).toBeTruthy() - // Should not double the server name - expect(result.split('github').length - 1).toBeLessThanOrEqual(2) + it('should handle filesystem MCP server', () => { + expect(buildFunctionCallToolName('filesystem', 'read_file')).toBe('mcp__filesystem__read_file') + expect(buildFunctionCallToolName('filesystem', 'write_file')).toBe('mcp__filesystem__write_file') + expect(buildFunctionCallToolName('filesystem', 'list_directory')).toBe('mcp__filesystem__list_directory') + }) + + it('should handle hyphenated server names (common in npm packages)', () => { + expect(buildFunctionCallToolName('cherry-fetch', 'get_page')).toBe('mcp__cherry_fetch__get_page') + expect(buildFunctionCallToolName('mcp-server-github', 'search')).toBe('mcp__mcp_server_github__search') + }) + + it('should handle scoped npm package style names', () => { + const result = buildFunctionCallToolName('@anthropic/mcp-server', 'chat') + expect(result).toBe('mcp__anthropic_mcp_server__chat') + }) + + it('should handle tools with long descriptive names', () => { + const result = buildFunctionCallToolName('github', 'search_repositories_by_language_and_stars') + expect(result.length).toBeLessThanOrEqual(63) + expect(result).toMatch(/^mcp__github__search_repositories_by_lan/) }) }) }) diff --git a/src/main/utils/__tests__/process.test.ts b/src/main/utils/__tests__/process.test.ts index 45c0f8b42b..373a66c48f 100644 --- a/src/main/utils/__tests__/process.test.ts +++ b/src/main/utils/__tests__/process.test.ts @@ -1,9 +1,28 @@ -import { execFileSync } from 'child_process' +import { configManager } from '@main/services/ConfigManager' +import { execFileSync, spawn } from 'child_process' +import { EventEmitter } from 'events' import fs from 'fs' import path from 'path' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { findExecutable, findGitBash } from '../process' +import { + autoDiscoverGitBash, + findCommandInShellEnv, + findExecutable, + findGitBash, + validateGitBashPath +} from '../process' + +// Mock configManager +vi.mock('@main/services/ConfigManager', () => ({ + ConfigKeys: { + GitBashPath: 'gitBashPath' + }, + configManager: { + get: vi.fn(), + set: vi.fn() + } +})) // Mock dependencies vi.mock('child_process') @@ -289,7 +308,133 @@ describe.skipIf(process.platform !== 'win32')('process utilities', () => { }) }) + describe('validateGitBashPath', () => { + it('returns null when path is null', () => { + const result = validateGitBashPath(null) + + expect(result).toBeNull() + }) + + it('returns null when path is undefined', () => { + const result = validateGitBashPath(undefined) + + expect(result).toBeNull() + }) + + it('returns normalized path when valid bash.exe exists', () => { + const customPath = 'C:\\PortableGit\\bin\\bash.exe' + vi.mocked(fs.existsSync).mockImplementation((p) => p === 'C:\\PortableGit\\bin\\bash.exe') + + const result = validateGitBashPath(customPath) + + expect(result).toBe('C:\\PortableGit\\bin\\bash.exe') + }) + + it('returns null when file does not exist', () => { + vi.mocked(fs.existsSync).mockReturnValue(false) + + const result = validateGitBashPath('C:\\missing\\bash.exe') + + expect(result).toBeNull() + }) + + it('returns null when path is not bash.exe', () => { + const customPath = 'C:\\PortableGit\\bin\\git.exe' + vi.mocked(fs.existsSync).mockReturnValue(true) + + const result = validateGitBashPath(customPath) + + expect(result).toBeNull() + }) + }) + describe('findGitBash', () => { + describe('customPath parameter', () => { + beforeEach(() => { + delete process.env.CLAUDE_CODE_GIT_BASH_PATH + }) + + it('uses customPath when valid', () => { + const customPath = 'C:\\CustomGit\\bin\\bash.exe' + vi.mocked(fs.existsSync).mockImplementation((p) => p === customPath) + + const result = findGitBash(customPath) + + expect(result).toBe(customPath) + expect(execFileSync).not.toHaveBeenCalled() + }) + + it('falls back when customPath is invalid', () => { + const customPath = 'C:\\Invalid\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + + vi.mocked(fs.existsSync).mockImplementation((p) => { + if (p === customPath) return false + if (p === gitPath) return true + if (p === bashPath) return true + return false + }) + + vi.mocked(execFileSync).mockReturnValue(gitPath) + + const result = findGitBash(customPath) + + expect(result).toBe(bashPath) + }) + + it('prioritizes customPath over env override', () => { + const customPath = 'C:\\CustomGit\\bin\\bash.exe' + const envPath = 'C:\\EnvGit\\bin\\bash.exe' + process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath + + vi.mocked(fs.existsSync).mockImplementation((p) => p === customPath || p === envPath) + + const result = findGitBash(customPath) + + expect(result).toBe(customPath) + }) + }) + + describe('env override', () => { + beforeEach(() => { + delete process.env.CLAUDE_CODE_GIT_BASH_PATH + }) + + it('uses CLAUDE_CODE_GIT_BASH_PATH when valid', () => { + const envPath = 'C:\\OverrideGit\\bin\\bash.exe' + process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath + + vi.mocked(fs.existsSync).mockImplementation((p) => p === envPath) + + const result = findGitBash() + + expect(result).toBe(envPath) + expect(execFileSync).not.toHaveBeenCalled() + }) + + it('falls back when CLAUDE_CODE_GIT_BASH_PATH is invalid', () => { + const envPath = 'C:\\Invalid\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + + process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath + + vi.mocked(fs.existsSync).mockImplementation((p) => { + if (p === envPath) return false + if (p === gitPath) return true + if (p === bashPath) return true + return false + }) + + vi.mocked(execFileSync).mockReturnValue(gitPath) + + const result = findGitBash() + + expect(result).toBe(bashPath) + }) + }) + describe('git.exe path derivation', () => { it('should derive bash.exe from standard Git installation (Git/cmd/git.exe)', () => { const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' @@ -569,4 +714,525 @@ describe.skipIf(process.platform !== 'win32')('process utilities', () => { }) }) }) + + describe('autoDiscoverGitBash', () => { + const originalEnvVar = process.env.CLAUDE_CODE_GIT_BASH_PATH + + beforeEach(() => { + vi.mocked(configManager.get).mockReset() + vi.mocked(configManager.set).mockReset() + delete process.env.CLAUDE_CODE_GIT_BASH_PATH + }) + + afterEach(() => { + // Restore original environment variable + if (originalEnvVar !== undefined) { + process.env.CLAUDE_CODE_GIT_BASH_PATH = originalEnvVar + } else { + delete process.env.CLAUDE_CODE_GIT_BASH_PATH + } + }) + + /** + * Helper to mock fs.existsSync with a set of valid paths + */ + const mockExistingPaths = (...validPaths: string[]) => { + vi.mocked(fs.existsSync).mockImplementation((p) => validPaths.includes(p as string)) + } + + describe('with no existing config path', () => { + it('should discover and persist Git Bash path when not configured', () => { + const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + process.env.ProgramFiles = 'C:\\Program Files' + mockExistingPaths(gitPath, bashPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(bashPath) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath) + }) + + it('should return null and not persist when Git Bash is not found', () => { + vi.mocked(configManager.get).mockReturnValue(undefined) + vi.mocked(fs.existsSync).mockReturnValue(false) + vi.mocked(execFileSync).mockImplementation(() => { + throw new Error('Not found') + }) + + const result = autoDiscoverGitBash() + + expect(result).toBeNull() + expect(configManager.set).not.toHaveBeenCalled() + }) + }) + + describe('environment variable precedence', () => { + it('should use env var over valid config path', () => { + const envPath = 'C:\\EnvGit\\bin\\bash.exe' + const configPath = 'C:\\ConfigGit\\bin\\bash.exe' + + process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath + vi.mocked(configManager.get).mockReturnValue(configPath) + mockExistingPaths(envPath, configPath) + + const result = autoDiscoverGitBash() + + // Env var should take precedence + expect(result).toBe(envPath) + // Should not persist env var path (it's a runtime override) + expect(configManager.set).not.toHaveBeenCalled() + }) + + it('should fall back to config path when env var is invalid', () => { + const envPath = 'C:\\Invalid\\bash.exe' + const configPath = 'C:\\ConfigGit\\bin\\bash.exe' + + process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath + vi.mocked(configManager.get).mockReturnValue(configPath) + // Env path is invalid (doesn't exist), only config path exists + mockExistingPaths(configPath) + + const result = autoDiscoverGitBash() + + // Should fall back to config path + expect(result).toBe(configPath) + expect(configManager.set).not.toHaveBeenCalled() + }) + + it('should fall back to auto-discovery when both env var and config are invalid', () => { + const envPath = 'C:\\InvalidEnv\\bash.exe' + const configPath = 'C:\\InvalidConfig\\bash.exe' + const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath + process.env.ProgramFiles = 'C:\\Program Files' + vi.mocked(configManager.get).mockReturnValue(configPath) + // Both env and config paths are invalid, only standard Git exists + mockExistingPaths(gitPath, discoveredPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(discoveredPath) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', discoveredPath) + }) + }) + + describe('with valid existing config path', () => { + it('should validate and return existing path without re-discovering', () => { + const existingPath = 'C:\\CustomGit\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + mockExistingPaths(existingPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(existingPath) + // Should not call findGitBash or persist again + expect(configManager.set).not.toHaveBeenCalled() + // Should not call execFileSync (which findGitBash would use for discovery) + expect(execFileSync).not.toHaveBeenCalled() + }) + + it('should not override existing valid config with auto-discovery', () => { + const existingPath = 'C:\\CustomGit\\bin\\bash.exe' + const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + mockExistingPaths(existingPath, discoveredPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(existingPath) + expect(configManager.set).not.toHaveBeenCalled() + }) + }) + + describe('with invalid existing config path', () => { + it('should attempt auto-discovery when existing path does not exist', () => { + const existingPath = 'C:\\NonExistent\\bin\\bash.exe' + const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + process.env.ProgramFiles = 'C:\\Program Files' + // Invalid path doesn't exist, but Git is installed at standard location + mockExistingPaths(gitPath, discoveredPath) + + const result = autoDiscoverGitBash() + + // Should discover and return the new path + expect(result).toBe(discoveredPath) + // Should persist the discovered path (overwrites invalid) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', discoveredPath) + }) + + it('should attempt auto-discovery when existing path is not bash.exe', () => { + const existingPath = 'C:\\CustomGit\\bin\\git.exe' + const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + process.env.ProgramFiles = 'C:\\Program Files' + // Invalid path exists but is not bash.exe (validation will fail) + // Git is installed at standard location + mockExistingPaths(existingPath, gitPath, discoveredPath) + + const result = autoDiscoverGitBash() + + // Should discover and return the new path + expect(result).toBe(discoveredPath) + // Should persist the discovered path (overwrites invalid) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', discoveredPath) + }) + + it('should return null when existing path is invalid and discovery fails', () => { + const existingPath = 'C:\\NonExistent\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + vi.mocked(fs.existsSync).mockReturnValue(false) + vi.mocked(execFileSync).mockImplementation(() => { + throw new Error('Not found') + }) + + const result = autoDiscoverGitBash() + + // Both validation and discovery failed + expect(result).toBeNull() + // Should not persist when discovery fails + expect(configManager.set).not.toHaveBeenCalled() + }) + }) + + describe('config persistence verification', () => { + it('should persist discovered path with correct config key', () => { + const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + process.env.ProgramFiles = 'C:\\Program Files' + mockExistingPaths(gitPath, bashPath) + + autoDiscoverGitBash() + + // Verify the exact call to configManager.set + expect(configManager.set).toHaveBeenCalledTimes(1) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath) + }) + + it('should persist on each discovery when config remains undefined', () => { + const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + process.env.ProgramFiles = 'C:\\Program Files' + mockExistingPaths(gitPath, bashPath) + + autoDiscoverGitBash() + autoDiscoverGitBash() + + // Each call discovers and persists since config remains undefined (mocked) + expect(configManager.set).toHaveBeenCalledTimes(2) + }) + }) + + describe('real-world scenarios', () => { + it('should discover and persist standard Git for Windows installation', () => { + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + process.env.ProgramFiles = 'C:\\Program Files' + mockExistingPaths(gitPath, bashPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(bashPath) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath) + }) + + it('should discover portable Git via where.exe and persist', () => { + const gitPath = 'D:\\PortableApps\\Git\\bin\\git.exe' + const bashPath = 'D:\\PortableApps\\Git\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + + vi.mocked(fs.existsSync).mockImplementation((p) => { + const pathStr = p?.toString() || '' + // Common git paths don't exist + if (pathStr.includes('Program Files\\Git\\cmd\\git.exe')) return false + if (pathStr.includes('Program Files (x86)\\Git\\cmd\\git.exe')) return false + // Portable bash path exists + if (pathStr === bashPath) return true + return false + }) + + vi.mocked(execFileSync).mockReturnValue(gitPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(bashPath) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath) + }) + + it('should respect user-configured path over auto-discovery', () => { + const userConfiguredPath = 'D:\\MyGit\\bin\\bash.exe' + const systemPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(userConfiguredPath) + mockExistingPaths(userConfiguredPath, systemPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(userConfiguredPath) + expect(configManager.set).not.toHaveBeenCalled() + // Verify findGitBash was not called for discovery + expect(execFileSync).not.toHaveBeenCalled() + }) + }) + }) +}) + +/** + * Helper to create a mock child process for spawn + */ +function createMockChildProcess() { + const mockChild = new EventEmitter() as EventEmitter & { + stdout: EventEmitter + stderr: EventEmitter + kill: ReturnType + } + mockChild.stdout = new EventEmitter() + mockChild.stderr = new EventEmitter() + mockChild.kill = vi.fn() + return mockChild +} + +describe('findCommandInShellEnv', () => { + beforeEach(() => { + vi.clearAllMocks() + // Reset path.isAbsolute to real implementation for these tests + vi.mocked(path.isAbsolute).mockImplementation((p) => p.startsWith('/') || /^[A-Z]:/i.test(p)) + }) + + describe('command name validation', () => { + it('should reject empty command name', async () => { + const result = await findCommandInShellEnv('', {}) + expect(result).toBeNull() + expect(spawn).not.toHaveBeenCalled() + }) + + it('should reject command names with shell metacharacters', async () => { + const maliciousCommands = [ + 'npx; rm -rf /', + 'npx && malicious', + 'npx | cat /etc/passwd', + 'npx`whoami`', + '$(whoami)', + 'npx\nmalicious' + ] + + for (const cmd of maliciousCommands) { + const result = await findCommandInShellEnv(cmd, {}) + expect(result).toBeNull() + expect(spawn).not.toHaveBeenCalled() + } + }) + + it('should reject command names starting with hyphen', async () => { + const result = await findCommandInShellEnv('-npx', {}) + expect(result).toBeNull() + expect(spawn).not.toHaveBeenCalled() + }) + + it('should reject path traversal attempts', async () => { + const pathTraversalCommands = ['../npx', '../../malicious', 'foo/bar', 'foo\\bar'] + + for (const cmd of pathTraversalCommands) { + const result = await findCommandInShellEnv(cmd, {}) + expect(result).toBeNull() + expect(spawn).not.toHaveBeenCalled() + } + }) + + it('should reject command names exceeding max length', async () => { + const longCommand = 'a'.repeat(129) + const result = await findCommandInShellEnv(longCommand, {}) + expect(result).toBeNull() + expect(spawn).not.toHaveBeenCalled() + }) + + it('should accept valid command names', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + // Don't await - just start the call + const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' }) + + // Simulate command not found + mockChild.emit('close', 1) + + const result = await resultPromise + expect(result).toBeNull() + expect(spawn).toHaveBeenCalled() + }) + + it('should accept command names with underscores and hyphens', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('my_command-name', { PATH: '/usr/bin' }) + mockChild.emit('close', 1) + + await resultPromise + expect(spawn).toHaveBeenCalled() + }) + + it('should accept command names at max length (128 chars)', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const maxLengthCommand = 'a'.repeat(128) + const resultPromise = findCommandInShellEnv(maxLengthCommand, { PATH: '/usr/bin' }) + mockChild.emit('close', 1) + + await resultPromise + expect(spawn).toHaveBeenCalled() + }) + }) + + describe.skipIf(process.platform === 'win32')('Unix/macOS behavior', () => { + it('should find command and return absolute path', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' }) + + // Simulate successful command -v output + mockChild.stdout.emit('data', '/usr/local/bin/npx\n') + mockChild.emit('close', 0) + + const result = await resultPromise + expect(result).toBe('/usr/local/bin/npx') + expect(spawn).toHaveBeenCalledWith('/bin/sh', ['-c', 'command -v "$1"', '--', 'npx'], expect.any(Object)) + }) + + it('should return null for non-absolute paths (aliases/builtins)', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('cd', { PATH: '/usr/bin' }) + + // Simulate builtin output (just command name) + mockChild.stdout.emit('data', 'cd\n') + mockChild.emit('close', 0) + + const result = await resultPromise + expect(result).toBeNull() + }) + + it('should return null when command not found', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('nonexistent', { PATH: '/usr/bin' }) + + // Simulate command not found (exit code 1) + mockChild.emit('close', 1) + + const result = await resultPromise + expect(result).toBeNull() + }) + + it('should handle spawn errors gracefully', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' }) + + // Simulate spawn error + mockChild.emit('error', new Error('spawn failed')) + + const result = await resultPromise + expect(result).toBeNull() + }) + + it('should handle timeout gracefully', async () => { + vi.useFakeTimers() + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' }) + + // Fast-forward past timeout (5000ms) + vi.advanceTimersByTime(6000) + + const result = await resultPromise + expect(result).toBeNull() + expect(mockChild.kill).toHaveBeenCalledWith('SIGKILL') + + vi.useRealTimers() + }) + }) + + describe.skipIf(process.platform !== 'win32')('Windows behavior', () => { + it('should find .exe files via where command', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' }) + + // Simulate where output + mockChild.stdout.emit('data', 'C:\\Program Files\\nodejs\\npx.exe\r\n') + mockChild.emit('close', 0) + + const result = await resultPromise + expect(result).toBe('C:\\Program Files\\nodejs\\npx.exe') + expect(spawn).toHaveBeenCalledWith('where', ['npx'], expect.any(Object)) + }) + + it('should reject .cmd files on Windows', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' }) + + // Simulate where output with only .cmd file + mockChild.stdout.emit('data', 'C:\\Program Files\\nodejs\\npx.cmd\r\n') + mockChild.emit('close', 0) + + const result = await resultPromise + expect(result).toBeNull() + }) + + it('should prefer .exe over .cmd when both exist', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' }) + + // Simulate where output with both .cmd and .exe + mockChild.stdout.emit('data', 'C:\\Program Files\\nodejs\\npx.cmd\r\nC:\\Program Files\\nodejs\\npx.exe\r\n') + mockChild.emit('close', 0) + + const result = await resultPromise + expect(result).toBe('C:\\Program Files\\nodejs\\npx.exe') + }) + + it('should handle spawn errors gracefully', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' }) + + // Simulate spawn error + mockChild.emit('error', new Error('spawn failed')) + + const result = await resultPromise + expect(result).toBeNull() + }) + }) }) diff --git a/src/main/utils/locales.ts b/src/main/utils/locales.ts index b41cba7c75..afaf48b20f 100644 --- a/src/main/utils/locales.ts +++ b/src/main/utils/locales.ts @@ -8,6 +8,7 @@ import esES from '../../renderer/src/i18n/translate/es-es.json' import frFR from '../../renderer/src/i18n/translate/fr-fr.json' import JaJP from '../../renderer/src/i18n/translate/ja-jp.json' import ptPT from '../../renderer/src/i18n/translate/pt-pt.json' +import roRO from '../../renderer/src/i18n/translate/ro-ro.json' import RuRu from '../../renderer/src/i18n/translate/ru-ru.json' const locales = Object.fromEntries( @@ -21,7 +22,8 @@ const locales = Object.fromEntries( ['el-GR', elGR], ['es-ES', esES], ['fr-FR', frFR], - ['pt-PT', ptPT] + ['pt-PT', ptPT], + ['ro-RO', roRO] ].map(([locale, translation]) => [locale, { translation }]) ) diff --git a/src/main/utils/mcp.ts b/src/main/utils/mcp.ts index cfa700f2e6..34eb0e63e7 100644 --- a/src/main/utils/mcp.ts +++ b/src/main/utils/mcp.ts @@ -1,56 +1,28 @@ -export function buildFunctionCallToolName(serverName: string, toolName: string, serverId?: string) { - const sanitizedServer = serverName.trim().replace(/-/g, '_') - const sanitizedTool = toolName.trim().replace(/-/g, '_') +/** + * Builds a valid JavaScript function name for MCP tool calls. + * Format: mcp__{server_name}__{tool_name} + * + * @param serverName - The MCP server name + * @param toolName - The tool name from the server + * @returns A valid JS identifier in format mcp__{server}__{tool}, max 63 chars + */ +export function buildFunctionCallToolName(serverName: string, toolName: string): string { + // Sanitize to valid JS identifier chars (alphanumeric + underscore only) + const sanitize = (str: string): string => + str + .trim() + .replace(/[^a-zA-Z0-9]/g, '_') // Replace all non-alphanumeric with underscore + .replace(/_{2,}/g, '_') // Collapse multiple underscores + .replace(/^_+|_+$/g, '') // Trim leading/trailing underscores - // Calculate suffix first to reserve space for it - // Suffix format: "_" + 6 alphanumeric chars = 7 chars total - let serverIdSuffix = '' - if (serverId) { - // Take the last 6 characters of the serverId for brevity - serverIdSuffix = serverId.slice(-6).replace(/[^a-zA-Z0-9]/g, '') + const server = sanitize(serverName).slice(0, 20) // Keep server name short + const tool = sanitize(toolName).slice(0, 35) // More room for tool name - // Fallback: if suffix becomes empty (all non-alphanumeric chars), use a simple hash - if (!serverIdSuffix) { - const hash = serverId.split('').reduce((acc, char) => acc + char.charCodeAt(0), 0) - serverIdSuffix = hash.toString(36).slice(-6) || 'x' - } - } + let name = `mcp__${server}__${tool}` - // Reserve space for suffix when calculating max base name length - const SUFFIX_LENGTH = serverIdSuffix ? serverIdSuffix.length + 1 : 0 // +1 for underscore - const MAX_BASE_LENGTH = 63 - SUFFIX_LENGTH - - // Combine server name and tool name - let name = sanitizedTool - if (!sanitizedTool.includes(sanitizedServer.slice(0, 7))) { - name = `${sanitizedServer.slice(0, 7) || ''}-${sanitizedTool || ''}` - } - - // Replace invalid characters with underscores or dashes - // Keep a-z, A-Z, 0-9, underscores and dashes - name = name.replace(/[^a-zA-Z0-9_-]/g, '_') - - // Ensure name starts with a letter or underscore (for valid JavaScript identifier) - if (!/^[a-zA-Z]/.test(name)) { - name = `tool-${name}` - } - - // Remove consecutive underscores/dashes (optional improvement) - name = name.replace(/[_-]{2,}/g, '_') - - // Truncate base name BEFORE adding suffix to ensure suffix is never cut off - if (name.length > MAX_BASE_LENGTH) { - name = name.slice(0, MAX_BASE_LENGTH) - } - - // Handle edge case: ensure we still have a valid name if truncation left invalid chars at edges - if (name.endsWith('_') || name.endsWith('-')) { - name = name.slice(0, -1) - } - - // Now append the suffix - it will always fit within 63 chars - if (serverIdSuffix) { - name = `${name}_${serverIdSuffix}` + // Ensure max 63 chars and clean trailing underscores + if (name.length > 63) { + name = name.slice(0, 63).replace(/_+$/, '') } return name diff --git a/src/main/utils/process.ts b/src/main/utils/process.ts index b59a37a048..7c928949fe 100644 --- a/src/main/utils/process.ts +++ b/src/main/utils/process.ts @@ -1,4 +1,5 @@ import { loggerService } from '@logger' +import type { GitBashPathInfo, GitBashPathSource } from '@shared/config/constant' import { HOME_CHERRY_DIR } from '@shared/config/constant' import { execFileSync, spawn } from 'child_process' import fs from 'fs' @@ -6,6 +7,7 @@ import os from 'os' import path from 'path' import { isWin } from '../constant' +import { ConfigKeys, configManager } from '../services/ConfigManager' import { getResourcePath } from '.' const logger = loggerService.withContext('Utils:Process') @@ -59,7 +61,146 @@ export async function getBinaryPath(name?: string): Promise { export async function isBinaryExists(name: string): Promise { const cmd = await getBinaryPath(name) - return await fs.existsSync(cmd) + return fs.existsSync(cmd) +} + +// Timeout for command lookup operations (in milliseconds) +const COMMAND_LOOKUP_TIMEOUT_MS = 5000 + +// Regex to validate command names - must start with alphanumeric or underscore, max 128 chars +const VALID_COMMAND_NAME_REGEX = /^[a-zA-Z0-9_][a-zA-Z0-9_-]{0,127}$/ + +// Maximum output size to prevent buffer overflow (10KB) +const MAX_OUTPUT_SIZE = 10240 + +/** + * Check if a command is available in the user's login shell environment + * @param command - Command name to check (e.g., 'npx', 'uvx') + * @param loginShellEnv - The login shell environment from getLoginShellEnvironment() + * @returns Full path to the command if found, null otherwise + */ +export async function findCommandInShellEnv( + command: string, + loginShellEnv: Record +): Promise { + // Validate command name to prevent command injection + if (!VALID_COMMAND_NAME_REGEX.test(command)) { + logger.warn(`Invalid command name '${command}' - must only contain alphanumeric characters, underscore, or hyphen`) + return null + } + + return new Promise((resolve) => { + let resolved = false + + const safeResolve = (value: string | null) => { + if (resolved) return + resolved = true + resolve(value) + } + + if (isWin) { + // On Windows, use 'where' command + const child = spawn('where', [command], { + env: loginShellEnv, + stdio: ['ignore', 'pipe', 'pipe'] + }) + + let output = '' + const timeoutId = setTimeout(() => { + if (resolved) return + child.kill('SIGKILL') + logger.debug(`Timeout checking command '${command}' on Windows`) + safeResolve(null) + }, COMMAND_LOOKUP_TIMEOUT_MS) + + child.stdout.on('data', (data) => { + if (output.length < MAX_OUTPUT_SIZE) { + output += data.toString() + } + }) + + child.on('close', (code) => { + clearTimeout(timeoutId) + if (resolved) return + + if (code === 0 && output.trim()) { + const paths = output.trim().split(/\r?\n/) + // Only accept .exe files on Windows - .cmd/.bat files cannot be executed + // with spawn({ shell: false }) which is used by MCP SDK's StdioClientTransport + const exePath = paths.find((p) => p.toLowerCase().endsWith('.exe')) + if (exePath) { + logger.debug(`Found command '${command}' at: ${exePath}`) + safeResolve(exePath) + } else { + logger.debug(`Command '${command}' found but not as .exe (${paths[0]}), treating as not found`) + safeResolve(null) + } + } else { + logger.debug(`Command '${command}' not found in shell environment`) + safeResolve(null) + } + }) + + child.on('error', (error) => { + clearTimeout(timeoutId) + if (resolved) return + logger.warn(`Error checking command '${command}':`, { error, platform: 'windows' }) + safeResolve(null) + }) + } else { + // Unix/Linux/macOS: use 'command -v' which is POSIX standard + // Use /bin/sh for reliability - it's POSIX compliant and always available + // This avoids issues with user's custom shell (csh, fish, etc.) + // SECURITY: Use positional parameter $1 to prevent command injection + const child = spawn('/bin/sh', ['-c', 'command -v "$1"', '--', command], { + env: loginShellEnv, + stdio: ['ignore', 'pipe', 'pipe'] + }) + + let output = '' + const timeoutId = setTimeout(() => { + if (resolved) return + child.kill('SIGKILL') + logger.debug(`Timeout checking command '${command}'`) + safeResolve(null) + }, COMMAND_LOOKUP_TIMEOUT_MS) + + child.stdout.on('data', (data) => { + if (output.length < MAX_OUTPUT_SIZE) { + output += data.toString() + } + }) + + child.on('close', (code) => { + clearTimeout(timeoutId) + if (resolved) return + + if (code === 0 && output.trim()) { + const commandPath = output.trim().split('\n')[0] + + // Validate the output is an absolute path (not an alias, function, or builtin) + // command -v can return just the command name for aliases/builtins + if (path.isAbsolute(commandPath)) { + logger.debug(`Found command '${command}' at: ${commandPath}`) + safeResolve(commandPath) + } else { + logger.debug(`Command '${command}' resolved to non-path '${commandPath}', treating as not found`) + safeResolve(null) + } + } else { + logger.debug(`Command '${command}' not found in shell environment`) + safeResolve(null) + } + }) + + child.on('error', (error) => { + clearTimeout(timeoutId) + if (resolved) return + logger.warn(`Error checking command '${command}':`, { error, platform: 'unix' }) + safeResolve(null) + }) + } + }) } /** @@ -131,15 +272,37 @@ export function findExecutable(name: string): string | null { /** * Find Git Bash executable on Windows + * @param customPath - Optional custom path from config * @returns Full path to bash.exe or null if not found */ -export function findGitBash(): string | null { +export function findGitBash(customPath?: string | null): string | null { // Git Bash is Windows-only if (!isWin) { return null } - // 1. Find git.exe and derive bash.exe path + // 1. Check custom path from config first + if (customPath) { + const validated = validateGitBashPath(customPath) + if (validated) { + logger.debug('Using custom Git Bash path from config', { path: validated }) + return validated + } + logger.warn('Custom Git Bash path provided but invalid', { path: customPath }) + } + + // 2. Check environment variable override + const envOverride = process.env.CLAUDE_CODE_GIT_BASH_PATH + if (envOverride) { + const validated = validateGitBashPath(envOverride) + if (validated) { + logger.debug('Using CLAUDE_CODE_GIT_BASH_PATH override for bash.exe', { path: validated }) + return validated + } + logger.warn('CLAUDE_CODE_GIT_BASH_PATH provided but path is invalid', { path: envOverride }) + } + + // 3. Find git.exe and derive bash.exe path const gitPath = findExecutable('git') if (gitPath) { // Try multiple possible locations for bash.exe relative to git.exe @@ -164,7 +327,7 @@ export function findGitBash(): string | null { }) } - // 2. Fallback: check common Git Bash paths directly + // 4. Fallback: check common Git Bash paths directly const commonBashPaths = [ path.join(process.env.ProgramFiles || 'C:\\Program Files', 'Git', 'bin', 'bash.exe'), path.join(process.env['ProgramFiles(x86)'] || 'C:\\Program Files (x86)', 'Git', 'bin', 'bash.exe'), @@ -181,3 +344,99 @@ export function findGitBash(): string | null { logger.debug('Git Bash not found - checked git derivation and common paths') return null } + +export function validateGitBashPath(customPath?: string | null): string | null { + if (!customPath) { + return null + } + + const resolved = path.resolve(customPath) + + if (!fs.existsSync(resolved)) { + logger.warn('Custom Git Bash path does not exist', { path: resolved }) + return null + } + + const isExe = resolved.toLowerCase().endsWith('bash.exe') + if (!isExe) { + logger.warn('Custom Git Bash path is not bash.exe', { path: resolved }) + return null + } + + logger.debug('Validated custom Git Bash path', { path: resolved }) + return resolved +} + +/** + * Auto-discover and persist Git Bash path if not already configured + * Only called when Git Bash is actually needed + * + * Precedence order: + * 1. CLAUDE_CODE_GIT_BASH_PATH environment variable (highest - runtime override) + * 2. Configured path from settings (manual or auto) + * 3. Auto-discovery via findGitBash (only if no valid config exists) + */ +export function autoDiscoverGitBash(): string | null { + if (!isWin) { + return null + } + + // 1. Check environment variable override first (highest priority) + const envOverride = process.env.CLAUDE_CODE_GIT_BASH_PATH + if (envOverride) { + const validated = validateGitBashPath(envOverride) + if (validated) { + logger.debug('Using CLAUDE_CODE_GIT_BASH_PATH override', { path: validated }) + return validated + } + logger.warn('CLAUDE_CODE_GIT_BASH_PATH provided but path is invalid', { path: envOverride }) + } + + // 2. Check if a path is already configured + const existingPath = configManager.get(ConfigKeys.GitBashPath) + const existingSource = configManager.get(ConfigKeys.GitBashPathSource) + + if (existingPath) { + const validated = validateGitBashPath(existingPath) + if (validated) { + return validated + } + // Existing path is invalid, try to auto-discover + logger.warn('Existing Git Bash path is invalid, attempting auto-discovery', { + path: existingPath, + source: existingSource + }) + } + + // 3. Try to find Git Bash via auto-discovery + const discoveredPath = findGitBash() + if (discoveredPath) { + // Persist the discovered path with 'auto' source + configManager.set(ConfigKeys.GitBashPath, discoveredPath) + configManager.set(ConfigKeys.GitBashPathSource, 'auto') + logger.info('Auto-discovered Git Bash path', { path: discoveredPath }) + } + + return discoveredPath +} + +/** + * Get Git Bash path info including source + * If no path is configured, triggers auto-discovery first + */ +export function getGitBashPathInfo(): GitBashPathInfo { + if (!isWin) { + return { path: null, source: null } + } + + let path = configManager.get(ConfigKeys.GitBashPath) ?? null + let source = configManager.get(ConfigKeys.GitBashPathSource) ?? null + + // If no path configured, trigger auto-discovery (handles upgrade from old versions) + if (!path) { + path = autoDiscoverGitBash() + source = path ? 'auto' : null + } + + return { path, source } +} diff --git a/src/main/utils/system.ts b/src/main/utils/system.ts new file mode 100644 index 0000000000..2cd9e4bf22 --- /dev/null +++ b/src/main/utils/system.ts @@ -0,0 +1,19 @@ +import os from 'node:os' + +import { isMac, isWin } from '@main/constant' + +export const getDeviceType = () => (isMac ? 'mac' : isWin ? 'windows' : 'linux') + +export const getHostname = () => os.hostname() + +export const getCpuName = () => { + try { + const cpus = os.cpus() + if (!cpus || cpus.length === 0 || !cpus[0].model) { + return 'Unknown CPU' + } + return cpus[0].model + } catch { + return 'Unknown CPU' + } +} diff --git a/src/preload/index.ts b/src/preload/index.ts index 0a3b141110..b27ecc68d4 100644 --- a/src/preload/index.ts +++ b/src/preload/index.ts @@ -2,9 +2,18 @@ import type { PermissionUpdate } from '@anthropic-ai/claude-agent-sdk' import { electronAPI } from '@electron-toolkit/preload' import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core' import type { SpanContext } from '@opentelemetry/api' -import type { TerminalConfig, UpgradeChannel } from '@shared/config/constant' +import type { GitBashPathInfo, TerminalConfig, UpgradeChannel } from '@shared/config/constant' import type { LogLevel, LogSourceWithContext } from '@shared/config/logger' -import type { FileChangeEvent, WebviewKeyEvent } from '@shared/config/types' +import type { + FileChangeEvent, + LanClientEvent, + LanFileCompleteMessage, + LanHandshakeAckMessage, + LocalTransferConnectPayload, + LocalTransferState, + WebviewKeyEvent +} from '@shared/config/types' +import type { MCPServerLogEntry } from '@shared/config/types' import { IpcChannel } from '@shared/IpcChannel' import type { Notification } from '@types' import type { @@ -123,7 +132,11 @@ const api = { getDeviceType: () => ipcRenderer.invoke(IpcChannel.System_GetDeviceType), getHostname: () => ipcRenderer.invoke(IpcChannel.System_GetHostname), getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName), - checkGitBash: (): Promise => ipcRenderer.invoke(IpcChannel.System_CheckGitBash) + checkGitBash: (): Promise => ipcRenderer.invoke(IpcChannel.System_CheckGitBash), + getGitBashPath: (): Promise => ipcRenderer.invoke(IpcChannel.System_GetGitBashPath), + getGitBashPathInfo: (): Promise => ipcRenderer.invoke(IpcChannel.System_GetGitBashPathInfo), + setGitBashPath: (newPath: string | null): Promise => + ipcRenderer.invoke(IpcChannel.System_SetGitBashPath, newPath) }, devTools: { toggle: () => ipcRenderer.invoke(IpcChannel.System_ToggleDevTools) @@ -167,7 +180,11 @@ const api = { listS3Files: (s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_ListS3Files, s3Config), deleteS3File: (fileName: string, s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_DeleteS3File, fileName, s3Config), - checkS3Connection: (s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_CheckS3Connection, s3Config) + checkS3Connection: (s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_CheckS3Connection, s3Config), + createLanTransferBackup: (data: string): Promise => + ipcRenderer.invoke(IpcChannel.Backup_CreateLanTransferBackup, data), + deleteTempBackup: (filePath: string): Promise => + ipcRenderer.invoke(IpcChannel.Backup_DeleteTempBackup, filePath) }, file: { select: (options?: OpenDialogOptions): Promise => @@ -338,7 +355,8 @@ const api = { deleteUser: (userId: string) => ipcRenderer.invoke(IpcChannel.Memory_DeleteUser, userId), deleteAllMemoriesForUser: (userId: string) => ipcRenderer.invoke(IpcChannel.Memory_DeleteAllMemoriesForUser, userId), - getUsersList: () => ipcRenderer.invoke(IpcChannel.Memory_GetUsersList) + getUsersList: () => ipcRenderer.invoke(IpcChannel.Memory_GetUsersList), + migrateMemoryDb: () => ipcRenderer.invoke(IpcChannel.Memory_MigrateMemoryDb) }, window: { setMinimumSize: (width: number, height: number) => @@ -367,6 +385,7 @@ const api = { ipcRenderer.invoke(IpcChannel.VertexAI_ClearAuthCache, projectId, clientEmail) }, ovms: { + isSupported: (): Promise => ipcRenderer.invoke(IpcChannel.Ovms_IsSupported), addModel: (modelName: string, modelId: string, modelSource: string, task: string) => ipcRenderer.invoke(IpcChannel.Ovms_AddModel, modelName, modelId, modelSource, task), stopAddModel: () => ipcRenderer.invoke(IpcChannel.Ovms_StopAddModel), @@ -417,7 +436,16 @@ const api = { }, abortTool: (callId: string) => ipcRenderer.invoke(IpcChannel.Mcp_AbortTool, callId), getServerVersion: (server: MCPServer): Promise => - ipcRenderer.invoke(IpcChannel.Mcp_GetServerVersion, server) + ipcRenderer.invoke(IpcChannel.Mcp_GetServerVersion, server), + getServerLogs: (server: MCPServer): Promise => + ipcRenderer.invoke(IpcChannel.Mcp_GetServerLogs, server), + onServerLog: (callback: (log: MCPServerLogEntry & { serverId?: string }) => void) => { + const listener = (_event: Electron.IpcRendererEvent, log: MCPServerLogEntry & { serverId?: string }) => { + callback(log) + } + ipcRenderer.on(IpcChannel.Mcp_ServerLog, listener) + return () => ipcRenderer.off(IpcChannel.Mcp_ServerLog, listener) + } }, python: { execute: (script: string, context?: Record, timeout?: number) => @@ -460,7 +488,7 @@ const api = { ipcRenderer.invoke(IpcChannel.Nutstore_GetDirectoryContents, token, path) }, searchService: { - openSearchWindow: (uid: string) => ipcRenderer.invoke(IpcChannel.SearchWindow_Open, uid), + openSearchWindow: (uid: string, show?: boolean) => ipcRenderer.invoke(IpcChannel.SearchWindow_Open, uid, show), closeSearchWindow: (uid: string) => ipcRenderer.invoke(IpcChannel.SearchWindow_Close, uid), openUrlInSearchWindow: (uid: string, url: string) => ipcRenderer.invoke(IpcChannel.SearchWindow_OpenUrl, uid, url) }, @@ -469,6 +497,8 @@ const api = { ipcRenderer.invoke(IpcChannel.Webview_SetOpenLinkExternal, webviewId, isExternal), setSpellCheckEnabled: (webviewId: number, isEnable: boolean) => ipcRenderer.invoke(IpcChannel.Webview_SetSpellCheckEnabled, webviewId, isEnable), + printToPDF: (webviewId: number) => ipcRenderer.invoke(IpcChannel.Webview_PrintToPDF, webviewId), + saveAsHTML: (webviewId: number) => ipcRenderer.invoke(IpcChannel.Webview_SaveAsHTML, webviewId), onFindShortcut: (callback: (payload: WebviewKeyEvent) => void) => { const listener = (_event: Electron.IpcRendererEvent, payload: WebviewKeyEvent) => { callback(payload) @@ -501,7 +531,10 @@ const api = { ipcRenderer.invoke(IpcChannel.Selection_ProcessAction, actionItem, isFullScreen), closeActionWindow: () => ipcRenderer.invoke(IpcChannel.Selection_ActionWindowClose), minimizeActionWindow: () => ipcRenderer.invoke(IpcChannel.Selection_ActionWindowMinimize), - pinActionWindow: (isPinned: boolean) => ipcRenderer.invoke(IpcChannel.Selection_ActionWindowPin, isPinned) + pinActionWindow: (isPinned: boolean) => ipcRenderer.invoke(IpcChannel.Selection_ActionWindowPin, isPinned), + // [Windows only] Electron bug workaround - can be removed once https://github.com/electron/electron/issues/48554 is fixed + resizeActionWindow: (deltaX: number, deltaY: number, direction: string) => + ipcRenderer.invoke(IpcChannel.Selection_ActionWindowResize, deltaX, deltaY, direction) }, agentTools: { respondToPermission: (payload: { @@ -615,12 +648,32 @@ const api = { writeContent: (options: WritePluginContentOptions): Promise> => ipcRenderer.invoke(IpcChannel.ClaudeCodePlugin_WriteContent, options) }, - webSocket: { - start: () => ipcRenderer.invoke(IpcChannel.WebSocket_Start), - stop: () => ipcRenderer.invoke(IpcChannel.WebSocket_Stop), - status: () => ipcRenderer.invoke(IpcChannel.WebSocket_Status), - sendFile: (filePath: string) => ipcRenderer.invoke(IpcChannel.WebSocket_SendFile, filePath), - getAllCandidates: () => ipcRenderer.invoke(IpcChannel.WebSocket_GetAllCandidates) + localTransfer: { + getState: (): Promise => ipcRenderer.invoke(IpcChannel.LocalTransfer_ListServices), + startScan: (): Promise => ipcRenderer.invoke(IpcChannel.LocalTransfer_StartScan), + stopScan: (): Promise => ipcRenderer.invoke(IpcChannel.LocalTransfer_StopScan), + connect: (payload: LocalTransferConnectPayload): Promise => + ipcRenderer.invoke(IpcChannel.LocalTransfer_Connect, payload), + disconnect: (): Promise => ipcRenderer.invoke(IpcChannel.LocalTransfer_Disconnect), + onServicesUpdated: (callback: (state: LocalTransferState) => void): (() => void) => { + const channel = IpcChannel.LocalTransfer_ServicesUpdated + const listener = (_: Electron.IpcRendererEvent, state: LocalTransferState) => callback(state) + ipcRenderer.on(channel, listener) + return () => { + ipcRenderer.removeListener(channel, listener) + } + }, + onClientEvent: (callback: (event: LanClientEvent) => void): (() => void) => { + const channel = IpcChannel.LocalTransfer_ClientEvent + const listener = (_: Electron.IpcRendererEvent, event: LanClientEvent) => callback(event) + ipcRenderer.on(channel, listener) + return () => { + ipcRenderer.removeListener(channel, listener) + } + }, + sendFile: (filePath: string): Promise => + ipcRenderer.invoke(IpcChannel.LocalTransfer_SendFile, { filePath }), + cancelTransfer: (): Promise => ipcRenderer.invoke(IpcChannel.LocalTransfer_CancelTransfer) } } diff --git a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts index 5de2ac3453..5d418de08b 100644 --- a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts +++ b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts @@ -120,6 +120,21 @@ export class AiSdkToChunkAdapter { } } + /** + * 如果有累积的思考内容,发送 THINKING_COMPLETE chunk 并清空 + * @param final 包含 reasoningContent 的状态对象 + * @returns 是否发送了 THINKING_COMPLETE chunk + */ + private emitThinkingCompleteIfNeeded(final: { reasoningContent: string; [key: string]: any }) { + if (final.reasoningContent) { + this.onChunk({ + type: ChunkType.THINKING_COMPLETE, + text: final.reasoningContent + }) + final.reasoningContent = '' + } + } + /** * 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调 * @param chunk AI SDK 的 chunk 数据 @@ -145,6 +160,9 @@ export class AiSdkToChunkAdapter { } // === 文本相关事件 === case 'text-start': + // 如果有未完成的思考内容,先生成 THINKING_COMPLETE + // 这处理了某些提供商不发送 reasoning-end 事件的情况 + this.emitThinkingCompleteIfNeeded(final) this.onChunk({ type: ChunkType.TEXT_START }) @@ -215,11 +233,7 @@ export class AiSdkToChunkAdapter { }) break case 'reasoning-end': - this.onChunk({ - type: ChunkType.THINKING_COMPLETE, - text: final.reasoningContent || '' - }) - final.reasoningContent = '' + this.emitThinkingCompleteIfNeeded(final) break // === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) === diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 4379547a3c..5c84a7254e 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -91,7 +91,9 @@ export default class ModernAiProvider { if (this.isModel(modelOrProvider)) { // 传入的是 Model this.model = modelOrProvider - this.actualProvider = provider ? adaptProvider({ provider }) : getActualProvider(modelOrProvider) + this.actualProvider = provider + ? adaptProvider({ provider, model: modelOrProvider }) + : getActualProvider(modelOrProvider) // 只保存配置,不预先创建executor this.config = providerToAiSdkConfig(this.actualProvider, modelOrProvider) } else { diff --git a/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts b/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts index 92f24b4abe..5d435b9074 100644 --- a/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts @@ -2,9 +2,10 @@ import { loggerService } from '@logger' import { getModelSupportedVerbosity, isFunctionCallingModel, - isNotSupportTemperatureAndTopP, isOpenAIModel, - isSupportFlexServiceTierModel + isSupportFlexServiceTierModel, + isSupportTemperatureModel, + isSupportTopPModel } from '@renderer/config/models' import { REFERENCE_PROMPT } from '@renderer/config/prompts' import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio' @@ -200,7 +201,7 @@ export abstract class BaseApiClient< } public getTemperature(assistant: Assistant, model: Model): number | undefined { - if (isNotSupportTemperatureAndTopP(model)) { + if (!isSupportTemperatureModel(model)) { return undefined } const assistantSettings = getAssistantSettings(assistant) @@ -208,7 +209,7 @@ export abstract class BaseApiClient< } public getTopP(assistant: Assistant, model: Model): number | undefined { - if (isNotSupportTemperatureAndTopP(model)) { + if (!isSupportTopPModel(model)) { return undefined } const assistantSettings = getAssistantSettings(assistant) diff --git a/src/renderer/src/aiCore/legacy/clients/__tests__/OpenAIBaseClient.azureEndpoint.test.ts b/src/renderer/src/aiCore/legacy/clients/__tests__/OpenAIBaseClient.azureEndpoint.test.ts new file mode 100644 index 0000000000..e3b2ef2676 --- /dev/null +++ b/src/renderer/src/aiCore/legacy/clients/__tests__/OpenAIBaseClient.azureEndpoint.test.ts @@ -0,0 +1,38 @@ +import { describe, expect, it } from 'vitest' + +import { normalizeAzureOpenAIEndpoint } from '../openai/azureOpenAIEndpoint' + +describe('normalizeAzureOpenAIEndpoint', () => { + it.each([ + { + apiHost: 'https://example.openai.azure.com/openai', + expectedEndpoint: 'https://example.openai.azure.com' + }, + { + apiHost: 'https://example.openai.azure.com/openai/', + expectedEndpoint: 'https://example.openai.azure.com' + }, + { + apiHost: 'https://example.openai.azure.com/openai/v1', + expectedEndpoint: 'https://example.openai.azure.com' + }, + { + apiHost: 'https://example.openai.azure.com/openai/v1/', + expectedEndpoint: 'https://example.openai.azure.com' + }, + { + apiHost: 'https://example.openai.azure.com', + expectedEndpoint: 'https://example.openai.azure.com' + }, + { + apiHost: 'https://example.openai.azure.com/', + expectedEndpoint: 'https://example.openai.azure.com' + }, + { + apiHost: 'https://example.openai.azure.com/OPENAI/V1', + expectedEndpoint: 'https://example.openai.azure.com' + } + ])('strips trailing /openai from $apiHost', ({ apiHost, expectedEndpoint }) => { + expect(normalizeAzureOpenAIEndpoint(apiHost)).toBe(expectedEndpoint) + }) +}) diff --git a/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts index 15f3cf1007..9b63b77ddf 100644 --- a/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts @@ -124,7 +124,8 @@ export class AnthropicAPIClient extends BaseApiClient< override async listModels(): Promise { const sdk = (await this.getSdkInstance()) as Anthropic - const response = await sdk.models.list() + // prevent auto appended /v1. It's included in baseUrl. + const response = await sdk.models.list({ path: '/models' }) return response.data } diff --git a/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts index 9c930a33ec..d7f14326f6 100644 --- a/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts @@ -46,7 +46,6 @@ import type { GeminiSdkRawOutput, GeminiSdkToolCall } from '@renderer/types/sdk' -import { getTrailingApiVersion, withoutTrailingApiVersion } from '@renderer/utils' import { isToolUseModeFunction } from '@renderer/utils/assistant' import { geminiFunctionCallToMcpTool, @@ -56,6 +55,7 @@ import { } from '@renderer/utils/mcp-tools' import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { defaultTimeout, MB } from '@shared/config/constant' +import { getTrailingApiVersion, withoutTrailingApiVersion } from '@shared/utils' import { t } from 'i18next' import type { GenericChunk } from '../../middleware/schemas' @@ -173,13 +173,15 @@ export class GeminiAPIClient extends BaseApiClient< return this.sdkInstance } + const apiVersion = this.getApiVersion() + this.sdkInstance = new GoogleGenAI({ vertexai: false, apiKey: this.apiKey, - apiVersion: this.getApiVersion(), + apiVersion, httpOptions: { baseUrl: this.getBaseURL(), - apiVersion: this.getApiVersion(), + apiVersion, headers: { ...this.provider.extra_headers } @@ -200,7 +202,7 @@ export class GeminiAPIClient extends BaseApiClient< return trailingVersion } - return 'v1beta' + return '' } /** diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts index cfc9087545..73a5bed4fe 100644 --- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts @@ -10,7 +10,7 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { findTokenLimit, GEMINI_FLASH_MODEL_REGEX, - getThinkModelType, + getModelSupportedReasoningEffortOptions, isDeepSeekHybridInferenceModel, isDoubaoThinkingAutoModel, isGPT5SeriesModel, @@ -33,7 +33,6 @@ import { isSupportedThinkingTokenQwenModel, isSupportedThinkingTokenZhipuModel, isVisionModel, - MODEL_SUPPORTED_REASONING_EFFORT, ZHIPU_RESULT_TOKENS } from '@renderer/config/models' import { mapLanguageToQwenMTModel } from '@renderer/config/translate' @@ -143,6 +142,10 @@ export class OpenAIAPIClient extends OpenAIBaseClient< return { thinking: { type: reasoningEffort ? 'enabled' : 'disabled' } } } + if (reasoningEffort === 'default') { + return {} + } + if (!reasoningEffort) { // DeepSeek hybrid inference models, v3.1 and maybe more in the future // 不同的 provider 有不同的思考控制方式,在这里统一解决 @@ -304,16 +307,15 @@ export class OpenAIAPIClient extends OpenAIBaseClient< // Grok models/Perplexity models/OpenAI models if (isSupportedReasoningEffortModel(model)) { // 检查模型是否支持所选选项 - const modelType = getThinkModelType(model) - const supportedOptions = MODEL_SUPPORTED_REASONING_EFFORT[modelType] - if (supportedOptions.includes(reasoningEffort)) { + const supportedOptions = getModelSupportedReasoningEffortOptions(model)?.filter((option) => option !== 'default') + if (supportedOptions?.includes(reasoningEffort)) { return { reasoning_effort: reasoningEffort } } else { // 如果不支持,fallback到第一个支持的值 return { - reasoning_effort: supportedOptions[0] + reasoning_effort: supportedOptions?.[0] } } } diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts index dc97e74a3c..efc3f4f7ce 100644 --- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts @@ -25,10 +25,11 @@ import type { OpenAISdkRawOutput, ReasoningEffortOptionalParams } from '@renderer/types/sdk' -import { formatApiHost, withoutTrailingSlash } from '@renderer/utils/api' +import { withoutTrailingSlash } from '@renderer/utils/api' import { isOllamaProvider } from '@renderer/utils/provider' import { BaseApiClient } from '../BaseApiClient' +import { normalizeAzureOpenAIEndpoint } from './azureOpenAIEndpoint' const logger = loggerService.withContext('OpenAIBaseClient') @@ -49,8 +50,9 @@ export abstract class OpenAIBaseClient< } // 仅适用于openai - override getBaseURL(isSupportedAPIVerion: boolean = true): string { - return formatApiHost(this.provider.apiHost, isSupportedAPIVerion) + override getBaseURL(): string { + // apiHost is formatted when called by AiProvider + return this.provider.apiHost } override async generateImage({ @@ -68,7 +70,7 @@ export abstract class OpenAIBaseClient< const sdk = await this.getSdkInstance() const response = (await sdk.request({ method: 'post', - path: '/images/generations', + path: '/v1/images/generations', signal, body: { model, @@ -87,7 +89,11 @@ export abstract class OpenAIBaseClient< } override async getEmbeddingDimensions(model: Model): Promise { - const sdk = await this.getSdkInstance() + let sdk: OpenAI = await this.getSdkInstance() + if (isOllamaProvider(this.provider)) { + const embedBaseUrl = `${this.provider.apiHost.replace(/(\/(api|v1))\/?$/, '')}/v1` + sdk = sdk.withOptions({ baseURL: embedBaseUrl }) + } const data = await sdk.embeddings.create({ model: model.id, @@ -100,6 +106,17 @@ export abstract class OpenAIBaseClient< override async listModels(): Promise { try { const sdk = await this.getSdkInstance() + if (this.provider.id === 'openrouter') { + // https://openrouter.ai/docs/api/api-reference/embeddings/list-embeddings-models + const embedBaseUrl = 'https://openrouter.ai/api/v1/embeddings' + const embedSdk = sdk.withOptions({ baseURL: embedBaseUrl }) + const modelPromise = sdk.models.list() + const embedModelPromise = embedSdk.models.list() + const [modelResponse, embedModelResponse] = await Promise.all([modelPromise, embedModelPromise]) + const models = [...modelResponse.data, ...embedModelResponse.data] + const uniqueModels = Array.from(new Map(models.map((model) => [model.id, model])).values()) + return uniqueModels.filter(isSupportedModel) + } if (this.provider.id === 'github') { // GitHub Models 其 models 和 chat completions 两个接口的 baseUrl 不一样 const baseUrl = 'https://models.github.ai/catalog/' @@ -118,7 +135,7 @@ export abstract class OpenAIBaseClient< } if (isOllamaProvider(this.provider)) { - const baseUrl = withoutTrailingSlash(this.getBaseURL(false)) + const baseUrl = withoutTrailingSlash(this.getBaseURL()) .replace(/\/v1$/, '') .replace(/\/api$/, '') const response = await fetch(`${baseUrl}/api/tags`, { @@ -173,6 +190,7 @@ export abstract class OpenAIBaseClient< let apiKeyForSdkInstance = this.apiKey let baseURLForSdkInstance = this.getBaseURL() + logger.debug('baseURLForSdkInstance', { baseURLForSdkInstance }) let headersForSdkInstance = { ...this.defaultHeaders(), ...this.provider.extra_headers @@ -184,7 +202,7 @@ export abstract class OpenAIBaseClient< // this.provider.apiKey不允许修改 // this.provider.apiKey = token apiKeyForSdkInstance = token - baseURLForSdkInstance = this.getBaseURL(false) + baseURLForSdkInstance = this.getBaseURL() headersForSdkInstance = { ...headersForSdkInstance, ...COPILOT_DEFAULT_HEADERS @@ -196,7 +214,7 @@ export abstract class OpenAIBaseClient< dangerouslyAllowBrowser: true, apiKey: apiKeyForSdkInstance, apiVersion: this.provider.apiVersion, - endpoint: this.provider.apiHost + endpoint: normalizeAzureOpenAIEndpoint(this.provider.apiHost) }) as TSdkInstance } else { this.sdkInstance = new OpenAI({ diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts index 8356826e26..b4f63e2bce 100644 --- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts @@ -122,6 +122,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient< if (this.sdkInstance) { return this.sdkInstance } + const baseUrl = this.getBaseURL() if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') { return new AzureOpenAI({ @@ -134,7 +135,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient< return new OpenAI({ dangerouslyAllowBrowser: true, apiKey: this.apiKey, - baseURL: this.getBaseURL(), + baseURL: baseUrl, defaultHeaders: { ...this.defaultHeaders(), ...this.provider.extra_headers diff --git a/src/renderer/src/aiCore/legacy/clients/openai/azureOpenAIEndpoint.ts b/src/renderer/src/aiCore/legacy/clients/openai/azureOpenAIEndpoint.ts new file mode 100644 index 0000000000..777dbe74d7 --- /dev/null +++ b/src/renderer/src/aiCore/legacy/clients/openai/azureOpenAIEndpoint.ts @@ -0,0 +1,4 @@ +export function normalizeAzureOpenAIEndpoint(apiHost: string): string { + const normalizedHost = apiHost.replace(/\/+$/, '') + return normalizedHost.replace(/\/openai(?:\/v1)?$/i, '') +} diff --git a/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts b/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts index 02ac6de091..4936b693ee 100644 --- a/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts @@ -3,7 +3,8 @@ import { loggerService } from '@logger' import { isSupportedModel } from '@renderer/config/models' import type { Provider } from '@renderer/types' import { objectKeys } from '@renderer/types' -import { formatApiHost, withoutTrailingApiVersion } from '@renderer/utils' +import { formatApiHost } from '@renderer/utils' +import { withoutTrailingApiVersion } from '@shared/utils' import { OpenAIAPIClient } from '../openai/OpenAIApiClient' diff --git a/src/renderer/src/aiCore/legacy/clients/zhipu/ZhipuAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/zhipu/ZhipuAPIClient.ts index ea6c141e31..9c590996f1 100644 --- a/src/renderer/src/aiCore/legacy/clients/zhipu/ZhipuAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/zhipu/ZhipuAPIClient.ts @@ -66,6 +66,11 @@ export class ZhipuAPIClient extends OpenAIAPIClient { public async listModels(): Promise { const models = [ + 'glm-4.7', + 'glm-4.6', + 'glm-4.6v', + 'glm-4.6v-flash', + 'glm-4.6v-flashx', 'glm-4.5', 'glm-4.5-x', 'glm-4.5-air', diff --git a/src/renderer/src/aiCore/legacy/index.ts b/src/renderer/src/aiCore/legacy/index.ts index da6cdb6726..7c5f5211d9 100644 --- a/src/renderer/src/aiCore/legacy/index.ts +++ b/src/renderer/src/aiCore/legacy/index.ts @@ -2,7 +2,6 @@ import { loggerService } from '@logger' import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory' import type { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient' import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models' -import { getProviderByModel } from '@renderer/services/AssistantService' import { withSpanResult } from '@renderer/services/SpanManagerService' import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity' import type { GenerateImageParams, Model, Provider } from '@renderer/types' @@ -160,9 +159,6 @@ export default class AiProvider { public async getEmbeddingDimensions(model: Model): Promise { try { // Use the SDK instance to test embedding capabilities - if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') { - this.apiClient = this.apiClient.getClient(model) as BaseApiClient - } const dimensions = await this.apiClient.getEmbeddingDimensions(model) return dimensions } catch (error) { diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts index 10a4d59384..b2a796bd33 100644 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -7,7 +7,6 @@ import type { Chunk } from '@renderer/types/chunk' import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider' import type { LanguageModelMiddleware } from 'ai' import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' -import { isEmpty } from 'lodash' import { getAiSdkProviderId } from '../provider/factory' import { isOpenRouterGeminiGenerateImageModel } from '../utils/image' @@ -16,7 +15,6 @@ import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMidd import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware' import { qwenThinkingMiddleware } from './qwenThinkingMiddleware' import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware' -import { toolChoiceMiddleware } from './toolChoiceMiddleware' const logger = loggerService.withContext('AiSdkMiddlewareBuilder') @@ -136,15 +134,6 @@ export class AiSdkMiddlewareBuilder { export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] { const builder = new AiSdkMiddlewareBuilder() - // 0. 知识库强制调用中间件(必须在最前面,确保第一轮强制调用知识库) - if (!isEmpty(config.assistant?.knowledge_bases?.map((base) => base.id)) && config.knowledgeRecognition !== 'on') { - builder.add({ - name: 'force-knowledge-first', - middleware: toolChoiceMiddleware('builtin_knowledge_search') - }) - logger.debug('Added toolChoice middleware to force knowledge base search on first round') - } - // 1. 根据provider添加特定中间件 if (config.provider) { addProviderSpecificMiddlewares(builder, config) diff --git a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts index 6be577f194..5b095a4461 100644 --- a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts +++ b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts @@ -31,7 +31,7 @@ import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool' const logger = loggerService.withContext('SearchOrchestrationPlugin') -const getMessageContent = (message: ModelMessage) => { +export const getMessageContent = (message: ModelMessage) => { if (typeof message.content === 'string') return message.content return message.content.reduce((acc, part) => { if (part.type === 'text') { @@ -266,14 +266,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) // 判断是否需要各种搜索 const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) const hasKnowledgeBase = !isEmpty(knowledgeBaseIds) - const knowledgeRecognition = assistant.knowledgeRecognition || 'on' + const knowledgeRecognition = assistant.knowledgeRecognition || 'off' const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState()) const shouldWebSearch = !!assistant.webSearchProviderId const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on' const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory // 执行意图分析 - if (shouldWebSearch || hasKnowledgeBase) { + if (shouldWebSearch || shouldKnowledgeSearch) { const analysisResult = await analyzeSearchIntent(lastUserMessage, assistant, { shouldWebSearch, shouldKnowledgeSearch, @@ -330,41 +330,25 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) // 📚 知识库搜索工具配置 const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) const hasKnowledgeBase = !isEmpty(knowledgeBaseIds) - const knowledgeRecognition = assistant.knowledgeRecognition || 'on' + const knowledgeRecognition = assistant.knowledgeRecognition || 'off' + const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on' - if (hasKnowledgeBase) { - if (knowledgeRecognition === 'off') { - // off 模式:直接添加知识库搜索工具,使用用户消息作为搜索关键词 + if (shouldKnowledgeSearch) { + // on 模式:根据意图识别结果决定是否添加工具 + const needsKnowledgeSearch = + analysisResult?.knowledge && + analysisResult.knowledge.question && + analysisResult.knowledge.question[0] !== 'not_needed' + + if (needsKnowledgeSearch && analysisResult.knowledge) { + // logger.info('📚 Adding knowledge search tool (intent-based)') const userMessage = userMessages[context.requestId] - const fallbackKeywords = { - question: [getMessageContent(userMessage) || 'search'], - rewrite: getMessageContent(userMessage) || 'search' - } - // logger.info('📚 Adding knowledge search tool (force mode)') params.tools['builtin_knowledge_search'] = knowledgeSearchTool( assistant, - fallbackKeywords, + analysisResult.knowledge, getMessageContent(userMessage), topicId ) - // params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' } - } else { - // on 模式:根据意图识别结果决定是否添加工具 - const needsKnowledgeSearch = - analysisResult?.knowledge && - analysisResult.knowledge.question && - analysisResult.knowledge.question[0] !== 'not_needed' - - if (needsKnowledgeSearch && analysisResult.knowledge) { - // logger.info('📚 Adding knowledge search tool (intent-based)') - const userMessage = userMessages[context.requestId] - params.tools['builtin_knowledge_search'] = knowledgeSearchTool( - assistant, - analysisResult.knowledge, - getMessageContent(userMessage), - topicId - ) - } } } diff --git a/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts b/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts index 2433192cd0..2a69f3bcef 100644 --- a/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts +++ b/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts @@ -109,6 +109,20 @@ const createImageBlock = ( ...overrides }) +const createThinkingBlock = ( + messageId: string, + overrides: Partial> = {} +): ThinkingMessageBlock => ({ + id: overrides.id ?? `thinking-block-${++blockCounter}`, + messageId, + type: MessageBlockType.THINKING, + createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(), + status: overrides.status ?? MessageBlockStatus.SUCCESS, + content: overrides.content ?? 'Let me think...', + thinking_millsec: overrides.thinking_millsec ?? 1000, + ...overrides +}) + describe('messageConverter', () => { beforeEach(() => { convertFileBlockToFilePartMock.mockReset() @@ -137,6 +151,73 @@ describe('messageConverter', () => { }) }) + it('extracts base64 data from data URLs and preserves mediaType', async () => { + const model = createModel() + const message = createMessage('user') + message.__mockContent = 'Check this image' + message.__mockImageBlocks = [createImageBlock(message.id, { url: '' })] + + const result = await convertMessageToSdkParam(message, true, model) + + expect(result).toEqual({ + role: 'user', + content: [ + { type: 'text', text: 'Check this image' }, + { type: 'image', image: 'iVBORw0KGgoAAAANS', mediaType: 'image/png' } + ] + }) + }) + + it('handles data URLs without mediaType gracefully', async () => { + const model = createModel() + const message = createMessage('user') + message.__mockContent = 'Check this' + message.__mockImageBlocks = [createImageBlock(message.id, { url: 'data:;base64,AAABBBCCC' })] + + const result = await convertMessageToSdkParam(message, true, model) + + expect(result).toEqual({ + role: 'user', + content: [ + { type: 'text', text: 'Check this' }, + { type: 'image', image: 'AAABBBCCC' } + ] + }) + }) + + it('skips malformed data URLs without comma separator', async () => { + const model = createModel() + const message = createMessage('user') + message.__mockContent = 'Malformed data url' + message.__mockImageBlocks = [createImageBlock(message.id, { url: 'data:image/pngAAABBB' })] + + const result = await convertMessageToSdkParam(message, true, model) + + expect(result).toEqual({ + role: 'user', + content: [ + { type: 'text', text: 'Malformed data url' } + // Malformed data URL is excluded from the content + ] + }) + }) + + it('handles multiple large base64 images without stack overflow', async () => { + const model = createModel() + const message = createMessage('user') + // Create large base64 strings (~500KB each) to simulate real-world large images + const largeBase64 = 'A'.repeat(500_000) + message.__mockContent = 'Check these images' + message.__mockImageBlocks = [ + createImageBlock(message.id, { url: `data:image/png;base64,${largeBase64}` }), + createImageBlock(message.id, { url: `data:image/png;base64,${largeBase64}` }), + createImageBlock(message.id, { url: `data:image/png;base64,${largeBase64}` }) + ] + + // Should not throw RangeError: Maximum call stack size exceeded + await expect(convertMessageToSdkParam(message, true, model)).resolves.toBeDefined() + }) + it('returns file instructions as a system message when native uploads succeed', async () => { const model = createModel() const message = createMessage('user') @@ -162,10 +243,27 @@ describe('messageConverter', () => { } ]) }) + + it('includes reasoning parts for assistant messages with thinking blocks', async () => { + const model = createModel() + const message = createMessage('assistant') + message.__mockContent = 'Here is my answer' + message.__mockThinkingBlocks = [createThinkingBlock(message.id, { content: 'Let me think...' })] + + const result = await convertMessageToSdkParam(message, false, model) + + expect(result).toEqual({ + role: 'assistant', + content: [ + { type: 'text', text: 'Here is my answer' }, + { type: 'reasoning', text: 'Let me think...' } + ] + }) + }) }) describe('convertMessagesToSdkMessages', () => { - it('appends assistant images to the final user message for image enhancement models', async () => { + it('collapses to [system?, user(image)] for image enhancement models', async () => { const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' }) const initialUser = createMessage('user') initialUser.__mockContent = 'Start editing' @@ -180,14 +278,6 @@ describe('messageConverter', () => { const result = await convertMessagesToSdkMessages([initialUser, assistant, finalUser], model) expect(result).toEqual([ - { - role: 'user', - content: [{ type: 'text', text: 'Start editing' }] - }, - { - role: 'assistant', - content: [{ type: 'text', text: 'Here is the current preview' }] - }, { role: 'user', content: [ @@ -198,7 +288,7 @@ describe('messageConverter', () => { ]) }) - it('preserves preceding system instructions when building enhancement payloads', async () => { + it('preserves system messages and collapses others for enhancement payloads', async () => { const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' }) const fileUser = createMessage('user') fileUser.__mockContent = 'Use this document as inspiration' @@ -221,11 +311,6 @@ describe('messageConverter', () => { expect(result).toEqual([ { role: 'system', content: 'fileid://reference' }, - { role: 'user', content: [{ type: 'text', text: 'Use this document as inspiration' }] }, - { - role: 'assistant', - content: [{ type: 'text', text: 'Generated previews ready' }] - }, { role: 'user', content: [ @@ -235,5 +320,120 @@ describe('messageConverter', () => { } ]) }) + + it('handles no previous assistant message with images', async () => { + const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' }) + const user1 = createMessage('user') + user1.__mockContent = 'Start' + + const user2 = createMessage('user') + user2.__mockContent = 'Continue without images' + + const result = await convertMessagesToSdkMessages([user1, user2], model) + + expect(result).toEqual([ + { + role: 'user', + content: [{ type: 'text', text: 'Continue without images' }] + } + ]) + }) + + it('handles assistant message without images', async () => { + const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' }) + const user1 = createMessage('user') + user1.__mockContent = 'Start' + + const assistant = createMessage('assistant') + assistant.__mockContent = 'Text only response' + assistant.__mockImageBlocks = [] + + const user2 = createMessage('user') + user2.__mockContent = 'Follow up' + + const result = await convertMessagesToSdkMessages([user1, assistant, user2], model) + + expect(result).toEqual([ + { + role: 'user', + content: [{ type: 'text', text: 'Follow up' }] + } + ]) + }) + + it('handles multiple assistant messages by using the most recent one', async () => { + const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' }) + const user1 = createMessage('user') + user1.__mockContent = 'Start' + + const assistant1 = createMessage('assistant') + assistant1.__mockContent = 'First response' + assistant1.__mockImageBlocks = [createImageBlock(assistant1.id, { url: 'https://example.com/old.png' })] + + const user2 = createMessage('user') + user2.__mockContent = 'Continue' + + const assistant2 = createMessage('assistant') + assistant2.__mockContent = 'Second response' + assistant2.__mockImageBlocks = [createImageBlock(assistant2.id, { url: 'https://example.com/new.png' })] + + const user3 = createMessage('user') + user3.__mockContent = 'Final request' + + const result = await convertMessagesToSdkMessages([user1, assistant1, user2, assistant2, user3], model) + + expect(result).toEqual([ + { + role: 'user', + content: [ + { type: 'text', text: 'Final request' }, + { type: 'image', image: 'https://example.com/new.png' } + ] + } + ]) + }) + + it('handles conversation ending with assistant message', async () => { + const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' }) + const user = createMessage('user') + user.__mockContent = 'Start' + + const assistant = createMessage('assistant') + assistant.__mockContent = 'Response with image' + assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/image.png' })] + + const result = await convertMessagesToSdkMessages([user, assistant], model) + + // The user message is the last user message, but since the assistant comes after, + // there's no "previous" assistant message (search starts from messages.length-2 backwards) + expect(result).toEqual([ + { + role: 'user', + content: [{ type: 'text', text: 'Start' }] + } + ]) + }) + + it('handles empty content in last user message', async () => { + const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' }) + const user1 = createMessage('user') + user1.__mockContent = 'Start' + + const assistant = createMessage('assistant') + assistant.__mockContent = 'Here is the preview' + assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/preview.png' })] + + const user2 = createMessage('user') + user2.__mockContent = '' + + const result = await convertMessagesToSdkMessages([user1, assistant, user2], model) + + expect(result).toEqual([ + { + role: 'user', + content: [{ type: 'image', image: 'https://example.com/preview.png' }] + } + ]) + }) }) }) diff --git a/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts b/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts index 70b4ac84b7..a4f345e3e5 100644 --- a/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts +++ b/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts @@ -18,7 +18,7 @@ vi.mock('@renderer/services/AssistantService', () => ({ toolUseMode: assistant.settings?.toolUseMode ?? 'prompt', defaultModel: assistant.defaultModel, customParameters: assistant.settings?.customParameters ?? [], - reasoning_effort: assistant.settings?.reasoning_effort, + reasoning_effort: assistant.settings?.reasoning_effort ?? 'default', reasoning_effort_cache: assistant.settings?.reasoning_effort_cache, qwenThinkMode: assistant.settings?.qwenThinkMode }) diff --git a/src/renderer/src/aiCore/prepareParams/messageConverter.ts b/src/renderer/src/aiCore/prepareParams/messageConverter.ts index b0c432ef85..c3798c1f43 100644 --- a/src/renderer/src/aiCore/prepareParams/messageConverter.ts +++ b/src/renderer/src/aiCore/prepareParams/messageConverter.ts @@ -3,10 +3,12 @@ * 将 Cherry Studio 消息格式转换为 AI SDK 消息格式 */ +import type { ReasoningPart } from '@ai-sdk/provider-utils' import { loggerService } from '@logger' import { isImageEnhancementModel, isVisionModel } from '@renderer/config/models' import type { Message, Model } from '@renderer/types' import type { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage' +import { parseDataUrlMediaType } from '@renderer/utils/image' import { findFileBlocks, findImageBlocks, @@ -59,23 +61,29 @@ async function convertImageBlockToImagePart(imageBlocks: ImageMessageBlock[]): P mediaType: image.mime }) } catch (error) { - logger.warn('Failed to load image:', error as Error) + logger.error('Failed to load image file, image will be excluded from message:', { + fileId: imageBlock.file.id, + fileName: imageBlock.file.origin_name, + error: error as Error + }) } } else if (imageBlock.url) { - const isBase64 = imageBlock.url.startsWith('data:') - if (isBase64) { - const base64 = imageBlock.url.match(/^data:[^;]*;base64,(.+)$/)![1] - const mimeMatch = imageBlock.url.match(/^data:([^;]+)/) - parts.push({ - type: 'image', - image: base64, - mediaType: mimeMatch ? mimeMatch[1] : 'image/png' - }) + const url = imageBlock.url + const isDataUrl = url.startsWith('data:') + if (isDataUrl) { + const { mediaType } = parseDataUrlMediaType(url) + const commaIndex = url.indexOf(',') + if (commaIndex === -1) { + logger.error('Malformed data URL detected (missing comma separator), image will be excluded:', { + urlPrefix: url.slice(0, 50) + '...' + }) + continue + } + const base64Data = url.slice(commaIndex + 1) + parts.push({ type: 'image', image: base64Data, ...(mediaType ? { mediaType } : {}) }) } else { - parts.push({ - type: 'image', - image: imageBlock.url - }) + // For remote URLs we keep payload minimal to match existing expectations. + parts.push({ type: 'image', image: url }) } } } @@ -156,13 +164,13 @@ async function convertMessageToAssistantModelMessage( thinkingBlocks: ThinkingMessageBlock[], model?: Model ): Promise { - const parts: Array = [] + const parts: Array = [] if (content) { parts.push({ type: 'text', text: content }) } for (const thinkingBlock of thinkingBlocks) { - parts.push({ type: 'text', text: thinkingBlock.content }) + parts.push({ type: 'reasoning', text: thinkingBlock.content }) } for (const fileBlock of fileBlocks) { @@ -194,17 +202,20 @@ async function convertMessageToAssistantModelMessage( * This function processes messages and transforms them into the format required by the SDK. * It handles special cases for vision models and image enhancement models. * - * @param messages - Array of messages to convert. Must contain at least 3 messages when using image enhancement models for special handling. + * @param messages - Array of messages to convert. * @param model - The model configuration that determines conversion behavior * * @returns A promise that resolves to an array of SDK-compatible model messages * * @remarks - * For image enhancement models with 3+ messages: - * - Examines the last 2 messages to find an assistant message containing image blocks - * - If found, extracts images from the assistant message and appends them to the last user message content - * - Returns all converted messages (not just the last two) with the images merged into the user message - * - Typical pattern: [system?, assistant(image), user] -> [system?, assistant, user(image)] + * For image enhancement models: + * - Collapses the conversation into [system?, user(image)] format + * - Searches backwards through all messages to find the most recent assistant message with images + * - Preserves all system messages (including ones generated from file uploads like 'fileid://...') + * - Extracts the last user message content and merges images from the previous assistant message + * - Returns only the collapsed messages: system messages (if any) followed by a single user message + * - If no user message is found, returns only system messages + * - Typical pattern: [system?, user, assistant(image), user] -> [system?, user(image)] * * For other models: * - Returns all converted messages in order without special image handling @@ -220,25 +231,66 @@ export async function convertMessagesToSdkMessages(messages: Message[], model: M sdkMessages.push(...(Array.isArray(sdkMessage) ? sdkMessage : [sdkMessage])) } // Special handling for image enhancement models - // Only merge images into the user message - // [system?, assistant(image), user] -> [system?, assistant, user(image)] - if (isImageEnhancementModel(model) && messages.length >= 3) { - const needUpdatedMessages = messages.slice(-2) - const assistantMessage = needUpdatedMessages.find((m) => m.role === 'assistant') - const userSdkMessage = sdkMessages[sdkMessages.length - 1] + // Target behavior: Collapse the conversation into [system?, user(image)]. + // Explanation of why we don't simply use slice: + // 1) We need to preserve all system messages: During the convertMessageToSdkParam process, native file uploads may insert `system(fileid://...)`. + // Directly slicing the original messages or already converted sdkMessages could easily result in missing these system instructions. + // Therefore, we first perform a full conversion and then aggregate the system messages afterward. + // 2) The conversion process may split messages: A single user message might be broken into two SDK messages—[system, user]. + // Slicing either side could lead to obtaining semantically incorrect fragments (e.g., only the split-out system message). + // 3) The “previous assistant message” is not necessarily the second-to-last one: There might be system messages or other message blocks inserted in between, + // making a simple slice(-2) assumption too rigid. Here, we trace back from the end of the original messages to locate the most recent assistant message, which better aligns with business semantics. + // 4) This is a “collapse” rather than a simple “slice”: Ultimately, we need to synthesize a new user message + // (with text from the last user message and images from the previous assistant message). Using slice can only extract subarrays, + // which still require reassembly; constructing directly according to the target structure is clearer and more reliable. + if (isImageEnhancementModel(model)) { + // Collect all system messages (including ones generated from file uploads) + const systemMessages = sdkMessages.filter((m): m is SystemModelMessage => m.role === 'system') - if (assistantMessage && userSdkMessage?.role === 'user') { - const imageBlocks = findImageBlocks(assistantMessage) - const imageParts = await convertImageBlockToImagePart(imageBlocks) + // Find the last user message (SDK converted) + const lastUserSdkIndex = (() => { + for (let i = sdkMessages.length - 1; i >= 0; i--) { + if (sdkMessages[i].role === 'user') return i + } + return -1 + })() - if (imageParts.length > 0) { - if (typeof userSdkMessage.content === 'string') { - userSdkMessage.content = [{ type: 'text', text: userSdkMessage.content }, ...imageParts] - } else if (Array.isArray(userSdkMessage.content)) { - userSdkMessage.content.push(...imageParts) - } + const lastUserSdk = lastUserSdkIndex >= 0 ? (sdkMessages[lastUserSdkIndex] as UserModelMessage) : null + + // Find the nearest preceding assistant message in original messages + let prevAssistant: Message | null = null + for (let i = messages.length - 2; i >= 0; i--) { + if (messages[i].role === 'assistant') { + prevAssistant = messages[i] + break } } + + // Build the final user content parts + let finalUserParts: Array = [] + if (lastUserSdk) { + if (typeof lastUserSdk.content === 'string') { + finalUserParts.push({ type: 'text', text: lastUserSdk.content }) + } else if (Array.isArray(lastUserSdk.content)) { + finalUserParts = [...lastUserSdk.content] + } + } + + // Append images from the previous assistant message if any + if (prevAssistant) { + const imageBlocks = findImageBlocks(prevAssistant) + const imageParts = await convertImageBlockToImagePart(imageBlocks) + if (imageParts.length > 0) { + finalUserParts.push(...imageParts) + } + } + + // If we couldn't find a last user message, fall back to returning collected system messages only + if (!lastUserSdk) { + return systemMessages + } + + return [...systemMessages, { role: 'user', content: finalUserParts }] } return sdkMessages diff --git a/src/renderer/src/aiCore/prepareParams/modelParameters.ts b/src/renderer/src/aiCore/prepareParams/modelParameters.ts index 8a1d53a754..58b4834f53 100644 --- a/src/renderer/src/aiCore/prepareParams/modelParameters.ts +++ b/src/renderer/src/aiCore/prepareParams/modelParameters.ts @@ -4,60 +4,90 @@ */ import { - isClaude45ReasoningModel, isClaudeReasoningModel, isMaxTemperatureOneModel, - isNotSupportTemperatureAndTopP, isSupportedFlexServiceTier, - isSupportedThinkingTokenClaudeModel + isSupportedThinkingTokenClaudeModel, + isSupportTemperatureModel, + isSupportTopPModel, + isTemperatureTopPMutuallyExclusiveModel } from '@renderer/config/models' -import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService' +import { + DEFAULT_ASSISTANT_SETTINGS, + getAssistantSettings, + getProviderByModel +} from '@renderer/services/AssistantService' import type { Assistant, Model } from '@renderer/types' import { defaultTimeout } from '@shared/config/constant' import { getAnthropicThinkingBudget } from '../utils/reasoning' /** - * Claude 4.5 推理模型: - * - 只启用 temperature → 使用 temperature - * - 只启用 top_p → 使用 top_p - * - 同时启用 → temperature 生效,top_p 被忽略 - * - 都不启用 → 都不使用 - * 获取温度参数 + * Retrieves the temperature parameter, adapting it based on assistant.settings and model capabilities. + * - Disabled for Claude reasoning models when reasoning effort is set. + * - Disabled for models that do not support temperature. + * - Disabled for Claude 4.5 reasoning models when TopP is enabled and temperature is disabled. + * Otherwise, returns the temperature value if the assistant has temperature enabled. + */ export function getTemperature(assistant: Assistant, model: Model): number | undefined { if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { return undefined } + + if (!isSupportTemperatureModel(model, assistant)) { + return undefined + } + if ( - isNotSupportTemperatureAndTopP(model) || - (isClaude45ReasoningModel(model) && assistant.settings?.enableTopP && !assistant.settings?.enableTemperature) + isTemperatureTopPMutuallyExclusiveModel(model) && + assistant.settings?.enableTopP && + !assistant.settings?.enableTemperature ) { return undefined } + + return getTemperatureValue(assistant, model) +} + +function getTemperatureValue(assistant: Assistant, model: Model): number | undefined { const assistantSettings = getAssistantSettings(assistant) let temperature = assistantSettings?.temperature if (temperature && isMaxTemperatureOneModel(model)) { temperature = Math.min(1, temperature) } - return assistantSettings?.enableTemperature ? temperature : undefined + + // FIXME: assistant.settings.enableTemperature should be always a boolean value. + const enableTemperature = assistantSettings?.enableTemperature ?? DEFAULT_ASSISTANT_SETTINGS.enableTemperature + return enableTemperature ? temperature : undefined } /** - * 获取 TopP 参数 + * Retrieves the TopP parameter, adapting it based on assistant.settings and model capabilities. + * - Disabled for Claude reasoning models when reasoning effort is set. + * - Disabled for models that do not support TopP. + * - Disabled for Claude 4.5 reasoning models when temperature is explicitly enabled. + * Otherwise, returns the TopP value if the assistant has TopP enabled. */ export function getTopP(assistant: Assistant, model: Model): number | undefined { if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { return undefined } - if ( - isNotSupportTemperatureAndTopP(model) || - (isClaude45ReasoningModel(model) && assistant.settings?.enableTemperature) - ) { + if (!isSupportTopPModel(model, assistant)) { return undefined } + if (isTemperatureTopPMutuallyExclusiveModel(model) && assistant.settings?.enableTemperature) { + return undefined + } + + return getTopPValue(assistant) +} + +function getTopPValue(assistant: Assistant): number | undefined { const assistantSettings = getAssistantSettings(assistant) - return assistantSettings?.enableTopP ? assistantSettings?.topP : undefined + // FIXME: assistant.settings.enableTopP should be always a boolean value. + const enableTopP = assistantSettings.enableTopP ?? DEFAULT_ASSISTANT_SETTINGS.enableTopP + return enableTopP ? assistantSettings?.topP : undefined } /** diff --git a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts index cba7fcdb10..52234c5f1f 100644 --- a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts +++ b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts @@ -21,6 +21,7 @@ import { isGrokModel, isOpenAIModel, isOpenRouterBuiltInWebSearchModel, + isPureGenerateImageModel, isSupportedReasoningEffortModel, isSupportedThinkingTokenModel, isWebSearchModel @@ -33,7 +34,7 @@ import { type Assistant, type MCPTool, type Provider, SystemProviderIds } from ' import type { StreamTextParams } from '@renderer/types/aiCoreTypes' import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern' import { replacePromptVariables } from '@renderer/utils/prompt' -import { isAIGatewayProvider, isAwsBedrockProvider } from '@renderer/utils/provider' +import { isAIGatewayProvider, isAwsBedrockProvider, isSupportUrlContextProvider } from '@renderer/utils/provider' import type { ModelMessage, Tool } from 'ai' import { stepCountIs } from 'ai' @@ -118,7 +119,13 @@ export async function buildStreamTextParams( isOpenRouterBuiltInWebSearchModel(model) || model.id.includes('sonar')) - const enableUrlContext = assistant.enableUrlContext || false + // Validate provider and model support to prevent stale state from triggering urlContext + const enableUrlContext = !!( + assistant.enableUrlContext && + isSupportUrlContextProvider(provider) && + !isPureGenerateImageModel(model) && + (isGeminiModel(model) || isAnthropicModel(model)) + ) const enableGenerateImage = !!(isGenerateImageModel(model) && assistant.enableGenerateImage) diff --git a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts index 43d3cc52b8..b1d8e34fcd 100644 --- a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts +++ b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts @@ -42,7 +42,8 @@ vi.mock('@renderer/utils/api', () => ({ routeToEndpoint: vi.fn((host) => ({ baseURL: host, endpoint: '/chat/completions' - })) + })), + isWithTrailingSharp: vi.fn((host) => host?.endsWith('#') || false) })) vi.mock('@renderer/utils/provider', async (importOriginal) => { @@ -78,7 +79,7 @@ vi.mock('@renderer/services/AssistantService', () => ({ import { getProviderByModel } from '@renderer/services/AssistantService' import type { Model, Provider } from '@renderer/types' import { formatApiHost } from '@renderer/utils/api' -import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider' +import { isAzureOpenAIProvider, isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider' import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants' import { getActualProvider, providerToAiSdkConfig } from '../providerConfig' @@ -132,6 +133,17 @@ const createPerplexityProvider = (): Provider => ({ isSystem: false }) +const createAzureProvider = (apiVersion: string): Provider => ({ + id: 'azure-openai', + type: 'azure-openai', + name: 'Azure OpenAI', + apiKey: 'test-key', + apiHost: 'https://example.openai.azure.com/openai', + apiVersion, + models: [], + isSystem: true +}) + describe('Copilot responses routing', () => { beforeEach(() => { ;(globalThis as any).window = { @@ -227,12 +239,19 @@ describe('CherryAI provider configuration', () => { // Mock the functions to simulate non-CherryAI provider vi.mocked(isCherryAIProvider).mockReturnValue(false) vi.mocked(getProviderByModel).mockReturnValue(provider) + // Mock isWithTrailingSharp to return false for this test + vi.mocked(formatApiHost as any).mockImplementation((host, isSupportedAPIVersion = true) => { + if (isSupportedAPIVersion === false) { + return host + } + return `${host}/v1` + }) // Call getActualProvider const actualProvider = getActualProvider(model) - // Verify that formatApiHost was called with default parameters (true) - expect(formatApiHost).toHaveBeenCalledWith('https://api.openai.com') + // Verify that formatApiHost was called with appendApiVersion parameter + expect(formatApiHost).toHaveBeenCalledWith('https://api.openai.com', true) expect(actualProvider.apiHost).toBe('https://api.openai.com/v1') }) @@ -303,12 +322,19 @@ describe('Perplexity provider configuration', () => { vi.mocked(isCherryAIProvider).mockReturnValue(false) vi.mocked(isPerplexityProvider).mockReturnValue(false) vi.mocked(getProviderByModel).mockReturnValue(provider) + // Mock isWithTrailingSharp to return false for this test + vi.mocked(formatApiHost as any).mockImplementation((host, isSupportedAPIVersion = true) => { + if (isSupportedAPIVersion === false) { + return host + } + return `${host}/v1` + }) // Call getActualProvider const actualProvider = getActualProvider(model) - // Verify that formatApiHost was called with default parameters (true) - expect(formatApiHost).toHaveBeenCalledWith('https://api.openai.com') + // Verify that formatApiHost was called with appendApiVersion parameter + expect(formatApiHost).toHaveBeenCalledWith('https://api.openai.com', true) expect(actualProvider.apiHost).toBe('https://api.openai.com/v1') }) @@ -489,3 +515,46 @@ describe('Stream options includeUsage configuration', () => { expect(config.providerId).toBe('github-copilot-openai-compatible') }) }) + +describe('Azure OpenAI traditional API routing', () => { + beforeEach(() => { + ;(globalThis as any).window = { + ...(globalThis as any).window, + keyv: createWindowKeyv() + } + mockGetState.mockReturnValue({ + settings: { + openAI: { + streamOptions: { + includeUsage: undefined + } + } + } + }) + + vi.mocked(isAzureOpenAIProvider).mockImplementation((provider) => provider.type === 'azure-openai') + }) + + it('uses deployment-based URLs when apiVersion is a date version', () => { + const provider = createAzureProvider('2024-02-15-preview') + const config = providerToAiSdkConfig(provider, createModel('gpt-4o', 'GPT-4o', provider.id)) + + expect(config.providerId).toBe('azure') + expect(config.options.apiVersion).toBe('2024-02-15-preview') + expect(config.options.useDeploymentBasedUrls).toBe(true) + }) + + it('does not force deployment-based URLs for apiVersion v1/preview', () => { + const v1Provider = createAzureProvider('v1') + const v1Config = providerToAiSdkConfig(v1Provider, createModel('gpt-4o', 'GPT-4o', v1Provider.id)) + expect(v1Config.providerId).toBe('azure-responses') + expect(v1Config.options.apiVersion).toBe('v1') + expect(v1Config.options.useDeploymentBasedUrls).toBeUndefined() + + const previewProvider = createAzureProvider('preview') + const previewConfig = providerToAiSdkConfig(previewProvider, createModel('gpt-4o', 'GPT-4o', previewProvider.id)) + expect(previewConfig.providerId).toBe('azure-responses') + expect(previewConfig.options.apiVersion).toBe('preview') + expect(previewConfig.options.useDeploymentBasedUrls).toBeUndefined() + }) +}) diff --git a/src/renderer/src/aiCore/provider/factory.ts b/src/renderer/src/aiCore/provider/factory.ts index ff100051b7..d18aa02eeb 100644 --- a/src/renderer/src/aiCore/provider/factory.ts +++ b/src/renderer/src/aiCore/provider/factory.ts @@ -31,7 +31,8 @@ const STATIC_PROVIDER_MAPPING: Record = { 'azure-openai': 'azure', // Azure OpenAI -> azure 'openai-response': 'openai', // OpenAI Responses -> openai grok: 'xai', // Grok -> xai - copilot: 'github-copilot-openai-compatible' + copilot: 'github-copilot-openai-compatible', + tokenflux: 'openrouter' // TokenFlux -> openrouter (fully compatible) } /** diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts index 99e4fbd1c9..0ad15ea895 100644 --- a/src/renderer/src/aiCore/provider/providerConfig.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -9,6 +9,7 @@ import { } from '@renderer/hooks/useAwsBedrock' import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI' import { getProviderByModel } from '@renderer/services/AssistantService' +import { getProviderById } from '@renderer/services/ProviderService' import store from '@renderer/store' import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types' import type { OpenAICompletionsStreamOptions } from '@renderer/types/aiCoreTypes' @@ -17,6 +18,7 @@ import { formatAzureOpenAIApiHost, formatOllamaApiHost, formatVertexApiHost, + isWithTrailingSharp, routeToEndpoint } from '@renderer/utils/api' import { @@ -30,6 +32,7 @@ import { isSupportStreamOptionsProvider, isVertexProvider } from '@renderer/utils/provider' +import { defaultAppHeaders } from '@shared/utils' import { cloneDeep, isEmpty } from 'lodash' import type { AiSdkConfig } from '../types' @@ -69,14 +72,15 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider { */ export function formatProviderApiHost(provider: Provider): Provider { const formatted = { ...provider } + const appendApiVersion = !isWithTrailingSharp(provider.apiHost) if (formatted.anthropicApiHost) { - formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost) + formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost, appendApiVersion) } if (isAnthropicProvider(provider)) { const baseHost = formatted.anthropicApiHost || formatted.apiHost // AI SDK needs /v1 in baseURL, Anthropic SDK will strip it in getSdkClient - formatted.apiHost = formatApiHost(baseHost) + formatted.apiHost = formatApiHost(baseHost, appendApiVersion) if (!formatted.anthropicApiHost) { formatted.anthropicApiHost = formatted.apiHost } @@ -85,7 +89,7 @@ export function formatProviderApiHost(provider: Provider): Provider { } else if (isOllamaProvider(formatted)) { formatted.apiHost = formatOllamaApiHost(formatted.apiHost) } else if (isGeminiProvider(formatted)) { - formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta') + formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion, 'v1beta') } else if (isAzureOpenAIProvider(formatted)) { formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost) } else if (isVertexProvider(formatted)) { @@ -95,7 +99,7 @@ export function formatProviderApiHost(provider: Provider): Provider { } else if (isPerplexityProvider(formatted)) { formatted.apiHost = formatApiHost(formatted.apiHost, false) } else { - formatted.apiHost = formatApiHost(formatted.apiHost) + formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion) } return formatted } @@ -194,18 +198,13 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A extraOptions.mode = 'chat' } - // 添加额外headers - if (actualProvider.extra_headers) { - extraOptions.headers = actualProvider.extra_headers - // copy from openaiBaseClient/openaiResponseApiClient - if (aiSdkProviderId === 'openai') { - extraOptions.headers = { - ...extraOptions.headers, - 'HTTP-Referer': 'https://cherry-ai.com', - 'X-Title': 'Cherry Studio', - 'X-Api-Key': baseConfig.apiKey - } - } + extraOptions.headers = { + ...defaultAppHeaders(), + ...actualProvider.extra_headers + } + + if (aiSdkProviderId === 'openai') { + extraOptions.headers['X-Api-Key'] = baseConfig.apiKey } // azure // https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest @@ -215,6 +214,15 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A } else if (aiSdkProviderId === 'azure') { extraOptions.mode = 'chat' } + if (isAzureOpenAIProvider(actualProvider)) { + const apiVersion = actualProvider.apiVersion?.trim() + if (apiVersion) { + extraOptions.apiVersion = apiVersion + if (!['preview', 'v1'].includes(apiVersion)) { + extraOptions.useDeploymentBasedUrls = true + } + } + } // bedrock if (aiSdkProviderId === 'bedrock') { @@ -248,6 +256,12 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A if (model.endpoint_type) { extraOptions.endpointType = model.endpoint_type } + // CherryIN API Host + const cherryinProvider = getProviderById(SystemProviderIds.cherryin) + if (cherryinProvider) { + extraOptions.anthropicBaseURL = cherryinProvider.anthropicApiHost + '/v1' + extraOptions.geminiBaseURL = cherryinProvider.apiHost + '/v1beta/models' + } } if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { diff --git a/src/renderer/src/aiCore/tools/MemorySearchTool.ts b/src/renderer/src/aiCore/tools/MemorySearchTool.ts index 20064dd1b2..5028f2eb4d 100644 --- a/src/renderer/src/aiCore/tools/MemorySearchTool.ts +++ b/src/renderer/src/aiCore/tools/MemorySearchTool.ts @@ -24,7 +24,8 @@ export const memorySearchTool = () => { } const memoryConfig = selectMemoryConfig(store.getState()) - if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) { + + if (!memoryConfig.llmModel || !memoryConfig.embeddingModel) { return [] } diff --git a/src/renderer/src/aiCore/utils/__tests__/options.test.ts b/src/renderer/src/aiCore/utils/__tests__/options.test.ts index 9eeeac725b..a6c9a6c95c 100644 --- a/src/renderer/src/aiCore/utils/__tests__/options.test.ts +++ b/src/renderer/src/aiCore/utils/__tests__/options.test.ts @@ -464,7 +464,8 @@ describe('options utils', () => { custom_param: 'custom_value', another_param: 123, serviceTier: undefined, - textVerbosity: undefined + textVerbosity: undefined, + store: false } }) }) diff --git a/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts b/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts index 36253e5c1d..df7d69d0c2 100644 --- a/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts +++ b/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts @@ -11,6 +11,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { getAnthropicReasoningParams, + getAnthropicThinkingBudget, getBedrockReasoningParams, getCustomParameters, getGeminiReasoningParams, @@ -89,7 +90,8 @@ vi.mock('@renderer/config/models', async (importOriginal) => { isQwenAlwaysThinkModel: vi.fn(() => false), isSupportedThinkingTokenHunyuanModel: vi.fn(() => false), isSupportedThinkingTokenModel: vi.fn(() => false), - isGPT51SeriesModel: vi.fn(() => false) + isGPT51SeriesModel: vi.fn(() => false), + findTokenLimit: vi.fn(actual.findTokenLimit) } }) @@ -596,7 +598,7 @@ describe('reasoning utils', () => { expect(result).toEqual({}) }) - it('should return disabled thinking when no reasoning effort', async () => { + it('should return disabled thinking when reasoning effort is none', async () => { const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models') vi.mocked(isReasoningModel).mockReturnValue(true) @@ -611,7 +613,9 @@ describe('reasoning utils', () => { const assistant: Assistant = { id: 'test', name: 'Test', - settings: {} + settings: { + reasoning_effort: 'none' + } } as Assistant const result = getAnthropicReasoningParams(assistant, model) @@ -647,7 +651,7 @@ describe('reasoning utils', () => { expect(result).toEqual({ thinking: { type: 'enabled', - budgetTokens: 2048 + budgetTokens: 4096 } }) }) @@ -675,7 +679,7 @@ describe('reasoning utils', () => { expect(result).toEqual({}) }) - it('should disable thinking for Flash models without reasoning effort', async () => { + it('should disable thinking for Flash models when reasoning effort is none', async () => { const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models') vi.mocked(isReasoningModel).mockReturnValue(true) @@ -690,7 +694,9 @@ describe('reasoning utils', () => { const assistant: Assistant = { id: 'test', name: 'Test', - settings: {} + settings: { + reasoning_effort: 'none' + } } as Assistant const result = getGeminiReasoningParams(assistant, model) @@ -725,7 +731,7 @@ describe('reasoning utils', () => { const result = getGeminiReasoningParams(assistant, model) expect(result).toEqual({ thinkingConfig: { - thinkingBudget: 16448, + thinkingBudget: expect.any(Number), includeThoughts: true } }) @@ -754,7 +760,8 @@ describe('reasoning utils', () => { const result = getGeminiReasoningParams(assistant, model) expect(result).toEqual({ thinkingConfig: { - includeThoughts: true + includeThoughts: true, + thinkingBudget: -1 } }) }) @@ -888,7 +895,7 @@ describe('reasoning utils', () => { expect(result).toEqual({ reasoningConfig: { type: 'enabled', - budgetTokens: 2048 + budgetTokens: 4096 } }) }) @@ -989,4 +996,89 @@ describe('reasoning utils', () => { }) }) }) + + describe('getAnthropicThinkingBudget', () => { + it('should return undefined when reasoningEffort is undefined', async () => { + const result = getAnthropicThinkingBudget(4096, undefined, 'claude-3-7-sonnet') + expect(result).toBeUndefined() + }) + + it('should return undefined when reasoningEffort is none', async () => { + const result = getAnthropicThinkingBudget(4096, 'none', 'claude-3-7-sonnet') + expect(result).toBeUndefined() + }) + + it('should return undefined when tokenLimit is not found', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue(undefined) + + const result = getAnthropicThinkingBudget(4096, 'medium', 'unknown-model') + expect(result).toBeUndefined() + }) + + it('should calculate budget correctly when maxTokens is provided', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 }) + + const result = getAnthropicThinkingBudget(4096, 'medium', 'claude-3-7-sonnet') + // EFFORT_RATIO['medium'] = 0.5 + // budget = Math.floor((32768 - 1024) * 0.5 + 1024) + // = Math.floor(31744 * 0.5 + 1024) = Math.floor(15872 + 1024) = 16896 + // budgetTokens = Math.min(16896, 4096) = 4096 + // result = Math.max(1024, 4096) = 4096 + expect(result).toBe(4096) + }) + + it('should use tokenLimit.max when maxTokens is undefined', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 }) + + const result = getAnthropicThinkingBudget(undefined, 'medium', 'claude-3-7-sonnet') + // When maxTokens is undefined, budget is not constrained by maxTokens + // EFFORT_RATIO['medium'] = 0.5 + // budget = Math.floor((32768 - 1024) * 0.5 + 1024) + // = Math.floor(31744 * 0.5 + 1024) = Math.floor(15872 + 1024) = 16896 + // result = Math.max(1024, 16896) = 16896 + expect(result).toBe(16896) + }) + + it('should enforce minimum budget of 1024', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 100, max: 1000 }) + + const result = getAnthropicThinkingBudget(500, 'low', 'claude-3-7-sonnet') + // EFFORT_RATIO['low'] = 0.05 + // budget = Math.floor((1000 - 100) * 0.05 + 100) + // = Math.floor(900 * 0.05 + 100) = Math.floor(45 + 100) = 145 + // budgetTokens = Math.min(145, 500) = 145 + // result = Math.max(1024, 145) = 1024 + expect(result).toBe(1024) + }) + + it('should respect effort ratio for high reasoning effort', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 }) + + const result = getAnthropicThinkingBudget(8192, 'high', 'claude-3-7-sonnet') + // EFFORT_RATIO['high'] = 0.8 + // budget = Math.floor((32768 - 1024) * 0.8 + 1024) + // = Math.floor(31744 * 0.8 + 1024) = Math.floor(25395.2 + 1024) = 26419 + // budgetTokens = Math.min(26419, 8192) = 8192 + // result = Math.max(1024, 8192) = 8192 + expect(result).toBe(8192) + }) + + it('should use full token limit when maxTokens is undefined and reasoning effort is high', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 }) + + const result = getAnthropicThinkingBudget(undefined, 'high', 'claude-3-7-sonnet') + // When maxTokens is undefined, budget is not constrained by maxTokens + // EFFORT_RATIO['high'] = 0.8 + // budget = Math.floor((32768 - 1024) * 0.8 + 1024) + // = Math.floor(31744 * 0.8 + 1024) = Math.floor(25395.2 + 1024) = 26419 + // result = Math.max(1024, 26419) = 26419 + expect(result).toBe(26419) + }) + }) }) diff --git a/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts b/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts index fa5e3c3b36..5c95a664aa 100644 --- a/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts +++ b/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts @@ -259,7 +259,7 @@ describe('websearch utils', () => { expect(result).toEqual({ xai: { - maxSearchResults: 50, + maxSearchResults: 30, returnCitations: true, sources: [{ type: 'web', excludedWebsites: [] }, { type: 'news' }, { type: 'x' }], mode: 'on' diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index 8ec46c9df2..8dc7a10af9 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -10,7 +10,9 @@ import { isAnthropicModel, isGeminiModel, isGrokModel, + isInterleavedThinkingModel, isOpenAIModel, + isOpenAIOpenWeightModel, isQwenMTModel, isSupportFlexServiceTierModel, isSupportVerbosityModel @@ -244,7 +246,7 @@ export function buildProviderOptions( providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier) break case SystemProviderIds.ollama: - providerSpecificOptions = buildOllamaProviderOptions(assistant, capabilities) + providerSpecificOptions = buildOllamaProviderOptions(assistant, model, capabilities) break case SystemProviderIds.gateway: providerSpecificOptions = buildAIGatewayOptions(assistant, model, capabilities, serviceTier, textVerbosity) @@ -395,10 +397,12 @@ function buildOpenAIProviderOptions( } } + // TODO: 支持配置是否在服务端持久化 providerOptions = { ...providerOptions, serviceTier, - textVerbosity + textVerbosity, + store: false } return { @@ -564,6 +568,7 @@ function buildBedrockProviderOptions( function buildOllamaProviderOptions( assistant: Assistant, + model: Model, capabilities: { enableReasoning: boolean enableWebSearch: boolean @@ -574,7 +579,14 @@ function buildOllamaProviderOptions( const providerOptions: OllamaCompletionProviderOptions = {} const reasoningEffort = assistant.settings?.reasoning_effort if (enableReasoning) { - providerOptions.think = !['none', undefined].includes(reasoningEffort) + if (isOpenAIOpenWeightModel(model)) { + // For gpt-oss models, Ollama accepts: 'low' | 'medium' | 'high' + if (reasoningEffort === 'low' || reasoningEffort === 'medium' || reasoningEffort === 'high') { + providerOptions.think = reasoningEffort + } + } else { + providerOptions.think = !['none', undefined].includes(reasoningEffort) + } } return { ollama: providerOptions @@ -594,7 +606,7 @@ function buildGenericProviderOptions( enableGenerateImage: boolean } ): Record { - const { enableWebSearch } = capabilities + const { enableWebSearch, enableReasoning } = capabilities let providerOptions: Record = {} const reasoningParams = getReasoningEffort(assistant, model) @@ -602,6 +614,14 @@ function buildGenericProviderOptions( ...providerOptions, ...reasoningParams } + if (enableReasoning) { + if (isInterleavedThinkingModel(model)) { + providerOptions = { + ...providerOptions, + sendReasoning: true + } + } + } if (enableWebSearch) { const webSearchParams = getWebSearchParams(model) diff --git a/src/renderer/src/aiCore/utils/reasoning.ts b/src/renderer/src/aiCore/utils/reasoning.ts index 1e74db24df..ab8a0b7983 100644 --- a/src/renderer/src/aiCore/utils/reasoning.ts +++ b/src/renderer/src/aiCore/utils/reasoning.ts @@ -8,16 +8,16 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { findTokenLimit, GEMINI_FLASH_MODEL_REGEX, - getThinkModelType, + getModelSupportedReasoningEffortOptions, isDeepSeekHybridInferenceModel, + isDoubaoSeed18Model, isDoubaoSeedAfter251015, isDoubaoThinkingAutoModel, isGemini3ThinkingTokenModel, - isGPT5SeriesModel, - isGPT51SeriesModel, isGrok4FastReasoningModel, isOpenAIDeepResearchModel, isOpenAIModel, + isOpenAIReasoningModel, isQwenAlwaysThinkModel, isQwenReasoningModel, isReasoningModel, @@ -28,14 +28,15 @@ import { isSupportedThinkingTokenDoubaoModel, isSupportedThinkingTokenGeminiModel, isSupportedThinkingTokenHunyuanModel, + isSupportedThinkingTokenMiMoModel, isSupportedThinkingTokenModel, isSupportedThinkingTokenQwenModel, isSupportedThinkingTokenZhipuModel, - MODEL_SUPPORTED_REASONING_EFFORT + isSupportNoneReasoningEffortModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService' -import type { Assistant, Model } from '@renderer/types' +import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types' import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types' import type { OpenAIReasoningSummary } from '@renderer/types/aiCoreTypes' import type { ReasoningEffortOptionalParams } from '@renderer/types/sdk' @@ -65,7 +66,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin // reasoningEffort is not set, no extra reasoning setting // Generally, for every model which supports reasoning control, the reasoning effort won't be undefined. // It's for some reasoning models that don't support reasoning control, such as deepseek reasoner. - if (!reasoningEffort) { + if (!reasoningEffort || reasoningEffort === 'default') { return {} } @@ -73,9 +74,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin if (reasoningEffort === 'none') { // openrouter: use reasoning if (model.provider === SystemProviderIds.openrouter) { - // 'none' is not an available value for effort for now. - // I think they should resolve this issue soon, so I'll just go ahead and use this value. - if (isGPT51SeriesModel(model) && reasoningEffort === 'none') { + if (isSupportNoneReasoningEffortModel(model) && reasoningEffort === 'none') { return { reasoning: { effort: 'none' } } } return { reasoning: { enabled: false, exclude: true } } @@ -119,8 +118,8 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin return { thinking: { type: 'disabled' } } } - // Specially for GPT-5.1. Suppose this is a OpenAI Compatible provider - if (isGPT51SeriesModel(model)) { + // GPT 5.1, GPT 5.2, or newer + if (isSupportNoneReasoningEffortModel(model)) { return { reasoningEffort: 'none' } @@ -134,8 +133,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin // https://creator.poe.com/docs/external-applications/openai-compatible-api#additional-considerations // Poe provider - supports custom bot parameters via extra_body if (provider.id === SystemProviderIds.poe) { - // GPT-5 series models use reasoning_effort parameter in extra_body - if (isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) { + if (isOpenAIReasoningModel(model)) { return { extra_body: { reasoning_effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort @@ -331,16 +329,15 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin // Grok models/Perplexity models/OpenAI models, use reasoning_effort if (isSupportedReasoningEffortModel(model)) { // 检查模型是否支持所选选项 - const modelType = getThinkModelType(model) - const supportedOptions = MODEL_SUPPORTED_REASONING_EFFORT[modelType] - if (supportedOptions.includes(reasoningEffort)) { + const supportedOptions = getModelSupportedReasoningEffortOptions(model)?.filter((option) => option !== 'default') + if (supportedOptions?.includes(reasoningEffort)) { return { reasoningEffort } } else { // 如果不支持,fallback到第一个支持的值 return { - reasoningEffort: supportedOptions[0] + reasoningEffort: supportedOptions?.[0] } } } @@ -392,7 +389,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin // Use thinking, doubao, zhipu, etc. if (isSupportedThinkingTokenDoubaoModel(model)) { - if (isDoubaoSeedAfter251015(model)) { + if (isDoubaoSeedAfter251015(model) || isDoubaoSeed18Model(model)) { return { reasoningEffort } } if (reasoningEffort === 'high') { @@ -411,6 +408,12 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin return { thinking: { type: 'enabled' } } } + if (isSupportedThinkingTokenMiMoModel(model)) { + return { + thinking: { type: 'enabled' } + } + } + // Default case: no special thinking settings return {} } @@ -430,7 +433,7 @@ export function getOpenAIReasoningParams( let reasoningEffort = assistant?.settings?.reasoning_effort - if (!reasoningEffort) { + if (!reasoningEffort || reasoningEffort === 'default') { return {} } @@ -482,16 +485,14 @@ export function getAnthropicThinkingBudget( return undefined } - const budgetTokens = Math.max( - 1024, - Math.floor( - Math.min( - (tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min, - (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio - ) - ) - ) - return budgetTokens + const budget = Math.floor((tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min) + + let budgetTokens = budget + if (maxTokens !== undefined) { + budgetTokens = Math.min(budget, maxTokens) + } + + return Math.max(1024, budgetTokens) } /** @@ -508,7 +509,11 @@ export function getAnthropicReasoningParams( const reasoningEffort = assistant?.settings?.reasoning_effort - if (reasoningEffort === undefined || reasoningEffort === 'none') { + if (!reasoningEffort || reasoningEffort === 'default') { + return {} + } + + if (reasoningEffort === 'none') { return { thinking: { type: 'disabled' @@ -532,20 +537,25 @@ export function getAnthropicReasoningParams( return {} } -// type GoogleThinkingLevel = NonNullable['thinkingLevel'] +type GoogleThinkingLevel = NonNullable['thinkingLevel'] -// function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogelThinkingLevel { -// switch (reasoningEffort) { -// case 'low': -// return 'low' -// case 'medium': -// return 'medium' -// case 'high': -// return 'high' -// default: -// return 'medium' -// } -// } +function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogleThinkingLevel { + switch (reasoningEffort) { + case 'default': + return undefined + case 'minimal': + return 'minimal' + case 'low': + return 'low' + case 'medium': + return 'medium' + case 'high': + return 'high' + default: + logger.warn('Unknown thinking level for Gemini. Fallback to medium instead.', { reasoningEffort }) + return 'medium' + } +} /** * 获取 Gemini 推理参数 @@ -563,6 +573,10 @@ export function getGeminiReasoningParams( const reasoningEffort = assistant?.settings?.reasoning_effort + if (!reasoningEffort || reasoningEffort === 'default') { + return {} + } + // Gemini 推理参数 if (isSupportedThinkingTokenGeminiModel(model)) { if (reasoningEffort === undefined || reasoningEffort === 'none') { @@ -574,21 +588,22 @@ export function getGeminiReasoningParams( } } - // TODO: 很多中转还不支持 // https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#new_api_features_in_gemini_3 - // if (isGemini3ThinkingTokenModel(model)) { - // return { - // thinkingConfig: { - // thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort) - // } - // } - // } + if (isGemini3ThinkingTokenModel(model)) { + return { + thinkingConfig: { + includeThoughts: true, + thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort) + } + } + } const effortRatio = EFFORT_RATIO[reasoningEffort] if (effortRatio > 1) { return { thinkingConfig: { + thinkingBudget: -1, includeThoughts: true } } @@ -622,10 +637,6 @@ export function getXAIReasoningParams(assistant: Assistant, model: Model): Pick< const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant) - if (!reasoningEffort || reasoningEffort === 'none') { - return {} - } - switch (reasoningEffort) { case 'auto': case 'minimal': @@ -634,6 +645,12 @@ export function getXAIReasoningParams(assistant: Assistant, model: Model): Pick< case 'low': case 'high': return { reasoningEffort } + case 'xhigh': + return { reasoningEffort: 'high' } + case 'default': + case 'none': + default: + return {} } } @@ -650,7 +667,7 @@ export function getBedrockReasoningParams( const reasoningEffort = assistant?.settings?.reasoning_effort - if (reasoningEffort === undefined) { + if (reasoningEffort === undefined || reasoningEffort === 'default') { return {} } diff --git a/src/renderer/src/aiCore/utils/websearch.ts b/src/renderer/src/aiCore/utils/websearch.ts index 127636a50b..14a99139be 100644 --- a/src/renderer/src/aiCore/utils/websearch.ts +++ b/src/renderer/src/aiCore/utils/websearch.ts @@ -9,6 +9,8 @@ import type { CherryWebSearchConfig } from '@renderer/store/websearch' import type { Model } from '@renderer/types' import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern' +const X_AI_MAX_SEARCH_RESULT = 30 + export function getWebSearchParams(model: Model): Record { if (model.provider === 'hunyuan') { return { enable_enhancement: true, citation: true, search_info: true } @@ -82,7 +84,7 @@ export function buildProviderBuiltinWebSearchConfig( const excludeDomains = mapRegexToPatterns(webSearchConfig.excludeDomains) return { xai: { - maxSearchResults: webSearchConfig.maxResults, + maxSearchResults: Math.min(webSearchConfig.maxResults, X_AI_MAX_SEARCH_RESULT), returnCitations: true, sources: [ { diff --git a/src/renderer/src/assets/images/apps/aistudio.png b/src/renderer/src/assets/images/apps/aistudio.png new file mode 100644 index 0000000000..c7cb2adebe Binary files /dev/null and b/src/renderer/src/assets/images/apps/aistudio.png differ diff --git a/src/renderer/src/assets/images/apps/aistudio.svg b/src/renderer/src/assets/images/apps/aistudio.svg deleted file mode 100644 index 2c08015593..0000000000 --- a/src/renderer/src/assets/images/apps/aistudio.svg +++ /dev/null @@ -1,27 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/src/renderer/src/assets/images/models/mimo.svg b/src/renderer/src/assets/images/models/mimo.svg new file mode 100644 index 0000000000..82370fece3 --- /dev/null +++ b/src/renderer/src/assets/images/models/mimo.svg @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/src/renderer/src/assets/images/providers/mimo.svg b/src/renderer/src/assets/images/providers/mimo.svg new file mode 100644 index 0000000000..82370fece3 --- /dev/null +++ b/src/renderer/src/assets/images/providers/mimo.svg @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/src/renderer/src/assets/images/search/baidu.svg b/src/renderer/src/assets/images/search/baidu.svg new file mode 100644 index 0000000000..ead7f89822 --- /dev/null +++ b/src/renderer/src/assets/images/search/baidu.svg @@ -0,0 +1 @@ +Baidu \ No newline at end of file diff --git a/src/renderer/src/assets/images/search/bing.svg b/src/renderer/src/assets/images/search/bing.svg new file mode 100644 index 0000000000..b411a4f068 --- /dev/null +++ b/src/renderer/src/assets/images/search/bing.svg @@ -0,0 +1 @@ +Bing \ No newline at end of file diff --git a/src/renderer/src/assets/images/search/google.svg b/src/renderer/src/assets/images/search/google.svg new file mode 100644 index 0000000000..e8e0f867bd --- /dev/null +++ b/src/renderer/src/assets/images/search/google.svg @@ -0,0 +1 @@ +Google \ No newline at end of file diff --git a/src/renderer/src/components/Avatar/AssistantAvatar.tsx b/src/renderer/src/components/Avatar/AssistantAvatar.tsx new file mode 100644 index 0000000000..7ab87e633e --- /dev/null +++ b/src/renderer/src/components/Avatar/AssistantAvatar.tsx @@ -0,0 +1,34 @@ +import EmojiIcon from '@renderer/components/EmojiIcon' +import { useSettings } from '@renderer/hooks/useSettings' +import { getDefaultModel } from '@renderer/services/AssistantService' +import type { Assistant } from '@renderer/types' +import { getLeadingEmoji } from '@renderer/utils' +import type { FC } from 'react' +import { useMemo } from 'react' + +import ModelAvatar from './ModelAvatar' + +interface AssistantAvatarProps { + assistant: Assistant + size?: number + className?: string +} + +const AssistantAvatar: FC = ({ assistant, size = 24, className }) => { + const { assistantIconType } = useSettings() + const defaultModel = getDefaultModel() + + const assistantName = useMemo(() => assistant.name || '', [assistant.name]) + + if (assistantIconType === 'model') { + return + } + + if (assistantIconType === 'emoji') { + return + } + + return null +} + +export default AssistantAvatar diff --git a/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx b/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx index 2cd8171d08..a51b718ecf 100644 --- a/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx +++ b/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx @@ -25,7 +25,7 @@ type ViewMode = 'split' | 'code' | 'preview' const HtmlArtifactsPopup: React.FC = ({ open, title, html, onSave, onClose }) => { const { t } = useTranslation() const [viewMode, setViewMode] = useState('split') - const [isFullscreen, setIsFullscreen] = useState(false) + const [isFullscreen, setIsFullscreen] = useState(true) const [saved, setSaved] = useTemporaryValue(false, 2000) const codeEditorRef = useRef(null) const previewFrameRef = useRef(null) @@ -78,7 +78,7 @@ const HtmlArtifactsPopup: React.FC = ({ open, title, ht - e.stopPropagation()}> + e.stopPropagation()} className="nodrag"> = ({ open, title, ht afterClose={onClose} centered={!isFullscreen} destroyOnHidden + forceRender={isFullscreen} mask={!isFullscreen} maskClosable={false} width={isFullscreen ? '100vw' : '90vw'} diff --git a/src/renderer/src/components/ContextMenu/index.tsx b/src/renderer/src/components/ContextMenu/index.tsx index 719cd14133..8db88955bd 100644 --- a/src/renderer/src/components/ContextMenu/index.tsx +++ b/src/renderer/src/components/ContextMenu/index.tsx @@ -6,6 +6,61 @@ interface ContextMenuProps { children: React.ReactNode } +/** + * Extract text content from selection, filtering out line numbers in code viewers. + * Preserves all content including plain text and code blocks, only removing line numbers. + * This ensures right-click copy in code blocks doesn't include line numbers while preserving indentation. + */ +function extractSelectedText(selection: Selection): string { + // Validate selection + if (selection.rangeCount === 0 || selection.isCollapsed) { + return '' + } + + const range = selection.getRangeAt(0) + const fragment = range.cloneContents() + + // Check if the selection contains code viewer elements + const hasLineNumbers = fragment.querySelectorAll('.line-number').length > 0 + + // If no line numbers, return the original text (preserves formatting) + if (!hasLineNumbers) { + return selection.toString() + } + + // Remove all line number elements + fragment.querySelectorAll('.line-number').forEach((el) => el.remove()) + + // Handle all content using optimized TreeWalker with precise node filtering + // This approach handles mixed content correctly while improving performance + const walker = document.createTreeWalker(fragment, NodeFilter.SHOW_TEXT | NodeFilter.SHOW_ELEMENT, null) + + let result = '' + let node = walker.nextNode() + + while (node) { + if (node.nodeType === Node.TEXT_NODE) { + // Preserve text content including whitespace + result += node.textContent + } else if (node.nodeType === Node.ELEMENT_NODE) { + const element = node as Element + + // Add newline after block elements and code lines to preserve structure + if (['H1', 'H2', 'H3', 'H4', 'H5', 'H6'].includes(element.tagName)) { + result += '\n' + } else if (element.classList.contains('line')) { + // Add newline after code lines to preserve code structure + result += '\n' + } + } + + node = walker.nextNode() + } + + // Clean up excessive newlines but preserve code structure + return result.trim() +} + // FIXME: Why does this component name look like a generic component but is not customizable at all? const ContextMenu: React.FC = ({ children }) => { const { t } = useTranslation() @@ -45,8 +100,12 @@ const ContextMenu: React.FC = ({ children }) => { const onOpenChange = (open: boolean) => { if (open) { - const selectedText = window.getSelection()?.toString() - setSelectedText(selectedText) + const selection = window.getSelection() + if (!selection || selection.rangeCount === 0 || selection.isCollapsed) { + setSelectedText(undefined) + return + } + setSelectedText(extractSelectedText(selection) || undefined) } } diff --git a/src/renderer/src/components/EmojiPicker/index.tsx b/src/renderer/src/components/EmojiPicker/index.tsx index 9a4158d469..c0a21e7c3c 100644 --- a/src/renderer/src/components/EmojiPicker/index.tsx +++ b/src/renderer/src/components/EmojiPicker/index.tsx @@ -45,6 +45,7 @@ const i18nMap: Record = { 'fr-FR': fr, 'ja-JP': ja, 'pt-PT': pt_PT, + 'ro-RO': en, // No Romanian available, fallback to English 'ru-RU': ru_RU } @@ -60,6 +61,7 @@ const dataSourceMap: Record = { 'fr-FR': dataFR, 'ja-JP': dataJA, 'pt-PT': dataPT, + 'ro-RO': dataEN, // No Romanian CLDR available, fallback to English 'ru-RU': dataRU } @@ -75,6 +77,7 @@ const localeMap: Record = { 'fr-FR': 'fr', 'ja-JP': 'ja', 'pt-PT': 'pt', + 'ro-RO': 'en', 'ru-RU': 'ru' } diff --git a/src/renderer/src/components/Icons/SVGIcon.tsx b/src/renderer/src/components/Icons/SVGIcon.tsx index ad503f0e38..3f4e98705c 100644 --- a/src/renderer/src/components/Icons/SVGIcon.tsx +++ b/src/renderer/src/components/Icons/SVGIcon.tsx @@ -113,6 +113,18 @@ export function MdiLightbulbOn(props: SVGProps) { ) } +export function MdiLightbulbQuestion(props: SVGProps) { + // {/* Icon from Material Design Icons by Pictogrammers - https://github.com/Templarian/MaterialDesign/blob/master/LICENSE */} + return ( + + + + ) +} + export function BingLogo(props: SVGProps) { return ( ) { ) } +export function McpLogo(props: SVGProps) { + return ( + + ModelContextProtocol + + + + ) +} + export function PoeLogo(props: SVGProps) { return ( = ({ src, style, ...props }) => { const getContextMenuItems = (src: string, size: number = 14) => { return [ { - key: 'copy-url', + key: 'copy-image', label: t('common.copy'), icon: , + onClick: () => handleCopyImage(src) + }, + { + key: 'copy-url', + label: t('preview.copy.src'), + icon: , onClick: () => { navigator.clipboard.writeText(src) window.toast.success(t('message.copy.success')) @@ -86,12 +92,6 @@ const ImageViewer: React.FC = ({ src, style, ...props }) => { label: t('common.download'), icon: , onClick: () => download(src) - }, - { - key: 'copy-image', - label: t('preview.copy.image'), - icon: , - onClick: () => handleCopyImage(src) } ] } diff --git a/src/renderer/src/components/InputEmbeddingDimension.tsx b/src/renderer/src/components/InputEmbeddingDimension.tsx index 8e6357a91d..baa316324b 100644 --- a/src/renderer/src/components/InputEmbeddingDimension.tsx +++ b/src/renderer/src/components/InputEmbeddingDimension.tsx @@ -1,5 +1,4 @@ import { loggerService } from '@logger' -import AiProvider from '@renderer/aiCore' import { RefreshIcon } from '@renderer/components/Icons' import { useProvider } from '@renderer/hooks/useProvider' import type { Model } from '@renderer/types' @@ -8,6 +7,8 @@ import { Button, InputNumber, Space, Tooltip } from 'antd' import { memo, useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' +import AiProviderNew from '../aiCore/index_new' + const logger = loggerService.withContext('DimensionsInput') interface InputEmbeddingDimensionProps { @@ -47,7 +48,7 @@ const InputEmbeddingDimension = ({ setLoading(true) try { - const aiProvider = new AiProvider(provider) + const aiProvider = new AiProviderNew(provider) const dimension = await aiProvider.getEmbeddingDimensions(model) // for controlled input if (ref?.current) { diff --git a/src/renderer/src/components/MinApp/WebviewContainer.tsx b/src/renderer/src/components/MinApp/WebviewContainer.tsx index 66bb9e554d..dea8243e42 100644 --- a/src/renderer/src/components/MinApp/WebviewContainer.tsx +++ b/src/renderer/src/components/MinApp/WebviewContainer.tsx @@ -106,6 +106,51 @@ const WebviewContainer = memo( // eslint-disable-next-line react-hooks/exhaustive-deps }, [appid, url]) + // Setup keyboard shortcuts handler for print and save + useEffect(() => { + if (!webviewRef.current) return + + const unsubscribe = window.api?.webview?.onFindShortcut?.(async (payload) => { + // Get webviewId when event is triggered + const webviewId = webviewRef.current?.getWebContentsId() + + // Only handle events for this webview + if (!webviewId || payload.webviewId !== webviewId) return + + const key = payload.key?.toLowerCase() + const isModifier = payload.control || payload.meta + + if (!isModifier || !key) return + + try { + if (key === 'p') { + // Print to PDF + logger.info(`Printing webview ${appid} to PDF`) + const filePath = await window.api.webview.printToPDF(webviewId) + if (filePath) { + window.toast?.success?.(`PDF saved to: ${filePath}`) + logger.info(`PDF saved to: ${filePath}`) + } + } else if (key === 's') { + // Save as HTML + logger.info(`Saving webview ${appid} as HTML`) + const filePath = await window.api.webview.saveAsHTML(webviewId) + if (filePath) { + window.toast?.success?.(`HTML saved to: ${filePath}`) + logger.info(`HTML saved to: ${filePath}`) + } + } + } catch (error) { + logger.error(`Failed to handle shortcut for webview ${appid}:`, error as Error) + window.toast?.error?.(`Failed: ${(error as Error).message}`) + } + }) + + return () => { + unsubscribe?.() + } + }, [appid]) + // Update webview settings when they change useEffect(() => { if (!webviewRef.current) return diff --git a/src/renderer/src/components/Popups/ExportToPhoneLanPopup.tsx b/src/renderer/src/components/Popups/ExportToPhoneLanPopup.tsx deleted file mode 100644 index cbe51ac614..0000000000 --- a/src/renderer/src/components/Popups/ExportToPhoneLanPopup.tsx +++ /dev/null @@ -1,553 +0,0 @@ -import { loggerService } from '@logger' -import { AppLogo } from '@renderer/config/env' -import { SettingHelpText, SettingRow } from '@renderer/pages/settings' -import type { WebSocketCandidatesResponse } from '@shared/config/types' -import { Alert, Button, Modal, Progress, Spin } from 'antd' -import { QRCodeSVG } from 'qrcode.react' -import { useCallback, useEffect, useMemo, useState } from 'react' -import { useTranslation } from 'react-i18next' - -import { TopView } from '../TopView' - -const logger = loggerService.withContext('ExportToPhoneLanPopup') - -interface Props { - resolve: (data: any) => void -} - -type ConnectionPhase = 'initializing' | 'waiting_qr_scan' | 'connecting' | 'connected' | 'disconnected' | 'error' -type TransferPhase = 'idle' | 'preparing' | 'sending' | 'completed' | 'error' - -const LoadingQRCode: React.FC = () => { - const { t } = useTranslation() - return ( -
- - - {t('settings.data.export_to_phone.lan.generating_qr')} - -
- ) -} - -const ScanQRCode: React.FC<{ qrCodeValue: string }> = ({ qrCodeValue }) => { - const { t } = useTranslation() - return ( -
- - - {t('settings.data.export_to_phone.lan.scan_qr')} - -
- ) -} - -const ConnectingAnimation: React.FC = () => { - const { t } = useTranslation() - return ( -
-
- - - {t('settings.data.export_to_phone.lan.status.connecting')} - -
-
- ) -} - -const ConnectedDisplay: React.FC = () => { - const { t } = useTranslation() - return ( -
-
- 📱 - - {t('settings.data.export_to_phone.lan.connected')} - -
-
- ) -} - -const ErrorQRCode: React.FC<{ error: string | null }> = ({ error }) => { - const { t } = useTranslation() - return ( -
- ⚠️ - - {t('settings.data.export_to_phone.lan.connection_failed')} - - {error && {error}} -
- ) -} - -const PopupContainer: React.FC = ({ resolve }) => { - const [isOpen, setIsOpen] = useState(true) - const [connectionPhase, setConnectionPhase] = useState('initializing') - const [transferPhase, setTransferPhase] = useState('idle') - const [qrCodeValue, setQrCodeValue] = useState('') - const [selectedFolderPath, setSelectedFolderPath] = useState(null) - const [sendProgress, setSendProgress] = useState(0) - const [error, setError] = useState(null) - const [autoCloseCountdown, setAutoCloseCountdown] = useState(null) - - const { t } = useTranslation() - - // 派生状态 - const isConnected = connectionPhase === 'connected' - const canSend = isConnected && selectedFolderPath && transferPhase === 'idle' - const isSending = transferPhase === 'preparing' || transferPhase === 'sending' - - // 状态文本映射 - const connectionStatusText = useMemo(() => { - const statusMap = { - initializing: t('settings.data.export_to_phone.lan.status.initializing'), - waiting_qr_scan: t('settings.data.export_to_phone.lan.status.waiting_qr_scan'), - connecting: t('settings.data.export_to_phone.lan.status.connecting'), - connected: t('settings.data.export_to_phone.lan.status.connected'), - disconnected: t('settings.data.export_to_phone.lan.status.disconnected'), - error: t('settings.data.export_to_phone.lan.status.error') - } - return statusMap[connectionPhase] - }, [connectionPhase, t]) - - const transferStatusText = useMemo(() => { - const statusMap = { - idle: '', - preparing: t('settings.data.export_to_phone.lan.status.preparing'), - sending: t('settings.data.export_to_phone.lan.status.sending'), - completed: t('settings.data.export_to_phone.lan.status.completed'), - error: t('settings.data.export_to_phone.lan.status.error') - } - return statusMap[transferPhase] - }, [transferPhase, t]) - - // 状态样式映射 - const connectionStatusStyles = useMemo(() => { - const styleMap = { - initializing: { - bg: 'var(--color-background-mute)', - border: 'var(--color-border-mute)' - }, - waiting_qr_scan: { - bg: 'var(--color-primary-mute)', - border: 'var(--color-primary-soft)' - }, - connecting: { bg: 'var(--color-status-warning)', border: 'var(--color-status-warning)' }, - connected: { - bg: 'var(--color-status-success)', - border: 'var(--color-status-success)' - }, - disconnected: { bg: 'var(--color-error)', border: 'var(--color-error)' }, - error: { bg: 'var(--color-error)', border: 'var(--color-error)' } - } - return styleMap[connectionPhase] - }, [connectionPhase]) - - const initWebSocket = useCallback(async () => { - try { - setConnectionPhase('initializing') - await window.api.webSocket.start() - const { port, ip } = await window.api.webSocket.status() - - if (ip && port) { - const candidatesData = await window.api.webSocket.getAllCandidates() - - const optimizeConnectionInfo = () => { - const ipToNumber = (ip: string) => { - return ip.split('.').reduce((acc, octet) => (acc << 8) + parseInt(octet), 0) - } - - const compressedData = [ - 'CSA', - ipToNumber(ip), - candidatesData.map((candidate: WebSocketCandidatesResponse) => ipToNumber(candidate.host)), - port, // 端口号 - Date.now() % 86400000 - ] - - return compressedData - } - - const compressedData = optimizeConnectionInfo() - const qrCodeValue = JSON.stringify(compressedData) - setQrCodeValue(qrCodeValue) - setConnectionPhase('waiting_qr_scan') - } else { - setError(t('settings.data.export_to_phone.lan.error.no_ip')) - setConnectionPhase('error') - } - } catch (error) { - setError( - `${t('settings.data.export_to_phone.lan.error.init_failed')}: ${error instanceof Error ? error.message : ''}` - ) - setConnectionPhase('error') - logger.error('Failed to initialize WebSocket:', error as Error) - } - }, [t]) - - const handleClientConnected = useCallback((_event: any, data: { connected: boolean }) => { - logger.info(`Client connection status: ${data.connected ? 'connected' : 'disconnected'}`) - if (data.connected) { - setConnectionPhase('connected') - setError(null) - } else { - setConnectionPhase('disconnected') - } - }, []) - - const handleMessageReceived = useCallback((_event: any, data: any) => { - logger.info(`Received message from mobile: ${JSON.stringify(data)}`) - }, []) - - const handleSendProgress = useCallback( - (_event: any, data: { progress: number }) => { - const progress = data.progress - setSendProgress(progress) - - if (transferPhase === 'preparing' && progress > 0) { - setTransferPhase('sending') - } - - if (progress >= 100) { - setTransferPhase('completed') - // 启动 3 秒倒计时自动关闭 - setAutoCloseCountdown(3) - } - }, - [transferPhase] - ) - - const handleSelectZip = useCallback(async () => { - const result = await window.api.file.select() - if (result) { - setSelectedFolderPath(result[0].path) - } - }, []) - - const handleSendZip = useCallback(async () => { - if (!selectedFolderPath) { - setError(t('settings.data.export_to_phone.lan.error.no_file')) - return - } - - setTransferPhase('preparing') - setError(null) - setSendProgress(0) - - try { - logger.info(`Starting file transfer: ${selectedFolderPath}`) - await window.api.webSocket.sendFile(selectedFolderPath) - } catch (error) { - setError( - `${t('settings.data.export_to_phone.lan.error.send_failed')}: ${error instanceof Error ? error.message : ''}` - ) - setTransferPhase('error') - logger.error('Failed to send file:', error as Error) - } - }, [selectedFolderPath, t]) - - // 尝试关闭弹窗 - 如果正在传输则显示确认 - const handleCancel = useCallback(() => { - if (isSending) { - window.modal.confirm({ - title: t('settings.data.export_to_phone.lan.confirm_close_title'), - content: t('settings.data.export_to_phone.lan.confirm_close_message'), - centered: true, - okButtonProps: { - danger: true - }, - okText: t('settings.data.export_to_phone.lan.force_close'), - onOk: () => setIsOpen(false) - }) - } else { - setIsOpen(false) - } - }, [isSending, t]) - - // 清理并关闭 - const handleClose = useCallback(async () => { - try { - // 主动断开 WebSocket 连接 - if (isConnected || connectionPhase !== 'disconnected') { - logger.info('Closing popup, stopping WebSocket') - await window.api.webSocket.stop() - } - } catch (error) { - logger.error('Failed to stop WebSocket on close:', error as Error) - } - resolve({}) - }, [resolve, isConnected, connectionPhase]) - - useEffect(() => { - initWebSocket() - - const removeClientConnectedListener = window.electron.ipcRenderer.on( - 'websocket-client-connected', - handleClientConnected - ) - const removeMessageReceivedListener = window.electron.ipcRenderer.on( - 'websocket-message-received', - handleMessageReceived - ) - const removeSendProgressListener = window.electron.ipcRenderer.on('file-send-progress', handleSendProgress) - - return () => { - removeClientConnectedListener() - removeMessageReceivedListener() - removeSendProgressListener() - window.api.webSocket.stop() - } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []) - - // 自动关闭倒计时 - useEffect(() => { - if (autoCloseCountdown === null) return - - if (autoCloseCountdown <= 0) { - logger.debug('Auto-closing popup after transfer completion') - setIsOpen(false) - return - } - - const timer = setTimeout(() => { - setAutoCloseCountdown(autoCloseCountdown - 1) - }, 1000) - - return () => clearTimeout(timer) - }, [autoCloseCountdown]) - - // 状态指示器组件 - const StatusIndicator = useCallback( - () => ( -
- {connectionStatusText} -
- ), - [connectionStatusStyles, connectionStatusText] - ) - - // 二维码显示组件 - 使用显式条件渲染以避免类型不匹配 - const QRCodeDisplay = useCallback(() => { - switch (connectionPhase) { - case 'waiting_qr_scan': - case 'disconnected': - return - case 'initializing': - return - case 'connecting': - return - case 'connected': - return - case 'error': - return - default: - return null - } - }, [connectionPhase, qrCodeValue, error]) - - // 传输进度组件 - const TransferProgress = useCallback(() => { - if (!isSending && transferPhase !== 'completed') return null - - return ( -
-
-
- - {t('settings.data.export_to_phone.lan.transfer_progress')} - - - {transferPhase === 'completed' ? '✅ ' + t('common.completed') : `${Math.round(sendProgress)}%`} - -
- - -
-
- ) - }, [isSending, transferPhase, sendProgress, t]) - - const AutoCloseCountdown = useCallback(() => { - if (transferPhase !== 'completed' || autoCloseCountdown === null || autoCloseCountdown <= 0) return null - - return ( -
- {t('settings.data.export_to_phone.lan.auto_close_tip', { seconds: autoCloseCountdown })} -
- ) - }, [transferPhase, autoCloseCountdown, t]) - - // 错误显示组件 - const ErrorDisplay = useCallback(() => { - if (!error || transferPhase !== 'error') return null - - return ( -
- ❌ {error} -
- ) - }, [error, transferPhase]) - - return ( - - - - - - - - - - - - -
- - -
-
- - - {selectedFolderPath || t('settings.data.export_to_phone.lan.noZipSelected')} - - - - - -
- ) -} - -const TopViewKey = 'ExportToPhoneLanPopup' - -export default class ExportToPhoneLanPopup { - static topviewId = 0 - static hide() { - TopView.hide(TopViewKey) - } - static show() { - return new Promise((resolve) => { - TopView.show( - { - resolve(v) - TopView.hide(TopViewKey) - }} - />, - TopViewKey - ) - }) - } -} diff --git a/src/renderer/src/components/Popups/LanTransferPopup/LanDeviceCard.tsx b/src/renderer/src/components/Popups/LanTransferPopup/LanDeviceCard.tsx new file mode 100644 index 0000000000..db16112e04 --- /dev/null +++ b/src/renderer/src/components/Popups/LanTransferPopup/LanDeviceCard.tsx @@ -0,0 +1,97 @@ +import { cn } from '@renderer/utils' +import type { FC, KeyboardEventHandler } from 'react' +import { useTranslation } from 'react-i18next' + +import { ProgressIndicator } from './ProgressIndicator' +import type { LanDeviceCardProps } from './types' + +export const LanDeviceCard: FC = ({ + service, + transferState, + isConnected, + handshakeInProgress, + isDisabled, + onSendFile +}) => { + const { t } = useTranslation() + + // Device info + const deviceName = service.txt?.modelName || t('common.unknown') + const platform = service.txt?.platform + const appVersion = service.txt?.appVersion + const platformInfo = [platform, appVersion].filter(Boolean).join(' ') + const displayTitle = platformInfo ? `${deviceName} (${platformInfo})` : deviceName + + // Address info + const primaryAddress = service.addresses?.[0] + const addressesWithPort = primaryAddress ? (service.port ? `${primaryAddress}:${service.port}` : primaryAddress) : '' + + // Progress visibility + const shouldShowProgress = + transferState && ['selecting', 'transferring', 'completed', 'failed'].includes(transferState.status) + + // Status text + const statusText = handshakeInProgress + ? t('settings.data.export_to_phone.lan.handshake.in_progress') + : isConnected + ? t('settings.data.export_to_phone.lan.connected') + : t('settings.data.export_to_phone.lan.send_file') + + // Event handlers + const handleClick = () => { + if (isDisabled) return + onSendFile(service.id) + } + + const handleKeyDown: KeyboardEventHandler = (event) => { + if (event.key === 'Enter' || event.key === ' ') { + event.preventDefault() + handleClick() + } + } + + return ( +
+ {/* Header */} +
+
+
{displayTitle}
+ {statusText} +
+
+ + {/* Meta Row - IP Address */} +
+ + {t('settings.data.export_to_phone.lan.ip_addresses')} + + {addressesWithPort || t('common.unknown')} +
+ + {/* Footer with Progress */} +
+ {shouldShowProgress && transferState && ( + + )} +
+
+ ) +} diff --git a/src/renderer/src/components/Popups/LanTransferPopup/ProgressIndicator.tsx b/src/renderer/src/components/Popups/LanTransferPopup/ProgressIndicator.tsx new file mode 100644 index 0000000000..b9707b4485 --- /dev/null +++ b/src/renderer/src/components/Popups/LanTransferPopup/ProgressIndicator.tsx @@ -0,0 +1,55 @@ +import { cn } from '@renderer/utils' +import type { FC } from 'react' +import { useTranslation } from 'react-i18next' + +import type { ProgressIndicatorProps } from './types' + +export const ProgressIndicator: FC = ({ transferState, handshakeInProgress }) => { + const { t } = useTranslation() + + const progressPercent = Math.min(100, Math.max(0, transferState.progress ?? 0)) + + const progressLabel = (() => { + if (transferState.status === 'failed') { + return transferState.error || t('common.unknown_error') + } + if (transferState.status === 'selecting') { + return handshakeInProgress + ? t('settings.data.export_to_phone.lan.handshake.in_progress') + : t('settings.data.export_to_phone.lan.status.preparing') + } + return `${Math.round(progressPercent)}%` + })() + + const isFailed = transferState.status === 'failed' + const isCompleted = transferState.status === 'completed' + + return ( +
+ {/* Label Row */} +
+ {transferState.fileName} + {progressLabel} +
+ + {/* Progress Track */} +
+
+
+
+ ) +} diff --git a/src/renderer/src/components/Popups/LanTransferPopup/hook.ts b/src/renderer/src/components/Popups/LanTransferPopup/hook.ts new file mode 100644 index 0000000000..6d2ea77527 --- /dev/null +++ b/src/renderer/src/components/Popups/LanTransferPopup/hook.ts @@ -0,0 +1,397 @@ +import { loggerService } from '@logger' +import { getBackupData } from '@renderer/services/BackupService' +import type { LocalTransferPeer } from '@shared/config/types' +import { useCallback, useEffect, useMemo, useReducer, useRef } from 'react' +import { useTranslation } from 'react-i18next' + +import type { LanPeerTransferState, LanTransferAction, LanTransferReducerState } from './types' + +const logger = loggerService.withContext('useLanTransfer') + +// ========================================== +// Initial State +// ========================================== + +export const initialState: LanTransferReducerState = { + open: true, + lanState: null, + lanHandshakePeerId: null, + lastHandshakeResult: null, + fileTransferState: {}, + tempBackupPath: null +} + +// ========================================== +// Reducer +// ========================================== + +export function lanTransferReducer(state: LanTransferReducerState, action: LanTransferAction): LanTransferReducerState { + switch (action.type) { + case 'SET_OPEN': + return { ...state, open: action.payload } + + case 'SET_LAN_STATE': + return { ...state, lanState: action.payload } + + case 'SET_HANDSHAKE_PEER_ID': + return { ...state, lanHandshakePeerId: action.payload } + + case 'SET_HANDSHAKE_RESULT': + return { ...state, lastHandshakeResult: action.payload } + + case 'SET_TEMP_BACKUP_PATH': + return { ...state, tempBackupPath: action.payload } + + case 'UPDATE_TRANSFER_STATE': { + const { peerId, state: transferState } = action.payload + return { + ...state, + fileTransferState: { + ...state.fileTransferState, + [peerId]: { + ...(state.fileTransferState[peerId] ?? { progress: 0, status: 'idle' as const }), + ...transferState + } + } + } + } + + case 'SET_TRANSFER_STATE': { + const { peerId, state: transferState } = action.payload + return { + ...state, + fileTransferState: { + ...state.fileTransferState, + [peerId]: transferState + } + } + } + + case 'CLEANUP_STALE_PEERS': { + const activeIds = action.payload + const newFileTransferState: Record = {} + for (const id of Object.keys(state.fileTransferState)) { + if (activeIds.has(id)) { + newFileTransferState[id] = state.fileTransferState[id] + } + } + return { + ...state, + fileTransferState: newFileTransferState, + lastHandshakeResult: + state.lastHandshakeResult && activeIds.has(state.lastHandshakeResult.peerId) + ? state.lastHandshakeResult + : null, + lanHandshakePeerId: + state.lanHandshakePeerId && activeIds.has(state.lanHandshakePeerId) ? state.lanHandshakePeerId : null + } + } + + case 'RESET_CONNECTION_STATE': + return { + ...state, + fileTransferState: {}, + lastHandshakeResult: null, + lanHandshakePeerId: null, + tempBackupPath: null + } + + default: + return state + } +} + +// ========================================== +// Hook Return Type +// ========================================== + +export interface UseLanTransferReturn { + // State + state: LanTransferReducerState + + // Derived values + lanDevices: LocalTransferPeer[] + isAnyTransferring: boolean + lastError: string | undefined + + // Actions + handleSendFile: (peerId: string) => Promise + handleModalCancel: () => void + getTransferState: (peerId: string) => LanPeerTransferState | undefined + isConnected: (peerId: string) => boolean + isHandshakeInProgress: (peerId: string) => boolean + + // Dispatch (for advanced use) + dispatch: React.Dispatch +} + +// ========================================== +// Hook +// ========================================== + +export function useLanTransfer(): UseLanTransferReturn { + const { t } = useTranslation() + const [state, dispatch] = useReducer(lanTransferReducer, initialState) + const isSendingRef = useRef(false) + + // ========================================== + // Derived Values + // ========================================== + + const lanDevices = useMemo(() => state.lanState?.services ?? [], [state.lanState]) + + const isAnyTransferring = useMemo( + () => Object.values(state.fileTransferState).some((s) => s.status === 'transferring' || s.status === 'selecting'), + [state.fileTransferState] + ) + + const lastError = state.lanState?.lastError + + // ========================================== + // LAN State Sync + // ========================================== + + const syncLanState = useCallback(async () => { + if (!window.api?.localTransfer) { + logger.warn('Local transfer bridge is unavailable') + return + } + try { + const nextState = await window.api.localTransfer.getState() + dispatch({ type: 'SET_LAN_STATE', payload: nextState }) + } catch (error) { + logger.error('Failed to sync LAN state', error as Error) + } + }, []) + + // ========================================== + // Send File Handler + // ========================================== + + const handleSendFile = useCallback( + async (peerId: string) => { + if (!window.api?.localTransfer || isSendingRef.current) { + return + } + isSendingRef.current = true + + dispatch({ + type: 'SET_TRANSFER_STATE', + payload: { peerId, state: { progress: 0, status: 'selecting' } } + }) + + let backupPath: string | null = null + + try { + // Step 0: Ensure handshake (connect if needed) + if (!state.lastHandshakeResult?.ack.accepted || state.lastHandshakeResult.peerId !== peerId) { + dispatch({ type: 'SET_HANDSHAKE_PEER_ID', payload: peerId }) + try { + const ack = await window.api.localTransfer.connect({ peerId }) + dispatch({ + type: 'SET_HANDSHAKE_RESULT', + payload: { peerId, ack, timestamp: Date.now() } + }) + if (!ack.accepted) { + throw new Error(ack.message || t('settings.data.export_to_phone.lan.connection_failed')) + } + } finally { + dispatch({ type: 'SET_HANDSHAKE_PEER_ID', payload: null }) + } + } + + // Step 1: Create temporary backup + logger.info('Creating temporary backup for LAN transfer...') + const backupData = await getBackupData() + backupPath = await window.api.backup.createLanTransferBackup(backupData) + dispatch({ type: 'SET_TEMP_BACKUP_PATH', payload: backupPath }) + + // Extract filename from path + const fileName = backupPath.split(/[/\\]/).pop() || 'backup.zip' + + // Step 2: Set transferring state + dispatch({ + type: 'UPDATE_TRANSFER_STATE', + payload: { peerId, state: { fileName, progress: 0, status: 'transferring' } } + }) + + // Step 3: Send file + logger.info(`Sending backup file: ${backupPath}`) + const result = await window.api.localTransfer.sendFile(backupPath) + + if (result.success) { + dispatch({ + type: 'UPDATE_TRANSFER_STATE', + payload: { peerId, state: { progress: 100, status: 'completed' } } + }) + } else { + dispatch({ + type: 'UPDATE_TRANSFER_STATE', + payload: { peerId, state: { status: 'failed', error: result.error } } + }) + } + } catch (error) { + const message = error instanceof Error ? error.message : String(error) + dispatch({ + type: 'UPDATE_TRANSFER_STATE', + payload: { peerId, state: { status: 'failed', error: message } } + }) + logger.error('Failed to send file', error as Error) + } finally { + // Step 4: Clean up temp file + if (backupPath) { + try { + await window.api.backup.deleteTempBackup(backupPath) + logger.info('Cleaned up temporary backup file') + } catch (cleanupError) { + logger.warn('Failed to clean up temp backup', cleanupError as Error) + } + dispatch({ type: 'SET_TEMP_BACKUP_PATH', payload: null }) + } + isSendingRef.current = false + } + }, + [state.lastHandshakeResult, t] + ) + + // ========================================== + // Teardown + // ========================================== + + // Use ref to track temp backup path for cleanup without causing effect re-runs + const tempBackupPathRef = useRef(null) + tempBackupPathRef.current = state.tempBackupPath + + const teardownLan = useCallback(async () => { + if (!window.api?.localTransfer) { + return + } + try { + await window.api.localTransfer.cancelTransfer?.() + } catch (error) { + logger.warn('Failed to cancel LAN transfer on close', error as Error) + } + try { + await window.api.localTransfer.disconnect?.() + } catch (error) { + logger.warn('Failed to disconnect LAN on close', error as Error) + } + // Clean up temp backup if exists (use ref to get current value) + if (tempBackupPathRef.current) { + try { + await window.api.backup.deleteTempBackup(tempBackupPathRef.current) + } catch (error) { + logger.warn('Failed to cleanup temp backup on close', error as Error) + } + } + dispatch({ type: 'RESET_CONNECTION_STATE' }) + }, []) // No dependencies - uses ref for current value + + const handleModalCancel = useCallback(() => { + void teardownLan() + dispatch({ type: 'SET_OPEN', payload: false }) + }, [teardownLan]) + + // ========================================== + // Effects + // ========================================== + + // Initial sync and service listener + useEffect(() => { + if (!window.api?.localTransfer) { + return + } + syncLanState() + const removeListener = window.api.localTransfer.onServicesUpdated((lanState) => { + dispatch({ type: 'SET_LAN_STATE', payload: lanState }) + }) + return () => { + removeListener?.() + } + }, [syncLanState]) + + // Client events listener (progress, completion) + useEffect(() => { + if (!window.api?.localTransfer) { + return + } + const removeListener = window.api.localTransfer.onClientEvent((event) => { + const key = event.peerId ?? 'global' + + if (event.type === 'file_transfer_progress') { + dispatch({ + type: 'UPDATE_TRANSFER_STATE', + payload: { + peerId: key, + state: { + transferId: event.transferId, + fileName: event.fileName, + progress: event.progress, + speed: event.speed, + status: 'transferring' + } + } + }) + } else if (event.type === 'file_transfer_complete') { + dispatch({ + type: 'UPDATE_TRANSFER_STATE', + payload: { + peerId: key, + state: { + progress: event.success ? 100 : undefined, + status: event.success ? 'completed' : 'failed', + error: event.error + } + } + }) + } + }) + return () => { + removeListener?.() + } + }, []) + + // Cleanup stale peers when services change + useEffect(() => { + const activeIds = new Set(lanDevices.map((s) => s.id)) + dispatch({ type: 'CLEANUP_STALE_PEERS', payload: activeIds }) + }, [lanDevices]) + + // Cleanup on unmount only (teardownLan is stable with no deps) + useEffect(() => { + return () => { + void teardownLan() + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) + + // ========================================== + // Helper Functions + // ========================================== + + const getTransferState = useCallback((peerId: string) => state.fileTransferState[peerId], [state.fileTransferState]) + + const isConnected = useCallback( + (peerId: string) => + state.lastHandshakeResult?.peerId === peerId && state.lastHandshakeResult?.ack.accepted === true, + [state.lastHandshakeResult] + ) + + const isHandshakeInProgress = useCallback( + (peerId: string) => state.lanHandshakePeerId === peerId, + [state.lanHandshakePeerId] + ) + + return { + state, + lanDevices, + isAnyTransferring, + lastError, + handleSendFile, + handleModalCancel, + getTransferState, + isConnected, + isHandshakeInProgress, + dispatch + } +} diff --git a/src/renderer/src/components/Popups/LanTransferPopup/index.tsx b/src/renderer/src/components/Popups/LanTransferPopup/index.tsx new file mode 100644 index 0000000000..66455f12a1 --- /dev/null +++ b/src/renderer/src/components/Popups/LanTransferPopup/index.tsx @@ -0,0 +1,37 @@ +import { TopView } from '@renderer/components/TopView' + +import { getHideCallback, PopupContainer } from './popup' +import type { PopupResolveData } from './types' + +// Re-export types for external use +export type { LanPeerTransferState } from './types' + +const TopViewKey = 'LanTransferPopup' + +export default class LanTransferPopup { + static topviewId = 0 + + static hide() { + // Try to use the registered callback for proper cleanup, fallback to TopView.hide + const callback = getHideCallback() + if (callback) { + callback() + } else { + TopView.hide(TopViewKey) + } + } + + static show() { + return new Promise((resolve) => { + TopView.show( + { + resolve(v) + TopView.hide(TopViewKey) + }} + />, + TopViewKey + ) + }) + } +} diff --git a/src/renderer/src/components/Popups/LanTransferPopup/popup.tsx b/src/renderer/src/components/Popups/LanTransferPopup/popup.tsx new file mode 100644 index 0000000000..34c53a6ad6 --- /dev/null +++ b/src/renderer/src/components/Popups/LanTransferPopup/popup.tsx @@ -0,0 +1,88 @@ +import { Modal } from 'antd' +import { TriangleAlert } from 'lucide-react' +import type { FC } from 'react' +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' + +import { useLanTransfer } from './hook' +import { LanDeviceCard } from './LanDeviceCard' +import type { PopupContainerProps } from './types' + +// Module-level callback for external hide access +let hideCallback: (() => void) | null = null +export const setHideCallback = (cb: () => void) => { + hideCallback = cb +} +export const getHideCallback = () => hideCallback + +export const PopupContainer: FC = ({ resolve }) => { + const { t } = useTranslation() + + const { + state, + lanDevices, + isAnyTransferring, + lastError, + handleSendFile, + handleModalCancel, + getTransferState, + isConnected, + isHandshakeInProgress + } = useLanTransfer() + + const contentTitle = useMemo(() => t('settings.data.export_to_phone.lan.title'), [t]) + + const onClose = () => resolve({}) + + // Register hide callback for external access + setHideCallback(handleModalCancel) + + return ( + +
+ {/* Error Display */} + {lastError &&
{lastError}
} + + {/* Device List */} +
+ {lanDevices.length === 0 ? ( + // Warning when no devices +
+ + + {t('settings.data.export_to_phone.lan.no_connection_warning')} + +
+ ) : ( + // Device cards + lanDevices.map((service) => { + const transferState = getTransferState(service.id) + const connected = isConnected(service.id) + const handshakeInProgress = isHandshakeInProgress(service.id) + const isCardDisabled = isAnyTransferring || handshakeInProgress + + return ( + + ) + }) + )} +
+
+
+ ) +} diff --git a/src/renderer/src/components/Popups/LanTransferPopup/types.ts b/src/renderer/src/components/Popups/LanTransferPopup/types.ts new file mode 100644 index 0000000000..644541bc27 --- /dev/null +++ b/src/renderer/src/components/Popups/LanTransferPopup/types.ts @@ -0,0 +1,84 @@ +import type { LanHandshakeAckMessage, LocalTransferPeer, LocalTransferState } from '@shared/config/types' + +// ========================================== +// Transfer Status +// ========================================== + +export type TransferStatus = 'idle' | 'selecting' | 'transferring' | 'completed' | 'failed' + +// ========================================== +// Per-Peer Transfer State +// ========================================== + +export interface LanPeerTransferState { + transferId?: string + fileName?: string + progress: number + speed?: number + status: TransferStatus + error?: string +} + +// ========================================== +// Handshake Result +// ========================================== + +export type HandshakeResult = { + peerId: string + ack: LanHandshakeAckMessage + timestamp: number +} | null + +// ========================================== +// Reducer State +// ========================================== + +export interface LanTransferReducerState { + open: boolean + lanState: LocalTransferState | null + lanHandshakePeerId: string | null + lastHandshakeResult: HandshakeResult + fileTransferState: Record + tempBackupPath: string | null +} + +// ========================================== +// Reducer Actions +// ========================================== + +export type LanTransferAction = + | { type: 'SET_OPEN'; payload: boolean } + | { type: 'SET_LAN_STATE'; payload: LocalTransferState | null } + | { type: 'SET_HANDSHAKE_PEER_ID'; payload: string | null } + | { type: 'SET_HANDSHAKE_RESULT'; payload: HandshakeResult } + | { type: 'SET_TEMP_BACKUP_PATH'; payload: string | null } + | { type: 'UPDATE_TRANSFER_STATE'; payload: { peerId: string; state: Partial } } + | { type: 'SET_TRANSFER_STATE'; payload: { peerId: string; state: LanPeerTransferState } } + | { type: 'CLEANUP_STALE_PEERS'; payload: Set } + | { type: 'RESET_CONNECTION_STATE' } + +// ========================================== +// Component Props +// ========================================== + +export interface LanDeviceCardProps { + service: LocalTransferPeer + transferState?: LanPeerTransferState + isConnected: boolean + handshakeInProgress: boolean + isDisabled: boolean + onSendFile: (peerId: string) => void +} + +export interface ProgressIndicatorProps { + transferState: LanPeerTransferState + handshakeInProgress: boolean +} + +export interface PopupResolveData { + // Empty for now, can be extended +} + +export interface PopupContainerProps { + resolve: (data: PopupResolveData) => void +} diff --git a/src/renderer/src/components/Popups/UpdateDialogPopup.tsx b/src/renderer/src/components/Popups/UpdateDialogPopup.tsx index 29afcc0d24..593c882bf5 100644 --- a/src/renderer/src/components/Popups/UpdateDialogPopup.tsx +++ b/src/renderer/src/components/Popups/UpdateDialogPopup.tsx @@ -1,6 +1,7 @@ import { loggerService } from '@logger' import { TopView } from '@renderer/components/TopView' -import { handleSaveData } from '@renderer/store' +import { handleSaveData, useAppDispatch } from '@renderer/store' +import { setUpdateState } from '@renderer/store/runtime' import { Button, Modal } from 'antd' import type { ReleaseNoteInfo, UpdateInfo } from 'builder-util-runtime' import { useEffect, useState } from 'react' @@ -22,6 +23,7 @@ const PopupContainer: React.FC = ({ releaseInfo, resolve }) => { const { t } = useTranslation() const [open, setOpen] = useState(true) const [isInstalling, setIsInstalling] = useState(false) + const dispatch = useAppDispatch() useEffect(() => { if (releaseInfo) { @@ -50,6 +52,11 @@ const PopupContainer: React.FC = ({ releaseInfo, resolve }) => { resolve({}) } + const onIgnore = () => { + dispatch(setUpdateState({ ignore: true })) + setOpen(false) + } + UpdateDialogPopup.hide = onCancel const releaseNotes = releaseInfo?.releaseNotes @@ -69,7 +76,7 @@ const PopupContainer: React.FC = ({ releaseInfo, resolve }) => { centered width={720} footer={[ - , -
- } - type="error" - showIcon - style={{ marginBottom: 16 }} - /> - )} - +
+ + +
= ({ agent, afterSubmit, resolve }) => { />
+ {isWin && ( + +
+ + +
+ + + + {gitBashPathInfo.source === 'manual' && ( + + )} + + {gitBashPathInfo.path && gitBashPathInfo.source === 'auto' && ( + {t('agent.gitBash.autoDiscoveredHint', 'Auto-discovered')} + )} +
+ )} +