diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index b178b306bf..b1b052f90c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -9,6 +9,8 @@ /src/main/data/ @0xfullex /src/renderer/src/data/ @0xfullex /v2-refactor-temp/ @0xfullex +/docs/en/references/data/ @0xfullex +/docs/zh/references/data/ @0xfullex /packages/ui/ @MyPrototypeWhat diff --git a/.github/workflows/auto-i18n.yml b/.github/workflows/auto-i18n.yml index 6141c061fa..7537c4d4a3 100644 --- a/.github/workflows/auto-i18n.yml +++ b/.github/workflows/auto-i18n.yml @@ -54,7 +54,7 @@ jobs: yarn install - name: 🏃‍♀️ Translate - run: yarn sync:i18n && yarn auto:i18n + run: yarn i18n:sync && yarn i18n:translate - name: 🔍 Format run: yarn format @@ -73,7 +73,7 @@ jobs: - name: 🚀 Create Pull Request if changes exist if: steps.git_status.outputs.has_changes == 'true' - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v8 with: token: ${{ secrets.GITHUB_TOKEN }} # Use the built-in GITHUB_TOKEN for bot actions commit-message: "feat(bot): Weekly automated script run" diff --git a/.github/workflows/pr-ci.yml b/.github/workflows/pr-ci.yml index 49a37c9f1e..d0d29ba777 100644 --- a/.github/workflows/pr-ci.yml +++ b/.github/workflows/pr-ci.yml @@ -58,7 +58,7 @@ jobs: run: yarn typecheck - name: i18n Check - run: yarn check:i18n + run: yarn i18n:check - name: Test run: yarn test diff --git a/.github/workflows/sync-to-gitcode.yml b/.github/workflows/sync-to-gitcode.yml index 4462ff6375..53ecae445b 100644 --- a/.github/workflows/sync-to-gitcode.yml +++ b/.github/workflows/sync-to-gitcode.yml @@ -216,6 +216,7 @@ jobs: local filename=$(basename "$file") local max_retries=3 local retry=0 + local curl_status=0 echo "Uploading: $filename" @@ -224,34 +225,45 @@ jobs: while [ $retry -lt $max_retries ]; do # Get upload URL + curl_status=0 UPLOAD_INFO=$(curl -s --connect-timeout 30 --max-time 60 \ -H "Authorization: Bearer ${GITCODE_TOKEN}" \ - "${API_URL}/repos/${GITCODE_OWNER}/${GITCODE_REPO}/releases/${TAG_NAME}/upload_url?file_name=${encoded_filename}") + "${API_URL}/repos/${GITCODE_OWNER}/${GITCODE_REPO}/releases/${TAG_NAME}/upload_url?file_name=${encoded_filename}") || curl_status=$? - UPLOAD_URL=$(echo "$UPLOAD_INFO" | jq -r '.url // empty') + if [ $curl_status -eq 0 ]; then + UPLOAD_URL=$(echo "$UPLOAD_INFO" | jq -r '.url // empty') - if [ -n "$UPLOAD_URL" ]; then - # Write headers to temp file to avoid shell escaping issues - echo "$UPLOAD_INFO" | jq -r '.headers | to_entries[] | "header = \"" + .key + ": " + .value + "\""' > /tmp/upload_headers.txt + if [ -n "$UPLOAD_URL" ]; then + # Write headers to temp file to avoid shell escaping issues + echo "$UPLOAD_INFO" | jq -r '.headers | to_entries[] | "header = \"" + .key + ": " + .value + "\""' > /tmp/upload_headers.txt - # Upload file using PUT with headers from file - UPLOAD_RESPONSE=$(curl -s -w "\n%{http_code}" -X PUT \ - -K /tmp/upload_headers.txt \ - --data-binary "@${file}" \ - "$UPLOAD_URL") + # Upload file using PUT with headers from file + curl_status=0 + UPLOAD_RESPONSE=$(curl -s -w "\n%{http_code}" -X PUT \ + -K /tmp/upload_headers.txt \ + --data-binary "@${file}" \ + "$UPLOAD_URL") || curl_status=$? - HTTP_CODE=$(echo "$UPLOAD_RESPONSE" | tail -n1) - RESPONSE_BODY=$(echo "$UPLOAD_RESPONSE" | sed '$d') + if [ $curl_status -eq 0 ]; then + HTTP_CODE=$(echo "$UPLOAD_RESPONSE" | tail -n1) + RESPONSE_BODY=$(echo "$UPLOAD_RESPONSE" | sed '$d') - if [ "$HTTP_CODE" -ge 200 ] && [ "$HTTP_CODE" -lt 300 ]; then - echo " Uploaded: $filename" - return 0 + if [ "$HTTP_CODE" -ge 200 ] && [ "$HTTP_CODE" -lt 300 ]; then + echo " Uploaded: $filename" + return 0 + else + echo " Failed (HTTP $HTTP_CODE), retry $((retry + 1))/$max_retries" + echo " Response: $RESPONSE_BODY" + fi + else + echo " Upload request failed (curl exit $curl_status), retry $((retry + 1))/$max_retries" + fi else - echo " Failed (HTTP $HTTP_CODE), retry $((retry + 1))/$max_retries" - echo " Response: $RESPONSE_BODY" + echo " Failed to get upload URL, retry $((retry + 1))/$max_retries" + echo " Response: $UPLOAD_INFO" fi else - echo " Failed to get upload URL, retry $((retry + 1))/$max_retries" + echo " Failed to get upload URL (curl exit $curl_status), retry $((retry + 1))/$max_retries" echo " Response: $UPLOAD_INFO" fi diff --git a/.yarn/patches/@ai-sdk-google-npm-2.0.43-689ed559b3.patch b/.yarn/patches/@ai-sdk-google-npm-2.0.49-84720f41bd.patch similarity index 64% rename from .yarn/patches/@ai-sdk-google-npm-2.0.43-689ed559b3.patch rename to .yarn/patches/@ai-sdk-google-npm-2.0.49-84720f41bd.patch index 3015e702ed..67403a6575 100644 --- a/.yarn/patches/@ai-sdk-google-npm-2.0.43-689ed559b3.patch +++ b/.yarn/patches/@ai-sdk-google-npm-2.0.49-84720f41bd.patch @@ -1,5 +1,5 @@ diff --git a/dist/index.js b/dist/index.js -index 51ce7e423934fb717cb90245cdfcdb3dae6780e6..0f7f7009e2f41a79a8669d38c8a44867bbff5e1f 100644 +index d004b415c5841a1969705823614f395265ea5a8a..6b1e0dad4610b0424393ecc12e9114723bbe316b 100644 --- a/dist/index.js +++ b/dist/index.js @@ -474,7 +474,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) { @@ -12,7 +12,7 @@ index 51ce7e423934fb717cb90245cdfcdb3dae6780e6..0f7f7009e2f41a79a8669d38c8a44867 // src/google-generative-ai-options.ts diff --git a/dist/index.mjs b/dist/index.mjs -index f4b77e35c0cbfece85a3ef0d4f4e67aa6dde6271..8d2fecf8155a226006a0bde72b00b6036d4014b6 100644 +index 1780dd2391b7f42224a0b8048c723d2f81222c44..1f12ed14399d6902107ce9b435d7d8e6cc61e06b 100644 --- a/dist/index.mjs +++ b/dist/index.mjs @@ -480,7 +480,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) { @@ -24,3 +24,14 @@ index f4b77e35c0cbfece85a3ef0d4f4e67aa6dde6271..8d2fecf8155a226006a0bde72b00b603 } // src/google-generative-ai-options.ts +@@ -1909,8 +1909,7 @@ function createGoogleGenerativeAI(options = {}) { + } + var google = createGoogleGenerativeAI(); + export { +- VERSION, + createGoogleGenerativeAI, +- google ++ google, VERSION + }; + //# sourceMappingURL=index.mjs.map +\ No newline at end of file diff --git a/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch b/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch deleted file mode 100644 index 2a13c33a78..0000000000 --- a/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch +++ /dev/null @@ -1,140 +0,0 @@ -diff --git a/dist/index.js b/dist/index.js -index 73045a7d38faafdc7f7d2cd79d7ff0e2b031056b..8d948c9ac4ea4b474db9ef3c5491961e7fcf9a07 100644 ---- a/dist/index.js -+++ b/dist/index.js -@@ -421,6 +421,17 @@ var OpenAICompatibleChatLanguageModel = class { - text: reasoning - }); - } -+ if (choice.message.images) { -+ for (const image of choice.message.images) { -+ const match1 = image.image_url.url.match(/^data:([^;]+)/) -+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); -+ content.push({ -+ type: 'file', -+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', -+ data: match2 ? match2[1] : image.image_url.url, -+ }); -+ } -+ } - if (choice.message.tool_calls != null) { - for (const toolCall of choice.message.tool_calls) { - content.push({ -@@ -598,6 +609,17 @@ var OpenAICompatibleChatLanguageModel = class { - delta: delta.content - }); - } -+ if (delta.images) { -+ for (const image of delta.images) { -+ const match1 = image.image_url.url.match(/^data:([^;]+)/) -+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); -+ controller.enqueue({ -+ type: 'file', -+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', -+ data: match2 ? match2[1] : image.image_url.url, -+ }); -+ } -+ } - if (delta.tool_calls != null) { - for (const toolCallDelta of delta.tool_calls) { - const index = toolCallDelta.index; -@@ -765,6 +787,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({ - arguments: import_v43.z.string() - }) - }) -+ ).nullish(), -+ images: import_v43.z.array( -+ import_v43.z.object({ -+ type: import_v43.z.literal('image_url'), -+ image_url: import_v43.z.object({ -+ url: import_v43.z.string(), -+ }) -+ }) - ).nullish() - }), - finish_reason: import_v43.z.string().nullish() -@@ -795,6 +825,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union( - arguments: import_v43.z.string().nullish() - }) - }) -+ ).nullish(), -+ images: import_v43.z.array( -+ import_v43.z.object({ -+ type: import_v43.z.literal('image_url'), -+ image_url: import_v43.z.object({ -+ url: import_v43.z.string(), -+ }) -+ }) - ).nullish() - }).nullish(), - finish_reason: import_v43.z.string().nullish() -diff --git a/dist/index.mjs b/dist/index.mjs -index 1c2b9560bbfbfe10cb01af080aeeed4ff59db29c..2c8ddc4fc9bfc5e7e06cfca105d197a08864c427 100644 ---- a/dist/index.mjs -+++ b/dist/index.mjs -@@ -405,6 +405,17 @@ var OpenAICompatibleChatLanguageModel = class { - text: reasoning - }); - } -+ if (choice.message.images) { -+ for (const image of choice.message.images) { -+ const match1 = image.image_url.url.match(/^data:([^;]+)/) -+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); -+ content.push({ -+ type: 'file', -+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', -+ data: match2 ? match2[1] : image.image_url.url, -+ }); -+ } -+ } - if (choice.message.tool_calls != null) { - for (const toolCall of choice.message.tool_calls) { - content.push({ -@@ -582,6 +593,17 @@ var OpenAICompatibleChatLanguageModel = class { - delta: delta.content - }); - } -+ if (delta.images) { -+ for (const image of delta.images) { -+ const match1 = image.image_url.url.match(/^data:([^;]+)/) -+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); -+ controller.enqueue({ -+ type: 'file', -+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', -+ data: match2 ? match2[1] : image.image_url.url, -+ }); -+ } -+ } - if (delta.tool_calls != null) { - for (const toolCallDelta of delta.tool_calls) { - const index = toolCallDelta.index; -@@ -749,6 +771,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({ - arguments: z3.string() - }) - }) -+ ).nullish(), -+ images: z3.array( -+ z3.object({ -+ type: z3.literal('image_url'), -+ image_url: z3.object({ -+ url: z3.string(), -+ }) -+ }) - ).nullish() - }), - finish_reason: z3.string().nullish() -@@ -779,6 +809,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([ - arguments: z3.string().nullish() - }) - }) -+ ).nullish(), -+ images: z3.array( -+ z3.object({ -+ type: z3.literal('image_url'), -+ image_url: z3.object({ -+ url: z3.string(), -+ }) -+ }) - ).nullish() - }).nullish(), - finish_reason: z3.string().nullish() diff --git a/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch b/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch new file mode 100644 index 0000000000..c17729ef93 --- /dev/null +++ b/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch @@ -0,0 +1,266 @@ +diff --git a/dist/index.d.ts b/dist/index.d.ts +index 48e2f6263c6ee4c75d7e5c28733e64f6ebe92200..00d0729c4a3cbf9a48e8e1e962c7e2b256b75eba 100644 +--- a/dist/index.d.ts ++++ b/dist/index.d.ts +@@ -7,6 +7,7 @@ declare const openaiCompatibleProviderOptions: z.ZodObject<{ + user: z.ZodOptional; + reasoningEffort: z.ZodOptional; + textVerbosity: z.ZodOptional; ++ sendReasoning: z.ZodOptional; + }, z.core.$strip>; + type OpenAICompatibleProviderOptions = z.infer; + +diff --git a/dist/index.js b/dist/index.js +index da237bb35b7fa8e24b37cd861ee73dfc51cdfc72..b3060fbaf010e30b64df55302807828e5bfe0f9a 100644 +--- a/dist/index.js ++++ b/dist/index.js +@@ -41,7 +41,7 @@ function getOpenAIMetadata(message) { + var _a, _b; + return (_b = (_a = message == null ? void 0 : message.providerOptions) == null ? void 0 : _a.openaiCompatible) != null ? _b : {}; + } +-function convertToOpenAICompatibleChatMessages(prompt) { ++function convertToOpenAICompatibleChatMessages({prompt, options}) { + const messages = []; + for (const { role, content, ...message } of prompt) { + const metadata = getOpenAIMetadata({ ...message }); +@@ -91,6 +91,7 @@ function convertToOpenAICompatibleChatMessages(prompt) { + } + case "assistant": { + let text = ""; ++ let reasoning_text = ""; + const toolCalls = []; + for (const part of content) { + const partMetadata = getOpenAIMetadata(part); +@@ -99,6 +100,12 @@ function convertToOpenAICompatibleChatMessages(prompt) { + text += part.text; + break; + } ++ case "reasoning": { ++ if (options.sendReasoning) { ++ reasoning_text += part.text; ++ } ++ break; ++ } + case "tool-call": { + toolCalls.push({ + id: part.toolCallId, +@@ -116,6 +123,7 @@ function convertToOpenAICompatibleChatMessages(prompt) { + messages.push({ + role: "assistant", + content: text, ++ reasoning_content: reasoning_text ?? undefined, + tool_calls: toolCalls.length > 0 ? toolCalls : void 0, + ...metadata + }); +@@ -200,7 +208,8 @@ var openaiCompatibleProviderOptions = import_v4.z.object({ + /** + * Controls the verbosity of the generated text. Defaults to `medium`. + */ +- textVerbosity: import_v4.z.string().optional() ++ textVerbosity: import_v4.z.string().optional(), ++ sendReasoning: import_v4.z.boolean().optional() + }); + + // src/openai-compatible-error.ts +@@ -378,7 +387,7 @@ var OpenAICompatibleChatLanguageModel = class { + reasoning_effort: compatibleOptions.reasoningEffort, + verbosity: compatibleOptions.textVerbosity, + // messages: +- messages: convertToOpenAICompatibleChatMessages(prompt), ++ messages: convertToOpenAICompatibleChatMessages({prompt, options: compatibleOptions}), + // tools: + tools: openaiTools, + tool_choice: openaiToolChoice +@@ -421,6 +430,17 @@ var OpenAICompatibleChatLanguageModel = class { + text: reasoning + }); + } ++ if (choice.message.images) { ++ for (const image of choice.message.images) { ++ const match1 = image.image_url.url.match(/^data:([^;]+)/) ++ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); ++ content.push({ ++ type: 'file', ++ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', ++ data: match2 ? match2[1] : image.image_url.url, ++ }); ++ } ++ } + if (choice.message.tool_calls != null) { + for (const toolCall of choice.message.tool_calls) { + content.push({ +@@ -598,6 +618,17 @@ var OpenAICompatibleChatLanguageModel = class { + delta: delta.content + }); + } ++ if (delta.images) { ++ for (const image of delta.images) { ++ const match1 = image.image_url.url.match(/^data:([^;]+)/) ++ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); ++ controller.enqueue({ ++ type: 'file', ++ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', ++ data: match2 ? match2[1] : image.image_url.url, ++ }); ++ } ++ } + if (delta.tool_calls != null) { + for (const toolCallDelta of delta.tool_calls) { + const index = toolCallDelta.index; +@@ -765,6 +796,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({ + arguments: import_v43.z.string() + }) + }) ++ ).nullish(), ++ images: import_v43.z.array( ++ import_v43.z.object({ ++ type: import_v43.z.literal('image_url'), ++ image_url: import_v43.z.object({ ++ url: import_v43.z.string(), ++ }) ++ }) + ).nullish() + }), + finish_reason: import_v43.z.string().nullish() +@@ -795,6 +834,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union( + arguments: import_v43.z.string().nullish() + }) + }) ++ ).nullish(), ++ images: import_v43.z.array( ++ import_v43.z.object({ ++ type: import_v43.z.literal('image_url'), ++ image_url: import_v43.z.object({ ++ url: import_v43.z.string(), ++ }) ++ }) + ).nullish() + }).nullish(), + finish_reason: import_v43.z.string().nullish() +diff --git a/dist/index.mjs b/dist/index.mjs +index a809a7aa0e148bfd43e01dd7b018568b151c8ad5..565b605eeacd9830b2b0e817e58ad0c5700264de 100644 +--- a/dist/index.mjs ++++ b/dist/index.mjs +@@ -23,7 +23,7 @@ function getOpenAIMetadata(message) { + var _a, _b; + return (_b = (_a = message == null ? void 0 : message.providerOptions) == null ? void 0 : _a.openaiCompatible) != null ? _b : {}; + } +-function convertToOpenAICompatibleChatMessages(prompt) { ++function convertToOpenAICompatibleChatMessages({prompt, options}) { + const messages = []; + for (const { role, content, ...message } of prompt) { + const metadata = getOpenAIMetadata({ ...message }); +@@ -73,6 +73,7 @@ function convertToOpenAICompatibleChatMessages(prompt) { + } + case "assistant": { + let text = ""; ++ let reasoning_text = ""; + const toolCalls = []; + for (const part of content) { + const partMetadata = getOpenAIMetadata(part); +@@ -81,6 +82,12 @@ function convertToOpenAICompatibleChatMessages(prompt) { + text += part.text; + break; + } ++ case "reasoning": { ++ if (options.sendReasoning) { ++ reasoning_text += part.text; ++ } ++ break; ++ } + case "tool-call": { + toolCalls.push({ + id: part.toolCallId, +@@ -98,6 +105,7 @@ function convertToOpenAICompatibleChatMessages(prompt) { + messages.push({ + role: "assistant", + content: text, ++ reasoning_content: reasoning_text ?? undefined, + tool_calls: toolCalls.length > 0 ? toolCalls : void 0, + ...metadata + }); +@@ -182,7 +190,8 @@ var openaiCompatibleProviderOptions = z.object({ + /** + * Controls the verbosity of the generated text. Defaults to `medium`. + */ +- textVerbosity: z.string().optional() ++ textVerbosity: z.string().optional(), ++ sendReasoning: z.boolean().optional() + }); + + // src/openai-compatible-error.ts +@@ -362,7 +371,7 @@ var OpenAICompatibleChatLanguageModel = class { + reasoning_effort: compatibleOptions.reasoningEffort, + verbosity: compatibleOptions.textVerbosity, + // messages: +- messages: convertToOpenAICompatibleChatMessages(prompt), ++ messages: convertToOpenAICompatibleChatMessages({prompt, options: compatibleOptions}), + // tools: + tools: openaiTools, + tool_choice: openaiToolChoice +@@ -405,6 +414,17 @@ var OpenAICompatibleChatLanguageModel = class { + text: reasoning + }); + } ++ if (choice.message.images) { ++ for (const image of choice.message.images) { ++ const match1 = image.image_url.url.match(/^data:([^;]+)/) ++ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); ++ content.push({ ++ type: 'file', ++ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', ++ data: match2 ? match2[1] : image.image_url.url, ++ }); ++ } ++ } + if (choice.message.tool_calls != null) { + for (const toolCall of choice.message.tool_calls) { + content.push({ +@@ -582,6 +602,17 @@ var OpenAICompatibleChatLanguageModel = class { + delta: delta.content + }); + } ++ if (delta.images) { ++ for (const image of delta.images) { ++ const match1 = image.image_url.url.match(/^data:([^;]+)/) ++ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/); ++ controller.enqueue({ ++ type: 'file', ++ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg', ++ data: match2 ? match2[1] : image.image_url.url, ++ }); ++ } ++ } + if (delta.tool_calls != null) { + for (const toolCallDelta of delta.tool_calls) { + const index = toolCallDelta.index; +@@ -749,6 +780,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({ + arguments: z3.string() + }) + }) ++ ).nullish(), ++ images: z3.array( ++ z3.object({ ++ type: z3.literal('image_url'), ++ image_url: z3.object({ ++ url: z3.string(), ++ }) ++ }) + ).nullish() + }), + finish_reason: z3.string().nullish() +@@ -779,6 +818,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([ + arguments: z3.string().nullish() + }) + }) ++ ).nullish(), ++ images: z3.array( ++ z3.object({ ++ type: z3.literal('image_url'), ++ image_url: z3.object({ ++ url: z3.string(), ++ }) ++ }) + ).nullish() + }).nullish(), + finish_reason: z3.string().nullish() diff --git a/.yarn/patches/ollama-ai-provider-v2-npm-1.5.5-8bef249af9.patch b/.yarn/patches/ollama-ai-provider-v2-npm-1.5.5-8bef249af9.patch index ea14381539..c306bef6e5 100644 --- a/.yarn/patches/ollama-ai-provider-v2-npm-1.5.5-8bef249af9.patch +++ b/.yarn/patches/ollama-ai-provider-v2-npm-1.5.5-8bef249af9.patch @@ -7,7 +7,7 @@ index 8dd9b498050dbecd8dd6b901acf1aa8ca38a49af..ed644349c9d38fe2a66b2fb44214f7c1 type OllamaChatModelId = "athene-v2" | "athene-v2:72b" | "aya-expanse" | "aya-expanse:8b" | "aya-expanse:32b" | "codegemma" | "codegemma:2b" | "codegemma:7b" | "codellama" | "codellama:7b" | "codellama:13b" | "codellama:34b" | "codellama:70b" | "codellama:code" | "codellama:python" | "command-r" | "command-r:35b" | "command-r-plus" | "command-r-plus:104b" | "command-r7b" | "command-r7b:7b" | "deepseek-r1" | "deepseek-r1:1.5b" | "deepseek-r1:7b" | "deepseek-r1:8b" | "deepseek-r1:14b" | "deepseek-r1:32b" | "deepseek-r1:70b" | "deepseek-r1:671b" | "deepseek-coder-v2" | "deepseek-coder-v2:16b" | "deepseek-coder-v2:236b" | "deepseek-v3" | "deepseek-v3:671b" | "devstral" | "devstral:24b" | "dolphin3" | "dolphin3:8b" | "exaone3.5" | "exaone3.5:2.4b" | "exaone3.5:7.8b" | "exaone3.5:32b" | "falcon2" | "falcon2:11b" | "falcon3" | "falcon3:1b" | "falcon3:3b" | "falcon3:7b" | "falcon3:10b" | "firefunction-v2" | "firefunction-v2:70b" | "gemma" | "gemma:2b" | "gemma:7b" | "gemma2" | "gemma2:2b" | "gemma2:9b" | "gemma2:27b" | "gemma3" | "gemma3:1b" | "gemma3:4b" | "gemma3:12b" | "gemma3:27b" | "granite3-dense" | "granite3-dense:2b" | "granite3-dense:8b" | "granite3-guardian" | "granite3-guardian:2b" | "granite3-guardian:8b" | "granite3-moe" | "granite3-moe:1b" | "granite3-moe:3b" | "granite3.1-dense" | "granite3.1-dense:2b" | "granite3.1-dense:8b" | "granite3.1-moe" | "granite3.1-moe:1b" | "granite3.1-moe:3b" | "llama2" | "llama2:7b" | "llama2:13b" | "llama2:70b" | "llama3" | "llama3:8b" | "llama3:70b" | "llama3-chatqa" | "llama3-chatqa:8b" | "llama3-chatqa:70b" | "llama3-gradient" | "llama3-gradient:8b" | "llama3-gradient:70b" | "llama3.1" | "llama3.1:8b" | "llama3.1:70b" | "llama3.1:405b" | "llama3.2" | "llama3.2:1b" | "llama3.2:3b" | "llama3.2-vision" | "llama3.2-vision:11b" | "llama3.2-vision:90b" | "llama3.3" | "llama3.3:70b" | "llama4" | "llama4:16x17b" | "llama4:128x17b" | "llama-guard3" | "llama-guard3:1b" | "llama-guard3:8b" | "llava" | "llava:7b" | "llava:13b" | "llava:34b" | "llava-llama3" | "llava-llama3:8b" | "llava-phi3" | "llava-phi3:3.8b" | "marco-o1" | "marco-o1:7b" | "mistral" | "mistral:7b" | "mistral-large" | "mistral-large:123b" | "mistral-nemo" | "mistral-nemo:12b" | "mistral-small" | "mistral-small:22b" | "mixtral" | "mixtral:8x7b" | "mixtral:8x22b" | "moondream" | "moondream:1.8b" | "openhermes" | "openhermes:v2.5" | "nemotron" | "nemotron:70b" | "nemotron-mini" | "nemotron-mini:4b" | "olmo" | "olmo:7b" | "olmo:13b" | "opencoder" | "opencoder:1.5b" | "opencoder:8b" | "phi3" | "phi3:3.8b" | "phi3:14b" | "phi3.5" | "phi3.5:3.8b" | "phi4" | "phi4:14b" | "qwen" | "qwen:7b" | "qwen:14b" | "qwen:32b" | "qwen:72b" | "qwen:110b" | "qwen2" | "qwen2:0.5b" | "qwen2:1.5b" | "qwen2:7b" | "qwen2:72b" | "qwen2.5" | "qwen2.5:0.5b" | "qwen2.5:1.5b" | "qwen2.5:3b" | "qwen2.5:7b" | "qwen2.5:14b" | "qwen2.5:32b" | "qwen2.5:72b" | "qwen2.5-coder" | "qwen2.5-coder:0.5b" | "qwen2.5-coder:1.5b" | "qwen2.5-coder:3b" | "qwen2.5-coder:7b" | "qwen2.5-coder:14b" | "qwen2.5-coder:32b" | "qwen3" | "qwen3:0.6b" | "qwen3:1.7b" | "qwen3:4b" | "qwen3:8b" | "qwen3:14b" | "qwen3:30b" | "qwen3:32b" | "qwen3:235b" | "qwq" | "qwq:32b" | "sailor2" | "sailor2:1b" | "sailor2:8b" | "sailor2:20b" | "shieldgemma" | "shieldgemma:2b" | "shieldgemma:9b" | "shieldgemma:27b" | "smallthinker" | "smallthinker:3b" | "smollm" | "smollm:135m" | "smollm:360m" | "smollm:1.7b" | "tinyllama" | "tinyllama:1.1b" | "tulu3" | "tulu3:8b" | "tulu3:70b" | (string & {}); declare const ollamaProviderOptions: z.ZodObject<{ - think: z.ZodOptional; -+ think: z.ZodOptional]>>; ++ think: z.ZodOptional, z.ZodLiteral<"medium">, z.ZodLiteral<"high">]>>; options: z.ZodOptional; repeat_last_n: z.ZodOptional; @@ -29,7 +29,7 @@ index 8dd9b498050dbecd8dd6b901acf1aa8ca38a49af..ed644349c9d38fe2a66b2fb44214f7c1 declare const ollamaCompletionProviderOptions: z.ZodObject<{ - think: z.ZodOptional; -+ think: z.ZodOptional]>>; ++ think: z.ZodOptional, z.ZodLiteral<"medium">, z.ZodLiteral<"high">]>>; user: z.ZodOptional; suffix: z.ZodOptional; echo: z.ZodOptional; @@ -42,7 +42,7 @@ index 35b5142ce8476ce2549ed7c2ec48e7d8c46c90d9..2ef64dc9a4c2be043e6af608241a6a83 // src/completion/ollama-completion-language-model.ts var ollamaCompletionProviderOptions = import_v42.z.object({ - think: import_v42.z.boolean().optional(), -+ think: import_v42.z.union([import_v42.z.boolean(), import_v42.z.enum(['low', 'medium', 'high'])]).optional(), ++ think: import_v42.z.union([import_v42.z.boolean(), import_v42.z.literal('low'), import_v42.z.literal('medium'), import_v42.z.literal('high')]).optional(), user: import_v42.z.string().optional(), suffix: import_v42.z.string().optional(), echo: import_v42.z.boolean().optional() @@ -64,7 +64,7 @@ index 35b5142ce8476ce2549ed7c2ec48e7d8c46c90d9..2ef64dc9a4c2be043e6af608241a6a83 * Only supported by certain models like DeepSeek R1 and Qwen 3. */ - think: import_v44.z.boolean().optional(), -+ think: import_v44.z.union([import_v44.z.boolean(), import_v44.z.enum(['low', 'medium', 'high'])]).optional(), ++ think: import_v44.z.union([import_v44.z.boolean(), import_v44.z.literal('low'), import_v44.z.literal('medium'), import_v44.z.literal('high')]).optional(), options: import_v44.z.object({ num_ctx: import_v44.z.number().optional(), repeat_last_n: import_v44.z.number().optional(), @@ -97,7 +97,7 @@ index e2a634a78d80ac9542f2cc4f96cf2291094b10cf..67b23efce3c1cf4f026693d3ff924698 // src/completion/ollama-completion-language-model.ts var ollamaCompletionProviderOptions = z2.object({ - think: z2.boolean().optional(), -+ think: z2.union([z2.boolean(), z2.enum(['low', 'medium', 'high'])]).optional(), ++ think: z2.union([z2.boolean(), z2.literal('low'), z2.literal('medium'), z2.literal('high')]).optional(), user: z2.string().optional(), suffix: z2.string().optional(), echo: z2.boolean().optional() @@ -119,7 +119,7 @@ index e2a634a78d80ac9542f2cc4f96cf2291094b10cf..67b23efce3c1cf4f026693d3ff924698 * Only supported by certain models like DeepSeek R1 and Qwen 3. */ - think: z4.boolean().optional(), -+ think: z4.union([z4.boolean(), z4.enum(['low', 'medium', 'high'])]).optional(), ++ think: z4.union([z4.boolean(), z4.literal('low'), z4.literal('medium'), z4.literal('high')]).optional(), options: z4.object({ num_ctx: z4.number().optional(), repeat_last_n: z4.number().optional(), diff --git a/CLAUDE.md b/CLAUDE.md index 8ee235d75a..e448c2b487 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -29,7 +29,7 @@ When creating a Pull Request, you MUST: - **Development**: `yarn dev` - Runs Electron app in development mode with hot reload - **Debug**: `yarn debug` - Starts with debugging enabled, use `chrome://inspect` to attach debugger - **Build Check**: `yarn build:check` - **REQUIRED** before commits (lint + test + typecheck) - - If having i18n sort issues, run `yarn sync:i18n` first to sync template + - If having i18n sort issues, run `yarn i18n:sync` first to sync template - If having formatting issues, run `yarn format` first - **Test**: `yarn test` - Run all tests (Vitest) across main and renderer processes - **Single Test**: @@ -41,39 +41,24 @@ When creating a Pull Request, you MUST: ## Project Architecture ### Electron Structure + - **Main Process** (`src/main/`): Node.js backend with services (MCP, Knowledge, Storage, etc.) -- **Renderer Process** (`src/renderer/`): React UI with Redux state management +- **Renderer Process** (`src/renderer/`): React UI - **Preload Scripts** (`src/preload/`): Secure IPC bridge ### Key Architectural Components -#### Main Process Services (`src/main/services/`) - -- **MCPService**: Model Context Protocol server management -- **KnowledgeService**: Document processing and knowledge base management -- **FileStorage/S3Storage/WebDav**: Multiple storage backends -- **WindowService**: Multi-window management (main, mini, selection windows) -- **ProxyManager**: Network proxy handling -- **SearchService**: Full-text search capabilities - -#### AI Core (`src/renderer/src/aiCore/`) - -- **Middleware System**: Composable pipeline for AI request processing -- **Client Factory**: Supports multiple AI providers (OpenAI, Anthropic, Gemini, etc.) -- **Stream Processing**: Real-time response handling - #### Data Management -- **Cache System**: Three-layer caching (memory/shared/persist) with React hooks integration -- **Preferences**: Type-safe configuration management with multi-window synchronization -- **User Data**: SQLite-based storage with Drizzle ORM for business data +**MUST READ**: [docs/en/references/data/README.md](docs/en/references/data/README.md) for system selection, architecture, and patterns. -#### Knowledge Management +| System | Use Case | APIs | +| ---------- | ---------------------------- | ----------------------------------------------- | +| Cache | Temp data (can lose) | `useCache`, `useSharedCache`, `usePersistCache` | +| Preference | User settings | `usePreference` | +| DataApi | Business data (**critical**) | `useQuery`, `useMutation` | -- **Embeddings**: Vector search with multiple providers (OpenAI, Voyage, etc.) -- **OCR**: Document text extraction (system OCR, Doc2x, Mineru) -- **Preprocessing**: Document preparation pipeline -- **Loaders**: Support for various file formats (PDF, DOCX, EPUB, etc.) +Database: SQLite + Drizzle ORM, schemas in `src/main/data/db/schemas/`, migrations via `yarn db:migrations:generate` ### Build System @@ -98,63 +83,36 @@ When creating a Pull Request, you MUST: - **Multi-language Support**: i18n with dynamic loading - **Theme System**: Light/dark themes with custom CSS variables -### UI Design +## v2 Refactoring (In Progress) -The project is in the process of migrating from antd & styled-components to Tailwind CSS and Shadcn UI. Please use components from `@packages/ui` to build UI components. The use of antd and styled-components is prohibited. +The v2 branch is undergoing a major refactoring effort: -UI Library: `@packages/ui` +### Data Layer -### Database Architecture +- **Removing**: Redux, Dexie +- **Adopting**: Cache / Preference / DataApi architecture (see [Data Management](#data-management)) -- **Database**: SQLite (`cherrystudio.sqlite`) + libsql driver -- **ORM**: Drizzle ORM with comprehensive migration system -- **Schemas**: Located in `src/main/data/db/schemas/` directory +### UI Layer -#### Database Standards +- **Removing**: antd, HeroUI, styled-components +- **Adopting**: `@cherrystudio/ui` (located in `packages/ui`, Tailwind CSS + Shadcn UI) +- **Prohibited**: antd, HeroUI, styled-components -- **Table Naming**: Use singular form with snake_case (e.g., `topic`, `message`, `app_state`) -- **Schema Exports**: Export using `xxxTable` pattern (e.g., `topicTable`, `appStateTable`) -- **Field Definition**: Drizzle auto-infers field names, no need to add default field names -- **JSON Fields**: For JSON support, add `{ mode: 'json' }`, refer to `preference.ts` table definition -- **JSON Serialization**: For JSON fields, no need to manually serialize/deserialize when reading/writing to database, Drizzle handles this automatically -- **Timestamps**: Use existing `crudTimestamps` utility -- **Migrations**: Generate via `yarn run migrations:generate` +### File Naming Convention -## Data Access Patterns +During migration, use `*.v2.ts` suffix for files not yet fully migrated: -The application uses three distinct data management systems. Choose the appropriate system based on data characteristics: - -### Cache System -- **Purpose**: Temporary data that can be regenerated -- **Lifecycle**: Component-level (memory), window-level (shared), or persistent (survives restart) -- **Use Cases**: API response caching, computed results, temporary UI state -- **APIs**: `useCache`, `useSharedCache`, `usePersistCache` hooks, or `cacheService` - -### Preference System -- **Purpose**: User configuration and application settings -- **Lifecycle**: Permanent until user changes -- **Use Cases**: Theme, language, editor settings, user preferences -- **APIs**: `usePreference`, `usePreferences` hooks, or `preferenceService` - -### User Data API -- **Purpose**: Core business data (conversations, files, notes, etc.) -- **Lifecycle**: Permanent business records -- **Use Cases**: Topics, messages, files, knowledge base, user-generated content -- **APIs**: `useDataApi` hook or `dataApiService` for direct calls - -### Selection Guidelines - -- **Use Cache** for data that can be lost without impact (computed values, API responses) -- **Use Preferences** for user settings that affect app behavior (UI configuration, feature flags) -- **Use User Data API** for irreplaceable business data (conversations, documents, user content) +- Indicates work-in-progress refactoring +- Avoids conflicts with existing code +- **Post-completion**: These files will be renamed or merged into their final locations ## Logging Standards ### Usage ```typescript -import { loggerService } from '@logger' -const logger = loggerService.withContext('moduleName') +import { loggerService } from "@logger"; +const logger = loggerService.withContext("moduleName"); // Renderer: loggerService.initWindowSource('windowName') first -logger.info('message', CONTEXT) +logger.info("message", CONTEXT); ``` diff --git a/README.md b/README.md index f790c10cbd..781e9299e5 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ -

English | 中文 | Official Site | Documents | Development | Feedback

+

English | 中文 | Official Site | Documents | Development | Feedback

@@ -242,12 +242,12 @@ The Enterprise Edition addresses core challenges in team collaboration by centra ## Version Comparison -| Feature | Community Edition | Enterprise Edition | -| :---------------- | :----------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------- | -| **Open Source** | ✅ Yes | ⭕️ Partially released to customers | +| Feature | Community Edition | Enterprise Edition | +| :---------------- | :----------------------------------------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------- | +| **Open Source** | ✅ Yes | ⭕️ Partially released to customers | | **Cost** | [AGPL-3.0 License](https://github.com/CherryHQ/cherry-studio?tab=AGPL-3.0-1-ov-file) | Buyout / Subscription Fee | -| **Admin Backend** | — | ● Centralized **Model** Access
● **Employee** Management
● Shared **Knowledge Base**
● **Access** Control
● **Data** Backup | -| **Server** | — | ✅ Dedicated Private Deployment | +| **Admin Backend** | — | ● Centralized **Model** Access
● **Employee** Management
● Shared **Knowledge Base**
● **Access** Control
● **Data** Backup | +| **Server** | — | ✅ Dedicated Private Deployment | ## Get the Enterprise Edition @@ -275,7 +275,7 @@ We believe the Enterprise Edition will become your team's AI productivity engine # 📊 GitHub Stats -![Stats](https://repobeats.axiom.co/api/embed/a693f2e5f773eed620f70031e974552156c7f397.svg 'Repobeats analytics image') +![Stats](https://repobeats.axiom.co/api/embed/a693f2e5f773eed620f70031e974552156c7f397.svg "Repobeats analytics image") # ⭐️ Star History diff --git a/biome.jsonc b/biome.jsonc index c4b777d6a5..90da1a238d 100644 --- a/biome.jsonc +++ b/biome.jsonc @@ -23,7 +23,7 @@ }, "files": { "ignoreUnknown": false, - "includes": ["**", "!**/.claude/**", "!**/.vscode/**"], + "includes": ["**", "!**/.claude/**", "!**/.vscode/**", "!**/.conductor/**"], "maxSize": 2097152 }, "formatter": { diff --git a/build/nsis-installer.nsh b/build/nsis-installer.nsh index 769ccaaa19..e644e18f3d 100644 --- a/build/nsis-installer.nsh +++ b/build/nsis-installer.nsh @@ -12,8 +12,13 @@ ; https://github.com/electron-userland/electron-builder/issues/1122 !ifndef BUILD_UNINSTALLER + ; Check VC++ Redistributable based on architecture stored in $1 Function checkVCRedist - ReadRegDWORD $0 HKLM "SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\x64" "Installed" + ${If} $1 == "arm64" + ReadRegDWORD $0 HKLM "SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\ARM64" "Installed" + ${Else} + ReadRegDWORD $0 HKLM "SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\x64" "Installed" + ${EndIf} FunctionEnd Function checkArchitectureCompatibility @@ -97,29 +102,47 @@ Call checkVCRedist ${If} $0 != "1" - MessageBox MB_YESNO "\ - NOTE: ${PRODUCT_NAME} requires $\r$\n\ - 'Microsoft Visual C++ Redistributable'$\r$\n\ - to function properly.$\r$\n$\r$\n\ - Download and install now?" /SD IDYES IDYES InstallVCRedist IDNO DontInstall - InstallVCRedist: - inetc::get /CAPTION " " /BANNER "Downloading Microsoft Visual C++ Redistributable..." "https://aka.ms/vs/17/release/vc_redist.x64.exe" "$TEMP\vc_redist.x64.exe" - ExecWait "$TEMP\vc_redist.x64.exe /install /norestart" - ;IfErrors InstallError ContinueInstall ; vc_redist exit code is unreliable :( - Call checkVCRedist - ${If} $0 == "1" - Goto ContinueInstall - ${EndIf} + ; VC++ is required - install automatically since declining would abort anyway + ; Select download URL based on system architecture (stored in $1) + ${If} $1 == "arm64" + StrCpy $2 "https://aka.ms/vs/17/release/vc_redist.arm64.exe" + StrCpy $3 "$TEMP\vc_redist.arm64.exe" + ${Else} + StrCpy $2 "https://aka.ms/vs/17/release/vc_redist.x64.exe" + StrCpy $3 "$TEMP\vc_redist.x64.exe" + ${EndIf} - ;InstallError: - MessageBox MB_ICONSTOP "\ - There was an unexpected error installing$\r$\n\ - Microsoft Visual C++ Redistributable.$\r$\n\ - The installation of ${PRODUCT_NAME} cannot continue." - DontInstall: + inetc::get /CAPTION " " /BANNER "Downloading Microsoft Visual C++ Redistributable..." \ + $2 $3 /END + Pop $0 ; Get download status from inetc::get + ${If} $0 != "OK" + MessageBox MB_ICONSTOP|MB_YESNO "\ + Failed to download Microsoft Visual C++ Redistributable.$\r$\n$\r$\n\ + Error: $0$\r$\n$\r$\n\ + Would you like to open the download page in your browser?$\r$\n\ + $2" IDYES openDownloadUrl IDNO skipDownloadUrl + openDownloadUrl: + ExecShell "open" $2 + skipDownloadUrl: Abort + ${EndIf} + + ExecWait "$3 /install /quiet /norestart" + ; Note: vc_redist exit code is unreliable, verify via registry check instead + + Call checkVCRedist + ${If} $0 != "1" + MessageBox MB_ICONSTOP|MB_YESNO "\ + Microsoft Visual C++ Redistributable installation failed.$\r$\n$\r$\n\ + Would you like to open the download page in your browser?$\r$\n\ + $2$\r$\n$\r$\n\ + The installation of ${PRODUCT_NAME} cannot continue." IDYES openInstallUrl IDNO skipInstallUrl + openInstallUrl: + ExecShell "open" $2 + skipInstallUrl: + Abort + ${EndIf} ${EndIf} - ContinueInstall: Pop $4 Pop $3 Pop $2 diff --git a/docs/en/guides/development.md b/docs/en/guides/development.md index fe67742768..032a515f61 100644 --- a/docs/en/guides/development.md +++ b/docs/en/guides/development.md @@ -36,7 +36,7 @@ yarn install ### ENV ```bash -copy .env.example .env +cp .env.example .env ``` ### Start diff --git a/docs/en/guides/i18n.md b/docs/en/guides/i18n.md index a3284e3ab9..7fccfde695 100644 --- a/docs/en/guides/i18n.md +++ b/docs/en/guides/i18n.md @@ -71,7 +71,7 @@ Tools like i18n Ally cannot parse dynamic content within template strings, resul ```javascript // Not recommended - Plugin cannot resolve -const message = t(`fruits.${fruit}`) +const message = t(`fruits.${fruit}`); ``` #### 2. **No Real-time Rendering in Editor** @@ -91,14 +91,14 @@ For example: ```ts // src/renderer/src/i18n/label.ts const themeModeKeyMap = { - dark: 'settings.theme.dark', - light: 'settings.theme.light', - system: 'settings.theme.system' -} as const + dark: "settings.theme.dark", + light: "settings.theme.light", + system: "settings.theme.system", +} as const; export const getThemeModeLabel = (key: string): string => { - return themeModeKeyMap[key] ? t(themeModeKeyMap[key]) : key -} + return themeModeKeyMap[key] ? t(themeModeKeyMap[key]) : key; +}; ``` By avoiding template strings, you gain better developer experience, more reliable translation checks, and a more maintainable codebase. @@ -107,7 +107,7 @@ By avoiding template strings, you gain better developer experience, more reliabl The project includes several scripts to automate i18n-related tasks: -### `check:i18n` - Validate i18n Structure +### `i18n:check` - Validate i18n Structure This script checks: @@ -116,10 +116,10 @@ This script checks: - Whether keys are properly sorted ```bash -yarn check:i18n +yarn i18n:check ``` -### `sync:i18n` - Synchronize JSON Structure and Sort Order +### `i18n:sync` - Synchronize JSON Structure and Sort Order This script uses `zh-cn.json` as the source of truth to sync structure across all language files, including: @@ -128,14 +128,14 @@ This script uses `zh-cn.json` as the source of truth to sync structure across al 3. Sorting keys automatically ```bash -yarn sync:i18n +yarn i18n:sync ``` -### `auto:i18n` - Automatically Translate Pending Texts +### `i18n:translate` - Automatically Translate Pending Texts This script fills in texts marked as `[to be translated]` using machine translation. -Typically, after adding new texts in `zh-cn.json`, run `sync:i18n`, then `auto:i18n` to complete translations. +Typically, after adding new texts in `zh-cn.json`, run `i18n:sync`, then `i18n:translate` to complete translations. Before using this script, set the required environment variables: @@ -148,30 +148,20 @@ MODEL="qwen-plus-latest" Alternatively, add these variables directly to your `.env` file. ```bash -yarn auto:i18n -``` - -### `update:i18n` - Object-level Translation Update - -Updates translations in language files under `src/renderer/src/i18n/translate` at the object level, preserving existing translations and only updating new content. - -**Not recommended** — prefer `auto:i18n` for translation tasks. - -```bash -yarn update:i18n +yarn i18n:translate ``` ### Workflow 1. During development, first add the required text in `zh-cn.json` 2. Confirm it displays correctly in the Chinese environment -3. Run `yarn sync:i18n` to propagate the keys to other language files -4. Run `yarn auto:i18n` to perform machine translation +3. Run `yarn i18n:sync` to propagate the keys to other language files +4. Run `yarn i18n:translate` to perform machine translation 5. Grab a coffee and let the magic happen! ## Best Practices 1. **Use Chinese as Source Language**: All development starts in Chinese, then translates to other languages. -2. **Run Check Script Before Commit**: Use `yarn check:i18n` to catch i18n issues early. +2. **Run Check Script Before Commit**: Use `yarn i18n:check` to catch i18n issues early. 3. **Translate in Small Increments**: Avoid accumulating a large backlog of untranslated content. 4. **Keep Keys Semantically Clear**: Keys should clearly express their purpose, e.g., `user.profile.avatar.upload.error` diff --git a/docs/en/references/data/README.md b/docs/en/references/data/README.md new file mode 100644 index 0000000000..abccd93508 --- /dev/null +++ b/docs/en/references/data/README.md @@ -0,0 +1,197 @@ +# Data System Reference + +This is the main entry point for Cherry Studio's data management documentation. The application uses three distinct data systems based on data characteristics. + +## Quick Navigation + +### System Overview (Architecture) +- [Cache Overview](./cache-overview.md) - Three-tier caching architecture +- [Preference Overview](./preference-overview.md) - User settings management +- [DataApi Overview](./data-api-overview.md) - Business data API architecture + +### Usage Guides (Code Examples) +- [Cache Usage](./cache-usage.md) - useCache hooks, CacheService examples +- [Preference Usage](./preference-usage.md) - usePreference hook, PreferenceService examples +- [DataApi in Renderer](./data-api-in-renderer.md) - useQuery/useMutation, DataApiService +- [DataApi in Main](./data-api-in-main.md) - Handlers, Services, Repositories patterns + +### Reference Guides (Coding Standards) +- [API Design Guidelines](./api-design-guidelines.md) - RESTful design rules +- [Database Patterns](./database-patterns.md) - DB naming, schema patterns +- [API Types](./api-types.md) - API type system, schemas, error handling +- [V2 Migration Guide](./v2-migration-guide.md) - Migration system + +### Testing +- [Test Mocks](../../../../tests/__mocks__/README.md) - Unified mocks for Cache, Preference, and DataApi + +--- + +## Choosing the Right System + +### Quick Decision Table + +| Service | Data Characteristics | Lifecycle | Data Loss Impact | Examples | +|---------|---------------------|-----------|------------------|----------| +| **CacheService** | Regenerable, temporary | ≤ App process or survives restart | None to minimal | API responses, computed results, UI state | +| **PreferenceService** | User settings, key-value | Permanent until changed | Low (can rebuild) | Theme, language, font size, shortcuts | +| **DataApiService** | Business data, structured | Permanent | **Severe** (irreplaceable) | Topics, messages, files, knowledge base | + +### Decision Flowchart + +Ask these questions in order: + +1. **Can this data be regenerated or lost without affecting the user?** + - Yes → **CacheService** + - No → Continue to #2 + +2. **Is this a user-configurable setting that affects app behavior?** + - Yes → Does it have a fixed key and stable value structure? + - Yes → **PreferenceService** + - No (structure changes often) → **DataApiService** + - No → Continue to #3 + +3. **Is this business data created/accumulated through user activity?** + - Yes → **DataApiService** + - No → Reconsider #1 (most data falls into one of these categories) + +--- + +## System Characteristics + +### CacheService - Runtime & Cache Data + +Use CacheService when: +- Data can be **regenerated or lost without user impact** +- No backup or cross-device synchronization needed +- Lifecycle is tied to component, window, or app session + +**Two sub-categories**: +1. **Performance cache**: Computed results, API responses, expensive calculations +2. **UI state cache**: Temporary settings, scroll positions, panel states + +**Three tiers based on persistence needs**: +- `useCache` (memory): Lost on app restart, component-level sharing +- `useSharedCache` (shared): Cross-window sharing, lost on restart +- `usePersistCache` (persist): Survives app restarts via localStorage + +```typescript +// Good: Temporary computed results +const [searchResults, setSearchResults] = useCache('search.results', []) + +// Good: UI state that can be lost +const [sidebarCollapsed, setSidebarCollapsed] = useSharedCache('ui.sidebar.collapsed', false) + +// Good: Recent items (nice to have, not critical) +const [recentSearches, setRecentSearches] = usePersistCache('search.recent', []) +``` + +### PreferenceService - User Preferences + +Use PreferenceService when: +- Data is a **user-modifiable setting that affects app behavior** +- Structure is key-value with **predefined keys** (users modify values, not keys) +- **Value structure is stable** (won't change frequently) +- Data loss has **low impact** (user can reconfigure) + +**Key characteristics**: +- Auto-syncs across all windows +- Each preference item should be **atomic** (one setting = one key) +- Values are typically: boolean, string, number, or simple array/object + +```typescript +// Good: App behavior settings +const [theme, setTheme] = usePreference('app.theme.mode') +const [language, setLanguage] = usePreference('app.language') +const [fontSize, setFontSize] = usePreference('chat.message.font_size') + +// Good: Feature toggles +const [showTimestamp, setShowTimestamp] = usePreference('chat.display.show_timestamp') +``` + +### DataApiService - User Data + +Use DataApiService when: +- Data is **business data accumulated through user activity** +- Data is **structured with dedicated schemas/tables** +- Users can **create, delete, modify records** (no fixed limit) +- Data loss would be **severe and irreplaceable** +- Data volume can be **large** (potentially GBs) + +**Key characteristics**: +- No automatic window sync (fetch on demand for fresh data) +- May contain sensitive data (encryption consideration) +- Requires proper CRUD operations and transactions + +```typescript +// Good: User-generated business data +const { data: topics } = useQuery('/topics') +const { trigger: createTopic } = useMutation('/topics', 'POST') + +// Good: Conversation history (irreplaceable) +const { data: messages } = useQuery('/messages', { query: { topicId } }) + +// Good: User files and knowledge base +const { data: files } = useQuery('/files') +``` + +--- + +## Common Anti-patterns + +| Wrong Choice | Why It's Wrong | Correct Choice | +|--------------|----------------|----------------| +| Storing AI provider configs in Cache | User loses configured providers on restart | **PreferenceService** | +| Storing conversation history in Preferences | Unbounded growth, complex structure | **DataApiService** | +| Storing topic list in Preferences | User-created records, can grow large | **DataApiService** | +| Storing theme/language in DataApi | Overkill for simple key-value settings | **PreferenceService** | +| Storing API responses in DataApi | Regenerable data, doesn't need persistence | **CacheService** | +| Storing window positions in Preferences | Can be lost without impact | **CacheService** (persist tier) | + +## Edge Cases + +- **Recently used items** (e.g., recent files, recent searches): Use `usePersistCache` - nice to have but not critical if lost +- **Draft content** (e.g., unsaved message): Use `useSharedCache` for cross-window, consider auto-save to DataApi for recovery +- **Computed statistics**: Use `useCache` with TTL - regenerate when expired +- **User-created templates/presets**: Use **DataApiService** - user-generated content that can grow + +--- + +## Architecture Overview + +``` +┌─────────────────┐ +│ React Components│ +└─────────┬───────┘ + │ +┌─────────▼───────┐ +│ React Hooks │ ← useDataApi, usePreference, useCache +└─────────┬───────┘ + │ +┌─────────▼───────┐ +│ Services │ ← DataApiService, PreferenceService, CacheService +└─────────┬───────┘ + │ +┌─────────▼───────┐ +│ IPC Layer │ ← Main Process Communication +└─────────────────┘ +``` + +## Related Source Code + +### Type Definitions +- `packages/shared/data/api/` - API type system +- `packages/shared/data/cache/` - Cache type definitions +- `packages/shared/data/preference/` - Preference type definitions + +### Main Process Implementation +- `src/main/data/api/` - API server and handlers +- `src/main/data/CacheService.ts` - Cache service +- `src/main/data/PreferenceService.ts` - Preference service +- `src/main/data/db/` - Database schemas + +### Renderer Process Implementation +- `src/renderer/src/data/DataApiService.ts` - API client +- `src/renderer/src/data/CacheService.ts` - Cache service +- `src/renderer/src/data/PreferenceService.ts` - Preference service +- `src/renderer/src/data/hooks/` - React hooks + diff --git a/docs/en/references/data/api-design-guidelines.md b/docs/en/references/data/api-design-guidelines.md new file mode 100644 index 0000000000..6c6da6eba0 --- /dev/null +++ b/docs/en/references/data/api-design-guidelines.md @@ -0,0 +1,250 @@ +# API Design Guidelines + +Guidelines for designing RESTful APIs in the Cherry Studio Data API system. + +## Path Naming + +| Rule | Example | Notes | +|------|---------|-------| +| Use plural nouns for collections | `/topics`, `/messages` | Resources are collections | +| Use kebab-case for multi-word paths | `/user-settings` | Not camelCase or snake_case | +| Express hierarchy via nesting | `/topics/:topicId/messages` | Parent-child relationships | +| Avoid verbs for CRUD operations | `/topics` not `/getTopics` | HTTP methods express action | + +## HTTP Method Semantics + +| Method | Purpose | Idempotent | Typical Response | +|--------|---------|------------|------------------| +| GET | Retrieve resource(s) | Yes | 200 + data | +| POST | Create resource | No | 201 + created entity | +| PUT | Replace entire resource | Yes | 200 + updated entity | +| PATCH | Partial update | Yes | 200 + updated entity | +| DELETE | Remove resource | Yes | 204 / void | + +## Standard Endpoint Patterns + +```typescript +// Collection operations +'/topics': { + GET: { ... } // List with pagination/filtering + POST: { ... } // Create new resource +} + +// Individual resource operations +'/topics/:id': { + GET: { ... } // Get single resource + PUT: { ... } // Replace resource + PATCH: { ... } // Partial update + DELETE: { ... } // Remove resource +} + +// Nested resources (use for parent-child relationships) +'/topics/:topicId/messages': { + GET: { ... } // List messages under topic + POST: { ... } // Create message in topic +} +``` + +## PATCH vs Dedicated Endpoints + +### Decision Criteria + +Use this decision tree to determine the appropriate approach: + +``` +Operation characteristics: +├── Simple field update with no side effects? +│ └── Yes → Use PATCH +├── High-frequency operation with clear business meaning? +│ └── Yes → Use dedicated endpoint (noun-based sub-resource) +├── Operation triggers complex side effects or validation? +│ └── Yes → Use dedicated endpoint +├── Operation creates new resources? +│ └── Yes → Use POST to dedicated endpoint +└── Default → Use PATCH +``` + +### Guidelines + +| Scenario | Approach | Example | +|----------|----------|---------| +| Simple field update | PATCH | `PATCH /messages/:id { data: {...} }` | +| High-frequency + business meaning | Dedicated sub-resource | `PUT /topics/:id/active-node { nodeId }` | +| Complex validation/side effects | Dedicated endpoint | `POST /messages/:id/move { newParentId }` | +| Creates new resources | POST dedicated | `POST /messages/:id/duplicate` | + +### Naming for Dedicated Endpoints + +- **Prefer noun-based paths** over verb-based when possible +- Treat the operation target as a sub-resource: `/topics/:id/active-node` not `/topics/:id/switch-branch` +- Use POST for actions that create resources or have non-idempotent side effects +- Use PUT for setting/replacing a sub-resource value + +### Examples + +```typescript +// ✅ Good: Noun-based sub-resource for high-frequency operation +PUT /topics/:id/active-node +{ nodeId: string } + +// ✅ Good: Simple field update via PATCH +PATCH /messages/:id +{ data: MessageData } + +// ✅ Good: POST for resource creation +POST /messages/:id/duplicate +{ includeDescendants?: boolean } + +// ❌ Avoid: Verb in path when noun works +POST /topics/:id/switch-branch // Use PUT /topics/:id/active-node instead + +// ❌ Avoid: Dedicated endpoint for simple updates +POST /messages/:id/update-content // Use PATCH /messages/:id instead +``` + +## Non-CRUD Operations + +Use verb-based paths for operations that don't fit CRUD semantics: + +```typescript +// Search +'/topics/search': { + GET: { query: { q: string } } +} + +// Statistics / Aggregations +'/topics/stats': { + GET: { response: { total: number, ... } } +} + +// Resource actions (state changes, triggers) +'/topics/:id/archive': { + POST: { response: { archived: boolean } } +} + +'/topics/:id/duplicate': { + POST: { response: Topic } +} +``` + +## Query Parameters + +| Purpose | Pattern | Example | +|---------|---------|---------| +| Pagination | `page` + `limit` | `?page=1&limit=20` | +| Sorting | `orderBy` + `order` | `?orderBy=createdAt&order=desc` | +| Filtering | direct field names | `?status=active&type=chat` | +| Search | `q` or `search` | `?q=keyword` | + +## Response Status Codes + +Use standard HTTP status codes consistently: + +| Status | Usage | Example | +|--------|-------|---------| +| 200 OK | Successful GET/PUT/PATCH | Return updated resource | +| 201 Created | Successful POST | Return created resource | +| 204 No Content | Successful DELETE | No body | +| 400 Bad Request | Invalid request format | Malformed JSON | +| 400 Invalid Operation | Business rule violation | Delete root without cascade, cycle creation | +| 401 Unauthorized | Authentication required | Missing/invalid token | +| 403 Permission Denied | Insufficient permissions | Access denied to resource | +| 404 Not Found | Resource not found | Invalid ID | +| 409 Conflict | Concurrent modification or data inconsistency | Version conflict, data corruption | +| 422 Unprocessable | Validation failed | Invalid field values | +| 423 Locked | Resource temporarily locked | File being exported | +| 429 Too Many Requests | Rate limit exceeded | Throttling | +| 500 Internal Error | Server error | Unexpected failure | +| 503 Service Unavailable | Service temporarily down | Maintenance mode | +| 504 Timeout | Request timed out | Long-running operation | + +## Error Response Format + +All error responses follow the `SerializedDataApiError` structure (transmitted via IPC): + +```typescript +interface SerializedDataApiError { + code: ErrorCode | string // ErrorCode enum value (e.g., 'NOT_FOUND') + message: string // Human-readable error message + status: number // HTTP status code + details?: Record // Additional context (e.g., field errors) + requestContext?: { // Request context for debugging + requestId: string + path: string + method: HttpMethod + timestamp?: number + } + // Note: stack trace is NOT transmitted via IPC - rely on Main process logs +} +``` + +**Examples:** + +```typescript +// 404 Not Found +{ + code: 'NOT_FOUND', + message: "Topic with id 'abc123' not found", + status: 404, + details: { resource: 'Topic', id: 'abc123' }, + requestContext: { requestId: 'req_123', path: '/topics/abc123', method: 'GET' } +} + +// 422 Validation Error +{ + code: 'VALIDATION_ERROR', + message: 'Request validation failed', + status: 422, + details: { + fieldErrors: { + name: ['Name is required', 'Name must be at least 3 characters'], + email: ['Invalid email format'] + } + } +} + +// 504 Timeout +{ + code: 'TIMEOUT', + message: 'Request timeout: fetch topics (3000ms)', + status: 504, + details: { operation: 'fetch topics', timeoutMs: 3000 } +} + +// 400 Invalid Operation +{ + code: 'INVALID_OPERATION', + message: 'Invalid operation: delete root message - cascade=true required', + status: 400, + details: { operation: 'delete root message', reason: 'cascade=true required' } +} +``` + +Use `DataApiErrorFactory` utilities to create consistent errors: + +```typescript +import { DataApiErrorFactory, DataApiError } from '@shared/data/api' + +// Using factory methods (recommended) +throw DataApiErrorFactory.notFound('Topic', id) +throw DataApiErrorFactory.validation({ name: ['Required'] }) +throw DataApiErrorFactory.database(error, 'insert topic') +throw DataApiErrorFactory.timeout('fetch topics', 3000) +throw DataApiErrorFactory.dataInconsistent('Topic', 'parent reference broken') +throw DataApiErrorFactory.invalidOperation('delete root message', 'cascade=true required') + +// Check if error is retryable +if (error instanceof DataApiError && error.isRetryable) { + await retry(operation) +} +``` + +## Naming Conventions Summary + +| Element | Case | Example | +|---------|------|---------| +| Paths | kebab-case, plural | `/user-settings`, `/topics` | +| Path params | camelCase | `:topicId`, `:messageId` | +| Query params | camelCase | `orderBy`, `pageSize` | +| Body fields | camelCase | `createdAt`, `userName` | +| Error codes | SCREAMING_SNAKE | `NOT_FOUND`, `VALIDATION_ERROR` | diff --git a/docs/en/references/data/api-types.md b/docs/en/references/data/api-types.md new file mode 100644 index 0000000000..a475b4d0c4 --- /dev/null +++ b/docs/en/references/data/api-types.md @@ -0,0 +1,338 @@ +# Data API Type System + +This directory contains the type definitions and utilities for Cherry Studio's Data API system, which provides type-safe IPC communication between renderer and main processes. + +## Directory Structure + +``` +packages/shared/data/api/ +├── index.ts # Barrel export for infrastructure types +├── apiTypes.ts # Core request/response types and API utilities +├── apiPaths.ts # Path template literal type utilities +├── apiErrors.ts # Error handling: ErrorCode, DataApiError class, factory +└── schemas/ + ├── index.ts # Schema composition (merges all domain schemas) + └── test.ts # Test API schema and DTOs +``` + +## File Responsibilities + +| File | Purpose | +|------|---------| +| `apiTypes.ts` | Core types (`DataRequest`, `DataResponse`, `ApiClient`) and schema utilities | +| `apiPaths.ts` | Template literal types for path resolution (`/items/:id` → `/items/${string}`) | +| `apiErrors.ts` | `ErrorCode` enum, `DataApiError` class, `DataApiErrorFactory`, retryability config | +| `index.ts` | Unified export of infrastructure types (not domain DTOs) | +| `schemas/index.ts` | Composes all domain schemas into `ApiSchemas` using intersection types | +| `schemas/*.ts` | Domain-specific API definitions and DTOs | + +## Import Conventions + +### Infrastructure Types (via barrel export) + +Use the barrel export for common API infrastructure: + +```typescript +import type { + DataRequest, + DataResponse, + ApiClient, + // Pagination types + OffsetPaginationParams, + OffsetPaginationResponse, + CursorPaginationParams, + CursorPaginationResponse, + PaginationResponse, + // Query parameter types + SortParams, + SearchParams +} from '@shared/data/api' + +import { + ErrorCode, + DataApiError, + DataApiErrorFactory, + isDataApiError, + toDataApiError, + // Pagination type guards + isOffsetPaginationResponse, + isCursorPaginationResponse +} from '@shared/data/api' +``` + +### Domain DTOs (directly from schema files) + +Import domain-specific types directly from their schema files: + +```typescript +// Topic domain +import type { Topic, CreateTopicDto, UpdateTopicDto } from '@shared/data/api/schemas/topic' + +// Message domain +import type { Message, CreateMessageDto } from '@shared/data/api/schemas/message' + +// Test domain (development) +import type { TestItem, CreateTestItemDto } from '@shared/data/api/schemas/test' +``` + +## Pagination Types + +The API system supports two pagination modes with composable query parameters. + +### Request Parameters + +| Type | Fields | Use Case | +|------|--------|----------| +| `OffsetPaginationParams` | `page?`, `limit?` | Traditional page-based navigation | +| `CursorPaginationParams` | `cursor?`, `limit?` | Infinite scroll, real-time feeds | +| `SortParams` | `sortBy?`, `sortOrder?` | Sorting (combine as needed) | +| `SearchParams` | `search?` | Text search (combine as needed) | + +### Cursor Semantics + +The `cursor` in `CursorPaginationParams` marks an **exclusive boundary** - the cursor item itself is never included in the response. + +**Common patterns:** + +| Pattern | Use Case | Behavior | +|---------|----------|----------| +| "after cursor" | Forward pagination, new items | Returns items AFTER cursor | +| "before cursor" | Backward/historical loading | Returns items BEFORE cursor | + +The specific semantic depends on the API endpoint. For example: +- `GET /topics/:id/messages` uses "before cursor" for loading historical messages +- Other endpoints may use "after cursor" for forward pagination + +**Example: Loading historical messages** + +```typescript +// First request - get most recent messages +const res1 = await api.get('/topics/123/messages', { query: { limit: 20 } }) +// res1: { items: [msg80...msg99], nextCursor: 'msg80-id', activeNodeId: '...' } + +// Load more - get older messages before the cursor +const res2 = await api.get('/topics/123/messages', { + query: { cursor: res1.nextCursor, limit: 20 } +}) +// res2: { items: [msg60...msg79], nextCursor: 'msg60-id', activeNodeId: '...' } +// Note: msg80 is NOT in res2 (cursor is exclusive) +``` + +### Response Types + +| Type | Fields | Description | +|------|--------|-------------| +| `OffsetPaginationResponse` | `items`, `total`, `page` | Page-based results | +| `CursorPaginationResponse` | `items`, `nextCursor?` | Cursor-based results | +| `PaginationResponse` | Union of both | When either mode is acceptable | + +### Usage Examples + +```typescript +// Offset pagination with sort and search +query?: OffsetPaginationParams & SortParams & SearchParams & { + type?: string +} +response: OffsetPaginationResponse + +// Cursor pagination for infinite scroll +query?: CursorPaginationParams & { + userId: string +} +response: CursorPaginationResponse +``` + +### Client-side Calculations + +For `OffsetPaginationResponse`, clients can calculate: +```typescript +const pageCount = Math.ceil(total / limit) +const hasNext = page * limit < total +const hasPrev = page > 1 +``` + +For `CursorPaginationResponse`: +```typescript +const hasNext = nextCursor !== undefined +``` + +## Adding a New Domain Schema + +1. Create the schema file (e.g., `schemas/topic.ts`): + +```typescript +import type { + OffsetPaginationParams, + OffsetPaginationResponse, + SearchParams, + SortParams +} from '../apiTypes' + +// Domain models +export interface Topic { + id: string + name: string + createdAt: string +} + +export interface CreateTopicDto { + name: string +} + +// API Schema - validation happens via AssertValidSchemas in index.ts +export interface TopicSchemas { + '/topics': { + GET: { + query?: OffsetPaginationParams & SortParams & SearchParams + response: OffsetPaginationResponse // response is required + } + POST: { + body: CreateTopicDto + response: Topic + } + } + '/topics/:id': { + GET: { + params: { id: string } + response: Topic + } + } +} +``` + +**Validation**: Schemas are validated at composition level via `AssertValidSchemas` in `schemas/index.ts`: +- Ensures only valid HTTP methods (GET, POST, PUT, DELETE, PATCH) +- Requires `response` field for each endpoint +- Invalid schemas cause TypeScript errors at the composition point + +> **Design Guidelines**: Before creating new schemas, review the [API Design Guidelines](./api-design-guidelines.md) for path naming, HTTP methods, and error handling conventions. + +2. Register in `schemas/index.ts`: + +```typescript +import type { TopicSchemas } from './topic' + +// AssertValidSchemas provides fallback validation even if ValidateSchema is forgotten +export type ApiSchemas = AssertValidSchemas +``` + +3. Implement handlers in `src/main/data/api/handlers/` + +## Type Safety Features + +### Path Resolution + +The system uses template literal types to map concrete paths to schema paths: + +```typescript +// Concrete path '/topics/abc123' maps to schema path '/topics/:id' +api.get('/topics/abc123') // TypeScript knows this returns Topic +``` + +### Exhaustive Handler Checking + +`ApiImplementation` type ensures all schema endpoints have handlers: + +```typescript +// TypeScript will error if any endpoint is missing +const handlers: ApiImplementation = { + '/topics': { + GET: async () => { /* ... */ }, + POST: async ({ body }) => { /* ... */ } + } + // Missing '/topics/:id' would cause compile error +} +``` + +### Type-Safe Client + +`ApiClient` provides fully typed methods: + +```typescript +const topic = await api.get('/topics/123') // Returns Topic +const topics = await api.get('/topics', { + query: { page: 1, limit: 20, search: 'hello' } +}) // Returns OffsetPaginationResponse +await api.post('/topics', { body: { name: 'New' } }) // Body is typed as CreateTopicDto +``` + +## Error Handling + +The error system provides type-safe error handling with automatic retryability detection: + +```typescript +import { + DataApiError, + DataApiErrorFactory, + ErrorCode, + isDataApiError, + toDataApiError +} from '@shared/data/api' + +// Create errors using the factory (recommended) +throw DataApiErrorFactory.notFound('Topic', id) +throw DataApiErrorFactory.validation({ name: ['Name is required'] }) +throw DataApiErrorFactory.timeout('fetch topics', 3000) +throw DataApiErrorFactory.database(originalError, 'insert topic') + +// Or create directly with the class +throw new DataApiError( + ErrorCode.NOT_FOUND, + 'Topic not found', + 404, + { resource: 'Topic', id: 'abc123' } +) + +// Check if error is retryable (for automatic retry logic) +if (error instanceof DataApiError && error.isRetryable) { + await retry(operation) +} + +// Check error type +if (error instanceof DataApiError) { + if (error.isClientError) { + // 4xx - issue with the request + } else if (error.isServerError) { + // 5xx - server-side issue + } +} + +// Convert any error to DataApiError +const apiError = toDataApiError(unknownError, 'context') + +// Serialize for IPC (Main → Renderer) +const serialized = apiError.toJSON() + +// Reconstruct from IPC response (Renderer) +const reconstructed = DataApiError.fromJSON(response.error) +``` + +### Retryable Error Codes + +The following errors are automatically considered retryable: +- `SERVICE_UNAVAILABLE` (503) +- `TIMEOUT` (504) +- `RATE_LIMIT_EXCEEDED` (429) +- `DATABASE_ERROR` (500) +- `INTERNAL_SERVER_ERROR` (500) +- `RESOURCE_LOCKED` (423) + +## Architecture Overview + +``` +Renderer Main +──────────────────────────────────────────────────── +DataApiService ──IPC──► IpcAdapter ──► ApiServer + │ │ + │ ▼ + ApiClient MiddlewareEngine + (typed) │ + ▼ + Handlers + (typed) +``` + +- **Renderer**: Uses `DataApiService` with type-safe `ApiClient` interface +- **IPC**: Requests serialized via `IpcAdapter` +- **Main**: `ApiServer` routes to handlers through `MiddlewareEngine` +- **Type Safety**: End-to-end types from client call to handler implementation diff --git a/docs/en/references/data/cache-overview.md b/docs/en/references/data/cache-overview.md new file mode 100644 index 0000000000..ab0f366b27 --- /dev/null +++ b/docs/en/references/data/cache-overview.md @@ -0,0 +1,142 @@ +# Cache System Overview + +The Cache system provides a three-tier caching architecture for temporary and regenerable data across the Cherry Studio application. + +## Purpose + +CacheService handles data that: +- Can be **regenerated or lost without user impact** +- Requires no backup or cross-device synchronization +- Has lifecycle tied to component, window, or app session + +## Three-Tier Architecture + +| Tier | Scope | Persistence | Use Case | +|------|-------|-------------|----------| +| **Memory Cache** | Component-level | Lost on app restart | API responses, computed results | +| **Shared Cache** | Cross-window | Lost on app restart | Window state, cross-window coordination | +| **Persist Cache** | Cross-window + localStorage | Survives app restarts | Recent items, non-critical preferences | + +### Memory Cache +- Fastest access, in-process memory +- Isolated per renderer process +- Best for: expensive computations, API response caching + +### Shared Cache +- Synchronized bidirectionally between Main and all Renderer windows via IPC +- Main process maintains authoritative copy and provides initialization sync for new windows +- New windows fetch complete shared cache state from Main on startup +- Best for: window layouts, shared UI state + +### Persist Cache +- Backed by localStorage in renderer +- Main process maintains authoritative copy +- Best for: recent files, search history, non-critical state + +## Key Features + +### TTL (Time To Live) Support +```typescript +// Cache with 30-second expiration +cacheService.set('temp.calculation', result, 30000) +``` + +### Hook Reference Tracking +- Prevents deletion of cache entries while React hooks are subscribed +- Automatic cleanup when components unmount + +### Cross-Window Synchronization +- Shared and Persist caches sync across all windows +- Uses IPC broadcast for real-time updates +- Main process resolves conflicts + +### Type Safety +- **Fixed keys**: Schema-based keys for compile-time checking (e.g., `'app.user.avatar'`) +- **Template keys**: Dynamic patterns with automatic type inference (e.g., `'scroll.position.${id}'` matches `'scroll.position.topic123'`) +- **Casual methods**: For completely dynamic keys with manual typing (blocked from using schema-defined keys) + +Note: Template keys follow the same dot-separated naming pattern as fixed keys. When `${xxx}` is treated as a literal string, the key must match the format: `xxx.yyy.zzz_www` + +## Data Categories + +### Performance Cache (Memory tier) +- Computed results from expensive operations +- API response caching +- Parsed/transformed data + +### UI State Cache (Shared tier) +- Sidebar collapsed state +- Panel dimensions +- Scroll positions + +### Non-Critical Persistence (Persist tier) +- Recently used items +- Search history +- User-customized but regenerable data + +## Architecture Diagram + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Renderer Process │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │ useCache │ │useSharedCache│ │usePersistCache│ │ +│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │ +│ │ │ │ │ +│ └────────────────┼────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────┐ │ +│ │ CacheService │ │ +│ │ (Renderer) │ │ +│ └──────────┬──────────┘ │ +└─────────────────────────┼────────────────────────────────────┘ + │ IPC (shared/persist only) +┌─────────────────────────┼────────────────────────────────────┐ +│ Main Process ▼ │ +│ ┌─────────────────────┐ │ +│ │ CacheService │ │ +│ │ (Main) │ │ +│ └─────────────────────┘ │ +│ - Source of truth for shared/persist │ +│ - Broadcasts updates to all windows │ +└──────────────────────────────────────────────────────────────┘ +``` + +## Main vs Renderer Responsibilities + +### Main Process CacheService +- Manages internal cache for Main process services +- Maintains authoritative SharedCache with type-safe access (`getShared`, `setShared`, `hasShared`, `deleteShared`) +- Provides `getAllShared()` for new window initialization sync +- Handles IPC requests from renderers and broadcasts updates to all windows +- Manages TTL expiration using absolute timestamps (`expireAt`) for precise cross-window sync + +### Renderer Process CacheService +- Manages local memory cache and SharedCache local copy +- Syncs SharedCache from Main on window initialization (async, non-blocking) +- Provides ready state tracking via `isSharedCacheReady()` and `onSharedCacheReady()` +- Broadcasts cache updates to Main for cross-window sync +- Handles hook subscriptions and updates +- Local TTL management for memory cache + +## Usage Summary + +For detailed code examples and API usage, see [Cache Usage Guide](./cache-usage.md). + +### Key Types + +| Type | Example Schema | Example Usage | Type Inference | +|------|----------------|---------------|----------------| +| Fixed key | `'app.user.avatar': string` | `get('app.user.avatar')` | Automatic | +| Template key | `'scroll.position.${id}': number` | `get('scroll.position.topic123')` | Automatic | +| Casual key | N/A | `getCasual('my.custom.key')` | Manual | + +### API Reference + +| Method | Tier | Key Type | +|--------|------|----------| +| `useCache` / `get` / `set` | Memory | Fixed + Template keys | +| `getCasual` / `setCasual` | Memory | Dynamic keys only (schema keys blocked) | +| `useSharedCache` / `getShared` / `setShared` | Shared | Fixed keys only | +| `getSharedCasual` / `setSharedCasual` | Shared | Dynamic keys only (schema keys blocked) | +| `usePersistCache` / `getPersist` / `setPersist` | Persist | Fixed keys only | diff --git a/docs/en/references/data/cache-usage.md b/docs/en/references/data/cache-usage.md new file mode 100644 index 0000000000..54066e700f --- /dev/null +++ b/docs/en/references/data/cache-usage.md @@ -0,0 +1,424 @@ +# Cache Usage Guide + +This guide covers how to use the Cache system in React components and services. + +## React Hooks + +### useCache (Memory Cache) + +Memory cache is lost on app restart. Best for temporary computed results. + +```typescript +import { useCache } from '@data/hooks/useCache' + +// Basic usage with default value +const [counter, setCounter] = useCache('ui.counter', 0) + +// Update the value +setCounter(counter + 1) + +// With TTL (30 seconds) +const [searchResults, setSearchResults] = useCache('search.results', [], { ttl: 30000 }) +``` + +### useSharedCache (Cross-Window Cache) + +Shared cache syncs across all windows, lost on app restart. + +```typescript +import { useSharedCache } from '@data/hooks/useCache' + +// Cross-window state +const [layout, setLayout] = useSharedCache('window.layout', defaultLayout) + +// Sidebar state shared between windows +const [sidebarCollapsed, setSidebarCollapsed] = useSharedCache('ui.sidebar.collapsed', false) +``` + +### usePersistCache (Persistent Cache) + +Persist cache survives app restarts via localStorage. + +```typescript +import { usePersistCache } from '@data/hooks/useCache' + +// Recent files list (survives restart) +const [recentFiles, setRecentFiles] = usePersistCache('app.recent_files', []) + +// Search history +const [searchHistory, setSearchHistory] = usePersistCache('search.history', []) +``` + +## CacheService Direct Usage + +For non-React code or more control, use CacheService directly. + +### Memory Cache + +```typescript +import { cacheService } from '@data/CacheService' + +// Type-safe (schema key) +cacheService.set('temp.calculation', result) +const result = cacheService.get('temp.calculation') + +// With TTL (30 seconds) +cacheService.set('temp.calculation', result, 30000) + +// Casual (dynamic key, manual type) +cacheService.setCasual(`topic:${id}`, topicData) +const topic = cacheService.getCasual(`topic:${id}`) + +// Check existence +if (cacheService.has('temp.calculation')) { + // ... +} + +// Delete +cacheService.delete('temp.calculation') +cacheService.deleteCasual(`topic:${id}`) +``` + +### Shared Cache + +```typescript +// Type-safe (schema key) +cacheService.setShared('window.layout', layoutConfig) +const layout = cacheService.getShared('window.layout') + +// Casual (dynamic key) +cacheService.setSharedCasual(`window:${windowId}`, state) +const state = cacheService.getSharedCasual(`window:${windowId}`) + +// Delete +cacheService.deleteShared('window.layout') +cacheService.deleteSharedCasual(`window:${windowId}`) +``` + +### Persist Cache + +```typescript +// Schema keys only (no Casual methods for persist) +cacheService.setPersist('app.recent_files', recentFiles) +const files = cacheService.getPersist('app.recent_files') + +// Delete +cacheService.deletePersist('app.recent_files') +``` + +## Main Process Usage + +Main process CacheService provides SharedCache for cross-window state management. + +### SharedCache in Main Process + +```typescript +import { cacheService } from '@main/data/CacheService' + +// Type-safe (schema key) - matches Renderer's type system +cacheService.setShared('window.layout', layoutConfig) +const layout = cacheService.getShared('window.layout') + +// With TTL (30 seconds) +cacheService.setShared('temp.state', state, 30000) + +// Check existence +if (cacheService.hasShared('window.layout')) { + // ... +} + +// Delete +cacheService.deleteShared('window.layout') +``` + +**Note**: Main CacheService does NOT support Casual methods (`getSharedCasual`, etc.). Only schema-based type-safe access is available in Main process. + +### Sync Strategy + +- **Renderer → Main**: When Renderer calls `setShared()`, it broadcasts to Main via IPC. Main updates its SharedCache and relays to other windows. +- **Main → Renderer**: When Main calls `setShared()`, it broadcasts to all Renderer windows. +- **New Window Initialization**: New windows fetch complete SharedCache state from Main via `getAllShared()`. Uses Main-priority override strategy for conflicts. + +## Type-Safe vs Casual Methods + +### Type-Safe Methods +- Use predefined keys from cache schema +- Full auto-completion and type inference +- Compile-time key validation + +```typescript +// Key 'ui.counter' must exist in schema +const [counter, setCounter] = useCache('ui.counter', 0) +``` + +### Casual Methods +- Use dynamically constructed keys +- Require manual type specification via generics +- No compile-time key validation +- **Cannot use keys that match schema patterns** (including template keys) + +```typescript +// Dynamic key, must specify type +const topic = cacheService.getCasual(`my.custom.key`) + +// Compile error: cannot use schema keys with Casual methods +cacheService.getCasual('app.user.avatar') // Error: matches fixed key +cacheService.getCasual('scroll.position.topic123') // Error: matches template key +``` + +### Template Keys + +Template keys provide type-safe caching for dynamic key patterns. Define a template in the schema using `${variable}` syntax, and TypeScript will automatically match and infer types for concrete keys. + +**Important**: Template keys follow the same dot-separated naming pattern as fixed keys. When `${xxx}` is treated as a literal string, the key must match the format: `xxx.yyy.zzz_www` + +#### Defining Template Keys + +```typescript +// packages/shared/data/cache/cacheSchemas.ts +export type UseCacheSchema = { + // Fixed key + 'app.user.avatar': string + + // Template keys - use ${variable} for dynamic segments + // Must follow dot-separated pattern like fixed keys + 'scroll.position.${topicId}': number + 'entity.cache.${type}_${id}': EntityData +} + +// Default values for templates (shared by all instances) +export const DefaultUseCache: UseCacheSchema = { + 'app.user.avatar': '', + 'scroll.position.${topicId}': 0, + 'entity.cache.${type}_${id}': { loaded: false } +} +``` + +#### Using Template Keys + +```typescript +// TypeScript infers the value type from schema +const [scrollPos, setScrollPos] = useCache('scroll.position.topic123') +// scrollPos is inferred as `number` + +const [entity, setEntity] = useCache('entity.cache.user_456') +// entity is inferred as `EntityData` + +// Direct CacheService usage +cacheService.set('scroll.position.mytopic', 150) // OK: value must be number +cacheService.set('scroll.position.mytopic', 'hi') // Error: type mismatch +``` + +#### Template Key Benefits + +| Feature | Fixed Keys | Template Keys | Casual Methods | +|---------|-----------|---------------|----------------| +| Type inference | ✅ Automatic | ✅ Automatic | ❌ Manual | +| Auto-completion | ✅ Full | ✅ Partial (prefix) | ❌ None | +| Compile-time validation | ✅ Yes | ✅ Yes | ❌ No | +| Dynamic IDs | ❌ No | ✅ Yes | ✅ Yes | +| Default values | ✅ Yes | ✅ Shared per template | ❌ No | + +### When to Use Which + +| Scenario | Method | Example | +|----------|--------|---------| +| Fixed cache keys | Type-safe | `useCache('ui.counter')` | +| Dynamic keys with known pattern | Template key | `useCache('scroll.position.topic123')` | +| Entity caching by ID | Template key | `get('entity.cache.user_456')` | +| Completely dynamic keys | Casual | `getCasual(\`custom.dynamic.${x}\`)` | +| UI state | Type-safe | `useSharedCache('window.layout')` | + +## Common Patterns + +### Caching Expensive Computations + +```typescript +function useExpensiveData(input: string) { + const [cached, setCached] = useCache(`computed:${input}`, null) + + useEffect(() => { + if (cached === null) { + const result = expensiveComputation(input) + setCached(result) + } + }, [input, cached, setCached]) + + return cached +} +``` + +### Cross-Window Coordination + +```typescript +// Window A: Update shared state +const [activeFile, setActiveFile] = useSharedCache('editor.activeFile', null) +setActiveFile(selectedFile) + +// Window B: Reacts to change automatically +const [activeFile] = useSharedCache('editor.activeFile', null) +// activeFile updates when Window A changes it +``` + +### Recent Items with Limit + +```typescript +const [recentItems, setRecentItems] = usePersistCache('app.recentItems', []) + +const addRecentItem = (item: Item) => { + setRecentItems(prev => { + const filtered = prev.filter(i => i.id !== item.id) + return [item, ...filtered].slice(0, 10) // Keep last 10 + }) +} +``` + +### Cache with Expiration Check + +```typescript +interface CachedData { + data: T + timestamp: number +} + +function useCachedWithExpiry(key: string, fetcher: () => Promise, maxAge: number) { + const [cached, setCached] = useCache | null>(key, null) + const [data, setData] = useState(cached?.data ?? null) + + useEffect(() => { + const isExpired = !cached || Date.now() - cached.timestamp > maxAge + + if (isExpired) { + fetcher().then(result => { + setCached({ data: result, timestamp: Date.now() }) + setData(result) + }) + } + }, [key, maxAge]) + + return data +} +``` + +## Adding New Cache Keys + +### Adding Fixed Keys + +#### 1. Add to Cache Schema + +```typescript +// packages/shared/data/cache/cacheSchemas.ts +export type UseCacheSchema = { + // Existing keys... + 'myFeature.data': MyDataType +} + +export const DefaultUseCache: UseCacheSchema = { + // Existing defaults... + 'myFeature.data': { items: [], lastUpdated: 0 } +} +``` + +#### 2. Define Value Type (if complex) + +```typescript +// packages/shared/data/cache/cacheValueTypes.ts +export interface MyDataType { + items: string[] + lastUpdated: number +} +``` + +#### 3. Use in Code + +```typescript +// Now type-safe +const [data, setData] = useCache('myFeature.data') +``` + +### Adding Template Keys + +#### 1. Add Template to Schema + +```typescript +// packages/shared/data/cache/cacheSchemas.ts +export type UseCacheSchema = { + // Existing keys... + // Template key with dynamic segment + 'scroll.position.${topicId}': number +} + +export const DefaultUseCache: UseCacheSchema = { + // Existing defaults... + // Default shared by all instances of this template + 'scroll.position.${topicId}': 0 +} +``` + +#### 2. Use in Code + +```typescript +// TypeScript infers number from template pattern +const [scrollPos, setScrollPos] = useCache(`scroll.position.${topicId}`) + +// Works with any string in the dynamic segment +const [pos1, setPos1] = useCache('scroll.position.topic123') +const [pos2, setPos2] = useCache('scroll.position.conversationabc') +``` + +### Key Naming Convention + +All keys (fixed and template) must follow the same naming convention: + +- **Format**: `namespace.sub.key_name` (template `${xxx}` treated as a literal string segment) +- **Rules**: + - Start with lowercase letter + - Use lowercase letters, numbers, and underscores + - Separate segments with dots (`.`) + - Template placeholders `${xxx}` are treated as literal string segments +- **Examples**: + - ✅ `app.user.avatar` + - ✅ `scroll.position.${id}` + - ✅ `entity.cache.${type}_${id}` + - ❌ `scroll.position:${id}` (colon not allowed) + - ❌ `UserAvatar` (no dots) + - ❌ `App.User` (uppercase) + +## Shared Cache Ready State + +Renderer CacheService provides ready state tracking for SharedCache initialization sync. + +```typescript +import { cacheService } from '@data/CacheService' + +// Check if shared cache is ready +if (cacheService.isSharedCacheReady()) { + // SharedCache has been synced from Main +} + +// Register callback when ready +const unsubscribe = cacheService.onSharedCacheReady(() => { + // Called immediately if already ready, or when sync completes + console.log('SharedCache ready!') +}) + +// Cleanup +unsubscribe() +``` + +**Behavior notes**: +- `getShared()` returns `undefined` before ready (expected behavior) +- `setShared()` works immediately and broadcasts to Main (Main updates its cache) +- Hooks like `useSharedCache` work normally - they set initial values and update when sync completes +- Main-priority override: when sync completes, Main's values override local values + +## Best Practices + +1. **Choose the right tier**: Memory for temp, Shared for cross-window, Persist for survival +2. **Use TTL for stale data**: Prevent serving outdated cached values +3. **Prefer type-safe keys**: Add to schema when possible +4. **Use template keys for patterns**: When you have a recurring pattern (e.g., caching by ID), define a template key instead of using casual methods +5. **Reserve casual for truly dynamic keys**: Only use casual methods when the key pattern is completely unknown at development time +6. **Clean up dynamic keys**: Remove casual cache entries when no longer needed +7. **Consider data size**: Persist cache uses localStorage (limited to ~5MB) +8. **Use absolute timestamps for sync**: CacheSyncMessage uses `expireAt` (absolute Unix timestamp) for precise cross-window TTL sync diff --git a/docs/en/references/data/data-api-in-main.md b/docs/en/references/data/data-api-in-main.md new file mode 100644 index 0000000000..90ca93ad0d --- /dev/null +++ b/docs/en/references/data/data-api-in-main.md @@ -0,0 +1,360 @@ +# DataApi in Main Process + +This guide covers how to implement API handlers, services, and repositories in the Main process. + +## Architecture Layers + +``` +Handlers → Services → Repositories → Database +``` + +- **Handlers**: Thin layer, extract params, call service, transform response +- **Services**: Business logic, validation, transaction coordination +- **Repositories**: Data access (for complex domains) +- **Database**: Drizzle ORM + SQLite + +## Implementing Handlers + +### Location +`src/main/data/api/handlers/` + +### Handler Responsibilities +- Extract parameters from request +- Delegate to business service +- Transform response for IPC +- **NO business logic here** + +### Example Handler + +```typescript +// handlers/topic.ts +import type { ApiImplementation } from '@shared/data/api' +import { TopicService } from '@data/services/TopicService' + +export const topicHandlers: Partial = { + '/topics': { + GET: async ({ query }) => { + const { page = 1, limit = 20 } = query ?? {} + return await TopicService.getInstance().list({ page, limit }) + }, + POST: async ({ body }) => { + return await TopicService.getInstance().create(body) + } + }, + '/topics/:id': { + GET: async ({ params }) => { + return await TopicService.getInstance().getById(params.id) + }, + PUT: async ({ params, body }) => { + return await TopicService.getInstance().replace(params.id, body) + }, + PATCH: async ({ params, body }) => { + return await TopicService.getInstance().update(params.id, body) + }, + DELETE: async ({ params }) => { + await TopicService.getInstance().delete(params.id) + } + } +} +``` + +### Register Handlers + +```typescript +// handlers/index.ts +import { topicHandlers } from './topic' +import { messageHandlers } from './message' + +export const allHandlers: ApiImplementation = { + ...topicHandlers, + ...messageHandlers +} +``` + +## Implementing Services + +### Location +`src/main/data/services/` + +### Service Responsibilities +- Business validation +- Transaction coordination +- Domain workflows +- Call repositories or direct Drizzle + +### Example Service + +```typescript +// services/TopicService.ts +import { DbService } from '@data/db/DbService' +import { TopicRepository } from '@data/repositories/TopicRepository' +import { DataApiErrorFactory } from '@shared/data/api' + +export class TopicService { + private static instance: TopicService + private topicRepo: TopicRepository + + private constructor() { + this.topicRepo = new TopicRepository() + } + + static getInstance(): TopicService { + if (!this.instance) { + this.instance = new TopicService() + } + return this.instance + } + + async list(options: { page: number; limit: number }) { + return await this.topicRepo.findAll(options) + } + + async getById(id: string) { + const topic = await this.topicRepo.findById(id) + if (!topic) { + throw DataApiErrorFactory.notFound('Topic', id) + } + return topic + } + + async create(data: CreateTopicDto) { + // Business validation + this.validateTopicData(data) + + return await this.topicRepo.create(data) + } + + async update(id: string, data: Partial) { + const existing = await this.getById(id) // Throws if not found + + return await this.topicRepo.update(id, data) + } + + async delete(id: string) { + await this.getById(id) // Throws if not found + await this.topicRepo.delete(id) + } + + private validateTopicData(data: CreateTopicDto) { + if (!data.name?.trim()) { + throw DataApiErrorFactory.validation({ name: ['Name is required'] }) + } + } +} +``` + +### Service with Transaction + +```typescript +async createTopicWithMessage(data: CreateTopicWithMessageDto) { + return await DbService.transaction(async (tx) => { + // Create topic + const topic = await this.topicRepo.create(data.topic, tx) + + // Create initial message + const message = await this.messageRepo.create({ + ...data.message, + topicId: topic.id + }, tx) + + return { topic, message } + }) +} +``` + +## Implementing Repositories + +### When to Use Repository Pattern + +Use repositories for **complex domains**: +- ✅ Complex queries (joins, subqueries, aggregations) +- ✅ GB-scale data requiring pagination +- ✅ Complex transactions involving multiple tables +- ✅ Reusable data access patterns +- ✅ High testing requirements + +### When to Use Direct Drizzle + +Use direct Drizzle for **simple domains**: +- ✅ Simple CRUD operations +- ✅ Small datasets (< 100MB) +- ✅ Domain-specific queries with no reuse +- ✅ Fast development is priority + +### Example Repository + +```typescript +// repositories/TopicRepository.ts +import { eq, desc, sql } from 'drizzle-orm' +import { DbService } from '@data/db/DbService' +import { topicTable } from '@data/db/schemas/topic' + +export class TopicRepository { + async findAll(options: { page: number; limit: number }) { + const { page, limit } = options + const offset = (page - 1) * limit + + const [items, countResult] = await Promise.all([ + DbService.db + .select() + .from(topicTable) + .orderBy(desc(topicTable.updatedAt)) + .limit(limit) + .offset(offset), + DbService.db + .select({ count: sql`count(*)` }) + .from(topicTable) + ]) + + return { + items, + total: countResult[0].count, + page, + limit + } + } + + async findById(id: string, tx?: Transaction) { + const db = tx || DbService.db + const [topic] = await db + .select() + .from(topicTable) + .where(eq(topicTable.id, id)) + .limit(1) + return topic ?? null + } + + async create(data: CreateTopicDto, tx?: Transaction) { + const db = tx || DbService.db + const [topic] = await db + .insert(topicTable) + .values(data) + .returning() + return topic + } + + async update(id: string, data: Partial, tx?: Transaction) { + const db = tx || DbService.db + const [topic] = await db + .update(topicTable) + .set(data) + .where(eq(topicTable.id, id)) + .returning() + return topic + } + + async delete(id: string, tx?: Transaction) { + const db = tx || DbService.db + await db + .delete(topicTable) + .where(eq(topicTable.id, id)) + } +} +``` + +### Example: Direct Drizzle in Service + +For simple domains, skip the repository: + +```typescript +// services/TagService.ts +import { eq } from 'drizzle-orm' +import { DbService } from '@data/db/DbService' +import { tagTable } from '@data/db/schemas/tag' + +export class TagService { + async getAll() { + return await DbService.db.select().from(tagTable) + } + + async create(name: string) { + const [tag] = await DbService.db + .insert(tagTable) + .values({ name }) + .returning() + return tag + } + + async delete(id: string) { + await DbService.db + .delete(tagTable) + .where(eq(tagTable.id, id)) + } +} +``` + +## Error Handling + +### Using DataApiErrorFactory + +```typescript +import { DataApiErrorFactory } from '@shared/data/api' + +// Not found +throw DataApiErrorFactory.notFound('Topic', id) + +// Validation error +throw DataApiErrorFactory.validation({ + name: ['Name is required', 'Name must be at least 3 characters'], + email: ['Invalid email format'] +}) + +// Database error +try { + await db.insert(table).values(data) +} catch (error) { + throw DataApiErrorFactory.database(error, 'insert topic') +} + +// Invalid operation +throw DataApiErrorFactory.invalidOperation( + 'delete root message', + 'cascade=true required' +) + +// Conflict +throw DataApiErrorFactory.conflict('Topic name already exists') + +// Timeout +throw DataApiErrorFactory.timeout('fetch topics', 3000) +``` + +## Adding New Endpoints + +### Step-by-Step + +1. **Define schema** in `packages/shared/data/api/schemas/` + +```typescript +// schemas/topic.ts +export interface TopicSchemas { + '/topics': { + GET: { response: PaginatedResponse } + POST: { body: CreateTopicDto; response: Topic } + } +} +``` + +2. **Register schema** in `schemas/index.ts` + +```typescript +export type ApiSchemas = AssertValidSchemas +``` + +3. **Create service** in `services/` + +4. **Create repository** (if complex) in `repositories/` + +5. **Implement handler** in `handlers/` + +6. **Register handler** in `handlers/index.ts` + +## Best Practices + +1. **Keep handlers thin**: Only extract params and call services +2. **Put logic in services**: All business rules belong in services +3. **Use repositories selectively**: Simple CRUD doesn't need a repository +4. **Always use `.returning()`**: Get inserted/updated data without re-querying +5. **Support transactions**: Accept optional `tx` parameter in repositories +6. **Validate in services**: Business validation belongs in the service layer +7. **Use error factory**: Consistent error creation with `DataApiErrorFactory` diff --git a/docs/en/references/data/data-api-in-renderer.md b/docs/en/references/data/data-api-in-renderer.md new file mode 100644 index 0000000000..8d4b9f07f6 --- /dev/null +++ b/docs/en/references/data/data-api-in-renderer.md @@ -0,0 +1,314 @@ +# DataApi in Renderer + +This guide covers how to use the DataApi system in React components and the renderer process. + +## React Hooks + +### useQuery (GET Requests) + +Fetch data with automatic caching and revalidation via SWR. + +```typescript +import { useQuery } from '@data/hooks/useDataApi' + +// Basic usage +const { data, isLoading, error } = useQuery('/topics') + +// With query parameters +const { data: messages } = useQuery('/messages', { + query: { topicId: 'abc123', page: 1, limit: 20 } +}) + +// With path parameters (inferred from path) +const { data: topic } = useQuery('/topics/abc123') + +// Conditional fetching +const { data } = useQuery('/topics', { enabled: !!topicId }) + +// With refresh callback +const { data, mutate, refetch } = useQuery('/topics') +// Refresh data +refetch() // or await mutate() +``` + +### useMutation (POST/PUT/PATCH/DELETE) + +Perform data modifications with loading states. + +```typescript +import { useMutation } from '@data/hooks/useDataApi' + +// Create (POST) +const { trigger: createTopic, isLoading } = useMutation('POST', '/topics') +const newTopic = await createTopic({ body: { name: 'New Topic' } }) + +// Update (PUT - full replacement) +const { trigger: replaceTopic } = useMutation('PUT', '/topics/abc123') +await replaceTopic({ body: { name: 'Updated Name', description: '...' } }) + +// Partial Update (PATCH) +const { trigger: updateTopic } = useMutation('PATCH', '/topics/abc123') +await updateTopic({ body: { name: 'New Name' } }) + +// Delete +const { trigger: deleteTopic } = useMutation('DELETE', '/topics/abc123') +await deleteTopic() + +// With auto-refresh of other queries +const { trigger } = useMutation('POST', '/topics', { + refresh: ['/topics'], // Refresh these keys on success + onSuccess: (data) => console.log('Created:', data) +}) +``` + +### useInfiniteQuery (Cursor-based Infinite Scroll) + +For infinite scroll UIs with "Load More" pattern. + +```typescript +import { useInfiniteQuery } from '@data/hooks/useDataApi' + +const { items, isLoading, hasNext, loadNext } = useInfiniteQuery('/messages', { + query: { topicId: 'abc123' }, + limit: 20 +}) + +// items: all loaded items flattened +// loadNext(): load next page +// hasNext: true if more pages available +``` + +### usePaginatedQuery (Offset-based Pagination) + +For page-by-page navigation with previous/next controls. + +```typescript +import { usePaginatedQuery } from '@data/hooks/useDataApi' + +const { items, page, total, hasNext, hasPrev, nextPage, prevPage } = + usePaginatedQuery('/topics', { limit: 10 }) + +// items: current page items +// page/total: current page number and total count +// nextPage()/prevPage(): navigate between pages +``` + +### Choosing Pagination Hooks + +| Use Case | Hook | +|----------|------| +| Infinite scroll, chat, feeds | `useInfiniteQuery` | +| Page navigation, tables | `usePaginatedQuery` | +| Manual control | `useQuery` | + +## DataApiService Direct Usage + +For non-React code or more control. + +```typescript +import { dataApiService } from '@data/DataApiService' + +// GET request +const topics = await dataApiService.get('/topics') +const topic = await dataApiService.get('/topics/abc123') +const messages = await dataApiService.get('/topics/abc123/messages', { + query: { page: 1, limit: 20 } +}) + +// POST request +const newTopic = await dataApiService.post('/topics', { + body: { name: 'New Topic' } +}) + +// PUT request (full replacement) +const updatedTopic = await dataApiService.put('/topics/abc123', { + body: { name: 'Updated', description: 'Full update' } +}) + +// PATCH request (partial update) +const patchedTopic = await dataApiService.patch('/topics/abc123', { + body: { name: 'Just update name' } +}) + +// DELETE request +await dataApiService.delete('/topics/abc123') +``` + +## Error Handling + +### With Hooks + +```typescript +function TopicList() { + const { data, isLoading, error } = useQuery('/topics') + + if (isLoading) return + if (error) { + if (error.code === ErrorCode.NOT_FOUND) { + return + } + return + } + + return +} +``` + +### With Try-Catch + +```typescript +import { DataApiError, ErrorCode } from '@shared/data/api' + +try { + await dataApiService.post('/topics', { body: data }) +} catch (error) { + if (error instanceof DataApiError) { + switch (error.code) { + case ErrorCode.VALIDATION_ERROR: + // Handle validation errors + const fieldErrors = error.details?.fieldErrors + break + case ErrorCode.NOT_FOUND: + // Handle not found + break + case ErrorCode.CONFLICT: + // Handle conflict + break + default: + // Handle other errors + } + } +} +``` + +### Retryable Errors + +```typescript +if (error instanceof DataApiError && error.isRetryable) { + // Safe to retry: SERVICE_UNAVAILABLE, TIMEOUT, etc. + await retry(operation) +} +``` + +## Common Patterns + +### Create Form + +```typescript +function CreateTopicForm() { + // Use refresh option to auto-refresh /topics after creation + const { trigger: createTopic, isLoading } = useMutation('POST', '/topics', { + refresh: ['/topics'] + }) + + const handleSubmit = async (data: CreateTopicDto) => { + try { + await createTopic({ body: data }) + toast.success('Topic created') + } catch (error) { + toast.error('Failed to create topic') + } + } + + return ( +
+ {/* form fields */} + +
+ ) +} +``` + +### Optimistic Updates + +```typescript +function TopicItem({ topic }: { topic: Topic }) { + // Use optimisticData for automatic optimistic updates with rollback + const { trigger: updateTopic } = useMutation('PATCH', `/topics/${topic.id}`, { + optimisticData: { ...topic, starred: !topic.starred } + }) + + const handleToggleStar = async () => { + try { + await updateTopic({ body: { starred: !topic.starred } }) + } catch (error) { + // Rollback happens automatically when optimisticData is set + toast.error('Failed to update') + } + } + + return ( +
+ {topic.name} + +
+ ) +} +``` + +### Dependent Queries + +```typescript +function MessageList({ topicId }: { topicId: string }) { + // First query: get topic + const { data: topic } = useQuery(`/topics/${topicId}`) + + // Second query: depends on first (only runs when topic exists) + const { data: messages } = useQuery( + topic ? `/topics/${topicId}/messages` : null + ) + + if (!topic) return + + return ( +
+

{topic.name}

+ +
+ ) +} +``` + +### Polling for Updates + +```typescript +function LiveTopicList() { + const { data } = useQuery('/topics', { + refreshInterval: 5000 // Poll every 5 seconds + }) + + return +} +``` + +## Type Safety + +The API is fully typed based on schema definitions: + +```typescript +// Types are inferred from schema +const { data } = useQuery('/topics') +// data is typed as PaginatedResponse + +const { trigger } = useMutation('POST', '/topics') +// trigger expects { body: CreateTopicDto } +// returns Topic + +// Path parameters are type-checked +const { data: topic } = useQuery('/topics/abc123') +// TypeScript knows this returns Topic +``` + +## Best Practices + +1. **Use hooks for components**: `useQuery` and `useMutation` handle loading/error states +2. **Choose the right pagination hook**: Use `useInfiniteQuery` for infinite scroll, `usePaginatedQuery` for page navigation +3. **Handle loading states**: Always show feedback while data is loading +4. **Handle errors gracefully**: Provide meaningful error messages to users +5. **Revalidate after mutations**: Use `refresh` option to keep the UI in sync +6. **Use conditional fetching**: Set `enabled: false` to skip queries when dependencies aren't ready +7. **Batch related operations**: Consider using transactions for multiple updates diff --git a/docs/en/references/data/data-api-overview.md b/docs/en/references/data/data-api-overview.md new file mode 100644 index 0000000000..883bd57f49 --- /dev/null +++ b/docs/en/references/data/data-api-overview.md @@ -0,0 +1,158 @@ +# DataApi System Overview + +The DataApi system provides type-safe IPC communication for business data operations between the Renderer and Main processes. + +## Purpose + +DataApiService handles data that: +- Is **business data accumulated through user activity** +- Has **dedicated database schemas/tables** +- Users can **create, delete, modify records** without fixed limits +- Would be **severe and irreplaceable** if lost +- Can grow to **large volumes** (potentially GBs) + +## Key Characteristics + +### Type-Safe Communication +- End-to-end TypeScript types from client call to handler +- Path parameter inference from route definitions +- Compile-time validation of request/response shapes + +### RESTful-Style API +- Familiar HTTP semantics (GET, POST, PUT, PATCH, DELETE) +- Resource-based URL patterns (`/topics/:id/messages`) +- Standard status codes and error responses + +### On-Demand Data Access +- No automatic caching (fetch fresh data when needed) +- Explicit cache control via query options +- Supports large datasets with pagination + +## Architecture Diagram + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Renderer Process │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ React Components │ │ +│ │ - useQuery('/topics') │ │ +│ │ - useMutation('/topics', 'POST') │ │ +│ └──────────────────────────┬──────────────────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ DataApiService (Renderer) │ │ +│ │ - Type-safe ApiClient interface │ │ +│ │ - Request serialization │ │ +│ │ - Automatic retry with exponential backoff │ │ +│ │ - Error handling and transformation │ │ +│ └──────────────────────────┬──────────────────────────────┘ │ +└──────────────────────────────┼───────────────────────────────────┘ + │ IPC +┌──────────────────────────────┼───────────────────────────────────┐ +│ Main Process ▼ │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ IpcAdapter │ │ +│ │ - Receives IPC requests │ │ +│ │ - Routes to ApiServer │ │ +│ └──────────────────────────┬──────────────────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ ApiServer │ │ +│ │ - Request routing by path and method │ │ +│ │ - Middleware pipeline processing │ │ +│ └──────────────────────────┬──────────────────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ Handlers (api/handlers/) │ │ +│ │ - Thin layer: extract params, call service, transform │ │ +│ │ - NO business logic here │ │ +│ └──────────────────────────┬──────────────────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ Services (services/) │ │ +│ │ - Business logic and validation │ │ +│ │ - Transaction coordination │ │ +│ │ - Domain workflows │ │ +│ └──────────────────────────┬──────────────────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────┴─────────────────────┐ │ +│ ▼ ▼ │ +│ ┌───────────────┐ ┌─────────────────────┐ │ +│ │ Repositories │ │ Direct Drizzle │ │ +│ │ (Complex) │ │ (Simple domains) │ │ +│ │ - Query logic │ │ - Inline queries │ │ +│ └───────┬───────┘ └──────────┬──────────┘ │ +│ │ │ │ +│ └────────────────────┬───────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ SQLite Database (via Drizzle ORM) │ │ +│ │ - topic, message, file tables │ │ +│ │ - Full-text search indexes │ │ +│ └─────────────────────────────────────────────────────────┘ │ +└──────────────────────────────────────────────────────────────────┘ +``` + +## Four-Layer Architecture + +### 1. API Layer (Handlers) +- **Location**: `src/main/data/api/handlers/` +- **Responsibility**: HTTP-like interface layer +- **Does**: Extract parameters, call services, transform responses +- **Does NOT**: Contain business logic + +### 2. Business Logic Layer (Services) +- **Location**: `src/main/data/services/` +- **Responsibility**: Domain logic and workflows +- **Does**: Validation, transaction coordination, orchestration +- **Uses**: Repositories or direct Drizzle queries + +### 3. Data Access Layer (Repositories) +- **Location**: `src/main/data/repositories/` +- **Responsibility**: Complex data operations +- **When to use**: Complex queries, large datasets, reusable patterns +- **Alternative**: Direct Drizzle for simple CRUD + +### 4. Database Layer +- **Location**: `src/main/data/db/` +- **Technology**: SQLite + Drizzle ORM +- **Schemas**: `db/schemas/` directory + +## Data Access Pattern Decision + +### Use Repository Pattern When: +- ✅ Complex queries (joins, subqueries, aggregations) +- ✅ GB-scale data requiring optimization and pagination +- ✅ Complex transactions involving multiple tables +- ✅ Reusable data access patterns across services +- ✅ High testing requirements (mock data access) + +### Use Direct Drizzle When: +- ✅ Simple CRUD operations +- ✅ Small datasets (< 100MB) +- ✅ Domain-specific queries with no reuse potential +- ✅ Fast development is priority + +## Key Features + +### Automatic Retry +- Exponential backoff for transient failures +- Configurable retry count and delays +- Skips retry for client errors (4xx) + +### Error Handling +- Typed error codes (`ErrorCode` enum) +- `DataApiError` class with retryability detection +- Factory methods for consistent error creation + +### Request Timeout +- Configurable per-request timeouts +- Automatic cancellation of stale requests + +## Usage Summary + +For detailed code examples, see: +- [DataApi in Renderer](./data-api-in-renderer.md) - Client-side usage +- [DataApi in Main](./data-api-in-main.md) - Server-side implementation +- [API Design Guidelines](./api-design-guidelines.md) - RESTful conventions +- [API Types](./api-types.md) - Type system details diff --git a/docs/en/references/data/database-patterns.md b/docs/en/references/data/database-patterns.md new file mode 100644 index 0000000000..c4745832b9 --- /dev/null +++ b/docs/en/references/data/database-patterns.md @@ -0,0 +1,207 @@ +# Database Schema Guidelines + +## Naming Conventions + +- **Table names**: Use **singular** form with snake_case (e.g., `topic`, `message`, `app_state`) +- **Export names**: Use `xxxTable` pattern (e.g., `topicTable`, `messageTable`) +- **Column names**: Drizzle auto-infers from property names, no need to specify explicitly + +## Column Helpers + +All helpers are exported from `./schemas/columnHelpers.ts`. + +### Primary Keys + +| Helper | UUID Version | Use Case | +|--------|--------------|----------| +| `uuidPrimaryKey()` | v4 (random) | General purpose tables | +| `uuidPrimaryKeyOrdered()` | v7 (time-ordered) | Large tables with time-based queries | + +**Usage:** + +```typescript +import { uuidPrimaryKey, uuidPrimaryKeyOrdered } from './columnHelpers' + +// General purpose table +export const topicTable = sqliteTable('topic', { + id: uuidPrimaryKey(), + name: text(), + ... +}) + +// Large table with time-ordered data +export const messageTable = sqliteTable('message', { + id: uuidPrimaryKeyOrdered(), + content: text(), + ... +}) +``` + +**Behavior:** + +- ID is auto-generated if not provided during insert +- Can be manually specified for migration scenarios +- Use `.returning()` to get the generated ID after insert + +### Timestamps + +| Helper | Fields | Use Case | +|--------|--------|----------| +| `createUpdateTimestamps` | `createdAt`, `updatedAt` | Tables without soft delete | +| `createUpdateDeleteTimestamps` | `createdAt`, `updatedAt`, `deletedAt` | Tables with soft delete | + +**Usage:** + +```typescript +import { createUpdateTimestamps, createUpdateDeleteTimestamps } from './columnHelpers' + +// Without soft delete +export const tagTable = sqliteTable('tag', { + id: uuidPrimaryKey(), + name: text(), + ...createUpdateTimestamps +}) + +// With soft delete +export const topicTable = sqliteTable('topic', { + id: uuidPrimaryKey(), + name: text(), + ...createUpdateDeleteTimestamps +}) +``` + +**Behavior:** + +- `createdAt`: Auto-set to `Date.now()` on insert +- `updatedAt`: Auto-set on insert, auto-updated on update +- `deletedAt`: `null` by default, set to timestamp for soft delete + +## JSON Fields + +For JSON column support, use `{ mode: 'json' }`: + +```typescript +data: text({ mode: 'json' }).$type() +``` + +Drizzle handles JSON serialization/deserialization automatically. + +## Foreign Keys + +### Basic Usage + +```typescript +// SET NULL: preserve record when referenced record is deleted +groupId: text().references(() => groupTable.id, { onDelete: 'set null' }) + +// CASCADE: delete record when referenced record is deleted +topicId: text().references(() => topicTable.id, { onDelete: 'cascade' }) +``` + +### Self-Referencing Foreign Keys + +For self-referencing foreign keys (e.g., tree structures with parentId), **always use the `foreignKey` operator** in the table's third parameter: + +```typescript +import { foreignKey, sqliteTable, text } from 'drizzle-orm/sqlite-core' + +export const messageTable = sqliteTable( + 'message', + { + id: uuidPrimaryKeyOrdered(), + parentId: text(), // Do NOT use .references() here + // ...other fields + }, + (t) => [ + // Use foreignKey operator for self-referencing + foreignKey({ columns: [t.parentId], foreignColumns: [t.id] }).onDelete('set null') + ] +) +``` + +**Why this approach:** +- Avoids TypeScript circular reference issues (no need for `AnySQLiteColumn` type annotation) +- More explicit and readable +- Allows chaining `.onDelete()` / `.onUpdate()` actions + +### Circular Foreign Key References + +**Avoid circular foreign key references between tables.** For example: + +```typescript +// ❌ BAD: Circular FK between tables +// tableA.currentItemId -> tableB.id +// tableB.ownerId -> tableA.id +``` + +If you encounter a scenario that seems to require circular references: + +1. **Identify which relationship is "weaker"** - typically the one that can be null or is less critical for data integrity +2. **Remove the FK constraint from the weaker side** - let the application layer handle validation and consistency (this is known as "soft references" pattern) +3. **Document the application-layer constraint** in code comments + +```typescript +// ✅ GOOD: Break the cycle by handling one side at application layer +export const topicTable = sqliteTable('topic', { + id: uuidPrimaryKey(), + // Application-managed reference (no FK constraint) + // Validated by TopicService.setCurrentMessage() + currentMessageId: text(), +}) + +export const messageTable = sqliteTable('message', { + id: uuidPrimaryKeyOrdered(), + // Database-enforced FK + topicId: text().references(() => topicTable.id, { onDelete: 'cascade' }), +}) +``` + +**Why soft references for SQLite:** +- SQLite does not support `DEFERRABLE` constraints (unlike PostgreSQL/Oracle) +- Application-layer validation provides equivalent data integrity +- Simplifies insert/update operations without transaction ordering concerns + +## Migrations + +Generate migrations after schema changes: + +```bash +yarn db:migrations:generate +``` + +## Field Generation Rules + +The schema uses Drizzle's auto-generation features. Follow these rules: + +### Auto-generated fields (NEVER set manually) + +- `id`: Uses `$defaultFn()` with UUID v4/v7, auto-generated on insert +- `createdAt`: Uses `$defaultFn()` with `Date.now()`, auto-generated on insert +- `updatedAt`: Uses `$defaultFn()` and `$onUpdateFn()`, auto-updated on every update + +### Using `.returning()` pattern + +Always use `.returning()` to get inserted/updated data instead of re-querying: + +```typescript +// Good: Use returning() +const [row] = await db.insert(table).values(data).returning() +return rowToEntity(row) + +// Avoid: Re-query after insert (unnecessary database round-trip) +await db.insert(table).values({ id, ...data }) +return this.getById(id) +``` + +### Soft delete support + +The schema supports soft delete via `deletedAt` field (see `createUpdateDeleteTimestamps`). +Business logic can choose to use soft delete or hard delete based on requirements. + +## Custom SQL + +Drizzle cannot manage triggers and virtual tables (e.g., FTS5). These are defined in `customSql.ts` and run automatically after every migration. + +**Why**: SQLite's `DROP TABLE` removes associated triggers. When Drizzle modifies a table schema, it drops and recreates the table, losing triggers in the process. + +**Adding new custom SQL**: Define statements as `string[]` in the relevant schema file, then spread into `CUSTOM_SQL_STATEMENTS` in `customSql.ts`. All statements must use `IF NOT EXISTS` to be idempotent. diff --git a/docs/en/references/data/preference-overview.md b/docs/en/references/data/preference-overview.md new file mode 100644 index 0000000000..755571c659 --- /dev/null +++ b/docs/en/references/data/preference-overview.md @@ -0,0 +1,144 @@ +# Preference System Overview + +The Preference system provides centralized management for user configuration and application settings with cross-window synchronization. + +## Purpose + +PreferenceService handles data that: +- Is a **user-modifiable setting that affects app behavior** +- Has a **fixed key structure** with stable value types +- Needs to **persist permanently** until explicitly changed +- Should **sync automatically** across all application windows + +## Key Characteristics + +### Fixed Key Structure +- Predefined keys in the schema (users modify values, not keys) +- Supports 158 configuration items +- Nested key paths supported (e.g., `app.theme.mode`) + +### Atomic Values +- Each preference item represents one logical setting +- Values are typically: boolean, string, number, or simple array/object +- Changes are independent (updating one doesn't affect others) + +### Cross-Window Synchronization +- Changes automatically broadcast to all windows +- Consistent state across main window, mini window, etc. +- Conflict resolution handled by Main process + +## Update Strategies + +### Optimistic Updates (Default) +```typescript +// UI updates immediately, then syncs to database +await preferenceService.set('app.theme.mode', 'dark') +``` +- Best for: frequent, non-critical settings +- Behavior: Local state updates first, then persists +- Rollback: Automatic revert if persistence fails + +### Pessimistic Updates +```typescript +// Waits for database confirmation before updating UI +await preferenceService.set('api.key', 'secret', { optimistic: false }) +``` +- Best for: critical settings (API keys, security options) +- Behavior: Persists first, then updates local state +- No rollback needed: UI only updates on success + +## Architecture Diagram + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Renderer Process │ +│ ┌─────────────────────────────────────────────────┐ │ +│ │ usePreference Hook │ │ +│ │ - Subscribe to preference changes │ │ +│ │ - Optimistic/pessimistic update support │ │ +│ └──────────────────────┬──────────────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────┐ │ +│ │ PreferenceService (Renderer) │ │ +│ │ - Local cache for fast reads │ │ +│ │ - IPC proxy to Main process │ │ +│ │ - Subscription management │ │ +│ └──────────────────────┬──────────────────────────┘ │ +└─────────────────────────┼────────────────────────────────────┘ + │ IPC +┌─────────────────────────┼────────────────────────────────────┐ +│ Main Process ▼ │ +│ ┌─────────────────────────────────────────────────┐ │ +│ │ PreferenceService (Main) │ │ +│ │ - Full memory cache of all preferences │ │ +│ │ - SQLite persistence via Drizzle ORM │ │ +│ │ - Cross-window broadcast │ │ +│ └──────────────────────┬──────────────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────┐ │ +│ │ SQLite Database (preference table) │ │ +│ │ - scope + key structure │ │ +│ │ - JSON value storage │ │ +│ └─────────────────────────────────────────────────┘ │ +└──────────────────────────────────────────────────────────────┘ +``` + +## Main vs Renderer Responsibilities + +### Main Process PreferenceService +- **Source of truth** for all preferences +- Full memory cache for fast access +- SQLite persistence via preference table +- Broadcasts changes to all renderer windows +- Handles batch operations and transactions + +### Renderer Process PreferenceService +- Local cache for read performance +- Proxies write operations to Main +- Manages React hook subscriptions +- Handles optimistic update rollbacks +- Listens for cross-window updates + +## Database Schema + +Preferences are stored in the `preference` table: + +```typescript +// Simplified schema +{ + scope: string // e.g., 'default', 'user' + key: string // e.g., 'app.theme.mode' + value: json // The preference value + createdAt: number + updatedAt: number +} +``` + +## Preference Categories + +### Application Settings +- Theme mode, language, font sizes +- Window behavior, startup options + +### Feature Toggles +- Show/hide UI elements +- Enable/disable features + +### User Customization +- Keyboard shortcuts +- Default values for operations + +### Provider Configuration +- AI provider settings +- API endpoints and tokens + +## Usage Summary + +For detailed code examples and API usage, see [Preference Usage Guide](./preference-usage.md). + +| Operation | Hook | Service Method | +|-----------|------|----------------| +| Read single | `usePreference(key)` | `preferenceService.get(key)` | +| Write single | `setPreference(value)` | `preferenceService.set(key, value)` | +| Read multiple | `usePreferences([...keys])` | `preferenceService.getMultiple([...keys])` | +| Write multiple | - | `preferenceService.setMultiple({...})` | diff --git a/docs/en/references/data/preference-usage.md b/docs/en/references/data/preference-usage.md new file mode 100644 index 0000000000..70a0586724 --- /dev/null +++ b/docs/en/references/data/preference-usage.md @@ -0,0 +1,260 @@ +# Preference Usage Guide + +This guide covers how to use the Preference system in React components and services. + +## React Hooks + +### usePreference (Single Preference) + +```typescript +import { usePreference } from '@data/hooks/usePreference' + +// Basic usage - optimistic updates (default) +const [theme, setTheme] = usePreference('app.theme.mode') + +// Update the value +await setTheme('dark') + +// With pessimistic updates (wait for confirmation) +const [apiKey, setApiKey] = usePreference('api.key', { optimistic: false }) +``` + +### usePreferences (Multiple Preferences) + +```typescript +import { usePreferences } from '@data/hooks/usePreference' + +// Read multiple preferences at once +const { theme, language, fontSize } = usePreferences([ + 'app.theme.mode', + 'app.language', + 'chat.message.font_size' +]) +``` + +## Update Strategies + +### Optimistic Updates (Default) + +UI updates immediately, then syncs to database. Automatic rollback on failure. + +```typescript +const [theme, setTheme] = usePreference('app.theme.mode') + +const handleThemeChange = async (newTheme: string) => { + try { + await setTheme(newTheme) // UI updates immediately + } catch (error) { + // UI automatically rolls back + console.error('Theme update failed:', error) + } +} +``` + +**Best for:** +- Frequent changes (theme, font size) +- Non-critical settings +- Better perceived performance + +### Pessimistic Updates + +Waits for database confirmation before updating UI. + +```typescript +const [apiKey, setApiKey] = usePreference('api.key', { optimistic: false }) + +const handleApiKeyChange = async (newKey: string) => { + try { + await setApiKey(newKey) // Waits for DB confirmation + toast.success('API key saved') + } catch (error) { + toast.error('Failed to save API key') + } +} +``` + +**Best for:** +- Security-sensitive settings (API keys, passwords) +- Settings that affect external services +- When confirmation feedback is important + +## PreferenceService Direct Usage + +For non-React code or batch operations. + +### Get Preferences + +```typescript +import { preferenceService } from '@data/PreferenceService' + +// Get single preference +const theme = await preferenceService.get('app.theme.mode') + +// Get multiple preferences +const settings = await preferenceService.getMultiple([ + 'app.theme.mode', + 'app.language' +]) +// Returns: { 'app.theme.mode': 'dark', 'app.language': 'en' } + +// Get with default value +const fontSize = await preferenceService.get('chat.message.font_size') ?? 14 +``` + +### Set Preferences + +```typescript +// Set single preference (optimistic by default) +await preferenceService.set('app.theme.mode', 'dark') + +// Set with pessimistic update +await preferenceService.set('api.key', 'secret', { optimistic: false }) + +// Set multiple preferences at once +await preferenceService.setMultiple({ + 'app.theme.mode': 'dark', + 'app.language': 'en', + 'chat.message.font_size': 16 +}) +``` + +### Subscribe to Changes + +```typescript +// Subscribe to preference changes (useful in services) +const unsubscribe = preferenceService.subscribe('app.theme.mode', (newValue) => { + console.log('Theme changed to:', newValue) +}) + +// Cleanup when done +unsubscribe() +``` + +## Common Patterns + +### Settings Form + +```typescript +function SettingsForm() { + const [theme, setTheme] = usePreference('app.theme.mode') + const [language, setLanguage] = usePreference('app.language') + const [fontSize, setFontSize] = usePreference('chat.message.font_size') + + return ( +
+ + + + + setFontSize(Number(e.target.value))} + min={12} + max={24} + /> +
+ ) +} +``` + +### Feature Toggle + +```typescript +function ChatMessage({ message }) { + const [showTimestamp] = usePreference('chat.display.show_timestamp') + + return ( +
+

{message.content}

+ {showTimestamp && {message.createdAt}} +
+ ) +} +``` + +### Conditional Rendering Based on Settings + +```typescript +function App() { + const [theme] = usePreference('app.theme.mode') + const [sidebarPosition] = usePreference('app.sidebar.position') + + return ( +
+ {sidebarPosition === 'left' && } + + {sidebarPosition === 'right' && } +
+ ) +} +``` + +### Batch Settings Update + +```typescript +async function resetToDefaults() { + await preferenceService.setMultiple({ + 'app.theme.mode': 'system', + 'app.language': 'en', + 'chat.message.font_size': 14, + 'chat.display.show_timestamp': true + }) +} +``` + +## Adding New Preference Keys + +### 1. Add to Preference Schema + +```typescript +// packages/shared/data/preference/preferenceSchemas.ts +export interface PreferenceSchema { + // Existing keys... + 'myFeature.enabled': boolean + 'myFeature.options': MyFeatureOptions +} +``` + +### 2. Set Default Value + +```typescript +// Same file or separate defaults file +export const preferenceDefaults: Partial = { + // Existing defaults... + 'myFeature.enabled': true, + 'myFeature.options': { mode: 'auto', limit: 100 } +} +``` + +### 3. Use in Code + +```typescript +// Now type-safe with auto-completion +const [enabled, setEnabled] = usePreference('myFeature.enabled') +``` + +## Best Practices + +1. **Choose update strategy wisely**: Optimistic for UX, pessimistic for critical settings +2. **Batch related updates**: Use `setMultiple` when changing multiple related settings +3. **Provide sensible defaults**: All preferences should have default values +4. **Keep values atomic**: One preference = one logical setting +5. **Use consistent naming**: Follow `domain.feature.setting` pattern + +## Preference vs Other Storage + +| Scenario | Use | +|----------|-----| +| User theme preference | `usePreference('app.theme.mode')` | +| Window position | `usePersistCache` (can be lost without impact) | +| API key | `usePreference` with pessimistic updates | +| Search history | `usePersistCache` (nice to have) | +| Conversation history | `DataApiService` (business data) | diff --git a/docs/en/references/data/v2-migration-guide.md b/docs/en/references/data/v2-migration-guide.md new file mode 100644 index 0000000000..8d08dd8d3a --- /dev/null +++ b/docs/en/references/data/v2-migration-guide.md @@ -0,0 +1,72 @@ +# Migration V2 (Main Process) + +Architecture for the new one-shot migration from the legacy Dexie + Redux Persist stores into the SQLite schema. This module owns orchestration, data access helpers, migrator plugins, and IPC entry points used by the renderer migration window. + +## Directory Layout + +``` +src/main/data/migration/v2/ +├── core/ # Engine + shared context +├── migrators/ # Domain-specific migrators and mappings +├── utils/ # Data source readers (Redux, Dexie, streaming JSON) +├── window/ # IPC handlers + migration window manager +└── index.ts # Public exports for main process +``` + +## Core Contracts + +- `core/MigrationEngine.ts` coordinates all migrators in order, surfaces progress to the UI, and marks status in `app_state.key = 'migration_v2_status'`. It will clear new-schema tables before running and abort on any validation failure. +- `core/MigrationContext.ts` builds the shared context passed to every migrator: + - `sources`: `ConfigManager` (ElectronStore), `ReduxStateReader` (parsed Redux Persist data), `DexieFileReader` (JSON exports) + - `db`: current SQLite connection + - `sharedData`: `Map` for passing cross-cutting info between migrators + - `logger`: `loggerService` scoped to migration +- `@shared/data/migration/v2/types` defines stages, results, and validation stats used across main and renderer. + +## Migrators + +- Base contract: extend `migrators/BaseMigrator.ts` and implement: + - `id`, `name`, `description`, `order` (lower runs first) + - `prepare(ctx)`: dry-run checks, counts, and staging data; return `PrepareResult` + - `execute(ctx)`: perform inserts/updates; manage your own transactions; report progress via `reportProgress` + - `validate(ctx)`: verify counts and integrity; return `ValidateResult` with stats (`sourceCount`, `targetCount`, `skippedCount`) and any `errors` +- Registration: list migrators (in order) in `migrators/index.ts` so the engine can sort and run them. +- Current migrators (see `migrators/README-.md` for detailed documentation): + - `PreferencesMigrator` (implemented): maps ElectronStore + Redux settings to the `preference` table using `mappings/PreferencesMappings.ts`. + - `ChatMigrator` (implemented): migrates topics and messages from Dexie to SQLite. See [`README-ChatMigrator.md`](../../../src/main/data/migration/v2/migrators/README-ChatMigrator.md). + - `AssistantMigrator`, `KnowledgeMigrator` (placeholders): scaffolding and TODO notes for future tables. +- Conventions: + - All logging goes through `loggerService` with a migrator-specific context. + - Use `MigrationContext.sources` instead of accessing raw files/stores directly. + - Use `sharedData` to pass IDs or lookup tables between migrators (e.g., assistant -> chat references) instead of re-reading sources. + - Stream large Dexie exports (`JSONStreamReader`) and batch inserts to avoid memory spikes. + - Count validation is mandatory; engine will fail the run if `targetCount < sourceCount - skippedCount` or if `ValidateResult.errors` is non-empty. + - Keep migrations idempotent per run—engine clears target tables before it starts, but each migrator should tolerate retries within the same run. + +## Utilities + +- `utils/ReduxStateReader.ts`: safe accessor for categorized Redux Persist data with dot-path lookup. +- `utils/DexieFileReader.ts`: reads exported Dexie JSON tables; can stream large tables. +- `utils/JSONStreamReader.ts`: streaming reader with batching, counting, and sampling helpers for very large arrays. + +## Window & IPC Integration + +- `window/MigrationIpcHandler.ts` exposes IPC channels for the migration UI: + - Receives Redux data and Dexie export path, starts the engine, and streams progress back to renderer. + - Manages backup flow (dialogs via `BackupManager`) and retry/cancel/restart actions. +- `window/MigrationWindowManager.ts` creates the frameless migration window, handles lifecycle, and relaunch instructions after completion in production. + +## Implementation Checklist for New Migrators + +- [ ] Add mapping definitions (if needed) under `migrators/mappings/`. +- [ ] Implement `prepare/execute/validate` with explicit counts, batch inserts, and integrity checks. +- [ ] Wire progress updates through `reportProgress` so UI shows per-migrator progress. +- [ ] Register the migrator in `migrators/index.ts` with the correct `order`. +- [ ] Add any new target tables to `MigrationEngine.verifyAndClearNewTables` once those tables exist. +- [ ] Include detailed comments for maintainability (file-level, function-level, logic blocks). +- [ ] **Create/update `migrators/README-.md`** with detailed documentation including: + - Data sources and target tables + - Key transformations + - Field mappings (source → target) + - Dropped fields and rationale + - Code quality notes diff --git a/docs/en/references/fuzzy-search.md b/docs/en/references/fuzzy-search.md new file mode 100644 index 0000000000..11c2002cb9 --- /dev/null +++ b/docs/en/references/fuzzy-search.md @@ -0,0 +1,129 @@ +# Fuzzy Search for File List + +This document describes the fuzzy search implementation for file listing in Cherry Studio. + +## Overview + +The fuzzy search feature allows users to find files by typing partial or approximate file names/paths. It uses a two-tier file filtering strategy (ripgrep glob pre-filtering with greedy substring fallback) combined with subsequence-based scoring for optimal performance and flexibility. + +## Features + +- **Ripgrep Glob Pre-filtering**: Primary filtering using glob patterns for fast native-level filtering +- **Greedy Substring Matching**: Fallback file filtering strategy when ripgrep glob pre-filtering returns no results +- **Subsequence-based Segment Scoring**: During scoring, path segments gain additional weight when query characters appear in order +- **Relevance Scoring**: Results are sorted by a relevance score derived from multiple factors + +## Matching Strategies + +### 1. Ripgrep Glob Pre-filtering (Primary) + +The query is converted to a glob pattern for ripgrep to do initial filtering: + +``` +Query: "updater" +Glob: "*u*p*d*a*t*e*r*" +``` + +This leverages ripgrep's native performance for the initial file filtering. + +### 2. Greedy Substring Matching (Fallback) + +When the glob pre-filter returns no results, the system falls back to greedy substring matching. This allows more flexible matching: + +``` +Query: "updatercontroller" +File: "packages/update/src/node/updateController.ts" + +Matching process: +1. Find "update" (longest match from start) +2. Remaining "rcontroller" → find "r" then "controller" +3. All parts matched → Success +``` + +## Scoring Algorithm + +Results are ranked by a relevance score based on named constants defined in `FileStorage.ts`: + +| Constant | Value | Description | +|----------|-------|-------------| +| `SCORE_FILENAME_STARTS` | 100 | Filename starts with query (highest priority) | +| `SCORE_FILENAME_CONTAINS` | 80 | Filename contains exact query substring | +| `SCORE_SEGMENT_MATCH` | 60 | Per path segment that matches query | +| `SCORE_WORD_BOUNDARY` | 20 | Query matches start of a word | +| `SCORE_CONSECUTIVE_CHAR` | 15 | Per consecutive character match | +| `PATH_LENGTH_PENALTY_FACTOR` | 4 | Logarithmic penalty for longer paths | + +### Scoring Strategy + +The scoring prioritizes: +1. **Filename matches** (highest): Files where the query appears in the filename are most relevant +2. **Path segment matches**: Multiple matching segments indicate stronger relevance +3. **Word boundaries**: Matching at word starts (e.g., "upd" matching "update") is preferred +4. **Consecutive matches**: Longer consecutive character sequences score higher +5. **Path length**: Shorter paths are preferred (logarithmic penalty prevents long paths from dominating) + +### Example Scoring + +For query `updater`: + +| File | Score Factors | +|------|---------------| +| `RCUpdater.js` | Short path + filename contains "updater" | +| `updateController.ts` | Multiple segment matches | +| `UpdaterHelper.plist` | Long path penalty | + +## Configuration + +### DirectoryListOptions + +```typescript +interface DirectoryListOptions { + recursive?: boolean // Default: true + maxDepth?: number // Default: 10 + includeHidden?: boolean // Default: false + includeFiles?: boolean // Default: true + includeDirectories?: boolean // Default: true + maxEntries?: number // Default: 20 + searchPattern?: string // Default: '.' + fuzzy?: boolean // Default: true +} +``` + +## Usage + +```typescript +// Basic fuzzy search +const files = await window.api.file.listDirectory(dirPath, { + searchPattern: 'updater', + fuzzy: true, + maxEntries: 20 +}) + +// Disable fuzzy search (exact glob matching) +const files = await window.api.file.listDirectory(dirPath, { + searchPattern: 'update', + fuzzy: false +}) +``` + +## Performance Considerations + +1. **Ripgrep Pre-filtering**: Most queries are handled by ripgrep's native glob matching, which is extremely fast +2. **Fallback Only When Needed**: Greedy substring matching (which loads all files) only runs when glob matching returns empty results +3. **Result Limiting**: Only top 20 results are returned by default +4. **Excluded Directories**: Common large directories are automatically excluded: + - `node_modules` + - `.git` + - `dist`, `build` + - `.next`, `.nuxt` + - `coverage`, `.cache` + +## Implementation Details + +The implementation is located in `src/main/services/FileStorage.ts`: + +- `queryToGlobPattern()`: Converts query to ripgrep glob pattern +- `isFuzzyMatch()`: Subsequence matching algorithm +- `isGreedySubstringMatch()`: Greedy substring matching fallback +- `getFuzzyMatchScore()`: Calculates relevance score +- `listDirectoryWithRipgrep()`: Main search orchestration diff --git a/docs/zh/README.md b/docs/zh/README.md index f8a1f1ab8c..c4adeb4901 100644 --- a/docs/zh/README.md +++ b/docs/zh/README.md @@ -34,7 +34,7 @@

- English | 中文 | 官方网站 | 文档 | 开发 | 反馈
+ English | 中文 | 官方网站 | 文档 | 开发 | 反馈

@@ -281,7 +281,7 @@ https://docs.cherry-ai.com # 📊 GitHub 统计 -![Stats](https://repobeats.axiom.co/api/embed/a693f2e5f773eed620f70031e974552156c7f397.svg 'Repobeats analytics image') +![Stats](https://repobeats.axiom.co/api/embed/a693f2e5f773eed620f70031e974552156c7f397.svg "Repobeats analytics image") # ⭐️ Star 记录 diff --git a/docs/zh/guides/development.md b/docs/zh/guides/development.md index fe67742768..032a515f61 100644 --- a/docs/zh/guides/development.md +++ b/docs/zh/guides/development.md @@ -36,7 +36,7 @@ yarn install ### ENV ```bash -copy .env.example .env +cp .env.example .env ``` ### Start diff --git a/docs/zh/guides/i18n.md b/docs/zh/guides/i18n.md index 82624d35c8..c8a8ccc66b 100644 --- a/docs/zh/guides/i18n.md +++ b/docs/zh/guides/i18n.md @@ -1,17 +1,17 @@ # 如何优雅地做好 i18n -## 使用i18n ally插件提升开发体验 +## 使用 i18n ally 插件提升开发体验 -i18n ally是一个强大的VSCode插件,它能在开发阶段提供实时反馈,帮助开发者更早发现文案缺失和错译问题。 +i18n ally 是一个强大的 VSCode 插件,它能在开发阶段提供实时反馈,帮助开发者更早发现文案缺失和错译问题。 项目中已经配置好了插件设置,直接安装即可。 ### 开发时优势 - **实时预览**:翻译文案会直接显示在编辑器中 -- **错误检测**:自动追踪标记出缺失的翻译或未使用的key -- **快速跳转**:可通过key直接跳转到定义处(Ctrl/Cmd + click) -- **自动补全**:输入i18n key时提供自动补全建议 +- **错误检测**:自动追踪标记出缺失的翻译或未使用的 key +- **快速跳转**:可通过 key 直接跳转到定义处(Ctrl/Cmd + click) +- **自动补全**:输入 i18n key 时提供自动补全建议 ### 效果展示 @@ -23,9 +23,9 @@ i18n ally是一个强大的VSCode插件,它能在开发阶段提供实时反 ## i18n 约定 -### **绝对避免使用flat格式** +### **绝对避免使用 flat 格式** -绝对避免使用flat格式,如`"add.button.tip": "添加"`。应采用清晰的嵌套结构: +绝对避免使用 flat 格式,如`"add.button.tip": "添加"`。应采用清晰的嵌套结构: ```json // 错误示例 - flat结构 @@ -52,14 +52,14 @@ i18n ally是一个强大的VSCode插件,它能在开发阶段提供实时反 #### 为什么要使用嵌套结构 1. **自然分组**:通过对象结构天然能将相关上下文的文案分到一个组别中 -2. **插件要求**:i18n ally 插件需要嵌套或flat格式其一的文件才能正常分析 +2. **插件要求**:i18n ally 插件需要嵌套或 flat 格式其一的文件才能正常分析 ### **避免在`t()`中使用模板字符串** -**强烈建议避免使用模板字符串**进行动态插值。虽然模板字符串在JavaScript开发中非常方便,但在国际化场景下会带来一系列问题。 +**强烈建议避免使用模板字符串**进行动态插值。虽然模板字符串在 JavaScript 开发中非常方便,但在国际化场景下会带来一系列问题。 1. **插件无法跟踪** - i18n ally等工具无法解析模板字符串中的动态内容,导致: + i18n ally 等工具无法解析模板字符串中的动态内容,导致: - 无法正确显示实时预览 - 无法检测翻译缺失 @@ -67,11 +67,11 @@ i18n ally是一个强大的VSCode插件,它能在开发阶段提供实时反 ```javascript // 不推荐 - 插件无法解析 - const message = t(`fruits.${fruit}`) + const message = t(`fruits.${fruit}`); ``` 2. **编辑器无法实时渲染** - 在IDE中,模板字符串会显示为原始代码而非最终翻译结果,降低了开发体验。 + 在 IDE 中,模板字符串会显示为原始代码而非最终翻译结果,降低了开发体验。 3. **更难以维护** 由于插件无法跟踪这样的文案,编辑器中也无法渲染,开发者必须人工确认语言文件中是否存在相应的文案。 @@ -85,36 +85,36 @@ i18n ally是一个强大的VSCode插件,它能在开发阶段提供实时反 ```ts // src/renderer/src/i18n/label.ts const themeModeKeyMap = { - dark: 'settings.theme.dark', - light: 'settings.theme.light', - system: 'settings.theme.system' -} as const + dark: "settings.theme.dark", + light: "settings.theme.light", + system: "settings.theme.system", +} as const; export const getThemeModeLabel = (key: string): string => { - return themeModeKeyMap[key] ? t(themeModeKeyMap[key]) : key -} + return themeModeKeyMap[key] ? t(themeModeKeyMap[key]) : key; +}; ``` 通过避免模板字符串,可以获得更好的开发体验、更可靠的翻译检查以及更易维护的代码库。 ## 自动化脚本 -项目中有一系列脚本来自动化i18n相关任务: +项目中有一系列脚本来自动化 i18n 相关任务: -### `check:i18n` - 检查i18n结构 +### `i18n:check` - 检查 i18n 结构 此脚本会检查: - 所有语言文件是否为嵌套结构 -- 是否存在缺失的key -- 是否存在多余的key +- 是否存在缺失的 key +- 是否存在多余的 key - 是否已经有序 ```bash -yarn check:i18n +yarn i18n:check ``` -### `sync:i18n` - 同步json结构与排序 +### `i18n:sync` - 同步 json 结构与排序 此脚本以`zh-cn.json`文件为基准,将结构同步到其他语言文件,包括: @@ -123,14 +123,14 @@ yarn check:i18n 3. 自动排序 ```bash -yarn sync:i18n +yarn i18n:sync ``` -### `auto:i18n` - 自动翻译待翻译文本 +### `i18n:translate` - 自动翻译待翻译文本 次脚本自动将标记为待翻译的文本通过机器翻译填充。 -通常,在`zh-cn.json`中添加所需文案后,执行`sync:i18n`即可自动完成翻译。 +通常,在`zh-cn.json`中添加所需文案后,执行`i18n:sync`即可自动完成翻译。 使用该脚本前,需要配置环境变量,例如: @@ -143,29 +143,19 @@ MODEL="qwen-plus-latest" 你也可以通过直接编辑`.env`文件来添加环境变量。 ```bash -yarn auto:i18n -``` - -### `update:i18n` - 对象级别翻译更新 - -对`src/renderer/src/i18n/translate`中的语言文件进行对象级别的翻译更新,保留已有翻译,只更新新增内容。 - -**不建议**使用该脚本,更推荐使用`auto:i18n`进行翻译。 - -```bash -yarn update:i18n +yarn i18n:translate ``` ### 工作流 1. 开发阶段,先在`zh-cn.json`中添加所需文案 -2. 确认在中文环境下显示无误后,使用`yarn sync:i18n`将文案同步到其他语言文件 -3. 使用`yarn auto:i18n`进行自动翻译 +2. 确认在中文环境下显示无误后,使用`yarn i18n:sync`将文案同步到其他语言文件 +3. 使用`yarn i18n:translate`进行自动翻译 4. 喝杯咖啡,等翻译完成吧! ## 最佳实践 1. **以中文为源语言**:所有开发首先使用中文,再翻译为其他语言 -2. **提交前运行检查脚本**:使用`yarn check:i18n`检查i18n是否有问题 +2. **提交前运行检查脚本**:使用`yarn i18n:check`检查 i18n 是否有问题 3. **小步提交翻译**:避免积累大量未翻译文本 -4. **保持key语义明确**:key应能清晰表达其用途,如`user.profile.avatar.upload.error` +4. **保持 key 语义明确**:key 应能清晰表达其用途,如`user.profile.avatar.upload.error` diff --git a/docs/zh/references/fuzzy-search.md b/docs/zh/references/fuzzy-search.md new file mode 100644 index 0000000000..d28d189928 --- /dev/null +++ b/docs/zh/references/fuzzy-search.md @@ -0,0 +1,129 @@ +# 文件列表模糊搜索 + +本文档描述了 Cherry Studio 中文件列表的模糊搜索实现。 + +## 概述 + +模糊搜索功能允许用户通过输入部分或近似的文件名/路径来查找文件。它使用两层文件过滤策略(ripgrep glob 预过滤 + 贪婪子串匹配回退),结合基于子序列的评分,以获得最佳性能和灵活性。 + +## 功能特性 + +- **Ripgrep Glob 预过滤**:使用 glob 模式进行快速原生级过滤的主要过滤策略 +- **贪婪子串匹配**:当 ripgrep glob 预过滤无结果时的回退文件过滤策略 +- **基于子序列的段评分**:评分时,当查询字符按顺序出现时,路径段获得额外权重 +- **相关性评分**:结果按多因素相关性分数排序 + +## 匹配策略 + +### 1. Ripgrep Glob 预过滤(主要) + +查询被转换为 glob 模式供 ripgrep 进行初始过滤: + +``` +查询: "updater" +Glob: "*u*p*d*a*t*e*r*" +``` + +这利用了 ripgrep 的原生性能进行初始文件过滤。 + +### 2. 贪婪子串匹配(回退) + +当 glob 预过滤无结果时,系统回退到贪婪子串匹配。这允许更灵活的匹配: + +``` +查询: "updatercontroller" +文件: "packages/update/src/node/updateController.ts" + +匹配过程: +1. 找到 "update"(从开头的最长匹配) +2. 剩余 "rcontroller" → 找到 "r" 然后 "controller" +3. 所有部分都匹配 → 成功 +``` + +## 评分算法 + +结果根据 `FileStorage.ts` 中定义的命名常量进行相关性分数排名: + +| 常量 | 值 | 描述 | +|------|-----|------| +| `SCORE_FILENAME_STARTS` | 100 | 文件名以查询开头(最高优先级)| +| `SCORE_FILENAME_CONTAINS` | 80 | 文件名包含精确查询子串 | +| `SCORE_SEGMENT_MATCH` | 60 | 每个匹配查询的路径段 | +| `SCORE_WORD_BOUNDARY` | 20 | 查询匹配单词开头 | +| `SCORE_CONSECUTIVE_CHAR` | 15 | 每个连续字符匹配 | +| `PATH_LENGTH_PENALTY_FACTOR` | 4 | 较长路径的对数惩罚 | + +### 评分策略 + +评分优先级: +1. **文件名匹配**(最高):查询出现在文件名中的文件最相关 +2. **路径段匹配**:多个匹配段表示更强的相关性 +3. **词边界**:在单词开头匹配(如 "upd" 匹配 "update")更优先 +4. **连续匹配**:更长的连续字符序列得分更高 +5. **路径长度**:较短路径更优先(对数惩罚防止长路径主导评分) + +### 评分示例 + +对于查询 `updater`: + +| 文件 | 评分因素 | +|------|----------| +| `RCUpdater.js` | 短路径 + 文件名包含 "updater" | +| `updateController.ts` | 多个路径段匹配 | +| `UpdaterHelper.plist` | 长路径惩罚 | + +## 配置 + +### DirectoryListOptions + +```typescript +interface DirectoryListOptions { + recursive?: boolean // 默认: true + maxDepth?: number // 默认: 10 + includeHidden?: boolean // 默认: false + includeFiles?: boolean // 默认: true + includeDirectories?: boolean // 默认: true + maxEntries?: number // 默认: 20 + searchPattern?: string // 默认: '.' + fuzzy?: boolean // 默认: true +} +``` + +## 使用方法 + +```typescript +// 基本模糊搜索 +const files = await window.api.file.listDirectory(dirPath, { + searchPattern: 'updater', + fuzzy: true, + maxEntries: 20 +}) + +// 禁用模糊搜索(精确 glob 匹配) +const files = await window.api.file.listDirectory(dirPath, { + searchPattern: 'update', + fuzzy: false +}) +``` + +## 性能考虑 + +1. **Ripgrep 预过滤**:大多数查询由 ripgrep 的原生 glob 匹配处理,速度极快 +2. **仅在需要时回退**:贪婪子串匹配(加载所有文件)仅在 glob 匹配返回空结果时运行 +3. **结果限制**:默认只返回前 20 个结果 +4. **排除目录**:自动排除常见的大型目录: + - `node_modules` + - `.git` + - `dist`、`build` + - `.next`、`.nuxt` + - `coverage`、`.cache` + +## 实现细节 + +实现位于 `src/main/services/FileStorage.ts`: + +- `queryToGlobPattern()`:将查询转换为 ripgrep glob 模式 +- `isFuzzyMatch()`:子序列匹配算法 +- `isGreedySubstringMatch()`:贪婪子串匹配回退 +- `getFuzzyMatchScore()`:计算相关性分数 +- `listDirectoryWithRipgrep()`:主搜索协调 diff --git a/docs/zh/references/lan-transfer-protocol.md b/docs/zh/references/lan-transfer-protocol.md new file mode 100644 index 0000000000..a4c01a23c5 --- /dev/null +++ b/docs/zh/references/lan-transfer-protocol.md @@ -0,0 +1,850 @@ +# Cherry Studio 局域网传输协议规范 + +> 版本: 1.0 +> 最后更新: 2025-12 + +本文档定义了 Cherry Studio 桌面客户端(Electron)与移动端(Expo)之间的局域网文件传输协议。 + +--- + +## 目录 + +1. [协议概述](#1-协议概述) +2. [服务发现(Bonjour/mDNS)](#2-服务发现bonjourmdns) +3. [TCP 连接与握手](#3-tcp-连接与握手) +4. [消息格式规范](#4-消息格式规范) +5. [文件传输协议](#5-文件传输协议) +6. [心跳与连接保活](#6-心跳与连接保活) +7. [错误处理](#7-错误处理) +8. [常量与配置](#8-常量与配置) +9. [完整时序图](#9-完整时序图) +10. [移动端实现指南](#10-移动端实现指南) + +--- + +## 1. 协议概述 + +### 1.1 架构角色 + +| 角色 | 平台 | 职责 | +| -------------------- | --------------- | ---------------------------- | +| **Client(客户端)** | Electron 桌面端 | 扫描服务、发起连接、发送文件 | +| **Server(服务端)** | Expo 移动端 | 发布服务、接受连接、接收文件 | + +### 1.2 协议栈(v1) + +``` +┌─────────────────────────────────────┐ +│ 应用层(文件传输) │ +├─────────────────────────────────────┤ +│ 消息层(控制: JSON \n) │ +│ (数据: 二进制帧) │ +├─────────────────────────────────────┤ +│ 传输层(TCP) │ +├─────────────────────────────────────┤ +│ 发现层(Bonjour/mDNS) │ +└─────────────────────────────────────┘ +``` + +### 1.3 通信流程概览 + +``` +1. 服务发现 → 移动端发布 mDNS 服务,桌面端扫描发现 +2. TCP 握手 → 建立连接,交换设备信息(`version=1`) +3. 文件传输 → 控制消息使用 JSON,`file_chunk` 使用二进制帧分块传输 +4. 连接保活 → ping/pong 心跳 +``` + +--- + +## 2. 服务发现(Bonjour/mDNS) + +### 2.1 服务类型 + +| 属性 | 值 | +| ------------ | -------------------- | +| 服务类型 | `cherrystudio` | +| 协议 | `tcp` | +| 完整服务标识 | `_cherrystudio._tcp` | + +### 2.2 服务发布(移动端) + +移动端需要通过 mDNS/Bonjour 发布服务: + +```typescript +// 服务发布参数 +{ + name: "Cherry Studio Mobile", // 设备名称 + type: "cherrystudio", // 服务类型 + protocol: "tcp", // 协议 + port: 53317, // TCP 监听端口 + txt: { // TXT 记录(可选) + version: "1", + platform: "ios" // 或 "android" + } +} +``` + +### 2.3 服务发现(桌面端) + +桌面端扫描并解析服务信息: + +```typescript +// 发现的服务信息结构 +type LocalTransferPeer = { + id: string; // 唯一标识符 + name: string; // 设备名称 + host?: string; // 主机名 + fqdn?: string; // 完全限定域名 + port?: number; // TCP 端口 + type?: string; // 服务类型 + protocol?: "tcp" | "udp"; // 协议 + addresses: string[]; // IP 地址列表 + txt?: Record; // TXT 记录 + updatedAt: number; // 发现时间戳 +}; +``` + +### 2.4 IP 地址选择策略 + +当服务有多个 IP 地址时,优先选择 IPv4: + +```typescript +// 优先选择 IPv4 地址 +const preferredAddress = addresses.find((addr) => isIPv4(addr)) || addresses[0]; +``` + +--- + +## 3. TCP 连接与握手 + +### 3.1 连接建立 + +1. 客户端使用发现的 `host:port` 建立 TCP 连接 +2. 连接成功后立即发送握手消息 +3. 等待服务端响应握手确认 + +### 3.2 握手消息(协议版本 v1) + +#### Client → Server: `handshake` + +```typescript +type LanTransferHandshakeMessage = { + type: "handshake"; + deviceName: string; // 设备名称 + version: string; // 协议版本,当前为 "1" + platform?: string; // 平台:'darwin' | 'win32' | 'linux' + appVersion?: string; // 应用版本 +}; +``` + +**示例:** + +```json +{ + "type": "handshake", + "deviceName": "Cherry Studio 1.7.2", + "version": "1", + "platform": "darwin", + "appVersion": "1.7.2" +} +``` + +### 4. 消息格式规范(混合协议) + +v1 使用"控制 JSON + 二进制数据帧"的混合协议(流式传输模式,无 per-chunk ACK): + +- **控制消息**(握手、心跳、file_start/ack、file_end、file_complete):UTF-8 JSON,`\n` 分隔 +- **数据消息**(`file_chunk`):二进制帧,使用 Magic + 总长度做分帧,不经 Base64 + +### 4.1 控制消息编码(JSON + `\n`) + +| 属性 | 规范 | +| ---------- | ------------ | +| 编码格式 | UTF-8 | +| 序列化格式 | JSON | +| 消息分隔符 | `\n`(0x0A) | + +```typescript +function sendControlMessage(socket: Socket, message: object): void { + socket.write(`${JSON.stringify(message)}\n`); +} +``` + +### 4.2 `file_chunk` 二进制帧格式 + +为解决 TCP 分包/粘包并消除 Base64 开销,`file_chunk` 采用带总长度的二进制帧: + +``` +┌──────────┬──────────┬────────┬───────────────┬──────────────┬────────────┬───────────┐ +│ Magic │ TotalLen │ Type │ TransferId Len│ TransferId │ ChunkIdx │ Data │ +│ 0x43 0x53│ (4B BE) │ 0x01 │ (2B BE) │ (UTF-8) │ (4B BE) │ (raw) │ +└──────────┴──────────┴────────┴───────────────┴──────────────┴────────────┴───────────┘ +``` + +| 字段 | 大小 | 说明 | +| -------------- | ---- | ------------------------------------------- | +| Magic | 2B | 常量 `0x43 0x53` ("CS"), 用于区分 JSON 消息 | +| TotalLen | 4B | Big-endian,帧总长度(不含 Magic/TotalLen) | +| Type | 1B | `0x01` 代表 `file_chunk` | +| TransferId Len | 2B | Big-endian,transferId 字符串长度 | +| TransferId | nB | UTF-8 transferId(长度由上一字段给出) | +| ChunkIdx | 4B | Big-endian,块索引,从 0 开始 | +| Data | mB | 原始文件二进制数据(未编码) | + +> 计算帧总长度:`TotalLen = 1 + 2 + transferIdLen + 4 + dataLen`(即 Type~Data 的长度和)。 + +### 4.3 消息解析策略 + +1. 读取 socket 数据到缓冲区; +2. 若前两字节为 `0x43 0x53` → 按二进制帧解析: + - 至少需要 6 字节头(Magic + TotalLen),不足则等待更多数据 + - 读取 `TotalLen` 判断帧整体长度,缓冲区不足则继续等待 + - 解析 Type/TransferId/ChunkIdx/Data,并传入文件接收逻辑 +3. 否则若首字节为 `{` → 按 JSON + `\n` 解析控制消息 +4. 其它数据丢弃 1 字节并继续循环,避免阻塞。 + +### 4.4 消息类型汇总(v1) + +| 类型 | 方向 | 编码 | 用途 | +| ---------------- | --------------- | -------- | ----------------------- | +| `handshake` | Client → Server | JSON+\n | 握手请求(version=1) | +| `handshake_ack` | Server → Client | JSON+\n | 握手响应 | +| `ping` | Client → Server | JSON+\n | 心跳请求 | +| `pong` | Server → Client | JSON+\n | 心跳响应 | +| `file_start` | Client → Server | JSON+\n | 开始文件传输 | +| `file_start_ack` | Server → Client | JSON+\n | 文件传输确认 | +| `file_chunk` | Client → Server | 二进制帧 | 文件数据块(无 Base64,流式无 per-chunk ACK) | +| `file_end` | Client → Server | JSON+\n | 文件传输结束 | +| `file_complete` | Server → Client | JSON+\n | 传输完成结果 | + +``` +{"type":"message_type",...其他字段...}\n +``` + +--- + +## 5. 文件传输协议 + +### 5.1 传输流程 + +``` +Client (Sender) Server (Receiver) + | | + |──── 1. file_start ────────────────>| + | (文件元数据) | + | | + |<─── 2. file_start_ack ─────────────| + | (接受/拒绝) | + | | + |══════ 循环发送数据块(流式,无 ACK) ═════| + | | + |──── 3. file_chunk [0] ────────────>| + | | + |──── 3. file_chunk [1] ────────────>| + | | + | ... 重复直到所有块发送完成 ... | + | | + |══════════════════════════════════════ + | | + |──── 5. file_end ──────────────────>| + | (所有块已发送) | + | | + |<─── 6. file_complete ──────────────| + | (最终结果) | +``` + +### 5.2 消息定义 + +#### 5.2.1 `file_start` - 开始传输 + +**方向:** Client → Server + +```typescript +type LanTransferFileStartMessage = { + type: "file_start"; + transferId: string; // UUID,唯一传输标识 + fileName: string; // 文件名(含扩展名) + fileSize: number; // 文件总字节数 + mimeType: string; // MIME 类型 + checksum: string; // 整个文件的 SHA-256 哈希(hex) + totalChunks: number; // 总数据块数 + chunkSize: number; // 每块大小(字节) +}; +``` + +**示例:** + +```json +{ + "type": "file_start", + "transferId": "550e8400-e29b-41d4-a716-446655440000", + "fileName": "backup.zip", + "fileSize": 524288000, + "mimeType": "application/zip", + "checksum": "a1b2c3d4e5f6789012345678901234567890abcdef1234567890abcdef123456", + "totalChunks": 8192, + "chunkSize": 65536 +} +``` + +#### 5.2.2 `file_start_ack` - 传输确认 + +**方向:** Server → Client + +```typescript +type LanTransferFileStartAckMessage = { + type: "file_start_ack"; + transferId: string; // 对应的传输 ID + accepted: boolean; // 是否接受传输 + message?: string; // 拒绝原因 +}; +``` + +**接受示例:** + +```json +{ + "type": "file_start_ack", + "transferId": "550e8400-e29b-41d4-a716-446655440000", + "accepted": true +} +``` + +**拒绝示例:** + +```json +{ + "type": "file_start_ack", + "transferId": "550e8400-e29b-41d4-a716-446655440000", + "accepted": false, + "message": "Insufficient storage space" +} +``` + +#### 5.2.3 `file_chunk` - 数据块 + +**方向:** Client → Server(**二进制帧**,见 4.2) + +- 不再使用 JSON/`\n`,也不再使用 Base64 +- 帧结构:`Magic` + `TotalLen` + `Type` + `TransferId` + `ChunkIdx` + `Data` +- `Type` 固定 `0x01`,`Data` 为原始文件二进制数据 +- 传输完整性依赖 `file_start.checksum`(全文件 SHA-256);分块校验和可选,不在帧中发送 + +#### 5.2.4 `file_chunk_ack` - 数据块确认(v1 流式不使用) + +v1 采用流式传输,不发送 per-chunk ACK。本节类型仅保留作为向后兼容参考,实际不会发送。 + +#### 5.2.5 `file_end` - 传输结束 + +**方向:** Client → Server + +```typescript +type LanTransferFileEndMessage = { + type: "file_end"; + transferId: string; // 传输 ID +}; +``` + +**示例:** + +```json +{ + "type": "file_end", + "transferId": "550e8400-e29b-41d4-a716-446655440000" +} +``` + +#### 5.2.6 `file_complete` - 传输完成 + +**方向:** Server → Client + +```typescript +type LanTransferFileCompleteMessage = { + type: "file_complete"; + transferId: string; // 传输 ID + success: boolean; // 是否成功 + filePath?: string; // 保存路径(成功时) + error?: string; // 错误信息(失败时) +}; +``` + +**成功示例:** + +```json +{ + "type": "file_complete", + "transferId": "550e8400-e29b-41d4-a716-446655440000", + "success": true, + "filePath": "/storage/emulated/0/Documents/backup.zip" +} +``` + +**失败示例:** + +```json +{ + "type": "file_complete", + "transferId": "550e8400-e29b-41d4-a716-446655440000", + "success": false, + "error": "File checksum verification failed" +} +``` + +### 5.3 校验和算法 + +#### 整个文件校验和(保持不变) + +```typescript +async function calculateFileChecksum(filePath: string): Promise { + const hash = crypto.createHash("sha256"); + const stream = fs.createReadStream(filePath); + + for await (const chunk of stream) { + hash.update(chunk); + } + + return hash.digest("hex"); +} +``` + +#### 数据块校验和 + +v1 默认 **不传输分块校验和**,依赖最终文件 checksum。若需要,可在应用层自定义(非协议字段)。 + +### 5.4 校验流程 + +**发送端(Client):** + +1. 发送前计算整个文件的 SHA-256 → `file_start.checksum` +2. 分块直接发送原始二进制(无 Base64) + +**接收端(Server):** + +1. 收到 `file_chunk` 后直接使用二进制数据 +2. 边收边落盘并增量计算 SHA-256(推荐) +3. 所有块接收完成后,计算/完成增量哈希,得到最终 SHA-256 +4. 与 `file_start.checksum` 比对,结果写入 `file_complete` + +### 5.5 数据块大小计算 + +```typescript +const CHUNK_SIZE = 512 * 1024; // 512KB + +const totalChunks = Math.ceil(fileSize / CHUNK_SIZE); + +// 最后一个块可能小于 CHUNK_SIZE +const lastChunkSize = fileSize % CHUNK_SIZE || CHUNK_SIZE; +``` + +--- + +## 6. 心跳与连接保活 + +### 6.1 心跳消息 + +#### `ping` + +**方向:** Client → Server + +```typescript +type LanTransferPingMessage = { + type: "ping"; + payload?: string; // 可选载荷 +}; +``` + +```json +{ + "type": "ping", + "payload": "heartbeat" +} +``` + +#### `pong` + +**方向:** Server → Client + +```typescript +type LanTransferPongMessage = { + type: "pong"; + received: boolean; // 确认收到 + payload?: string; // 回传 ping 的载荷 +}; +``` + +```json +{ + "type": "pong", + "received": true, + "payload": "heartbeat" +} +``` + +### 6.2 心跳策略 + +- 握手成功后立即发送一次 `ping` 验证连接 +- 可选:定期发送心跳保持连接活跃 +- `pong` 应返回 `ping` 中的 `payload`(可选) + +--- + +## 7. 错误处理 + +### 7.1 超时配置 + +| 操作 | 超时时间 | 说明 | +| ---------- | -------- | --------------------- | +| TCP 连接 | 10 秒 | 连接建立超时 | +| 握手等待 | 10 秒 | 等待 `handshake_ack` | +| 传输完成 | 60 秒 | 等待 `file_complete` | + +### 7.2 错误场景处理 + +| 场景 | Client 处理 | Server 处理 | +| --------------- | ------------------ | ---------------------- | +| TCP 连接失败 | 通知 UI,允许重试 | - | +| 握手超时 | 断开连接,通知 UI | 关闭 socket | +| 握手被拒绝 | 显示拒绝原因 | - | +| 数据块处理失败 | 中止传输,清理状态 | 清理临时文件 | +| 连接意外断开 | 清理状态,通知 UI | 清理临时文件 | +| 存储空间不足 | - | 发送 `accepted: false` | + +### 7.3 资源清理 + +**Client 端:** + +```typescript +function cleanup(): void { + // 1. 销毁文件读取流 + if (readStream) { + readStream.destroy(); + } + // 2. 清理传输状态 + activeTransfer = undefined; + // 3. 关闭 socket(如需要) + socket?.destroy(); +} +``` + +**Server 端:** + +```typescript +function cleanup(): void { + // 1. 关闭文件写入流 + if (writeStream) { + writeStream.end(); + } + // 2. 删除未完成的临时文件 + if (tempFilePath) { + fs.unlinkSync(tempFilePath); + } + // 3. 清理传输状态 + activeTransfer = undefined; +} +``` + +--- + +## 8. 常量与配置 + +### 8.1 协议常量 + +```typescript +// 协议版本(v1 = 控制 JSON + 二进制 chunk + 流式传输) +export const LAN_TRANSFER_PROTOCOL_VERSION = "1"; + +// 服务发现 +export const LAN_TRANSFER_SERVICE_TYPE = "cherrystudio"; +export const LAN_TRANSFER_SERVICE_FULL_NAME = "_cherrystudio._tcp"; + +// TCP 端口 +export const LAN_TRANSFER_TCP_PORT = 53317; + +// 文件传输(与二进制帧一致) +export const LAN_TRANSFER_CHUNK_SIZE = 512 * 1024; // 512KB +export const LAN_TRANSFER_GLOBAL_TIMEOUT_MS = 10 * 60 * 1000; // 10 分钟 + +// 超时设置 +export const LAN_TRANSFER_HANDSHAKE_TIMEOUT_MS = 10_000; // 10秒 +export const LAN_TRANSFER_CHUNK_TIMEOUT_MS = 30_000; // 30秒 +export const LAN_TRANSFER_COMPLETE_TIMEOUT_MS = 60_000; // 60秒 +``` + +### 8.2 支持的文件类型 + +当前仅支持 ZIP 文件: + +```typescript +export const LAN_TRANSFER_ALLOWED_EXTENSIONS = [".zip"]; +export const LAN_TRANSFER_ALLOWED_MIME_TYPES = [ + "application/zip", + "application/x-zip-compressed", +]; +``` + +--- + +## 9. 完整时序图 + +### 9.1 完整传输流程(v1,流式传输) + +``` +┌─────────┐ ┌─────────┐ ┌─────────┐ +│ Renderer│ │ Main │ │ Mobile │ +│ (UI) │ │ Process │ │ Server │ +└────┬────┘ └────┬────┘ └────┬────┘ + │ │ │ + │ ════════════ 服务发现阶段 ════════════ │ + │ │ │ + │ startScan() │ │ + │────────────────────────────────────>│ │ + │ │ mDNS browse │ + │ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─>│ + │ │ │ + │ │<─ ─ ─ service discovered ─ ─ ─ ─ ─ ─│ + │ │ │ + │<────── onServicesUpdated ───────────│ │ + │ │ │ + │ ════════════ 握手连接阶段 ════════════ │ + │ │ │ + │ connect(peer) │ │ + │────────────────────────────────────>│ │ + │ │──────── TCP Connect ───────────────>│ + │ │ │ + │ │──────── handshake ─────────────────>│ + │ │ │ + │ │<─────── handshake_ack ──────────────│ + │ │ │ + │ │──────── ping ──────────────────────>│ + │ │<─────── pong ───────────────────────│ + │ │ │ + │<────── connect result ──────────────│ │ + │ │ │ + │ ════════════ 文件传输阶段 ════════════ │ + │ │ │ + │ sendFile(path) │ │ + │────────────────────────────────────>│ │ + │ │──────── file_start ────────────────>│ + │ │ │ + │ │<─────── file_start_ack ─────────────│ + │ │ │ + │ │ │ + │ │══════ 循环发送数据块 ═══════════════│ + │ │ │ + │ │──────── file_chunk[0] (binary) ────>│ + │<────── progress event ──────────────│ │ + │ │ │ + │ │──────── file_chunk[1] (binary) ────>│ + │<────── progress event ──────────────│ │ + │ │ │ + │ │ ... 重复 ... │ + │ │ │ + │ │══════════════════════════════════════│ + │ │ │ + │ │──────── file_end ──────────────────>│ + │ │ │ + │ │<─────── file_complete ──────────────│ + │ │ │ + │<────── complete event ──────────────│ │ + │<────── sendFile result ─────────────│ │ + │ │ │ +``` + +--- + +## 10. 移动端实现指南(v1 要点) + +### 10.1 必须实现的功能 + +1. **mDNS 服务发布** + + - 发布 `_cherrystudio._tcp` 服务 + - 提供 TCP 端口号 `53317` + - 可选:TXT 记录(版本、平台信息) + +2. **TCP 服务端** + + - 监听指定端口 + - 支持单连接或多连接 + +3. **消息解析** + + - 控制消息:UTF-8 + `\n` JSON + - 数据消息:二进制帧(Magic+TotalLen 分帧) + +4. **握手处理** + + - 验证 `handshake` 消息 + - 发送 `handshake_ack` 响应 + - 响应 `ping` 消息 + +5. **文件接收(流式模式)** + - 解析 `file_start`,准备接收 + - 接收 `file_chunk` 二进制帧,直接写入文件/缓冲并增量哈希 + - v1 不发送 per-chunk ACK(流式传输) + - 处理 `file_end`,完成增量哈希并校验 checksum + - 发送 `file_complete` 结果 + +### 10.2 推荐的库 + +**React Native / Expo:** + +- mDNS: `react-native-zeroconf` 或 `@homielab/react-native-bonjour` +- TCP: `react-native-tcp-socket` +- Crypto: `expo-crypto` 或 `react-native-quick-crypto` + +### 10.3 接收端伪代码 + +```typescript +class FileReceiver { + private transfer?: { + id: string; + fileName: string; + fileSize: number; + checksum: string; + totalChunks: number; + receivedChunks: number; + tempPath: string; + // v1: 边收边写文件,避免大文件 OOM + // stream: FileSystem writable stream (平台相关封装) + }; + + handleMessage(message: any) { + switch (message.type) { + case "handshake": + this.handleHandshake(message); + break; + case "ping": + this.sendPong(message); + break; + case "file_start": + this.handleFileStart(message); + break; + // v1: file_chunk 为二进制帧,不再走 JSON 分支 + case "file_end": + this.handleFileEnd(message); + break; + } + } + + handleFileStart(msg: LanTransferFileStartMessage) { + // 1. 检查存储空间 + // 2. 创建临时文件 + // 3. 初始化传输状态 + // 4. 发送 file_start_ack + } + + // v1: 二进制帧处理在 socket data 流中解析,随后调用 handleBinaryFileChunk + handleBinaryFileChunk(transferId: string, chunkIndex: number, data: Buffer) { + // 直接使用二进制数据,按 chunkSize/lastChunk 计算长度 + // 写入文件流并更新增量 SHA-256 + this.transfer.receivedChunks++; + // v1: 流式传输,不发送 per-chunk ACK + } + + handleFileEnd(msg: LanTransferFileEndMessage) { + // 1. 合并所有数据块 + // 2. 验证完整文件 checksum + // 3. 写入最终位置 + // 4. 发送 file_complete + } +} +``` + +--- + +## 附录 A:TypeScript 类型定义 + +完整的类型定义位于 `packages/shared/config/types.ts`: + +```typescript +// 握手消息 +export interface LanTransferHandshakeMessage { + type: "handshake"; + deviceName: string; + version: string; + platform?: string; + appVersion?: string; +} + +export interface LanTransferHandshakeAckMessage { + type: "handshake_ack"; + accepted: boolean; + message?: string; +} + +// 心跳消息 +export interface LanTransferPingMessage { + type: "ping"; + payload?: string; +} + +export interface LanTransferPongMessage { + type: "pong"; + received: boolean; + payload?: string; +} + +// 文件传输消息 (Client -> Server) +export interface LanTransferFileStartMessage { + type: "file_start"; + transferId: string; + fileName: string; + fileSize: number; + mimeType: string; + checksum: string; + totalChunks: number; + chunkSize: number; +} + +export interface LanTransferFileChunkMessage { + type: "file_chunk"; + transferId: string; + chunkIndex: number; + data: string; // Base64 encoded (v1: 二进制帧模式下不使用) +} + +export interface LanTransferFileEndMessage { + type: "file_end"; + transferId: string; +} + +// 文件传输响应消息 (Server -> Client) +export interface LanTransferFileStartAckMessage { + type: "file_start_ack"; + transferId: string; + accepted: boolean; + message?: string; +} + +// v1 流式不发送 per-chunk ACK,以下类型仅用于向后兼容参考 +export interface LanTransferFileChunkAckMessage { + type: "file_chunk_ack"; + transferId: string; + chunkIndex: number; + received: boolean; + error?: string; +} + +export interface LanTransferFileCompleteMessage { + type: "file_complete"; + transferId: string; + success: boolean; + filePath?: string; + error?: string; +} + +// 常量 +export const LAN_TRANSFER_TCP_PORT = 53317; +export const LAN_TRANSFER_CHUNK_SIZE = 512 * 1024; +export const LAN_TRANSFER_CHUNK_TIMEOUT_MS = 30_000; +``` + +--- + +## 附录 B:版本历史 + +| 版本 | 日期 | 变更 | +| ---- | ------- | ---------------------------------------- | +| 1.0 | 2025-12 | 初始发布版本,支持二进制帧格式与流式传输 | diff --git a/electron-builder.yml b/electron-builder.yml index 7d1b76ade5..bb5a4f3954 100644 --- a/electron-builder.yml +++ b/electron-builder.yml @@ -135,60 +135,44 @@ artifactBuildCompleted: scripts/artifact-build-completed.js releaseInfo: releaseNotes: | - Cherry Studio 1.7.3 - Feature & Stability Update - - This release brings new features, UI improvements, and important bug fixes. + Cherry Studio 1.7.9 - New Features & Bug Fixes ✨ New Features - - Add MCP server log viewer for better debugging - - Support custom Git Bash path configuration - - Add print to PDF and save as HTML for mini program webviews - - Add CherryIN API host selection settings - - Enhance assistant presets with sort and batch delete modes - - Open URL directly for SelectionAssistant search action - - Enhance web search tool switching with provider-specific context - - 🔧 Improvements - - Remove Intel Ultra limit for OVMS - - Improve settings tab and assistant item UI + - [Agent] Add 302.AI provider support + - [Browser] Browser data now persists and supports multiple tabs + - [Language] Add Romanian language support + - [Search] Add fuzzy search for file list + - [Models] Add latest Zhipu models + - [Image] Improve text-to-image functionality 🐛 Bug Fixes - - Fix stack overflow with base64 images - - Fix infinite loop in knowledge queue processing - - Fix quick panel closing in multiple selection mode - - Fix thinking timer not stopping when reply is aborted - - Fix ThinkingButton icon display for fixed reasoning mode - - Fix knowledge query prioritization and intent prompt - - Fix OpenRouter embeddings support - - Fix SelectionAction window resize on Windows - - Add gpustack provider support for qwen3 thinking mode + - [Mac] Fix mini window unexpected closing issue + - [Preview] Fix HTML preview controls not working in fullscreen + - [Translate] Fix translation duplicate execution issue + - [Zoom] Fix page zoom reset issue during navigation + - [Agent] Fix crash when switching between agent and assistant + - [Agent] Fix navigation in agent mode + - [Copy] Fix markdown copy button issue + - [Windows] Fix compatibility issues on non-Windows systems - Cherry Studio 1.7.3 - 功能与稳定性更新 - - 本次更新带来新功能、界面改进和重要的问题修复。 + Cherry Studio 1.7.9 - 新功能与问题修复 ✨ 新功能 - - 新增 MCP 服务器日志查看器,便于调试 - - 支持自定义 Git Bash 路径配置 - - 小程序 webview 支持打印 PDF 和保存为 HTML - - 新增 CherryIN API 主机选择设置 - - 助手预设增强:支持排序和批量删除模式 - - 划词助手搜索操作直接打开 URL - - 增强网页搜索工具切换逻辑,支持服务商特定上下文 - - 🔧 功能改进 - - 移除 OVMS 的 Intel Ultra 限制 - - 优化设置标签页和助手项目 UI + - [Agent] 新增 302.AI 服务商支持 + - [浏览器] 浏览器数据现在可以保存,支持多标签页 + - [语言] 新增罗马尼亚语支持 + - [搜索] 文件列表新增模糊搜索功能 + - [模型] 新增最新智谱模型 + - [图片] 优化文生图功能 🐛 问题修复 - - 修复 base64 图片导致的栈溢出问题 - - 修复知识库队列处理的无限循环问题 - - 修复多选模式下快捷面板意外关闭的问题 - - 修复回复中止时思考计时器未停止的问题 - - 修复固定推理模式下思考按钮图标显示问题 - - 修复知识库查询优先级和意图提示 - - 修复 OpenRouter 嵌入模型支持 - - 修复 Windows 上划词助手窗口大小调整问题 - - 为 gpustack 服务商添加 qwen3 思考模式支持 + - [Mac] 修复迷你窗口意外关闭的问题 + - [预览] 修复全屏模式下 HTML 预览控件无法使用的问题 + - [翻译] 修复翻译重复执行的问题 + - [缩放] 修复页面导航时缩放被重置的问题 + - [智能体] 修复在智能体和助手间切换时崩溃的问题 + - [智能体] 修复智能体模式下的导航问题 + - [复制] 修复 Markdown 复制按钮问题 + - [兼容性] 修复非 Windows 系统的兼容性问题 diff --git a/electron.vite.config.ts b/electron.vite.config.ts index bbb8e2ecf8..ccb020ce7d 100644 --- a/electron.vite.config.ts +++ b/electron.vite.config.ts @@ -1,6 +1,6 @@ import react from '@vitejs/plugin-react-swc' import { CodeInspectorPlugin } from 'code-inspector-plugin' -import { defineConfig, externalizeDepsPlugin } from 'electron-vite' +import { defineConfig } from 'electron-vite' import { resolve } from 'path' import { visualizer } from 'rollup-plugin-visualizer' @@ -17,7 +17,7 @@ const isProd = process.env.NODE_ENV === 'production' export default defineConfig({ main: { - plugins: [externalizeDepsPlugin(), ...visualizerPlugin('main')], + plugins: [...visualizerPlugin('main')], resolve: { alias: { '@main': resolve('src/main'), @@ -26,7 +26,8 @@ export default defineConfig({ '@shared': resolve('packages/shared'), '@logger': resolve('src/main/services/LoggerService'), '@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'), - '@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node') + '@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node'), + '@test-mocks': resolve('tests/__mocks__') } }, build: { @@ -52,8 +53,7 @@ export default defineConfig({ plugins: [ react({ tsDecorators: true - }), - externalizeDepsPlugin() + }) ], resolve: { alias: { @@ -113,7 +113,8 @@ export default defineConfig({ '@cherrystudio/extension-table-plus': resolve('packages/extension-table-plus/src'), '@cherrystudio/ai-sdk-provider': resolve('packages/ai-sdk-provider/src'), '@cherrystudio/ui/icons': resolve('packages/ui/src/components/icons'), - '@cherrystudio/ui': resolve('packages/ui/src') + '@cherrystudio/ui': resolve('packages/ui/src'), + '@test-mocks': resolve('tests/__mocks__') } }, optimizeDeps: { diff --git a/eslint.config.mjs b/eslint.config.mjs index 2ff7fd9f41..2e6a2cc311 100644 --- a/eslint.config.mjs +++ b/eslint.config.mjs @@ -61,6 +61,7 @@ export default defineConfig([ 'tests/**', '.yarn/**', '.gitignore', + '.conductor/**', 'scripts/cloudflare-worker.js', 'src/main/integration/nutstore/sso/lib/**', 'src/main/integration/cherryai/index.js', @@ -171,6 +172,11 @@ export default defineConfig([ } }, // Schema key naming convention (cache & preferences) + // Supports both fixed keys and template keys: + // - Fixed: 'app.user.avatar', 'chat.multi_select_mode' + // - Template: 'scroll.position.${topicId}', 'entity.cache.${type}_${id}' + // Template keys must follow the same dot-separated pattern as fixed keys. + // When ${xxx} placeholders are treated as literal strings, the key must match: xxx.yyy.zzz_www { files: ['packages/shared/data/cache/cacheSchemas.ts', 'packages/shared/data/preference/preferenceSchemas.ts'], plugins: { @@ -180,25 +186,80 @@ export default defineConfig([ meta: { type: 'problem', docs: { - description: 'Enforce schema key naming convention: namespace.sub.key_name', + description: + 'Enforce schema key naming convention: namespace.sub.key_name (template placeholders treated as literal strings)', recommended: true }, messages: { invalidKey: - 'Schema key "{{key}}" must follow format: namespace.sub.key_name (e.g., app.user.avatar).' + 'Schema key "{{key}}" must follow format: namespace.sub.key_name (e.g., app.user.avatar, scroll.position.${id}). Template ${xxx} is treated as a literal string segment.', + invalidTemplateVar: + 'Template variable in "{{key}}" must be a valid identifier (e.g., ${id}, ${topicId}).' } }, create(context) { - const VALID_KEY_PATTERN = /^[a-z][a-z0-9_]*(\.[a-z][a-z0-9_]*)+$/ + /** + * Validates a schema key for correct naming convention. + * + * Both fixed keys and template keys must follow the same pattern: + * - Lowercase segments separated by dots + * - Each segment: starts with letter, contains letters/numbers/underscores + * - At least two segments (must have at least one dot) + * + * Template keys: ${xxx} placeholders are treated as literal string segments. + * Example valid: 'scroll.position.${id}', 'entity.cache.${type}_${id}' + * Example invalid: 'cache:${type}' (colon not allowed), '${id}' (no dot) + * + * @param {string} key - The schema key to validate + * @returns {{ valid: boolean, error?: 'invalidKey' | 'invalidTemplateVar' }} + */ + function validateKey(key) { + // Check if key contains template placeholders + const hasTemplate = key.includes('${') + + if (hasTemplate) { + // Validate template variable names first + const templateVarPattern = /\$\{([^}]*)\}/g + let match + while ((match = templateVarPattern.exec(key)) !== null) { + const varName = match[1] + // Variable must be a valid identifier: start with letter, contain only alphanumeric and underscore + if (!varName || !/^[a-zA-Z][a-zA-Z0-9_]*$/.test(varName)) { + return { valid: false, error: 'invalidTemplateVar' } + } + } + + // Replace template placeholders with a valid segment marker + // Use 'x' as placeholder since it's a valid segment character + const keyWithoutTemplates = key.replace(/\$\{[^}]+\}/g, 'x') + + // Template key must follow the same pattern as fixed keys + // when ${xxx} is treated as a literal string + const fixedKeyPattern = /^[a-z][a-z0-9_]*(\.[a-z][a-z0-9_]*)+$/ + if (!fixedKeyPattern.test(keyWithoutTemplates)) { + return { valid: false, error: 'invalidKey' } + } + + return { valid: true } + } else { + // Fixed key validation: standard dot-separated format + const fixedKeyPattern = /^[a-z][a-z0-9_]*(\.[a-z][a-z0-9_]*)+$/ + if (!fixedKeyPattern.test(key)) { + return { valid: false, error: 'invalidKey' } + } + return { valid: true } + } + } return { TSPropertySignature(node) { if (node.key.type === 'Literal' && typeof node.key.value === 'string') { const key = node.key.value - if (!VALID_KEY_PATTERN.test(key)) { + const result = validateKey(key) + if (!result.valid) { context.report({ node: node.key, - messageId: 'invalidKey', + messageId: result.error, data: { key } }) } @@ -207,10 +268,11 @@ export default defineConfig([ Property(node) { if (node.key.type === 'Literal' && typeof node.key.value === 'string') { const key = node.key.value - if (!VALID_KEY_PATTERN.test(key)) { + const result = validateKey(key) + if (!result.valid) { context.report({ node: node.key, - messageId: 'invalidKey', + messageId: result.error, data: { key } }) } diff --git a/migrations/README.md b/migrations/README.md index fc11adc188..5ade119ec8 100644 --- a/migrations/README.md +++ b/migrations/README.md @@ -1,6 +1,10 @@ **THIS DIRECTORY IS NOT FOR RUNTIME USE** +**v2 Data Refactoring Notice** +Before the official release of the alpha version, the database structure may change at any time. To maintain simplicity, the database migration files will be periodically reinitialized, which may cause the application to fail. If this occurs, please delete the `cherrystudio.sqlite` file located in the user data directory. + - Using `libsql` as the `sqlite3` driver, and `drizzle` as the ORM and database migration tool +- Table schemas are defined in `src\main\data\db\schemas` - `migrations/sqlite-drizzle` contains auto-generated migration data. Please **DO NOT** modify it. - If table structure changes, we should run migrations. -- To generate migrations, use the command `yarn run migrations:generate` +- To generate migrations, use the command `yarn run db:migrations:generate` diff --git a/migrations/sqlite-drizzle/0000_init.sql b/migrations/sqlite-drizzle/0000_init.sql new file mode 100644 index 0000000000..1b49b5e7ad --- /dev/null +++ b/migrations/sqlite-drizzle/0000_init.sql @@ -0,0 +1,145 @@ +CREATE TABLE `app_state` ( + `key` text PRIMARY KEY NOT NULL, + `value` text NOT NULL, + `description` text, + `created_at` integer, + `updated_at` integer +); +--> statement-breakpoint +CREATE TABLE `entity_tag` ( + `entity_type` text NOT NULL, + `entity_id` text NOT NULL, + `tag_id` text NOT NULL, + `created_at` integer, + `updated_at` integer, + PRIMARY KEY(`entity_type`, `entity_id`, `tag_id`), + FOREIGN KEY (`tag_id`) REFERENCES `tag`(`id`) ON UPDATE no action ON DELETE cascade +); +--> statement-breakpoint +CREATE INDEX `entity_tag_tag_id_idx` ON `entity_tag` (`tag_id`);--> statement-breakpoint +CREATE TABLE `group` ( + `id` text PRIMARY KEY NOT NULL, + `entity_type` text NOT NULL, + `name` text NOT NULL, + `sort_order` integer DEFAULT 0, + `created_at` integer, + `updated_at` integer +); +--> statement-breakpoint +CREATE INDEX `group_entity_sort_idx` ON `group` (`entity_type`,`sort_order`);--> statement-breakpoint +CREATE TABLE `message` ( + `id` text PRIMARY KEY NOT NULL, + `parent_id` text, + `topic_id` text NOT NULL, + `role` text NOT NULL, + `data` text NOT NULL, + `searchable_text` text, + `status` text NOT NULL, + `siblings_group_id` integer DEFAULT 0, + `assistant_id` text, + `assistant_meta` text, + `model_id` text, + `model_meta` text, + `trace_id` text, + `stats` text, + `created_at` integer, + `updated_at` integer, + `deleted_at` integer, + FOREIGN KEY (`topic_id`) REFERENCES `topic`(`id`) ON UPDATE no action ON DELETE cascade, + FOREIGN KEY (`parent_id`) REFERENCES `message`(`id`) ON UPDATE no action ON DELETE set null, + CONSTRAINT "message_role_check" CHECK("message"."role" IN ('user', 'assistant', 'system')), + CONSTRAINT "message_status_check" CHECK("message"."status" IN ('success', 'error', 'paused')) +); +--> statement-breakpoint +CREATE INDEX `message_parent_id_idx` ON `message` (`parent_id`);--> statement-breakpoint +CREATE INDEX `message_topic_created_idx` ON `message` (`topic_id`,`created_at`);--> statement-breakpoint +CREATE INDEX `message_trace_id_idx` ON `message` (`trace_id`);--> statement-breakpoint +CREATE TABLE `preference` ( + `scope` text DEFAULT 'default' NOT NULL, + `key` text NOT NULL, + `value` text, + `created_at` integer, + `updated_at` integer, + PRIMARY KEY(`scope`, `key`) +); +--> statement-breakpoint +CREATE TABLE `tag` ( + `id` text PRIMARY KEY NOT NULL, + `name` text NOT NULL, + `color` text, + `created_at` integer, + `updated_at` integer +); +--> statement-breakpoint +CREATE UNIQUE INDEX `tag_name_unique` ON `tag` (`name`);--> statement-breakpoint +CREATE TABLE `topic` ( + `id` text PRIMARY KEY NOT NULL, + `name` text, + `is_name_manually_edited` integer DEFAULT false, + `assistant_id` text, + `assistant_meta` text, + `prompt` text, + `active_node_id` text, + `group_id` text, + `sort_order` integer DEFAULT 0, + `is_pinned` integer DEFAULT false, + `pinned_order` integer DEFAULT 0, + `created_at` integer, + `updated_at` integer, + `deleted_at` integer, + FOREIGN KEY (`group_id`) REFERENCES `group`(`id`) ON UPDATE no action ON DELETE set null +); +--> statement-breakpoint +CREATE INDEX `topic_group_updated_idx` ON `topic` (`group_id`,`updated_at`);--> statement-breakpoint +CREATE INDEX `topic_group_sort_idx` ON `topic` (`group_id`,`sort_order`);--> statement-breakpoint +CREATE INDEX `topic_updated_at_idx` ON `topic` (`updated_at`);--> statement-breakpoint +CREATE INDEX `topic_is_pinned_idx` ON `topic` (`is_pinned`,`pinned_order`);--> statement-breakpoint +CREATE INDEX `topic_assistant_id_idx` ON `topic` (`assistant_id`); +--> statement-breakpoint +-- ============================================================ +-- FTS5 Virtual Table and Triggers for Message Full-Text Search +-- ============================================================ + +-- 1. Create FTS5 virtual table with external content +-- Links to message table's searchable_text column +CREATE VIRTUAL TABLE IF NOT EXISTS message_fts USING fts5( + searchable_text, + content='message', + content_rowid='rowid', + tokenize='trigram' +);--> statement-breakpoint + +-- 2. Trigger: populate searchable_text and sync FTS on INSERT +CREATE TRIGGER IF NOT EXISTS message_ai AFTER INSERT ON message BEGIN + -- Extract searchable text from data.blocks + UPDATE message SET searchable_text = ( + SELECT group_concat(json_extract(value, '$.content'), ' ') + FROM json_each(json_extract(NEW.data, '$.blocks')) + WHERE json_extract(value, '$.type') = 'main_text' + ) WHERE id = NEW.id; + -- Sync to FTS5 + INSERT INTO message_fts(rowid, searchable_text) + SELECT rowid, searchable_text FROM message WHERE id = NEW.id; +END;--> statement-breakpoint + +-- 3. Trigger: sync FTS on DELETE +CREATE TRIGGER IF NOT EXISTS message_ad AFTER DELETE ON message BEGIN + INSERT INTO message_fts(message_fts, rowid, searchable_text) + VALUES ('delete', OLD.rowid, OLD.searchable_text); +END;--> statement-breakpoint + +-- 4. Trigger: update searchable_text and sync FTS on UPDATE OF data +CREATE TRIGGER IF NOT EXISTS message_au AFTER UPDATE OF data ON message BEGIN + -- Remove old FTS entry + INSERT INTO message_fts(message_fts, rowid, searchable_text) + VALUES ('delete', OLD.rowid, OLD.searchable_text); + -- Update searchable_text + UPDATE message SET searchable_text = ( + SELECT group_concat(json_extract(value, '$.content'), ' ') + FROM json_each(json_extract(NEW.data, '$.blocks')) + WHERE json_extract(value, '$.type') = 'main_text' + ) WHERE id = NEW.id; + -- Add new FTS entry + INSERT INTO message_fts(rowid, searchable_text) + SELECT rowid, searchable_text FROM message WHERE id = NEW.id; +END; \ No newline at end of file diff --git a/migrations/sqlite-drizzle/0000_solid_lord_hawal.sql b/migrations/sqlite-drizzle/0000_solid_lord_hawal.sql deleted file mode 100644 index 9e52692966..0000000000 --- a/migrations/sqlite-drizzle/0000_solid_lord_hawal.sql +++ /dev/null @@ -1,17 +0,0 @@ -CREATE TABLE `app_state` ( - `key` text PRIMARY KEY NOT NULL, - `value` text NOT NULL, - `description` text, - `created_at` integer, - `updated_at` integer -); ---> statement-breakpoint -CREATE TABLE `preference` ( - `scope` text NOT NULL, - `key` text NOT NULL, - `value` text, - `created_at` integer, - `updated_at` integer -); ---> statement-breakpoint -CREATE INDEX `scope_name_idx` ON `preference` (`scope`,`key`); \ No newline at end of file diff --git a/migrations/sqlite-drizzle/0001_futuristic_human_fly.sql b/migrations/sqlite-drizzle/0001_futuristic_human_fly.sql new file mode 100644 index 0000000000..e4683658be --- /dev/null +++ b/migrations/sqlite-drizzle/0001_futuristic_human_fly.sql @@ -0,0 +1,32 @@ +PRAGMA foreign_keys=OFF;--> statement-breakpoint +CREATE TABLE `__new_message` ( + `id` text PRIMARY KEY NOT NULL, + `parent_id` text, + `topic_id` text NOT NULL, + `role` text NOT NULL, + `data` text NOT NULL, + `searchable_text` text, + `status` text NOT NULL, + `siblings_group_id` integer DEFAULT 0, + `assistant_id` text, + `assistant_meta` text, + `model_id` text, + `model_meta` text, + `trace_id` text, + `stats` text, + `created_at` integer, + `updated_at` integer, + `deleted_at` integer, + FOREIGN KEY (`topic_id`) REFERENCES `topic`(`id`) ON UPDATE no action ON DELETE cascade, + FOREIGN KEY (`parent_id`) REFERENCES `message`(`id`) ON UPDATE no action ON DELETE set null, + CONSTRAINT "message_role_check" CHECK("__new_message"."role" IN ('user', 'assistant', 'system')), + CONSTRAINT "message_status_check" CHECK("__new_message"."status" IN ('pending', 'success', 'error', 'paused')) +); +--> statement-breakpoint +INSERT INTO `__new_message`("id", "parent_id", "topic_id", "role", "data", "searchable_text", "status", "siblings_group_id", "assistant_id", "assistant_meta", "model_id", "model_meta", "trace_id", "stats", "created_at", "updated_at", "deleted_at") SELECT "id", "parent_id", "topic_id", "role", "data", "searchable_text", "status", "siblings_group_id", "assistant_id", "assistant_meta", "model_id", "model_meta", "trace_id", "stats", "created_at", "updated_at", "deleted_at" FROM `message`;--> statement-breakpoint +DROP TABLE `message`;--> statement-breakpoint +ALTER TABLE `__new_message` RENAME TO `message`;--> statement-breakpoint +PRAGMA foreign_keys=ON;--> statement-breakpoint +CREATE INDEX `message_parent_id_idx` ON `message` (`parent_id`);--> statement-breakpoint +CREATE INDEX `message_topic_created_idx` ON `message` (`topic_id`,`created_at`);--> statement-breakpoint +CREATE INDEX `message_trace_id_idx` ON `message` (`trace_id`); \ No newline at end of file diff --git a/migrations/sqlite-drizzle/meta/0000_snapshot.json b/migrations/sqlite-drizzle/meta/0000_snapshot.json index 51c5ed6cba..2fd34856f7 100644 --- a/migrations/sqlite-drizzle/meta/0000_snapshot.json +++ b/migrations/sqlite-drizzle/meta/0000_snapshot.json @@ -6,7 +6,7 @@ }, "dialect": "sqlite", "enums": {}, - "id": "de8009d7-95b9-4f99-99fa-4b8795708f21", + "id": "2ee6f7b2-99da-4de1-b895-48866855b7c6", "internal": { "indexes": {} }, @@ -57,6 +57,305 @@ "name": "app_state", "uniqueConstraints": {} }, + "entity_tag": { + "checkConstraints": {}, + "columns": { + "created_at": { + "autoincrement": false, + "name": "created_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "entity_id": { + "autoincrement": false, + "name": "entity_id", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "entity_type": { + "autoincrement": false, + "name": "entity_type", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "tag_id": { + "autoincrement": false, + "name": "tag_id", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "updated_at": { + "autoincrement": false, + "name": "updated_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + } + }, + "compositePrimaryKeys": { + "entity_tag_entity_type_entity_id_tag_id_pk": { + "columns": ["entity_type", "entity_id", "tag_id"], + "name": "entity_tag_entity_type_entity_id_tag_id_pk" + } + }, + "foreignKeys": { + "entity_tag_tag_id_tag_id_fk": { + "columnsFrom": ["tag_id"], + "columnsTo": ["id"], + "name": "entity_tag_tag_id_tag_id_fk", + "onDelete": "cascade", + "onUpdate": "no action", + "tableFrom": "entity_tag", + "tableTo": "tag" + } + }, + "indexes": { + "entity_tag_tag_id_idx": { + "columns": ["tag_id"], + "isUnique": false, + "name": "entity_tag_tag_id_idx" + } + }, + "name": "entity_tag", + "uniqueConstraints": {} + }, + "group": { + "checkConstraints": {}, + "columns": { + "created_at": { + "autoincrement": false, + "name": "created_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "entity_type": { + "autoincrement": false, + "name": "entity_type", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "id": { + "autoincrement": false, + "name": "id", + "notNull": true, + "primaryKey": true, + "type": "text" + }, + "name": { + "autoincrement": false, + "name": "name", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "sort_order": { + "autoincrement": false, + "default": 0, + "name": "sort_order", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "updated_at": { + "autoincrement": false, + "name": "updated_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + } + }, + "compositePrimaryKeys": {}, + "foreignKeys": {}, + "indexes": { + "group_entity_sort_idx": { + "columns": ["entity_type", "sort_order"], + "isUnique": false, + "name": "group_entity_sort_idx" + } + }, + "name": "group", + "uniqueConstraints": {} + }, + "message": { + "checkConstraints": { + "message_role_check": { + "name": "message_role_check", + "value": "\"message\".\"role\" IN ('user', 'assistant', 'system')" + }, + "message_status_check": { + "name": "message_status_check", + "value": "\"message\".\"status\" IN ('success', 'error', 'paused')" + } + }, + "columns": { + "assistant_id": { + "autoincrement": false, + "name": "assistant_id", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "assistant_meta": { + "autoincrement": false, + "name": "assistant_meta", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "created_at": { + "autoincrement": false, + "name": "created_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "data": { + "autoincrement": false, + "name": "data", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "deleted_at": { + "autoincrement": false, + "name": "deleted_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "id": { + "autoincrement": false, + "name": "id", + "notNull": true, + "primaryKey": true, + "type": "text" + }, + "model_id": { + "autoincrement": false, + "name": "model_id", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "model_meta": { + "autoincrement": false, + "name": "model_meta", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "parent_id": { + "autoincrement": false, + "name": "parent_id", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "role": { + "autoincrement": false, + "name": "role", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "searchable_text": { + "autoincrement": false, + "name": "searchable_text", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "siblings_group_id": { + "autoincrement": false, + "default": 0, + "name": "siblings_group_id", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "stats": { + "autoincrement": false, + "name": "stats", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "status": { + "autoincrement": false, + "name": "status", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "topic_id": { + "autoincrement": false, + "name": "topic_id", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "trace_id": { + "autoincrement": false, + "name": "trace_id", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "updated_at": { + "autoincrement": false, + "name": "updated_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + } + }, + "compositePrimaryKeys": {}, + "foreignKeys": { + "message_parent_id_message_id_fk": { + "columnsFrom": ["parent_id"], + "columnsTo": ["id"], + "name": "message_parent_id_message_id_fk", + "onDelete": "set null", + "onUpdate": "no action", + "tableFrom": "message", + "tableTo": "message" + }, + "message_topic_id_topic_id_fk": { + "columnsFrom": ["topic_id"], + "columnsTo": ["id"], + "name": "message_topic_id_topic_id_fk", + "onDelete": "cascade", + "onUpdate": "no action", + "tableFrom": "message", + "tableTo": "topic" + } + }, + "indexes": { + "message_parent_id_idx": { + "columns": ["parent_id"], + "isUnique": false, + "name": "message_parent_id_idx" + }, + "message_topic_created_idx": { + "columns": ["topic_id", "created_at"], + "isUnique": false, + "name": "message_topic_created_idx" + }, + "message_trace_id_idx": { + "columns": ["trace_id"], + "isUnique": false, + "name": "message_trace_id_idx" + } + }, + "name": "message", + "uniqueConstraints": {} + }, "preference": { "checkConstraints": {}, "columns": { @@ -76,6 +375,7 @@ }, "scope": { "autoincrement": false, + "default": "'default'", "name": "scope", "notNull": true, "primaryKey": false, @@ -96,16 +396,214 @@ "type": "text" } }, + "compositePrimaryKeys": { + "preference_scope_key_pk": { + "columns": ["scope", "key"], + "name": "preference_scope_key_pk" + } + }, + "foreignKeys": {}, + "indexes": {}, + "name": "preference", + "uniqueConstraints": {} + }, + "tag": { + "checkConstraints": {}, + "columns": { + "color": { + "autoincrement": false, + "name": "color", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "created_at": { + "autoincrement": false, + "name": "created_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "id": { + "autoincrement": false, + "name": "id", + "notNull": true, + "primaryKey": true, + "type": "text" + }, + "name": { + "autoincrement": false, + "name": "name", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "updated_at": { + "autoincrement": false, + "name": "updated_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + } + }, "compositePrimaryKeys": {}, "foreignKeys": {}, "indexes": { - "scope_name_idx": { - "columns": ["scope", "key"], - "isUnique": false, - "name": "scope_name_idx" + "tag_name_unique": { + "columns": ["name"], + "isUnique": true, + "name": "tag_name_unique" } }, - "name": "preference", + "name": "tag", + "uniqueConstraints": {} + }, + "topic": { + "checkConstraints": {}, + "columns": { + "active_node_id": { + "autoincrement": false, + "name": "active_node_id", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "assistant_id": { + "autoincrement": false, + "name": "assistant_id", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "assistant_meta": { + "autoincrement": false, + "name": "assistant_meta", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "created_at": { + "autoincrement": false, + "name": "created_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "deleted_at": { + "autoincrement": false, + "name": "deleted_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "group_id": { + "autoincrement": false, + "name": "group_id", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "id": { + "autoincrement": false, + "name": "id", + "notNull": true, + "primaryKey": true, + "type": "text" + }, + "is_name_manually_edited": { + "autoincrement": false, + "default": false, + "name": "is_name_manually_edited", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "is_pinned": { + "autoincrement": false, + "default": false, + "name": "is_pinned", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "name": { + "autoincrement": false, + "name": "name", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "pinned_order": { + "autoincrement": false, + "default": 0, + "name": "pinned_order", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "prompt": { + "autoincrement": false, + "name": "prompt", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "sort_order": { + "autoincrement": false, + "default": 0, + "name": "sort_order", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "updated_at": { + "autoincrement": false, + "name": "updated_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + } + }, + "compositePrimaryKeys": {}, + "foreignKeys": { + "topic_group_id_group_id_fk": { + "columnsFrom": ["group_id"], + "columnsTo": ["id"], + "name": "topic_group_id_group_id_fk", + "onDelete": "set null", + "onUpdate": "no action", + "tableFrom": "topic", + "tableTo": "group" + } + }, + "indexes": { + "topic_assistant_id_idx": { + "columns": ["assistant_id"], + "isUnique": false, + "name": "topic_assistant_id_idx" + }, + "topic_group_sort_idx": { + "columns": ["group_id", "sort_order"], + "isUnique": false, + "name": "topic_group_sort_idx" + }, + "topic_group_updated_idx": { + "columns": ["group_id", "updated_at"], + "isUnique": false, + "name": "topic_group_updated_idx" + }, + "topic_is_pinned_idx": { + "columns": ["is_pinned", "pinned_order"], + "isUnique": false, + "name": "topic_is_pinned_idx" + }, + "topic_updated_at_idx": { + "columns": ["updated_at"], + "isUnique": false, + "name": "topic_updated_at_idx" + } + }, + "name": "topic", "uniqueConstraints": {} } }, diff --git a/migrations/sqlite-drizzle/meta/0001_snapshot.json b/migrations/sqlite-drizzle/meta/0001_snapshot.json new file mode 100644 index 0000000000..7560d37a6c --- /dev/null +++ b/migrations/sqlite-drizzle/meta/0001_snapshot.json @@ -0,0 +1,612 @@ +{ + "_meta": { + "columns": {}, + "schemas": {}, + "tables": {} + }, + "dialect": "sqlite", + "enums": {}, + "id": "a433b120-0ab8-4f3f-9d1d-766b48c216c8", + "internal": { + "indexes": {} + }, + "prevId": "2ee6f7b2-99da-4de1-b895-48866855b7c6", + "tables": { + "app_state": { + "checkConstraints": {}, + "columns": { + "created_at": { + "autoincrement": false, + "name": "created_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "description": { + "autoincrement": false, + "name": "description", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "key": { + "autoincrement": false, + "name": "key", + "notNull": true, + "primaryKey": true, + "type": "text" + }, + "updated_at": { + "autoincrement": false, + "name": "updated_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "value": { + "autoincrement": false, + "name": "value", + "notNull": true, + "primaryKey": false, + "type": "text" + } + }, + "compositePrimaryKeys": {}, + "foreignKeys": {}, + "indexes": {}, + "name": "app_state", + "uniqueConstraints": {} + }, + "entity_tag": { + "checkConstraints": {}, + "columns": { + "created_at": { + "autoincrement": false, + "name": "created_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "entity_id": { + "autoincrement": false, + "name": "entity_id", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "entity_type": { + "autoincrement": false, + "name": "entity_type", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "tag_id": { + "autoincrement": false, + "name": "tag_id", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "updated_at": { + "autoincrement": false, + "name": "updated_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + } + }, + "compositePrimaryKeys": { + "entity_tag_entity_type_entity_id_tag_id_pk": { + "columns": ["entity_type", "entity_id", "tag_id"], + "name": "entity_tag_entity_type_entity_id_tag_id_pk" + } + }, + "foreignKeys": { + "entity_tag_tag_id_tag_id_fk": { + "columnsFrom": ["tag_id"], + "columnsTo": ["id"], + "name": "entity_tag_tag_id_tag_id_fk", + "onDelete": "cascade", + "onUpdate": "no action", + "tableFrom": "entity_tag", + "tableTo": "tag" + } + }, + "indexes": { + "entity_tag_tag_id_idx": { + "columns": ["tag_id"], + "isUnique": false, + "name": "entity_tag_tag_id_idx" + } + }, + "name": "entity_tag", + "uniqueConstraints": {} + }, + "group": { + "checkConstraints": {}, + "columns": { + "created_at": { + "autoincrement": false, + "name": "created_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "entity_type": { + "autoincrement": false, + "name": "entity_type", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "id": { + "autoincrement": false, + "name": "id", + "notNull": true, + "primaryKey": true, + "type": "text" + }, + "name": { + "autoincrement": false, + "name": "name", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "sort_order": { + "autoincrement": false, + "default": 0, + "name": "sort_order", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "updated_at": { + "autoincrement": false, + "name": "updated_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + } + }, + "compositePrimaryKeys": {}, + "foreignKeys": {}, + "indexes": { + "group_entity_sort_idx": { + "columns": ["entity_type", "sort_order"], + "isUnique": false, + "name": "group_entity_sort_idx" + } + }, + "name": "group", + "uniqueConstraints": {} + }, + "message": { + "checkConstraints": { + "message_role_check": { + "name": "message_role_check", + "value": "\"message\".\"role\" IN ('user', 'assistant', 'system')" + }, + "message_status_check": { + "name": "message_status_check", + "value": "\"message\".\"status\" IN ('pending', 'success', 'error', 'paused')" + } + }, + "columns": { + "assistant_id": { + "autoincrement": false, + "name": "assistant_id", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "assistant_meta": { + "autoincrement": false, + "name": "assistant_meta", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "created_at": { + "autoincrement": false, + "name": "created_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "data": { + "autoincrement": false, + "name": "data", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "deleted_at": { + "autoincrement": false, + "name": "deleted_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "id": { + "autoincrement": false, + "name": "id", + "notNull": true, + "primaryKey": true, + "type": "text" + }, + "model_id": { + "autoincrement": false, + "name": "model_id", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "model_meta": { + "autoincrement": false, + "name": "model_meta", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "parent_id": { + "autoincrement": false, + "name": "parent_id", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "role": { + "autoincrement": false, + "name": "role", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "searchable_text": { + "autoincrement": false, + "name": "searchable_text", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "siblings_group_id": { + "autoincrement": false, + "default": 0, + "name": "siblings_group_id", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "stats": { + "autoincrement": false, + "name": "stats", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "status": { + "autoincrement": false, + "name": "status", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "topic_id": { + "autoincrement": false, + "name": "topic_id", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "trace_id": { + "autoincrement": false, + "name": "trace_id", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "updated_at": { + "autoincrement": false, + "name": "updated_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + } + }, + "compositePrimaryKeys": {}, + "foreignKeys": { + "message_parent_id_message_id_fk": { + "columnsFrom": ["parent_id"], + "columnsTo": ["id"], + "name": "message_parent_id_message_id_fk", + "onDelete": "set null", + "onUpdate": "no action", + "tableFrom": "message", + "tableTo": "message" + }, + "message_topic_id_topic_id_fk": { + "columnsFrom": ["topic_id"], + "columnsTo": ["id"], + "name": "message_topic_id_topic_id_fk", + "onDelete": "cascade", + "onUpdate": "no action", + "tableFrom": "message", + "tableTo": "topic" + } + }, + "indexes": { + "message_parent_id_idx": { + "columns": ["parent_id"], + "isUnique": false, + "name": "message_parent_id_idx" + }, + "message_topic_created_idx": { + "columns": ["topic_id", "created_at"], + "isUnique": false, + "name": "message_topic_created_idx" + }, + "message_trace_id_idx": { + "columns": ["trace_id"], + "isUnique": false, + "name": "message_trace_id_idx" + } + }, + "name": "message", + "uniqueConstraints": {} + }, + "preference": { + "checkConstraints": {}, + "columns": { + "created_at": { + "autoincrement": false, + "name": "created_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "key": { + "autoincrement": false, + "name": "key", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "scope": { + "autoincrement": false, + "default": "'default'", + "name": "scope", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "updated_at": { + "autoincrement": false, + "name": "updated_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "value": { + "autoincrement": false, + "name": "value", + "notNull": false, + "primaryKey": false, + "type": "text" + } + }, + "compositePrimaryKeys": { + "preference_scope_key_pk": { + "columns": ["scope", "key"], + "name": "preference_scope_key_pk" + } + }, + "foreignKeys": {}, + "indexes": {}, + "name": "preference", + "uniqueConstraints": {} + }, + "tag": { + "checkConstraints": {}, + "columns": { + "color": { + "autoincrement": false, + "name": "color", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "created_at": { + "autoincrement": false, + "name": "created_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "id": { + "autoincrement": false, + "name": "id", + "notNull": true, + "primaryKey": true, + "type": "text" + }, + "name": { + "autoincrement": false, + "name": "name", + "notNull": true, + "primaryKey": false, + "type": "text" + }, + "updated_at": { + "autoincrement": false, + "name": "updated_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + } + }, + "compositePrimaryKeys": {}, + "foreignKeys": {}, + "indexes": { + "tag_name_unique": { + "columns": ["name"], + "isUnique": true, + "name": "tag_name_unique" + } + }, + "name": "tag", + "uniqueConstraints": {} + }, + "topic": { + "checkConstraints": {}, + "columns": { + "active_node_id": { + "autoincrement": false, + "name": "active_node_id", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "assistant_id": { + "autoincrement": false, + "name": "assistant_id", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "assistant_meta": { + "autoincrement": false, + "name": "assistant_meta", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "created_at": { + "autoincrement": false, + "name": "created_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "deleted_at": { + "autoincrement": false, + "name": "deleted_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "group_id": { + "autoincrement": false, + "name": "group_id", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "id": { + "autoincrement": false, + "name": "id", + "notNull": true, + "primaryKey": true, + "type": "text" + }, + "is_name_manually_edited": { + "autoincrement": false, + "default": false, + "name": "is_name_manually_edited", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "is_pinned": { + "autoincrement": false, + "default": false, + "name": "is_pinned", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "name": { + "autoincrement": false, + "name": "name", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "pinned_order": { + "autoincrement": false, + "default": 0, + "name": "pinned_order", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "prompt": { + "autoincrement": false, + "name": "prompt", + "notNull": false, + "primaryKey": false, + "type": "text" + }, + "sort_order": { + "autoincrement": false, + "default": 0, + "name": "sort_order", + "notNull": false, + "primaryKey": false, + "type": "integer" + }, + "updated_at": { + "autoincrement": false, + "name": "updated_at", + "notNull": false, + "primaryKey": false, + "type": "integer" + } + }, + "compositePrimaryKeys": {}, + "foreignKeys": { + "topic_group_id_group_id_fk": { + "columnsFrom": ["group_id"], + "columnsTo": ["id"], + "name": "topic_group_id_group_id_fk", + "onDelete": "set null", + "onUpdate": "no action", + "tableFrom": "topic", + "tableTo": "group" + } + }, + "indexes": { + "topic_assistant_id_idx": { + "columns": ["assistant_id"], + "isUnique": false, + "name": "topic_assistant_id_idx" + }, + "topic_group_sort_idx": { + "columns": ["group_id", "sort_order"], + "isUnique": false, + "name": "topic_group_sort_idx" + }, + "topic_group_updated_idx": { + "columns": ["group_id", "updated_at"], + "isUnique": false, + "name": "topic_group_updated_idx" + }, + "topic_is_pinned_idx": { + "columns": ["is_pinned", "pinned_order"], + "isUnique": false, + "name": "topic_is_pinned_idx" + }, + "topic_updated_at_idx": { + "columns": ["updated_at"], + "isUnique": false, + "name": "topic_updated_at_idx" + } + }, + "name": "topic", + "uniqueConstraints": {} + } + }, + "version": "6", + "views": {} +} diff --git a/migrations/sqlite-drizzle/meta/_journal.json b/migrations/sqlite-drizzle/meta/_journal.json index db2791fd7f..c2bac3b325 100644 --- a/migrations/sqlite-drizzle/meta/_journal.json +++ b/migrations/sqlite-drizzle/meta/_journal.json @@ -4,9 +4,16 @@ { "breakpoints": true, "idx": 0, - "tag": "0000_solid_lord_hawal", + "tag": "0000_init", "version": "6", - "when": 1754745234572 + "when": 1767272575118 + }, + { + "breakpoints": true, + "idx": 1, + "tag": "0001_futuristic_human_fly", + "version": "6", + "when": 1767455592181 } ], "version": "7" diff --git a/package.json b/package.json index 281e72435b..6220968c5b 100644 --- a/package.json +++ b/package.json @@ -27,6 +27,7 @@ "scripts": { "start": "electron-vite preview", "dev": "dotenv electron-vite dev", + "dev:watch": "dotenv electron-vite dev -- -w", "debug": "electron-vite -- --inspect --sourcemap --remote-debugging-port=9222", "build": "npm run typecheck && electron-vite build", "build:check": "yarn lint && yarn test", @@ -53,10 +54,10 @@ "typecheck": "concurrently -n \"node,web\" -c \"cyan,magenta\" \"npm run typecheck:node\" \"npm run typecheck:web\"", "typecheck:node": "tsgo --noEmit -p tsconfig.node.json --composite false", "typecheck:web": "tsgo --noEmit -p tsconfig.web.json --composite false", - "check:i18n": "dotenv -e .env -- tsx scripts/check-i18n.ts", - "sync:i18n": "dotenv -e .env -- tsx scripts/sync-i18n.ts", - "update:i18n": "dotenv -e .env -- tsx scripts/update-i18n.ts", - "auto:i18n": "dotenv -e .env -- tsx scripts/auto-translate-i18n.ts", + "i18n:check": "dotenv -e .env -- tsx scripts/check-i18n.ts", + "i18n:sync": "dotenv -e .env -- tsx scripts/sync-i18n.ts", + "i18n:translate": "dotenv -e .env -- tsx scripts/auto-translate-i18n.ts", + "i18n:all": "yarn i18n:check && yarn i18n:sync && yarn i18n:translate", "update:languages": "tsx scripts/update-languages.ts", "update:upgrade-config": "tsx scripts/update-app-upgrade-config.ts", "test": "vitest run --silent", @@ -70,13 +71,12 @@ "test:e2e": "yarn playwright test", "test:lint": "oxlint --deny-warnings && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --cache", "test:scripts": "vitest scripts", - "lint": "oxlint --fix && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --cache && biome lint --write && biome format --write && yarn typecheck && yarn check:i18n && yarn format:check", - "lint:ox": "oxlint --fix && biome lint --write && biome format --write", + "lint": "oxlint --fix && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --cache && biome lint --write && biome format --write && yarn typecheck && yarn i18n:check && yarn format:check", "format": "biome format --write && biome lint --write", "format:check": "biome format && biome lint", "prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky", "claude": "dotenv -e .env -- claude", - "migrations:generate": "drizzle-kit generate --config ./migrations/sqlite-drizzle.config.ts", + "db:migrations:generate": "drizzle-kit generate --config ./migrations/sqlite-drizzle.config.ts", "release:aicore:alpha": "yarn workspace @cherrystudio/ai-core version prerelease --preid alpha --immediate && yarn workspace @cherrystudio/ai-core build && yarn workspace @cherrystudio/ai-core npm publish --tag alpha --access public", "release:aicore:beta": "yarn workspace @cherrystudio/ai-core version prerelease --preid beta --immediate && yarn workspace @cherrystudio/ai-core build && yarn workspace @cherrystudio/ai-core npm publish --tag beta --access public", "release:aicore": "yarn workspace @cherrystudio/ai-core version patch --immediate && yarn workspace @cherrystudio/ai-core build && yarn workspace @cherrystudio/ai-core npm publish --access public", @@ -89,6 +89,7 @@ "@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch", "@paymoapp/electron-shutdown-handler": "^1.1.2", "@strongtz/win32-arm64-msvc": "^0.4.7", + "bonjour-service": "^1.3.0", "emoji-picker-element-data": "^1", "express": "^5.1.0", "font-list": "^2.0.0", @@ -99,10 +100,8 @@ "node-stream-zip": "^1.15.0", "officeparser": "^4.2.0", "os-proxy-config": "^1.1.2", - "qrcode.react": "^4.2.0", "selection-hook": "^1.0.12", "sharp": "^0.34.3", - "socket.io": "^4.8.1", "stream-json": "^1.9.1", "swagger-jsdoc": "^6.2.8", "swagger-ui-express": "^5.0.1", @@ -117,8 +116,8 @@ "@ai-sdk/anthropic": "^2.0.49", "@ai-sdk/cerebras": "^1.0.31", "@ai-sdk/gateway": "^2.0.15", - "@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.43#~/.yarn/patches/@ai-sdk-google-npm-2.0.43-689ed559b3.patch", - "@ai-sdk/google-vertex": "^3.0.79", + "@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.49#~/.yarn/patches/@ai-sdk-google-npm-2.0.49-84720f41bd.patch", + "@ai-sdk/google-vertex": "^3.0.94", "@ai-sdk/huggingface": "^0.0.10", "@ai-sdk/mistral": "^2.0.24", "@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.85#~/.yarn/patches/@ai-sdk-openai-npm-2.0.85-27483d1d6a.patch", @@ -279,7 +278,7 @@ "electron-reload": "^2.0.0-alpha.1", "electron-store": "^8.2.0", "electron-updater": "patch:electron-updater@npm%3A6.7.0#~/.yarn/patches/electron-updater-npm-6.7.0-47b11bb0d4.patch", - "electron-vite": "4.0.1", + "electron-vite": "5.0.0", "electron-window-state": "^5.0.3", "emittery": "^1.0.3", "emoji-picker-element": "^1.22.1", @@ -376,7 +375,7 @@ "undici": "6.21.2", "unified": "^11.0.5", "uuid": "^13.0.0", - "vite": "npm:rolldown-vite@7.1.5", + "vite": "npm:rolldown-vite@7.3.0", "vitest": "^3.2.4", "webdav": "^5.8.0", "winston": "^3.17.0", @@ -390,6 +389,8 @@ "zod": "^4.1.5" }, "resolutions": { + "react": "^19.2.0", + "react-dom": "^19.2.0", "@smithy/types": "4.7.1", "@codemirror/language": "6.11.3", "@codemirror/lint": "6.8.5", @@ -406,7 +407,7 @@ "pkce-challenge@npm:^4.1.0": "patch:pkce-challenge@npm%3A4.1.0#~/.yarn/patches/pkce-challenge-npm-4.1.0-fbc51695a3.patch", "tar-fs": "^2.1.4", "undici": "6.21.2", - "vite": "npm:rolldown-vite@7.1.5", + "vite": "npm:rolldown-vite@7.3.0", "tesseract.js@npm:*": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch", "@ai-sdk/openai@npm:^2.0.52": "patch:@ai-sdk/openai@npm%3A2.0.52#~/.yarn/patches/@ai-sdk-openai-npm-2.0.52-b36d949c76.patch", "@img/sharp-darwin-arm64": "0.34.3", @@ -421,7 +422,10 @@ "@langchain/openai@npm:>=0.2.0 <0.7.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch", "@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.85#~/.yarn/patches/@ai-sdk-openai-npm-2.0.85-27483d1d6a.patch", "@ai-sdk/google@npm:^2.0.40": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch", - "@ai-sdk/openai-compatible@npm:^1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch" + "@ai-sdk/openai-compatible@npm:^1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch", + "@ai-sdk/google@npm:2.0.49": "patch:@ai-sdk/google@npm%3A2.0.49#~/.yarn/patches/@ai-sdk-google-npm-2.0.49-84720f41bd.patch", + "@ai-sdk/openai-compatible@npm:1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch", + "@ai-sdk/openai-compatible@npm:^1.0.19": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch" }, "packageManager": "yarn@4.9.1", "lint-staged": { diff --git a/packages/ai-sdk-provider/package.json b/packages/ai-sdk-provider/package.json index 25864f3b1f..e635f93aeb 100644 --- a/packages/ai-sdk-provider/package.json +++ b/packages/ai-sdk-provider/package.json @@ -41,7 +41,7 @@ "ai": "^5.0.26" }, "dependencies": { - "@ai-sdk/openai-compatible": "^1.0.28", + "@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch", "@ai-sdk/provider": "^2.0.0", "@ai-sdk/provider-utils": "^3.0.17" }, diff --git a/packages/aiCore/package.json b/packages/aiCore/package.json index 6fc0f53344..e73a843b1d 100644 --- a/packages/aiCore/package.json +++ b/packages/aiCore/package.json @@ -42,7 +42,7 @@ "@ai-sdk/anthropic": "^2.0.49", "@ai-sdk/azure": "^2.0.87", "@ai-sdk/deepseek": "^1.0.31", - "@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch", + "@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch", "@ai-sdk/provider": "^2.0.0", "@ai-sdk/provider-utils": "^3.0.17", "@ai-sdk/xai": "^2.0.36", diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts index 22e8b5a605..224cee05ae 100644 --- a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts @@ -22,10 +22,10 @@ const TOOL_USE_TAG_CONFIG: TagConfig = { } /** - * 默认系统提示符模板(提取自 Cherry Studio) + * 默认系统提示符模板 */ -const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \\ -You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use. +export const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \ +You can use one or more tools per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use. ## Tool Use Formatting @@ -74,10 +74,13 @@ Here are the rules you should always follow to solve your task: 4. Never re-do a tool call that you previously did with the exact same parameters. 5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format. +## Response rules + +Respond in the language of the user's query, unless the user instructions specify additional requirements for the language to be used. + # User Instructions {{ USER_SYSTEM_PROMPT }} - -Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.` +` /** * 默认工具使用示例(提取自 Cherry Studio) diff --git a/packages/shared/IpcChannel.ts b/packages/shared/IpcChannel.ts index 75a52534b7..c611cd3151 100644 --- a/packages/shared/IpcChannel.ts +++ b/packages/shared/IpcChannel.ts @@ -233,6 +233,8 @@ export enum IpcChannel { Backup_ListS3Files = 'backup:listS3Files', Backup_DeleteS3File = 'backup:deleteS3File', Backup_CheckS3Connection = 'backup:checkS3Connection', + Backup_CreateLanTransferBackup = 'backup:createLanTransferBackup', + Backup_DeleteTempBackup = 'backup:deleteTempBackup', // data migration DataMigrate_CheckNeeded = 'data-migrate:check-needed', @@ -260,6 +262,7 @@ export enum IpcChannel { System_GetCpuName = 'system:getCpuName', System_CheckGitBash = 'system:checkGitBash', System_GetGitBashPath = 'system:getGitBashPath', + System_GetGitBashPathInfo = 'system:getGitBashPathInfo', System_SetGitBashPath = 'system:setGitBashPath', // DevTools @@ -326,6 +329,7 @@ export enum IpcChannel { Memory_DeleteUser = 'memory:delete-user', Memory_DeleteAllMemoriesForUser = 'memory:delete-all-memories-for-user', Memory_GetUsersList = 'memory:get-users-list', + Memory_MigrateMemoryDb = 'memory:migrate-memory-db', // Data: Preference Preference_Get = 'preference:get', @@ -339,11 +343,10 @@ export enum IpcChannel { // Data: Cache Cache_Sync = 'cache:sync', Cache_SyncBatch = 'cache:sync-batch', + Cache_GetAllShared = 'cache:get-all-shared', // Data: API Channels DataApi_Request = 'data-api:request', - DataApi_Batch = 'data-api:batch', - DataApi_Transaction = 'data-api:transaction', DataApi_Subscribe = 'data-api:subscribe', DataApi_Unsubscribe = 'data-api:unsubscribe', DataApi_Stream = 'data-api:stream', @@ -392,6 +395,7 @@ export enum IpcChannel { OCR_ListProviders = 'ocr:list-providers', // OVMS + Ovms_IsSupported = 'ovms:is-supported', Ovms_AddModel = 'ovms:add-model', Ovms_StopAddModel = 'ovms:stop-addmodel', Ovms_GetModels = 'ovms:get-models', @@ -412,10 +416,14 @@ export enum IpcChannel { ClaudeCodePlugin_ReadContent = 'claudeCodePlugin:read-content', ClaudeCodePlugin_WriteContent = 'claudeCodePlugin:write-content', - // WebSocket - WebSocket_Start = 'webSocket:start', - WebSocket_Stop = 'webSocket:stop', - WebSocket_Status = 'webSocket:status', - WebSocket_SendFile = 'webSocket:send-file', - WebSocket_GetAllCandidates = 'webSocket:get-all-candidates' + // Local Transfer + LocalTransfer_ListServices = 'local-transfer:list', + LocalTransfer_StartScan = 'local-transfer:start-scan', + LocalTransfer_StopScan = 'local-transfer:stop-scan', + LocalTransfer_ServicesUpdated = 'local-transfer:services-updated', + LocalTransfer_Connect = 'local-transfer:connect', + LocalTransfer_Disconnect = 'local-transfer:disconnect', + LocalTransfer_ClientEvent = 'local-transfer:client-event', + LocalTransfer_SendFile = 'local-transfer:send-file', + LocalTransfer_CancelTransfer = 'local-transfer:cancel-transfer' } diff --git a/packages/shared/config/constant.ts b/packages/shared/config/constant.ts index 235250adb2..379f80dc96 100644 --- a/packages/shared/config/constant.ts +++ b/packages/shared/config/constant.ts @@ -488,3 +488,11 @@ export const MACOS_TERMINALS_WITH_COMMANDS: TerminalConfigWithCommand[] = [ // resources/scripts should be maintained manually export const HOME_CHERRY_DIR = '.cherrystudio' + +// Git Bash path configuration types +export type GitBashPathSource = 'manual' | 'auto' + +export interface GitBashPathInfo { + path: string | null + source: GitBashPathSource | null +} diff --git a/packages/shared/config/types.ts b/packages/shared/config/types.ts index 7dff53c753..56f746b0d5 100644 --- a/packages/shared/config/types.ts +++ b/packages/shared/config/types.ts @@ -52,3 +52,196 @@ export interface WebSocketCandidatesResponse { interface: string priority: number } + +export type LocalTransferPeer = { + id: string + name: string + host?: string + fqdn?: string + port?: number + type?: string + protocol?: 'tcp' | 'udp' + addresses: string[] + txt?: Record + updatedAt: number +} + +export type LocalTransferState = { + services: LocalTransferPeer[] + isScanning: boolean + lastScanStartedAt?: number + lastUpdatedAt: number + lastError?: string +} + +export type LanHandshakeRequestMessage = { + type: 'handshake' + deviceName: string + version: string + platform?: string + appVersion?: string +} + +export type LanHandshakeAckMessage = { + type: 'handshake_ack' + accepted: boolean + message?: string +} + +export type LocalTransferConnectPayload = { + peerId: string + metadata?: Record + timeoutMs?: number +} + +export type LanClientEvent = + | { + type: 'ping_sent' + payload: string + timestamp: number + peerId?: string + peerName?: string + } + | { + type: 'pong' + payload?: string + received?: boolean + timestamp: number + peerId?: string + peerName?: string + } + | { + type: 'socket_closed' + reason?: string + timestamp: number + peerId?: string + peerName?: string + } + | { + type: 'error' + message: string + timestamp: number + peerId?: string + peerName?: string + } + | { + type: 'file_transfer_progress' + transferId: string + fileName: string + bytesSent: number + totalBytes: number + chunkIndex: number + totalChunks: number + progress: number // 0-100 + speed: number // bytes/sec + timestamp: number + peerId?: string + peerName?: string + } + | { + type: 'file_transfer_complete' + transferId: string + fileName: string + success: boolean + filePath?: string + error?: string + timestamp: number + peerId?: string + peerName?: string + } + +// ============================================================================= +// LAN File Transfer Protocol Types +// ============================================================================= + +// Constants for file transfer +export const LAN_TRANSFER_TCP_PORT = 53317 +export const LAN_TRANSFER_CHUNK_SIZE = 512 * 1024 // 512KB +export const LAN_TRANSFER_MAX_FILE_SIZE = 500 * 1024 * 1024 // 500MB +export const LAN_TRANSFER_COMPLETE_TIMEOUT_MS = 60_000 // 60s - wait for file_complete after file_end +export const LAN_TRANSFER_GLOBAL_TIMEOUT_MS = 10 * 60 * 1000 // 10 minutes - global transfer timeout + +// Binary protocol constants (v1) +export const LAN_TRANSFER_PROTOCOL_VERSION = '1' +export const LAN_BINARY_FRAME_MAGIC = 0x4353 // "CS" as uint16 +export const LAN_BINARY_TYPE_FILE_CHUNK = 0x01 + +// Messages from Electron (Client/Sender) to Mobile (Server/Receiver) + +/** Request to start file transfer */ +export type LanFileStartMessage = { + type: 'file_start' + transferId: string + fileName: string + fileSize: number + mimeType: string // 'application/zip' + checksum: string // SHA-256 of entire file + totalChunks: number + chunkSize: number +} + +/** + * File chunk data (JSON format) + * @deprecated Use binary frame format in protocol v1. This type is kept for reference only. + */ +export type LanFileChunkMessage = { + type: 'file_chunk' + transferId: string + chunkIndex: number + data: string // Base64 encoded + chunkChecksum: string // SHA-256 of this chunk +} + +/** Notification that all chunks have been sent */ +export type LanFileEndMessage = { + type: 'file_end' + transferId: string +} + +/** Request to cancel file transfer */ +export type LanFileCancelMessage = { + type: 'file_cancel' + transferId: string + reason?: string +} + +// Messages from Mobile (Server/Receiver) to Electron (Client/Sender) + +/** Acknowledgment of file transfer request */ +export type LanFileStartAckMessage = { + type: 'file_start_ack' + transferId: string + accepted: boolean + message?: string // Rejection reason +} + +/** + * Acknowledgment of file chunk received + * @deprecated Protocol v1 uses streaming mode without per-chunk acknowledgment. + * This type is kept for backward compatibility reference only. + */ +export type LanFileChunkAckMessage = { + type: 'file_chunk_ack' + transferId: string + chunkIndex: number + received: boolean + message?: string +} + +/** Final result of file transfer */ +export type LanFileCompleteMessage = { + type: 'file_complete' + transferId: string + success: boolean + filePath?: string // Path where file was saved on mobile + error?: string + // Enhanced error diagnostics + errorCode?: 'CHECKSUM_MISMATCH' | 'INCOMPLETE_TRANSFER' | 'DISK_ERROR' | 'CANCELLED' + receivedChunks?: number + receivedBytes?: number +} + +/** Payload for sending a file via IPC */ +export type LanFileSendPayload = { + filePath: string +} diff --git a/packages/shared/data/README.md b/packages/shared/data/README.md index b65af18e33..30d30ff54d 100644 --- a/packages/shared/data/README.md +++ b/packages/shared/data/README.md @@ -1,106 +1,50 @@ -# Cherry Studio Shared Data +# Shared Data Types -This directory contains shared type definitions and schemas for the Cherry Studio data management systems. These files provide type safety and consistency across the entire application. +This directory contains shared type definitions for Cherry Studio's data layer. -## 📁 Directory Structure +## Documentation + +For comprehensive documentation, see: +- **Overview**: [docs/en/references/data/README.md](../../../docs/en/references/data/README.md) +- **Cache Types**: [cache-overview.md](../../../docs/en/references/data/cache-overview.md) +- **Preference Types**: [preference-overview.md](../../../docs/en/references/data/preference-overview.md) +- **API Types**: [api-types.md](../../../docs/en/references/data/api-types.md) + +## Directory Structure ``` packages/shared/data/ ├── api/ # Data API type system -│ ├── index.ts # Barrel exports for clean imports -│ ├── apiSchemas.ts # API endpoint definitions and mappings -│ ├── apiTypes.ts # Core request/response infrastructure types -│ ├── apiModels.ts # Business entity types and DTOs -│ ├── apiPaths.ts # API path definitions and utilities -│ └── errorCodes.ts # Standardized error handling +│ ├── index.ts # Barrel exports +│ ├── apiTypes.ts # Core request/response types +│ ├── apiPaths.ts # Path template utilities +│ ├── apiErrors.ts # Error handling +│ └── schemas/ # Domain-specific API schemas ├── cache/ # Cache system type definitions -│ ├── cacheTypes.ts # Core cache infrastructure types -│ ├── cacheSchemas.ts # Cache key schemas and type mappings -│ └── cacheValueTypes.ts # Cache value type definitions +│ ├── cacheTypes.ts # Core cache types +│ ├── cacheSchemas.ts # Cache key schemas +│ └── cacheValueTypes.ts # Cache value types ├── preference/ # Preference system type definitions -│ ├── preferenceTypes.ts # Core preference system types -│ └── preferenceSchemas.ts # Preference schemas and default values -└── README.md # This file +│ ├── preferenceTypes.ts # Core preference types +│ └── preferenceSchemas.ts # Preference schemas +└── types/ # Shared data types ``` -## 🏗️ System Overview +## Quick Reference -This directory provides type definitions for three main data management systems: +### Import Conventions -### API System (`api/`) -- **Purpose**: Type-safe IPC communication between Main and Renderer processes -- **Features**: RESTful patterns, error handling, business entity definitions -- **Usage**: Ensures type safety for all data API operations - -### Cache System (`cache/`) -- **Purpose**: Type definitions for three-layer caching architecture -- **Features**: Memory/shared/persist cache schemas, TTL support, hook integration -- **Usage**: Type-safe caching operations across the application - -### Preference System (`preference/`) -- **Purpose**: User configuration and settings management -- **Features**: 158 configuration items, default values, nested key support -- **Usage**: Type-safe preference access and synchronization - -## 📋 File Categories - -**Framework Infrastructure** - These are TypeScript type definitions that: -- ✅ Exist only at compile time -- ✅ Provide type safety and IntelliSense support -- ✅ Define contracts between application layers -- ✅ Enable static analysis and error detection - -## 📖 Usage Examples - -### API Types ```typescript -// Import API types -import type { DataRequest, DataResponse, ApiSchemas } from '@shared/data/api' -``` +// API infrastructure types (from barrel) +import type { DataRequest, DataResponse, ApiClient } from '@shared/data/api' +import { ErrorCode, DataApiError, DataApiErrorFactory } from '@shared/data/api' -### Cache Types -```typescript -// Import cache types +// Domain DTOs (from schema files) +import type { Topic, CreateTopicDto } from '@shared/data/api/schemas/topic' + +// Cache types import type { UseCacheKey, UseSharedCacheKey } from '@shared/data/cache' + +// Preference types +import type { PreferenceKeyType } from '@shared/data/preference' ``` - -### Preference Types -```typescript -// Import preference types -import type { PreferenceKeyType, PreferenceDefaultScopeType } from '@shared/data/preference' -``` - -## 🔧 Development Guidelines - -### Adding Cache Types -1. Add cache key to `cache/cacheSchemas.ts` -2. Define value type in `cache/cacheValueTypes.ts` -3. Update type mappings for type safety - -### Adding Preference Types -1. Add preference key to `preference/preferenceSchemas.ts` -2. Define default value and type -3. Preference system automatically picks up new keys - -### Adding API Types -1. Define business entities in `api/apiModels.ts` -2. Add endpoint definitions to `api/apiSchemas.ts` -3. Export types from `api/index.ts` - -### Best Practices -- Use `import type` for type-only imports -- Follow existing naming conventions -- Document complex types with JSDoc -- Maintain type safety across all imports - -## 🔗 Related Implementation - -### Main Process Services -- `src/main/data/CacheService.ts` - Main process cache management -- `src/main/data/PreferenceService.ts` - Preference management service -- `src/main/data/DataApiService.ts` - Data API coordination service - -### Renderer Process Services -- `src/renderer/src/data/CacheService.ts` - Renderer cache service -- `src/renderer/src/data/PreferenceService.ts` - Renderer preference service -- `src/renderer/src/data/DataApiService.ts` - Renderer API client \ No newline at end of file diff --git a/packages/shared/data/api/README.md b/packages/shared/data/api/README.md new file mode 100644 index 0000000000..eb06824d87 --- /dev/null +++ b/packages/shared/data/api/README.md @@ -0,0 +1,42 @@ +# Data API Type System + +This directory contains type definitions for the DataApi system. + +## Documentation + +- **DataApi Overview**: [docs/en/references/data/data-api-overview.md](../../../../docs/en/references/data/data-api-overview.md) +- **API Types**: [api-types.md](../../../../docs/en/references/data/api-types.md) +- **API Design Guidelines**: [api-design-guidelines.md](../../../../docs/en/references/data/api-design-guidelines.md) + +## Directory Structure + +``` +packages/shared/data/api/ +├── index.ts # Barrel exports +├── apiTypes.ts # Core request/response types +├── apiPaths.ts # Path template utilities +├── apiErrors.ts # Error handling +└── schemas/ + ├── index.ts # Schema composition + └── *.ts # Domain-specific schemas +``` + +## Quick Reference + +### Import Conventions + +```typescript +// Infrastructure types (via barrel) +import type { DataRequest, DataResponse, ApiClient } from '@shared/data/api' +import { ErrorCode, DataApiError, DataApiErrorFactory } from '@shared/data/api' + +// Domain DTOs (directly from schema files) +import type { Topic, CreateTopicDto } from '@shared/data/api/schemas/topic' +import type { Message, CreateMessageDto } from '@shared/data/api/schemas/message' +``` + +### Adding New Schemas + +1. Create schema file in `schemas/` (e.g., `topic.ts`) +2. Register in `schemas/index.ts` using intersection type +3. Implement handlers in `src/main/data/api/handlers/` diff --git a/packages/shared/data/api/apiErrors.ts b/packages/shared/data/api/apiErrors.ts new file mode 100644 index 0000000000..4819ac32e2 --- /dev/null +++ b/packages/shared/data/api/apiErrors.ts @@ -0,0 +1,817 @@ +/** + * @fileoverview Centralized error handling for the Data API system + * + * This module provides comprehensive error management including: + * - ErrorCode enum with HTTP status mapping + * - Type-safe error details for each error type + * - DataApiError class for structured error handling + * - DataApiErrorFactory for convenient error creation + * - Retryability configuration for automatic retry logic + * + * @example + * ```typescript + * import { DataApiError, DataApiErrorFactory, ErrorCode } from '@shared/data/api' + * + * // Create and throw an error + * throw DataApiErrorFactory.notFound('Topic', 'abc123') + * + * // Check if error is retryable + * if (error instanceof DataApiError && error.isRetryable) { + * await retry(operation) + * } + * ``` + */ + +import type { HttpMethod } from './apiTypes' + +// ============================================================================ +// Error Code Enum +// ============================================================================ + +/** + * Standard error codes for the Data API system. + * Maps to HTTP status codes via ERROR_STATUS_MAP. + */ +export enum ErrorCode { + // ───────────────────────────────────────────────────────────────── + // Client errors (4xx) - Issues with the request itself + // ───────────────────────────────────────────────────────────────── + + /** 400 - Malformed request syntax or invalid parameters */ + BAD_REQUEST = 'BAD_REQUEST', + + /** 401 - Authentication required or credentials invalid */ + UNAUTHORIZED = 'UNAUTHORIZED', + + /** 404 - Requested resource does not exist */ + NOT_FOUND = 'NOT_FOUND', + + /** 405 - HTTP method not supported for this endpoint */ + METHOD_NOT_ALLOWED = 'METHOD_NOT_ALLOWED', + + /** 422 - Request body fails validation rules */ + VALIDATION_ERROR = 'VALIDATION_ERROR', + + /** 429 - Too many requests, retry after delay */ + RATE_LIMIT_EXCEEDED = 'RATE_LIMIT_EXCEEDED', + + /** 403 - Authenticated but lacks required permissions */ + PERMISSION_DENIED = 'PERMISSION_DENIED', + + /** + * 400 - Operation is not valid in current state. + * Use when: deleting root message without cascade, moving node would create cycle, + * or any operation that violates business rules but isn't a validation error. + */ + INVALID_OPERATION = 'INVALID_OPERATION', + + // ───────────────────────────────────────────────────────────────── + // Server errors (5xx) - Issues on the server side + // ───────────────────────────────────────────────────────────────── + + /** 500 - Unexpected server error */ + INTERNAL_SERVER_ERROR = 'INTERNAL_SERVER_ERROR', + + /** 500 - Database operation failed (connection, query, constraint) */ + DATABASE_ERROR = 'DATABASE_ERROR', + + /** 503 - Service temporarily unavailable, retry later */ + SERVICE_UNAVAILABLE = 'SERVICE_UNAVAILABLE', + + /** 504 - Request timed out waiting for response */ + TIMEOUT = 'TIMEOUT', + + // ───────────────────────────────────────────────────────────────── + // Application-specific errors + // ───────────────────────────────────────────────────────────────── + + /** 500 - Data migration process failed */ + MIGRATION_ERROR = 'MIGRATION_ERROR', + + /** + * 423 - Resource is temporarily locked by another operation. + * Use when: file being exported, data migration in progress, + * or resource held by background process. + * Retryable: Yes (resource may be released) + */ + RESOURCE_LOCKED = 'RESOURCE_LOCKED', + + /** + * 409 - Optimistic lock conflict, resource was modified after read. + * Use when: multi-window editing same topic, version mismatch + * on update, or stale data detected during save. + * Client should: refresh data and retry or notify user. + */ + CONCURRENT_MODIFICATION = 'CONCURRENT_MODIFICATION', + + /** + * 409 - Data integrity violation or inconsistent state detected. + * Use when: referential integrity broken, computed values mismatch, + * or data corruption found during validation. + * Not retryable: requires investigation or data repair. + */ + DATA_INCONSISTENT = 'DATA_INCONSISTENT' +} + +// ============================================================================ +// Error Code Mappings +// ============================================================================ + +/** + * Maps error codes to HTTP status codes. + * Used by DataApiError and DataApiErrorFactory. + */ +export const ERROR_STATUS_MAP: Record = { + // Client errors (4xx) + [ErrorCode.BAD_REQUEST]: 400, + [ErrorCode.UNAUTHORIZED]: 401, + [ErrorCode.NOT_FOUND]: 404, + [ErrorCode.METHOD_NOT_ALLOWED]: 405, + [ErrorCode.VALIDATION_ERROR]: 422, + [ErrorCode.RATE_LIMIT_EXCEEDED]: 429, + [ErrorCode.PERMISSION_DENIED]: 403, + [ErrorCode.INVALID_OPERATION]: 400, + + // Server errors (5xx) + [ErrorCode.INTERNAL_SERVER_ERROR]: 500, + [ErrorCode.DATABASE_ERROR]: 500, + [ErrorCode.SERVICE_UNAVAILABLE]: 503, + [ErrorCode.TIMEOUT]: 504, + + // Application-specific errors + [ErrorCode.MIGRATION_ERROR]: 500, + [ErrorCode.RESOURCE_LOCKED]: 423, + [ErrorCode.CONCURRENT_MODIFICATION]: 409, + [ErrorCode.DATA_INCONSISTENT]: 409 +} + +/** + * Default error messages for each error code. + * Used when no custom message is provided. + */ +export const ERROR_MESSAGES: Record = { + [ErrorCode.BAD_REQUEST]: 'Bad request: Invalid request format or parameters', + [ErrorCode.UNAUTHORIZED]: 'Unauthorized: Authentication required', + [ErrorCode.NOT_FOUND]: 'Not found: Requested resource does not exist', + [ErrorCode.METHOD_NOT_ALLOWED]: 'Method not allowed: HTTP method not supported for this endpoint', + [ErrorCode.VALIDATION_ERROR]: 'Validation error: Request data does not meet requirements', + [ErrorCode.RATE_LIMIT_EXCEEDED]: 'Rate limit exceeded: Too many requests', + [ErrorCode.PERMISSION_DENIED]: 'Permission denied: Insufficient permissions for this operation', + [ErrorCode.INVALID_OPERATION]: 'Invalid operation: Operation not allowed in current state', + + [ErrorCode.INTERNAL_SERVER_ERROR]: 'Internal server error: An unexpected error occurred', + [ErrorCode.DATABASE_ERROR]: 'Database error: Failed to access or modify data', + [ErrorCode.SERVICE_UNAVAILABLE]: 'Service unavailable: The service is temporarily unavailable', + [ErrorCode.TIMEOUT]: 'Timeout: Request timed out waiting for response', + + [ErrorCode.MIGRATION_ERROR]: 'Migration error: Failed to migrate data', + [ErrorCode.RESOURCE_LOCKED]: 'Resource locked: Resource is currently locked by another operation', + [ErrorCode.CONCURRENT_MODIFICATION]: 'Concurrent modification: Resource was modified by another user', + [ErrorCode.DATA_INCONSISTENT]: 'Data inconsistent: Data integrity violation detected' +} + +// ============================================================================ +// Request Context +// ============================================================================ + +/** + * Request context attached to errors for debugging and logging. + * Always transmitted via IPC for frontend display. + */ +export interface RequestContext { + /** Unique identifier for request correlation */ + requestId: string + /** API path that was called */ + path: string + /** HTTP method used */ + method: HttpMethod + /** Timestamp when request was initiated */ + timestamp?: number +} + +// ============================================================================ +// Error-specific Detail Types +// ============================================================================ + +/** + * Details for VALIDATION_ERROR - field-level validation failures. + * Maps field names to arrays of error messages. + */ +export interface ValidationErrorDetails { + fieldErrors: Record +} + +/** + * Details for NOT_FOUND - which resource was not found. + */ +export interface NotFoundErrorDetails { + resource: string + id?: string +} + +/** + * Details for DATABASE_ERROR - underlying database failure info. + */ +export interface DatabaseErrorDetails { + originalError: string + operation?: string +} + +/** + * Details for TIMEOUT - what operation timed out. + */ +export interface TimeoutErrorDetails { + operation?: string + timeoutMs?: number +} + +/** + * Details for DATA_INCONSISTENT - what data is inconsistent. + */ +export interface DataInconsistentErrorDetails { + resource: string + description?: string +} + +/** + * Details for PERMISSION_DENIED - what action was denied. + */ +export interface PermissionDeniedErrorDetails { + action: string + resource?: string +} + +/** + * Details for INVALID_OPERATION - what operation was invalid. + */ +export interface InvalidOperationErrorDetails { + operation: string + reason?: string +} + +/** + * Details for RESOURCE_LOCKED - which resource is locked. + */ +export interface ResourceLockedErrorDetails { + resource: string + id: string + lockedBy?: string +} + +/** + * Details for CONCURRENT_MODIFICATION - which resource was concurrently modified. + */ +export interface ConcurrentModificationErrorDetails { + resource: string + id: string +} + +/** + * Details for INTERNAL_SERVER_ERROR - context about the failure. + */ +export interface InternalErrorDetails { + originalError: string + context?: string +} + +// ============================================================================ +// Type Mapping for Error Details +// ============================================================================ + +/** + * Maps error codes to their specific detail types. + * Only define for error codes that have structured details. + */ +export type ErrorDetailsMap = { + [ErrorCode.VALIDATION_ERROR]: ValidationErrorDetails + [ErrorCode.NOT_FOUND]: NotFoundErrorDetails + [ErrorCode.DATABASE_ERROR]: DatabaseErrorDetails + [ErrorCode.TIMEOUT]: TimeoutErrorDetails + [ErrorCode.DATA_INCONSISTENT]: DataInconsistentErrorDetails + [ErrorCode.PERMISSION_DENIED]: PermissionDeniedErrorDetails + [ErrorCode.INVALID_OPERATION]: InvalidOperationErrorDetails + [ErrorCode.RESOURCE_LOCKED]: ResourceLockedErrorDetails + [ErrorCode.CONCURRENT_MODIFICATION]: ConcurrentModificationErrorDetails + [ErrorCode.INTERNAL_SERVER_ERROR]: InternalErrorDetails +} + +/** + * Get the detail type for a specific error code. + * Falls back to generic Record for unmapped codes. + */ +export type DetailsForCode = T extends keyof ErrorDetailsMap + ? ErrorDetailsMap[T] + : Record | undefined + +// ============================================================================ +// Retryability Configuration +// ============================================================================ + +/** + * Set of error codes that are safe to retry automatically. + * These represent temporary failures that may succeed on retry. + */ +export const RETRYABLE_ERROR_CODES: ReadonlySet = new Set([ + ErrorCode.SERVICE_UNAVAILABLE, // 503 - Service temporarily down + ErrorCode.TIMEOUT, // 504 - Request timed out + ErrorCode.RATE_LIMIT_EXCEEDED, // 429 - Can retry after delay + ErrorCode.DATABASE_ERROR, // 500 - Temporary DB issues + ErrorCode.INTERNAL_SERVER_ERROR, // 500 - May be transient + ErrorCode.RESOURCE_LOCKED // 423 - Lock may be released +]) + +/** + * Check if an error code represents a retryable condition. + * @param code - The error code to check + * @returns true if the error is safe to retry + */ +export function isRetryableErrorCode(code: ErrorCode): boolean { + return RETRYABLE_ERROR_CODES.has(code) +} + +// ============================================================================ +// Serialized Error Interface (for IPC transmission) +// ============================================================================ + +/** + * Serialized error structure for IPC transmission. + * Used in DataResponse.error field. + * Note: Does not include stack trace - rely on Main process logs. + */ +export interface SerializedDataApiError { + /** Error code from ErrorCode enum */ + code: ErrorCode | string + /** Human-readable error message */ + message: string + /** HTTP status code */ + status: number + /** Structured error details */ + details?: Record + /** Request context for debugging */ + requestContext?: RequestContext +} + +// ============================================================================ +// DataApiError Class +// ============================================================================ + +/** + * Custom error class for Data API errors. + * + * Provides type-safe error handling with: + * - Typed error codes and details + * - Retryability checking via `isRetryable` getter + * - IPC serialization via `toJSON()` / `fromJSON()` + * - Request context for debugging + * + * @example + * ```typescript + * // Throw a typed error + * throw new DataApiError( + * ErrorCode.NOT_FOUND, + * 'Topic not found', + * 404, + * { resource: 'Topic', id: 'abc123' } + * ) + * + * // Check if error is retryable + * if (error.isRetryable) { + * await retry(operation) + * } + * ``` + */ +export class DataApiError extends Error { + /** Error code from ErrorCode enum */ + public readonly code: T + /** HTTP status code */ + public readonly status: number + /** Structured error details (type depends on error code) */ + public readonly details?: DetailsForCode + /** Request context for debugging */ + public readonly requestContext?: RequestContext + + constructor(code: T, message: string, status: number, details?: DetailsForCode, requestContext?: RequestContext) { + super(message) + this.name = 'DataApiError' + this.code = code + this.status = status + this.details = details + this.requestContext = requestContext + + // Maintains proper stack trace for where error was thrown + if (Error.captureStackTrace) { + Error.captureStackTrace(this, DataApiError) + } + } + + /** + * Whether this error is safe to retry automatically. + * Based on the RETRYABLE_ERROR_CODES configuration. + */ + get isRetryable(): boolean { + return isRetryableErrorCode(this.code) + } + + /** + * Whether this is a client error (4xx status). + * Client errors typically indicate issues with the request itself. + */ + get isClientError(): boolean { + return this.status >= 400 && this.status < 500 + } + + /** + * Whether this is a server error (5xx status). + * Server errors typically indicate issues on the server side. + */ + get isServerError(): boolean { + return this.status >= 500 && this.status < 600 + } + + /** + * Serialize for IPC transmission. + * Note: Stack trace is NOT included - rely on Main process logs. + * @returns Serialized error object for IPC + */ + toJSON(): SerializedDataApiError { + return { + code: this.code, + message: this.message, + status: this.status, + details: this.details as Record | undefined, + requestContext: this.requestContext + } + } + + /** + * Reconstruct DataApiError from IPC response. + * @param error - Serialized error from IPC + * @returns DataApiError instance + */ + static fromJSON(error: SerializedDataApiError): DataApiError { + return new DataApiError(error.code as ErrorCode, error.message, error.status, error.details, error.requestContext) + } + + /** + * Create DataApiError from a generic Error. + * @param error - Original error + * @param code - Error code to use (defaults to INTERNAL_SERVER_ERROR) + * @param requestContext - Optional request context + * @returns DataApiError instance + */ + static fromError( + error: Error, + code: ErrorCode = ErrorCode.INTERNAL_SERVER_ERROR, + requestContext?: RequestContext + ): DataApiError { + return new DataApiError( + code, + error.message, + ERROR_STATUS_MAP[code], + { originalError: error.message, context: error.name } as DetailsForCode, + requestContext + ) + } +} + +// ============================================================================ +// DataApiErrorFactory +// ============================================================================ + +/** + * Factory for creating standardized DataApiError instances. + * Provides convenience methods for common error types with proper typing. + * + * @example + * ```typescript + * // Create a not found error + * throw DataApiErrorFactory.notFound('Topic', 'abc123') + * + * // Create a validation error + * throw DataApiErrorFactory.validation({ + * name: ['Name is required'], + * email: ['Invalid email format'] + * }) + * ``` + */ +export class DataApiErrorFactory { + /** + * Create a DataApiError with any error code. + * Use specialized methods when available for better type safety. + * @param code - Error code from ErrorCode enum + * @param customMessage - Optional custom error message + * @param details - Optional structured error details + * @param requestContext - Optional request context + * @returns DataApiError instance + */ + static create( + code: T, + customMessage?: string, + details?: DetailsForCode, + requestContext?: RequestContext + ): DataApiError { + return new DataApiError( + code, + customMessage || ERROR_MESSAGES[code], + ERROR_STATUS_MAP[code], + details, + requestContext + ) + } + + /** + * Create a validation error with field-specific error messages. + * @param fieldErrors - Map of field names to error messages + * @param message - Optional custom message + * @param requestContext - Optional request context + * @returns DataApiError with VALIDATION_ERROR code + */ + static validation( + fieldErrors: Record, + message?: string, + requestContext?: RequestContext + ): DataApiError { + return new DataApiError( + ErrorCode.VALIDATION_ERROR, + message || 'Request validation failed', + ERROR_STATUS_MAP[ErrorCode.VALIDATION_ERROR], + { fieldErrors }, + requestContext + ) + } + + /** + * Create a not found error for a specific resource. + * @param resource - Resource type name (e.g., 'Topic', 'Message') + * @param id - Optional resource identifier + * @param requestContext - Optional request context + * @returns DataApiError with NOT_FOUND code + */ + static notFound(resource: string, id?: string, requestContext?: RequestContext): DataApiError { + const message = id ? `${resource} with id '${id}' not found` : `${resource} not found` + return new DataApiError( + ErrorCode.NOT_FOUND, + message, + ERROR_STATUS_MAP[ErrorCode.NOT_FOUND], + { resource, id }, + requestContext + ) + } + + /** + * Create a database error from an original error. + * @param originalError - The underlying database error + * @param operation - Description of the failed operation + * @param requestContext - Optional request context + * @returns DataApiError with DATABASE_ERROR code + */ + static database( + originalError: Error, + operation?: string, + requestContext?: RequestContext + ): DataApiError { + return new DataApiError( + ErrorCode.DATABASE_ERROR, + `Database operation failed${operation ? `: ${operation}` : ''}`, + ERROR_STATUS_MAP[ErrorCode.DATABASE_ERROR], + { originalError: originalError.message, operation }, + requestContext + ) + } + + /** + * Create an internal server error from an unexpected error. + * @param originalError - The underlying error + * @param context - Additional context about where the error occurred + * @param requestContext - Optional request context + * @returns DataApiError with INTERNAL_SERVER_ERROR code + */ + static internal( + originalError: Error, + context?: string, + requestContext?: RequestContext + ): DataApiError { + const message = context + ? `Internal error in ${context}: ${originalError.message}` + : `Internal error: ${originalError.message}` + return new DataApiError( + ErrorCode.INTERNAL_SERVER_ERROR, + message, + ERROR_STATUS_MAP[ErrorCode.INTERNAL_SERVER_ERROR], + { originalError: originalError.message, context }, + requestContext + ) + } + + /** + * Create a permission denied error. + * @param action - The action that was denied + * @param resource - Optional resource that access was denied to + * @param requestContext - Optional request context + * @returns DataApiError with PERMISSION_DENIED code + */ + static permissionDenied( + action: string, + resource?: string, + requestContext?: RequestContext + ): DataApiError { + const message = resource ? `Permission denied: Cannot ${action} ${resource}` : `Permission denied: Cannot ${action}` + return new DataApiError( + ErrorCode.PERMISSION_DENIED, + message, + ERROR_STATUS_MAP[ErrorCode.PERMISSION_DENIED], + { action, resource }, + requestContext + ) + } + + /** + * Create a timeout error. + * @param operation - Description of the operation that timed out + * @param timeoutMs - The timeout duration in milliseconds + * @param requestContext - Optional request context + * @returns DataApiError with TIMEOUT code + */ + static timeout( + operation?: string, + timeoutMs?: number, + requestContext?: RequestContext + ): DataApiError { + const message = operation + ? `Request timeout: ${operation}${timeoutMs ? ` (${timeoutMs}ms)` : ''}` + : `Request timeout${timeoutMs ? ` (${timeoutMs}ms)` : ''}` + return new DataApiError( + ErrorCode.TIMEOUT, + message, + ERROR_STATUS_MAP[ErrorCode.TIMEOUT], + { operation, timeoutMs }, + requestContext + ) + } + + /** + * Create an invalid operation error. + * Use when an operation violates business rules but isn't a validation error. + * @param operation - Description of the invalid operation + * @param reason - Optional reason why the operation is invalid + * @param requestContext - Optional request context + * @returns DataApiError with INVALID_OPERATION code + */ + static invalidOperation( + operation: string, + reason?: string, + requestContext?: RequestContext + ): DataApiError { + const message = reason ? `Invalid operation: ${operation} - ${reason}` : `Invalid operation: ${operation}` + return new DataApiError( + ErrorCode.INVALID_OPERATION, + message, + ERROR_STATUS_MAP[ErrorCode.INVALID_OPERATION], + { operation, reason }, + requestContext + ) + } + + /** + * Create a data inconsistency error. + * @param resource - The resource with inconsistent data + * @param description - Description of the inconsistency + * @param requestContext - Optional request context + * @returns DataApiError with DATA_INCONSISTENT code + */ + static dataInconsistent( + resource: string, + description?: string, + requestContext?: RequestContext + ): DataApiError { + const message = description + ? `Data inconsistent in ${resource}: ${description}` + : `Data inconsistent in ${resource}` + return new DataApiError( + ErrorCode.DATA_INCONSISTENT, + message, + ERROR_STATUS_MAP[ErrorCode.DATA_INCONSISTENT], + { resource, description }, + requestContext + ) + } + + /** + * Create a resource locked error. + * Use when a resource is temporarily unavailable due to: + * - File being exported + * - Data migration in progress + * - Resource held by background process + * + * @param resource - Resource type name + * @param id - Resource identifier + * @param lockedBy - Optional description of what's holding the lock + * @param requestContext - Optional request context + * @returns DataApiError with RESOURCE_LOCKED code + */ + static resourceLocked( + resource: string, + id: string, + lockedBy?: string, + requestContext?: RequestContext + ): DataApiError { + const message = lockedBy + ? `${resource} '${id}' is locked by ${lockedBy}` + : `${resource} '${id}' is currently locked` + return new DataApiError( + ErrorCode.RESOURCE_LOCKED, + message, + ERROR_STATUS_MAP[ErrorCode.RESOURCE_LOCKED], + { resource, id, lockedBy }, + requestContext + ) + } + + /** + * Create a concurrent modification error. + * Use when an optimistic lock conflict occurs: + * - Multi-window editing same topic + * - Version mismatch on update + * - Stale data detected during save + * + * Client should refresh data and retry or notify user. + * + * @param resource - Resource type name + * @param id - Resource identifier + * @param requestContext - Optional request context + * @returns DataApiError with CONCURRENT_MODIFICATION code + */ + static concurrentModification( + resource: string, + id: string, + requestContext?: RequestContext + ): DataApiError { + return new DataApiError( + ErrorCode.CONCURRENT_MODIFICATION, + `${resource} '${id}' was modified by another user`, + ERROR_STATUS_MAP[ErrorCode.CONCURRENT_MODIFICATION], + { resource, id }, + requestContext + ) + } +} + +// ============================================================================ +// Utility Functions +// ============================================================================ + +/** + * Check if an error is a DataApiError instance. + * @param error - Any error object + * @returns true if the error is a DataApiError + */ +export function isDataApiError(error: unknown): error is DataApiError { + return error instanceof DataApiError +} + +/** + * Check if an object is a serialized DataApiError. + * @param error - Any object + * @returns true if the object has DataApiError structure + */ +export function isSerializedDataApiError(error: unknown): error is SerializedDataApiError { + return ( + typeof error === 'object' && + error !== null && + 'code' in error && + 'message' in error && + 'status' in error && + typeof (error as SerializedDataApiError).code === 'string' && + typeof (error as SerializedDataApiError).message === 'string' && + typeof (error as SerializedDataApiError).status === 'number' + ) +} + +/** + * Convert any error to a DataApiError. + * If already a DataApiError, returns as-is. + * Otherwise, wraps in an INTERNAL_SERVER_ERROR. + * + * @param error - Any error + * @param context - Optional context description + * @returns DataApiError instance + */ +export function toDataApiError(error: unknown, context?: string): DataApiError { + if (isDataApiError(error)) { + return error + } + + if (isSerializedDataApiError(error)) { + return DataApiError.fromJSON(error) + } + + if (error instanceof Error) { + return DataApiErrorFactory.internal(error, context) + } + + return DataApiErrorFactory.create( + ErrorCode.INTERNAL_SERVER_ERROR, + `Unknown error${context ? ` in ${context}` : ''}: ${String(error)}`, + { originalError: String(error), context } as DetailsForCode + ) +} diff --git a/packages/shared/data/api/apiModels.ts b/packages/shared/data/api/apiModels.ts deleted file mode 100644 index 08107a9729..0000000000 --- a/packages/shared/data/api/apiModels.ts +++ /dev/null @@ -1,107 +0,0 @@ -/** - * Generic test model definitions - * Contains flexible types for comprehensive API testing - */ - -/** - * Generic test item entity - flexible structure for testing various scenarios - */ -export interface TestItem { - /** Unique identifier */ - id: string - /** Item title */ - title: string - /** Optional description */ - description?: string - /** Type category */ - type: string - /** Current status */ - status: string - /** Priority level */ - priority: string - /** Associated tags */ - tags: string[] - /** Creation timestamp */ - createdAt: string - /** Last update timestamp */ - updatedAt: string - /** Additional metadata */ - metadata: Record -} - -/** - * Data Transfer Objects (DTOs) for test operations - */ - -/** - * DTO for creating a new test item - */ -export interface CreateTestItemDto { - /** Item title */ - title: string - /** Optional description */ - description?: string - /** Type category */ - type?: string - /** Current status */ - status?: string - /** Priority level */ - priority?: string - /** Associated tags */ - tags?: string[] - /** Additional metadata */ - metadata?: Record -} - -/** - * DTO for updating an existing test item - */ -export interface UpdateTestItemDto { - /** Updated title */ - title?: string - /** Updated description */ - description?: string - /** Updated type */ - type?: string - /** Updated status */ - status?: string - /** Updated priority */ - priority?: string - /** Updated tags */ - tags?: string[] - /** Updated metadata */ - metadata?: Record -} - -/** - * Bulk operation types for batch processing - */ - -/** - * Request for bulk operations on multiple items - */ -export interface BulkOperationRequest { - /** Type of bulk operation to perform */ - operation: 'create' | 'update' | 'delete' | 'archive' | 'restore' - /** Array of data items to process */ - data: TData[] -} - -/** - * Response from a bulk operation - */ -export interface BulkOperationResponse { - /** Number of successfully processed items */ - successful: number - /** Number of items that failed processing */ - failed: number - /** Array of errors that occurred during processing */ - errors: Array<{ - /** Index of the item that failed */ - index: number - /** Error message */ - error: string - /** Optional additional error data */ - data?: any - }> -} diff --git a/packages/shared/data/api/apiPaths.ts b/packages/shared/data/api/apiPaths.ts index a947157869..7cd5397e02 100644 --- a/packages/shared/data/api/apiPaths.ts +++ b/packages/shared/data/api/apiPaths.ts @@ -1,4 +1,4 @@ -import type { ApiSchemas } from './apiSchemas' +import type { ApiSchemas } from './schemas' /** * Template literal type utilities for converting parameterized paths to concrete paths diff --git a/packages/shared/data/api/apiSchemas.ts b/packages/shared/data/api/apiSchemas.ts deleted file mode 100644 index e405af806e..0000000000 --- a/packages/shared/data/api/apiSchemas.ts +++ /dev/null @@ -1,487 +0,0 @@ -// NOTE: Types are defined inline in the schema for simplicity -// If needed, specific types can be imported from './apiModels' -import type { BodyForPath, ConcreteApiPaths, QueryParamsForPath, ResponseForPath } from './apiPaths' -import type { HttpMethod, PaginatedResponse, PaginationParams } from './apiTypes' - -// Re-export for external use -export type { ConcreteApiPaths } from './apiPaths' - -/** - * Complete API Schema definitions for Test API - * - * Each path defines the supported HTTP methods with their: - * - Request parameters (params, query, body) - * - Response types - * - Type safety guarantees - * - * This schema serves as the contract between renderer and main processes, - * enabling full TypeScript type checking across IPC boundaries. - */ -export interface ApiSchemas { - /** - * Test items collection endpoint - * @example GET /test/items?page=1&limit=10&search=hello - * @example POST /test/items { "title": "New Test Item" } - */ - '/test/items': { - /** List all test items with optional filtering and pagination */ - GET: { - query?: PaginationParams & { - /** Search items by title or description */ - search?: string - /** Filter by item type */ - type?: string - /** Filter by status */ - status?: string - } - response: PaginatedResponse - } - /** Create a new test item */ - POST: { - body: { - title: string - description?: string - type?: string - status?: string - priority?: string - tags?: string[] - metadata?: Record - } - response: any - } - } - - /** - * Individual test item endpoint - * @example GET /test/items/123 - * @example PUT /test/items/123 { "title": "Updated Title" } - * @example DELETE /test/items/123 - */ - '/test/items/:id': { - /** Get a specific test item by ID */ - GET: { - params: { id: string } - response: any - } - /** Update a specific test item */ - PUT: { - params: { id: string } - body: { - title?: string - description?: string - type?: string - status?: string - priority?: string - tags?: string[] - metadata?: Record - } - response: any - } - /** Delete a specific test item */ - DELETE: { - params: { id: string } - response: void - } - } - - /** - * Test search endpoint - * @example GET /test/search?query=hello&page=1&limit=20 - */ - '/test/search': { - /** Search test items */ - GET: { - query: { - /** Search query string */ - query: string - /** Page number for pagination */ - page?: number - /** Number of results per page */ - limit?: number - /** Additional filters */ - type?: string - status?: string - } - response: PaginatedResponse - } - } - - /** - * Test statistics endpoint - * @example GET /test/stats - */ - '/test/stats': { - /** Get comprehensive test statistics */ - GET: { - response: { - /** Total number of items */ - total: number - /** Item count grouped by type */ - byType: Record - /** Item count grouped by status */ - byStatus: Record - /** Item count grouped by priority */ - byPriority: Record - /** Recent activity timeline */ - recentActivity: Array<{ - /** Date of activity */ - date: string - /** Number of items on that date */ - count: number - }> - } - } - } - - /** - * Test bulk operations endpoint - * @example POST /test/bulk { "operation": "create", "data": [...] } - */ - '/test/bulk': { - /** Perform bulk operations on test items */ - POST: { - body: { - /** Operation type */ - operation: 'create' | 'update' | 'delete' - /** Array of data items to process */ - data: any[] - } - response: { - successful: number - failed: number - errors: string[] - } - } - } - - /** - * Test error simulation endpoint - * @example POST /test/error { "errorType": "timeout" } - */ - '/test/error': { - /** Simulate various error scenarios for testing */ - POST: { - body: { - /** Type of error to simulate */ - errorType: - | 'timeout' - | 'network' - | 'server' - | 'notfound' - | 'validation' - | 'unauthorized' - | 'ratelimit' - | 'generic' - } - response: never - } - } - - /** - * Test slow response endpoint - * @example POST /test/slow { "delay": 2000 } - */ - '/test/slow': { - /** Test slow response for performance testing */ - POST: { - body: { - /** Delay in milliseconds */ - delay: number - } - response: { - message: string - delay: number - timestamp: string - } - } - } - - /** - * Test data reset endpoint - * @example POST /test/reset - */ - '/test/reset': { - /** Reset all test data to initial state */ - POST: { - response: { - message: string - timestamp: string - } - } - } - - /** - * Test config endpoint - * @example GET /test/config - * @example PUT /test/config { "setting": "value" } - */ - '/test/config': { - /** Get test configuration */ - GET: { - response: Record - } - /** Update test configuration */ - PUT: { - body: Record - response: Record - } - } - - /** - * Test status endpoint - * @example GET /test/status - */ - '/test/status': { - /** Get system test status */ - GET: { - response: { - status: string - timestamp: string - version: string - uptime: number - environment: string - } - } - } - - /** - * Test performance endpoint - * @example GET /test/performance - */ - '/test/performance': { - /** Get performance metrics */ - GET: { - response: { - requestsPerSecond: number - averageLatency: number - memoryUsage: number - cpuUsage: number - uptime: number - } - } - } - - /** - * Batch execution of multiple requests - * @example POST /batch { "requests": [...], "parallel": true } - */ - '/batch': { - /** Execute multiple API requests in a single call */ - POST: { - body: { - /** Array of requests to execute */ - requests: Array<{ - /** HTTP method for the request */ - method: HttpMethod - /** API path for the request */ - path: string - /** URL parameters */ - params?: any - /** Request body */ - body?: any - }> - /** Execute requests in parallel vs sequential */ - parallel?: boolean - } - response: { - /** Results array matching input order */ - results: Array<{ - /** HTTP status code */ - status: number - /** Response data if successful */ - data?: any - /** Error information if failed */ - error?: any - }> - /** Batch execution metadata */ - metadata: { - /** Total execution duration in ms */ - duration: number - /** Number of successful requests */ - successCount: number - /** Number of failed requests */ - errorCount: number - } - } - } - } - - /** - * Atomic transaction of multiple operations - * @example POST /transaction { "operations": [...], "options": { "rollbackOnError": true } } - */ - '/transaction': { - /** Execute multiple operations in a database transaction */ - POST: { - body: { - /** Array of operations to execute atomically */ - operations: Array<{ - /** HTTP method for the operation */ - method: HttpMethod - /** API path for the operation */ - path: string - /** URL parameters */ - params?: any - /** Request body */ - body?: any - }> - /** Transaction configuration options */ - options?: { - /** Database isolation level */ - isolation?: 'read-uncommitted' | 'read-committed' | 'repeatable-read' | 'serializable' - /** Rollback all operations on any error */ - rollbackOnError?: boolean - /** Transaction timeout in milliseconds */ - timeout?: number - } - } - response: Array<{ - /** HTTP status code */ - status: number - /** Response data if successful */ - data?: any - /** Error information if failed */ - error?: any - }> - } - } -} - -/** - * Simplified type extraction helpers - */ -export type ApiPaths = keyof ApiSchemas -export type ApiMethods = keyof ApiSchemas[TPath] & HttpMethod -export type ApiResponse = TPath extends keyof ApiSchemas - ? TMethod extends keyof ApiSchemas[TPath] - ? ApiSchemas[TPath][TMethod] extends { response: infer R } - ? R - : never - : never - : never - -export type ApiParams = TPath extends keyof ApiSchemas - ? TMethod extends keyof ApiSchemas[TPath] - ? ApiSchemas[TPath][TMethod] extends { params: infer P } - ? P - : never - : never - : never - -export type ApiQuery = TPath extends keyof ApiSchemas - ? TMethod extends keyof ApiSchemas[TPath] - ? ApiSchemas[TPath][TMethod] extends { query: infer Q } - ? Q - : never - : never - : never - -export type ApiBody = TPath extends keyof ApiSchemas - ? TMethod extends keyof ApiSchemas[TPath] - ? ApiSchemas[TPath][TMethod] extends { body: infer B } - ? B - : never - : never - : never - -/** - * Type-safe API client interface using concrete paths - * Accepts actual paths like '/test/items/123' instead of '/test/items/:id' - * Automatically infers query, body, and response types from ApiSchemas - */ -export interface ApiClient { - get( - path: TPath, - options?: { - query?: QueryParamsForPath - headers?: Record - } - ): Promise> - - post( - path: TPath, - options: { - body?: BodyForPath - query?: Record - headers?: Record - } - ): Promise> - - put( - path: TPath, - options: { - body: BodyForPath - query?: Record - headers?: Record - } - ): Promise> - - delete( - path: TPath, - options?: { - query?: Record - headers?: Record - } - ): Promise> - - patch( - path: TPath, - options: { - body?: BodyForPath - query?: Record - headers?: Record - } - ): Promise> -} - -/** - * Helper types to determine if parameters are required based on schema - */ -type HasRequiredQuery> = Path extends keyof ApiSchemas - ? Method extends keyof ApiSchemas[Path] - ? ApiSchemas[Path][Method] extends { query: any } - ? true - : false - : false - : false - -type HasRequiredBody> = Path extends keyof ApiSchemas - ? Method extends keyof ApiSchemas[Path] - ? ApiSchemas[Path][Method] extends { body: any } - ? true - : false - : false - : false - -type HasRequiredParams> = Path extends keyof ApiSchemas - ? Method extends keyof ApiSchemas[Path] - ? ApiSchemas[Path][Method] extends { params: any } - ? true - : false - : false - : false - -/** - * Handler function for a specific API endpoint - * Provides type-safe parameter extraction based on ApiSchemas - * Parameters are required or optional based on the schema definition - */ -export type ApiHandler> = ( - params: (HasRequiredParams extends true - ? { params: ApiParams } - : { params?: ApiParams }) & - (HasRequiredQuery extends true - ? { query: ApiQuery } - : { query?: ApiQuery }) & - (HasRequiredBody extends true ? { body: ApiBody } : { body?: ApiBody }) -) => Promise> - -/** - * Complete API implementation that must match ApiSchemas structure - * TypeScript will error if any endpoint is missing - this ensures exhaustive coverage - */ -export type ApiImplementation = { - [Path in ApiPaths]: { - [Method in ApiMethods]: ApiHandler - } -} diff --git a/packages/shared/data/api/apiTypes.ts b/packages/shared/data/api/apiTypes.ts index e45c45603c..4cdfc08761 100644 --- a/packages/shared/data/api/apiTypes.ts +++ b/packages/shared/data/api/apiTypes.ts @@ -8,6 +8,75 @@ */ export type HttpMethod = 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH' +// ============================================================================ +// Schema Constraint Types +// ============================================================================ + +/** + * Constraint for a single endpoint method definition. + * Requires `response` field, allows optional `params`, `query`, and `body`. + */ +export type EndpointMethodConstraint = { + params?: Record + query?: Record + body?: any + response: any // response is required +} + +/** + * Constraint for a single API path - only allows valid HTTP methods. + */ +export type EndpointConstraint = { + [Method in HttpMethod]?: EndpointMethodConstraint +} + +/** + * Validates that a schema only contains valid HTTP methods. + * Used in AssertValidSchemas for compile-time validation. + */ +type ValidateMethods = { + [Path in keyof T]: { + [Method in keyof T[Path]]: Method extends HttpMethod ? T[Path][Method] : never + } +} + +/** + * Validates that all endpoints have a `response` field. + * Returns the original type if valid, or `never` if any endpoint lacks response. + */ +type ValidateResponses = { + [Path in keyof T]: { + [Method in keyof T[Path]]: T[Path][Method] extends { response: any } + ? T[Path][Method] + : { error: `Endpoint ${Path & string}.${Method & string} is missing 'response' field` } + } +} + +/** + * Validates that a schema conforms to expected structure: + * 1. All methods must be valid HTTP methods (GET, POST, PUT, DELETE, PATCH) + * 2. All endpoints must have a `response` field + * + * This is applied at the composition level (schemas/index.ts) to catch + * invalid schemas even if individual schema files don't use validation. + * + * @example + * ```typescript + * // In schemas/index.ts + * export type ApiSchemas = AssertValidSchemas + * + * // Invalid method will cause error: + * // Type 'never' is not assignable to type... + * ``` + */ +export type AssertValidSchemas = ValidateMethods & ValidateResponses extends infer R + ? { [K in keyof R]: R[K] } + : never + +// ============================================================================ +// Core Request/Response Types +// ============================================================================ + /** * Request object structure for Data API calls */ @@ -30,8 +99,6 @@ export interface DataRequest { timestamp: number /** OpenTelemetry span context for tracing */ spanContext?: any - /** Cache options for this specific request */ - cache?: CacheOptions } } @@ -46,7 +113,7 @@ export interface DataResponse { /** Response data if successful */ data?: T /** Error information if request failed */ - error?: DataApiError + error?: SerializedDataApiError /** Response metadata */ metadata?: { /** Request processing duration in milliseconds */ @@ -60,146 +127,131 @@ export interface DataResponse { } } -/** - * Standardized error structure for Data API - */ -export interface DataApiError { - /** Error code for programmatic handling */ - code: string - /** Human-readable error message */ - message: string - /** HTTP status code */ - status: number - /** Additional error details */ - details?: any - /** Error stack trace (development mode only) */ - stack?: string -} +// Note: Error types have been moved to apiErrors.ts +// Import from there: ErrorCode, DataApiError, SerializedDataApiError, DataApiErrorFactory +import type { SerializedDataApiError } from './apiErrors' + +// Re-export for backwards compatibility in DataResponse +export type { SerializedDataApiError } from './apiErrors' + +// ============================================================================ +// Pagination Types +// ============================================================================ + +// ----- Request Parameters ----- /** - * Standard error codes for Data API + * Offset-based pagination parameters (page + limit) */ -export enum ErrorCode { - // Client errors (4xx) - BAD_REQUEST = 'BAD_REQUEST', - UNAUTHORIZED = 'UNAUTHORIZED', - FORBIDDEN = 'FORBIDDEN', - NOT_FOUND = 'NOT_FOUND', - METHOD_NOT_ALLOWED = 'METHOD_NOT_ALLOWED', - VALIDATION_ERROR = 'VALIDATION_ERROR', - RATE_LIMIT_EXCEEDED = 'RATE_LIMIT_EXCEEDED', - - // Server errors (5xx) - INTERNAL_SERVER_ERROR = 'INTERNAL_SERVER_ERROR', - DATABASE_ERROR = 'DATABASE_ERROR', - SERVICE_UNAVAILABLE = 'SERVICE_UNAVAILABLE', - - // Custom application errors - MIGRATION_ERROR = 'MIGRATION_ERROR', - PERMISSION_DENIED = 'PERMISSION_DENIED', - RESOURCE_LOCKED = 'RESOURCE_LOCKED', - CONCURRENT_MODIFICATION = 'CONCURRENT_MODIFICATION' -} - -/** - * Cache configuration options - */ -export interface CacheOptions { - /** Cache TTL in seconds */ - ttl?: number - /** Return stale data while revalidating in background */ - staleWhileRevalidate?: boolean - /** Custom cache key override */ - cacheKey?: string - /** Operations that should invalidate this cache entry */ - invalidateOn?: string[] - /** Whether to bypass cache entirely */ - noCache?: boolean -} - -/** - * Transaction request wrapper for atomic operations - */ -export interface TransactionRequest { - /** List of operations to execute in transaction */ - operations: DataRequest[] - /** Transaction options */ - options?: { - /** Database isolation level */ - isolation?: 'read-uncommitted' | 'read-committed' | 'repeatable-read' | 'serializable' - /** Whether to rollback entire transaction on any error */ - rollbackOnError?: boolean - /** Transaction timeout in milliseconds */ - timeout?: number - } -} - -/** - * Batch request for multiple operations - */ -export interface BatchRequest { - /** List of requests to execute */ - requests: DataRequest[] - /** Whether to execute requests in parallel */ - parallel?: boolean - /** Stop on first error */ - stopOnError?: boolean -} - -/** - * Batch response containing results for all requests - */ -export interface BatchResponse { - /** Individual response for each request */ - results: DataResponse[] - /** Overall batch execution metadata */ - metadata: { - /** Total execution time */ - duration: number - /** Number of successful operations */ - successCount: number - /** Number of failed operations */ - errorCount: number - } -} - -/** - * Pagination parameters for list operations - */ -export interface PaginationParams { +export interface OffsetPaginationParams { /** Page number (1-based) */ page?: number /** Items per page */ limit?: number - /** Cursor for cursor-based pagination */ - cursor?: string - /** Sort field and direction */ - sort?: { - field: string - order: 'asc' | 'desc' - } } /** - * Paginated response wrapper + * Cursor-based pagination parameters (cursor + limit) + * + * The cursor is typically an opaque reference to a record in the dataset. + * The cursor itself is NEVER included in the response - it marks an exclusive boundary. + * + * Common semantics: + * - "after cursor": Returns items AFTER the cursor (forward pagination) + * - "before cursor": Returns items BEFORE the cursor (backward/historical pagination) + * + * The specific semantic depends on the API endpoint. Check endpoint documentation. */ -export interface PaginatedResponse { +export interface CursorPaginationParams { + /** Cursor for pagination boundary (exclusive - cursor item not included in response) */ + cursor?: string + /** Items per page */ + limit?: number +} + +/** + * Sort parameters (independent, combine as needed) + */ +export interface SortParams { + /** Field to sort by */ + sortBy?: string + /** Sort direction */ + sortOrder?: 'asc' | 'desc' +} + +/** + * Search parameters (independent, combine as needed) + */ +export interface SearchParams { + /** Search query string */ + search?: string +} + +// ----- Response Types ----- + +/** + * Offset-based pagination response + */ +export interface OffsetPaginationResponse { /** Items for current page */ items: T[] /** Total number of items */ total: number - /** Current page number */ + /** Current page number (1-based) */ page: number - /** Total number of pages */ - pageCount: number - /** Whether there are more pages */ - hasNext: boolean - /** Whether there are previous pages */ - hasPrev: boolean - /** Next cursor for cursor-based pagination */ +} + +/** + * Cursor-based pagination response + */ +export interface CursorPaginationResponse { + /** Items for current page */ + items: T[] + /** Next cursor (undefined means no more data) */ nextCursor?: string - /** Previous cursor for cursor-based pagination */ - prevCursor?: string +} + +// ----- Type Utilities ----- + +/** + * Infer pagination mode from response type + */ +export type InferPaginationMode = R extends OffsetPaginationResponse + ? 'offset' + : R extends CursorPaginationResponse + ? 'cursor' + : never + +/** + * Infer item type from pagination response + */ +export type InferPaginationItem = R extends OffsetPaginationResponse + ? T + : R extends CursorPaginationResponse + ? T + : never + +/** + * Union type for both pagination responses + */ +export type PaginationResponse = OffsetPaginationResponse | CursorPaginationResponse + +/** + * Type guard: check if response is offset-based + */ +export function isOffsetPaginationResponse( + response: PaginationResponse +): response is OffsetPaginationResponse { + return 'page' in response && 'total' in response +} + +/** + * Type guard: check if response is cursor-based + */ +export function isCursorPaginationResponse( + response: PaginationResponse +): response is CursorPaginationResponse { + return !('page' in response) } /** @@ -274,16 +326,169 @@ export interface ServiceOptions { metadata?: Record } +// ============================================================================ +// API Schema Type Utilities +// ============================================================================ + +import type { BodyForPath, ConcreteApiPaths, QueryParamsForPath, ResponseForPath } from './apiPaths' +import type { ApiSchemas } from './schemas' + +// Re-export for external use +export type { ConcreteApiPaths } from './apiPaths' +export type { ApiSchemas } from './schemas' + /** - * Standard service response wrapper + * All available API paths */ -export interface ServiceResult { - /** Whether operation was successful */ - success: boolean - /** Result data if successful */ - data?: T - /** Error information if failed */ - error?: DataApiError - /** Additional metadata */ - metadata?: Record +export type ApiPaths = keyof ApiSchemas + +/** + * Available HTTP methods for a specific path + */ +export type ApiMethods = keyof ApiSchemas[TPath] & HttpMethod + +/** + * Response type for a specific path and method + */ +export type ApiResponse = TPath extends keyof ApiSchemas + ? TMethod extends keyof ApiSchemas[TPath] + ? ApiSchemas[TPath][TMethod] extends { response: infer R } + ? R + : never + : never + : never + +/** + * URL params type for a specific path and method + */ +export type ApiParams = TPath extends keyof ApiSchemas + ? TMethod extends keyof ApiSchemas[TPath] + ? ApiSchemas[TPath][TMethod] extends { params: infer P } + ? P + : never + : never + : never + +/** + * Query params type for a specific path and method + */ +export type ApiQuery = TPath extends keyof ApiSchemas + ? TMethod extends keyof ApiSchemas[TPath] + ? ApiSchemas[TPath][TMethod] extends { query: infer Q } + ? Q + : never + : never + : never + +/** + * Request body type for a specific path and method + */ +export type ApiBody = TPath extends keyof ApiSchemas + ? TMethod extends keyof ApiSchemas[TPath] + ? ApiSchemas[TPath][TMethod] extends { body: infer B } + ? B + : never + : never + : never + +/** + * Type-safe API client interface using concrete paths + * Accepts actual paths like '/test/items/123' instead of '/test/items/:id' + * Automatically infers query, body, and response types from ApiSchemas + */ +export interface ApiClient { + get( + path: TPath, + options?: { + query?: QueryParamsForPath + headers?: Record + } + ): Promise> + + post( + path: TPath, + options: { + body?: BodyForPath + query?: Record + headers?: Record + } + ): Promise> + + put( + path: TPath, + options: { + body: BodyForPath + query?: Record + headers?: Record + } + ): Promise> + + delete( + path: TPath, + options?: { + query?: Record + headers?: Record + } + ): Promise> + + patch( + path: TPath, + options: { + body?: BodyForPath + query?: Record + headers?: Record + } + ): Promise> +} + +/** + * Helper types to determine if parameters are required based on schema + */ +type HasRequiredQuery> = Path extends keyof ApiSchemas + ? Method extends keyof ApiSchemas[Path] + ? ApiSchemas[Path][Method] extends { query: any } + ? true + : false + : false + : false + +type HasRequiredBody> = Path extends keyof ApiSchemas + ? Method extends keyof ApiSchemas[Path] + ? ApiSchemas[Path][Method] extends { body: any } + ? true + : false + : false + : false + +type HasRequiredParams> = Path extends keyof ApiSchemas + ? Method extends keyof ApiSchemas[Path] + ? ApiSchemas[Path][Method] extends { params: any } + ? true + : false + : false + : false + +/** + * Handler function for a specific API endpoint + * Provides type-safe parameter extraction based on ApiSchemas + * Parameters are required or optional based on the schema definition + */ +export type ApiHandler> = ( + params: (HasRequiredParams extends true + ? { params: ApiParams } + : { params?: ApiParams }) & + (HasRequiredQuery extends true + ? { query: ApiQuery } + : { query?: ApiQuery }) & + (HasRequiredBody extends true ? { body: ApiBody } : { body?: ApiBody }) +) => Promise> + +/** + * Complete API implementation that must match ApiSchemas structure + * TypeScript will error if any endpoint is missing - this ensures exhaustive coverage + */ +export type ApiImplementation = { + [Path in ApiPaths]: { + [Method in ApiMethods]: ApiHandler + } } diff --git a/packages/shared/data/api/errorCodes.ts b/packages/shared/data/api/errorCodes.ts deleted file mode 100644 index 7ccb96c8c9..0000000000 --- a/packages/shared/data/api/errorCodes.ts +++ /dev/null @@ -1,194 +0,0 @@ -/** - * Centralized error code definitions for the Data API system - * Provides consistent error handling across renderer and main processes - */ - -import type { DataApiError } from './apiTypes' -import { ErrorCode } from './apiTypes' - -// Re-export ErrorCode for convenience -export { ErrorCode } from './apiTypes' - -/** - * Error code to HTTP status mapping - */ -export const ERROR_STATUS_MAP: Record = { - // Client errors (4xx) - [ErrorCode.BAD_REQUEST]: 400, - [ErrorCode.UNAUTHORIZED]: 401, - [ErrorCode.FORBIDDEN]: 403, - [ErrorCode.NOT_FOUND]: 404, - [ErrorCode.METHOD_NOT_ALLOWED]: 405, - [ErrorCode.VALIDATION_ERROR]: 422, - [ErrorCode.RATE_LIMIT_EXCEEDED]: 429, - - // Server errors (5xx) - [ErrorCode.INTERNAL_SERVER_ERROR]: 500, - [ErrorCode.DATABASE_ERROR]: 500, - [ErrorCode.SERVICE_UNAVAILABLE]: 503, - - // Custom application errors (5xx) - [ErrorCode.MIGRATION_ERROR]: 500, - [ErrorCode.PERMISSION_DENIED]: 403, - [ErrorCode.RESOURCE_LOCKED]: 423, - [ErrorCode.CONCURRENT_MODIFICATION]: 409 -} - -/** - * Default error messages for each error code - */ -export const ERROR_MESSAGES: Record = { - [ErrorCode.BAD_REQUEST]: 'Bad request: Invalid request format or parameters', - [ErrorCode.UNAUTHORIZED]: 'Unauthorized: Authentication required', - [ErrorCode.FORBIDDEN]: 'Forbidden: Insufficient permissions', - [ErrorCode.NOT_FOUND]: 'Not found: Requested resource does not exist', - [ErrorCode.METHOD_NOT_ALLOWED]: 'Method not allowed: HTTP method not supported for this endpoint', - [ErrorCode.VALIDATION_ERROR]: 'Validation error: Request data does not meet requirements', - [ErrorCode.RATE_LIMIT_EXCEEDED]: 'Rate limit exceeded: Too many requests', - - [ErrorCode.INTERNAL_SERVER_ERROR]: 'Internal server error: An unexpected error occurred', - [ErrorCode.DATABASE_ERROR]: 'Database error: Failed to access or modify data', - [ErrorCode.SERVICE_UNAVAILABLE]: 'Service unavailable: The service is temporarily unavailable', - - [ErrorCode.MIGRATION_ERROR]: 'Migration error: Failed to migrate data', - [ErrorCode.PERMISSION_DENIED]: 'Permission denied: Operation not allowed for current user', - [ErrorCode.RESOURCE_LOCKED]: 'Resource locked: Resource is currently locked by another operation', - [ErrorCode.CONCURRENT_MODIFICATION]: 'Concurrent modification: Resource was modified by another user' -} - -/** - * Utility class for creating standardized Data API errors - */ -export class DataApiErrorFactory { - /** - * Create a DataApiError with standard properties - */ - static create(code: ErrorCode, customMessage?: string, details?: any, stack?: string): DataApiError { - return { - code, - message: customMessage || ERROR_MESSAGES[code], - status: ERROR_STATUS_MAP[code], - details, - stack: stack || undefined - } - } - - /** - * Create a validation error with field-specific details - */ - static validation(fieldErrors: Record, message?: string): DataApiError { - return this.create(ErrorCode.VALIDATION_ERROR, message || 'Request validation failed', { fieldErrors }) - } - - /** - * Create a not found error for specific resource - */ - static notFound(resource: string, id?: string): DataApiError { - const message = id ? `${resource} with id '${id}' not found` : `${resource} not found` - - return this.create(ErrorCode.NOT_FOUND, message, { resource, id }) - } - - /** - * Create a database error with query details - */ - static database(originalError: Error, operation?: string): DataApiError { - return this.create( - ErrorCode.DATABASE_ERROR, - `Database operation failed${operation ? `: ${operation}` : ''}`, - { - originalError: originalError.message, - operation - }, - originalError.stack - ) - } - - /** - * Create a permission denied error - */ - static permissionDenied(action: string, resource?: string): DataApiError { - const message = resource ? `Permission denied: Cannot ${action} ${resource}` : `Permission denied: Cannot ${action}` - - return this.create(ErrorCode.PERMISSION_DENIED, message, { action, resource }) - } - - /** - * Create an internal server error from an unexpected error - */ - static internal(originalError: Error, context?: string): DataApiError { - const message = context - ? `Internal error in ${context}: ${originalError.message}` - : `Internal error: ${originalError.message}` - - return this.create( - ErrorCode.INTERNAL_SERVER_ERROR, - message, - { originalError: originalError.message, context }, - originalError.stack - ) - } - - /** - * Create a rate limit exceeded error - */ - static rateLimit(limit: number, windowMs: number): DataApiError { - return this.create(ErrorCode.RATE_LIMIT_EXCEEDED, `Rate limit exceeded: ${limit} requests per ${windowMs}ms`, { - limit, - windowMs - }) - } - - /** - * Create a resource locked error - */ - static resourceLocked(resource: string, id: string, lockedBy?: string): DataApiError { - const message = lockedBy - ? `${resource} '${id}' is locked by ${lockedBy}` - : `${resource} '${id}' is currently locked` - - return this.create(ErrorCode.RESOURCE_LOCKED, message, { resource, id, lockedBy }) - } - - /** - * Create a concurrent modification error - */ - static concurrentModification(resource: string, id: string): DataApiError { - return this.create(ErrorCode.CONCURRENT_MODIFICATION, `${resource} '${id}' was modified by another user`, { - resource, - id - }) - } -} - -/** - * Check if an error is a Data API error - */ -export function isDataApiError(error: any): error is DataApiError { - return ( - error && - typeof error === 'object' && - typeof error.code === 'string' && - typeof error.message === 'string' && - typeof error.status === 'number' - ) -} - -/** - * Convert a generic error to a DataApiError - */ -export function toDataApiError(error: unknown, context?: string): DataApiError { - if (isDataApiError(error)) { - return error - } - - if (error instanceof Error) { - return DataApiErrorFactory.internal(error, context) - } - - return DataApiErrorFactory.create( - ErrorCode.INTERNAL_SERVER_ERROR, - `Unknown error${context ? ` in ${context}` : ''}: ${String(error)}`, - { originalError: error, context } - ) -} diff --git a/packages/shared/data/api/index.ts b/packages/shared/data/api/index.ts index 3b00e37473..a90c4e2c51 100644 --- a/packages/shared/data/api/index.ts +++ b/packages/shared/data/api/index.ts @@ -1,121 +1,110 @@ /** * Cherry Studio Data API - Barrel Exports * - * This file provides a centralized entry point for all data API types, - * schemas, and utilities. Import everything you need from this single location. + * Exports common infrastructure types for the Data API system. + * Domain-specific DTOs should be imported directly from their schema files. * * @example * ```typescript - * import { Topic, CreateTopicDto, ApiSchemas, DataRequest, ErrorCode } from '@/shared/data' + * // Infrastructure types from barrel export + * import { DataRequest, DataResponse, ErrorCode, DataApiError } from '@shared/data/api' + * + * // Domain DTOs from schema files directly + * import type { Topic, CreateTopicDto } from '@shared/data/api/schemas/topic' * ``` */ -// Core data API types and infrastructure +// ============================================================================ +// Core Request/Response Types +// ============================================================================ + export type { - BatchRequest, - BatchResponse, - CacheOptions, - DataApiError, + CursorPaginationParams, + CursorPaginationResponse, DataRequest, DataResponse, HttpMethod, - Middleware, - PaginatedResponse, - PaginationParams, - RequestContext, - ServiceOptions, - ServiceResult, - SubscriptionCallback, - SubscriptionOptions, - TransactionRequest + OffsetPaginationParams, + OffsetPaginationResponse, + PaginationResponse, + SearchParams, + SortParams } from './apiTypes' -export { ErrorCode, SubscriptionEvent } from './apiTypes' +export { isCursorPaginationResponse, isOffsetPaginationResponse } from './apiTypes' -// Domain models and DTOs -export type { - BulkOperationRequest, - BulkOperationResponse, - CreateTestItemDto, - TestItem, - UpdateTestItemDto -} from './apiModels' +// ============================================================================ +// API Schema Type Utilities +// ============================================================================ -// API schema definitions and type helpers export type { ApiBody, ApiClient, + ApiHandler, + ApiImplementation, ApiMethods, ApiParams, ApiPaths, ApiQuery, ApiResponse, - ApiSchemas -} from './apiSchemas' + ApiSchemas, + ConcreteApiPaths +} from './apiTypes' + +// ============================================================================ +// Path Resolution Utilities +// ============================================================================ -// Path type utilities for template literal types export type { BodyForPath, - ConcreteApiPaths, MatchApiPath, QueryParamsForPath, ResolvedPath, ResponseForPath } from './apiPaths' -// Error handling utilities +// ============================================================================ +// Error Handling (from apiErrors.ts) +// ============================================================================ + +// Error code enum and mappings export { - ErrorCode as DataApiErrorCode, - DataApiErrorFactory, ERROR_MESSAGES, ERROR_STATUS_MAP, - isDataApiError, - toDataApiError -} from './errorCodes' - -/** - * Re-export commonly used type combinations for convenience - */ - -// Import types for re-export convenience types -import type { CreateTestItemDto, TestItem, UpdateTestItemDto } from './apiModels' -import type { - BatchRequest, - BatchResponse, - DataApiError, - DataRequest, - DataResponse, ErrorCode, - PaginatedResponse, - PaginationParams, - TransactionRequest -} from './apiTypes' -import type { DataApiErrorFactory } from './errorCodes' + isRetryableErrorCode, + RETRYABLE_ERROR_CODES +} from './apiErrors' -/** All test item-related types */ -export type TestItemTypes = { - TestItem: TestItem - CreateTestItemDto: CreateTestItemDto - UpdateTestItemDto: UpdateTestItemDto -} +// DataApiError class and factory +export { + DataApiError, + DataApiErrorFactory, + isDataApiError, + isSerializedDataApiError, + toDataApiError +} from './apiErrors' -/** All error-related types and utilities */ -export type ErrorTypes = { - DataApiError: DataApiError - ErrorCode: ErrorCode - ErrorFactory: typeof DataApiErrorFactory -} +// Error-related types +export type { + ConcurrentModificationErrorDetails, + DatabaseErrorDetails, + DataInconsistentErrorDetails, + DetailsForCode, + ErrorDetailsMap, + InternalErrorDetails, + InvalidOperationErrorDetails, + NotFoundErrorDetails, + PermissionDeniedErrorDetails, + RequestContext, + ResourceLockedErrorDetails, + SerializedDataApiError, + TimeoutErrorDetails, + ValidationErrorDetails +} from './apiErrors' -/** All request/response types */ -export type RequestTypes = { - DataRequest: DataRequest - DataResponse: DataResponse - BatchRequest: BatchRequest - BatchResponse: BatchResponse - TransactionRequest: TransactionRequest -} +// ============================================================================ +// Subscription & Middleware (for advanced usage) +// ============================================================================ -/** All pagination-related types */ -export type PaginationTypes = { - PaginationParams: PaginationParams - PaginatedResponse: PaginatedResponse -} +export type { Middleware, ServiceOptions, SubscriptionCallback, SubscriptionOptions } from './apiTypes' +export { SubscriptionEvent } from './apiTypes' diff --git a/packages/shared/data/api/schemas/index.ts b/packages/shared/data/api/schemas/index.ts new file mode 100644 index 0000000000..703b92ff24 --- /dev/null +++ b/packages/shared/data/api/schemas/index.ts @@ -0,0 +1,39 @@ +/** + * Schema Index - Composes all domain schemas into unified ApiSchemas + * + * This file has ONE responsibility: compose domain schemas into ApiSchemas. + * + * Import conventions (see api/README.md for details): + * - Infrastructure types: import from '@shared/data/api' + * - Domain DTOs: import directly from schema files (e.g., '@shared/data/api/schemas/topic') + * + * @example + * ```typescript + * // Infrastructure types via barrel export + * import type { ApiSchemas, DataRequest } from '@shared/data/api' + * + * // Domain DTOs directly from schema files + * import type { TestItem, CreateTestItemDto } from '@shared/data/api/schemas/test' + * import type { Topic, CreateTopicDto } from '@shared/data/api/schemas/topics' + * import type { Message, CreateMessageDto } from '@shared/data/api/schemas/messages' + * ``` + */ + +import type { AssertValidSchemas } from '../apiTypes' +import type { MessageSchemas } from './messages' +import type { TestSchemas } from './test' +import type { TopicSchemas } from './topics' + +/** + * Merged API Schemas - single source of truth for all API endpoints + * + * All domain schemas are composed here using intersection types. + * AssertValidSchemas provides compile-time validation: + * - Invalid HTTP methods become `never` type + * - Missing `response` field causes type errors + * + * When adding a new domain: + * 1. Create the schema file (e.g., topic.ts) + * 2. Import and add to intersection below + */ +export type ApiSchemas = AssertValidSchemas diff --git a/packages/shared/data/api/schemas/messages.ts b/packages/shared/data/api/schemas/messages.ts new file mode 100644 index 0000000000..4cc0a54998 --- /dev/null +++ b/packages/shared/data/api/schemas/messages.ts @@ -0,0 +1,213 @@ +/** + * Message API Schema definitions + * + * Contains all message-related endpoints for tree operations and message management. + * Includes endpoints for tree visualization and conversation view. + */ + +import type { CursorPaginationParams } from '@shared/data/api/apiTypes' +import type { + BranchMessagesResponse, + Message, + MessageData, + MessageRole, + MessageStats, + MessageStatus, + TreeResponse +} from '@shared/data/types/message' +import type { AssistantMeta, ModelMeta } from '@shared/data/types/meta' + +// ============================================================================ +// DTOs +// ============================================================================ + +/** + * DTO for creating a new message + */ +export interface CreateMessageDto { + /** + * Parent message ID for positioning this message in the conversation tree. + * + * Behavior: + * - `undefined` (omitted): Auto-resolve parent based on topic state: + * - If topic has no messages: create as root (parentId = null) + * - If topic has messages and activeNodeId is set: attach to activeNodeId + * - If topic has messages but no activeNodeId: throw INVALID_OPERATION error + * - `null` (explicit): Create as root message. Throws INVALID_OPERATION if + * topic already has a root message (only one root allowed per topic). + * - `string` (message ID): Attach to specified parent. Throws NOT_FOUND if + * parent doesn't exist, or INVALID_OPERATION if parent belongs to different topic. + */ + parentId?: string | null + /** Message role */ + role: MessageRole + /** Message content */ + data: MessageData + /** Message status */ + status?: MessageStatus + /** Siblings group ID (0 = normal, >0 = multi-model group) */ + siblingsGroupId?: number + /** Assistant ID */ + assistantId?: string + /** Preserved assistant info */ + assistantMeta?: AssistantMeta + /** Model identifier */ + modelId?: string + /** Preserved model info */ + modelMeta?: ModelMeta + /** Trace ID */ + traceId?: string + /** Statistics */ + stats?: MessageStats + /** Set this message as the active node in the topic (default: true) */ + setAsActive?: boolean +} + +/** + * DTO for updating an existing message + */ +export interface UpdateMessageDto { + /** Updated message content */ + data?: MessageData + /** Move message to new parent */ + parentId?: string | null + /** Change siblings group */ + siblingsGroupId?: number + /** Update status */ + status?: MessageStatus + /** Update trace ID */ + traceId?: string | null + /** Update statistics */ + stats?: MessageStats | null +} + +/** + * Strategy for updating activeNodeId when the active message is deleted + */ +export type ActiveNodeStrategy = 'parent' | 'clear' + +/** + * Response for delete operation + */ +export interface DeleteMessageResponse { + /** IDs of deleted messages */ + deletedIds: string[] + /** IDs of reparented children (only when cascade=false) */ + reparentedIds?: string[] + /** New activeNodeId for the topic (only if activeNodeId was affected by deletion) */ + newActiveNodeId?: string | null +} + +// ============================================================================ +// Query Parameters +// ============================================================================ + +/** + * Query parameters for GET /topics/:id/tree + */ +export interface TreeQueryParams { + /** Root node ID (defaults to tree root) */ + rootId?: string + /** End node ID (defaults to topic.activeNodeId) */ + nodeId?: string + /** Depth to expand beyond active path (-1 = all, 0 = path only, 1+ = layers) */ + depth?: number +} + +/** + * Query parameters for GET /topics/:id/messages + * + * Uses "before cursor" semantics for loading historical messages: + * - First request (no cursor): Returns the most recent `limit` messages + * - Subsequent requests: Pass `nextCursor` from previous response as `cursor` + * to load older messages towards root + * - The cursor message itself is NOT included in the response + */ +export interface BranchMessagesQueryParams extends CursorPaginationParams { + /** End node ID (defaults to topic.activeNodeId) */ + nodeId?: string + /** Whether to include siblingsGroup in response */ + includeSiblings?: boolean +} + +// ============================================================================ +// API Schema Definitions +// ============================================================================ + +/** + * Message API Schema definitions + * + * Organized by domain responsibility: + * - /topics/:id/tree - Tree visualization + * - /topics/:id/messages - Branch messages for conversation + * - /messages/:id - Individual message operations + */ +export interface MessageSchemas { + /** + * Tree query endpoint for visualization + * @example GET /topics/abc123/tree?depth=1 + */ + '/topics/:topicId/tree': { + /** Get tree structure for visualization */ + GET: { + params: { topicId: string } + query?: TreeQueryParams + response: TreeResponse + } + } + + /** + * Branch messages endpoint for conversation view + * @example GET /topics/abc123/messages?limit=20 + * @example POST /topics/abc123/messages { "parentId": "msg1", "role": "user", "data": {...} } + */ + '/topics/:topicId/messages': { + /** Get messages along active branch with pagination */ + GET: { + params: { topicId: string } + query?: BranchMessagesQueryParams + response: BranchMessagesResponse + } + /** Create a new message in the topic */ + POST: { + params: { topicId: string } + body: CreateMessageDto + response: Message + } + } + + /** + * Individual message endpoint + * @example GET /messages/msg123 + * @example PATCH /messages/msg123 { "data": {...} } + * @example DELETE /messages/msg123?cascade=true + */ + '/messages/:id': { + /** Get a single message by ID */ + GET: { + params: { id: string } + response: Message + } + /** Update a message (content, move to new parent, etc.) */ + PATCH: { + params: { id: string } + body: UpdateMessageDto + response: Message + } + /** + * Delete a message + * - cascade=true: deletes message and all descendants + * - cascade=false: reparents children to grandparent + * - activeNodeStrategy='parent' (default): sets activeNodeId to parent if affected + * - activeNodeStrategy='clear': sets activeNodeId to null if affected + */ + DELETE: { + params: { id: string } + query?: { + cascade?: boolean + activeNodeStrategy?: ActiveNodeStrategy + } + response: DeleteMessageResponse + } + } +} diff --git a/packages/shared/data/api/schemas/test.ts b/packages/shared/data/api/schemas/test.ts new file mode 100644 index 0000000000..b1627c1ed8 --- /dev/null +++ b/packages/shared/data/api/schemas/test.ts @@ -0,0 +1,318 @@ +/** + * Test API Schema definitions + * + * Contains all test-related endpoints for development and testing purposes. + * These endpoints demonstrate the API patterns and provide testing utilities. + */ + +import type { OffsetPaginationParams, OffsetPaginationResponse, SearchParams, SortParams } from '../apiTypes' + +// ============================================================================ +// Domain Models & DTOs +// ============================================================================ + +/** + * Generic test item entity - flexible structure for testing various scenarios + */ +export interface TestItem { + /** Unique identifier */ + id: string + /** Item title */ + title: string + /** Optional description */ + description?: string + /** Type category */ + type: string + /** Current status */ + status: string + /** Priority level */ + priority: string + /** Associated tags */ + tags: string[] + /** Creation timestamp */ + createdAt: string + /** Last update timestamp */ + updatedAt: string + /** Additional metadata */ + metadata: Record +} + +/** + * DTO for creating a new test item + */ +export interface CreateTestItemDto { + /** Item title */ + title: string + /** Optional description */ + description?: string + /** Type category */ + type?: string + /** Current status */ + status?: string + /** Priority level */ + priority?: string + /** Associated tags */ + tags?: string[] + /** Additional metadata */ + metadata?: Record +} + +/** + * DTO for updating an existing test item + */ +export interface UpdateTestItemDto { + /** Updated title */ + title?: string + /** Updated description */ + description?: string + /** Updated type */ + type?: string + /** Updated status */ + status?: string + /** Updated priority */ + priority?: string + /** Updated tags */ + tags?: string[] + /** Updated metadata */ + metadata?: Record +} + +// ============================================================================ +// API Schema Definitions +// ============================================================================ + +/** + * Test API Schema definitions + * + * Validation is performed at composition level via AssertValidSchemas + * in schemas/index.ts, which ensures: + * - All methods are valid HTTP methods (GET, POST, PUT, DELETE, PATCH) + * - All endpoints have a `response` field + */ +export interface TestSchemas { + /** + * Test items collection endpoint + * @example GET /test/items?page=1&limit=10&search=hello + * @example POST /test/items { "title": "New Test Item" } + */ + '/test/items': { + /** List all test items with optional filtering and pagination */ + GET: { + query?: OffsetPaginationParams & + SortParams & + SearchParams & { + /** Filter by item type */ + type?: string + /** Filter by status */ + status?: string + } + response: OffsetPaginationResponse + } + /** Create a new test item */ + POST: { + body: CreateTestItemDto + response: TestItem + } + } + + /** + * Individual test item endpoint + * @example GET /test/items/123 + * @example PUT /test/items/123 { "title": "Updated Title" } + * @example DELETE /test/items/123 + */ + '/test/items/:id': { + /** Get a specific test item by ID */ + GET: { + params: { id: string } + response: TestItem + } + /** Update a specific test item */ + PUT: { + params: { id: string } + body: UpdateTestItemDto + response: TestItem + } + /** Delete a specific test item */ + DELETE: { + params: { id: string } + response: void + } + } + + /** + * Test search endpoint + * @example GET /test/search?query=hello&page=1&limit=20 + */ + '/test/search': { + /** Search test items */ + GET: { + query: OffsetPaginationParams & { + /** Search query string */ + query: string + /** Additional filters */ + type?: string + status?: string + } + response: OffsetPaginationResponse + } + } + + /** + * Test statistics endpoint + * @example GET /test/stats + */ + '/test/stats': { + /** Get comprehensive test statistics */ + GET: { + response: { + /** Total number of items */ + total: number + /** Item count grouped by type */ + byType: Record + /** Item count grouped by status */ + byStatus: Record + /** Item count grouped by priority */ + byPriority: Record + /** Recent activity timeline */ + recentActivity: Array<{ + /** Date of activity */ + date: string + /** Number of items on that date */ + count: number + }> + } + } + } + + /** + * Test bulk operations endpoint + * @example POST /test/bulk { "operation": "create", "data": [...] } + */ + '/test/bulk': { + /** Perform bulk operations on test items */ + POST: { + body: { + /** Operation type */ + operation: 'create' | 'update' | 'delete' + /** Array of data items to process */ + data: Array + } + response: { + /** Number of successfully processed items */ + successful: number + /** Number of items that failed processing */ + failed: number + /** Array of error messages */ + errors: string[] + } + } + } + + /** + * Test error simulation endpoint + * @example POST /test/error { "errorType": "timeout" } + */ + '/test/error': { + /** Simulate various error scenarios for testing */ + POST: { + body: { + /** Type of error to simulate */ + errorType: + | 'timeout' + | 'network' + | 'server' + | 'notfound' + | 'validation' + | 'unauthorized' + | 'ratelimit' + | 'generic' + } + response: never + } + } + + /** + * Test slow response endpoint + * @example POST /test/slow { "delay": 2000 } + */ + '/test/slow': { + /** Test slow response for performance testing */ + POST: { + body: { + /** Delay in milliseconds */ + delay: number + } + response: { + message: string + delay: number + timestamp: string + } + } + } + + /** + * Test data reset endpoint + * @example POST /test/reset + */ + '/test/reset': { + /** Reset all test data to initial state */ + POST: { + response: { + message: string + timestamp: string + } + } + } + + /** + * Test config endpoint + * @example GET /test/config + * @example PUT /test/config { "setting": "value" } + */ + '/test/config': { + /** Get test configuration */ + GET: { + response: Record + } + /** Update test configuration */ + PUT: { + body: Record + response: Record + } + } + + /** + * Test status endpoint + * @example GET /test/status + */ + '/test/status': { + /** Get system test status */ + GET: { + response: { + status: string + timestamp: string + version: string + uptime: number + environment: string + } + } + } + + /** + * Test performance endpoint + * @example GET /test/performance + */ + '/test/performance': { + /** Get performance metrics */ + GET: { + response: { + requestsPerSecond: number + averageLatency: number + memoryUsage: number + cpuUsage: number + uptime: number + } + } + } +} diff --git a/packages/shared/data/api/schemas/topics.ts b/packages/shared/data/api/schemas/topics.ts new file mode 100644 index 0000000000..3a4d82b5ec --- /dev/null +++ b/packages/shared/data/api/schemas/topics.ts @@ -0,0 +1,133 @@ +/** + * Topic API Schema definitions + * + * Contains all topic-related endpoints for CRUD operations and branch switching. + */ + +import type { AssistantMeta } from '@shared/data/types/meta' +import type { Topic } from '@shared/data/types/topic' + +// ============================================================================ +// DTOs +// ============================================================================ + +/** + * DTO for creating a new topic + */ +export interface CreateTopicDto { + /** Topic name */ + name?: string + /** Associated assistant ID */ + assistantId?: string + /** Preserved assistant info */ + assistantMeta?: AssistantMeta + /** Topic-specific prompt */ + prompt?: string + /** Group ID for organization */ + groupId?: string + /** + * Source node ID for fork operation. + * When provided, copies the path from root to this node into the new topic. + */ + sourceNodeId?: string +} + +/** + * DTO for updating an existing topic + */ +export interface UpdateTopicDto { + /** Updated topic name */ + name?: string + /** Mark name as manually edited */ + isNameManuallyEdited?: boolean + /** Updated assistant ID */ + assistantId?: string + /** Updated assistant meta */ + assistantMeta?: AssistantMeta + /** Updated prompt */ + prompt?: string + /** Updated group ID */ + groupId?: string + /** Updated sort order */ + sortOrder?: number + /** Updated pin state */ + isPinned?: boolean + /** Updated pin order */ + pinnedOrder?: number +} + +/** + * DTO for setting active node + */ +export interface SetActiveNodeDto { + /** Node ID to set as active */ + nodeId: string +} + +/** + * Response for active node update + */ +export interface ActiveNodeResponse { + /** The new active node ID */ + activeNodeId: string +} + +// ============================================================================ +// API Schema Definitions +// ============================================================================ + +/** + * Topic API Schema definitions + */ +export interface TopicSchemas { + /** + * Topics collection endpoint + * @example POST /topics { "name": "New Topic", "assistantId": "asst_123" } + */ + '/topics': { + /** Create a new topic (optionally fork from existing node) */ + POST: { + body: CreateTopicDto + response: Topic + } + } + + /** + * Individual topic endpoint + * @example GET /topics/abc123 + * @example PATCH /topics/abc123 { "name": "Updated Name" } + * @example DELETE /topics/abc123 + */ + '/topics/:id': { + /** Get a topic by ID */ + GET: { + params: { id: string } + response: Topic + } + /** Update a topic */ + PATCH: { + params: { id: string } + body: UpdateTopicDto + response: Topic + } + /** Delete a topic and all its messages */ + DELETE: { + params: { id: string } + response: void + } + } + + /** + * Active node sub-resource endpoint + * High-frequency operation for branch switching + * @example PUT /topics/abc123/active-node { "nodeId": "msg456" } + */ + '/topics/:id/active-node': { + /** Set the active node for a topic */ + PUT: { + params: { id: string } + body: SetActiveNodeDto + response: ActiveNodeResponse + } + } +} diff --git a/packages/shared/data/cache/cacheSchemas.ts b/packages/shared/data/cache/cacheSchemas.ts index 0c467d7682..dc182a9c7e 100644 --- a/packages/shared/data/cache/cacheSchemas.ts +++ b/packages/shared/data/cache/cacheSchemas.ts @@ -5,23 +5,104 @@ import type * as CacheValueTypes from './cacheValueTypes' * * ## Key Naming Convention * - * All cache keys MUST follow the format: `namespace.sub.key_name` + * All cache keys (fixed and template) MUST follow the format: `namespace.sub.key_name` * * Rules: * - At least 2 segments separated by dots (.) * - Each segment uses lowercase letters, numbers, and underscores only * - Pattern: /^[a-z][a-z0-9_]*(\.[a-z][a-z0-9_]*)+$/ + * - Template placeholders `${xxx}` are treated as literal string segments * * Examples: * - 'app.user.avatar' (valid) * - 'chat.multi_select_mode' (valid) - * - 'minapp.opened_keep_alive' (valid) + * - 'scroll.position.${topicId}' (valid template key) * - 'userAvatar' (invalid - missing dot separator) * - 'App.user' (invalid - uppercase not allowed) + * - 'scroll.position:${id}' (invalid - colon not allowed) + * + * ## Template Key Support + * + * Template keys allow type-safe dynamic keys using template literal syntax. + * Define in schema with `${variable}` placeholder, use with actual values. + * Template keys follow the same dot-separated pattern as fixed keys. + * + * Examples: + * - Schema: `'scroll.position.${topicId}': number` + * - Usage: `useCache('scroll.position.topic123')` -> infers `number` type + * + * Multiple placeholders are supported: + * - Schema: `'entity.cache.${type}_${id}': CacheData` + * - Usage: `useCache('entity.cache.user_456')` -> infers `CacheData` type * * This convention is enforced by ESLint rule: data-schema-key/valid-key */ +// ============================================================================ +// Template Key Type Utilities +// ============================================================================ + +/** + * Detects whether a key string contains template placeholder syntax. + * + * Template keys use `${variable}` syntax to define dynamic segments. + * This type returns `true` if the key contains at least one `${...}` placeholder. + * + * @template K - The key string to check + * @returns `true` if K contains `${...}`, `false` otherwise + * + * @example + * ```typescript + * type Test1 = IsTemplateKey<'scroll.position.${id}'> // true + * type Test2 = IsTemplateKey<'entity.cache.${a}_${b}'> // true + * type Test3 = IsTemplateKey<'app.user.avatar'> // false + * ``` + */ +export type IsTemplateKey = K extends `${string}\${${string}}${string}` ? true : false + +/** + * Expands a template key pattern into a matching literal type. + * + * Replaces each `${variable}` placeholder with `string`, allowing + * TypeScript to match concrete keys against the template pattern. + * Recursively processes multiple placeholders. + * + * @template T - The template key pattern to expand + * @returns A template literal type that matches all valid concrete keys + * + * @example + * ```typescript + * type Test1 = ExpandTemplateKey<'scroll.position.${id}'> + * // Result: `scroll.position.${string}` (matches 'scroll.position.123', etc.) + * + * type Test2 = ExpandTemplateKey<'entity.cache.${type}_${id}'> + * // Result: `entity.cache.${string}_${string}` (matches 'entity.cache.user_123', etc.) + * + * type Test3 = ExpandTemplateKey<'app.user.avatar'> + * // Result: 'app.user.avatar' (unchanged for non-template keys) + * ``` + */ +export type ExpandTemplateKey = T extends `${infer Prefix}\${${string}}${infer Suffix}` + ? `${Prefix}${string}${ExpandTemplateKey}` + : T + +/** + * Processes a cache key, expanding template patterns if present. + * + * For template keys (containing `${...}`), returns the expanded pattern. + * For fixed keys, returns the key unchanged. + * + * @template K - The key to process + * @returns The processed key type (expanded if template, unchanged if fixed) + * + * @example + * ```typescript + * type Test1 = ProcessKey<'scroll.position.${id}'> // `scroll.position.${string}` + * type Test2 = ProcessKey<'app.user.avatar'> // 'app.user.avatar' + * ``` + */ +export type ProcessKey = IsTemplateKey extends true ? ExpandTemplateKey : K + /** * Use cache schema for renderer hook */ @@ -57,6 +138,25 @@ export type UseCacheSchema = { 'agent.active_id': string | null 'agent.session.active_id_map': Record 'agent.session.waiting_id_map': Record + + // Template key examples (for testing and demonstration) + 'scroll.position.${topicId}': number + 'entity.cache.${type}_${id}': { loaded: boolean; data: unknown } + + // ============================================================================ + // Message Streaming Cache (Temporary) + // ============================================================================ + // TODO [v2]: Replace `any` with proper types after newMessage.ts types are + // migrated to packages/shared/data/types/message.ts + // Current types: + // - StreamingSession: defined locally in StreamingService.ts + // - Message: src/renderer/src/types/newMessage.ts (renderer format, not shared/Message) + // - MessageBlock: src/renderer/src/types/newMessage.ts + 'message.streaming.session.${messageId}': any // StreamingSession + 'message.streaming.topic_sessions.${topicId}': string[] + 'message.streaming.content.${messageId}': any // Message (renderer format) + 'message.streaming.block.${blockId}': any // MessageBlock + 'message.streaming.siblings_counter.${topicId}': number } export const DefaultUseCache: UseCacheSchema = { @@ -95,17 +195,28 @@ export const DefaultUseCache: UseCacheSchema = { // Agent management 'agent.active_id': null, 'agent.session.active_id_map': {}, - 'agent.session.waiting_id_map': {} + 'agent.session.waiting_id_map': {}, + + // Template key examples (for testing and demonstration) + 'scroll.position.${topicId}': 0, + 'entity.cache.${type}_${id}': { loaded: false, data: null }, + + // Message Streaming Cache + 'message.streaming.session.${messageId}': null, + 'message.streaming.topic_sessions.${topicId}': [], + 'message.streaming.content.${messageId}': null, + 'message.streaming.block.${blockId}': null, + 'message.streaming.siblings_counter.${topicId}': 0 } /** * Use shared cache schema for renderer hook */ -export type UseSharedCacheSchema = { +export type SharedCacheSchema = { 'example_scope.example_key': string } -export const DefaultUseSharedCache: UseSharedCacheSchema = { +export const DefaultSharedCache: SharedCacheSchema = { 'example_scope.example_key': 'example default value' } @@ -121,9 +232,107 @@ export const DefaultRendererPersistCache: RendererPersistCacheSchema = { 'example_scope.example_key': 'example default value' } +// ============================================================================ +// Cache Key Types +// ============================================================================ + /** - * Type-safe cache key + * Key type for renderer persist cache (fixed keys only) */ export type RendererPersistCacheKey = keyof RendererPersistCacheSchema -export type UseCacheKey = keyof UseCacheSchema -export type UseSharedCacheKey = keyof UseSharedCacheSchema + +/** + * Key type for shared cache (fixed keys only) + */ +export type SharedCacheKey = keyof SharedCacheSchema + +/** + * Key type for memory cache (supports both fixed and template keys). + * + * This type expands all schema keys using ProcessKey, which: + * - Keeps fixed keys unchanged (e.g., 'app.user.avatar') + * - Expands template keys to match patterns (e.g., 'scroll.position.${id}' -> `scroll.position.${string}`) + * + * The resulting union type allows TypeScript to accept any concrete key + * that matches either a fixed key or an expanded template pattern. + * + * @example + * ```typescript + * // Given schema: + * // 'app.user.avatar': string + * // 'scroll.position.${topicId}': number + * + * // UseCacheKey becomes: 'app.user.avatar' | `scroll.position.${string}` + * + * // Valid keys: + * const k1: UseCacheKey = 'app.user.avatar' // fixed key + * const k2: UseCacheKey = 'scroll.position.123' // matches template + * const k3: UseCacheKey = 'scroll.position.abc' // matches template + * + * // Invalid keys: + * const k4: UseCacheKey = 'unknown.key' // error: not in schema + * ``` + */ +export type UseCacheKey = { + [K in keyof UseCacheSchema]: ProcessKey +}[keyof UseCacheSchema] + +// ============================================================================ +// UseCache Specialized Types +// ============================================================================ + +/** + * Infers the value type for a given cache key from UseCacheSchema. + * + * Works with both fixed keys and template keys: + * - For fixed keys, returns the exact value type from schema + * - For template keys, matches the key against expanded patterns and returns the value type + * + * If the key doesn't match any schema entry, returns `never`. + * + * @template K - The cache key to infer value type for + * @returns The value type associated with the key, or `never` if not found + * + * @example + * ```typescript + * // Given schema: + * // 'app.user.avatar': string + * // 'scroll.position.${topicId}': number + * + * type T1 = InferUseCacheValue<'app.user.avatar'> // string + * type T2 = InferUseCacheValue<'scroll.position.123'> // number + * type T3 = InferUseCacheValue<'scroll.position.abc'> // number + * type T4 = InferUseCacheValue<'unknown.key'> // never + * ``` + */ +export type InferUseCacheValue = { + [S in keyof UseCacheSchema]: K extends ProcessKey ? UseCacheSchema[S] : never +}[keyof UseCacheSchema] + +/** + * Type guard for casual cache keys that blocks schema-defined keys. + * + * Used to ensure casual API methods (getCasual, setCasual, etc.) cannot + * be called with keys that are defined in the schema (including template patterns). + * This enforces proper API usage: use type-safe methods for schema keys, + * use casual methods only for truly dynamic/unknown keys. + * + * @template K - The key to check + * @returns `K` if the key doesn't match any schema pattern, `never` if it does + * + * @example + * ```typescript + * // Given schema: + * // 'app.user.avatar': string + * // 'scroll.position.${topicId}': number + * + * // These cause compile-time errors (key matches schema): + * getCasual('app.user.avatar') // Error: never + * getCasual('scroll.position.123') // Error: never (matches template) + * + * // These are allowed (key doesn't match any schema pattern): + * getCasual('my.custom.key') // OK + * getCasual('other.dynamic.key') // OK + * ``` + */ +export type UseCacheCasualKey = K extends UseCacheKey ? never : K diff --git a/packages/shared/data/cache/cacheTypes.ts b/packages/shared/data/cache/cacheTypes.ts index 1ae71919bc..e39dd2877c 100644 --- a/packages/shared/data/cache/cacheTypes.ts +++ b/packages/shared/data/cache/cacheTypes.ts @@ -22,7 +22,7 @@ export interface CacheSyncMessage { type: 'shared' | 'persist' key: string value: any - ttl?: number + expireAt?: number // Absolute Unix timestamp for precise cross-window sync } /** @@ -33,7 +33,7 @@ export interface CacheSyncBatchMessage { entries: Array<{ key: string value: any - ttl?: number + expireAt?: number // Absolute Unix timestamp for precise cross-window sync }> } diff --git a/packages/shared/data/preference/preferenceTypes.ts b/packages/shared/data/preference/preferenceTypes.ts index 1937266c47..0b6a7cc27f 100644 --- a/packages/shared/data/preference/preferenceTypes.ts +++ b/packages/shared/data/preference/preferenceTypes.ts @@ -55,14 +55,15 @@ export enum ThemeMode { export type LanguageVarious = | 'zh-CN' | 'zh-TW' + | 'de-DE' | 'el-GR' | 'en-US' | 'es-ES' | 'fr-FR' | 'ja-JP' | 'pt-PT' + | 'ro-RO' | 'ru-RU' - | 'de-DE' export type WindowStyle = 'transparent' | 'opaque' diff --git a/packages/shared/data/types/message.ts b/packages/shared/data/types/message.ts new file mode 100644 index 0000000000..3542c30a57 --- /dev/null +++ b/packages/shared/data/types/message.ts @@ -0,0 +1,481 @@ +import type { CursorPaginationResponse } from '@shared/data/api/apiTypes' +/** + * Message Statistics - combines token usage and performance metrics + * Replaces the separate `usage` and `metrics` fields + */ +export interface MessageStats { + // Token consumption (from API response) + promptTokens?: number + completionTokens?: number + totalTokens?: number + thoughtsTokens?: number + + // Cost (calculated at message completion time) + cost?: number + + // Performance metrics (measured locally) + timeFirstTokenMs?: number + timeCompletionMs?: number + timeThinkingMs?: number +} + +// ============================================================================ +// Message Data +// ============================================================================ + +/** + * Message data field structure + * This is the type for the `data` column in the message table + */ +export interface MessageData { + blocks: MessageDataBlock[] +} + +//FIXME [v2] 注意,以下类型只是占位,接口未稳定,随时会变 + +// ============================================================================ +// Content Reference Types +// ============================================================================ + +/** + * Reference category for content references + */ +export enum ReferenceCategory { + CITATION = 'citation', + MENTION = 'mention' +} + +/** + * Citation source type + */ +export enum CitationType { + WEB = 'web', + KNOWLEDGE = 'knowledge', + MEMORY = 'memory' +} + +/** + * Base reference structure for inline content references + */ +export interface BaseReference { + category: ReferenceCategory + /** Text marker in content, e.g., "[1]", "@user" */ + marker?: string + /** Position range in content */ + range?: { start: number; end: number } +} + +/** + * Base citation reference + */ +interface BaseCitationReference extends BaseReference { + category: ReferenceCategory.CITATION + citationType: CitationType +} + +/** + * Web search citation reference + * Data structure compatible with WebSearchResponse from renderer + */ +export interface WebCitationReference extends BaseCitationReference { + citationType: CitationType.WEB + content: { + results?: unknown // types needs to be migrated from renderer ( newMessage.ts ) + source: unknown // types needs to be migrated from renderer ( newMessage.ts ) + } +} + +/** + * Knowledge base citation reference + * Data structure compatible with KnowledgeReference[] from renderer + */ +export interface KnowledgeCitationReference extends BaseCitationReference { + citationType: CitationType.KNOWLEDGE + + // types needs to be migrated from renderer ( newMessage.ts ) + content: { + id: number + content: string + sourceUrl: string + type: string + file?: unknown + metadata?: Record + }[] +} + +/** + * Memory citation reference + * Data structure compatible with MemoryItem[] from renderer + */ +export interface MemoryCitationReference extends BaseCitationReference { + citationType: CitationType.MEMORY + // types needs to be migrated from renderer ( newMessage.ts ) + content: { + id: string + memory: string + hash?: string + createdAt?: string + updatedAt?: string + score?: number + metadata?: Record + }[] +} + +/** + * Union type of all citation references + */ +export type CitationReference = WebCitationReference | KnowledgeCitationReference | MemoryCitationReference + +/** + * Mention reference for @mentions in content + * References a Model entity + */ +export interface MentionReference extends BaseReference { + category: ReferenceCategory.MENTION + /** Model ID being mentioned */ + modelId: string //FIXME 未定接口,model的数据结构还未确定,先占位 + /** Display name for the mention */ + displayName?: string +} + +/** + * Union type of all content references + */ +export type ContentReference = CitationReference | MentionReference + +/** + * Type guard: check if reference is a citation + */ +export function isCitation(ref: ContentReference): ref is CitationReference { + return ref.category === ReferenceCategory.CITATION +} + +/** + * Type guard: check if reference is a mention + */ +export function isMention(ref: ContentReference): ref is MentionReference { + return ref.category === ReferenceCategory.MENTION +} + +/** + * Type guard: check if reference is a web citation + */ +export function isWebCitation(ref: ContentReference): ref is WebCitationReference { + return isCitation(ref) && ref.citationType === CitationType.WEB +} + +/** + * Type guard: check if reference is a knowledge citation + */ +export function isKnowledgeCitation(ref: ContentReference): ref is KnowledgeCitationReference { + return isCitation(ref) && ref.citationType === CitationType.KNOWLEDGE +} + +/** + * Type guard: check if reference is a memory citation + */ +export function isMemoryCitation(ref: ContentReference): ref is MemoryCitationReference { + return isCitation(ref) && ref.citationType === CitationType.MEMORY +} + +// ============================================================================ +// Message Block +// ============================================================================ + +export enum BlockType { + UNKNOWN = 'unknown', + MAIN_TEXT = 'main_text', + THINKING = 'thinking', + TRANSLATION = 'translation', + IMAGE = 'image', + CODE = 'code', + TOOL = 'tool', + FILE = 'file', + ERROR = 'error', + CITATION = 'citation', + VIDEO = 'video', + COMPACT = 'compact' +} + +/** + * Base message block data structure + */ +export interface BaseBlock { + type: BlockType + createdAt: number // timestamp + updatedAt?: number + // modelId?: string // v1's dead code, will be removed in v2 + metadata?: Record + error?: SerializedErrorData +} + +/** + * Serialized error for storage + */ +export interface SerializedErrorData { + name?: string + message: string + code?: string + stack?: string + cause?: unknown +} + +// Block type specific interfaces + +export interface UnknownBlock extends BaseBlock { + type: BlockType.UNKNOWN + content?: string +} + +/** + * Main text block containing the primary message content. + * + * ## Migration Notes (v2.0) + * + * ### Added + * - `references`: Unified inline references replacing the old citation system. + * Supports multiple reference types (citations, mentions) with position tracking. + * + * ### Removed + * - `citationReferences`: Use `references` with `ReferenceCategory.CITATION` instead. + * - `CitationBlock`: Citation data is now embedded in `MainTextBlock.references`. + * The standalone CitationBlock type is no longer used. + */ +export interface MainTextBlock extends BaseBlock { + type: BlockType.MAIN_TEXT + content: string + //knowledgeBaseIds?: string[] // v1's dead code, will be removed in v2 + + /** + * Inline references embedded in the content (citations, mentions, etc.) + * Replaces the old CitationBlock + citationReferences pattern. + * @since v2.0 + */ + references?: ContentReference[] + + /** + * @deprecated Use `references` with `ReferenceCategory.CITATION` instead. + */ + // citationReferences?: { + // citationBlockId?: string + // citationBlockSource?: string + // }[] +} + +export interface ThinkingBlock extends BaseBlock { + type: BlockType.THINKING + content: string + thinkingMs: number +} + +export interface TranslationBlock extends BaseBlock { + type: BlockType.TRANSLATION + content: string + sourceBlockId?: string + sourceLanguage?: string + targetLanguage: string +} + +export interface CodeBlock extends BaseBlock { + type: BlockType.CODE + content: string + language: string +} + +export interface ImageBlock extends BaseBlock { + type: BlockType.IMAGE + url?: string + fileId?: string +} + +export interface ToolBlock extends BaseBlock { + type: BlockType.TOOL + toolId: string + toolName?: string + arguments?: Record + content?: string | object +} + +/** + * @deprecated Citation data is now embedded in MainTextBlock.references. + * Use ContentReference types instead. Will be removed in v3.0. + */ +export interface CitationBlock extends BaseBlock { + type: BlockType.CITATION + responseData?: unknown + knowledgeData?: unknown + memoriesData?: unknown +} + +export interface FileBlock extends BaseBlock { + type: BlockType.FILE + fileId: string +} + +export interface VideoBlock extends BaseBlock { + type: BlockType.VIDEO + url?: string + filePath?: string +} + +export interface ErrorBlock extends BaseBlock { + type: BlockType.ERROR +} + +export interface CompactBlock extends BaseBlock { + type: BlockType.COMPACT + content: string + compactedContent: string +} + +/** + * Union type of all message block data types + */ +export type MessageDataBlock = + | UnknownBlock + | MainTextBlock + | ThinkingBlock + | TranslationBlock + | CodeBlock + | ImageBlock + | ToolBlock + | CitationBlock + | FileBlock + | VideoBlock + | ErrorBlock + | CompactBlock + +// ============================================================================ +// Message Entity Types +// ============================================================================ + +import type { AssistantMeta, ModelMeta } from './meta' + +/** + * Message role - user, assistant, or system + */ +export type MessageRole = 'user' | 'assistant' | 'system' + +/** + * Message status + * - pending: Placeholder created, streaming in progress + * - success: Completed successfully + * - error: Failed with error + * - paused: User stopped generation + */ +export type MessageStatus = 'pending' | 'success' | 'error' | 'paused' + +/** + * Complete message entity as stored in database + */ +export interface Message { + /** Message ID (UUIDv7) */ + id: string + /** Topic ID this message belongs to */ + topicId: string + /** Parent message ID (null for root) */ + parentId: string | null + /** Message role */ + role: MessageRole + /** Message content (blocks, mentions, etc.) */ + data: MessageData + /** Searchable text extracted from data.blocks */ + searchableText?: string | null + /** Message status */ + status: MessageStatus + /** Siblings group ID (0 = normal branch, >0 = multi-model response group) */ + siblingsGroupId: number + /** Assistant ID */ + assistantId?: string | null + /** Preserved assistant info for display */ + assistantMeta?: AssistantMeta | null + /** Model identifier */ + modelId?: string | null + /** Preserved model info (provider, name) */ + modelMeta?: ModelMeta | null + /** Trace ID for tracking */ + traceId?: string | null + /** Statistics: token usage, performance metrics */ + stats?: MessageStats | null + /** Creation timestamp (ISO string) */ + createdAt: string + /** Last update timestamp (ISO string) */ + updatedAt: string +} + +// ============================================================================ +// Tree Structure Types +// ============================================================================ + +/** + * Lightweight tree node for tree visualization (ReactFlow) + * Contains only essential display info, not full message content + */ +export interface TreeNode { + /** Message ID */ + id: string + /** Parent message ID (null for root, omitted in SiblingsGroup.nodes) */ + parentId?: string | null + /** Message role */ + role: MessageRole + /** Content preview (first 50 characters) */ + preview: string + /** Model identifier */ + modelId?: string | null + /** Model display info */ + modelMeta?: ModelMeta | null + /** Message status */ + status: MessageStatus + /** Creation timestamp (ISO string) */ + createdAt: string + /** Whether this node has children (for expand indicator) */ + hasChildren: boolean +} + +/** + * Group of sibling nodes with same parentId and siblingsGroupId + * Used for multi-model responses in tree view + */ +export interface SiblingsGroup { + /** Parent message ID */ + parentId: string + /** Siblings group ID (non-zero) */ + siblingsGroupId: number + /** Nodes in this group (parentId omitted to avoid redundancy) */ + nodes: Omit[] +} + +/** + * Tree query response structure + */ +export interface TreeResponse { + /** Regular nodes (siblingsGroupId = 0) */ + nodes: TreeNode[] + /** Multi-model response groups (siblingsGroupId != 0) */ + siblingsGroups: SiblingsGroup[] + /** Current active node ID */ + activeNodeId: string | null +} + +// ============================================================================ +// Branch Message Types +// ============================================================================ + +/** + * Message with optional siblings group for conversation view + * Used in GET /topics/:id/messages response + */ +export interface BranchMessage { + /** The message itself */ + message: Message + /** Other messages in the same siblings group (only when siblingsGroupId != 0 and includeSiblings=true) */ + siblingsGroup?: Message[] +} + +/** + * Branch messages response structure + */ +export interface BranchMessagesResponse extends CursorPaginationResponse { + /** Current active node ID */ + activeNodeId: string | null +} diff --git a/packages/shared/data/types/meta.ts b/packages/shared/data/types/meta.ts new file mode 100644 index 0000000000..2bba74d700 --- /dev/null +++ b/packages/shared/data/types/meta.ts @@ -0,0 +1,36 @@ +/** + * Soft reference metadata types + * + * These types store snapshots of referenced entities at creation time, + * preserving display information even if the original entity is deleted. + */ + +/** + * Preserved assistant info for display when assistant is deleted + * Used in: message.assistantMeta, topic.assistantMeta + */ +export interface AssistantMeta { + /** Original assistant ID, used to attempt reference recovery */ + id: string + /** Assistant display name shown in UI */ + name: string + /** Assistant icon emoji for visual identification */ + emoji?: string + /** Assistant type, e.g., 'default', 'custom', 'agent' */ + type?: string +} + +/** + * Preserved model info for display when model is unavailable + * Used in: message.modelMeta + */ +export interface ModelMeta { + /** Original model ID, used to attempt reference recovery */ + id: string + /** Model display name, e.g., "GPT-4o", "Claude 3.5 Sonnet" */ + name: string + /** Provider identifier, e.g., "openai", "anthropic", "google" */ + provider: string + /** Model family/group, e.g., "gpt-4", "claude-3", useful for grouping in UI */ + group?: string +} diff --git a/packages/shared/data/types/topic.ts b/packages/shared/data/types/topic.ts new file mode 100644 index 0000000000..f03981f771 --- /dev/null +++ b/packages/shared/data/types/topic.ts @@ -0,0 +1,40 @@ +/** + * Topic entity types + * + * Topics are containers for messages and belong to assistants. + * They can be organized into groups and have tags for categorization. + */ + +import type { AssistantMeta } from './meta' + +/** + * Complete topic entity as stored in database + */ +export interface Topic { + /** Topic ID */ + id: string + /** Topic name */ + name?: string | null + /** Whether the name was manually edited by user */ + isNameManuallyEdited: boolean + /** Associated assistant ID */ + assistantId?: string | null + /** Preserved assistant info for display when assistant is deleted */ + assistantMeta?: AssistantMeta | null + /** Topic-specific prompt override */ + prompt?: string | null + /** Active node ID in the message tree */ + activeNodeId?: string | null + /** Group ID for organization */ + groupId?: string | null + /** Sort order within group */ + sortOrder: number + /** Whether topic is pinned */ + isPinned: boolean + /** Pinned order */ + pinnedOrder: number + /** Creation timestamp (ISO string) */ + createdAt: string + /** Last update timestamp (ISO string) */ + updatedAt: string +} diff --git a/packages/shared/utils.ts b/packages/shared/utils.ts index a14f78958d..7e90624aba 100644 --- a/packages/shared/utils.ts +++ b/packages/shared/utils.ts @@ -35,3 +35,56 @@ export const defaultAppHeaders = () => { // return value // } // } + +/** + * Extracts the trailing API version segment from a URL path. + * + * This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL. + * Only versions at the end of the path are extracted, not versions in the middle. + * The returned version string does not include leading or trailing slashes. + * + * @param {string} url - The URL string to parse. + * @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found. + * + * @example + * getTrailingApiVersion('https://api.example.com/v1') // 'v1' + * getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta' + * getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end) + * getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta' + * getTrailingApiVersion('https://api.example.com') // undefined + */ +export function getTrailingApiVersion(url: string): string | undefined { + const match = url.match(TRAILING_VERSION_REGEX) + + if (match) { + // Extract version without leading slash and trailing slash + return match[0].replace(/^\//, '').replace(/\/$/, '') + } + + return undefined +} + +/** + * Matches an API version at the end of a URL (with optional trailing slash). + * Used to detect and extract versions only from the trailing position. + */ +const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i + +/** + * Removes the trailing API version segment from a URL path. + * + * This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL. + * Only versions at the end of the path are removed, not versions in the middle. + * + * @param {string} url - The URL string to process. + * @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found. + * + * @example + * withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com' + * withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com' + * withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change) + * withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com' + */ +export function withoutTrailingApiVersion(url: string): string { + return url.replace(TRAILING_VERSION_REGEX, '') +} diff --git a/packages/ui/components.json b/packages/ui/components.json index b5c2f24eff..a6c7c26b0c 100644 --- a/packages/ui/components.json +++ b/packages/ui/components.json @@ -5,7 +5,7 @@ "hooks": "@cherrystudio/ui/hooks", "lib": "@cherrystudio/ui/lib", "ui": "@cherrystudio/ui/components/primitives", - "utils": "@cherrystudio/ui/utils" + "utils": "@cherrystudio/ui/lib/utils" }, "iconLibrary": "lucide", "rsc": false, diff --git a/packages/ui/src/components/composites/Ellipsis/index.tsx b/packages/ui/src/components/composites/Ellipsis/index.tsx index c4c296079c..c5c3a5fd72 100644 --- a/packages/ui/src/components/composites/Ellipsis/index.tsx +++ b/packages/ui/src/components/composites/Ellipsis/index.tsx @@ -1,8 +1,7 @@ // Original: src/renderer/src/components/Ellipsis/index.tsx +import { cn } from '@cherrystudio/ui/lib/utils' import type { HTMLAttributes } from 'react' -import { cn } from '../../../utils' - type Props = { maxLine?: number className?: string diff --git a/packages/ui/src/components/composites/Flex/index.tsx b/packages/ui/src/components/composites/Flex/index.tsx index 522a5574d7..6aa34293e1 100644 --- a/packages/ui/src/components/composites/Flex/index.tsx +++ b/packages/ui/src/components/composites/Flex/index.tsx @@ -1,7 +1,6 @@ +import { cn } from '@cherrystudio/ui/lib/utils' import React from 'react' -import { cn } from '../../../utils' - export interface BoxProps extends React.ComponentProps<'div'> {} export const Box = ({ children, className, ...props }: BoxProps & { children?: React.ReactNode }) => { diff --git a/packages/ui/src/components/composites/Input/input.tsx b/packages/ui/src/components/composites/Input/input.tsx index 80c83fe15e..1d792f2039 100644 --- a/packages/ui/src/components/composites/Input/input.tsx +++ b/packages/ui/src/components/composites/Input/input.tsx @@ -1,4 +1,5 @@ -import { cn, toUndefinedIfNull } from '@cherrystudio/ui/utils' +import { cn } from '@cherrystudio/ui/lib/utils' +import { toUndefinedIfNull } from '@cherrystudio/ui/utils/index' import type { VariantProps } from 'class-variance-authority' import { cva } from 'class-variance-authority' import { Edit2Icon, EyeIcon, EyeOffIcon } from 'lucide-react' diff --git a/packages/ui/src/components/composites/ListItem/index.tsx b/packages/ui/src/components/composites/ListItem/index.tsx index 196fdb2949..327dda27e0 100644 --- a/packages/ui/src/components/composites/ListItem/index.tsx +++ b/packages/ui/src/components/composites/ListItem/index.tsx @@ -1,9 +1,8 @@ // Original path: src/renderer/src/components/ListItem/index.tsx +import { cn } from '@cherrystudio/ui/lib/utils' import { Tooltip } from '@heroui/react' import type { ReactNode } from 'react' -import { cn } from '../../../utils' - interface ListItemProps { active?: boolean icon?: ReactNode diff --git a/packages/ui/src/components/composites/Sortable/ItemRenderer.tsx b/packages/ui/src/components/composites/Sortable/ItemRenderer.tsx index e9e048fd6a..396af9b11b 100644 --- a/packages/ui/src/components/composites/Sortable/ItemRenderer.tsx +++ b/packages/ui/src/components/composites/Sortable/ItemRenderer.tsx @@ -1,10 +1,10 @@ +import { cn } from '@cherrystudio/ui/lib/utils' import type { DraggableSyntheticListeners } from '@dnd-kit/core' import type { Transform } from '@dnd-kit/utilities' import { CSS } from '@dnd-kit/utilities' import React, { useEffect } from 'react' import styled from 'styled-components' -import { cn } from '../../../utils' import type { RenderItemType } from './types' interface ItemRendererProps { diff --git a/packages/ui/src/components/composites/ThinkingEffect/index.tsx b/packages/ui/src/components/composites/ThinkingEffect/index.tsx index 6c542bf3d4..86baad1470 100644 --- a/packages/ui/src/components/composites/ThinkingEffect/index.tsx +++ b/packages/ui/src/components/composites/ThinkingEffect/index.tsx @@ -7,12 +7,12 @@ */ // Original path: src/renderer/src/components/ThinkingEffect.tsx +import { cn } from '@cherrystudio/ui/lib/utils' import { isEqual } from 'lodash' import { ChevronRight, Lightbulb } from 'lucide-react' import { motion } from 'motion/react' import React, { useEffect, useMemo, useState } from 'react' -import { cn } from '../../../utils' import { lightbulbVariants } from './defaultVariants' interface ThinkingEffectProps { diff --git a/packages/ui/src/components/index.ts b/packages/ui/src/components/index.ts index 1cede94148..3c7051f02c 100644 --- a/packages/ui/src/components/index.ts +++ b/packages/ui/src/components/index.ts @@ -59,6 +59,7 @@ export { export { Sortable } from './composites/Sortable' /* Shadcn Primitive Components */ +export * from './primitives/badge' export * from './primitives/breadcrumb' export * from './primitives/button' export * from './primitives/checkbox' diff --git a/packages/ui/src/components/primitives/Avatar/EmojiAvatar.tsx b/packages/ui/src/components/primitives/Avatar/EmojiAvatar.tsx index 7a9ce03e24..e6fef89703 100644 --- a/packages/ui/src/components/primitives/Avatar/EmojiAvatar.tsx +++ b/packages/ui/src/components/primitives/Avatar/EmojiAvatar.tsx @@ -1,7 +1,6 @@ +import { cn } from '@cherrystudio/ui/lib/utils' import React, { memo } from 'react' -import { cn } from '../../../utils' - interface EmojiAvatarProps { children: string size?: number diff --git a/packages/ui/src/components/primitives/Avatar/index.tsx b/packages/ui/src/components/primitives/Avatar/index.tsx index a2ad31bd73..1c5aff9658 100644 --- a/packages/ui/src/components/primitives/Avatar/index.tsx +++ b/packages/ui/src/components/primitives/Avatar/index.tsx @@ -1,7 +1,7 @@ +import { cn } from '@cherrystudio/ui/lib/utils' import type { AvatarProps as HeroUIAvatarProps } from '@heroui/react' import { Avatar as HeroUIAvatar, AvatarGroup as HeroUIAvatarGroup } from '@heroui/react' -import { cn } from '../../../utils' import EmojiAvatar from './EmojiAvatar' export interface AvatarProps extends Omit { diff --git a/packages/ui/src/components/primitives/badge.tsx b/packages/ui/src/components/primitives/badge.tsx new file mode 100644 index 0000000000..5cb3c8cefe --- /dev/null +++ b/packages/ui/src/components/primitives/badge.tsx @@ -0,0 +1,35 @@ +import { cn } from '@cherrystudio/ui/lib/utils' +import { Slot } from '@radix-ui/react-slot' +import { cva, type VariantProps } from 'class-variance-authority' +import * as React from 'react' + +const badgeVariants = cva( + 'inline-flex items-center justify-center rounded-full border px-2 py-0.5 text-xs font-medium w-fit whitespace-nowrap shrink-0 [&>svg]:size-3 gap-1 [&>svg]:pointer-events-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive transition-[color,box-shadow] overflow-hidden', + { + variants: { + variant: { + default: 'border-transparent bg-background-subtle text-secondary-foreground [a&]:hover:bg-primary/90', + secondary: 'border-transparent bg-secondary text-secondary-foreground [a&]:hover:bg-secondary/90', + destructive: + 'border-transparent text-destructive bg-[red]/10 [a&]:hover:bg-destructive/90 focus-visible:ring-destructive/20 dark:focus-visible:ring-destructive/40 dark:bg-destructive/60', + outline: 'text-foreground [a&]:hover:bg-accent [a&]:hover:text-accent-foreground' + } + }, + defaultVariants: { + variant: 'default' + } + } +) + +function Badge({ + className, + variant, + asChild = false, + ...props +}: React.ComponentProps<'span'> & VariantProps & { asChild?: boolean }) { + const Comp = asChild ? Slot : 'span' + + return +} + +export { Badge, badgeVariants } diff --git a/packages/ui/src/components/primitives/breadcrumb.tsx b/packages/ui/src/components/primitives/breadcrumb.tsx index 6f9d871409..11c3527eeb 100644 --- a/packages/ui/src/components/primitives/breadcrumb.tsx +++ b/packages/ui/src/components/primitives/breadcrumb.tsx @@ -1,4 +1,4 @@ -import { cn } from '@cherrystudio/ui/utils/index' +import { cn } from '@cherrystudio/ui/lib/utils' import { Slot } from '@radix-ui/react-slot' import { ChevronRight, MoreHorizontal } from 'lucide-react' import * as React from 'react' diff --git a/packages/ui/src/components/primitives/button.tsx b/packages/ui/src/components/primitives/button.tsx index 092d55dd1c..8fb96c9903 100644 --- a/packages/ui/src/components/primitives/button.tsx +++ b/packages/ui/src/components/primitives/button.tsx @@ -1,4 +1,4 @@ -import { cn } from '@cherrystudio/ui/utils/index' +import { cn } from '@cherrystudio/ui/lib/utils' import { Slot } from '@radix-ui/react-slot' import { cva, type VariantProps } from 'class-variance-authority' import { Loader } from 'lucide-react' diff --git a/packages/ui/src/components/primitives/checkbox.tsx b/packages/ui/src/components/primitives/checkbox.tsx index 34f374fec4..dff1f928c2 100644 --- a/packages/ui/src/components/primitives/checkbox.tsx +++ b/packages/ui/src/components/primitives/checkbox.tsx @@ -1,4 +1,4 @@ -import { cn } from '@cherrystudio/ui/utils/index' +import { cn } from '@cherrystudio/ui/lib/utils' import * as CheckboxPrimitive from '@radix-ui/react-checkbox' import { cva, type VariantProps } from 'class-variance-authority' import { CheckIcon } from 'lucide-react' diff --git a/packages/ui/src/components/primitives/combobox.tsx b/packages/ui/src/components/primitives/combobox.tsx index 15afa8c0a8..2e11809351 100644 --- a/packages/ui/src/components/primitives/combobox.tsx +++ b/packages/ui/src/components/primitives/combobox.tsx @@ -10,7 +10,7 @@ import { CommandList } from '@cherrystudio/ui/components/primitives/command' import { Popover, PopoverContent, PopoverTrigger } from '@cherrystudio/ui/components/primitives/popover' -import { cn } from '@cherrystudio/ui/utils/index' +import { cn } from '@cherrystudio/ui/lib/utils' import { cva, type VariantProps } from 'class-variance-authority' import { Check, ChevronDown, X } from 'lucide-react' import * as React from 'react' diff --git a/packages/ui/src/components/primitives/command.tsx b/packages/ui/src/components/primitives/command.tsx index 2d0515d272..76ecf7a1c1 100644 --- a/packages/ui/src/components/primitives/command.tsx +++ b/packages/ui/src/components/primitives/command.tsx @@ -5,7 +5,7 @@ import { DialogHeader, DialogTitle } from '@cherrystudio/ui/components/primitives/dialog' -import { cn } from '@cherrystudio/ui/utils' +import { cn } from '@cherrystudio/ui/lib/utils' import { Command as CommandPrimitive } from 'cmdk' import { SearchIcon } from 'lucide-react' import * as React from 'react' diff --git a/packages/ui/src/components/primitives/dialog.tsx b/packages/ui/src/components/primitives/dialog.tsx index 6b36644bc7..62a063eea4 100644 --- a/packages/ui/src/components/primitives/dialog.tsx +++ b/packages/ui/src/components/primitives/dialog.tsx @@ -1,4 +1,4 @@ -import { cn } from '@cherrystudio/ui/utils/index' +import { cn } from '@cherrystudio/ui/lib/utils' import * as DialogPrimitive from '@radix-ui/react-dialog' import { XIcon } from 'lucide-react' import * as React from 'react' diff --git a/packages/ui/src/components/primitives/input-group.tsx b/packages/ui/src/components/primitives/input-group.tsx index 9c27456b34..0bb9253001 100644 --- a/packages/ui/src/components/primitives/input-group.tsx +++ b/packages/ui/src/components/primitives/input-group.tsx @@ -3,7 +3,7 @@ import type { InputProps } from '@cherrystudio/ui/components/primitives/input' import { Input } from '@cherrystudio/ui/components/primitives/input' import type { TextareaInputProps } from '@cherrystudio/ui/components/primitives/textarea' import * as Textarea from '@cherrystudio/ui/components/primitives/textarea' -import { cn } from '@cherrystudio/ui/utils/index' +import { cn } from '@cherrystudio/ui/lib/utils' import { cva, type VariantProps } from 'class-variance-authority' import * as React from 'react' diff --git a/packages/ui/src/components/primitives/input.tsx b/packages/ui/src/components/primitives/input.tsx index 5a5e29cd5a..cffad36b44 100644 --- a/packages/ui/src/components/primitives/input.tsx +++ b/packages/ui/src/components/primitives/input.tsx @@ -1,4 +1,4 @@ -import { cn } from '@cherrystudio/ui/utils' +import { cn } from '@cherrystudio/ui/lib/utils' import * as React from 'react' interface InputProps extends React.ComponentProps<'input'> {} diff --git a/packages/ui/src/components/primitives/kbd.tsx b/packages/ui/src/components/primitives/kbd.tsx index d1a2268e75..21c4b06b7b 100644 --- a/packages/ui/src/components/primitives/kbd.tsx +++ b/packages/ui/src/components/primitives/kbd.tsx @@ -1,4 +1,4 @@ -import { cn } from '@cherrystudio/ui/utils/index' +import { cn } from '@cherrystudio/ui/lib/utils' function Kbd({ className, ...props }: React.ComponentProps<'kbd'>) { return ( diff --git a/packages/ui/src/components/primitives/pagination.tsx b/packages/ui/src/components/primitives/pagination.tsx index eb675e8bb0..4e5a407c07 100644 --- a/packages/ui/src/components/primitives/pagination.tsx +++ b/packages/ui/src/components/primitives/pagination.tsx @@ -1,6 +1,6 @@ import type { Button } from '@cherrystudio/ui/components/primitives/button' import { buttonVariants } from '@cherrystudio/ui/components/primitives/button' -import { cn } from '@cherrystudio/ui/utils/index' +import { cn } from '@cherrystudio/ui/lib/utils' import { ChevronLeftIcon, ChevronRightIcon, MoreHorizontalIcon } from 'lucide-react' import * as React from 'react' diff --git a/packages/ui/src/components/primitives/popover.tsx b/packages/ui/src/components/primitives/popover.tsx index b52cc7aa4a..805d952b07 100644 --- a/packages/ui/src/components/primitives/popover.tsx +++ b/packages/ui/src/components/primitives/popover.tsx @@ -1,6 +1,6 @@ 'use client' -import { cn } from '@cherrystudio/ui/utils' +import { cn } from '@cherrystudio/ui/lib/utils' import * as PopoverPrimitive from '@radix-ui/react-popover' import * as React from 'react' diff --git a/packages/ui/src/components/primitives/radioGroup.tsx b/packages/ui/src/components/primitives/radioGroup.tsx index 0d4b95b6c9..2dd4ece391 100644 --- a/packages/ui/src/components/primitives/radioGroup.tsx +++ b/packages/ui/src/components/primitives/radioGroup.tsx @@ -1,4 +1,4 @@ -import { cn } from '@cherrystudio/ui/utils/index' +import { cn } from '@cherrystudio/ui/lib/utils' import * as RadioGroupPrimitive from '@radix-ui/react-radio-group' import { cva, type VariantProps } from 'class-variance-authority' import { CircleIcon } from 'lucide-react' diff --git a/packages/ui/src/components/primitives/select.tsx b/packages/ui/src/components/primitives/select.tsx index ec2bac4cba..9b1fa1bba5 100644 --- a/packages/ui/src/components/primitives/select.tsx +++ b/packages/ui/src/components/primitives/select.tsx @@ -1,4 +1,4 @@ -import { cn } from '@cherrystudio/ui/utils/index' +import { cn } from '@cherrystudio/ui/lib/utils' import * as SelectPrimitive from '@radix-ui/react-select' import { cva, type VariantProps } from 'class-variance-authority' import { CheckIcon, ChevronDownIcon, ChevronUpIcon } from 'lucide-react' diff --git a/packages/ui/src/components/primitives/shadcn-io/dropzone/index.tsx b/packages/ui/src/components/primitives/shadcn-io/dropzone/index.tsx index 4892a94244..14ba16a6d0 100644 --- a/packages/ui/src/components/primitives/shadcn-io/dropzone/index.tsx +++ b/packages/ui/src/components/primitives/shadcn-io/dropzone/index.tsx @@ -1,7 +1,7 @@ 'use client' import { Button } from '@cherrystudio/ui/components/primitives/button' -import { cn } from '@cherrystudio/ui/utils/index' +import { cn } from '@cherrystudio/ui/lib/utils' import { UploadIcon } from 'lucide-react' import type { ReactNode } from 'react' import { createContext, use } from 'react' diff --git a/packages/ui/src/components/primitives/switch.tsx b/packages/ui/src/components/primitives/switch.tsx index 7ce2ac5d3c..2e9b2c12eb 100644 --- a/packages/ui/src/components/primitives/switch.tsx +++ b/packages/ui/src/components/primitives/switch.tsx @@ -1,4 +1,4 @@ -import { cn } from '@cherrystudio/ui/utils' +import { cn } from '@cherrystudio/ui/lib/utils' import * as SwitchPrimitive from '@radix-ui/react-switch' import { cva } from 'class-variance-authority' import * as React from 'react' diff --git a/packages/ui/src/components/primitives/tabs.tsx b/packages/ui/src/components/primitives/tabs.tsx index 051de1dfb2..95c8ec90e2 100644 --- a/packages/ui/src/components/primitives/tabs.tsx +++ b/packages/ui/src/components/primitives/tabs.tsx @@ -1,4 +1,4 @@ -import { cn } from '@cherrystudio/ui/utils/index' +import { cn } from '@cherrystudio/ui/lib/utils' import * as TabsPrimitive from '@radix-ui/react-tabs' import { cva } from 'class-variance-authority' import * as React from 'react' diff --git a/packages/ui/src/components/primitives/textarea.tsx b/packages/ui/src/components/primitives/textarea.tsx index 5bc749d7ac..1444e8c9ed 100644 --- a/packages/ui/src/components/primitives/textarea.tsx +++ b/packages/ui/src/components/primitives/textarea.tsx @@ -1,4 +1,4 @@ -import { cn } from '@cherrystudio/ui/utils/index' +import { cn } from '@cherrystudio/ui/lib/utils' import { composeEventHandlers } from '@radix-ui/primitive' import { useCallbackRef } from '@radix-ui/react-use-callback-ref' import { useControllableState } from '@radix-ui/react-use-controllable-state' diff --git a/packages/ui/src/components/primitives/tooltip_new.tsx b/packages/ui/src/components/primitives/tooltip_new.tsx index 430ac262f4..9b1db13e1e 100644 --- a/packages/ui/src/components/primitives/tooltip_new.tsx +++ b/packages/ui/src/components/primitives/tooltip_new.tsx @@ -1,4 +1,4 @@ -import { cn } from '@cherrystudio/ui/utils/index' +import { cn } from '@cherrystudio/ui/lib/utils' import * as TooltipPrimitive from '@radix-ui/react-tooltip' import * as React from 'react' diff --git a/packages/ui/src/lib/utils.ts b/packages/ui/src/lib/utils.ts new file mode 100644 index 0000000000..d477ffd44a --- /dev/null +++ b/packages/ui/src/lib/utils.ts @@ -0,0 +1,23 @@ +/** + * Internal utilities for UI components. + * + * This module is for INTERNAL use only and should NOT be exposed to external consumers. + * External utilities should be placed in `utils/` instead. + * + * @internal + */ + +import { type ClassValue, clsx } from 'clsx' +import { twMerge } from 'tailwind-merge' + +/** + * Merges Tailwind CSS class names with conflict resolution. + * Combines clsx for conditional classes and tailwind-merge for deduplication. + * + * @example + * cn('px-2 py-1', 'px-4') // => 'py-1 px-4' + * cn('text-red-500', isActive && 'text-blue-500') + */ +export function cn(...inputs: ClassValue[]) { + return twMerge(clsx(inputs)) +} diff --git a/packages/ui/src/utils/index.ts b/packages/ui/src/utils/index.ts index 7f0275f99d..573e97be73 100644 --- a/packages/ui/src/utils/index.ts +++ b/packages/ui/src/utils/index.ts @@ -1,13 +1,11 @@ -import { type ClassValue, clsx } from 'clsx' -import { twMerge } from 'tailwind-merge' - /** - * Merge class names with tailwind-merge - * This utility combines clsx and tailwind-merge for optimal class name handling + * Public utility functions for external consumers. + * + * This module is part of the PUBLIC API and can be imported via `@cherrystudio/ui/utils`. + * For internal-only utilities (e.g., Tailwind class merging), use `lib/` instead. + * + * @module utils */ -export function cn(...inputs: ClassValue[]) { - return twMerge(clsx(inputs)) -} /** * Converts `null` to `undefined`, otherwise returns the input value. diff --git a/packages/ui/stories/components/primitives/Badge.stories.tsx b/packages/ui/stories/components/primitives/Badge.stories.tsx new file mode 100644 index 0000000000..cf02c202bb --- /dev/null +++ b/packages/ui/stories/components/primitives/Badge.stories.tsx @@ -0,0 +1,207 @@ +import { Badge } from '@cherrystudio/ui' +import type { Meta, StoryObj } from '@storybook/react' +import { Check, X } from 'lucide-react' + +const meta: Meta = { + title: 'Components/Primitives/Badge', + component: Badge, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Displays a badge or a component that looks like a badge. Based on shadcn/ui.' + } + } + }, + tags: ['autodocs'], + argTypes: { + variant: { + control: { type: 'select' }, + options: ['default', 'secondary', 'destructive', 'outline'], + description: 'The visual style variant of the badge' + }, + asChild: { + control: { type: 'boolean' }, + description: 'Render as a child element' + }, + className: { + control: { type: 'text' }, + description: 'Additional CSS classes' + } + } +} + +export default meta +type Story = StoryObj + +// Default +export const Default: Story = { + args: { + children: 'Badge' + } +} + +// Variants +export const Secondary: Story = { + args: { + variant: 'secondary', + children: 'Secondary' + } +} + +export const Destructive: Story = { + args: { + variant: 'destructive', + children: 'Destructive' + } +} + +export const Outline: Story = { + args: { + variant: 'outline', + children: 'Outline' + } +} + +// All Variants +export const AllVariants: Story = { + render: () => ( +
+ Default + Secondary + Destructive + Outline +
+ ) +} + +// With Icons +export const WithIcon: Story = { + render: () => ( +
+ + + Success + + + + Error + + + + Completed + + + + Verified + +
+ ) +} + +// As Link +export const AsLink: Story = { + render: () => ( +
+
+

Using asChild to render as an anchor tag:

+ + + GitHub + + +
+
+

All variants as links (hover to see effect):

+ +
+
+ ) +} + +// Status Badges +export const StatusBadges: Story = { + render: () => ( +
+ Active + Pending + Failed + Draft +
+ ) +} + +// Real World Examples +export const RealWorldExamples: Story = { + render: () => ( +
+ {/* Status Indicators */} +
+

Status Indicators

+
+ Online + Away + Offline + Unknown +
+
+ + {/* Labels */} +
+

Labels

+
+ New + Featured + Hot + Beta +
+
+ + {/* Tags */} +
+

Tags

+
+ React + TypeScript + Tailwind + Shadcn +
+
+ + {/* Notification Counts */} +
+

Notification Counts

+
+ 3 + 99+ + 12 +
+
+ + {/* With Icons */} +
+

With Icons

+
+ + + Verified + + + + Rejected + +
+
+
+ ) +} diff --git a/resources/scripts/install-ovms.js b/resources/scripts/install-ovms.js index f2be80bffe..8ccd522b01 100644 --- a/resources/scripts/install-ovms.js +++ b/resources/scripts/install-ovms.js @@ -6,8 +6,8 @@ const { downloadWithPowerShell } = require('./download') // Base URL for downloading OVMS binaries const OVMS_RELEASE_BASE_URL = - 'https://storage.openvinotoolkit.org/repositories/openvino_model_server/packages/2025.3.0/ovms_windows_python_on.zip' -const OVMS_EX_URL = 'https://gitcode.com/gcw_ggDjjkY3/kjfile/releases/download/download/ovms_25.3_ex.zip' + 'https://storage.openvinotoolkit.org/repositories/openvino_model_server/packages/2025.4.1/ovms_windows_python_on.zip' +const OVMS_EX_URL = 'https://gitcode.com/gcw_ggDjjkY3/kjfile/releases/download/download/ovms_25.4_ex.zip' /** * error code: diff --git a/scripts/auto-translate-i18n.ts b/scripts/auto-translate-i18n.ts index 7a1bea6f35..41bb14a0a1 100644 --- a/scripts/auto-translate-i18n.ts +++ b/scripts/auto-translate-i18n.ts @@ -50,7 +50,7 @@ Usage Instructions: - pt-pt (Portuguese) Run Command: -yarn auto:i18n +yarn i18n:translate Performance Optimization Recommendations: - For stable API services: MAX_CONCURRENT_TRANSLATIONS=8, TRANSLATION_DELAY_MS=50 @@ -152,7 +152,8 @@ const languageMap = { 'es-es': 'Spanish', 'fr-fr': 'French', 'pt-pt': 'Portuguese', - 'de-de': 'German' + 'de-de': 'German', + 'ro-ro': 'Romanian' } const PROMPT = ` diff --git a/scripts/check-i18n.ts b/scripts/check-i18n.ts index 5735474106..ac1adc3de8 100644 --- a/scripts/check-i18n.ts +++ b/scripts/check-i18n.ts @@ -145,7 +145,7 @@ export function main() { console.log('i18n 检查已通过') } catch (e) { console.error(e) - throw new Error(`检查未通过。尝试运行 yarn sync:i18n 以解决问题。`) + throw new Error(`检查未通过。尝试运行 yarn i18n:sync 以解决问题。`) } } diff --git a/src/main/data/CacheService.ts b/src/main/data/CacheService.ts index 79e8104999..2d260e14d2 100644 --- a/src/main/data/CacheService.ts +++ b/src/main/data/CacheService.ts @@ -18,6 +18,7 @@ */ import { loggerService } from '@logger' +import type { SharedCacheKey, SharedCacheSchema } from '@shared/data/cache/cacheSchemas' import type { CacheEntry, CacheSyncMessage } from '@shared/data/cache/cacheTypes' import { IpcChannel } from '@shared/IpcChannel' import { BrowserWindow, ipcMain } from 'electron' @@ -42,9 +43,12 @@ export class CacheService { private static instance: CacheService private initialized = false - // Main process cache + // Main process internal cache private cache = new Map() + // Shared cache (synchronized with renderer windows) + private sharedCache = new Map() + // GC timer reference and interval time (e.g., every 10 minutes) private gcInterval: NodeJS.Timeout | null = null private readonly GC_INTERVAL_MS = 10 * 60 * 1000 @@ -79,7 +83,7 @@ export class CacheService { // ============ Main Process Cache (Internal) ============ /** - * Garbage collection logic + * Garbage collection logic for both internal and shared cache */ private startGarbageCollection() { if (this.gcInterval) return @@ -88,6 +92,7 @@ export class CacheService { const now = Date.now() let removedCount = 0 + // Clean internal cache for (const [key, entry] of this.cache.entries()) { if (entry.expireAt && now > entry.expireAt) { this.cache.delete(key) @@ -95,6 +100,14 @@ export class CacheService { } } + // Clean shared cache + for (const [key, entry] of this.sharedCache.entries()) { + if (entry.expireAt && now > entry.expireAt) { + this.sharedCache.delete(key) + removedCount++ + } + } + if (removedCount > 0) { logger.debug(`Garbage collection removed ${removedCount} expired items`) } @@ -155,6 +168,110 @@ export class CacheService { return this.cache.delete(key) } + // ============ Shared Cache (Cross-window via IPC) ============ + + /** + * Get value from shared cache with TTL validation (type-safe) + * @param key - Schema-defined shared cache key + * @returns Cached value or undefined if not found or expired + */ + getShared(key: K): SharedCacheSchema[K] | undefined { + const entry = this.sharedCache.get(key) + if (!entry) return undefined + + // Check TTL (lazy cleanup) + if (entry.expireAt && Date.now() > entry.expireAt) { + this.sharedCache.delete(key) + return undefined + } + + return entry.value as SharedCacheSchema[K] + } + + /** + * Set value in shared cache with cross-window broadcast (type-safe) + * @param key - Schema-defined shared cache key + * @param value - Value to cache (type inferred from schema) + * @param ttl - Time to live in milliseconds (optional) + */ + setShared(key: K, value: SharedCacheSchema[K], ttl?: number): void { + const expireAt = ttl ? Date.now() + ttl : undefined + const entry: CacheEntry = { value, expireAt } + + this.sharedCache.set(key, entry) + + // Broadcast to all renderer windows + this.broadcastSync({ + type: 'shared', + key, + value, + expireAt + }) + + logger.verbose(`Set shared cache key "${key}"`) + } + + /** + * Check if key exists in shared cache and is not expired (type-safe) + * @param key - Schema-defined shared cache key + * @returns True if key exists and is valid, false otherwise + */ + hasShared(key: K): boolean { + const entry = this.sharedCache.get(key) + if (!entry) return false + + // Check TTL + if (entry.expireAt && Date.now() > entry.expireAt) { + this.sharedCache.delete(key) + return false + } + + return true + } + + /** + * Delete from shared cache with cross-window broadcast (type-safe) + * @param key - Schema-defined shared cache key + * @returns True if deletion succeeded + */ + deleteShared(key: K): boolean { + if (!this.sharedCache.has(key)) { + return true + } + + this.sharedCache.delete(key) + + // Broadcast deletion to all renderer windows + this.broadcastSync({ + type: 'shared', + key, + value: undefined // undefined means deletion + }) + + logger.verbose(`Deleted shared cache key "${key}"`) + return true + } + + /** + * Get all shared cache entries (for renderer initialization sync) + * @returns Record of all shared cache entries with their metadata + */ + private getAllShared(): Record { + const now = Date.now() + const result: Record = {} + + for (const [key, entry] of this.sharedCache.entries()) { + // Skip expired entries + if (entry.expireAt && now > entry.expireAt) { + this.sharedCache.delete(key) + continue + } + result[key] = entry + } + + return result + } + // ============ Persist Cache Interface (Reserved) ============ // TODO: Implement persist cache in future @@ -180,10 +297,32 @@ export class CacheService { // Handle cache sync broadcast from renderer ipcMain.on(IpcChannel.Cache_Sync, (event, message: CacheSyncMessage) => { const senderWindowId = BrowserWindow.fromWebContents(event.sender)?.id + + // Update Main's sharedCache when receiving shared type sync + if (message.type === 'shared') { + if (message.value === undefined) { + // Handle deletion + this.sharedCache.delete(message.key) + } else { + // Handle set - use expireAt directly (absolute timestamp) + const entry: CacheEntry = { + value: message.value, + expireAt: message.expireAt + } + this.sharedCache.set(message.key, entry) + } + } + + // Broadcast to other windows this.broadcastSync(message, senderWindowId) logger.verbose(`Broadcasted cache sync: ${message.type}:${message.key}`) }) + // Handle getAllShared request for renderer initialization + ipcMain.handle(IpcChannel.Cache_GetAllShared, () => { + return this.getAllShared() + }) + logger.debug('Cache sync IPC handlers registered') } @@ -197,11 +336,13 @@ export class CacheService { this.gcInterval = null } - // Clear cache + // Clear caches this.cache.clear() + this.sharedCache.clear() // Remove IPC handlers ipcMain.removeAllListeners(IpcChannel.Cache_Sync) + ipcMain.removeHandler(IpcChannel.Cache_GetAllShared) logger.debug('CacheService cleanup completed') } diff --git a/src/main/data/README.md b/src/main/data/README.md index 7efff10113..e596b87434 100644 --- a/src/main/data/README.md +++ b/src/main/data/README.md @@ -1,386 +1,44 @@ # Main Data Layer -This directory contains the main process data management system, providing unified data access for the entire application. +This directory contains the main process data management implementation. + +## Documentation + +- **Overview**: [docs/en/references/data/README.md](../../../docs/en/references/data/README.md) +- **DataApi in Main**: [data-api-in-main.md](../../../docs/en/references/data/data-api-in-main.md) +- **Database Patterns**: [database-patterns.md](../../../docs/en/references/data/database-patterns.md) ## Directory Structure ``` src/main/data/ -├── api/ # Data API framework (interface layer) -│ ├── core/ # Core API infrastructure -│ │ ├── ApiServer.ts # Request routing and handler execution -│ │ ├── MiddlewareEngine.ts # Request/response middleware -│ │ └── adapters/ # Communication adapters (IPC) -│ ├── handlers/ # API endpoint implementations -│ │ └── index.ts # Thin handlers: param extraction, DTO conversion -│ └── index.ts # API framework exports -│ +├── api/ # Data API framework +│ ├── core/ # ApiServer, MiddlewareEngine, adapters +│ └── handlers/ # API endpoint implementations ├── services/ # Business logic layer -│ ├── base/ # Service base classes and interfaces -│ │ └── IBaseService.ts # Service interface definitions -│ └── TestService.ts # Test service (placeholder for real services) -│ # Future business services: -│ # - TopicService.ts # Topic business logic -│ # - MessageService.ts # Message business logic -│ # - FileService.ts # File business logic -│ ├── repositories/ # Data access layer (selective usage) -│ # Repository pattern used selectively for complex domains -│ # Future repositories: -│ # - TopicRepository.ts # Complex: Topic data access -│ # - MessageRepository.ts # Complex: Message data access -│ -├── db/ # Database layer -│ ├── schemas/ # Drizzle table definitions -│ │ ├── preference.ts # Preference configuration table -│ │ ├── appState.ts # Application state table -│ │ └── columnHelpers.ts # Reusable column definitions -│ ├── seeding/ # Database initialization -│ └── DbService.ts # Database connection and management -│ -├── migration/ # Data migration system -│ └── v2/ # v2 data refactoring migration tools -│ -├── CacheService.ts # Infrastructure: Cache management -├── DataApiService.ts # Infrastructure: API coordination -└── PreferenceService.ts # System service: User preferences +├── db/ # Database layer +│ ├── schemas/ # Drizzle table definitions +│ ├── seeding/ # Database initialization +│ └── DbService.ts # Database connection management +├── migration/ # Data migration system +├── CacheService.ts # Cache management +├── DataApiService.ts # API coordination +└── PreferenceService.ts # User preferences ``` -## Core Components - -### Naming Note - -Three components at the root of `data/` use the "Service" suffix but serve different purposes: - -#### CacheService (Infrastructure Component) -- **True Nature**: Cache Manager / Infrastructure Utility -- **Purpose**: Multi-tier caching system (memory/shared/persist) -- **Features**: TTL support, IPC synchronization, cross-window broadcasting -- **Characteristics**: Zero business logic, purely technical functionality -- **Note**: Named "Service" for management consistency, but is actually infrastructure - -#### DataApiService (Coordinator Component) -- **True Nature**: API Coordinator (Main) / API Client (Renderer) -- **Main Process Purpose**: Coordinates ApiServer and IpcAdapter initialization -- **Renderer Purpose**: HTTP-like client for IPC communication -- **Characteristics**: Zero business logic, purely coordination/communication plumbing -- **Note**: Named "Service" for management consistency, but is actually coordinator/client - -#### PreferenceService (System Service) -- **True Nature**: System-level Data Access Service -- **Purpose**: User configuration management with caching and multi-window sync -- **Features**: SQLite persistence, full memory cache, cross-window synchronization -- **Characteristics**: Minimal business logic (validation, defaults), primarily data access -- **Note**: Hybrid between data access and infrastructure, "Service" naming is acceptable - -**Key Takeaway**: Despite all being named "Service", these are infrastructure/coordination components, not business services. The "Service" suffix is kept for consistency with existing codebase conventions. - -## Architecture Layers - -### API Framework Layer (`api/`) - -The API framework provides the interface layer for data access: - -#### API Server (`api/core/ApiServer.ts`) -- Request routing and handler execution -- Middleware pipeline processing -- Type-safe endpoint definitions - -#### Handlers (`api/handlers/`) -- **Purpose**: Thin API endpoint implementations -- **Responsibilities**: - - HTTP-like parameter extraction from requests - - DTO/domain model conversion - - Delegating to business services - - Transforming responses for IPC -- **Anti-pattern**: Do NOT put business logic in handlers -- **Currently**: Contains test handlers (production handlers pending) -- **Type Safety**: Must implement all endpoints defined in `@shared/data/api` - -### Business Logic Layer (`services/`) - -Business services implement domain logic and workflows: - -#### When to Create a Service -- Contains business rules and validation -- Orchestrates multiple repositories or data sources -- Implements complex workflows -- Manages transactions across multiple operations - -#### Service Pattern - -Just an example for understanding. - -```typescript -// services/TopicService.ts -export class TopicService { - constructor( - private topicRepo: TopicRepository, // Use repository for complex data access - private cacheService: CacheService // Use infrastructure utilities - ) {} - - async createTopicWithMessage(data: CreateTopicDto) { - // Business validation - this.validateTopicData(data) - - // Transaction coordination - return await DbService.transaction(async (tx) => { - const topic = await this.topicRepo.create(data.topic, tx) - const message = await this.messageRepo.create(data.message, tx) - return { topic, message } - }) - } -} -``` - -#### Current Services -- `TestService`: Placeholder service for testing API framework -- More business services will be added as needed (TopicService, MessageService, etc.) - -### Data Access Layer (`repositories/`) - -Repositories handle database operations with a **selective usage pattern**: - -#### When to Use Repository Pattern -Use repositories for **complex domains** that meet multiple criteria: -- ✅ Complex queries (joins, subqueries, aggregations) -- ✅ GB-scale data requiring optimization and pagination -- ✅ Complex transactions involving multiple tables -- ✅ Reusable data access patterns across services -- ✅ High testing requirements (mock data access in tests) - -#### When to Use Direct Drizzle in Services -Skip repository layer for **simple domains**: -- ✅ Simple CRUD operations -- ✅ Small datasets (< 100MB) -- ✅ Domain-specific queries with no reuse potential -- ✅ Fast development is priority - -#### Repository Pattern - -Just an example for understanding. - -```typescript -// repositories/TopicRepository.ts -export class TopicRepository { - async findById(id: string, tx?: Transaction): Promise { - const db = tx || DbService.db - return await db.select() - .from(topicTable) - .where(eq(topicTable.id, id)) - .limit(1) - } - - async findByIdWithMessages( - topicId: string, - pagination: PaginationOptions - ): Promise { - // Complex join query with pagination - // Handles GB-scale data efficiently - } -} -``` - -#### Direct Drizzle Pattern (Simple Services) -```typescript -// services/SimpleService.ts -export class SimpleService extends BaseService { - async getItem(id: string) { - // Direct Drizzle query for simple operations - return await this.database - .select() - .from(itemTable) - .where(eq(itemTable.id, id)) - } -} -``` - -#### Planned Repositories -- **TopicRepository**: Complex topic data access with message relationships -- **MessageRepository**: GB-scale message queries with pagination -- **FileRepository**: File reference counting and cleanup logic - -**Decision Principle**: Use the simplest approach that solves the problem. Add repository abstraction only when complexity demands it. - -## Database Layer - -### DbService -- SQLite database connection management -- Automatic migrations and seeding -- Drizzle ORM integration - -### Schemas (`db/schemas/`) -- Table definitions using Drizzle ORM -- Follow naming convention: `{entity}Table` exports -- Use `crudTimestamps` helper for timestamp fields - -### Current Tables -- `preference`: User configuration storage -- `appState`: Application state persistence - -## Usage Examples - -### Accessing Services -```typescript -// Get service instances -import { cacheService } from '@/data/CacheService' -import { preferenceService } from '@/data/PreferenceService' -import { dataApiService } from '@/data/DataApiService' - -// Services are singletons, initialized at app startup -``` +## Quick Reference ### Adding New API Endpoints -1. Define endpoint in `@shared/data/api/apiSchemas.ts` -2. Implement handler in `api/handlers/index.ts` (thin layer, delegate to service) -3. Create business service in `services/` for domain logic -4. Create repository in `repositories/` if domain is complex (optional) -5. Add database schema in `db/schemas/` if required -### Adding Database Tables -1. Create schema in `db/schemas/{tableName}.ts` -2. Generate migration: `yarn run migrations:generate` -3. Add seeding data in `db/seeding/` if needed -4. Decide: Repository pattern or direct Drizzle? - - Complex domain → Create repository in `repositories/` - - Simple domain → Use direct Drizzle in service -5. Create business service in `services/` -6. Implement API handler in `api/handlers/` +1. Define schema in `@shared/data/api/schemas/` +2. Implement handler in `api/handlers/` +3. Create business service in `services/` +4. Create repository in `repositories/` (if complex domain) -### Creating a New Business Service +### Database Commands -**For complex domains (with repository)**: -```typescript -// 1. Create repository: repositories/ExampleRepository.ts -export class ExampleRepository { - async findById(id: string, tx?: Transaction) { /* ... */ } - async create(data: CreateDto, tx?: Transaction) { /* ... */ } -} - -// 2. Create service: services/ExampleService.ts -export class ExampleService { - constructor(private exampleRepo: ExampleRepository) {} - - async createExample(data: CreateDto) { - // Business validation - this.validate(data) - - // Use repository - return await this.exampleRepo.create(data) - } -} - -// 3. Create handler: api/handlers/example.ts -import { ExampleService } from '../../services/ExampleService' - -export const exampleHandlers = { - 'POST /examples': async ({ body }) => { - return await ExampleService.getInstance().createExample(body) - } -} +```bash +# Generate migrations +yarn db:migrations:generate ``` - -**For simple domains (direct Drizzle)**: -```typescript -// 1. Create service: services/SimpleService.ts -export class SimpleService extends BaseService { - async getItem(id: string) { - // Direct database access - return await this.database - .select() - .from(itemTable) - .where(eq(itemTable.id, id)) - } -} - -// 2. Create handler: api/handlers/simple.ts -export const simpleHandlers = { - 'GET /items/:id': async ({ params }) => { - return await SimpleService.getInstance().getItem(params.id) - } -} -``` - -## Data Flow - -### Complete Request Flow - -``` -┌─────────────────────────────────────────────────────┐ -│ Renderer Process │ -│ React Component → useDataApi Hook │ -└────────────────┬────────────────────────────────────┘ - │ IPC Request -┌────────────────▼────────────────────────────────────┐ -│ Infrastructure Layer │ -│ DataApiService (coordinator) │ -│ ↓ │ -│ ApiServer (routing) → MiddlewareEngine │ -└────────────────┬────────────────────────────────────┘ - │ -┌────────────────▼────────────────────────────────────┐ -│ API Layer (api/handlers/) │ -│ Handler: Thin layer │ -│ - Extract parameters │ -│ - Call business service │ -│ - Transform response │ -└────────────────┬────────────────────────────────────┘ - │ -┌────────────────▼────────────────────────────────────┐ -│ Business Logic Layer (services/) │ -│ Service: Domain logic │ -│ - Business validation │ -│ - Transaction coordination │ -│ - Call repository or direct DB │ -└────────────────┬────────────────────────────────────┘ - │ - ┌──────────┴──────────┐ - │ │ -┌─────▼─────────┐ ┌──────▼──────────────────────────┐ -│ repositories/ │ │ Direct Drizzle │ -│ (Complex) │ │ (Simple domains) │ -│ - Repository │ │ - Service uses DbService.db │ -│ - Query logic │ │ - Inline queries │ -└─────┬─────────┘ └──────┬──────────────────────────┘ - │ │ - └──────────┬─────────┘ - │ -┌────────────────▼────────────────────────────────────┐ -│ Database Layer (db/) │ -│ DbService → SQLite (Drizzle ORM) │ -└─────────────────────────────────────────────────────┘ -``` - -### Architecture Principles - -1. **Separation of Concerns** - - Handlers: Request/response transformation only - - Services: Business logic and orchestration - - Repositories: Data access (when complexity demands it) - -2. **Dependency Flow** (top to bottom only) - - Handlers depend on Services - - Services depend on Repositories (or DbService directly) - - Repositories depend on DbService - - **Never**: Services depend on Handlers - - **Never**: Repositories contain business logic - -3. **Selective Repository Usage** - - Use Repository: Complex domains (Topic, Message, File) - - Direct Drizzle: Simple domains (Agent, Session, Translate) - - Decision based on: query complexity, data volume, testing needs - -## Development Guidelines - -- All services use singleton pattern -- Database operations must be type-safe (Drizzle) -- API endpoints require complete type definitions -- Services should handle errors gracefully -- Use existing logging system (`@logger`) - -## Integration Points - -- **IPC Communication**: All services expose IPC handlers for renderer communication -- **Type Safety**: Shared types in `@shared/data` ensure end-to-end type safety -- **Error Handling**: Standardized error codes and handling across all services -- **Logging**: Comprehensive logging for debugging and monitoring \ No newline at end of file diff --git a/src/main/data/api/core/ApiServer.ts b/src/main/data/api/core/ApiServer.ts index 038f7cc1b8..50ee6cd02e 100644 --- a/src/main/data/api/core/ApiServer.ts +++ b/src/main/data/api/core/ApiServer.ts @@ -1,14 +1,15 @@ import { loggerService } from '@logger' -import type { ApiImplementation } from '@shared/data/api/apiSchemas' +import type { RequestContext as ErrorRequestContext } from '@shared/data/api/apiErrors' +import { DataApiError, DataApiErrorFactory, toDataApiError } from '@shared/data/api/apiErrors' +import type { ApiImplementation } from '@shared/data/api/apiTypes' import type { DataRequest, DataResponse, HttpMethod, RequestContext } from '@shared/data/api/apiTypes' -import { DataApiErrorFactory, ErrorCode } from '@shared/data/api/errorCodes' import { MiddlewareEngine } from './MiddlewareEngine' // Handler function type type HandlerFunction = (params: { params?: Record; query?: any; body?: any }) => Promise -const logger = loggerService.withContext('DataApiServer') +const logger = loggerService.withContext('DataApi:Server') /** * Core API Server - Transport agnostic request processor @@ -59,6 +60,14 @@ export class ApiServer { const { method, path } = request const startTime = Date.now() + // Build error request context for tracking + const errorContext: ErrorRequestContext = { + requestId: request.id, + path, + method: method as HttpMethod, + timestamp: startTime + } + logger.debug(`Processing request: ${method} ${path}`) try { @@ -66,7 +75,7 @@ export class ApiServer { const handlerMatch = this.findHandler(path, method as HttpMethod) if (!handlerMatch) { - throw DataApiErrorFactory.create(ErrorCode.NOT_FOUND, `Handler not found: ${method} ${path}`) + throw DataApiErrorFactory.notFound('Handler', `${method} ${path}`, errorContext) } // Create request context @@ -91,12 +100,13 @@ export class ApiServer { } catch (error) { logger.error(`Request handling failed: ${method} ${path}`, error as Error) - const apiError = DataApiErrorFactory.create(ErrorCode.INTERNAL_SERVER_ERROR, (error as Error).message) + // Convert to DataApiError and serialize for IPC + const apiError = error instanceof DataApiError ? error : toDataApiError(error, `${method} ${path}`) return { id: request.id, status: apiError.status, - error: apiError, + error: apiError.toJSON(), // Serialize for IPC transmission metadata: { duration: Date.now() - startTime, timestamp: Date.now() @@ -105,37 +115,6 @@ export class ApiServer { } } - /** - * Handle batch requests - */ - async handleBatchRequest(batchRequest: DataRequest): Promise { - const requests = batchRequest.body?.requests || [] - - if (!Array.isArray(requests)) { - throw DataApiErrorFactory.create(ErrorCode.VALIDATION_ERROR, 'Batch request body must contain requests array') - } - - logger.debug(`Processing batch request with ${requests.length} requests`) - - // Use the batch handler from our handlers - const batchHandler = this.handlers['/batch']?.POST - if (!batchHandler) { - throw DataApiErrorFactory.create(ErrorCode.NOT_FOUND, 'Batch handler not found') - } - - const result = await batchHandler({ body: batchRequest.body }) - - return { - id: batchRequest.id, - status: 200, - data: result, - metadata: { - duration: 0, - timestamp: Date.now() - } - } - } - /** * Find handler for given path and method */ diff --git a/src/main/data/api/core/MiddlewareEngine.ts b/src/main/data/api/core/MiddlewareEngine.ts index 1f6bf1915d..f1bf3c90b7 100644 --- a/src/main/data/api/core/MiddlewareEngine.ts +++ b/src/main/data/api/core/MiddlewareEngine.ts @@ -1,8 +1,8 @@ import { loggerService } from '@logger' +import { toDataApiError } from '@shared/data/api/apiErrors' import type { DataRequest, DataResponse, Middleware, RequestContext } from '@shared/data/api/apiTypes' -import { toDataApiError } from '@shared/data/api/errorCodes' -const logger = loggerService.withContext('MiddlewareEngine') +const logger = loggerService.withContext('DataApi:MiddlewareEngine') /** * Middleware engine for executing middleware chains @@ -82,7 +82,7 @@ export class MiddlewareEngine { logger.error(`Request error: ${req.method} ${req.path}`, error as Error) const apiError = toDataApiError(error, `${req.method} ${req.path}`) - res.error = apiError + res.error = apiError.toJSON() // Serialize for IPC transmission res.status = apiError.status } } diff --git a/src/main/data/api/core/adapters/IpcAdapter.ts b/src/main/data/api/core/adapters/IpcAdapter.ts index 7d17264388..d7d08fa6e1 100644 --- a/src/main/data/api/core/adapters/IpcAdapter.ts +++ b/src/main/data/api/core/adapters/IpcAdapter.ts @@ -1,12 +1,12 @@ import { loggerService } from '@logger' +import { toDataApiError } from '@shared/data/api/apiErrors' import type { DataRequest, DataResponse } from '@shared/data/api/apiTypes' -import { toDataApiError } from '@shared/data/api/errorCodes' import { IpcChannel } from '@shared/IpcChannel' import { ipcMain } from 'electron' import type { ApiServer } from '../ApiServer' -const logger = loggerService.withContext('DataApiIpcAdapter') +const logger = loggerService.withContext('DataApi:IpcAdapter') /** * IPC Adapter for Electron environment @@ -46,7 +46,7 @@ export class IpcAdapter { const errorResponse: DataResponse = { id: request.id, status: apiError.status, - error: apiError, + error: apiError.toJSON(), // Serialize for IPC transmission metadata: { duration: 0, timestamp: Date.now() @@ -57,55 +57,6 @@ export class IpcAdapter { } }) - // Batch request handler - ipcMain.handle(IpcChannel.DataApi_Batch, async (_event, batchRequest: DataRequest): Promise => { - try { - logger.debug('Handling batch request', { requestCount: batchRequest.body?.requests?.length }) - - const response = await this.apiServer.handleBatchRequest(batchRequest) - return response - } catch (error) { - logger.error('Batch request failed', error as Error) - - const apiError = toDataApiError(error, 'batch request') - return { - id: batchRequest.id, - status: apiError.status, - error: apiError, - metadata: { - duration: 0, - timestamp: Date.now() - } - } - } - }) - - // Transaction handler (placeholder) - ipcMain.handle( - IpcChannel.DataApi_Transaction, - async (_event, transactionRequest: DataRequest): Promise => { - try { - logger.debug('Handling transaction request') - - // TODO: Implement transaction support - throw new Error('Transaction support not yet implemented') - } catch (error) { - logger.error('Transaction request failed', error as Error) - - const apiError = toDataApiError(error, 'transaction request') - return { - id: transactionRequest.id, - status: apiError.status, - error: apiError, - metadata: { - duration: 0, - timestamp: Date.now() - } - } - } - } - ) - // Subscription handlers (placeholder for future real-time features) ipcMain.handle(IpcChannel.DataApi_Subscribe, async (_event, path: string) => { logger.debug(`Data subscription request: ${path}`) @@ -134,8 +85,6 @@ export class IpcAdapter { logger.debug('Removing IPC handlers...') ipcMain.removeHandler(IpcChannel.DataApi_Request) - ipcMain.removeHandler(IpcChannel.DataApi_Batch) - ipcMain.removeHandler(IpcChannel.DataApi_Transaction) ipcMain.removeHandler(IpcChannel.DataApi_Subscribe) ipcMain.removeHandler(IpcChannel.DataApi_Unsubscribe) diff --git a/src/main/data/api/handlers/index.ts b/src/main/data/api/handlers/index.ts index 817a882be8..87072fdfc0 100644 --- a/src/main/data/api/handlers/index.ts +++ b/src/main/data/api/handlers/index.ts @@ -1,210 +1,30 @@ /** - * Complete API handler implementation + * API Handlers Index * - * This file implements ALL endpoints defined in ApiSchemas. - * TypeScript will error if any endpoint is missing. + * Combines all domain-specific handlers into a unified apiHandlers object. + * TypeScript will error if any endpoint from ApiSchemas is missing. + * + * Handler files are organized by domain: + * - test.ts - Test API handlers + * - topics.ts - Topic API handlers + * - messages.ts - Message API handlers */ -import { TestService } from '@data/services/TestService' -import type { ApiImplementation } from '@shared/data/api/apiSchemas' +import type { ApiImplementation } from '@shared/data/api/apiTypes' -// Service instances -const testService = TestService.getInstance() +import { messageHandlers } from './messages' +import { testHandlers } from './test' +import { topicHandlers } from './topics' /** * Complete API handlers implementation * Must implement every path+method combination from ApiSchemas + * + * Handlers are spread from individual domain modules for maintainability. + * TypeScript ensures exhaustive coverage - missing handlers cause compile errors. */ export const apiHandlers: ApiImplementation = { - '/test/items': { - GET: async ({ query }) => { - return await testService.getItems({ - page: (query as any)?.page, - limit: (query as any)?.limit, - search: (query as any)?.search, - type: (query as any)?.type, - status: (query as any)?.status - }) - }, - - POST: async ({ body }) => { - return await testService.createItem({ - title: body.title, - description: body.description, - type: body.type, - status: body.status, - priority: body.priority, - tags: body.tags, - metadata: body.metadata - }) - } - }, - - '/test/items/:id': { - GET: async ({ params }) => { - const item = await testService.getItemById(params.id) - if (!item) { - throw new Error(`Test item not found: ${params.id}`) - } - return item - }, - - PUT: async ({ params, body }) => { - const item = await testService.updateItem(params.id, { - title: body.title, - description: body.description, - type: body.type, - status: body.status, - priority: body.priority, - tags: body.tags, - metadata: body.metadata - }) - if (!item) { - throw new Error(`Test item not found: ${params.id}`) - } - return item - }, - - DELETE: async ({ params }) => { - const deleted = await testService.deleteItem(params.id) - if (!deleted) { - throw new Error(`Test item not found: ${params.id}`) - } - return undefined - } - }, - - '/test/search': { - GET: async ({ query }) => { - return await testService.searchItems(query.query, { - page: query.page, - limit: query.limit, - filters: { - type: query.type, - status: query.status - } - }) - } - }, - - '/test/stats': { - GET: async () => { - return await testService.getStats() - } - }, - - '/test/bulk': { - POST: async ({ body }) => { - return await testService.bulkOperation(body.operation, body.data) - } - }, - - '/test/error': { - POST: async ({ body }) => { - return await testService.simulateError(body.errorType) - } - }, - - '/test/slow': { - POST: async ({ body }) => { - const delay = body.delay - await new Promise((resolve) => setTimeout(resolve, delay)) - return { - message: `Slow response completed after ${delay}ms`, - delay, - timestamp: new Date().toISOString() - } - } - }, - - '/test/reset': { - POST: async () => { - await testService.resetData() - return { - message: 'Test data reset successfully', - timestamp: new Date().toISOString() - } - } - }, - - '/test/config': { - GET: async () => { - return { - environment: 'test', - version: '1.0.0', - debug: true, - features: { - bulkOperations: true, - search: true, - statistics: true - } - } - }, - - PUT: async ({ body }) => { - return { - ...body, - updated: true, - timestamp: new Date().toISOString() - } - } - }, - - '/test/status': { - GET: async () => { - return { - status: 'healthy', - timestamp: new Date().toISOString(), - version: '1.0.0', - uptime: Math.floor(process.uptime()), - environment: 'test' - } - } - }, - - '/test/performance': { - GET: async () => { - const memUsage = process.memoryUsage() - return { - requestsPerSecond: Math.floor(Math.random() * 100) + 50, - averageLatency: Math.floor(Math.random() * 200) + 50, - memoryUsage: memUsage.heapUsed / 1024 / 1024, // MB - cpuUsage: Math.random() * 100, - uptime: Math.floor(process.uptime()) - } - } - }, - - '/batch': { - POST: async ({ body }) => { - // Mock batch implementation - can be enhanced with actual batch processing - const { requests } = body - - const results = requests.map(() => ({ - status: 200, - data: { processed: true, timestamp: new Date().toISOString() } - })) - - return { - results, - metadata: { - duration: Math.floor(Math.random() * 500) + 100, - successCount: requests.length, - errorCount: 0 - } - } - } - }, - - '/transaction': { - POST: async ({ body }) => { - // Mock transaction implementation - can be enhanced with actual transaction support - const { operations } = body - - return operations.map(() => ({ - status: 200, - data: { executed: true, timestamp: new Date().toISOString() } - })) - } - } + ...testHandlers, + ...topicHandlers, + ...messageHandlers } diff --git a/src/main/data/api/handlers/messages.ts b/src/main/data/api/handlers/messages.ts new file mode 100644 index 0000000000..d89457ab5e --- /dev/null +++ b/src/main/data/api/handlers/messages.ts @@ -0,0 +1,75 @@ +/** + * Message API Handlers + * + * Implements all message-related API endpoints including: + * - Tree visualization queries + * - Branch message queries with pagination + * - Message CRUD operations + */ + +import { messageService } from '@data/services/MessageService' +import type { ApiHandler, ApiMethods } from '@shared/data/api/apiTypes' +import type { + ActiveNodeStrategy, + BranchMessagesQueryParams, + MessageSchemas, + TreeQueryParams +} from '@shared/data/api/schemas/messages' + +/** + * Handler type for a specific message endpoint + */ +type MessageHandler> = ApiHandler + +/** + * Message API handlers implementation + */ +export const messageHandlers: { + [Path in keyof MessageSchemas]: { + [Method in keyof MessageSchemas[Path]]: MessageHandler> + } +} = { + '/topics/:topicId/tree': { + GET: async ({ params, query }) => { + const q = (query || {}) as TreeQueryParams + return await messageService.getTree(params.topicId, { + rootId: q.rootId, + nodeId: q.nodeId, + depth: q.depth + }) + } + }, + + '/topics/:topicId/messages': { + GET: async ({ params, query }) => { + const q = (query || {}) as BranchMessagesQueryParams + return await messageService.getBranchMessages(params.topicId, { + nodeId: q.nodeId, + cursor: q.cursor, + limit: q.limit, + includeSiblings: q.includeSiblings + }) + }, + + POST: async ({ params, body }) => { + return await messageService.create(params.topicId, body) + } + }, + + '/messages/:id': { + GET: async ({ params }) => { + return await messageService.getById(params.id) + }, + + PATCH: async ({ params, body }) => { + return await messageService.update(params.id, body) + }, + + DELETE: async ({ params, query }) => { + const q = (query || {}) as { cascade?: boolean; activeNodeStrategy?: ActiveNodeStrategy } + const cascade = q.cascade ?? false + const activeNodeStrategy = q.activeNodeStrategy ?? 'parent' + return await messageService.delete(params.id, cascade, activeNodeStrategy) + } + } +} diff --git a/src/main/data/api/handlers/test.ts b/src/main/data/api/handlers/test.ts new file mode 100644 index 0000000000..a9522cf2a9 --- /dev/null +++ b/src/main/data/api/handlers/test.ts @@ -0,0 +1,185 @@ +/** + * Test API Handlers + * + * Implements all test-related API endpoints for development and testing purposes. + */ + +import { TestService } from '@data/services/TestService' +import type { ApiHandler, ApiMethods } from '@shared/data/api/apiTypes' +import type { TestSchemas } from '@shared/data/api/schemas/test' + +// Service instance +const testService = TestService.getInstance() + +/** + * Handler type for a specific test endpoint + */ +type TestHandler> = ApiHandler + +/** + * Test API handlers implementation + */ +export const testHandlers: { + [Path in keyof TestSchemas]: { + [Method in keyof TestSchemas[Path]]: TestHandler> + } +} = { + '/test/items': { + GET: async ({ query }) => { + return await testService.getItems({ + page: (query as any)?.page, + limit: (query as any)?.limit, + search: (query as any)?.search, + type: (query as any)?.type, + status: (query as any)?.status + }) + }, + + POST: async ({ body }) => { + return await testService.createItem({ + title: body.title, + description: body.description, + type: body.type, + status: body.status, + priority: body.priority, + tags: body.tags, + metadata: body.metadata + }) + } + }, + + '/test/items/:id': { + GET: async ({ params }) => { + const item = await testService.getItemById(params.id) + if (!item) { + throw new Error(`Test item not found: ${params.id}`) + } + return item + }, + + PUT: async ({ params, body }) => { + const item = await testService.updateItem(params.id, { + title: body.title, + description: body.description, + type: body.type, + status: body.status, + priority: body.priority, + tags: body.tags, + metadata: body.metadata + }) + if (!item) { + throw new Error(`Test item not found: ${params.id}`) + } + return item + }, + + DELETE: async ({ params }) => { + const deleted = await testService.deleteItem(params.id) + if (!deleted) { + throw new Error(`Test item not found: ${params.id}`) + } + return undefined + } + }, + + '/test/search': { + GET: async ({ query }) => { + return await testService.searchItems(query.query, { + page: query.page, + limit: query.limit, + filters: { + type: query.type, + status: query.status + } + }) + } + }, + + '/test/stats': { + GET: async () => { + return await testService.getStats() + } + }, + + '/test/bulk': { + POST: async ({ body }) => { + return await testService.bulkOperation(body.operation, body.data) + } + }, + + '/test/error': { + POST: async ({ body }) => { + return await testService.simulateError(body.errorType) + } + }, + + '/test/slow': { + POST: async ({ body }) => { + const delay = body.delay + await new Promise((resolve) => setTimeout(resolve, delay)) + return { + message: `Slow response completed after ${delay}ms`, + delay, + timestamp: new Date().toISOString() + } + } + }, + + '/test/reset': { + POST: async () => { + await testService.resetData() + return { + message: 'Test data reset successfully', + timestamp: new Date().toISOString() + } + } + }, + + '/test/config': { + GET: async () => { + return { + environment: 'test', + version: '1.0.0', + debug: true, + features: { + bulkOperations: true, + search: true, + statistics: true + } + } + }, + + PUT: async ({ body }) => { + return { + ...body, + updated: true, + timestamp: new Date().toISOString() + } + } + }, + + '/test/status': { + GET: async () => { + return { + status: 'healthy', + timestamp: new Date().toISOString(), + version: '1.0.0', + uptime: Math.floor(process.uptime()), + environment: 'test' + } + } + }, + + '/test/performance': { + GET: async () => { + const memUsage = process.memoryUsage() + return { + requestsPerSecond: Math.floor(Math.random() * 100) + 50, + averageLatency: Math.floor(Math.random() * 200) + 50, + memoryUsage: memUsage.heapUsed / 1024 / 1024, // MB + cpuUsage: Math.random() * 100, + uptime: Math.floor(process.uptime()) + } + } + } +} diff --git a/src/main/data/api/handlers/topics.ts b/src/main/data/api/handlers/topics.ts new file mode 100644 index 0000000000..45fbabac1b --- /dev/null +++ b/src/main/data/api/handlers/topics.ts @@ -0,0 +1,52 @@ +/** + * Topic API Handlers + * + * Implements all topic-related API endpoints including: + * - Topic CRUD operations + * - Active node switching for branch navigation + */ + +import { topicService } from '@data/services/TopicService' +import type { ApiHandler, ApiMethods } from '@shared/data/api/apiTypes' +import type { TopicSchemas } from '@shared/data/api/schemas/topics' + +/** + * Handler type for a specific topic endpoint + */ +type TopicHandler> = ApiHandler + +/** + * Topic API handlers implementation + */ +export const topicHandlers: { + [Path in keyof TopicSchemas]: { + [Method in keyof TopicSchemas[Path]]: TopicHandler> + } +} = { + '/topics': { + POST: async ({ body }) => { + return await topicService.create(body) + } + }, + + '/topics/:id': { + GET: async ({ params }) => { + return await topicService.getById(params.id) + }, + + PATCH: async ({ params, body }) => { + return await topicService.update(params.id, body) + }, + + DELETE: async ({ params }) => { + await topicService.delete(params.id) + return undefined + } + }, + + '/topics/:id/active-node': { + PUT: async ({ params, body }) => { + return await topicService.setActiveNode(params.id, body.nodeId) + } + } +} diff --git a/src/main/data/api/index.ts b/src/main/data/api/index.ts index 1bd4d3b7a5..d2cf988727 100644 --- a/src/main/data/api/index.ts +++ b/src/main/data/api/index.ts @@ -20,13 +20,18 @@ export { apiHandlers } from './handlers' export { TestService } from '@data/services/TestService' // Re-export types for convenience -export type { CreateTestItemDto, TestItem, UpdateTestItemDto } from '@shared/data/api' export type { + CursorPaginationParams, + CursorPaginationResponse, DataRequest, DataResponse, Middleware, - PaginatedResponse, - PaginationParams, + OffsetPaginationParams, + OffsetPaginationResponse, + PaginationResponse, RequestContext, - ServiceOptions + SearchParams, + ServiceOptions, + SortParams } from '@shared/data/api/apiTypes' +export type { CreateTestItemDto, TestItem, UpdateTestItemDto } from '@shared/data/api/schemas/test' diff --git a/src/main/data/db/DbService.ts b/src/main/data/db/DbService.ts index 8a7edb6f33..de72be03dd 100644 --- a/src/main/data/db/DbService.ts +++ b/src/main/data/db/DbService.ts @@ -6,6 +6,7 @@ import { app } from 'electron' import path from 'path' import { pathToFileURL } from 'url' +import { CUSTOM_SQL_STATEMENTS } from './customSql' import Seeding from './seeding' import type { DbType } from './types' @@ -120,6 +121,9 @@ class DbService { const migrationsFolder = this.getMigrationsFolder() await migrate(this.db, { migrationsFolder }) + // Run custom SQL that Drizzle cannot manage (triggers, virtual tables, etc.) + await this.runCustomMigrations() + logger.info('Database migration completed successfully') } catch (error) { logger.error('Database migration failed', error as Error) @@ -127,6 +131,27 @@ class DbService { } } + /** + * Run custom SQL statements that Drizzle cannot manage + * + * This includes triggers, virtual tables, and other SQL objects. + * Called after every migration because: + * 1. Drizzle doesn't track these in schema + * 2. DROP TABLE removes associated triggers + * 3. All statements use IF NOT EXISTS, so they're idempotent + */ + private async runCustomMigrations(): Promise { + try { + for (const statement of CUSTOM_SQL_STATEMENTS) { + await this.db.run(sql.raw(statement)) + } + logger.debug('Custom migrations completed', { count: CUSTOM_SQL_STATEMENTS.length }) + } catch (error) { + logger.error('Custom migrations failed', error as Error) + throw error + } + } + /** * Get the database instance * @throws {Error} If database is not initialized diff --git a/src/main/data/db/README.md b/src/main/data/db/README.md index 8bc38b01c4..2a07bd5d43 100644 --- a/src/main/data/db/README.md +++ b/src/main/data/db/README.md @@ -1,2 +1,52 @@ -- All the database table names use **singular** form, snake_casing -- Export table names use `xxxxTable` +# Database Layer + +This directory contains database schemas and configuration. + +## Documentation + +- **Database Patterns**: [docs/en/references/data/database-patterns.md](../../../../docs/en/references/data/database-patterns.md) + +## Directory Structure + +``` +src/main/data/db/ +├── schemas/ # Drizzle table definitions +│ ├── columnHelpers.ts # Reusable column definitions +│ ├── topic.ts # Topic table +│ ├── message.ts # Message table +│ ├── messageFts.ts # FTS5 virtual table & triggers +│ └── ... # Other tables +├── seeding/ # Database initialization +├── customSql.ts # Custom SQL (triggers, virtual tables, etc.) +└── DbService.ts # Database connection management +``` + +## Quick Reference + +### Naming Conventions + +- **Table names**: Singular snake_case (`topic`, `message`, `app_state`) +- **Export names**: `xxxTable` pattern (`topicTable`, `messageTable`) + +### Common Commands + +```bash +# Generate migrations after schema changes +yarn db:migrations:generate +``` + +### Custom SQL (Triggers, Virtual Tables) + +Drizzle cannot manage triggers and virtual tables. See `customSql.ts` for how these are handled. + +### Column Helpers + +```typescript +import { uuidPrimaryKey, createUpdateTimestamps } from './columnHelpers' + +export const myTable = sqliteTable('my_table', { + id: uuidPrimaryKey(), + name: text(), + ...createUpdateTimestamps +}) +``` diff --git a/src/main/data/db/customSql.ts b/src/main/data/db/customSql.ts new file mode 100644 index 0000000000..eaeea28db2 --- /dev/null +++ b/src/main/data/db/customSql.ts @@ -0,0 +1,25 @@ +/** + * Custom SQL statements that Drizzle cannot manage + * + * Drizzle ORM doesn't track: + * - Virtual tables (FTS5) + * - Triggers + * - Custom indexes with expressions + * + * These are executed after every migration via DbService.runCustomMigrations() + * All statements must be idempotent (use IF NOT EXISTS, etc.) + * + * To add new custom SQL: + * 1. Create statements in the relevant schema file (e.g., messageFts.ts) + * 2. Import and spread them into CUSTOM_SQL_STATEMENTS below + */ + +import { MESSAGE_FTS_STATEMENTS } from './schemas/messageFts' + +/** + * All custom SQL statements to run after migrations + */ +export const CUSTOM_SQL_STATEMENTS: string[] = [ + ...MESSAGE_FTS_STATEMENTS + // Add more custom SQL arrays here as needed +] diff --git a/src/main/data/db/schemas/columnHelpers.ts b/src/main/data/db/schemas/columnHelpers.ts index 7623afd0ed..61a596602d 100644 --- a/src/main/data/db/schemas/columnHelpers.ts +++ b/src/main/data/db/schemas/columnHelpers.ts @@ -1,4 +1,32 @@ -import { integer } from 'drizzle-orm/sqlite-core' +/** + * Column helper utilities for Drizzle schemas + * + * USAGE RULES: + * - DO NOT manually set id, createdAt, or updatedAt - they are auto-generated + * - Use .returning() to get inserted/updated rows instead of re-querying + * - See db/README.md for detailed field generation rules + */ + +import { integer, text } from 'drizzle-orm/sqlite-core' +import { v4 as uuidv4, v7 as uuidv7 } from 'uuid' + +/** + * UUID v4 primary key with auto-generation + * Use for general purpose tables + */ +export const uuidPrimaryKey = () => + text() + .primaryKey() + .$defaultFn(() => uuidv4()) + +/** + * UUID v7 primary key with auto-generation (time-ordered) + * Use for tables with large datasets that benefit from sequential inserts + */ +export const uuidPrimaryKeyOrdered = () => + text() + .primaryKey() + .$defaultFn(() => uuidv7()) const createTimestamp = () => { return Date.now() diff --git a/src/main/data/db/schemas/entityTag.ts b/src/main/data/db/schemas/entityTag.ts new file mode 100644 index 0000000000..e041d771db --- /dev/null +++ b/src/main/data/db/schemas/entityTag.ts @@ -0,0 +1,26 @@ +import { index, primaryKey, sqliteTable, text } from 'drizzle-orm/sqlite-core' + +import { createUpdateTimestamps } from './columnHelpers' +import { tagTable } from './tag' + +/** + * Entity-Tag join table - associates tags with entities + * + * Supports many-to-many relationship between tags and + * various entity types (topic, session, assistant). + */ +export const entityTagTable = sqliteTable( + 'entity_tag', + { + // Entity type: topic, session, assistant + entityType: text().notNull(), + // FK to the entity + entityId: text().notNull(), + // FK to tag table - CASCADE: delete association when tag is deleted + tagId: text() + .notNull() + .references(() => tagTable.id, { onDelete: 'cascade' }), + ...createUpdateTimestamps + }, + (t) => [primaryKey({ columns: [t.entityType, t.entityId, t.tagId] }), index('entity_tag_tag_id_idx').on(t.tagId)] +) diff --git a/src/main/data/db/schemas/group.ts b/src/main/data/db/schemas/group.ts new file mode 100644 index 0000000000..6ef06c522f --- /dev/null +++ b/src/main/data/db/schemas/group.ts @@ -0,0 +1,24 @@ +import { index, integer, sqliteTable, text } from 'drizzle-orm/sqlite-core' + +import { createUpdateTimestamps, uuidPrimaryKey } from './columnHelpers' + +/** + * Group table - general-purpose grouping for entities + * + * Supports grouping of topics, sessions, and assistants. + * Each group belongs to a specific entity type. + */ +export const groupTable = sqliteTable( + 'group', + { + id: uuidPrimaryKey(), + // Entity type this group belongs to: topic, session, assistant + entityType: text().notNull(), + // Display name of the group + name: text().notNull(), + // Sort order for display + sortOrder: integer().default(0), + ...createUpdateTimestamps + }, + (t) => [index('group_entity_sort_idx').on(t.entityType, t.sortOrder)] +) diff --git a/src/main/data/db/schemas/message.ts b/src/main/data/db/schemas/message.ts new file mode 100644 index 0000000000..3118433f6c --- /dev/null +++ b/src/main/data/db/schemas/message.ts @@ -0,0 +1,65 @@ +import type { MessageData, MessageStats } from '@shared/data/types/message' +import type { AssistantMeta, ModelMeta } from '@shared/data/types/meta' +import { sql } from 'drizzle-orm' +import { check, foreignKey, index, integer, sqliteTable, text } from 'drizzle-orm/sqlite-core' + +import { createUpdateDeleteTimestamps, uuidPrimaryKeyOrdered } from './columnHelpers' +import { topicTable } from './topic' + +/** + * Message table - stores chat messages with tree structure + * + * Uses adjacency list pattern (parentId) for tree navigation. + * Block content is stored as JSON in the data field. + * searchableText is a generated column for FTS5 indexing. + */ +export const messageTable = sqliteTable( + 'message', + { + id: uuidPrimaryKeyOrdered(), + // Adjacency list parent reference for tree structure + parentId: text(), + // FK to topic - CASCADE: delete messages when topic is deleted + topicId: text() + .notNull() + .references(() => topicTable.id, { onDelete: 'cascade' }), + // Message role: user, assistant, system + role: text().notNull(), + // Main content - contains blocks[], mentions, etc. + data: text({ mode: 'json' }).$type().notNull(), + // Searchable text extracted from data.blocks (populated by trigger, used for FTS5) + searchableText: text(), + + // Final status: SUCCESS, ERROR, PAUSED + status: text().notNull(), + + // Group ID for siblings (0 = normal branch) + siblingsGroupId: integer().default(0), + // FK to assistant + assistantId: text(), + // Preserved assistant info for display + assistantMeta: text({ mode: 'json' }).$type(), + // Model identifier + modelId: text(), + // Preserved model info (provider, name) + modelMeta: text({ mode: 'json' }).$type(), + // Trace ID for tracking + + traceId: text(), + // Statistics: token usage, performance metrics, etc. + stats: text({ mode: 'json' }).$type(), + + ...createUpdateDeleteTimestamps + }, + (t) => [ + // Foreign keys + foreignKey({ columns: [t.parentId], foreignColumns: [t.id] }).onDelete('set null'), + // Indexes + index('message_parent_id_idx').on(t.parentId), + index('message_topic_created_idx').on(t.topicId, t.createdAt), + index('message_trace_id_idx').on(t.traceId), + // Check constraints for enum fields + check('message_role_check', sql`${t.role} IN ('user', 'assistant', 'system')`), + check('message_status_check', sql`${t.status} IN ('pending', 'success', 'error', 'paused')`) + ] +) diff --git a/src/main/data/db/schemas/messageFts.ts b/src/main/data/db/schemas/messageFts.ts new file mode 100644 index 0000000000..ccffbb5eaf --- /dev/null +++ b/src/main/data/db/schemas/messageFts.ts @@ -0,0 +1,86 @@ +/** + * FTS5 SQL statements for message full-text search + * + * This file contains SQL statements that must be manually added to migration files. + * Drizzle does not auto-generate virtual tables or triggers. + * + * Architecture: + * 1. message.searchable_text - regular column populated by trigger + * 2. message_fts - FTS5 virtual table with external content + * 3. Triggers sync both searchable_text and FTS5 index + * + * Usage: + * - Copy MESSAGE_FTS_MIGRATION_SQL to migration file when generating migrations + */ + +/** + * SQL expression to extract searchable text from data.blocks + * Concatenates content from all main_text type blocks + */ +export const SEARCHABLE_TEXT_EXPRESSION = ` + (SELECT group_concat(json_extract(value, '$.content'), ' ') + FROM json_each(json_extract(NEW.data, '$.blocks')) + WHERE json_extract(value, '$.type') = 'main_text') +` + +/** + * Custom SQL statements that Drizzle cannot manage + * These are executed after every migration via DbService.runCustomMigrations() + * + * All statements should use IF NOT EXISTS to be idempotent. + */ +export const MESSAGE_FTS_STATEMENTS: string[] = [ + // FTS5 virtual table, Links to message table's searchable_text column + `CREATE VIRTUAL TABLE IF NOT EXISTS message_fts USING fts5( + searchable_text, + content='message', + content_rowid='rowid', + tokenize='trigram' + )`, + + // Trigger: populate searchable_text and sync FTS on INSERT + `CREATE TRIGGER IF NOT EXISTS message_ai AFTER INSERT ON message BEGIN + UPDATE message SET searchable_text = ( + SELECT group_concat(json_extract(value, '$.content'), ' ') + FROM json_each(json_extract(NEW.data, '$.blocks')) + WHERE json_extract(value, '$.type') = 'main_text' + ) WHERE id = NEW.id; + INSERT INTO message_fts(rowid, searchable_text) + SELECT rowid, searchable_text FROM message WHERE id = NEW.id; + END`, + + // Trigger: sync FTS on DELETE + `CREATE TRIGGER IF NOT EXISTS message_ad AFTER DELETE ON message BEGIN + INSERT INTO message_fts(message_fts, rowid, searchable_text) + VALUES ('delete', OLD.rowid, OLD.searchable_text); + END`, + + // Trigger: update searchable_text and sync FTS on UPDATE OF data + `CREATE TRIGGER IF NOT EXISTS message_au AFTER UPDATE OF data ON message BEGIN + INSERT INTO message_fts(message_fts, rowid, searchable_text) + VALUES ('delete', OLD.rowid, OLD.searchable_text); + UPDATE message SET searchable_text = ( + SELECT group_concat(json_extract(value, '$.content'), ' ') + FROM json_each(json_extract(NEW.data, '$.blocks')) + WHERE json_extract(value, '$.type') = 'main_text' + ) WHERE id = NEW.id; + INSERT INTO message_fts(rowid, searchable_text) + SELECT rowid, searchable_text FROM message WHERE id = NEW.id; + END` +] + +/** + * Rebuild FTS index (run manually if needed) + */ +export const REBUILD_FTS_SQL = `INSERT INTO message_fts(message_fts) VALUES ('rebuild')` + +/** + * Example search query + */ +export const EXAMPLE_SEARCH_SQL = ` +SELECT m.* +FROM message m +JOIN message_fts fts ON m.rowid = fts.rowid +WHERE message_fts MATCH ? +ORDER BY rank +` diff --git a/src/main/data/db/schemas/preference.ts b/src/main/data/db/schemas/preference.ts index f41cf175c4..5ca9b2f14a 100644 --- a/src/main/data/db/schemas/preference.ts +++ b/src/main/data/db/schemas/preference.ts @@ -1,14 +1,14 @@ -import { index, sqliteTable, text } from 'drizzle-orm/sqlite-core' +import { primaryKey, sqliteTable, text } from 'drizzle-orm/sqlite-core' import { createUpdateTimestamps } from './columnHelpers' export const preferenceTable = sqliteTable( 'preference', { - scope: text().notNull(), // scope is reserved for future use, now only 'default' is supported + scope: text().notNull().default('default'), // scope is reserved for future use, now only 'default' is supported key: text().notNull(), value: text({ mode: 'json' }), ...createUpdateTimestamps }, - (t) => [index('scope_name_idx').on(t.scope, t.key)] + (t) => [primaryKey({ columns: [t.scope, t.key] })] ) diff --git a/src/main/data/db/schemas/tag.ts b/src/main/data/db/schemas/tag.ts new file mode 100644 index 0000000000..87820fadf9 --- /dev/null +++ b/src/main/data/db/schemas/tag.ts @@ -0,0 +1,18 @@ +import { sqliteTable, text } from 'drizzle-orm/sqlite-core' + +import { createUpdateTimestamps, uuidPrimaryKey } from './columnHelpers' + +/** + * Tag table - general-purpose tags for entities + * + * Tags can be applied to topics, sessions, and assistants + * via the entity_tag join table. + */ +export const tagTable = sqliteTable('tag', { + id: uuidPrimaryKey(), + // Unique tag name + name: text().notNull().unique(), + // Display color (hex code) + color: text(), + ...createUpdateTimestamps +}) diff --git a/src/main/data/db/schemas/topic.ts b/src/main/data/db/schemas/topic.ts new file mode 100644 index 0000000000..68078d8f86 --- /dev/null +++ b/src/main/data/db/schemas/topic.ts @@ -0,0 +1,47 @@ +import type { AssistantMeta } from '@shared/data/types/meta' +import { index, integer, sqliteTable, text } from 'drizzle-orm/sqlite-core' + +import { createUpdateDeleteTimestamps, uuidPrimaryKey } from './columnHelpers' +import { groupTable } from './group' + +/** + * Topic table - stores conversation topics/threads + * + * Topics are containers for messages and belong to assistants. + * They can be organized into groups and have tags for categorization. + */ +export const topicTable = sqliteTable( + 'topic', + { + id: uuidPrimaryKey(), + name: text(), + // Whether the name was manually edited by user + isNameManuallyEdited: integer({ mode: 'boolean' }).default(false), + // FK to assistant table + assistantId: text(), + // Preserved assistant info for display when assistant is deleted + assistantMeta: text({ mode: 'json' }).$type(), + // Topic-specific prompt override + prompt: text(), + // Active node ID in the message tree + activeNodeId: text(), + + // FK to group table for organization + // SET NULL: preserve topic when group is deleted + groupId: text().references(() => groupTable.id, { onDelete: 'set null' }), + // Sort order within group + sortOrder: integer().default(0), + // Pinning state and order + isPinned: integer({ mode: 'boolean' }).default(false), + pinnedOrder: integer().default(0), + + ...createUpdateDeleteTimestamps + }, + (t) => [ + index('topic_group_updated_idx').on(t.groupId, t.updatedAt), + index('topic_group_sort_idx').on(t.groupId, t.sortOrder), + index('topic_updated_at_idx').on(t.updatedAt), + index('topic_is_pinned_idx').on(t.isPinned, t.pinnedOrder), + index('topic_assistant_id_idx').on(t.assistantId) + ] +) diff --git a/src/main/data/migration/v2/README.md b/src/main/data/migration/v2/README.md index 86d597223e..6e5e071f7d 100644 --- a/src/main/data/migration/v2/README.md +++ b/src/main/data/migration/v2/README.md @@ -1,64 +1,33 @@ -# Migration V2 (Main Process) +# Data Migration System -Architecture for the new one-shot migration from the legacy Dexie + Redux Persist stores into the SQLite schema. This module owns orchestration, data access helpers, migrator plugins, and IPC entry points used by the renderer migration window. +This directory contains the v2 data migration implementation. -## Directory Layout +## Documentation + +- **Migration Guide**: [docs/en/references/data/v2-migration-guide.md](../../../../../docs/en/references/data/v2-migration-guide.md) + +## Directory Structure ``` src/main/data/migration/v2/ -├── core/ # Engine + shared context -├── migrators/ # Domain-specific migrators and mappings -├── utils/ # Data source readers (Redux, Dexie, streaming JSON) -├── window/ # IPC handlers + migration window manager -└── index.ts # Public exports for main process +├── core/ # MigrationEngine, MigrationContext +├── migrators/ # Domain-specific migrators +│ └── mappings/ # Mapping definitions +├── utils/ # ReduxStateReader, DexieFileReader, JSONStreamReader +├── window/ # IPC handlers, window manager +└── index.ts # Public exports ``` -## Core Contracts +## Quick Reference -- `core/MigrationEngine.ts` coordinates all migrators in order, surfaces progress to the UI, and marks status in `app_state.key = 'migration_v2_status'`. It will clear new-schema tables before running and abort on any validation failure. -- `core/MigrationContext.ts` builds the shared context passed to every migrator: - - `sources`: `ConfigManager` (ElectronStore), `ReduxStateReader` (parsed Redux Persist data), `DexieFileReader` (JSON exports) - - `db`: current SQLite connection - - `sharedData`: `Map` for passing cross-cutting info between migrators - - `logger`: `loggerService` scoped to migration -- `@shared/data/migration/v2/types` defines stages, results, and validation stats used across main and renderer. +### Creating a New Migrator -## Migrators +1. Extend `BaseMigrator` in `migrators/` +2. Implement `prepare`, `execute`, `validate` methods +3. Register in `migrators/index.ts` -- Base contract: extend `migrators/BaseMigrator.ts` and implement: - - `id`, `name`, `description`, `order` (lower runs first) - - `prepare(ctx)`: dry-run checks, counts, and staging data; return `PrepareResult` - - `execute(ctx)`: perform inserts/updates; manage your own transactions; report progress via `reportProgress` - - `validate(ctx)`: verify counts and integrity; return `ValidateResult` with stats (`sourceCount`, `targetCount`, `skippedCount`) and any `errors` -- Registration: list migrators (in order) in `migrators/index.ts` so the engine can sort and run them. -- Current migrators: - - `PreferencesMigrator` (implemented): maps ElectronStore + Redux settings to the `preference` table using `mappings/PreferencesMappings.ts`. - - `AssistantMigrator`, `KnowledgeMigrator`, `ChatMigrator` (placeholders): scaffolding and TODO notes for future tables. -- Conventions: - - All logging goes through `loggerService` with a migrator-specific context. - - Use `MigrationContext.sources` instead of accessing raw files/stores directly. - - Use `sharedData` to pass IDs or lookup tables between migrators (e.g., assistant -> chat references) instead of re-reading sources. - - Stream large Dexie exports (`JSONStreamReader`) and batch inserts to avoid memory spikes. - - Count validation is mandatory; engine will fail the run if `targetCount < sourceCount - skippedCount` or if `ValidateResult.errors` is non-empty. - - Keep migrations idempotent per run—engine clears target tables before it starts, but each migrator should tolerate retries within the same run. +### Key Contracts -## Utilities - -- `utils/ReduxStateReader.ts`: safe accessor for categorized Redux Persist data with dot-path lookup. -- `utils/DexieFileReader.ts`: reads exported Dexie JSON tables; can stream large tables. -- `utils/JSONStreamReader.ts`: streaming reader with batching, counting, and sampling helpers for very large arrays. - -## Window & IPC Integration - -- `window/MigrationIpcHandler.ts` exposes IPC channels for the migration UI: - - Receives Redux data and Dexie export path, starts the engine, and streams progress back to renderer. - - Manages backup flow (dialogs via `BackupManager`) and retry/cancel/restart actions. -- `window/MigrationWindowManager.ts` creates the frameless migration window, handles lifecycle, and relaunch instructions after completion in production. - -## Implementation Checklist for New Migrators - -- [ ] Add mapping definitions (if needed) under `migrators/mappings/`. -- [ ] Implement `prepare/execute/validate` with explicit counts, batch inserts, and integrity checks. -- [ ] Wire progress updates through `reportProgress` so UI shows per-migrator progress. -- [ ] Register the migrator in `migrators/index.ts` with the correct `order`. -- [ ] Add any new target tables to `MigrationEngine.verifyAndClearNewTables` once those tables exist. +- `prepare(ctx)`: Dry-run checks, return counts +- `execute(ctx)`: Perform inserts, report progress +- `validate(ctx)`: Verify counts and integrity diff --git a/src/main/data/migration/v2/core/MigrationEngine.ts b/src/main/data/migration/v2/core/MigrationEngine.ts index 1b004d38e7..77bc4afd92 100644 --- a/src/main/data/migration/v2/core/MigrationEngine.ts +++ b/src/main/data/migration/v2/core/MigrationEngine.ts @@ -5,7 +5,9 @@ import { dbService } from '@data/db/DbService' import { appStateTable } from '@data/db/schemas/appState' +import { messageTable } from '@data/db/schemas/message' import { preferenceTable } from '@data/db/schemas/preference' +import { topicTable } from '@data/db/schemas/topic' import { loggerService } from '@logger' import type { MigrationProgress, @@ -24,8 +26,6 @@ import { createMigrationContext } from './MigrationContext' // TODO: Import these tables when they are created in user data schema // import { assistantTable } from '../../db/schemas/assistant' -// import { topicTable } from '../../db/schemas/topic' -// import { messageTable } from '../../db/schemas/message' // import { fileTable } from '../../db/schemas/file' // import { knowledgeBaseTable } from '../../db/schemas/knowledgeBase' @@ -197,12 +197,13 @@ export class MigrationEngine { const db = dbService.getDb() // Tables to clear - add more as they are created + // Order matters: child tables must be cleared before parent tables const tables = [ + { table: messageTable, name: 'message' }, // Must clear before topic (FK reference) + { table: topicTable, name: 'topic' }, { table: preferenceTable, name: 'preference' } // TODO: Add these when tables are created // { table: assistantTable, name: 'assistant' }, - // { table: topicTable, name: 'topic' }, - // { table: messageTable, name: 'message' }, // { table: fileTable, name: 'file' }, // { table: knowledgeBaseTable, name: 'knowledge_base' } ] @@ -216,14 +217,15 @@ export class MigrationEngine { } } - // Clear tables in reverse dependency order + // Clear tables in dependency order (children before parents) + // Messages reference topics, so delete messages first + await db.delete(messageTable) + await db.delete(topicTable) + await db.delete(preferenceTable) // TODO: Add these when tables are created (in correct order) - // await db.delete(messageTable) - // await db.delete(topicTable) // await db.delete(fileTable) // await db.delete(knowledgeBaseTable) // await db.delete(assistantTable) - await db.delete(preferenceTable) logger.info('All new architecture tables cleared successfully') } diff --git a/src/main/data/migration/v2/migrators/ChatMigrator.ts b/src/main/data/migration/v2/migrators/ChatMigrator.ts index 5a9b845a00..077ad2179d 100644 --- a/src/main/data/migration/v2/migrators/ChatMigrator.ts +++ b/src/main/data/migration/v2/migrators/ChatMigrator.ts @@ -1,81 +1,659 @@ /** - * Chat migrator - migrates topics and messages from Dexie to SQLite + * Chat Migrator - Migrates topics and messages from Dexie to SQLite * - * TODO: Implement when chat tables are created - * Data source: Dexie topics table (messages are embedded in topics) - * Target tables: topic, message + * ## Overview * - * Note: This migrator handles the largest amount of data (potentially millions of messages) - * and uses streaming JSON reading with batch inserts for memory efficiency. + * This migrator handles the largest data migration task: transferring all chat topics + * and their messages from the old Dexie/IndexedDB storage to the new SQLite database. + * + * ## Data Sources + * + * | Data | Source | File/Path | + * |------|--------|-----------| + * | Topics with messages | Dexie `topics` table | `topics.json` → `{ id, messages[] }` | + * | Message blocks | Dexie `message_blocks` table | `message_blocks.json` | + * | Assistants (for meta) | Redux `assistants` slice | `ReduxStateReader.getCategory('assistants')` | + * + * ## Target Tables + * + * - `topicTable` - Stores conversation topics/threads + * - `messageTable` - Stores chat messages with tree structure + * + * ## Key Transformations + * + * 1. **Linear → Tree Structure** + * - Old: Messages stored as linear array in `topic.messages[]` + * - New: Tree via `parentId` + `siblingsGroupId` + * + * 2. **Multi-model Responses** + * - Old: `askId` links responses to user message, `foldSelected` marks active + * - New: Shared `parentId` + non-zero `siblingsGroupId` groups siblings + * + * 3. **Block Inlining** + * - Old: `message.blocks: string[]` (IDs) + separate `message_blocks` table + * - New: `message.data.blocks: MessageDataBlock[]` (inline JSON) + * + * 4. **Citation Migration** + * - Old: Separate `CitationMessageBlock` + * - New: Merged into `MainTextBlock.references` as ContentReference[] + * + * 5. **Mention Migration** + * - Old: `message.mentions: Model[]` + * - New: `MentionReference[]` in `MainTextBlock.references` + * + * ## Performance Considerations + * + * - Uses streaming JSON reader for large data sets (potentially millions of messages) + * - Processes topics in batches to control memory usage + * - Pre-loads all blocks into memory map for O(1) lookup (blocks table is smaller) + * - Uses database transactions for atomicity and performance + * + * @since v2.0.0 */ +import { messageTable } from '@data/db/schemas/message' +import { topicTable } from '@data/db/schemas/topic' import { loggerService } from '@logger' -import type { ExecuteResult, PrepareResult, ValidateResult } from '@shared/data/migration/v2/types' +import type { ExecuteResult, PrepareResult, ValidateResult, ValidationError } from '@shared/data/migration/v2/types' +import { eq, sql } from 'drizzle-orm' +import { v4 as uuidv4 } from 'uuid' +import type { MigrationContext } from '../core/MigrationContext' import { BaseMigrator } from './BaseMigrator' +import { + buildBlockLookup, + buildMessageTree, + findActiveNodeId, + type NewMessage, + type NewTopic, + type OldAssistant, + type OldBlock, + type OldTopic, + type OldTopicMeta, + resolveBlocks, + transformMessage, + transformTopic +} from './mappings/ChatMappings' const logger = loggerService.withContext('ChatMigrator') +/** + * Batch size for processing topics + * Chosen to balance memory usage and transaction overhead + */ +const TOPIC_BATCH_SIZE = 50 + +/** + * Batch size for inserting messages + * SQLite has limits on the number of parameters per statement + */ +const MESSAGE_INSERT_BATCH_SIZE = 100 + +/** + * Assistant data from Redux for generating AssistantMeta + */ +interface AssistantState { + assistants: OldAssistant[] +} + +/** + * Prepared data for execution phase + */ +interface PreparedTopicData { + topic: NewTopic + messages: NewMessage[] +} + export class ChatMigrator extends BaseMigrator { readonly id = 'chat' readonly name = 'ChatData' - readonly description = 'Migrate chat data' + readonly description = 'Migrate chat topics and messages' readonly order = 4 - async prepare(): Promise { - logger.info('ChatMigrator.prepare - placeholder implementation') + // Prepared data for execution + private topicCount = 0 + private messageCount = 0 + private blockLookup: Map = new Map() + private assistantLookup: Map = new Map() + // Topic metadata from Redux (name, pinned, etc.) - Dexie only has messages + private topicMetaLookup: Map = new Map() + // Topic → AssistantId mapping from Redux (Dexie topics don't store assistantId) + private topicAssistantLookup: Map = new Map() + private skippedTopics = 0 + private skippedMessages = 0 + // Track seen message IDs to handle duplicates across topics + private seenMessageIds = new Set() + // Block statistics for diagnostics + private blockStats = { requested: 0, resolved: 0, messagesWithMissingBlocks: 0, messagesWithEmptyBlocks: 0 } - // TODO: Implement when chat tables are created - // 1. Check if topics.json export file exists - // 2. Validate JSON format with sample read - // 3. Count total topics and estimate message count - // 4. Check for data integrity (e.g., messages have valid topic references) + /** + * Prepare phase - validate source data and count items + * + * Steps: + * 1. Check if topics.json and message_blocks.json exist + * 2. Load all blocks into memory for fast lookup + * 3. Load assistant data for generating meta + * 4. Count topics and estimate message count + * 5. Validate sample data for integrity + */ + async prepare(ctx: MigrationContext): Promise { + const warnings: string[] = [] - return { - success: true, - itemCount: 0, - warnings: ['ChatMigrator not yet implemented - waiting for chat tables'] - } - } + try { + // Step 1: Verify export files exist + const topicsExist = await ctx.sources.dexieExport.tableExists('topics') + if (!topicsExist) { + logger.warn('topics.json not found, skipping chat migration') + return { + success: true, + itemCount: 0, + warnings: ['topics.json not found - no chat data to migrate'] + } + } - async execute(): Promise { - logger.info('ChatMigrator.execute - placeholder implementation') + const blocksExist = await ctx.sources.dexieExport.tableExists('message_blocks') + if (!blocksExist) { + warnings.push('message_blocks.json not found - messages will have empty blocks') + } - // TODO: Implement when chat tables are created - // Use streaming JSON reader for large message files: - // - // const streamReader = _ctx.sources.dexieExport.createStreamReader('topics') - // await streamReader.readInBatches( - // BATCH_SIZE, - // async (topics, batchIndex) => { - // // 1. Insert topics - // // 2. Extract and insert messages from each topic - // // 3. Report progress - // } - // ) + // Step 2: Load all blocks into lookup map + // Blocks table is typically smaller than messages, safe to load entirely + if (blocksExist) { + logger.info('Loading message blocks into memory...') + const blocks = await ctx.sources.dexieExport.readTable('message_blocks') + this.blockLookup = buildBlockLookup(blocks) + logger.info(`Loaded ${this.blockLookup.size} blocks into lookup map`) + } - return { - success: true, - processedCount: 0 - } - } + // Step 3: Load assistant data for generating AssistantMeta + // Also extract topic metadata from assistants (Redux stores topic metadata in assistants.topics[]) + const assistantState = ctx.sources.reduxState.getCategory('assistants') + if (assistantState?.assistants) { + for (const assistant of assistantState.assistants) { + this.assistantLookup.set(assistant.id, assistant) - async validate(): Promise { - logger.info('ChatMigrator.validate - placeholder implementation') + // Extract topic metadata from this assistant's topics array + // Redux stores topic metadata (name, pinned, etc.) but with messages: [] + // Also track topic → assistantId mapping (Dexie doesn't store assistantId) + if (assistant.topics && Array.isArray(assistant.topics)) { + for (const topic of assistant.topics) { + if (topic.id) { + this.topicMetaLookup.set(topic.id, topic) + this.topicAssistantLookup.set(topic.id, assistant.id) + } + } + } + } + logger.info( + `Loaded ${this.assistantLookup.size} assistants and ${this.topicMetaLookup.size} topic metadata entries` + ) + } else { + warnings.push('No assistant data found - topics will have null assistantMeta and missing names') + } - // TODO: Implement when chat tables are created - // 1. Count validation for topics and messages - // 2. Sample validation (check a few topics have correct message counts) - // 3. Reference integrity validation + // Step 4: Count topics and estimate messages + const topicReader = ctx.sources.dexieExport.createStreamReader('topics') + this.topicCount = await topicReader.count() + logger.info(`Found ${this.topicCount} topics to migrate`) - return { - success: true, - errors: [], - stats: { - sourceCount: 0, - targetCount: 0, - skippedCount: 0 + // Estimate message count from sample + if (this.topicCount > 0) { + const sampleTopics = await topicReader.readSample(10) + const avgMessagesPerTopic = + sampleTopics.reduce((sum, t) => sum + (t.messages?.length || 0), 0) / sampleTopics.length + this.messageCount = Math.round(this.topicCount * avgMessagesPerTopic) + logger.info(`Estimated ${this.messageCount} messages based on sample`) + } + + // Step 5: Validate sample data + if (this.topicCount > 0) { + const sampleTopics = await topicReader.readSample(5) + for (const topic of sampleTopics) { + if (!topic.id) { + warnings.push(`Found topic without id - will be skipped`) + } + if (!topic.messages || !Array.isArray(topic.messages)) { + warnings.push(`Topic ${topic.id} has invalid messages array`) + } + } + } + + logger.info('Prepare phase completed', { + topics: this.topicCount, + estimatedMessages: this.messageCount, + blocks: this.blockLookup.size, + assistants: this.assistantLookup.size + }) + + return { + success: true, + itemCount: this.topicCount, + warnings: warnings.length > 0 ? warnings : undefined + } + } catch (error) { + logger.error('Prepare failed', error as Error) + return { + success: false, + itemCount: 0, + warnings: [error instanceof Error ? error.message : String(error)] } } } + + /** + * Execute phase - perform the actual data migration + * + * Processing strategy: + * 1. Stream topics in batches to control memory + * 2. For each topic batch: + * a. Transform topics and their messages + * b. Build message tree structure + * c. Insert topics in single transaction + * d. Insert messages in batched transactions + * 3. Report progress throughout + */ + async execute(ctx: MigrationContext): Promise { + if (this.topicCount === 0) { + logger.info('No topics to migrate') + return { success: true, processedCount: 0 } + } + + let processedTopics = 0 + let processedMessages = 0 + + try { + const db = ctx.db + const topicReader = ctx.sources.dexieExport.createStreamReader('topics') + + // Process topics in batches + await topicReader.readInBatches(TOPIC_BATCH_SIZE, async (topics, batchIndex) => { + logger.debug(`Processing topic batch ${batchIndex + 1}`, { count: topics.length }) + + // Transform all topics and messages in this batch + const preparedData: PreparedTopicData[] = [] + + for (const oldTopic of topics) { + try { + const prepared = this.prepareTopicData(oldTopic) + if (prepared) { + preparedData.push(prepared) + } else { + this.skippedTopics++ + } + } catch (error) { + logger.warn(`Failed to transform topic ${oldTopic.id}`, { error }) + this.skippedTopics++ + } + } + + // Insert topics in a transaction + if (preparedData.length > 0) { + // Collect all messages and handle duplicates BEFORE transaction + // This ensures parentId references are updated correctly + const allMessages: NewMessage[] = [] + const idRemapping = new Map() // oldId → newId for duplicates + const batchMessageIds = new Set() // IDs added in this batch (for transaction safety) + + for (const data of preparedData) { + for (const msg of data.messages) { + if (this.seenMessageIds.has(msg.id) || batchMessageIds.has(msg.id)) { + const newId = uuidv4() + logger.warn(`Duplicate message ID found: ${msg.id}, assigning new ID: ${newId}`) + idRemapping.set(msg.id, newId) + msg.id = newId + } + batchMessageIds.add(msg.id) + allMessages.push(msg) + } + } + + // Update parentId references for any remapped IDs + if (idRemapping.size > 0) { + for (const msg of allMessages) { + if (msg.parentId && idRemapping.has(msg.parentId)) { + msg.parentId = idRemapping.get(msg.parentId)! + } + } + } + + // Execute transaction + await db.transaction(async (tx) => { + // Insert topics + const topicValues = preparedData.map((d) => d.topic) + await tx.insert(topicTable).values(topicValues) + + // Insert messages in batches (SQLite parameter limit) + for (let i = 0; i < allMessages.length; i += MESSAGE_INSERT_BATCH_SIZE) { + const batch = allMessages.slice(i, i + MESSAGE_INSERT_BATCH_SIZE) + await tx.insert(messageTable).values(batch) + } + }) + + // Update state ONLY after transaction succeeds (transaction safety) + for (const id of batchMessageIds) { + this.seenMessageIds.add(id) + } + processedMessages += allMessages.length + processedTopics += preparedData.length + } + + // Report progress + const progress = Math.round((processedTopics / this.topicCount) * 100) + this.reportProgress( + progress, + `已迁移 ${processedTopics}/${this.topicCount} 个对话,${processedMessages} 条消息` + ) + }) + + logger.info('Execute completed', { + processedTopics, + processedMessages, + skippedTopics: this.skippedTopics, + skippedMessages: this.skippedMessages + }) + + // Log block statistics for diagnostics + logger.info('Block migration statistics', { + blocksRequested: this.blockStats.requested, + blocksResolved: this.blockStats.resolved, + blocksMissing: this.blockStats.requested - this.blockStats.resolved, + messagesWithEmptyBlocks: this.blockStats.messagesWithEmptyBlocks, + messagesWithMissingBlocks: this.blockStats.messagesWithMissingBlocks + }) + + return { + success: true, + processedCount: processedTopics + } + } catch (error) { + logger.error('Execute failed', error as Error) + return { + success: false, + processedCount: processedTopics, + error: error instanceof Error ? error.message : String(error) + } + } + } + + /** + * Validate phase - verify migrated data integrity + * + * Validation checks: + * 1. Topic count matches source (minus skipped) + * 2. Message count is within expected range + * 3. Sample topics have correct structure + * 4. Foreign key integrity (messages belong to existing topics) + */ + async validate(ctx: MigrationContext): Promise { + const errors: ValidationError[] = [] + const db = ctx.db + + try { + // Count topics in target + const topicResult = await db.select({ count: sql`count(*)` }).from(topicTable).get() + const targetTopicCount = topicResult?.count ?? 0 + + // Count messages in target + const messageResult = await db.select({ count: sql`count(*)` }).from(messageTable).get() + const targetMessageCount = messageResult?.count ?? 0 + + logger.info('Validation counts', { + sourceTopics: this.topicCount, + targetTopics: targetTopicCount, + skippedTopics: this.skippedTopics, + targetMessages: targetMessageCount + }) + + // Validate topic count + const expectedTopics = this.topicCount - this.skippedTopics + if (targetTopicCount < expectedTopics) { + errors.push({ + key: 'topic_count_low', + message: `Topic count too low: expected ${expectedTopics}, got ${targetTopicCount}` + }) + } else if (targetTopicCount > expectedTopics) { + // More topics than expected could indicate duplicate insertions or data corruption + logger.warn(`Topic count higher than expected: expected ${expectedTopics}, got ${targetTopicCount}`) + } + + // Sample validation: check a few topics have messages + const sampleTopics = await db.select().from(topicTable).limit(5).all() + for (const topic of sampleTopics) { + const msgCount = await db + .select({ count: sql`count(*)` }) + .from(messageTable) + .where(eq(messageTable.topicId, topic.id)) + .get() + + if (msgCount?.count === 0) { + // This is a warning, not an error - some topics may legitimately have no messages + logger.warn(`Topic ${topic.id} has no messages after migration`) + } + } + + // Check for orphan messages (messages without valid topic) + // This shouldn't happen due to foreign key constraints, but verify anyway + const orphanCheck = await db + .select({ count: sql`count(*)` }) + .from(messageTable) + .where(sql`${messageTable.topicId} NOT IN (SELECT id FROM ${topicTable})`) + .get() + + if (orphanCheck && orphanCheck.count > 0) { + errors.push({ + key: 'orphan_messages', + message: `Found ${orphanCheck.count} orphan messages without valid topics` + }) + } + + return { + success: errors.length === 0, + errors, + stats: { + sourceCount: this.topicCount, + targetCount: targetTopicCount, + skippedCount: this.skippedTopics + } + } + } catch (error) { + logger.error('Validation failed', error as Error) + return { + success: false, + errors: [ + { + key: 'validation', + message: error instanceof Error ? error.message : String(error) + } + ], + stats: { + sourceCount: this.topicCount, + targetCount: 0, + skippedCount: this.skippedTopics + } + } + } + } + + /** + * Prepare a single topic and its messages for migration + * + * @param oldTopic - Source topic from Dexie (has messages, may lack metadata) + * @returns Prepared data or null if topic should be skipped + * + * ## Data Merging + * + * Topic data comes from two sources: + * - Dexie `topics` table: Has `id`, `messages[]`, `assistantId` + * - Redux `assistants[].topics[]`: Has metadata (`name`, `pinned`, `prompt`, etc.) + * + * We merge Redux metadata into the Dexie topic before transformation. + */ + private prepareTopicData(oldTopic: OldTopic): PreparedTopicData | null { + // Validate required fields + if (!oldTopic.id) { + logger.warn('Topic missing id, skipping') + return null + } + + // Merge topic metadata from Redux (name, pinned, etc.) + // Dexie topics may have stale or missing metadata; Redux is authoritative for these fields + const topicMeta = this.topicMetaLookup.get(oldTopic.id) + if (topicMeta) { + // Merge Redux metadata into Dexie topic + // Note: Redux topic.name can also be empty from ancient version migrations (see store/migrate.ts:303-305) + oldTopic.name = topicMeta.name || oldTopic.name + oldTopic.pinned = topicMeta.pinned ?? oldTopic.pinned + oldTopic.prompt = topicMeta.prompt ?? oldTopic.prompt + oldTopic.isNameManuallyEdited = topicMeta.isNameManuallyEdited ?? oldTopic.isNameManuallyEdited + // Use Redux timestamps if available and Dexie lacks them + if (topicMeta.createdAt && !oldTopic.createdAt) { + oldTopic.createdAt = topicMeta.createdAt + } + if (topicMeta.updatedAt && !oldTopic.updatedAt) { + oldTopic.updatedAt = topicMeta.updatedAt + } + } + + // Fallback: If name is still empty after merge, use a default name + // This handles cases where both Dexie and Redux have empty names (ancient version bug) + if (!oldTopic.name) { + oldTopic.name = 'Unnamed Topic' // Default fallback for topics with no name + } + + // Get assistantId from Redux mapping (Dexie topics don't store assistantId) + // Fall back to oldTopic.assistantId in case Dexie did store it (defensive) + const assistantId = this.topicAssistantLookup.get(oldTopic.id) || oldTopic.assistantId + if (assistantId && !oldTopic.assistantId) { + oldTopic.assistantId = assistantId + } + + // Get assistant for meta generation + const assistant = this.assistantLookup.get(assistantId) || null + + // Get messages array (may be empty or undefined) + const oldMessages = oldTopic.messages || [] + + // Build message tree structure + const messageTree = buildMessageTree(oldMessages) + + // === First pass: identify messages to skip (no blocks) === + const skippedMessageIds = new Set() + const messageParentMap = new Map() // messageId -> parentId + + for (const oldMsg of oldMessages) { + const blockIds = oldMsg.blocks || [] + const blocks = resolveBlocks(blockIds, this.blockLookup) + + // Track block statistics for diagnostics + this.blockStats.requested += blockIds.length + this.blockStats.resolved += blocks.length + if (blockIds.length === 0) { + this.blockStats.messagesWithEmptyBlocks++ + } else if (blocks.length < blockIds.length) { + this.blockStats.messagesWithMissingBlocks++ + if (blocks.length === 0) { + logger.warn(`Message ${oldMsg.id} has ${blockIds.length} block IDs but none found in message_blocks`) + } + } + + // Store parent info from tree + const treeInfo = messageTree.get(oldMsg.id) + messageParentMap.set(oldMsg.id, treeInfo?.parentId ?? null) + + // Mark for skipping if no blocks + if (blocks.length === 0) { + skippedMessageIds.add(oldMsg.id) + this.skippedMessages++ + } + } + + // === Helper: resolve parent through skipped messages === + // If parentId points to a skipped message, follow the chain to find a non-skipped ancestor + const resolveParentId = (parentId: string | null): string | null => { + let currentParent = parentId + const visited = new Set() // Prevent infinite loops + + while (currentParent && skippedMessageIds.has(currentParent)) { + if (visited.has(currentParent)) { + // Circular reference, break out + return null + } + visited.add(currentParent) + currentParent = messageParentMap.get(currentParent) ?? null + } + + return currentParent + } + + // === Second pass: transform messages that have blocks === + const newMessages: NewMessage[] = [] + for (const oldMsg of oldMessages) { + // Skip messages marked for skipping + if (skippedMessageIds.has(oldMsg.id)) { + continue + } + + try { + const treeInfo = messageTree.get(oldMsg.id) + if (!treeInfo) { + logger.warn(`Message ${oldMsg.id} not found in tree, using defaults`) + continue + } + + // Resolve blocks for this message (we know it has blocks from first pass) + const blockIds = oldMsg.blocks || [] + const blocks = resolveBlocks(blockIds, this.blockLookup) + + // Resolve parentId through any skipped messages + const resolvedParentId = resolveParentId(treeInfo.parentId) + + // Get assistant for this message (may differ from topic's assistant) + const msgAssistant = this.assistantLookup.get(oldMsg.assistantId) || assistant + + const newMsg = transformMessage( + oldMsg, + resolvedParentId, // Use resolved parent instead of original + treeInfo.siblingsGroupId, + blocks, + msgAssistant, + oldTopic.id + ) + + newMessages.push(newMsg) + } catch (error) { + logger.warn(`Failed to transform message ${oldMsg.id}`, { error }) + this.skippedMessages++ + } + } + + // Calculate activeNodeId using smart selection logic + // Priority: 1) Original activeNode if migrated, 2) foldSelected if migrated, 3) last migrated + let activeNodeId: string | null = null + if (newMessages.length > 0) { + const migratedIds = new Set(newMessages.map((m) => m.id)) + + // Try to use the original active node (handles foldSelected for multi-model) + const originalActiveId = findActiveNodeId(oldMessages) + if (originalActiveId && migratedIds.has(originalActiveId)) { + activeNodeId = originalActiveId + } else { + // Original active was skipped; find a foldSelected among migrated messages + const foldSelectedMsg = oldMessages.find((m) => m.foldSelected && migratedIds.has(m.id)) + if (foldSelectedMsg) { + activeNodeId = foldSelectedMsg.id + } else { + // Fallback to last migrated message + activeNodeId = newMessages[newMessages.length - 1].id + } + } + } + + // Transform topic with correct activeNodeId + const newTopic = transformTopic(oldTopic, assistant, activeNodeId) + + return { + topic: newTopic, + messages: newMessages + } + } } diff --git a/src/main/data/migration/v2/migrators/README-ChatMigrator.md b/src/main/data/migration/v2/migrators/README-ChatMigrator.md new file mode 100644 index 0000000000..63e2053e73 --- /dev/null +++ b/src/main/data/migration/v2/migrators/README-ChatMigrator.md @@ -0,0 +1,138 @@ +# ChatMigrator + +The `ChatMigrator` handles the largest data migration task: topics and messages from Dexie/IndexedDB to SQLite. + +## Data Sources + +| Data | Source | File/Path | +|------|--------|-----------| +| Topics with messages | Dexie `topics` table | `topics.json` | +| Topic metadata (name, pinned, etc.) | Redux `assistants[].topics[]` | `ReduxStateReader.getCategory('assistants')` | +| Message blocks | Dexie `message_blocks` table | `message_blocks.json` | +| Assistants (for meta) | Redux `assistants` slice | `ReduxStateReader.getCategory('assistants')` | + +### Topic Data Split (Important!) + +The old system stores topic data in **two separate locations**: + +1. **Dexie `topics` table**: Contains only `id` and `messages[]` array (NO `assistantId`!) +2. **Redux `assistants[].topics[]`**: Contains metadata (`name`, `pinned`, `prompt`, `isNameManuallyEdited`) and implicitly the `assistantId` (from parent assistant) + +Redux deliberately clears `messages[]` to reduce storage size. The migrator merges these sources: +- Messages come from Dexie +- Metadata (name, pinned, etc.) comes from Redux +- `assistantId` comes from Redux structure (each assistant owns its topics) + +## Key Transformations + +1. **Linear → Tree Structure** + - Old: Messages stored as linear array in `topic.messages[]` + - New: Tree via `parentId` + `siblingsGroupId` + +2. **Multi-model Responses** + - Old: `askId` links responses to user message, `foldSelected` marks active + - New: Shared `parentId` + non-zero `siblingsGroupId` groups siblings + +3. **Block Inlining** + - Old: `message.blocks: string[]` (IDs) + separate `message_blocks` table + - New: `message.data.blocks: MessageDataBlock[]` (inline JSON) + +4. **Citation Migration** + - Old: Separate `CitationMessageBlock` with `response`, `knowledge`, `memories` + - New: Merged into `MainTextBlock.references` as `ContentReference[]` + +5. **Mention Migration** + - Old: `message.mentions: Model[]` + - New: `MentionReference[]` in `MainTextBlock.references` + +## Data Quality Handling + +The migrator handles potential data inconsistencies from the old system: + +| Issue | Detection | Handling | +|-------|-----------|----------| +| **Duplicate message ID** | Same ID appears in multiple topics | Generate new UUID, update parentId refs, log warning | +| **TopicId mismatch** | `message.topicId` ≠ parent `topic.id` | Use correct parent topic.id (silent fix) | +| **Missing blocks** | Block ID not found in `message_blocks` | Skip missing block (silent) | +| **Invalid topic** | Topic missing required `id` field | Skip entire topic | +| **Missing topic metadata** | Topic not found in Redux `assistants[].topics[]` | Use Dexie values, fallback name if empty | +| **Missing assistantId** | Topic not in any `assistant.topics[]` | `assistantId` and `assistantMeta` will be null | +| **Empty topic name** | Both Dexie and Redux have empty `name` (ancient bug) | Use fallback "Unnamed Topic" | +| **Message with no blocks** | `blocks` array is empty after resolution | Skip message, re-link children to parent's parent | +| **Topic with no messages** | All messages skipped (no blocks) | Keep topic, set `activeNodeId` to null | + +## Field Mappings + +### Topic Mapping + +Topic data is merged from Dexie + Redux before transformation: + +| Source | Target (topicTable) | Notes | +|--------|---------------------|-------| +| Dexie: `id` | `id` | Direct copy | +| Redux: `name` | `name` | Merged from Redux `assistants[].topics[]` | +| Redux: `isNameManuallyEdited` | `isNameManuallyEdited` | Merged from Redux | +| Redux: (parent assistant.id) | `assistantId` | From `topicAssistantLookup` mapping | +| (from Assistant) | `assistantMeta` | Generated from assistant entity | +| Redux: `prompt` | `prompt` | Merged from Redux | +| (computed) | `activeNodeId` | Smart selection: original active → foldSelected → last migrated | +| (none) | `groupId` | null (new field) | +| (none) | `sortOrder` | 0 (new field) | +| Redux: `pinned` | `isPinned` | Merged from Redux, renamed | +| (none) | `pinnedOrder` | 0 (new field) | +| `createdAt` | `createdAt` | ISO string → timestamp | +| `updatedAt` | `updatedAt` | ISO string → timestamp | + +**Dropped fields**: `type` ('chat' | 'session') + +### Message Mapping + +| Source (OldMessage) | Target (messageTable) | Notes | +|---------------------|----------------------|-------| +| `id` | `id` | Direct copy (new UUID if duplicate) | +| (computed) | `parentId` | From tree building algorithm | +| (from parent topic) | `topicId` | Uses parent topic.id for consistency | +| `role` | `role` | Direct copy | +| `blocks` + `mentions` + citations | `data` | Complex transformation | +| (extracted) | `searchableText` | Extracted from text blocks | +| `status` | `status` | Normalized to success/error/paused | +| (computed) | `siblingsGroupId` | From multi-model detection | +| `assistantId` | `assistantId` | Direct copy | +| `modelId` | `modelId` | Direct copy | +| (from Message.model) | `modelMeta` | Generated from model entity | +| `traceId` | `traceId` | Direct copy | +| `usage` + `metrics` | `stats` | Merged into single stats object | +| `createdAt` | `createdAt` | ISO string → timestamp | +| `updatedAt` | `updatedAt` | ISO string → timestamp | + +**Dropped fields**: `type`, `useful`, `enabledMCPs`, `agentSessionId`, `providerMetadata`, `multiModelMessageStyle`, `askId` (replaced by parentId), `foldSelected` (replaced by siblingsGroupId) + +### Block Type Mapping + +| Old Type | New Type | Notes | +|----------|----------|-------| +| `main_text` | `MainTextBlock` | Direct, references added from citations/mentions | +| `thinking` | `ThinkingBlock` | `thinking_millsec` → `thinkingMs` | +| `translation` | `TranslationBlock` | Direct copy | +| `code` | `CodeBlock` | Direct copy | +| `image` | `ImageBlock` | `file.id` → `fileId` | +| `file` | `FileBlock` | `file.id` → `fileId` | +| `video` | `VideoBlock` | Direct copy | +| `tool` | `ToolBlock` | Direct copy | +| `citation` | (removed) | Converted to `MainTextBlock.references` | +| `error` | `ErrorBlock` | Direct copy | +| `compact` | `CompactBlock` | Direct copy | +| `unknown` | (skipped) | Placeholder blocks are dropped | + +## Implementation Files + +- `ChatMigrator.ts` - Main migrator class with prepare/execute/validate phases +- `mappings/ChatMappings.ts` - Pure transformation functions and type definitions + +## Code Quality + +All implementation code includes detailed comments: +- File-level comments: Describe purpose, data flow, and overview +- Function-level comments: Purpose, parameters, return values, side effects +- Logic block comments: Step-by-step explanations for complex logic +- Data transformation comments: Old field → new field mapping relationships diff --git a/src/main/data/migration/v2/migrators/mappings/ChatMappings.ts b/src/main/data/migration/v2/migrators/mappings/ChatMappings.ts new file mode 100644 index 0000000000..99b4023c08 --- /dev/null +++ b/src/main/data/migration/v2/migrators/mappings/ChatMappings.ts @@ -0,0 +1,1168 @@ +/** + * Chat Mappings - Topic and Message transformation functions for Dexie → SQLite migration + * + * This file contains pure transformation functions that convert old data structures + * to new SQLite-compatible formats. All functions are stateless and side-effect free. + * + * ## Data Flow Overview: + * + * ### Topics: + * - Source: Redux `assistants.topics[]` + Dexie `topics` table (for messages) + * - Target: SQLite `topicTable` + * + * ### Messages: + * - Source: Dexie `topics.messages[]` (embedded in topic) + `message_blocks` table + * - Target: SQLite `messageTable` with inline blocks in `data.blocks` + * + * ## Key Transformations: + * + * 1. **Message Order → Tree Structure** + * - Old: Linear array `topic.messages[]` with array index as order + * - New: Tree via `parentId` + `siblingsGroupId` + * + * 2. **Multi-model Responses** + * - Old: Multiple messages share same `askId`, `foldSelected` marks active + * - New: Same `parentId` + non-zero `siblingsGroupId` groups siblings + * + * 3. **Block Storage** + * - Old: `message.blocks: string[]` (IDs) + separate `message_blocks` table + * - New: `message.data.blocks: MessageDataBlock[]` (inline JSON) + * + * 4. **Citations → References** + * - Old: Separate `CitationMessageBlock` with response/knowledge/memories + * - New: Merged into `MainTextBlock.references` as typed ContentReference[] + * + * 5. **Mentions → References** + * - Old: `message.mentions: Model[]` + * - New: `MentionReference[]` in `MainTextBlock.references` + * + * @since v2.0.0 + */ + +import type { + BlockType, + CitationReference, + CitationType, + CodeBlock, + CompactBlock, + ContentReference, + ErrorBlock, + FileBlock, + ImageBlock, + MainTextBlock, + MentionReference, + MessageData, + MessageDataBlock, + MessageStats, + ReferenceCategory, + ThinkingBlock, + ToolBlock, + TranslationBlock, + VideoBlock +} from '@shared/data/types/message' +import type { AssistantMeta, ModelMeta } from '@shared/data/types/meta' + +// ============================================================================ +// Old Type Definitions (Source Data Structures) +// ============================================================================ + +/** + * Old Topic type from Redux assistants slice + * Source: src/renderer/src/types/index.ts + */ +export interface OldTopic { + id: string + type?: 'chat' | 'session' // Dropped in new schema + assistantId: string + name: string + createdAt: string + updatedAt: string + messages: OldMessage[] + pinned?: boolean + prompt?: string + isNameManuallyEdited?: boolean +} + +/** + * Old Assistant type for extracting AssistantMeta + * Note: In Redux state, assistant.topics[] contains topic metadata (but with messages: []) + */ +export interface OldAssistant { + id: string + name: string + emoji?: string + type: string + topics?: OldTopicMeta[] // Topics are nested inside assistants in Redux +} + +/** + * Old Topic metadata from Redux assistants.topics[] + * + * Redux stores topic metadata (name, pinned, etc.) but clears messages[] to reduce storage. + * Dexie stores topics with messages[] but may have stale metadata. + * Migration merges: Redux metadata + Dexie messages. + */ +export interface OldTopicMeta { + id: string + name: string + pinned?: boolean + prompt?: string + isNameManuallyEdited?: boolean + createdAt?: string + updatedAt?: string +} + +/** + * Old Model type for extracting ModelMeta + */ +export interface OldModel { + id: string + name: string + provider: string + group: string +} + +/** + * Old Message type from Dexie topics table + * Source: src/renderer/src/types/newMessage.ts + */ +export interface OldMessage { + id: string + role: 'user' | 'assistant' | 'system' + assistantId: string + topicId: string + createdAt: string + updatedAt?: string + // Old status includes more values, we normalize to success/error/paused + status: 'sending' | 'pending' | 'searching' | 'processing' | 'success' | 'paused' | 'error' + + // Model info + modelId?: string + model?: OldModel + + // Multi-model response fields + askId?: string // Links to user message ID + foldSelected?: boolean // True if this is the selected response in fold view + multiModelMessageStyle?: string // UI state, dropped + + // Content + blocks: string[] // Block IDs referencing message_blocks table + + // Metadata + usage?: OldUsage + metrics?: OldMetrics + traceId?: string + + // Fields being transformed + mentions?: OldModel[] // → MentionReference in MainTextBlock.references + + // Dropped fields + type?: 'clear' | 'text' | '@' + useful?: boolean + enabledMCPs?: unknown[] + agentSessionId?: string + providerMetadata?: unknown +} + +/** + * Old Usage type for token consumption + */ +export interface OldUsage { + prompt_tokens?: number + completion_tokens?: number + total_tokens?: number + thoughts_tokens?: number + cost?: number +} + +/** + * Old Metrics type for performance measurement + */ +export interface OldMetrics { + completion_tokens?: number + time_completion_millsec?: number + time_first_token_millsec?: number + time_thinking_millsec?: number +} + +/** + * Old MessageBlock base type + */ +export interface OldMessageBlock { + id: string + messageId: string + type: string + createdAt: string + updatedAt?: string + status: string // Dropped in new schema + model?: OldModel // Dropped in new schema + metadata?: Record + error?: unknown +} + +/** + * Old MainTextMessageBlock + */ +export interface OldMainTextBlock extends OldMessageBlock { + type: 'main_text' + content: string + knowledgeBaseIds?: string[] // Dropped (deprecated) + citationReferences?: Array<{ + citationBlockId?: string + citationBlockSource?: string + }> // Dropped (replaced by references) +} + +/** + * Old ThinkingMessageBlock + */ +export interface OldThinkingBlock extends OldMessageBlock { + type: 'thinking' + content: string + thinking_millsec: number // → thinkingMs +} + +/** + * Old TranslationMessageBlock + */ +export interface OldTranslationBlock extends OldMessageBlock { + type: 'translation' + content: string + sourceBlockId?: string + sourceLanguage?: string + targetLanguage: string +} + +/** + * Old CodeMessageBlock + */ +export interface OldCodeBlock extends OldMessageBlock { + type: 'code' + content: string + language: string +} + +/** + * Old ImageMessageBlock + */ +export interface OldImageBlock extends OldMessageBlock { + type: 'image' + url?: string + file?: { id: string; [key: string]: unknown } // file.id → fileId +} + +/** + * Old FileMessageBlock + */ +export interface OldFileBlock extends OldMessageBlock { + type: 'file' + file: { id: string; [key: string]: unknown } // file.id → fileId +} + +/** + * Old VideoMessageBlock + */ +export interface OldVideoBlock extends OldMessageBlock { + type: 'video' + url?: string + filePath?: string +} + +/** + * Old ToolMessageBlock + */ +export interface OldToolBlock extends OldMessageBlock { + type: 'tool' + toolId: string + toolName?: string + arguments?: Record + content?: string | object +} + +/** + * Old CitationMessageBlock - contains web search, knowledge, and memory references + * This is the primary source for ContentReference transformation + */ +export interface OldCitationBlock extends OldMessageBlock { + type: 'citation' + response?: { + results?: unknown + source: unknown + } + knowledge?: Array<{ + id: number + content: string + sourceUrl: string + type: string + file?: unknown + metadata?: Record + }> + memories?: Array<{ + id: string + memory: string + hash?: string + createdAt?: string + updatedAt?: string + score?: number + metadata?: Record + }> +} + +/** + * Old ErrorMessageBlock + */ +export interface OldErrorBlock extends OldMessageBlock { + type: 'error' +} + +/** + * Old CompactMessageBlock + */ +export interface OldCompactBlock extends OldMessageBlock { + type: 'compact' + content: string + compactedContent: string +} + +/** + * Union of all old block types + */ +export type OldBlock = + | OldMainTextBlock + | OldThinkingBlock + | OldTranslationBlock + | OldCodeBlock + | OldImageBlock + | OldFileBlock + | OldVideoBlock + | OldToolBlock + | OldCitationBlock + | OldErrorBlock + | OldCompactBlock + | OldMessageBlock + +// ============================================================================ +// New Type Definitions (Target Data Structures) +// ============================================================================ + +/** + * New Topic for SQLite insertion + * Matches topicTable schema + */ +export interface NewTopic { + id: string + name: string | null + isNameManuallyEdited: boolean + assistantId: string | null + assistantMeta: AssistantMeta | null + prompt: string | null + activeNodeId: string | null + groupId: string | null + sortOrder: number + isPinned: boolean + pinnedOrder: number + createdAt: number // timestamp + updatedAt: number // timestamp +} + +/** + * New Message for SQLite insertion + * Matches messageTable schema + */ +export interface NewMessage { + id: string + parentId: string | null + topicId: string + role: string + data: MessageData + searchableText: string | null + status: 'success' | 'error' | 'paused' + siblingsGroupId: number + assistantId: string | null + assistantMeta: AssistantMeta | null + modelId: string | null + modelMeta: ModelMeta | null + traceId: string | null + stats: MessageStats | null + createdAt: number // timestamp + updatedAt: number // timestamp +} + +// ============================================================================ +// Topic Transformation Functions +// ============================================================================ + +/** + * Transform old Topic to new Topic format + * + * @param oldTopic - Source topic from Redux/Dexie + * @param assistant - Assistant entity for generating AssistantMeta + * @param activeNodeId - Last message ID to set as active node + * @returns New topic ready for SQLite insertion + * + * ## Field Mapping: + * | Source | Target | Notes | + * |--------|--------|-------| + * | id | id | Direct copy | + * | name | name | Direct copy | + * | isNameManuallyEdited | isNameManuallyEdited | Direct copy | + * | assistantId | assistantId | Direct copy | + * | (from Assistant) | assistantMeta | Generated from assistant entity | + * | prompt | prompt | Direct copy | + * | (computed) | activeNodeId | Last message ID | + * | (none) | groupId | null (new field) | + * | (none) | sortOrder | 0 (new field) | + * | pinned | isPinned | Renamed | + * | (none) | pinnedOrder | 0 (new field) | + * | createdAt | createdAt | ISO string → timestamp | + * | updatedAt | updatedAt | ISO string → timestamp | + * + * ## Dropped Fields: + * - type ('chat' | 'session'): No longer needed in new schema + */ +export function transformTopic( + oldTopic: OldTopic, + assistant: OldAssistant | null, + activeNodeId: string | null +): NewTopic { + return { + id: oldTopic.id, + name: oldTopic.name || null, + isNameManuallyEdited: oldTopic.isNameManuallyEdited ?? false, + assistantId: oldTopic.assistantId || null, + assistantMeta: assistant ? extractAssistantMeta(assistant) : null, + prompt: oldTopic.prompt || null, + activeNodeId, + groupId: null, // New field, no migration source + sortOrder: 0, // New field, default value + isPinned: oldTopic.pinned ?? false, + pinnedOrder: 0, // New field, default value + createdAt: parseTimestamp(oldTopic.createdAt), + updatedAt: parseTimestamp(oldTopic.updatedAt) + } +} + +/** + * Extract AssistantMeta from old Assistant entity + * + * AssistantMeta preserves display information when the original + * assistant is deleted, ensuring messages/topics remain readable. + * + * @param assistant - Source assistant entity + * @returns AssistantMeta for storage in topic/message + */ +export function extractAssistantMeta(assistant: OldAssistant): AssistantMeta { + return { + id: assistant.id, + name: assistant.name, + emoji: assistant.emoji, + type: assistant.type + } +} + +// ============================================================================ +// Message Transformation Functions +// ============================================================================ + +/** + * Transform old Message to new Message format + * + * This is the core message transformation function. It handles: + * - Status normalization + * - Block transformation (IDs → inline data) + * - Citation merging into references + * - Mention conversion to references + * - Stats merging (usage + metrics) + * + * @param oldMessage - Source message from Dexie + * @param parentId - Computed parent message ID (from tree building) + * @param siblingsGroupId - Computed siblings group ID (from multi-model detection) + * @param blocks - Resolved block data from message_blocks table + * @param assistant - Assistant entity for generating AssistantMeta + * @param correctTopicId - The correct topic ID (from parent topic, not from message) + * @returns New message ready for SQLite insertion + * + * ## Field Mapping: + * | Source | Target | Notes | + * |--------|--------|-------| + * | id | id | Direct copy | + * | (computed) | parentId | From tree building algorithm | + * | (parameter) | topicId | From correctTopicId param (ensures consistency) | + * | role | role | Direct copy | + * | blocks + mentions + citations | data | Complex transformation | + * | (extracted) | searchableText | Extracted from text blocks | + * | status | status | Normalized to success/error/paused | + * | (computed) | siblingsGroupId | From multi-model detection | + * | assistantId | assistantId | Direct copy | + * | (from Message.model) | assistantMeta | Generated if available | + * | modelId | modelId | Direct copy | + * | (from Message.model) | modelMeta | Generated from model entity | + * | traceId | traceId | Direct copy | + * | usage + metrics | stats | Merged into single stats object | + * | createdAt | createdAt | ISO string → timestamp | + * | updatedAt | updatedAt | ISO string → timestamp | + * + * ## Dropped Fields: + * - type ('clear' | 'text' | '@') + * - useful (boolean) + * - enabledMCPs (deprecated) + * - agentSessionId (session identifier) + * - providerMetadata (raw provider data) + * - multiModelMessageStyle (UI state) + * - askId (replaced by parentId) + * - foldSelected (replaced by siblingsGroupId) + */ +export function transformMessage( + oldMessage: OldMessage, + parentId: string | null, + siblingsGroupId: number, + blocks: OldBlock[], + assistant: OldAssistant | null, + correctTopicId: string +): NewMessage { + // Transform blocks and merge citations/mentions into references + const { dataBlocks, citationReferences, searchableText } = transformBlocks(blocks) + + // Convert mentions to MentionReferences + const mentionReferences = transformMentions(oldMessage.mentions) + + // Find the MainTextBlock and add references if any exist + const allReferences = [...citationReferences, ...mentionReferences] + if (allReferences.length > 0) { + const mainTextBlock = dataBlocks.find((b) => b.type === 'main_text') as MainTextBlock | undefined + if (mainTextBlock) { + mainTextBlock.references = allReferences + } + } + + return { + id: oldMessage.id, + parentId, + topicId: correctTopicId, + role: oldMessage.role, + data: { blocks: dataBlocks }, + searchableText: searchableText || null, + status: normalizeStatus(oldMessage.status), + siblingsGroupId, + assistantId: oldMessage.assistantId || null, + assistantMeta: assistant ? extractAssistantMeta(assistant) : null, + modelId: oldMessage.modelId || null, + modelMeta: oldMessage.model ? extractModelMeta(oldMessage.model) : null, + traceId: oldMessage.traceId || null, + stats: mergeStats(oldMessage.usage, oldMessage.metrics), + createdAt: parseTimestamp(oldMessage.createdAt), + updatedAt: parseTimestamp(oldMessage.updatedAt || oldMessage.createdAt) + } +} + +/** + * Extract ModelMeta from old Model entity + * + * ModelMeta preserves model display information when the original + * model configuration is removed or unavailable. + * + * @param model - Source model entity + * @returns ModelMeta for storage in message + */ +export function extractModelMeta(model: OldModel): ModelMeta { + return { + id: model.id, + name: model.name, + provider: model.provider, + group: model.group + } +} + +/** + * Normalize old status values to new enum + * + * Old system has multiple transient states that don't apply to stored messages. + * We normalize these to the three final states in the new schema. + * + * @param oldStatus - Status from old message + * @returns Normalized status for new message + * + * ## Mapping: + * - 'success' → 'success' + * - 'error' → 'error' + * - 'paused' → 'paused' + * - 'sending', 'pending', 'searching', 'processing' → 'success' (completed states) + */ +export function normalizeStatus(oldStatus: OldMessage['status']): 'success' | 'error' | 'paused' { + switch (oldStatus) { + case 'error': + return 'error' + case 'paused': + return 'paused' + case 'success': + case 'sending': + case 'pending': + case 'searching': + case 'processing': + default: + // All transient states are treated as success for stored messages + // If a message was in a transient state during export, it completed + return 'success' + } +} + +/** + * Merge old usage and metrics into new MessageStats + * + * The old system stored token usage and performance metrics in separate objects. + * The new schema combines them into a single stats object. + * + * @param usage - Token usage data from old message + * @param metrics - Performance metrics from old message + * @returns Combined MessageStats or null if no data + * + * ## Field Mapping: + * | Source | Target | + * |--------|--------| + * | usage.prompt_tokens | promptTokens | + * | usage.completion_tokens | completionTokens | + * | usage.total_tokens | totalTokens | + * | usage.thoughts_tokens | thoughtsTokens | + * | usage.cost | cost | + * | metrics.time_first_token_millsec | timeFirstTokenMs | + * | metrics.time_completion_millsec | timeCompletionMs | + * | metrics.time_thinking_millsec | timeThinkingMs | + */ +export function mergeStats(usage?: OldUsage, metrics?: OldMetrics): MessageStats | null { + if (!usage && !metrics) return null + + const stats: MessageStats = {} + + // Token usage + if (usage) { + if (usage.prompt_tokens !== undefined) stats.promptTokens = usage.prompt_tokens + if (usage.completion_tokens !== undefined) stats.completionTokens = usage.completion_tokens + if (usage.total_tokens !== undefined) stats.totalTokens = usage.total_tokens + if (usage.thoughts_tokens !== undefined) stats.thoughtsTokens = usage.thoughts_tokens + if (usage.cost !== undefined) stats.cost = usage.cost + } + + // Performance metrics + if (metrics) { + if (metrics.time_first_token_millsec !== undefined) stats.timeFirstTokenMs = metrics.time_first_token_millsec + if (metrics.time_completion_millsec !== undefined) stats.timeCompletionMs = metrics.time_completion_millsec + if (metrics.time_thinking_millsec !== undefined) stats.timeThinkingMs = metrics.time_thinking_millsec + } + + // Return null if no data was actually added + return Object.keys(stats).length > 0 ? stats : null +} + +// ============================================================================ +// Block Transformation Functions +// ============================================================================ + +/** + * Transform old blocks to new format and extract citation references + * + * This function: + * 1. Converts each old block to new format (removing id, messageId, status) + * 2. Extracts CitationMessageBlocks and converts to ContentReference[] + * 3. Extracts searchable text from text-based blocks + * + * @param oldBlocks - Array of old blocks from message_blocks table + * @returns Object containing: + * - dataBlocks: Transformed blocks (excluding CitationBlocks) + * - citationReferences: Extracted citation references + * - searchableText: Combined searchable text + * + * ## Block Type Mapping: + * | Old Type | New Type | Notes | + * |----------|----------|-------| + * | main_text | MainTextBlock | Direct, references added later | + * | thinking | ThinkingBlock | thinking_millsec → thinkingMs | + * | translation | TranslationBlock | Direct copy | + * | code | CodeBlock | Direct copy | + * | image | ImageBlock | file.id → fileId | + * | file | FileBlock | file.id → fileId | + * | video | VideoBlock | Direct copy | + * | tool | ToolBlock | Direct copy | + * | citation | (removed) | Converted to MainTextBlock.references | + * | error | ErrorBlock | Direct copy | + * | compact | CompactBlock | Direct copy | + * | unknown | (skipped) | Placeholder blocks are dropped | + */ +export function transformBlocks(oldBlocks: OldBlock[]): { + dataBlocks: MessageDataBlock[] + citationReferences: ContentReference[] + searchableText: string +} { + const dataBlocks: MessageDataBlock[] = [] + const citationReferences: ContentReference[] = [] + const searchableTexts: string[] = [] + + for (const oldBlock of oldBlocks) { + const transformed = transformSingleBlock(oldBlock) + + if (transformed.block) { + dataBlocks.push(transformed.block) + } + + if (transformed.citations) { + citationReferences.push(...transformed.citations) + } + + if (transformed.searchableText) { + searchableTexts.push(transformed.searchableText) + } + } + + return { + dataBlocks, + citationReferences, + searchableText: searchableTexts.join('\n') + } +} + +/** + * Transform a single old block to new format + * + * @param oldBlock - Single old block + * @returns Transformed block and extracted data + */ +function transformSingleBlock(oldBlock: OldBlock): { + block: MessageDataBlock | null + citations: ContentReference[] | null + searchableText: string | null +} { + const baseFields = { + createdAt: parseTimestamp(oldBlock.createdAt), + updatedAt: oldBlock.updatedAt ? parseTimestamp(oldBlock.updatedAt) : undefined, + metadata: oldBlock.metadata, + error: oldBlock.error as MessageDataBlock['error'] + } + + switch (oldBlock.type) { + case 'main_text': { + const block = oldBlock as OldMainTextBlock + return { + block: { + type: 'main_text' as BlockType.MAIN_TEXT, + content: block.content, + ...baseFields + // knowledgeBaseIds and citationReferences are intentionally dropped + // References will be added from CitationBlocks and mentions + } as MainTextBlock, + citations: null, + searchableText: block.content + } + } + + case 'thinking': { + const block = oldBlock as OldThinkingBlock + return { + block: { + type: 'thinking' as BlockType.THINKING, + content: block.content, + thinkingMs: block.thinking_millsec, // Field rename + ...baseFields + } as ThinkingBlock, + citations: null, + searchableText: block.content + } + } + + case 'translation': { + const block = oldBlock as OldTranslationBlock + return { + block: { + type: 'translation' as BlockType.TRANSLATION, + content: block.content, + sourceBlockId: block.sourceBlockId, + sourceLanguage: block.sourceLanguage, + targetLanguage: block.targetLanguage, + ...baseFields + } as TranslationBlock, + citations: null, + searchableText: block.content + } + } + + case 'code': { + const block = oldBlock as OldCodeBlock + return { + block: { + type: 'code' as BlockType.CODE, + content: block.content, + language: block.language, + ...baseFields + } as CodeBlock, + citations: null, + searchableText: block.content + } + } + + case 'image': { + const block = oldBlock as OldImageBlock + return { + block: { + type: 'image' as BlockType.IMAGE, + url: block.url, + fileId: block.file?.id, // file.id → fileId + ...baseFields + } as ImageBlock, + citations: null, + searchableText: null + } + } + + case 'file': { + const block = oldBlock as OldFileBlock + return { + block: { + type: 'file' as BlockType.FILE, + fileId: block.file.id, // file.id → fileId + ...baseFields + } as FileBlock, + citations: null, + searchableText: null + } + } + + case 'video': { + const block = oldBlock as OldVideoBlock + return { + block: { + type: 'video' as BlockType.VIDEO, + url: block.url, + filePath: block.filePath, + ...baseFields + } as VideoBlock, + citations: null, + searchableText: null + } + } + + case 'tool': { + const block = oldBlock as OldToolBlock + return { + block: { + type: 'tool' as BlockType.TOOL, + toolId: block.toolId, + toolName: block.toolName, + arguments: block.arguments, + content: block.content, + ...baseFields + } as ToolBlock, + citations: null, + searchableText: null + } + } + + case 'citation': { + // CitationBlocks are NOT converted to blocks + // Instead, their content is extracted as ContentReferences + const block = oldBlock as OldCitationBlock + const citations = extractCitationReferences(block) + return { + block: null, // No block output + citations, + searchableText: null + } + } + + case 'error': { + return { + block: { + type: 'error' as BlockType.ERROR, + ...baseFields + } as ErrorBlock, + citations: null, + searchableText: null + } + } + + case 'compact': { + const block = oldBlock as OldCompactBlock + return { + block: { + type: 'compact' as BlockType.COMPACT, + content: block.content, + compactedContent: block.compactedContent, + ...baseFields + } as CompactBlock, + citations: null, + searchableText: block.content + } + } + + case 'unknown': + default: + // Skip unknown/placeholder blocks + return { + block: null, + citations: null, + searchableText: null + } + } +} + +/** + * Extract ContentReferences from old CitationMessageBlock + * + * Old CitationBlocks contain three types of citations: + * - response (web search results) → WebCitationReference + * - knowledge (knowledge base refs) → KnowledgeCitationReference + * - memories (memory items) → MemoryCitationReference + * + * @param citationBlock - Old CitationMessageBlock + * @returns Array of ContentReferences + */ +export function extractCitationReferences(citationBlock: OldCitationBlock): ContentReference[] { + const references: ContentReference[] = [] + + // Web search citations + if (citationBlock.response) { + references.push({ + category: 'citation' as ReferenceCategory.CITATION, + citationType: 'web' as CitationType.WEB, + content: { + results: citationBlock.response.results, + source: citationBlock.response.source + } + } as CitationReference) + } + + // Knowledge base citations + if (citationBlock.knowledge && citationBlock.knowledge.length > 0) { + references.push({ + category: 'citation' as ReferenceCategory.CITATION, + citationType: 'knowledge' as CitationType.KNOWLEDGE, + content: citationBlock.knowledge.map((k) => ({ + id: k.id, + content: k.content, + sourceUrl: k.sourceUrl, + type: k.type, + file: k.file, + metadata: k.metadata + })) + } as CitationReference) + } + + // Memory citations + if (citationBlock.memories && citationBlock.memories.length > 0) { + references.push({ + category: 'citation' as ReferenceCategory.CITATION, + citationType: 'memory' as CitationType.MEMORY, + content: citationBlock.memories.map((m) => ({ + id: m.id, + memory: m.memory, + hash: m.hash, + createdAt: m.createdAt, + updatedAt: m.updatedAt, + score: m.score, + metadata: m.metadata + })) + } as CitationReference) + } + + return references +} + +/** + * Transform old mentions to MentionReferences + * + * Old system stored @mentions as a Model[] array on the message. + * New system stores them as MentionReference[] in MainTextBlock.references. + * + * @param mentions - Array of mentioned models from old message + * @returns Array of MentionReferences + * + * ## Transformation: + * | Old Field | New Field | + * |-----------|-----------| + * | model.id | modelId | + * | model.name | displayName | + */ +export function transformMentions(mentions?: OldModel[]): MentionReference[] { + if (!mentions || mentions.length === 0) return [] + + return mentions.map((model) => ({ + category: 'mention' as ReferenceCategory.MENTION, + modelId: model.id, + displayName: model.name + })) +} + +// ============================================================================ +// Tree Building Functions +// ============================================================================ + +/** + * Build message tree structure from linear message array + * + * The old system stores messages in a linear array. The new system uses + * a tree structure with parentId for navigation. + * + * ## Algorithm: + * 1. Process messages in array order (which is the conversation order) + * 2. For each message: + * - If it's a user message or first message, parent is the previous message + * - If it's an assistant message with askId, link to that user message + * - If multiple messages share same askId, they form a siblings group + * + * @param messages - Messages in array order from old topic + * @returns Map of messageId → { parentId, siblingsGroupId } + * + * ## Example: + * ``` + * Input: [u1, a1, u2, a2, a3(askId=u2,foldSelected), a4(askId=u2), u3] + * + * Output: + * u1: { parentId: null, siblingsGroupId: 0 } + * a1: { parentId: 'u1', siblingsGroupId: 0 } + * u2: { parentId: 'a1', siblingsGroupId: 0 } + * a2: { parentId: 'u2', siblingsGroupId: 1 } // Multi-model group + * a3: { parentId: 'u2', siblingsGroupId: 1 } // Selected one + * a4: { parentId: 'u2', siblingsGroupId: 1 } + * u3: { parentId: 'a3', siblingsGroupId: 0 } // Links to foldSelected + * ``` + */ +export function buildMessageTree( + messages: OldMessage[] +): Map { + const result = new Map() + + if (messages.length === 0) return result + + // Track askId → siblingsGroupId mapping + // Each unique askId with multiple responses gets a unique siblingsGroupId + const askIdToGroupId = new Map() + const askIdCounts = new Map() + + // First pass: count messages per askId to identify multi-model responses + for (const msg of messages) { + if (msg.askId) { + askIdCounts.set(msg.askId, (askIdCounts.get(msg.askId) || 0) + 1) + } + } + + // Assign group IDs to askIds with multiple responses + let nextGroupId = 1 + for (const [askId, count] of askIdCounts) { + if (count > 1) { + askIdToGroupId.set(askId, nextGroupId++) + } + } + + // Second pass: build parent/sibling relationships + let previousMessageId: string | null = null + let lastNonGroupMessageId: string | null = null // Last message not in a group, for linking subsequent user messages + + for (let i = 0; i < messages.length; i++) { + const msg = messages[i] + let parentId: string | null = null + let siblingsGroupId = 0 + + if (msg.askId && askIdToGroupId.has(msg.askId)) { + // This is part of a multi-model response group + parentId = msg.askId // Parent is the user message + siblingsGroupId = askIdToGroupId.get(msg.askId)! + + // If this is the selected response, update lastNonGroupMessageId for subsequent user messages + if (msg.foldSelected) { + lastNonGroupMessageId = msg.id + } + } else if (msg.role === 'user' && lastNonGroupMessageId) { + // User message after a multi-model group links to the selected response + parentId = lastNonGroupMessageId + lastNonGroupMessageId = null + } else { + // Normal sequential message - parent is previous message + parentId = previousMessageId + } + + result.set(msg.id, { parentId, siblingsGroupId }) + + // Update tracking for next iteration + previousMessageId = msg.id + + // Update lastNonGroupMessageId for non-group messages + if (siblingsGroupId === 0) { + lastNonGroupMessageId = msg.id + } + } + + return result +} + +/** + * Find the activeNodeId for a topic + * + * The activeNodeId should be the last message in the main conversation thread. + * For multi-model responses, it should be the foldSelected one. + * + * @param messages - Messages in array order + * @returns The ID of the last message (or foldSelected if applicable) + */ +export function findActiveNodeId(messages: OldMessage[]): string | null { + if (messages.length === 0) return null + + // Find the last message + // If it's part of a multi-model group, find the foldSelected one + const lastMsg = messages[messages.length - 1] + + if (lastMsg.askId) { + // Check if there's a foldSelected message with the same askId + const selectedMsg = messages.find((m) => m.askId === lastMsg.askId && m.foldSelected) + if (selectedMsg) return selectedMsg.id + } + + return lastMsg.id +} + +// ============================================================================ +// Utility Functions +// ============================================================================ + +/** + * Parse ISO timestamp string to Unix timestamp (milliseconds) + * + * @param isoString - ISO 8601 timestamp string or undefined + * @returns Unix timestamp in milliseconds + */ +export function parseTimestamp(isoString: string | undefined): number { + if (!isoString) return Date.now() + + const parsed = new Date(isoString).getTime() + return isNaN(parsed) ? Date.now() : parsed +} + +/** + * Build block lookup map from message_blocks table + * + * Creates a Map of blockId → block for fast lookup during message transformation. + * + * @param blocks - All blocks from message_blocks table + * @returns Map for O(1) block lookup + */ +export function buildBlockLookup(blocks: OldBlock[]): Map { + const lookup = new Map() + for (const block of blocks) { + lookup.set(block.id, block) + } + return lookup +} + +/** + * Resolve block IDs to actual block data + * + * @param blockIds - Array of block IDs from message.blocks + * @param blockLookup - Map of blockId → block + * @returns Array of resolved blocks (missing blocks are skipped) + */ +export function resolveBlocks(blockIds: string[], blockLookup: Map): OldBlock[] { + const resolved: OldBlock[] = [] + for (const id of blockIds) { + const block = blockLookup.get(id) + if (block) { + resolved.push(block) + } + } + return resolved +} diff --git a/src/main/data/services/MessageService.ts b/src/main/data/services/MessageService.ts new file mode 100644 index 0000000000..1dfb7f378e --- /dev/null +++ b/src/main/data/services/MessageService.ts @@ -0,0 +1,815 @@ +/** + * Message Service - handles message CRUD and tree operations + * + * Provides business logic for: + * - Tree visualization queries + * - Branch message queries with pagination + * - Message CRUD with tree structure maintenance + * - Cascade delete and reparenting + */ + +import { dbService } from '@data/db/DbService' +import { messageTable } from '@data/db/schemas/message' +import { topicTable } from '@data/db/schemas/topic' +import { loggerService } from '@logger' +import { DataApiErrorFactory } from '@shared/data/api' +import type { + ActiveNodeStrategy, + CreateMessageDto, + DeleteMessageResponse, + UpdateMessageDto +} from '@shared/data/api/schemas/messages' +import type { + BranchMessage, + BranchMessagesResponse, + Message, + SiblingsGroup, + TreeNode, + TreeResponse +} from '@shared/data/types/message' +import { and, eq, inArray, isNull, or, sql } from 'drizzle-orm' + +const logger = loggerService.withContext('DataApi:MessageService') + +/** + * Preview length for tree nodes + */ +const PREVIEW_LENGTH = 50 + +/** + * Default pagination limit + */ +const DEFAULT_LIMIT = 20 + +/** + * Convert database row to Message entity + */ +function rowToMessage(row: typeof messageTable.$inferSelect): Message { + return { + id: row.id, + topicId: row.topicId, + parentId: row.parentId, + role: row.role as Message['role'], + data: row.data, + searchableText: row.searchableText, + status: row.status as Message['status'], + siblingsGroupId: row.siblingsGroupId ?? 0, + assistantId: row.assistantId, + assistantMeta: row.assistantMeta, + modelId: row.modelId, + modelMeta: row.modelMeta, + traceId: row.traceId, + stats: row.stats, + createdAt: row.createdAt ? new Date(row.createdAt).toISOString() : new Date().toISOString(), + updatedAt: row.updatedAt ? new Date(row.updatedAt).toISOString() : new Date().toISOString() + } +} + +/** + * Extract preview text from message data + */ +function extractPreview(message: Message): string { + const blocks = message.data?.blocks || [] + for (const block of blocks) { + if ('content' in block && typeof block.content === 'string') { + const text = block.content.trim() + if (text.length > 0) { + return text.length > PREVIEW_LENGTH ? text.substring(0, PREVIEW_LENGTH) + '...' : text + } + } + } + return '' +} + +/** + * Convert Message to TreeNode + */ +function messageToTreeNode(message: Message, hasChildren: boolean): TreeNode { + return { + id: message.id, + parentId: message.parentId, + role: message.role === 'system' ? 'assistant' : message.role, + preview: extractPreview(message), + modelId: message.modelId, + modelMeta: message.modelMeta, + status: message.status, + createdAt: message.createdAt, + hasChildren + } +} + +export class MessageService { + private static instance: MessageService + + private constructor() {} + + public static getInstance(): MessageService { + if (!MessageService.instance) { + MessageService.instance = new MessageService() + } + return MessageService.instance + } + + /** + * Get tree structure for visualization + * + * Optimized to avoid loading all messages: + * 1. Uses CTE to get active path (single query) + * 2. Uses CTE to get tree nodes within depth limit (single query) + * 3. Fetches additional nodes for active path if beyond depth limit + */ + async getTree( + topicId: string, + options: { rootId?: string; nodeId?: string; depth?: number } = {} + ): Promise { + const db = dbService.getDb() + const { depth = 1 } = options + + // Get topic to verify existence and get activeNodeId + const [topic] = await db.select().from(topicTable).where(eq(topicTable.id, topicId)).limit(1) + + if (!topic) { + throw DataApiErrorFactory.notFound('Topic', topicId) + } + + const activeNodeId = options.nodeId || topic.activeNodeId + + // Find root node if not specified + let rootId = options.rootId + if (!rootId) { + const [root] = await db + .select({ id: messageTable.id }) + .from(messageTable) + .where(and(eq(messageTable.topicId, topicId), sql`${messageTable.parentId} IS NULL`)) + .limit(1) + rootId = root?.id + } + + if (!rootId) { + return { nodes: [], siblingsGroups: [], activeNodeId: null } + } + + // Build active path via CTE (single query) + const activePath = new Set() + if (activeNodeId) { + const pathRows = await db.all<{ id: string }>(sql` + WITH RECURSIVE path AS ( + SELECT id, parent_id FROM message WHERE id = ${activeNodeId} + UNION ALL + SELECT m.id, m.parent_id FROM message m + INNER JOIN path p ON m.id = p.parent_id + ) + SELECT id FROM path + `) + pathRows.forEach((r) => activePath.add(r.id)) + } + + // Get tree with depth limit via CTE + // Use a large depth for unlimited (-1) + const maxDepth = depth === -1 ? 999 : depth + + const treeRows = await db.all(sql` + WITH RECURSIVE tree AS ( + SELECT *, 0 as tree_depth FROM message WHERE id = ${rootId} + UNION ALL + SELECT m.*, t.tree_depth + 1 FROM message m + INNER JOIN tree t ON m.parent_id = t.id + WHERE t.tree_depth < ${maxDepth} + ) + SELECT * FROM tree + `) + + // Also fetch active path nodes that might be beyond depth limit + const treeNodeIds = new Set(treeRows.map((r) => r.id)) + const missingActivePathIds = [...activePath].filter((id) => !treeNodeIds.has(id)) + + if (missingActivePathIds.length > 0) { + const additionalRows = await db.select().from(messageTable).where(inArray(messageTable.id, missingActivePathIds)) + treeRows.push(...additionalRows.map((r) => ({ ...r, tree_depth: maxDepth + 1 }))) + } + + // Also need children of active path nodes for proper tree building + // Get all children of active path nodes that we haven't loaded yet + const activePathArray = [...activePath] + if (activePathArray.length > 0 && treeNodeIds.size > 0) { + const childrenRows = await db + .select() + .from(messageTable) + .where( + and( + inArray(messageTable.parentId, activePathArray), + sql`${messageTable.id} NOT IN (${sql.join( + [...treeNodeIds].map((id) => sql`${id}`), + sql`, ` + )})` + ) + ) + + for (const row of childrenRows) { + if (!treeNodeIds.has(row.id)) { + treeRows.push({ ...row, tree_depth: maxDepth + 1 }) + treeNodeIds.add(row.id) + } + } + } else if (activePathArray.length > 0) { + // No tree nodes loaded yet, just get all children of active path + const childrenRows = await db.select().from(messageTable).where(inArray(messageTable.parentId, activePathArray)) + + for (const row of childrenRows) { + if (!treeNodeIds.has(row.id)) { + treeRows.push({ ...row, tree_depth: maxDepth + 1 }) + treeNodeIds.add(row.id) + } + } + } + + if (treeRows.length === 0) { + return { nodes: [], siblingsGroups: [], activeNodeId: null } + } + + // Build maps for tree processing + const messagesById = new Map() + const childrenMap = new Map() + const depthMap = new Map() + + for (const row of treeRows) { + const message = rowToMessage(row) + messagesById.set(message.id, message) + depthMap.set(message.id, row.tree_depth) + + const parentId = message.parentId || 'root' + if (!childrenMap.has(parentId)) { + childrenMap.set(parentId, []) + } + childrenMap.get(parentId)!.push(message.id) + } + + // Collect nodes based on depth + const resultNodes: TreeNode[] = [] + const siblingsGroups: SiblingsGroup[] = [] + const visitedGroups = new Set() + + const collectNodes = (nodeId: string, currentDepth: number, isOnActivePath: boolean) => { + const message = messagesById.get(nodeId) + if (!message) return + + const children = childrenMap.get(nodeId) || [] + const hasChildren = children.length > 0 + + // Check if this message is part of a siblings group + if (message.siblingsGroupId !== 0) { + const groupKey = `${message.parentId}-${message.siblingsGroupId}` + if (!visitedGroups.has(groupKey)) { + visitedGroups.add(groupKey) + + // Find all siblings in this group + const parentChildren = childrenMap.get(message.parentId || 'root') || [] + const groupMembers = parentChildren + .map((id) => messagesById.get(id)!) + .filter((m) => m && m.siblingsGroupId === message.siblingsGroupId) + + if (groupMembers.length > 1) { + siblingsGroups.push({ + parentId: message.parentId!, + siblingsGroupId: message.siblingsGroupId, + nodes: groupMembers.map((m) => { + const memberChildren = childrenMap.get(m.id) || [] + const node = messageToTreeNode(m, memberChildren.length > 0) + const { parentId: _parentId, ...rest } = node + void _parentId // Intentionally unused - removing parentId from TreeNode for SiblingsGroup + return rest + }) + }) + } else { + // Single member, add as regular node + resultNodes.push(messageToTreeNode(message, hasChildren)) + } + } + } else { + resultNodes.push(messageToTreeNode(message, hasChildren)) + } + + // Recurse to children + const shouldExpand = isOnActivePath || (depth === -1 ? true : currentDepth < depth) + if (shouldExpand) { + for (const childId of children) { + const childOnPath = activePath.has(childId) + collectNodes(childId, isOnActivePath ? 0 : currentDepth + 1, childOnPath) + } + } + } + + // Start from root + collectNodes(rootId, 0, activePath.has(rootId)) + + return { + nodes: resultNodes, + siblingsGroups, + activeNodeId + } + } + + /** + * Get branch messages for conversation view + * + * Optimized implementation using recursive CTE to fetch only the path + * from nodeId to root, avoiding loading all messages for large topics. + * Siblings are batch-queried in a single additional query. + * + * Uses "before cursor" pagination semantics: + * - cursor: Message ID marking the pagination boundary (exclusive) + * - Returns messages BEFORE the cursor (towards root) + * - The cursor message itself is NOT included + * - nextCursor points to the oldest message in current batch + * + * Example flow: + * 1. First request (no cursor) → returns msg80-99, nextCursor=msg80.id + * 2. Second request (cursor=msg80.id) → returns msg60-79, nextCursor=msg60.id + */ + async getBranchMessages( + topicId: string, + options: { nodeId?: string; cursor?: string; limit?: number; includeSiblings?: boolean } = {} + ): Promise { + const db = dbService.getDb() + const { cursor, limit = DEFAULT_LIMIT, includeSiblings = true } = options + + // Get topic + const [topic] = await db.select().from(topicTable).where(eq(topicTable.id, topicId)).limit(1) + + if (!topic) { + throw DataApiErrorFactory.notFound('Topic', topicId) + } + + const nodeId = options.nodeId || topic.activeNodeId + + // Return empty if no active node + if (!nodeId) { + return { items: [], nextCursor: undefined, activeNodeId: null } + } + + // Use recursive CTE to get path from nodeId to root (single query) + const pathMessages = await db.all(sql` + WITH RECURSIVE path AS ( + SELECT * FROM message WHERE id = ${nodeId} + UNION ALL + SELECT m.* FROM message m + INNER JOIN path p ON m.id = p.parent_id + ) + SELECT * FROM path + `) + + if (pathMessages.length === 0) { + throw DataApiErrorFactory.notFound('Message', nodeId) + } + + // Reverse to get root->nodeId order + const fullPath = pathMessages.reverse() + + // Apply pagination + let startIndex = 0 + let endIndex = fullPath.length + + if (cursor) { + const cursorIndex = fullPath.findIndex((m) => m.id === cursor) + if (cursorIndex === -1) { + throw DataApiErrorFactory.notFound('Message (cursor)', cursor) + } + startIndex = Math.max(0, cursorIndex - limit) + endIndex = cursorIndex + } else { + startIndex = Math.max(0, fullPath.length - limit) + } + + const paginatedPath = fullPath.slice(startIndex, endIndex) + + // Calculate nextCursor: if there are more historical messages + const nextCursor = startIndex > 0 ? fullPath[startIndex].id : undefined + + // Build result with optional siblings + const result: BranchMessage[] = [] + + if (includeSiblings) { + // Collect unique (parentId, siblingsGroupId) pairs that need siblings + const uniqueGroups = new Set() + const groupsToQuery: Array<{ parentId: string; siblingsGroupId: number }> = [] + + for (const msg of paginatedPath) { + if (msg.siblingsGroupId && msg.siblingsGroupId !== 0 && msg.parentId) { + const key = `${msg.parentId}-${msg.siblingsGroupId}` + if (!uniqueGroups.has(key)) { + uniqueGroups.add(key) + groupsToQuery.push({ parentId: msg.parentId, siblingsGroupId: msg.siblingsGroupId }) + } + } + } + + // Batch query all siblings if needed + const siblingsMap = new Map() + + if (groupsToQuery.length > 0) { + // Build OR conditions for batch query + const orConditions = groupsToQuery.map((g) => + and(eq(messageTable.parentId, g.parentId), eq(messageTable.siblingsGroupId, g.siblingsGroupId)) + ) + + const siblingsRows = await db + .select() + .from(messageTable) + .where(or(...orConditions)) + + // Group results by parentId-siblingsGroupId + for (const row of siblingsRows) { + const key = `${row.parentId}-${row.siblingsGroupId}` + if (!siblingsMap.has(key)) siblingsMap.set(key, []) + siblingsMap.get(key)!.push(rowToMessage(row)) + } + } + + // Build result with siblings from map + for (const msg of paginatedPath) { + const message = rowToMessage(msg) + let siblingsGroup: Message[] | undefined + + if (msg.siblingsGroupId !== 0 && msg.parentId) { + const key = `${msg.parentId}-${msg.siblingsGroupId}` + const group = siblingsMap.get(key) + if (group && group.length > 1) { + siblingsGroup = group.filter((m) => m.id !== message.id) + } + } + + result.push({ message, siblingsGroup }) + } + } else { + // No siblings needed, just map messages + for (const msg of paginatedPath) { + result.push({ message: rowToMessage(msg) }) + } + } + + return { + items: result, + nextCursor, + activeNodeId: topic.activeNodeId + } + } + + /** + * Get a single message by ID + */ + async getById(id: string): Promise { + const db = dbService.getDb() + + const [row] = await db.select().from(messageTable).where(eq(messageTable.id, id)).limit(1) + + if (!row) { + throw DataApiErrorFactory.notFound('Message', id) + } + + return rowToMessage(row) + } + + /** + * Create a new message + * + * Uses transaction to ensure atomicity of: + * - Topic existence validation + * - Parent message validation (if specified) + * - Message insertion + * - Topic activeNodeId update + */ + async create(topicId: string, dto: CreateMessageDto): Promise { + const db = dbService.getDb() + + return await db.transaction(async (tx) => { + // Step 1: Verify topic exists and fetch its current state. + // We need the topic to check activeNodeId for parentId auto-resolution. + const [topic] = await tx.select().from(topicTable).where(eq(topicTable.id, topicId)).limit(1) + + if (!topic) { + throw DataApiErrorFactory.notFound('Topic', topicId) + } + + // Step 2: Resolve parentId based on the three possible input states: + // - undefined: auto-resolve based on topic state + // - null: explicitly create as root (must validate uniqueness) + // - string: use provided ID (must validate existence and ownership) + let resolvedParentId: string | null + + if (dto.parentId === undefined) { + // Auto-resolution mode: Determine parentId based on topic's current state. + // This provides convenience for callers who want to "append" to the conversation + // without needing to know the tree structure. + + // Check if topic has any existing messages by querying for at least one. + const [existingMessage] = await tx + .select({ id: messageTable.id }) + .from(messageTable) + .where(eq(messageTable.topicId, topicId)) + .limit(1) + + if (!existingMessage) { + // Topic is empty: This will be the first message, so it becomes the root. + // Root messages have parentId = null. + resolvedParentId = null + } else if (topic.activeNodeId) { + // Topic has messages and an active node: Attach new message as child of activeNodeId. + // This is the typical case for continuing a conversation. + resolvedParentId = topic.activeNodeId + } else { + // Topic has messages but no activeNodeId: This is an ambiguous state. + // We cannot auto-resolve because we don't know where in the tree to attach. + // Require explicit parentId from caller to resolve the ambiguity. + throw DataApiErrorFactory.invalidOperation( + 'create message', + 'Topic has messages but no activeNodeId. Please specify parentId explicitly.' + ) + } + } else if (dto.parentId === null) { + // Explicit root creation: Caller wants to create a root message. + // Each topic can only have one root message (parentId = null). + // Check if a root already exists to enforce this constraint. + + const [existingRoot] = await tx + .select({ id: messageTable.id }) + .from(messageTable) + .where(and(eq(messageTable.topicId, topicId), isNull(messageTable.parentId))) + .limit(1) + + if (existingRoot) { + // Root already exists: Cannot create another root message. + // This enforces the single-root tree structure constraint. + throw DataApiErrorFactory.invalidOperation('create root message', 'Topic already has a root message') + } + resolvedParentId = null + } else { + // Explicit parent ID provided: Validate the parent exists and belongs to this topic. + // This ensures referential integrity within the message tree. + + const [parent] = await tx.select().from(messageTable).where(eq(messageTable.id, dto.parentId)).limit(1) + + if (!parent) { + // Parent message not found: Cannot attach to non-existent message. + throw DataApiErrorFactory.notFound('Message', dto.parentId) + } + if (parent.topicId !== topicId) { + // Parent belongs to different topic: Cross-topic references are not allowed. + // Each topic's message tree must be self-contained. + throw DataApiErrorFactory.invalidOperation('create message', 'Parent message does not belong to this topic') + } + resolvedParentId = dto.parentId + } + + // Step 3: Insert the message using the resolved parentId. + const [row] = await tx + .insert(messageTable) + .values({ + topicId, + parentId: resolvedParentId, + role: dto.role, + data: dto.data, + status: dto.status ?? 'pending', + siblingsGroupId: dto.siblingsGroupId ?? 0, + assistantId: dto.assistantId, + assistantMeta: dto.assistantMeta, + modelId: dto.modelId, + modelMeta: dto.modelMeta, + traceId: dto.traceId, + stats: dto.stats + }) + .returning() + + // Update activeNodeId if setAsActive is not explicitly false + if (dto.setAsActive !== false) { + await tx.update(topicTable).set({ activeNodeId: row.id }).where(eq(topicTable.id, topicId)) + } + + logger.info('Created message', { id: row.id, topicId, role: dto.role, setAsActive: dto.setAsActive !== false }) + + return rowToMessage(row) + }) + } + + /** + * Update a message + * + * Uses transaction to ensure atomicity of validation and update. + * Cycle check is performed outside transaction as a read-only safety check. + */ + async update(id: string, dto: UpdateMessageDto): Promise { + const db = dbService.getDb() + + // Pre-transaction: Check for cycle if moving to new parent + // This is done outside transaction since getDescendantIds uses its own db context + // and cycle check is a safety check (worst case: reject valid operation) + if (dto.parentId !== undefined && dto.parentId !== null) { + const descendants = await this.getDescendantIds(id) + if (descendants.includes(dto.parentId)) { + throw DataApiErrorFactory.invalidOperation('move message', 'would create cycle') + } + } + + return await db.transaction(async (tx) => { + // Get existing message within transaction + const [existingRow] = await tx.select().from(messageTable).where(eq(messageTable.id, id)).limit(1) + + if (!existingRow) { + throw DataApiErrorFactory.notFound('Message', id) + } + + const existing = rowToMessage(existingRow) + + // Verify new parent exists if changing parent + if (dto.parentId !== undefined && dto.parentId !== existing.parentId && dto.parentId !== null) { + const [parent] = await tx.select().from(messageTable).where(eq(messageTable.id, dto.parentId)).limit(1) + + if (!parent) { + throw DataApiErrorFactory.notFound('Message', dto.parentId) + } + } + + // Build update object + const updates: Partial = {} + + if (dto.data !== undefined) updates.data = dto.data + if (dto.parentId !== undefined) updates.parentId = dto.parentId + if (dto.siblingsGroupId !== undefined) updates.siblingsGroupId = dto.siblingsGroupId + if (dto.status !== undefined) updates.status = dto.status + if (dto.traceId !== undefined) updates.traceId = dto.traceId + if (dto.stats !== undefined) updates.stats = dto.stats + + const [row] = await tx.update(messageTable).set(updates).where(eq(messageTable.id, id)).returning() + + logger.info('Updated message', { id, changes: Object.keys(dto) }) + + return rowToMessage(row) + }) + } + + /** + * Delete a message (hard delete) + * + * Supports two modes: + * - cascade=true: Delete the message and all its descendants + * - cascade=false: Delete only this message, reparent children to grandparent + * + * When the deleted message(s) include the topic's activeNodeId, it will be + * automatically updated based on activeNodeStrategy: + * - 'parent' (default): Sets activeNodeId to the deleted message's parent + * - 'clear': Sets activeNodeId to null + * + * All operations are performed within a transaction for consistency. + * + * @param id - Message ID to delete + * @param cascade - If true, delete descendants; if false, reparent children (default: false) + * @param activeNodeStrategy - Strategy for updating activeNodeId if affected (default: 'parent') + * @returns Deletion result including deletedIds, reparentedIds, and newActiveNodeId + * @throws NOT_FOUND if message doesn't exist + * @throws INVALID_OPERATION if deleting root without cascade=true + */ + async delete( + id: string, + cascade: boolean = false, + activeNodeStrategy: ActiveNodeStrategy = 'parent' + ): Promise { + const db = dbService.getDb() + + // Get the message + const message = await this.getById(id) + + // Get topic to check activeNodeId + const [topic] = await db.select().from(topicTable).where(eq(topicTable.id, message.topicId)).limit(1) + + if (!topic) { + throw DataApiErrorFactory.notFound('Topic', message.topicId) + } + + // Check if it's a root message + const isRoot = message.parentId === null + + if (isRoot && !cascade) { + throw DataApiErrorFactory.invalidOperation('delete root message', 'cascade=true required') + } + + // Get all descendant IDs before transaction (for cascade delete) + let descendantIds: string[] = [] + if (cascade) { + descendantIds = await this.getDescendantIds(id) + } + + // Use transaction for atomic delete + activeNodeId update + return await db.transaction(async (tx) => { + let deletedIds: string[] + let reparentedIds: string[] | undefined + let newActiveNodeId: string | null | undefined + + if (cascade) { + deletedIds = [id, ...descendantIds] + + // Check if activeNodeId is affected + if (topic.activeNodeId && deletedIds.includes(topic.activeNodeId)) { + newActiveNodeId = activeNodeStrategy === 'clear' ? null : message.parentId + } + + // Hard delete all + await tx.delete(messageTable).where(inArray(messageTable.id, deletedIds)) + + logger.info('Cascade deleted messages', { rootId: id, count: deletedIds.length }) + } else { + // Reparent children to this message's parent + const children = await tx + .select({ id: messageTable.id }) + .from(messageTable) + .where(eq(messageTable.parentId, id)) + + reparentedIds = children.map((c) => c.id) + + if (reparentedIds.length > 0) { + await tx + .update(messageTable) + .set({ parentId: message.parentId }) + .where(inArray(messageTable.id, reparentedIds)) + } + + deletedIds = [id] + + // Check if activeNodeId is affected + if (topic.activeNodeId === id) { + newActiveNodeId = activeNodeStrategy === 'clear' ? null : message.parentId + } + + // Hard delete this message + await tx.delete(messageTable).where(eq(messageTable.id, id)) + + logger.info('Deleted message with reparenting', { id, reparentedCount: reparentedIds.length }) + } + + // Update topic.activeNodeId if needed + if (newActiveNodeId !== undefined) { + await tx.update(topicTable).set({ activeNodeId: newActiveNodeId }).where(eq(topicTable.id, message.topicId)) + + logger.info('Updated topic activeNodeId after message deletion', { + topicId: message.topicId, + oldActiveNodeId: topic.activeNodeId, + newActiveNodeId + }) + } + + return { + deletedIds, + reparentedIds: reparentedIds?.length ? reparentedIds : undefined, + newActiveNodeId + } + }) + } + + /** + * Get all descendant IDs of a message + */ + private async getDescendantIds(id: string): Promise { + const db = dbService.getDb() + + // Use recursive query to get all descendants + const result = await db.all<{ id: string }>(sql` + WITH RECURSIVE descendants AS ( + SELECT id FROM message WHERE parent_id = ${id} + UNION ALL + SELECT m.id FROM message m + INNER JOIN descendants d ON m.parent_id = d.id + ) + SELECT id FROM descendants + `) + + return result.map((r) => r.id) + } + + /** + * Get path from root to a node + * + * Uses recursive CTE to fetch all ancestors in a single query, + * avoiding N+1 query problem for deep message trees. + */ + async getPathToNode(nodeId: string): Promise { + const db = dbService.getDb() + + // Use recursive CTE to get all ancestors in one query + const result = await db.all(sql` + WITH RECURSIVE ancestors AS ( + SELECT * FROM message WHERE id = ${nodeId} + UNION ALL + SELECT m.* FROM message m + INNER JOIN ancestors a ON m.id = a.parent_id + ) + SELECT * FROM ancestors + `) + + if (result.length === 0) { + throw DataApiErrorFactory.notFound('Message', nodeId) + } + + // Result is from nodeId to root, reverse to get root to nodeId + return result.reverse().map(rowToMessage) + } +} + +export const messageService = MessageService.getInstance() diff --git a/src/main/data/services/TestService.ts b/src/main/data/services/TestService.ts index 1af016cf44..7e7b810eaf 100644 --- a/src/main/data/services/TestService.ts +++ b/src/main/data/services/TestService.ts @@ -1,6 +1,6 @@ import { loggerService } from '@logger' -const logger = loggerService.withContext('TestService') +const logger = loggerService.withContext('DataApi:TestService') /** * Test service for API testing scenarios diff --git a/src/main/data/services/TopicService.ts b/src/main/data/services/TopicService.ts new file mode 100644 index 0000000000..d30215d283 --- /dev/null +++ b/src/main/data/services/TopicService.ts @@ -0,0 +1,233 @@ +/** + * Topic Service - handles topic CRUD and branch switching + * + * Provides business logic for: + * - Topic CRUD operations + * - Fork from existing conversation + * - Active node switching + */ + +import { dbService } from '@data/db/DbService' +import { messageTable } from '@data/db/schemas/message' +import { topicTable } from '@data/db/schemas/topic' +import { loggerService } from '@logger' +import { DataApiErrorFactory } from '@shared/data/api' +import type { CreateTopicDto, UpdateTopicDto } from '@shared/data/api/schemas/topics' +import type { Topic } from '@shared/data/types/topic' +import { eq } from 'drizzle-orm' + +import { messageService } from './MessageService' + +const logger = loggerService.withContext('DataApi:TopicService') + +/** + * Convert database row to Topic entity + */ +function rowToTopic(row: typeof topicTable.$inferSelect): Topic { + return { + id: row.id, + name: row.name, + isNameManuallyEdited: row.isNameManuallyEdited ?? false, + assistantId: row.assistantId, + assistantMeta: row.assistantMeta, + prompt: row.prompt, + activeNodeId: row.activeNodeId, + groupId: row.groupId, + sortOrder: row.sortOrder ?? 0, + isPinned: row.isPinned ?? false, + pinnedOrder: row.pinnedOrder ?? 0, + createdAt: row.createdAt ? new Date(row.createdAt).toISOString() : new Date().toISOString(), + updatedAt: row.updatedAt ? new Date(row.updatedAt).toISOString() : new Date().toISOString() + } +} + +export class TopicService { + private static instance: TopicService + + private constructor() {} + + public static getInstance(): TopicService { + if (!TopicService.instance) { + TopicService.instance = new TopicService() + } + return TopicService.instance + } + + /** + * Get a topic by ID + */ + async getById(id: string): Promise { + const db = dbService.getDb() + + const [row] = await db.select().from(topicTable).where(eq(topicTable.id, id)).limit(1) + + if (!row) { + throw DataApiErrorFactory.notFound('Topic', id) + } + + return rowToTopic(row) + } + + /** + * Create a new topic + */ + async create(dto: CreateTopicDto): Promise { + const db = dbService.getDb() + + // If forking from existing node, copy the path + if (dto.sourceNodeId) { + // Verify source node exists + try { + await messageService.getById(dto.sourceNodeId) + } catch { + throw DataApiErrorFactory.notFound('Message', dto.sourceNodeId) + } + + // Get path from root to source node + const path = await messageService.getPathToNode(dto.sourceNodeId) + + // Create new topic first using returning() to get the id + const [topicRow] = await db + .insert(topicTable) + .values({ + name: dto.name, + assistantId: dto.assistantId, + assistantMeta: dto.assistantMeta, + prompt: dto.prompt, + groupId: dto.groupId + }) + .returning() + + const topicId = topicRow.id + + // Copy messages with new IDs using returning() + const idMapping = new Map() + let activeNodeId: string | null = null + + for (const message of path) { + const newParentId = message.parentId ? idMapping.get(message.parentId) || null : null + + const [messageRow] = await db + .insert(messageTable) + .values({ + topicId, + parentId: newParentId, + role: message.role, + data: message.data, + status: message.status, + siblingsGroupId: 0, // Simplify multi-model to normal node + assistantId: message.assistantId, + assistantMeta: message.assistantMeta, + modelId: message.modelId, + modelMeta: message.modelMeta, + traceId: null, + stats: null + }) + .returning() + + idMapping.set(message.id, messageRow.id) + activeNodeId = messageRow.id + } + + // Update topic with active node + await db.update(topicTable).set({ activeNodeId }).where(eq(topicTable.id, topicId)) + + logger.info('Created topic by forking', { + id: topicId, + sourceNodeId: dto.sourceNodeId, + messageCount: path.length + }) + + return this.getById(topicId) + } else { + // Create empty topic using returning() + const [row] = await db + .insert(topicTable) + .values({ + name: dto.name, + assistantId: dto.assistantId, + assistantMeta: dto.assistantMeta, + prompt: dto.prompt, + groupId: dto.groupId + }) + .returning() + + logger.info('Created empty topic', { id: row.id }) + + return rowToTopic(row) + } + } + + /** + * Update a topic + */ + async update(id: string, dto: UpdateTopicDto): Promise { + const db = dbService.getDb() + + // Verify topic exists + await this.getById(id) + + // Build update object + const updates: Partial = {} + + if (dto.name !== undefined) updates.name = dto.name + if (dto.isNameManuallyEdited !== undefined) updates.isNameManuallyEdited = dto.isNameManuallyEdited + if (dto.assistantId !== undefined) updates.assistantId = dto.assistantId + if (dto.assistantMeta !== undefined) updates.assistantMeta = dto.assistantMeta + if (dto.prompt !== undefined) updates.prompt = dto.prompt + if (dto.groupId !== undefined) updates.groupId = dto.groupId + if (dto.sortOrder !== undefined) updates.sortOrder = dto.sortOrder + if (dto.isPinned !== undefined) updates.isPinned = dto.isPinned + if (dto.pinnedOrder !== undefined) updates.pinnedOrder = dto.pinnedOrder + + const [row] = await db.update(topicTable).set(updates).where(eq(topicTable.id, id)).returning() + + logger.info('Updated topic', { id, changes: Object.keys(dto) }) + + return rowToTopic(row) + } + + /** + * Delete a topic and all its messages (hard delete) + */ + async delete(id: string): Promise { + const db = dbService.getDb() + + // Verify topic exists + await this.getById(id) + + // Hard delete all messages first (due to foreign key) + await db.delete(messageTable).where(eq(messageTable.topicId, id)) + + // Hard delete topic + await db.delete(topicTable).where(eq(topicTable.id, id)) + + logger.info('Deleted topic', { id }) + } + + /** + * Set the active node for a topic + */ + async setActiveNode(topicId: string, nodeId: string): Promise<{ activeNodeId: string }> { + const db = dbService.getDb() + + // Verify topic exists + await this.getById(topicId) + + // Verify node exists and belongs to this topic + const [message] = await db.select().from(messageTable).where(eq(messageTable.id, nodeId)).limit(1) + + if (!message || message.topicId !== topicId) { + throw DataApiErrorFactory.notFound('Message', nodeId) + } + + // Update active node + await db.update(topicTable).set({ activeNodeId: nodeId }).where(eq(topicTable.id, topicId)) + + logger.info('Set active node', { topicId, nodeId }) + + return { activeNodeId: nodeId } + } +} + +export const topicService = TopicService.getInstance() diff --git a/src/main/data/services/base/IBaseService.ts b/src/main/data/services/base/IBaseService.ts index 446de55716..d9b3a4b0be 100644 --- a/src/main/data/services/base/IBaseService.ts +++ b/src/main/data/services/base/IBaseService.ts @@ -1,4 +1,9 @@ -import type { PaginationParams, ServiceOptions } from '@shared/data/api/apiTypes' +import type { CursorPaginationParams, OffsetPaginationParams, ServiceOptions } from '@shared/data/api/apiTypes' + +/** + * Base pagination params for service layer (supports both modes) + */ +type BasePaginationParams = (OffsetPaginationParams | CursorPaginationParams) & Record /** * Standard service interface for data operations @@ -14,12 +19,12 @@ export interface IBaseService { * Find multiple entities with pagination */ findMany( - params: PaginationParams & Record, + params: BasePaginationParams, options?: ServiceOptions ): Promise<{ items: T[] - total: number - hasNext?: boolean + total?: number + page?: number nextCursor?: string }> @@ -68,12 +73,12 @@ export interface ISearchableService exten */ search( query: string, - params?: PaginationParams, + params?: BasePaginationParams, options?: ServiceOptions ): Promise<{ items: T[] - total: number - hasNext?: boolean + total?: number + page?: number nextCursor?: string }> } @@ -87,12 +92,12 @@ export interface IHierarchicalService diff --git a/src/main/index.ts b/src/main/index.ts index 0456797abd..6c692ce532 100644 --- a/src/main/index.ts +++ b/src/main/index.ts @@ -20,8 +20,10 @@ import { registerIpc } from './ipc' import { agentService } from './services/agents' import { apiServerService } from './services/ApiServerService' import { appMenuService } from './services/AppMenuService' -import { nodeTraceService } from './services/NodeTraceService' +import { lanTransferClientService } from './services/lanTransfer' import mcpService from './services/MCPService' +import { localTransferService } from './services/LocalTransferService' +import { nodeTraceService } from './services/NodeTraceService' import powerMonitorService from './services/PowerMonitorService' import { CHERRY_STUDIO_PROTOCOL, @@ -45,6 +47,7 @@ import { dataApiService } from '@data/DataApiService' import { cacheService } from '@data/CacheService' import { initWebviewHotkeys } from './services/WebviewService' import { runAsyncFunction } from './utils' +import { isOvmsSupported } from './services/OvmsManager' const logger = loggerService.withContext('MainEntry') @@ -132,10 +135,25 @@ if (!app.requestSingleInstanceLock()) { // initialization and is ready to create browser windows. // Some APIs can only be used after this event occurs. app.whenReady().then(async () => { + //TODO v2 Data Refactor: App Lifecycle Management + // This is the temporary solution for the data migration v2. + // We will refactor the app lifecycle management after the data migration v2 is stable. + // First of all, init & migrate the database - await dbService.init() - await dbService.migrateDb() - await dbService.migrateSeed('preference') + try { + await dbService.init() + await dbService.migrateDb() + await dbService.migrateSeed('preference') + } catch (error) { + logger.error('Failed to initialize database', error as Error) + //TODO for v2 testing only: + await dialog.showErrorBox( + 'Database Initialization Failed', + 'Before the official release of the alpha version, the database structure may change at any time. To maintain simplicity, the database migration files will be periodically reinitialized, which may cause the application to fail. If this occurs, please delete the cherrystudio.sqlite file located in the user data directory.' + ) + app.quit() + return + } // Data Migration v2 // Check if data migration is needed BEFORE creating any windows @@ -241,7 +259,8 @@ if (!app.requestSingleInstanceLock()) { }) registerShortcuts(mainWindow) - registerIpc(mainWindow, app) + await registerIpc(mainWindow, app) + localTransferService.startDiscovery({ resetList: true }) replaceDevtoolsFont(mainWindow) @@ -323,16 +342,29 @@ if (!app.requestSingleInstanceLock()) { if (selectionService) { selectionService.quit() } + + lanTransferClientService.dispose() + localTransferService.dispose() }) app.on('will-quit', async () => { // 简单的资源清理,不阻塞退出流程 + if (isOvmsSupported) { + const { ovmsManager } = await import('./services/OvmsManager') + if (ovmsManager) { + await ovmsManager.stopOvms() + } else { + logger.warn('Unexpected behavior: undefined ovmsManager, but OVMS should be supported.') + } + } + try { await mcpService.cleanup() await apiServerService.stop() } catch (error) { logger.warn('Error cleaning up MCP service:', error as Error) } + // finish the logger logger.finish() }) diff --git a/src/main/ipc.ts b/src/main/ipc.ts index 55fa3a17b0..7560d67460 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -8,10 +8,18 @@ import { loggerService } from '@logger' import { isLinux, isMac, isPortable, isWin } from '@main/constant' import { generateSignature } from '@main/integration/cherryai' import anthropicService from '@main/services/AnthropicService' -import { findGitBash, getBinaryPath, isBinaryExists, runInstallScript, validateGitBashPath } from '@main/utils/process' +import { + autoDiscoverGitBash, + getBinaryPath, + getGitBashPathInfo, + isBinaryExists, + runInstallScript, + validateGitBashPath +} from '@main/utils/process' import { handleZoomFactor } from '@main/utils/zoom' import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core' import { MIN_WINDOW_HEIGHT, MIN_WINDOW_WIDTH } from '@shared/config/constant' +import type { LocalTransferConnectPayload } from '@shared/config/types' import type { UpgradeChannel } from '@shared/data/preference/preferenceTypes' import { IpcChannel } from '@shared/IpcChannel' import type { @@ -43,6 +51,8 @@ import { ExportService } from './services/ExportService' import { fileStorage as fileManager } from './services/FileStorage' import FileService from './services/FileSystemService' import KnowledgeService from './services/KnowledgeService' +import { lanTransferClientService } from './services/lanTransfer' +import { localTransferService } from './services/LocalTransferService' import mcpService from './services/MCPService' import MemoryService from './services/memory/MemoryService' import { openTraceWindow, setTraceWindowTitle } from './services/NodeTraceService' @@ -50,7 +60,7 @@ import NotificationService from './services/NotificationService' import * as NutstoreService from './services/NutstoreService' import ObsidianVaultService from './services/ObsidianVaultService' import { ocrService } from './services/ocr/OcrService' -import OvmsManager from './services/OvmsManager' +import { isOvmsSupported } from './services/OvmsManager' import powerMonitorService from './services/PowerMonitorService' import { proxyManager } from './services/ProxyManager' import { pythonService } from './services/PythonService' @@ -73,7 +83,6 @@ import { } from './services/SpanCacheService' import storeSyncService from './services/StoreSyncService' import VertexAIService from './services/VertexAIService' -import WebSocketService from './services/WebSocketService' import { setOpenLinkExternal } from './services/WebviewService' import { windowService } from './services/WindowService' import { calculateDirectorySize, getResourcePath } from './utils' @@ -88,6 +97,7 @@ import { untildify } from './utils/file' import { updateAppDataConfig } from './utils/init' +import { getCpuName, getDeviceType, getHostname } from './utils/system' import { compress, decompress } from './utils/zip' const logger = loggerService.withContext('IPC') @@ -98,7 +108,6 @@ const obsidianVaultService = new ObsidianVaultService() const vertexAIService = VertexAIService.getInstance() const memoryService = MemoryService.getInstance() const dxtService = new DxtService() -const ovmsManager = new OvmsManager() const pluginService = PluginService.getInstance() function normalizeError(error: unknown): Error { @@ -112,7 +121,7 @@ function extractPluginError(error: unknown): PluginError | null { return null } -export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { +export async function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { const appUpdater = new AppUpdater() const notificationService = new NotificationService() @@ -491,18 +500,17 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text)) // system - ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux')) - ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname()) - ipcMain.handle(IpcChannel.System_GetCpuName, () => require('os').cpus()[0].model) + ipcMain.handle(IpcChannel.System_GetDeviceType, getDeviceType) + ipcMain.handle(IpcChannel.System_GetHostname, getHostname) + ipcMain.handle(IpcChannel.System_GetCpuName, getCpuName) ipcMain.handle(IpcChannel.System_CheckGitBash, () => { if (!isWin) { return true // Non-Windows systems don't need Git Bash } try { - const customPath = configManager.get(ConfigKeys.GitBashPath) as string | undefined - const bashPath = findGitBash(customPath) - + // Use autoDiscoverGitBash to handle auto-discovery and persistence + const bashPath = autoDiscoverGitBash() if (bashPath) { logger.info('Git Bash is available', { path: bashPath }) return true @@ -525,13 +533,22 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { return customPath ?? null }) + // Returns { path, source } where source is 'manual' | 'auto' | null + ipcMain.handle(IpcChannel.System_GetGitBashPathInfo, () => { + return getGitBashPathInfo() + }) + ipcMain.handle(IpcChannel.System_SetGitBashPath, (_, newPath: string | null) => { if (!isWin) { return false } if (!newPath) { + // Clear manual setting and re-run auto-discovery configManager.set(ConfigKeys.GitBashPath, null) + configManager.set(ConfigKeys.GitBashPathSource, null) + // Re-run auto-discovery to restore auto-discovered path if available + autoDiscoverGitBash() return true } @@ -540,7 +557,9 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { return false } + // Set path with 'manual' source configManager.set(ConfigKeys.GitBashPath, validated) + configManager.set(ConfigKeys.GitBashPathSource, 'manual') return true }) @@ -567,6 +586,8 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.Backup_ListS3Files, backupManager.listS3Files.bind(backupManager)) ipcMain.handle(IpcChannel.Backup_DeleteS3File, backupManager.deleteS3File.bind(backupManager)) ipcMain.handle(IpcChannel.Backup_CheckS3Connection, backupManager.checkS3Connection.bind(backupManager)) + ipcMain.handle(IpcChannel.Backup_CreateLanTransferBackup, backupManager.createLanTransferBackup.bind(backupManager)) + ipcMain.handle(IpcChannel.Backup_DeleteTempBackup, backupManager.deleteTempBackup.bind(backupManager)) // file ipcMain.handle(IpcChannel.File_Open, fileManager.open.bind(fileManager)) @@ -666,36 +687,19 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.KnowledgeBase_Check_Quota, KnowledgeService.checkQuota.bind(KnowledgeService)) // memory - ipcMain.handle(IpcChannel.Memory_Add, async (_, messages, config) => { - return await memoryService.add(messages, config) - }) - ipcMain.handle(IpcChannel.Memory_Search, async (_, query, config) => { - return await memoryService.search(query, config) - }) - ipcMain.handle(IpcChannel.Memory_List, async (_, config) => { - return await memoryService.list(config) - }) - ipcMain.handle(IpcChannel.Memory_Delete, async (_, id) => { - return await memoryService.delete(id) - }) - ipcMain.handle(IpcChannel.Memory_Update, async (_, id, memory, metadata) => { - return await memoryService.update(id, memory, metadata) - }) - ipcMain.handle(IpcChannel.Memory_Get, async (_, memoryId) => { - return await memoryService.get(memoryId) - }) - ipcMain.handle(IpcChannel.Memory_SetConfig, async (_, config) => { - memoryService.setConfig(config) - }) - ipcMain.handle(IpcChannel.Memory_DeleteUser, async (_, userId) => { - return await memoryService.deleteUser(userId) - }) - ipcMain.handle(IpcChannel.Memory_DeleteAllMemoriesForUser, async (_, userId) => { - return await memoryService.deleteAllMemoriesForUser(userId) - }) - ipcMain.handle(IpcChannel.Memory_GetUsersList, async () => { - return await memoryService.getUsersList() - }) + ipcMain.handle(IpcChannel.Memory_Add, (_, messages, config) => memoryService.add(messages, config)) + ipcMain.handle(IpcChannel.Memory_Search, (_, query, config) => memoryService.search(query, config)) + ipcMain.handle(IpcChannel.Memory_List, (_, config) => memoryService.list(config)) + ipcMain.handle(IpcChannel.Memory_Delete, (_, id) => memoryService.delete(id)) + ipcMain.handle(IpcChannel.Memory_Update, (_, id, memory, metadata) => memoryService.update(id, memory, metadata)) + ipcMain.handle(IpcChannel.Memory_Get, (_, memoryId) => memoryService.get(memoryId)) + ipcMain.handle(IpcChannel.Memory_SetConfig, (_, config) => memoryService.setConfig(config)) + ipcMain.handle(IpcChannel.Memory_DeleteUser, (_, userId) => memoryService.deleteUser(userId)) + ipcMain.handle(IpcChannel.Memory_DeleteAllMemoriesForUser, (_, userId) => + memoryService.deleteAllMemoriesForUser(userId) + ) + ipcMain.handle(IpcChannel.Memory_GetUsersList, () => memoryService.getUsersList()) + ipcMain.handle(IpcChannel.Memory_MigrateMemoryDb, () => memoryService.migrateMemoryDb()) // window ipcMain.handle(IpcChannel.Windows_SetMinimumSize, (_, width: number, height: number) => { @@ -855,8 +859,8 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ) // search window - ipcMain.handle(IpcChannel.SearchWindow_Open, async (_, uid: string) => { - await searchService.openSearchWindow(uid) + ipcMain.handle(IpcChannel.SearchWindow_Open, async (_, uid: string, show?: boolean) => { + await searchService.openSearchWindow(uid, show) }) ipcMain.handle(IpcChannel.SearchWindow_Close, async (_, uid: string) => { await searchService.closeSearchWindow(uid) @@ -972,15 +976,36 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.OCR_ListProviders, () => ocrService.listProviderIds()) // OVMS - ipcMain.handle(IpcChannel.Ovms_AddModel, (_, modelName: string, modelId: string, modelSource: string, task: string) => - ovmsManager.addModel(modelName, modelId, modelSource, task) - ) - ipcMain.handle(IpcChannel.Ovms_StopAddModel, () => ovmsManager.stopAddModel()) - ipcMain.handle(IpcChannel.Ovms_GetModels, () => ovmsManager.getModels()) - ipcMain.handle(IpcChannel.Ovms_IsRunning, () => ovmsManager.initializeOvms()) - ipcMain.handle(IpcChannel.Ovms_GetStatus, () => ovmsManager.getOvmsStatus()) - ipcMain.handle(IpcChannel.Ovms_RunOVMS, () => ovmsManager.runOvms()) - ipcMain.handle(IpcChannel.Ovms_StopOVMS, () => ovmsManager.stopOvms()) + ipcMain.handle(IpcChannel.Ovms_IsSupported, () => isOvmsSupported) + if (isOvmsSupported) { + const { ovmsManager } = await import('./services/OvmsManager') + if (ovmsManager) { + ipcMain.handle( + IpcChannel.Ovms_AddModel, + (_, modelName: string, modelId: string, modelSource: string, task: string) => + ovmsManager.addModel(modelName, modelId, modelSource, task) + ) + ipcMain.handle(IpcChannel.Ovms_StopAddModel, () => ovmsManager.stopAddModel()) + ipcMain.handle(IpcChannel.Ovms_GetModels, () => ovmsManager.getModels()) + ipcMain.handle(IpcChannel.Ovms_IsRunning, () => ovmsManager.initializeOvms()) + ipcMain.handle(IpcChannel.Ovms_GetStatus, () => ovmsManager.getOvmsStatus()) + ipcMain.handle(IpcChannel.Ovms_RunOVMS, () => ovmsManager.runOvms()) + ipcMain.handle(IpcChannel.Ovms_StopOVMS, () => ovmsManager.stopOvms()) + } else { + logger.error('Unexpected behavior: undefined ovmsManager, but OVMS should be supported.') + } + } else { + const fallback = () => { + throw new Error('OVMS is only supported on Windows with intel CPU.') + } + ipcMain.handle(IpcChannel.Ovms_AddModel, fallback) + ipcMain.handle(IpcChannel.Ovms_StopAddModel, fallback) + ipcMain.handle(IpcChannel.Ovms_GetModels, fallback) + ipcMain.handle(IpcChannel.Ovms_IsRunning, fallback) + ipcMain.handle(IpcChannel.Ovms_GetStatus, fallback) + ipcMain.handle(IpcChannel.Ovms_RunOVMS, fallback) + ipcMain.handle(IpcChannel.Ovms_StopOVMS, fallback) + } // CherryAI ipcMain.handle(IpcChannel.Cherryai_GetSignature, (_, params) => generateSignature(params)) @@ -1037,12 +1062,18 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { } catch (error) { const pluginError = extractPluginError(error) if (pluginError) { - logger.error('Failed to list installed plugins', { agentId, error: pluginError }) + logger.error('Failed to list installed plugins', { + agentId, + error: pluginError + }) return { success: false, error: pluginError } } const err = normalizeError(error) - logger.error('Failed to list installed plugins', { agentId, error: err }) + logger.error('Failed to list installed plugins', { + agentId, + error: err + }) return { success: false, error: { @@ -1098,12 +1129,17 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { } }) - // WebSocket - ipcMain.handle(IpcChannel.WebSocket_Start, WebSocketService.start) - ipcMain.handle(IpcChannel.WebSocket_Stop, WebSocketService.stop) - ipcMain.handle(IpcChannel.WebSocket_Status, WebSocketService.getStatus) - ipcMain.handle(IpcChannel.WebSocket_SendFile, WebSocketService.sendFile) - ipcMain.handle(IpcChannel.WebSocket_GetAllCandidates, WebSocketService.getAllCandidates) + ipcMain.handle(IpcChannel.LocalTransfer_ListServices, () => localTransferService.getState()) + ipcMain.handle(IpcChannel.LocalTransfer_StartScan, () => localTransferService.startDiscovery({ resetList: true })) + ipcMain.handle(IpcChannel.LocalTransfer_StopScan, () => localTransferService.stopDiscovery()) + ipcMain.handle(IpcChannel.LocalTransfer_Connect, (_, payload: LocalTransferConnectPayload) => + lanTransferClientService.connectAndHandshake(payload) + ) + ipcMain.handle(IpcChannel.LocalTransfer_Disconnect, () => lanTransferClientService.disconnect()) + ipcMain.handle(IpcChannel.LocalTransfer_SendFile, (_, payload: { filePath: string }) => + lanTransferClientService.sendFile(payload.filePath) + ) + ipcMain.handle(IpcChannel.LocalTransfer_CancelTransfer, () => lanTransferClientService.cancelTransfer()) ipcMain.handle(IpcChannel.APP_CrashRenderProcess, () => { mainWindow.webContents.forcefullyCrashRenderer() diff --git a/src/main/mcpServers/__tests__/browser.test.ts b/src/main/mcpServers/__tests__/browser.test.ts new file mode 100644 index 0000000000..800d03d7c5 --- /dev/null +++ b/src/main/mcpServers/__tests__/browser.test.ts @@ -0,0 +1,372 @@ +import { describe, expect, it, vi } from 'vitest' + +vi.mock('node:fs', () => ({ + default: { + existsSync: vi.fn(() => false), + mkdirSync: vi.fn() + }, + existsSync: vi.fn(() => false), + mkdirSync: vi.fn() +})) + +vi.mock('electron', () => { + const sendCommand = vi.fn(async (command: string, params?: { expression?: string }) => { + if (command === 'Runtime.evaluate') { + if (params?.expression === 'document.documentElement.outerHTML') { + return { result: { value: '

Test

Content

' } } + } + if (params?.expression === 'document.body.innerText') { + return { result: { value: 'Test\nContent' } } + } + return { result: { value: 'ok' } } + } + return {} + }) + + const debuggerObj = { + isAttached: vi.fn(() => true), + attach: vi.fn(), + detach: vi.fn(), + sendCommand + } + + const createWebContents = () => ({ + debugger: debuggerObj, + setUserAgent: vi.fn(), + getURL: vi.fn(() => 'https://example.com/'), + getTitle: vi.fn(async () => 'Example Title'), + loadURL: vi.fn(async () => {}), + once: vi.fn(), + removeListener: vi.fn(), + on: vi.fn(), + isDestroyed: vi.fn(() => false), + canGoBack: vi.fn(() => false), + canGoForward: vi.fn(() => false), + goBack: vi.fn(), + goForward: vi.fn(), + reload: vi.fn(), + executeJavaScript: vi.fn(async () => null), + setWindowOpenHandler: vi.fn() + }) + + const windows: any[] = [] + const views: any[] = [] + + class MockBrowserWindow { + private destroyed = false + public webContents = createWebContents() + public isDestroyed = vi.fn(() => this.destroyed) + public close = vi.fn(() => { + this.destroyed = true + }) + public destroy = vi.fn(() => { + this.destroyed = true + }) + public on = vi.fn() + public setBrowserView = vi.fn() + public addBrowserView = vi.fn() + public removeBrowserView = vi.fn() + public getContentSize = vi.fn(() => [1200, 800]) + public show = vi.fn() + + constructor() { + windows.push(this) + } + } + + class MockBrowserView { + public webContents = createWebContents() + public setBounds = vi.fn() + public setAutoResize = vi.fn() + public destroy = vi.fn() + + constructor() { + views.push(this) + } + } + + const app = { + isReady: vi.fn(() => true), + whenReady: vi.fn(async () => {}), + on: vi.fn(), + getPath: vi.fn((key: string) => { + if (key === 'userData') return '/mock/userData' + if (key === 'temp') return '/tmp' + return '/mock/unknown' + }), + getAppPath: vi.fn(() => '/mock/app'), + setPath: vi.fn() + } + + const nativeTheme = { + on: vi.fn(), + shouldUseDarkColors: false + } + + return { + BrowserWindow: MockBrowserWindow as any, + BrowserView: MockBrowserView as any, + app, + nativeTheme, + __mockDebugger: debuggerObj, + __mockSendCommand: sendCommand, + __mockWindows: windows, + __mockViews: views + } +}) + +import { CdpBrowserController } from '../browser' + +describe('CdpBrowserController', () => { + it('executes single-line code via Runtime.evaluate', async () => { + const controller = new CdpBrowserController() + const result = await controller.execute('1+1') + expect(result).toBe('ok') + }) + + it('opens a URL in normal mode and returns current page info', async () => { + const controller = new CdpBrowserController() + const result = await controller.open('https://foo.bar/', 5000, false) + expect(result.currentUrl).toBe('https://example.com/') + expect(result.title).toBe('Example Title') + }) + + it('opens a URL in private mode', async () => { + const controller = new CdpBrowserController() + const result = await controller.open('https://foo.bar/', 5000, true) + expect(result.currentUrl).toBe('https://example.com/') + expect(result.title).toBe('Example Title') + }) + + it('reuses session for execute and supports multiline', async () => { + const controller = new CdpBrowserController() + await controller.open('https://foo.bar/', 5000, false) + const result = await controller.execute('const a=1; const b=2; a+b;', 5000, false) + expect(result).toBe('ok') + }) + + it('normal and private modes are isolated', async () => { + const controller = new CdpBrowserController() + await controller.open('https://foo.bar/', 5000, false) + await controller.open('https://foo.bar/', 5000, true) + const normalResult = await controller.execute('1+1', 5000, false) + const privateResult = await controller.execute('1+1', 5000, true) + expect(normalResult).toBe('ok') + expect(privateResult).toBe('ok') + }) + + it('fetches URL and returns html format with tabId', async () => { + const controller = new CdpBrowserController() + const result = await controller.fetch('https://example.com/', 'html') + expect(result.tabId).toBeDefined() + expect(result.content).toBe('

Test

Content

') + }) + + it('fetches URL and returns txt format with tabId', async () => { + const controller = new CdpBrowserController() + const result = await controller.fetch('https://example.com/', 'txt') + expect(result.tabId).toBeDefined() + expect(result.content).toBe('Test\nContent') + }) + + it('fetches URL and returns markdown format (default) with tabId', async () => { + const controller = new CdpBrowserController() + const result = await controller.fetch('https://example.com/') + expect(result.tabId).toBeDefined() + expect(typeof result.content).toBe('string') + expect(result.content).toContain('Test') + }) + + it('fetches URL in private mode with tabId', async () => { + const controller = new CdpBrowserController() + const result = await controller.fetch('https://example.com/', 'html', 10000, true) + expect(result.tabId).toBeDefined() + expect(result.content).toBe('

Test

Content

') + }) + + describe('Multi-tab support', () => { + it('creates new tab with newTab parameter', async () => { + const controller = new CdpBrowserController() + const result1 = await controller.open('https://site1.com/', 5000, false, true) + const result2 = await controller.open('https://site2.com/', 5000, false, true) + + expect(result1.tabId).toBeDefined() + expect(result2.tabId).toBeDefined() + expect(result1.tabId).not.toBe(result2.tabId) + }) + + it('reuses same tab without newTab parameter', async () => { + const controller = new CdpBrowserController() + const result1 = await controller.open('https://site1.com/', 5000, false) + const result2 = await controller.open('https://site2.com/', 5000, false) + + expect(result1.tabId).toBe(result2.tabId) + }) + + it('fetches in new tab with newTab parameter', async () => { + const controller = new CdpBrowserController() + await controller.open('https://example.com/', 5000, false) + const tabs = await controller.listTabs(false) + const initialTabCount = tabs.length + + await controller.fetch('https://other.com/', 'html', 10000, false, true) + const tabsAfter = await controller.listTabs(false) + + expect(tabsAfter.length).toBe(initialTabCount + 1) + }) + }) + + describe('Tab management', () => { + it('lists tabs in a window', async () => { + const controller = new CdpBrowserController() + await controller.open('https://example.com/', 5000, false) + + const tabs = await controller.listTabs(false) + expect(tabs.length).toBeGreaterThan(0) + expect(tabs[0].tabId).toBeDefined() + }) + + it('lists tabs separately for normal and private modes', async () => { + const controller = new CdpBrowserController() + await controller.open('https://example.com/', 5000, false) + await controller.open('https://example.com/', 5000, true) + + const normalTabs = await controller.listTabs(false) + const privateTabs = await controller.listTabs(true) + + expect(normalTabs.length).toBe(1) + expect(privateTabs.length).toBe(1) + expect(normalTabs[0].tabId).not.toBe(privateTabs[0].tabId) + }) + + it('closes specific tab', async () => { + const controller = new CdpBrowserController() + const result1 = await controller.open('https://site1.com/', 5000, false, true) + await controller.open('https://site2.com/', 5000, false, true) + + const tabsBefore = await controller.listTabs(false) + expect(tabsBefore.length).toBe(2) + + await controller.closeTab(false, result1.tabId) + + const tabsAfter = await controller.listTabs(false) + expect(tabsAfter.length).toBe(1) + expect(tabsAfter.find((t) => t.tabId === result1.tabId)).toBeUndefined() + }) + + it('switches active tab', async () => { + const controller = new CdpBrowserController() + const result1 = await controller.open('https://site1.com/', 5000, false, true) + const result2 = await controller.open('https://site2.com/', 5000, false, true) + + await controller.switchTab(false, result1.tabId) + await controller.switchTab(false, result2.tabId) + }) + + it('throws error when switching to non-existent tab', async () => { + const controller = new CdpBrowserController() + await controller.open('https://example.com/', 5000, false) + + await expect(controller.switchTab(false, 'non-existent-tab')).rejects.toThrow('Tab non-existent-tab not found') + }) + }) + + describe('Reset behavior', () => { + it('resets specific tab only', async () => { + const controller = new CdpBrowserController() + const result1 = await controller.open('https://site1.com/', 5000, false, true) + await controller.open('https://site2.com/', 5000, false, true) + + await controller.reset(false, result1.tabId) + + const tabs = await controller.listTabs(false) + expect(tabs.length).toBe(1) + }) + + it('resets specific window only', async () => { + const controller = new CdpBrowserController() + await controller.open('https://example.com/', 5000, false) + await controller.open('https://example.com/', 5000, true) + + await controller.reset(false) + + const normalTabs = await controller.listTabs(false) + const privateTabs = await controller.listTabs(true) + + expect(normalTabs.length).toBe(0) + expect(privateTabs.length).toBe(1) + }) + + it('resets all windows', async () => { + const controller = new CdpBrowserController() + await controller.open('https://example.com/', 5000, false) + await controller.open('https://example.com/', 5000, true) + + await controller.reset() + + const normalTabs = await controller.listTabs(false) + const privateTabs = await controller.listTabs(true) + + expect(normalTabs.length).toBe(0) + expect(privateTabs.length).toBe(0) + }) + }) + + describe('showWindow parameter', () => { + it('passes showWindow parameter through open', async () => { + const controller = new CdpBrowserController() + const result = await controller.open('https://example.com/', 5000, false, false, true) + expect(result.currentUrl).toBe('https://example.com/') + expect(result.tabId).toBeDefined() + }) + + it('passes showWindow parameter through fetch', async () => { + const controller = new CdpBrowserController() + const result = await controller.fetch('https://example.com/', 'html', 10000, false, false, true) + expect(result.tabId).toBeDefined() + expect(result.content).toBe('

Test

Content

') + }) + + it('passes showWindow parameter through createTab', async () => { + const controller = new CdpBrowserController() + const { tabId, view } = await controller.createTab(false, true) + expect(tabId).toBeDefined() + expect(view).toBeDefined() + }) + + it('shows existing window when showWindow=true on subsequent calls', async () => { + const controller = new CdpBrowserController() + // First call creates window + await controller.open('https://example.com/', 5000, false, false, false) + // Second call with showWindow=true should show existing window + const result = await controller.open('https://example.com/', 5000, false, false, true) + expect(result.currentUrl).toBe('https://example.com/') + }) + }) + + describe('Window limits and eviction', () => { + it('respects maxWindows limit', async () => { + const controller = new CdpBrowserController({ maxWindows: 1 }) + await controller.open('https://example.com/', 5000, false) + await controller.open('https://example.com/', 5000, true) + + const normalTabs = await controller.listTabs(false) + const privateTabs = await controller.listTabs(true) + + expect(privateTabs.length).toBe(1) + expect(normalTabs.length).toBe(0) + }) + + it('cleans up idle windows on next access', async () => { + const controller = new CdpBrowserController({ idleTimeoutMs: 1 }) + await controller.open('https://example.com/', 5000, false) + + await new Promise((r) => setTimeout(r, 10)) + + await controller.open('https://example.com/', 5000, true) + + const normalTabs = await controller.listTabs(false) + expect(normalTabs.length).toBe(0) + }) + }) +}) diff --git a/src/main/mcpServers/browser/README.md b/src/main/mcpServers/browser/README.md new file mode 100644 index 0000000000..27d1307782 --- /dev/null +++ b/src/main/mcpServers/browser/README.md @@ -0,0 +1,177 @@ +# Browser MCP Server + +A Model Context Protocol (MCP) server for controlling browser windows via Chrome DevTools Protocol (CDP). + +## Features + +### ✨ User Data Persistence +- **Normal mode (default)**: Cookies, localStorage, and sessionStorage persist across browser restarts +- **Private mode**: Ephemeral browsing - no data persists (like incognito mode) + +### 🔄 Window Management +- Two browsing modes: normal (persistent) and private (ephemeral) +- Lazy idle timeout cleanup (cleaned on next window access) +- Maximum window limits to prevent resource exhaustion + +> **Note**: Normal mode uses a global `persist:default` partition shared by all clients. This means login sessions and stored data are accessible to any code using the MCP server. + +## Architecture + +### How It Works +``` +Normal Mode (BrowserWindow) +├─ Persistent Storage (partition: persist:default) ← Global, shared across all clients +└─ Tabs (BrowserView) ← created via newTab or automatically + +Private Mode (BrowserWindow) +├─ Ephemeral Storage (partition: private) ← No disk persistence +└─ Tabs (BrowserView) ← created via newTab or automatically +``` + +- **One Window Per Mode**: Normal and private modes each have their own window +- **Multi-Tab Support**: Use `newTab: true` for parallel URL requests +- **Storage Isolation**: Normal and private modes have completely separate storage + +## Available Tools + +### `open` +Open a URL in a browser window. Optionally return page content. +```json +{ + "url": "https://example.com", + "format": "markdown", + "timeout": 10000, + "privateMode": false, + "newTab": false, + "showWindow": false +} +``` +- `format`: If set (`html`, `txt`, `markdown`, `json`), returns page content in that format along with tabId. If not set, just opens the page and returns navigation info. +- `newTab`: Set to `true` to open in a new tab (required for parallel requests) +- `showWindow`: Set to `true` to display the browser window (useful for debugging) +- Returns (without format): `{ currentUrl, title, tabId }` +- Returns (with format): `{ tabId, content }` where content is in the specified format + +### `execute` +Execute JavaScript code in the page context. +```json +{ + "code": "document.title", + "timeout": 5000, + "privateMode": false, + "tabId": "optional-tab-id" +} +``` +- `tabId`: Target a specific tab (from `open` response) + +### `reset` +Reset browser windows and tabs. +```json +{ + "privateMode": false, + "tabId": "optional-tab-id" +} +``` +- Omit all parameters to close all windows +- Set `privateMode` to close a specific window +- Set both `privateMode` and `tabId` to close a specific tab only + +## Usage Examples + +### Basic Navigation +```typescript +// Open a URL in normal mode (data persists) +await controller.open('https://example.com') +``` + +### Fetch Page Content +```typescript +// Open URL and get content as markdown +await open({ url: 'https://example.com', format: 'markdown' }) + +// Open URL and get raw HTML +await open({ url: 'https://example.com', format: 'html' }) +``` + +### Multi-Tab / Parallel Requests +```typescript +// Open multiple URLs in parallel using newTab +const [page1, page2] = await Promise.all([ + controller.open('https://site1.com', 10000, false, true), // newTab: true + controller.open('https://site2.com', 10000, false, true) // newTab: true +]) + +// Execute on specific tab +await controller.execute('document.title', 5000, false, page1.tabId) + +// Close specific tab when done +await controller.reset(false, page1.tabId) +``` + +### Private Browsing +```typescript +// Open a URL in private mode (no data persistence) +await controller.open('https://example.com', 10000, true) + +// Cookies and localStorage won't persist after reset +``` + +### Data Persistence (Normal Mode) +```typescript +// Set data +await controller.open('https://example.com', 10000, false) +await controller.execute('localStorage.setItem("key", "value")', 5000, false) + +// Close window +await controller.reset(false) + +// Reopen - data persists! +await controller.open('https://example.com', 10000, false) +const value = await controller.execute('localStorage.getItem("key")', 5000, false) +// Returns: "value" +``` + +### No Persistence (Private Mode) +```typescript +// Set data in private mode +await controller.open('https://example.com', 10000, true) +await controller.execute('localStorage.setItem("key", "value")', 5000, true) + +// Close private window +await controller.reset(true) + +// Reopen - data is gone! +await controller.open('https://example.com', 10000, true) +const value = await controller.execute('localStorage.getItem("key")', 5000, true) +// Returns: null +``` + +## Configuration + +```typescript +const controller = new CdpBrowserController({ + maxWindows: 5, // Maximum concurrent windows + idleTimeoutMs: 5 * 60 * 1000 // 5 minutes idle timeout (lazy cleanup) +}) +``` + +> **Note on Idle Timeout**: Idle windows are cleaned up lazily when the next window is created or accessed, not on a background timer. + +## Best Practices + +1. **Use Normal Mode for Authentication**: When you need to stay logged in across sessions +2. **Use Private Mode for Sensitive Operations**: When you don't want data to persist +3. **Use `newTab: true` for Parallel Requests**: Avoid race conditions when fetching multiple URLs +4. **Resource Cleanup**: Call `reset()` when done, or `reset(privateMode, tabId)` to close specific tabs +5. **Error Handling**: All tool handlers return error responses on failure +6. **Timeout Configuration**: Adjust timeouts based on page complexity + +## Technical Details + +- **CDP Version**: 1.3 +- **User Agent**: Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:145.0) Gecko/20100101 Firefox/145.0 +- **Storage**: + - Normal mode: `persist:default` (disk-persisted, global) + - Private mode: `private` (memory only) +- **Window Size**: 1200x800 (default) +- **Visibility**: Windows hidden by default (use `showWindow: true` to display) diff --git a/src/main/mcpServers/browser/constants.ts b/src/main/mcpServers/browser/constants.ts new file mode 100644 index 0000000000..2b10943f8e --- /dev/null +++ b/src/main/mcpServers/browser/constants.ts @@ -0,0 +1,3 @@ +export const TAB_BAR_HEIGHT = 92 // Height for Chrome-style tab bar (42px) + address bar (50px) +export const SESSION_KEY_DEFAULT = 'default' +export const SESSION_KEY_PRIVATE = 'private' diff --git a/src/main/mcpServers/browser/controller.ts b/src/main/mcpServers/browser/controller.ts new file mode 100644 index 0000000000..9e0f5220ca --- /dev/null +++ b/src/main/mcpServers/browser/controller.ts @@ -0,0 +1,909 @@ +import { titleBarOverlayDark, titleBarOverlayLight } from '@main/config' +import { isMac } from '@main/constant' +import { randomUUID } from 'crypto' +import { app, BrowserView, BrowserWindow, nativeTheme } from 'electron' +import TurndownService from 'turndown' + +import { SESSION_KEY_DEFAULT, SESSION_KEY_PRIVATE, TAB_BAR_HEIGHT } from './constants' +import { TAB_BAR_HTML } from './tabbar-html' +import { logger, type TabInfo, userAgent, type WindowInfo } from './types' + +/** + * Controller for managing browser windows via Chrome DevTools Protocol (CDP). + * Supports two modes: normal (persistent) and private (ephemeral). + * Normal mode persists user data (cookies, localStorage, etc.) globally across all clients. + * Private mode is ephemeral - data is cleared when the window closes. + */ +export class CdpBrowserController { + private windows: Map = new Map() + private readonly maxWindows: number + private readonly idleTimeoutMs: number + private readonly turndownService: TurndownService + + constructor(options?: { maxWindows?: number; idleTimeoutMs?: number }) { + this.maxWindows = options?.maxWindows ?? 5 + this.idleTimeoutMs = options?.idleTimeoutMs ?? 5 * 60 * 1000 + this.turndownService = new TurndownService() + + // Listen for theme changes and update all tab bars + nativeTheme.on('updated', () => { + const isDark = nativeTheme.shouldUseDarkColors + for (const windowInfo of this.windows.values()) { + if (windowInfo.tabBarView && !windowInfo.tabBarView.webContents.isDestroyed()) { + windowInfo.tabBarView.webContents.executeJavaScript(`window.setTheme(${isDark})`).catch(() => { + // Ignore errors if tab bar is not ready + }) + } + } + }) + } + + private getWindowKey(privateMode: boolean): string { + return privateMode ? SESSION_KEY_PRIVATE : SESSION_KEY_DEFAULT + } + + private getPartition(privateMode: boolean): string { + return privateMode ? SESSION_KEY_PRIVATE : `persist:${SESSION_KEY_DEFAULT}` + } + + private async ensureAppReady() { + if (!app.isReady()) { + await app.whenReady() + } + } + + private touchWindow(windowKey: string) { + const windowInfo = this.windows.get(windowKey) + if (windowInfo) windowInfo.lastActive = Date.now() + } + + private touchTab(windowKey: string, tabId: string) { + const windowInfo = this.windows.get(windowKey) + if (windowInfo) { + const tab = windowInfo.tabs.get(tabId) + if (tab) tab.lastActive = Date.now() + windowInfo.lastActive = Date.now() + } + } + + private closeTabInternal(windowInfo: WindowInfo, tabId: string) { + try { + const tab = windowInfo.tabs.get(tabId) + if (!tab) return + + if (!tab.view.webContents.isDestroyed()) { + if (tab.view.webContents.debugger.isAttached()) { + tab.view.webContents.debugger.detach() + } + } + + // Remove view from window + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.removeBrowserView(tab.view) + } + + // Destroy the view using safe cast + const viewWithDestroy = tab.view as BrowserView & { destroy?: () => void } + if (viewWithDestroy.destroy) { + viewWithDestroy.destroy() + } + } catch (error) { + logger.warn('Error closing tab', { error, windowKey: windowInfo.windowKey, tabId }) + } + } + + private async ensureDebuggerAttached(dbg: Electron.Debugger, sessionKey: string) { + if (!dbg.isAttached()) { + try { + logger.info('Attaching debugger', { sessionKey }) + dbg.attach('1.3') + await dbg.sendCommand('Page.enable') + await dbg.sendCommand('Runtime.enable') + logger.info('Debugger attached and domains enabled') + } catch (error) { + logger.error('Failed to attach debugger', { error }) + throw error + } + } + } + + private sweepIdle() { + const now = Date.now() + const windowKeys = Array.from(this.windows.keys()) + for (const windowKey of windowKeys) { + const windowInfo = this.windows.get(windowKey) + if (!windowInfo) continue + if (now - windowInfo.lastActive > this.idleTimeoutMs) { + const tabIds = Array.from(windowInfo.tabs.keys()) + for (const tabId of tabIds) { + this.closeTabInternal(windowInfo, tabId) + } + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.close() + } + this.windows.delete(windowKey) + } + } + } + + private evictIfNeeded(newWindowKey: string) { + if (this.windows.size < this.maxWindows) return + let lruKey: string | null = null + let lruTime = Number.POSITIVE_INFINITY + for (const [key, windowInfo] of this.windows.entries()) { + if (key === newWindowKey) continue + if (windowInfo.lastActive < lruTime) { + lruTime = windowInfo.lastActive + lruKey = key + } + } + if (lruKey) { + const windowInfo = this.windows.get(lruKey) + if (windowInfo) { + for (const [tabId] of windowInfo.tabs.entries()) { + this.closeTabInternal(windowInfo, tabId) + } + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.close() + } + } + this.windows.delete(lruKey) + logger.info('Evicted window to respect maxWindows', { evicted: lruKey }) + } + } + + private sendTabBarUpdate(windowInfo: WindowInfo) { + if (!windowInfo.tabBarView || !windowInfo.tabBarView.webContents || windowInfo.tabBarView.webContents.isDestroyed()) + return + + const tabs = Array.from(windowInfo.tabs.values()).map((tab) => ({ + id: tab.id, + title: tab.title || 'New Tab', + url: tab.url, + isActive: tab.id === windowInfo.activeTabId + })) + + let activeUrl = '' + let canGoBack = false + let canGoForward = false + + if (windowInfo.activeTabId) { + const activeTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (activeTab && !activeTab.view.webContents.isDestroyed()) { + activeUrl = activeTab.view.webContents.getURL() + canGoBack = activeTab.view.webContents.canGoBack() + canGoForward = activeTab.view.webContents.canGoForward() + } + } + + const script = `window.updateTabs(${JSON.stringify(tabs)}, ${JSON.stringify(activeUrl)}, ${canGoBack}, ${canGoForward})` + windowInfo.tabBarView.webContents.executeJavaScript(script).catch((error) => { + logger.debug('Tab bar update failed', { error, windowKey: windowInfo.windowKey }) + }) + } + + private handleNavigateAction(windowInfo: WindowInfo, url: string) { + if (!windowInfo.activeTabId) return + const activeTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (!activeTab || activeTab.view.webContents.isDestroyed()) return + + let finalUrl = url.trim() + if (!/^https?:\/\//i.test(finalUrl)) { + if (/^[a-zA-Z0-9][a-zA-Z0-9-]*\.[a-zA-Z]{2,}/.test(finalUrl) || finalUrl.includes('.')) { + finalUrl = 'https://' + finalUrl + } else { + finalUrl = 'https://www.google.com/search?q=' + encodeURIComponent(finalUrl) + } + } + + activeTab.view.webContents.loadURL(finalUrl).catch((error) => { + logger.warn('Navigation failed in tab bar', { error, url: finalUrl, tabId: windowInfo.activeTabId }) + }) + } + + private handleBackAction(windowInfo: WindowInfo) { + if (!windowInfo.activeTabId) return + const activeTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (!activeTab || activeTab.view.webContents.isDestroyed()) return + + if (activeTab.view.webContents.canGoBack()) { + activeTab.view.webContents.goBack() + } + } + + private handleForwardAction(windowInfo: WindowInfo) { + if (!windowInfo.activeTabId) return + const activeTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (!activeTab || activeTab.view.webContents.isDestroyed()) return + + if (activeTab.view.webContents.canGoForward()) { + activeTab.view.webContents.goForward() + } + } + + private handleRefreshAction(windowInfo: WindowInfo) { + if (!windowInfo.activeTabId) return + const activeTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (!activeTab || activeTab.view.webContents.isDestroyed()) return + + activeTab.view.webContents.reload() + } + + private setupTabBarMessageHandler(windowInfo: WindowInfo) { + if (!windowInfo.tabBarView) return + + windowInfo.tabBarView.webContents.on('console-message', (_event, _level, message) => { + try { + const parsed = JSON.parse(message) + if (parsed?.channel === 'tabbar-action' && parsed?.payload) { + this.handleTabBarAction(windowInfo, parsed.payload) + } + } catch { + // Not a JSON message, ignore + } + }) + + windowInfo.tabBarView.webContents + .executeJavaScript(` + (function() { + window.addEventListener('message', function(e) { + if (e.data && e.data.channel === 'tabbar-action') { + console.log(JSON.stringify(e.data)); + } + }); + })(); + `) + .catch((error) => { + logger.debug('Tab bar message handler setup failed', { error, windowKey: windowInfo.windowKey }) + }) + } + + private handleTabBarAction(windowInfo: WindowInfo, action: { type: string; tabId?: string; url?: string }) { + if (action.type === 'switch' && action.tabId) { + this.switchTab(windowInfo.privateMode, action.tabId).catch((error) => { + logger.warn('Tab switch failed', { error, tabId: action.tabId, windowKey: windowInfo.windowKey }) + }) + } else if (action.type === 'close' && action.tabId) { + this.closeTab(windowInfo.privateMode, action.tabId).catch((error) => { + logger.warn('Tab close failed', { error, tabId: action.tabId, windowKey: windowInfo.windowKey }) + }) + } else if (action.type === 'new') { + this.createTab(windowInfo.privateMode, true) + .then(({ tabId }) => this.switchTab(windowInfo.privateMode, tabId)) + .catch((error) => { + logger.warn('New tab creation failed', { error, windowKey: windowInfo.windowKey }) + }) + } else if (action.type === 'navigate' && action.url) { + this.handleNavigateAction(windowInfo, action.url) + } else if (action.type === 'back') { + this.handleBackAction(windowInfo) + } else if (action.type === 'forward') { + this.handleForwardAction(windowInfo) + } else if (action.type === 'refresh') { + this.handleRefreshAction(windowInfo) + } else if (action.type === 'window-minimize') { + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.minimize() + } + } else if (action.type === 'window-maximize') { + if (!windowInfo.window.isDestroyed()) { + if (windowInfo.window.isMaximized()) { + windowInfo.window.unmaximize() + } else { + windowInfo.window.maximize() + } + } + } else if (action.type === 'window-close') { + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.close() + } + } + } + + private createTabBarView(windowInfo: WindowInfo): BrowserView { + const tabBarView = new BrowserView({ + webPreferences: { + contextIsolation: false, + sandbox: false, + nodeIntegration: false + } + }) + + windowInfo.window.addBrowserView(tabBarView) + const [width] = windowInfo.window.getContentSize() + tabBarView.setBounds({ x: 0, y: 0, width, height: TAB_BAR_HEIGHT }) + tabBarView.setAutoResize({ width: true, height: false }) + tabBarView.webContents.loadURL(`data:text/html;charset=utf-8,${encodeURIComponent(TAB_BAR_HTML)}`) + + tabBarView.webContents.on('did-finish-load', () => { + // Initialize platform for proper styling + const platform = isMac ? 'mac' : process.platform === 'win32' ? 'win' : 'linux' + tabBarView.webContents.executeJavaScript(`window.initPlatform('${platform}')`).catch((error) => { + logger.debug('Platform init failed', { error, windowKey: windowInfo.windowKey }) + }) + // Initialize theme + const isDark = nativeTheme.shouldUseDarkColors + tabBarView.webContents.executeJavaScript(`window.setTheme(${isDark})`).catch((error) => { + logger.debug('Theme init failed', { error, windowKey: windowInfo.windowKey }) + }) + this.setupTabBarMessageHandler(windowInfo) + this.sendTabBarUpdate(windowInfo) + }) + + return tabBarView + } + + private async createBrowserWindow( + windowKey: string, + privateMode: boolean, + showWindow = false + ): Promise { + await this.ensureAppReady() + + const partition = this.getPartition(privateMode) + + const win = new BrowserWindow({ + show: showWindow, + width: 1200, + height: 800, + ...(isMac + ? { + titleBarStyle: 'hidden', + titleBarOverlay: nativeTheme.shouldUseDarkColors ? titleBarOverlayDark : titleBarOverlayLight, + trafficLightPosition: { x: 8, y: 13 } + } + : { + frame: false // Frameless window for Windows and Linux + }), + webPreferences: { + contextIsolation: true, + sandbox: true, + nodeIntegration: false, + devTools: true, + partition + } + }) + + win.on('closed', () => { + const windowInfo = this.windows.get(windowKey) + if (windowInfo) { + const tabIds = Array.from(windowInfo.tabs.keys()) + for (const tabId of tabIds) { + this.closeTabInternal(windowInfo, tabId) + } + this.windows.delete(windowKey) + } + }) + + return win + } + + private async getOrCreateWindow(privateMode: boolean, showWindow = false): Promise { + await this.ensureAppReady() + this.sweepIdle() + + const windowKey = this.getWindowKey(privateMode) + + let windowInfo = this.windows.get(windowKey) + if (!windowInfo) { + this.evictIfNeeded(windowKey) + const window = await this.createBrowserWindow(windowKey, privateMode, showWindow) + windowInfo = { + windowKey, + privateMode, + window, + tabs: new Map(), + activeTabId: null, + lastActive: Date.now(), + tabBarView: undefined + } + this.windows.set(windowKey, windowInfo) + const tabBarView = this.createTabBarView(windowInfo) + windowInfo.tabBarView = tabBarView + + // Register resize listener once per window (not per tab) + // Capture windowKey to look up fresh windowInfo on each resize + windowInfo.window.on('resize', () => { + const info = this.windows.get(windowKey) + if (info) this.updateViewBounds(info) + }) + + logger.info('Created new window', { windowKey, privateMode }) + } else if (showWindow && !windowInfo.window.isDestroyed()) { + windowInfo.window.show() + } + + this.touchWindow(windowKey) + return windowInfo + } + + private updateViewBounds(windowInfo: WindowInfo) { + if (windowInfo.window.isDestroyed()) return + + const [width, height] = windowInfo.window.getContentSize() + + // Update tab bar bounds + if (windowInfo.tabBarView && !windowInfo.tabBarView.webContents.isDestroyed()) { + windowInfo.tabBarView.setBounds({ x: 0, y: 0, width, height: TAB_BAR_HEIGHT }) + } + + // Update active tab view bounds + if (windowInfo.activeTabId) { + const activeTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (activeTab && !activeTab.view.webContents.isDestroyed()) { + activeTab.view.setBounds({ + x: 0, + y: TAB_BAR_HEIGHT, + width, + height: Math.max(0, height - TAB_BAR_HEIGHT) + }) + } + } + } + + /** + * Creates a new tab in the window + * @param privateMode - If true, uses private browsing mode (default: false) + * @param showWindow - If true, shows the browser window (default: false) + * @returns Tab ID and view + */ + public async createTab(privateMode = false, showWindow = false): Promise<{ tabId: string; view: BrowserView }> { + const windowInfo = await this.getOrCreateWindow(privateMode, showWindow) + const tabId = randomUUID() + const partition = this.getPartition(privateMode) + + const view = new BrowserView({ + webPreferences: { + contextIsolation: true, + sandbox: true, + nodeIntegration: false, + devTools: true, + partition + } + }) + + view.webContents.setUserAgent(userAgent) + + const windowKey = windowInfo.windowKey + view.webContents.on('did-start-loading', () => logger.info(`did-start-loading`, { windowKey, tabId })) + view.webContents.on('dom-ready', () => logger.info(`dom-ready`, { windowKey, tabId })) + view.webContents.on('did-finish-load', () => logger.info(`did-finish-load`, { windowKey, tabId })) + view.webContents.on('did-fail-load', (_e, code, desc) => logger.warn('Navigation failed', { code, desc })) + + view.webContents.on('destroyed', () => { + windowInfo.tabs.delete(tabId) + if (windowInfo.activeTabId === tabId) { + windowInfo.activeTabId = windowInfo.tabs.keys().next().value ?? null + if (windowInfo.activeTabId) { + const newActiveTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (newActiveTab && !windowInfo.window.isDestroyed()) { + windowInfo.window.addBrowserView(newActiveTab.view) + this.updateViewBounds(windowInfo) + } + } + } + this.sendTabBarUpdate(windowInfo) + }) + + view.webContents.on('page-title-updated', (_event, title) => { + tabInfo.title = title + this.sendTabBarUpdate(windowInfo) + }) + + view.webContents.on('did-navigate', (_event, url) => { + tabInfo.url = url + this.sendTabBarUpdate(windowInfo) + }) + + view.webContents.on('did-navigate-in-page', (_event, url) => { + tabInfo.url = url + this.sendTabBarUpdate(windowInfo) + }) + + // Handle new window requests (e.g., target="_blank" links) - open in new tab instead + view.webContents.setWindowOpenHandler(({ url }) => { + // Create a new tab and navigate to the URL + this.createTab(privateMode, true) + .then(({ tabId: newTabId }) => { + return this.switchTab(privateMode, newTabId).then(() => { + const newTab = windowInfo.tabs.get(newTabId) + if (newTab && !newTab.view.webContents.isDestroyed()) { + newTab.view.webContents.loadURL(url) + } + }) + }) + .catch((error) => { + logger.warn('Failed to open link in new tab', { error, url }) + }) + return { action: 'deny' } + }) + + const tabInfo: TabInfo = { + id: tabId, + view, + url: '', + title: '', + lastActive: Date.now() + } + + windowInfo.tabs.set(tabId, tabInfo) + + // Set as active tab and add to window + if (!windowInfo.activeTabId || windowInfo.tabs.size === 1) { + windowInfo.activeTabId = tabId + windowInfo.window.addBrowserView(view) + this.updateViewBounds(windowInfo) + } + + this.sendTabBarUpdate(windowInfo) + logger.info('Created new tab', { windowKey, tabId, privateMode }) + return { tabId, view } + } + + /** + * Gets an existing tab or creates a new one + * @param privateMode - Whether to use private browsing mode + * @param tabId - Optional specific tab ID to use + * @param newTab - If true, always create a new tab (useful for parallel requests) + * @param showWindow - If true, shows the browser window (default: false) + */ + private async getTab( + privateMode: boolean, + tabId?: string, + newTab?: boolean, + showWindow = false + ): Promise<{ tabId: string; tab: TabInfo }> { + const windowInfo = await this.getOrCreateWindow(privateMode, showWindow) + + // If newTab is requested, create a fresh tab + if (newTab) { + const { tabId: freshTabId } = await this.createTab(privateMode, showWindow) + const tab = windowInfo.tabs.get(freshTabId) + if (!tab) { + throw new Error(`Tab ${freshTabId} was created but not found - it may have been closed`) + } + return { tabId: freshTabId, tab } + } + + if (tabId) { + const tab = windowInfo.tabs.get(tabId) + if (tab && !tab.view.webContents.isDestroyed()) { + this.touchTab(windowInfo.windowKey, tabId) + return { tabId, tab } + } + } + + // Use active tab or create new one + if (windowInfo.activeTabId) { + const activeTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (activeTab && !activeTab.view.webContents.isDestroyed()) { + this.touchTab(windowInfo.windowKey, windowInfo.activeTabId) + return { tabId: windowInfo.activeTabId, tab: activeTab } + } + } + + // Create new tab + const { tabId: newTabId } = await this.createTab(privateMode, showWindow) + const tab = windowInfo.tabs.get(newTabId) + if (!tab) { + throw new Error(`Tab ${newTabId} was created but not found - it may have been closed`) + } + return { tabId: newTabId, tab } + } + + /** + * Opens a URL in a browser window and waits for navigation to complete. + * @param url - The URL to navigate to + * @param timeout - Navigation timeout in milliseconds (default: 10000) + * @param privateMode - If true, uses private browsing mode (default: false) + * @param newTab - If true, always creates a new tab (useful for parallel requests) + * @param showWindow - If true, shows the browser window (default: false) + * @returns Object containing the current URL, page title, and tab ID after navigation + */ + public async open(url: string, timeout = 10000, privateMode = false, newTab = false, showWindow = false) { + const { tabId: actualTabId, tab } = await this.getTab(privateMode, undefined, newTab, showWindow) + const view = tab.view + const windowKey = this.getWindowKey(privateMode) + + logger.info('Loading URL', { url, windowKey, tabId: actualTabId, privateMode }) + const { webContents } = view + this.touchTab(windowKey, actualTabId) + + let resolved = false + let timeoutHandle: ReturnType | undefined + let onFinish: () => void + let onDomReady: () => void + let onFail: (_event: Electron.Event, code: number, desc: string) => void + + const cleanup = () => { + if (timeoutHandle) clearTimeout(timeoutHandle) + webContents.removeListener('did-finish-load', onFinish) + webContents.removeListener('did-fail-load', onFail) + webContents.removeListener('dom-ready', onDomReady) + } + + const loadPromise = new Promise((resolve, reject) => { + onFinish = () => { + if (resolved) return + resolved = true + cleanup() + resolve() + } + onDomReady = () => { + if (resolved) return + resolved = true + cleanup() + resolve() + } + onFail = (_event: Electron.Event, code: number, desc: string) => { + if (resolved) return + resolved = true + cleanup() + reject(new Error(`Navigation failed (${code}): ${desc}`)) + } + webContents.once('did-finish-load', onFinish) + webContents.once('dom-ready', onDomReady) + webContents.once('did-fail-load', onFail) + }) + + const timeoutPromise = new Promise((_, reject) => { + timeoutHandle = setTimeout(() => reject(new Error('Navigation timed out')), timeout) + }) + + try { + await Promise.race([view.webContents.loadURL(url), loadPromise, timeoutPromise]) + } finally { + cleanup() + } + + const currentUrl = webContents.getURL() + const title = await webContents.getTitle() + + // Update tab info + tab.url = currentUrl + tab.title = title + + return { currentUrl, title, tabId: actualTabId } + } + + /** + * Executes JavaScript code in the page context using Chrome DevTools Protocol. + * @param code - JavaScript code to evaluate in the page + * @param timeout - Execution timeout in milliseconds (default: 5000) + * @param privateMode - If true, targets the private browsing window (default: false) + * @param tabId - Optional specific tab ID to target; if omitted, uses the active tab + * @returns The result value from the evaluated code, or null if no value returned + */ + public async execute(code: string, timeout = 5000, privateMode = false, tabId?: string) { + const { tabId: actualTabId, tab } = await this.getTab(privateMode, tabId) + const windowKey = this.getWindowKey(privateMode) + this.touchTab(windowKey, actualTabId) + const dbg = tab.view.webContents.debugger + + await this.ensureDebuggerAttached(dbg, windowKey) + + let timeoutHandle: ReturnType | undefined + const evalPromise = dbg.sendCommand('Runtime.evaluate', { + expression: code, + awaitPromise: true, + returnByValue: true + }) + + try { + const result = await Promise.race([ + evalPromise, + new Promise((_, reject) => { + timeoutHandle = setTimeout(() => reject(new Error('Execution timed out')), timeout) + }) + ]) + + const evalResult = result as any + + if (evalResult?.exceptionDetails) { + const message = evalResult.exceptionDetails.exception?.description || 'Unknown script error' + logger.warn('Runtime.evaluate raised exception', { message }) + throw new Error(message) + } + + const value = evalResult?.result?.value ?? evalResult?.result?.description ?? null + return value + } finally { + if (timeoutHandle) clearTimeout(timeoutHandle) + } + } + + public async reset(privateMode?: boolean, tabId?: string) { + if (privateMode !== undefined && tabId) { + const windowKey = this.getWindowKey(privateMode) + const windowInfo = this.windows.get(windowKey) + if (windowInfo) { + this.closeTabInternal(windowInfo, tabId) + windowInfo.tabs.delete(tabId) + + // If no tabs left, close the window + if (windowInfo.tabs.size === 0) { + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.close() + } + this.windows.delete(windowKey) + logger.info('Browser CDP window closed (last tab closed)', { windowKey, tabId }) + return + } + + if (windowInfo.activeTabId === tabId) { + windowInfo.activeTabId = windowInfo.tabs.keys().next().value ?? null + if (windowInfo.activeTabId) { + const newActiveTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (newActiveTab && !windowInfo.window.isDestroyed()) { + windowInfo.window.addBrowserView(newActiveTab.view) + this.updateViewBounds(windowInfo) + } + } + } + this.sendTabBarUpdate(windowInfo) + } + logger.info('Browser CDP tab reset', { windowKey, tabId }) + return + } + + if (privateMode !== undefined) { + const windowKey = this.getWindowKey(privateMode) + const windowInfo = this.windows.get(windowKey) + if (windowInfo) { + const tabIds = Array.from(windowInfo.tabs.keys()) + for (const tid of tabIds) { + this.closeTabInternal(windowInfo, tid) + } + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.close() + } + } + this.windows.delete(windowKey) + logger.info('Browser CDP window reset', { windowKey, privateMode }) + return + } + + const allWindowInfos = Array.from(this.windows.values()) + for (const windowInfo of allWindowInfos) { + const tabIds = Array.from(windowInfo.tabs.keys()) + for (const tid of tabIds) { + this.closeTabInternal(windowInfo, tid) + } + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.close() + } + } + this.windows.clear() + logger.info('Browser CDP context reset (all windows)') + } + + /** + * Fetches a URL and returns content in the specified format. + * @param url - The URL to fetch + * @param format - Output format: 'html', 'txt', 'markdown', or 'json' (default: 'markdown') + * @param timeout - Navigation timeout in milliseconds (default: 10000) + * @param privateMode - If true, uses private browsing mode (default: false) + * @param newTab - If true, always creates a new tab (useful for parallel requests) + * @param showWindow - If true, shows the browser window (default: false) + * @returns Object with tabId and content in the requested format. For 'json', content is parsed object or { data: rawContent } if parsing fails + */ + public async fetch( + url: string, + format: 'html' | 'txt' | 'markdown' | 'json' = 'markdown', + timeout = 10000, + privateMode = false, + newTab = false, + showWindow = false + ): Promise<{ tabId: string; content: string | object }> { + const { tabId } = await this.open(url, timeout, privateMode, newTab, showWindow) + + const { tab } = await this.getTab(privateMode, tabId, false, showWindow) + const dbg = tab.view.webContents.debugger + const windowKey = this.getWindowKey(privateMode) + + await this.ensureDebuggerAttached(dbg, windowKey) + + let expression: string + if (format === 'json' || format === 'txt') { + expression = 'document.body.innerText' + } else { + expression = 'document.documentElement.outerHTML' + } + + let timeoutHandle: ReturnType | undefined + try { + const result = (await Promise.race([ + dbg.sendCommand('Runtime.evaluate', { + expression, + returnByValue: true + }), + new Promise((_, reject) => { + timeoutHandle = setTimeout(() => reject(new Error('Fetch content timed out')), timeout) + }) + ])) as { result?: { value?: string } } + + const rawContent = result?.result?.value ?? '' + + let content: string | object + if (format === 'markdown') { + content = this.turndownService.turndown(rawContent) + } else if (format === 'json') { + try { + content = JSON.parse(rawContent) + } catch (parseError) { + logger.warn('JSON parse failed, returning raw content', { + url, + contentLength: rawContent.length, + error: parseError + }) + content = { data: rawContent } + } + } else { + content = rawContent + } + + return { tabId, content } + } finally { + if (timeoutHandle) clearTimeout(timeoutHandle) + } + } + + /** + * Lists all tabs in a window + * @param privateMode - If true, lists tabs from private window (default: false) + */ + public async listTabs(privateMode = false): Promise> { + const windowKey = this.getWindowKey(privateMode) + const windowInfo = this.windows.get(windowKey) + if (!windowInfo) return [] + + return Array.from(windowInfo.tabs.values()).map((tab) => ({ + tabId: tab.id, + url: tab.url, + title: tab.title + })) + } + + /** + * Closes a specific tab + * @param privateMode - If true, closes tab from private window (default: false) + * @param tabId - Tab identifier to close + */ + public async closeTab(privateMode: boolean, tabId: string) { + await this.reset(privateMode, tabId) + } + + /** + * Switches the active tab + * @param privateMode - If true, switches tab in private window (default: false) + * @param tabId - Tab identifier to switch to + */ + public async switchTab(privateMode: boolean, tabId: string) { + const windowKey = this.getWindowKey(privateMode) + const windowInfo = this.windows.get(windowKey) + if (!windowInfo) throw new Error(`Window not found for ${privateMode ? 'private' : 'normal'} mode`) + + const tab = windowInfo.tabs.get(tabId) + if (!tab) throw new Error(`Tab ${tabId} not found`) + + // Remove previous active tab view (but NOT the tabBarView) + if (windowInfo.activeTabId && windowInfo.activeTabId !== tabId) { + const prevTab = windowInfo.tabs.get(windowInfo.activeTabId) + if (prevTab && !windowInfo.window.isDestroyed()) { + windowInfo.window.removeBrowserView(prevTab.view) + } + } + + windowInfo.activeTabId = tabId + + // Add the new active tab view + if (!windowInfo.window.isDestroyed()) { + windowInfo.window.addBrowserView(tab.view) + this.updateViewBounds(windowInfo) + } + + this.touchTab(windowKey, tabId) + this.sendTabBarUpdate(windowInfo) + logger.info('Switched active tab', { windowKey, tabId, privateMode }) + } +} diff --git a/src/main/mcpServers/browser/index.ts b/src/main/mcpServers/browser/index.ts new file mode 100644 index 0000000000..fbdb0a0f6e --- /dev/null +++ b/src/main/mcpServers/browser/index.ts @@ -0,0 +1,3 @@ +export { CdpBrowserController } from './controller' +export { BrowserServer } from './server' +export { BrowserServer as default } from './server' diff --git a/src/main/mcpServers/browser/server.ts b/src/main/mcpServers/browser/server.ts new file mode 100644 index 0000000000..3e889a7b66 --- /dev/null +++ b/src/main/mcpServers/browser/server.ts @@ -0,0 +1,50 @@ +import type { Server } from '@modelcontextprotocol/sdk/server/index.js' +import { Server as MCServer } from '@modelcontextprotocol/sdk/server/index.js' +import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js' +import { app } from 'electron' + +import { CdpBrowserController } from './controller' +import { toolDefinitions, toolHandlers } from './tools' + +export class BrowserServer { + public server: Server + private controller = new CdpBrowserController() + + constructor() { + const server = new MCServer( + { + name: '@cherry/browser', + version: '0.1.0' + }, + { + capabilities: { + resources: {}, + tools: {} + } + } + ) + + server.setRequestHandler(ListToolsRequestSchema, async () => { + return { + tools: toolDefinitions + } + }) + + server.setRequestHandler(CallToolRequestSchema, async (request) => { + const { name, arguments: args } = request.params + const handler = toolHandlers[name] + if (!handler) { + throw new Error('Tool not found') + } + return handler(this.controller, args) + }) + + app.on('before-quit', () => { + void this.controller.reset() + }) + + this.server = server + } +} + +export default BrowserServer diff --git a/src/main/mcpServers/browser/tabbar-html.ts b/src/main/mcpServers/browser/tabbar-html.ts new file mode 100644 index 0000000000..4a1bec0e0d --- /dev/null +++ b/src/main/mcpServers/browser/tabbar-html.ts @@ -0,0 +1,567 @@ +export const TAB_BAR_HTML = ` + + + + + + +
+
+
+ +
+
+ +
+ + + +
+
+
+ + + +
+ +
+
+ + +` diff --git a/src/main/mcpServers/browser/tools/execute.ts b/src/main/mcpServers/browser/tools/execute.ts new file mode 100644 index 0000000000..09cd79f2d1 --- /dev/null +++ b/src/main/mcpServers/browser/tools/execute.ts @@ -0,0 +1,52 @@ +import * as z from 'zod' + +import type { CdpBrowserController } from '../controller' +import { logger } from '../types' +import { errorResponse, successResponse } from './utils' + +export const ExecuteSchema = z.object({ + code: z.string().describe('JavaScript code to run in page context'), + timeout: z.number().default(5000).describe('Execution timeout in ms (default: 5000)'), + privateMode: z.boolean().optional().describe('Target private session (default: false)'), + tabId: z.string().optional().describe('Target specific tab by ID') +}) + +export const executeToolDefinition = { + name: 'execute', + description: + 'Run JavaScript in the currently open page. Use after open to: click elements, fill forms, extract content (document.body.innerText), or interact with the page. The page must be opened first with open or fetch.', + inputSchema: { + type: 'object', + properties: { + code: { + type: 'string', + description: + 'JavaScript to evaluate. Examples: document.body.innerText (get text), document.querySelector("button").click() (click), document.title (get title)' + }, + timeout: { + type: 'number', + description: 'Execution timeout in ms (default: 5000)' + }, + privateMode: { + type: 'boolean', + description: 'Target private session (default: false)' + }, + tabId: { + type: 'string', + description: 'Target specific tab by ID (from open response)' + } + }, + required: ['code'] + } +} + +export async function handleExecute(controller: CdpBrowserController, args: unknown) { + const { code, timeout, privateMode, tabId } = ExecuteSchema.parse(args) + try { + const value = await controller.execute(code, timeout, privateMode ?? false, tabId) + return successResponse(typeof value === 'string' ? value : JSON.stringify(value)) + } catch (error) { + logger.error('Execute failed', { error, code: code.slice(0, 100), privateMode, tabId }) + return errorResponse(error as Error) + } +} diff --git a/src/main/mcpServers/browser/tools/index.ts b/src/main/mcpServers/browser/tools/index.ts new file mode 100644 index 0000000000..5ba6fcae6d --- /dev/null +++ b/src/main/mcpServers/browser/tools/index.ts @@ -0,0 +1,22 @@ +export { ExecuteSchema, executeToolDefinition, handleExecute } from './execute' +export { handleOpen, OpenSchema, openToolDefinition } from './open' +export { handleReset, resetToolDefinition } from './reset' + +import type { CdpBrowserController } from '../controller' +import { executeToolDefinition, handleExecute } from './execute' +import { handleOpen, openToolDefinition } from './open' +import { handleReset, resetToolDefinition } from './reset' + +export const toolDefinitions = [openToolDefinition, executeToolDefinition, resetToolDefinition] + +export const toolHandlers: Record< + string, + ( + controller: CdpBrowserController, + args: unknown + ) => Promise<{ content: { type: string; text: string }[]; isError: boolean }> +> = { + open: handleOpen, + execute: handleExecute, + reset: handleReset +} diff --git a/src/main/mcpServers/browser/tools/open.ts b/src/main/mcpServers/browser/tools/open.ts new file mode 100644 index 0000000000..6ea9ec9e48 --- /dev/null +++ b/src/main/mcpServers/browser/tools/open.ts @@ -0,0 +1,81 @@ +import * as z from 'zod' + +import type { CdpBrowserController } from '../controller' +import { logger } from '../types' +import { errorResponse, successResponse } from './utils' + +export const OpenSchema = z.object({ + url: z.url().describe('URL to navigate to'), + format: z + .enum(['html', 'txt', 'markdown', 'json']) + .optional() + .describe('If set, return page content in this format. If not set, just open the page and return tabId.'), + timeout: z.number().optional().describe('Navigation timeout in ms (default: 10000)'), + privateMode: z.boolean().optional().describe('Use incognito mode, no data persisted (default: false)'), + newTab: z.boolean().optional().describe('Open in new tab, required for parallel requests (default: false)'), + showWindow: z.boolean().optional().default(true).describe('Show browser window (default: true)') +}) + +export const openToolDefinition = { + name: 'open', + description: + 'Navigate to a URL in a browser window. If format is specified, returns { tabId, content } with page content in that format. Otherwise, returns { currentUrl, title, tabId } for subsequent operations with execute tool. Set newTab=true when opening multiple URLs in parallel.', + inputSchema: { + type: 'object', + properties: { + url: { + type: 'string', + description: 'URL to navigate to' + }, + format: { + type: 'string', + enum: ['html', 'txt', 'markdown', 'json'], + description: 'If set, return page content in this format. If not set, just open the page and return tabId.' + }, + timeout: { + type: 'number', + description: 'Navigation timeout in ms (default: 10000)' + }, + privateMode: { + type: 'boolean', + description: 'Use incognito mode, no data persisted (default: false)' + }, + newTab: { + type: 'boolean', + description: 'Open in new tab, required for parallel requests (default: false)' + }, + showWindow: { + type: 'boolean', + description: 'Show browser window (default: true)' + } + }, + required: ['url'] + } +} + +export async function handleOpen(controller: CdpBrowserController, args: unknown) { + try { + const { url, format, timeout, privateMode, newTab, showWindow } = OpenSchema.parse(args) + + if (format) { + const { tabId, content } = await controller.fetch( + url, + format, + timeout ?? 10000, + privateMode ?? false, + newTab ?? false, + showWindow + ) + return successResponse(JSON.stringify({ tabId, content })) + } else { + const res = await controller.open(url, timeout ?? 10000, privateMode ?? false, newTab ?? false, showWindow) + return successResponse(JSON.stringify(res)) + } + } catch (error) { + logger.error('Open failed', { + error, + url: args && typeof args === 'object' && 'url' in args ? args.url : undefined + }) + return errorResponse(error instanceof Error ? error : String(error)) + } +} diff --git a/src/main/mcpServers/browser/tools/reset.ts b/src/main/mcpServers/browser/tools/reset.ts new file mode 100644 index 0000000000..fe67b74b1d --- /dev/null +++ b/src/main/mcpServers/browser/tools/reset.ts @@ -0,0 +1,43 @@ +import * as z from 'zod' + +import type { CdpBrowserController } from '../controller' +import { logger } from '../types' +import { errorResponse, successResponse } from './utils' + +export const ResetSchema = z.object({ + privateMode: z.boolean().optional().describe('true=private window, false=normal window, omit=all windows'), + tabId: z.string().optional().describe('Close specific tab only (requires privateMode)') +}) + +export const resetToolDefinition = { + name: 'reset', + description: + 'Close browser windows and clear state. Call when done browsing to free resources. Omit all parameters to close everything.', + inputSchema: { + type: 'object', + properties: { + privateMode: { + type: 'boolean', + description: 'true=reset private window only, false=reset normal window only, omit=reset all' + }, + tabId: { + type: 'string', + description: 'Close specific tab only (requires privateMode to be set)' + } + } + } +} + +export async function handleReset(controller: CdpBrowserController, args: unknown) { + try { + const { privateMode, tabId } = ResetSchema.parse(args) + await controller.reset(privateMode, tabId) + return successResponse('reset') + } catch (error) { + logger.error('Reset failed', { + error, + privateMode: args && typeof args === 'object' && 'privateMode' in args ? args.privateMode : undefined + }) + return errorResponse(error instanceof Error ? error : String(error)) + } +} diff --git a/src/main/mcpServers/browser/tools/utils.ts b/src/main/mcpServers/browser/tools/utils.ts new file mode 100644 index 0000000000..f5272ac81c --- /dev/null +++ b/src/main/mcpServers/browser/tools/utils.ts @@ -0,0 +1,14 @@ +export function successResponse(text: string) { + return { + content: [{ type: 'text', text }], + isError: false + } +} + +export function errorResponse(error: Error | string) { + const message = error instanceof Error ? error.message : error + return { + content: [{ type: 'text', text: message }], + isError: true + } +} diff --git a/src/main/mcpServers/browser/types.ts b/src/main/mcpServers/browser/types.ts new file mode 100644 index 0000000000..a59fe59665 --- /dev/null +++ b/src/main/mcpServers/browser/types.ts @@ -0,0 +1,24 @@ +import { loggerService } from '@logger' +import type { BrowserView, BrowserWindow } from 'electron' + +export const logger = loggerService.withContext('MCPBrowserCDP') +export const userAgent = + 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36' + +export interface TabInfo { + id: string + view: BrowserView + url: string + title: string + lastActive: number +} + +export interface WindowInfo { + windowKey: string + privateMode: boolean + window: BrowserWindow + tabs: Map + activeTabId: string | null + lastActive: number + tabBarView?: BrowserView +} diff --git a/src/main/mcpServers/factory.ts b/src/main/mcpServers/factory.ts index 2323701e49..909901c1c8 100644 --- a/src/main/mcpServers/factory.ts +++ b/src/main/mcpServers/factory.ts @@ -4,6 +4,7 @@ import type { BuiltinMCPServerName } from '@types' import { BuiltinMCPServerNames } from '@types' import BraveSearchServer from './brave-search' +import BrowserServer from './browser' import DiDiMcpServer from './didi-mcp' import DifyKnowledgeServer from './dify-knowledge' import FetchServer from './fetch' @@ -35,7 +36,7 @@ export function createInMemoryMCPServer( return new FetchServer().server } case BuiltinMCPServerNames.filesystem: { - return new FileSystemServer(args).server + return new FileSystemServer(envs.WORKSPACE_ROOT).server } case BuiltinMCPServerNames.difyKnowledge: { const difyKey = envs.DIFY_KEY @@ -48,6 +49,9 @@ export function createInMemoryMCPServer( const apiKey = envs.DIDI_API_KEY return new DiDiMcpServer(apiKey).server } + case BuiltinMCPServerNames.browser: { + return new BrowserServer().server + } default: throw new Error(`Unknown in-memory MCP server: ${name}`) } diff --git a/src/main/mcpServers/filesystem.ts b/src/main/mcpServers/filesystem.ts deleted file mode 100644 index ba10783881..0000000000 --- a/src/main/mcpServers/filesystem.ts +++ /dev/null @@ -1,652 +0,0 @@ -// port https://github.com/modelcontextprotocol/servers/blob/main/src/filesystem/index.ts - -import { loggerService } from '@logger' -import { Server } from '@modelcontextprotocol/sdk/server/index.js' -import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js' -import { createTwoFilesPatch } from 'diff' -import fs from 'fs/promises' -import { minimatch } from 'minimatch' -import os from 'os' -import path from 'path' -import * as z from 'zod' - -const logger = loggerService.withContext('MCP:FileSystemServer') - -// Normalize all paths consistently -function normalizePath(p: string): string { - return path.normalize(p) -} - -function expandHome(filepath: string): string { - if (filepath.startsWith('~/') || filepath === '~') { - return path.join(os.homedir(), filepath.slice(1)) - } - return filepath -} - -// Security utilities -async function validatePath(allowedDirectories: string[], requestedPath: string): Promise { - const expandedPath = expandHome(requestedPath) - const absolute = path.isAbsolute(expandedPath) - ? path.resolve(expandedPath) - : path.resolve(process.cwd(), expandedPath) - - const normalizedRequested = normalizePath(absolute) - - // Check if path is within allowed directories - const isAllowed = allowedDirectories.some((dir) => normalizedRequested.startsWith(dir)) - if (!isAllowed) { - throw new Error( - `Access denied - path outside allowed directories: ${absolute} not in ${allowedDirectories.join(', ')}` - ) - } - - // Handle symlinks by checking their real path - try { - const realPath = await fs.realpath(absolute) - const normalizedReal = normalizePath(realPath) - const isRealPathAllowed = allowedDirectories.some((dir) => normalizedReal.startsWith(dir)) - if (!isRealPathAllowed) { - throw new Error('Access denied - symlink target outside allowed directories') - } - return realPath - } catch (error) { - // For new files that don't exist yet, verify parent directory - const parentDir = path.dirname(absolute) - try { - const realParentPath = await fs.realpath(parentDir) - const normalizedParent = normalizePath(realParentPath) - const isParentAllowed = allowedDirectories.some((dir) => normalizedParent.startsWith(dir)) - if (!isParentAllowed) { - throw new Error('Access denied - parent directory outside allowed directories') - } - return absolute - } catch { - throw new Error(`Parent directory does not exist: ${parentDir}`) - } - } -} - -// Schema definitions -const ReadFileArgsSchema = z.object({ - path: z.string() -}) - -const ReadMultipleFilesArgsSchema = z.object({ - paths: z.array(z.string()) -}) - -const WriteFileArgsSchema = z.object({ - path: z.string(), - content: z.string() -}) - -const EditOperation = z.object({ - oldText: z.string().describe('Text to search for - must match exactly'), - newText: z.string().describe('Text to replace with') -}) - -const EditFileArgsSchema = z.object({ - path: z.string(), - edits: z.array(EditOperation), - dryRun: z.boolean().default(false).describe('Preview changes using git-style diff format') -}) - -const CreateDirectoryArgsSchema = z.object({ - path: z.string() -}) - -const ListDirectoryArgsSchema = z.object({ - path: z.string() -}) - -const DirectoryTreeArgsSchema = z.object({ - path: z.string() -}) - -const MoveFileArgsSchema = z.object({ - source: z.string(), - destination: z.string() -}) - -const SearchFilesArgsSchema = z.object({ - path: z.string(), - pattern: z.string(), - excludePatterns: z.array(z.string()).optional().default([]) -}) - -const GetFileInfoArgsSchema = z.object({ - path: z.string() -}) - -interface FileInfo { - size: number - created: Date - modified: Date - accessed: Date - isDirectory: boolean - isFile: boolean - permissions: string -} - -// Tool implementations -async function getFileStats(filePath: string): Promise { - const stats = await fs.stat(filePath) - return { - size: stats.size, - created: stats.birthtime, - modified: stats.mtime, - accessed: stats.atime, - isDirectory: stats.isDirectory(), - isFile: stats.isFile(), - permissions: stats.mode.toString(8).slice(-3) - } -} - -async function searchFiles( - allowedDirectories: string[], - rootPath: string, - pattern: string, - excludePatterns: string[] = [] -): Promise { - const results: string[] = [] - - async function search(currentPath: string) { - const entries = await fs.readdir(currentPath, { withFileTypes: true }) - - for (const entry of entries) { - const fullPath = path.join(currentPath, entry.name) - - try { - // Validate each path before processing - await validatePath(allowedDirectories, fullPath) - - // Check if path matches any exclude pattern - const relativePath = path.relative(rootPath, fullPath) - const shouldExclude = excludePatterns.some((pattern) => { - const globPattern = pattern.includes('*') ? pattern : `**/${pattern}/**` - return minimatch(relativePath, globPattern, { dot: true }) - }) - - if (shouldExclude) { - continue - } - - if (entry.name.toLowerCase().includes(pattern.toLowerCase())) { - results.push(fullPath) - } - - if (entry.isDirectory()) { - await search(fullPath) - } - } catch (error) { - // Skip invalid paths during search - } - } - } - - await search(rootPath) - return results -} - -// file editing and diffing utilities -function normalizeLineEndings(text: string): string { - return text.replace(/\r\n/g, '\n') -} - -function createUnifiedDiff(originalContent: string, newContent: string, filepath: string = 'file'): string { - // Ensure consistent line endings for diff - const normalizedOriginal = normalizeLineEndings(originalContent) - const normalizedNew = normalizeLineEndings(newContent) - - return createTwoFilesPatch(filepath, filepath, normalizedOriginal, normalizedNew, 'original', 'modified') -} - -async function applyFileEdits( - filePath: string, - edits: Array<{ oldText: string; newText: string }>, - dryRun = false -): Promise { - // Read file content and normalize line endings - const content = normalizeLineEndings(await fs.readFile(filePath, 'utf-8')) - - // Apply edits sequentially - let modifiedContent = content - for (const edit of edits) { - const normalizedOld = normalizeLineEndings(edit.oldText) - const normalizedNew = normalizeLineEndings(edit.newText) - - // If exact match exists, use it - if (modifiedContent.includes(normalizedOld)) { - modifiedContent = modifiedContent.replace(normalizedOld, normalizedNew) - continue - } - - // Otherwise, try line-by-line matching with flexibility for whitespace - const oldLines = normalizedOld.split('\n') - const contentLines = modifiedContent.split('\n') - let matchFound = false - - for (let i = 0; i <= contentLines.length - oldLines.length; i++) { - const potentialMatch = contentLines.slice(i, i + oldLines.length) - - // Compare lines with normalized whitespace - const isMatch = oldLines.every((oldLine, j) => { - const contentLine = potentialMatch[j] - return oldLine.trim() === contentLine.trim() - }) - - if (isMatch) { - // Preserve original indentation of first line - const originalIndent = contentLines[i].match(/^\s*/)?.[0] || '' - const newLines = normalizedNew.split('\n').map((line, j) => { - if (j === 0) return originalIndent + line.trimStart() - // For subsequent lines, try to preserve relative indentation - const oldIndent = oldLines[j]?.match(/^\s*/)?.[0] || '' - const newIndent = line.match(/^\s*/)?.[0] || '' - if (oldIndent && newIndent) { - const relativeIndent = newIndent.length - oldIndent.length - return originalIndent + ' '.repeat(Math.max(0, relativeIndent)) + line.trimStart() - } - return line - }) - - contentLines.splice(i, oldLines.length, ...newLines) - modifiedContent = contentLines.join('\n') - matchFound = true - break - } - } - - if (!matchFound) { - throw new Error(`Could not find exact match for edit:\n${edit.oldText}`) - } - } - - // Create unified diff - const diff = createUnifiedDiff(content, modifiedContent, filePath) - - // Format diff with appropriate number of backticks - let numBackticks = 3 - while (diff.includes('`'.repeat(numBackticks))) { - numBackticks++ - } - const formattedDiff = `${'`'.repeat(numBackticks)}diff\n${diff}${'`'.repeat(numBackticks)}\n\n` - - if (!dryRun) { - await fs.writeFile(filePath, modifiedContent, 'utf-8') - } - - return formattedDiff -} - -class FileSystemServer { - public server: Server - private allowedDirectories: string[] - constructor(allowedDirs: string[]) { - if (!Array.isArray(allowedDirs) || allowedDirs.length === 0) { - throw new Error('No allowed directories provided, please specify at least one directory in args') - } - - this.allowedDirectories = allowedDirs.map((dir) => normalizePath(path.resolve(expandHome(dir)))) - - // Validate that all directories exist and are accessible - this.validateDirs().catch((error) => { - logger.error('Error validating allowed directories:', error) - throw new Error(`Error validating allowed directories: ${error}`) - }) - - this.server = new Server( - { - name: 'secure-filesystem-server', - version: '0.2.0' - }, - { - capabilities: { - tools: {} - } - } - ) - this.initialize() - } - - async validateDirs() { - // Validate that all directories exist and are accessible - await Promise.all( - this.allowedDirectories.map(async (dir) => { - try { - const stats = await fs.stat(expandHome(dir)) - if (!stats.isDirectory()) { - logger.error(`Error: ${dir} is not a directory`) - throw new Error(`Error: ${dir} is not a directory`) - } - } catch (error: any) { - logger.error(`Error accessing directory ${dir}:`, error) - throw new Error(`Error accessing directory ${dir}:`, error) - } - }) - ) - } - - initialize() { - // Tool handlers - this.server.setRequestHandler(ListToolsRequestSchema, async () => { - return { - tools: [ - { - name: 'read_file', - description: - 'Read the complete contents of a file from the file system. ' + - 'Handles various text encodings and provides detailed error messages ' + - 'if the file cannot be read. Use this tool when you need to examine ' + - 'the contents of a single file. Only works within allowed directories.', - inputSchema: z.toJSONSchema(ReadFileArgsSchema) - }, - { - name: 'read_multiple_files', - description: - 'Read the contents of multiple files simultaneously. This is more ' + - 'efficient than reading files one by one when you need to analyze ' + - "or compare multiple files. Each file's content is returned with its " + - "path as a reference. Failed reads for individual files won't stop " + - 'the entire operation. Only works within allowed directories.', - inputSchema: z.toJSONSchema(ReadMultipleFilesArgsSchema) - }, - { - name: 'write_file', - description: - 'Create a new file or completely overwrite an existing file with new content. ' + - 'Use with caution as it will overwrite existing files without warning. ' + - 'Handles text content with proper encoding. Only works within allowed directories.', - inputSchema: z.toJSONSchema(WriteFileArgsSchema) - }, - { - name: 'edit_file', - description: - 'Make line-based edits to a text file. Each edit replaces exact line sequences ' + - 'with new content. Returns a git-style diff showing the changes made. ' + - 'Only works within allowed directories.', - inputSchema: z.toJSONSchema(EditFileArgsSchema) - }, - { - name: 'create_directory', - description: - 'Create a new directory or ensure a directory exists. Can create multiple ' + - 'nested directories in one operation. If the directory already exists, ' + - 'this operation will succeed silently. Perfect for setting up directory ' + - 'structures for projects or ensuring required paths exist. Only works within allowed directories.', - inputSchema: z.toJSONSchema(CreateDirectoryArgsSchema) - }, - { - name: 'list_directory', - description: - 'Get a detailed listing of all files and directories in a specified path. ' + - 'Results clearly distinguish between files and directories with [FILE] and [DIR] ' + - 'prefixes. This tool is essential for understanding directory structure and ' + - 'finding specific files within a directory. Only works within allowed directories.', - inputSchema: z.toJSONSchema(ListDirectoryArgsSchema) - }, - { - name: 'directory_tree', - description: - 'Get a recursive tree view of files and directories as a JSON structure. ' + - "Each entry includes 'name', 'type' (file/directory), and 'children' for directories. " + - 'Files have no children array, while directories always have a children array (which may be empty). ' + - 'The output is formatted with 2-space indentation for readability. Only works within allowed directories.', - inputSchema: z.toJSONSchema(DirectoryTreeArgsSchema) - }, - { - name: 'move_file', - description: - 'Move or rename files and directories. Can move files between directories ' + - 'and rename them in a single operation. If the destination exists, the ' + - 'operation will fail. Works across different directories and can be used ' + - 'for simple renaming within the same directory. Both source and destination must be within allowed directories.', - inputSchema: z.toJSONSchema(MoveFileArgsSchema) - }, - { - name: 'search_files', - description: - 'Recursively search for files and directories matching a pattern. ' + - 'Searches through all subdirectories from the starting path. The search ' + - 'is case-insensitive and matches partial names. Returns full paths to all ' + - "matching items. Great for finding files when you don't know their exact location. " + - 'Only searches within allowed directories.', - inputSchema: z.toJSONSchema(SearchFilesArgsSchema) - }, - { - name: 'get_file_info', - description: - 'Retrieve detailed metadata about a file or directory. Returns comprehensive ' + - 'information including size, creation time, last modified time, permissions, ' + - 'and type. This tool is perfect for understanding file characteristics ' + - 'without reading the actual content. Only works within allowed directories.', - inputSchema: z.toJSONSchema(GetFileInfoArgsSchema) - }, - { - name: 'list_allowed_directories', - description: - 'Returns the list of directories that this server is allowed to access. ' + - 'Use this to understand which directories are available before trying to access files.', - inputSchema: { - type: 'object', - properties: {}, - required: [] - } - } - ] - } - }) - - this.server.setRequestHandler(CallToolRequestSchema, async (request) => { - try { - const { name, arguments: args } = request.params - - switch (name) { - case 'read_file': { - const parsed = ReadFileArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for read_file: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const content = await fs.readFile(validPath, 'utf-8') - return { - content: [{ type: 'text', text: content }] - } - } - - case 'read_multiple_files': { - const parsed = ReadMultipleFilesArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for read_multiple_files: ${parsed.error}`) - } - const results = await Promise.all( - parsed.data.paths.map(async (filePath: string) => { - try { - const validPath = await validatePath(this.allowedDirectories, filePath) - const content = await fs.readFile(validPath, 'utf-8') - return `${filePath}:\n${content}\n` - } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error) - return `${filePath}: Error - ${errorMessage}` - } - }) - ) - return { - content: [{ type: 'text', text: results.join('\n---\n') }] - } - } - - case 'write_file': { - const parsed = WriteFileArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for write_file: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - await fs.writeFile(validPath, parsed.data.content, 'utf-8') - return { - content: [{ type: 'text', text: `Successfully wrote to ${parsed.data.path}` }] - } - } - - case 'edit_file': { - const parsed = EditFileArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for edit_file: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const result = await applyFileEdits(validPath, parsed.data.edits, parsed.data.dryRun) - return { - content: [{ type: 'text', text: result }] - } - } - - case 'create_directory': { - const parsed = CreateDirectoryArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for create_directory: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - await fs.mkdir(validPath, { recursive: true }) - return { - content: [{ type: 'text', text: `Successfully created directory ${parsed.data.path}` }] - } - } - - case 'list_directory': { - const parsed = ListDirectoryArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for list_directory: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const entries = await fs.readdir(validPath, { withFileTypes: true }) - const formatted = entries - .map((entry) => `${entry.isDirectory() ? '[DIR]' : '[FILE]'} ${entry.name}`) - .join('\n') - return { - content: [{ type: 'text', text: formatted }] - } - } - - case 'directory_tree': { - const parsed = DirectoryTreeArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for directory_tree: ${parsed.error}`) - } - - interface TreeEntry { - name: string - type: 'file' | 'directory' - children?: TreeEntry[] - } - - async function buildTree(allowedDirectories: string[], currentPath: string): Promise { - const validPath = await validatePath(allowedDirectories, currentPath) - const entries = await fs.readdir(validPath, { withFileTypes: true }) - const result: TreeEntry[] = [] - - for (const entry of entries) { - const entryData: TreeEntry = { - name: entry.name, - type: entry.isDirectory() ? 'directory' : 'file' - } - - if (entry.isDirectory()) { - const subPath = path.join(currentPath, entry.name) - entryData.children = await buildTree(allowedDirectories, subPath) - } - - result.push(entryData) - } - - return result - } - - const treeData = await buildTree(this.allowedDirectories, parsed.data.path) - return { - content: [ - { - type: 'text', - text: JSON.stringify(treeData, null, 2) - } - ] - } - } - - case 'move_file': { - const parsed = MoveFileArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for move_file: ${parsed.error}`) - } - const validSourcePath = await validatePath(this.allowedDirectories, parsed.data.source) - const validDestPath = await validatePath(this.allowedDirectories, parsed.data.destination) - await fs.rename(validSourcePath, validDestPath) - return { - content: [ - { type: 'text', text: `Successfully moved ${parsed.data.source} to ${parsed.data.destination}` } - ] - } - } - - case 'search_files': { - const parsed = SearchFilesArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for search_files: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const results = await searchFiles( - this.allowedDirectories, - validPath, - parsed.data.pattern, - parsed.data.excludePatterns - ) - return { - content: [{ type: 'text', text: results.length > 0 ? results.join('\n') : 'No matches found' }] - } - } - - case 'get_file_info': { - const parsed = GetFileInfoArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for get_file_info: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const info = await getFileStats(validPath) - return { - content: [ - { - type: 'text', - text: Object.entries(info) - .map(([key, value]) => `${key}: ${value}`) - .join('\n') - } - ] - } - } - - case 'list_allowed_directories': { - return { - content: [ - { - type: 'text', - text: `Allowed directories:\n${this.allowedDirectories.join('\n')}` - } - ] - } - } - - default: - throw new Error(`Unknown tool: ${name}`) - } - } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error) - return { - content: [{ type: 'text', text: `Error: ${errorMessage}` }], - isError: true - } - } - }) - } -} - -export default FileSystemServer diff --git a/src/main/mcpServers/filesystem/index.ts b/src/main/mcpServers/filesystem/index.ts new file mode 100644 index 0000000000..cec4c31cdf --- /dev/null +++ b/src/main/mcpServers/filesystem/index.ts @@ -0,0 +1,2 @@ +// Re-export FileSystemServer to maintain existing import pattern +export { default, FileSystemServer } from './server' diff --git a/src/main/mcpServers/filesystem/server.ts b/src/main/mcpServers/filesystem/server.ts new file mode 100644 index 0000000000..164ba0c9c4 --- /dev/null +++ b/src/main/mcpServers/filesystem/server.ts @@ -0,0 +1,118 @@ +import { Server } from '@modelcontextprotocol/sdk/server/index.js' +import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js' +import { app } from 'electron' +import fs from 'fs/promises' +import path from 'path' + +import { + deleteToolDefinition, + editToolDefinition, + globToolDefinition, + grepToolDefinition, + handleDeleteTool, + handleEditTool, + handleGlobTool, + handleGrepTool, + handleLsTool, + handleReadTool, + handleWriteTool, + lsToolDefinition, + readToolDefinition, + writeToolDefinition +} from './tools' +import { logger } from './types' + +export class FileSystemServer { + public server: Server + private baseDir: string + + constructor(baseDir?: string) { + if (baseDir && path.isAbsolute(baseDir)) { + this.baseDir = baseDir + logger.info(`Using provided baseDir for filesystem MCP: ${baseDir}`) + } else { + const userData = app.getPath('userData') + this.baseDir = path.join(userData, 'Data', 'Workspace') + logger.info(`Using default workspace for filesystem MCP baseDir: ${this.baseDir}`) + } + + this.server = new Server( + { + name: 'filesystem-server', + version: '2.0.0' + }, + { + capabilities: { + tools: {} + } + } + ) + + this.initialize() + } + + async initialize() { + try { + await fs.mkdir(this.baseDir, { recursive: true }) + } catch (error) { + logger.error('Failed to create filesystem MCP baseDir', { error, baseDir: this.baseDir }) + } + + // Register tool list handler + this.server.setRequestHandler(ListToolsRequestSchema, async () => { + return { + tools: [ + globToolDefinition, + lsToolDefinition, + grepToolDefinition, + readToolDefinition, + editToolDefinition, + writeToolDefinition, + deleteToolDefinition + ] + } + }) + + // Register tool call handler + this.server.setRequestHandler(CallToolRequestSchema, async (request) => { + try { + const { name, arguments: args } = request.params + + switch (name) { + case 'glob': + return await handleGlobTool(args, this.baseDir) + + case 'ls': + return await handleLsTool(args, this.baseDir) + + case 'grep': + return await handleGrepTool(args, this.baseDir) + + case 'read': + return await handleReadTool(args, this.baseDir) + + case 'edit': + return await handleEditTool(args, this.baseDir) + + case 'write': + return await handleWriteTool(args, this.baseDir) + + case 'delete': + return await handleDeleteTool(args, this.baseDir) + + default: + throw new Error(`Unknown tool: ${name}`) + } + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + logger.error(`Tool execution error for ${request.params.name}:`, { error }) + return { + content: [{ type: 'text', text: `Error: ${errorMessage}` }], + isError: true + } + } + }) + } +} + +export default FileSystemServer diff --git a/src/main/mcpServers/filesystem/tools/delete.ts b/src/main/mcpServers/filesystem/tools/delete.ts new file mode 100644 index 0000000000..83becc4f17 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/delete.ts @@ -0,0 +1,93 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { logger, validatePath } from '../types' + +// Schema definition +export const DeleteToolSchema = z.object({ + path: z.string().describe('The path to the file or directory to delete'), + recursive: z.boolean().optional().describe('For directories, whether to delete recursively (default: false)') +}) + +// Tool definition with detailed description +export const deleteToolDefinition = { + name: 'delete', + description: `Deletes a file or directory from the filesystem. + +CAUTION: This operation cannot be undone! + +- For files: simply provide the path +- For empty directories: provide the path +- For non-empty directories: set recursive=true +- The path must be an absolute path, not a relative path +- Always verify the path before deleting to avoid data loss`, + inputSchema: z.toJSONSchema(DeleteToolSchema) +} + +// Handler implementation +export async function handleDeleteTool(args: unknown, baseDir: string) { + const parsed = DeleteToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for delete: ${parsed.error}`) + } + + const targetPath = parsed.data.path + const validPath = await validatePath(targetPath, baseDir) + const recursive = parsed.data.recursive || false + + // Check if path exists and get stats + let stats + try { + stats = await fs.stat(validPath) + } catch (error: any) { + if (error.code === 'ENOENT') { + throw new Error(`Path not found: ${targetPath}`) + } + throw error + } + + const isDirectory = stats.isDirectory() + const relativePath = path.relative(baseDir, validPath) + + // Perform deletion + try { + if (isDirectory) { + if (recursive) { + // Delete directory recursively + await fs.rm(validPath, { recursive: true, force: true }) + } else { + // Try to delete empty directory + await fs.rmdir(validPath) + } + } else { + // Delete file + await fs.unlink(validPath) + } + } catch (error: any) { + if (error.code === 'ENOTEMPTY') { + throw new Error(`Directory not empty: ${targetPath}. Use recursive=true to delete non-empty directories.`) + } + throw new Error(`Failed to delete: ${error.message}`) + } + + // Log the operation + logger.info('Path deleted', { + path: validPath, + type: isDirectory ? 'directory' : 'file', + recursive: isDirectory ? recursive : undefined + }) + + // Format output + const itemType = isDirectory ? 'Directory' : 'File' + const recursiveNote = isDirectory && recursive ? ' (recursive)' : '' + + return { + content: [ + { + type: 'text', + text: `${itemType} deleted${recursiveNote}: ${relativePath}` + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/edit.ts b/src/main/mcpServers/filesystem/tools/edit.ts new file mode 100644 index 0000000000..c1a0e637ce --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/edit.ts @@ -0,0 +1,130 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { logger, replaceWithFuzzyMatch, validatePath } from '../types' + +// Schema definition +export const EditToolSchema = z.object({ + file_path: z.string().describe('The path to the file to modify'), + old_string: z.string().describe('The text to replace'), + new_string: z.string().describe('The text to replace it with'), + replace_all: z.boolean().optional().default(false).describe('Replace all occurrences of old_string (default false)') +}) + +// Tool definition with detailed description +export const editToolDefinition = { + name: 'edit', + description: `Performs exact string replacements in files. + +- You must use the 'read' tool at least once before editing +- The file_path must be an absolute path, not a relative path +- Preserve exact indentation from read output (after the line number prefix) +- Never include line number prefixes in old_string or new_string +- ALWAYS prefer editing existing files over creating new ones +- The edit will FAIL if old_string is not found in the file +- The edit will FAIL if old_string appears multiple times (provide more context or use replace_all) +- The edit will FAIL if old_string equals new_string +- Use replace_all to rename variables or replace all occurrences`, + inputSchema: z.toJSONSchema(EditToolSchema) +} + +// Handler implementation +export async function handleEditTool(args: unknown, baseDir: string) { + const parsed = EditToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for edit: ${parsed.error}`) + } + + const { file_path: filePath, old_string: oldString, new_string: newString, replace_all: replaceAll } = parsed.data + + // Validate path + const validPath = await validatePath(filePath, baseDir) + + // Check if file exists + try { + const stats = await fs.stat(validPath) + if (!stats.isFile()) { + throw new Error(`Path is not a file: ${filePath}`) + } + } catch (error: any) { + if (error.code === 'ENOENT') { + // If old_string is empty, this is a create new file operation + if (oldString === '') { + // Create parent directory if needed + const parentDir = path.dirname(validPath) + await fs.mkdir(parentDir, { recursive: true }) + + // Write the new content + await fs.writeFile(validPath, newString, 'utf-8') + + logger.info('File created', { path: validPath }) + + const relativePath = path.relative(baseDir, validPath) + return { + content: [ + { + type: 'text', + text: `Created new file: ${relativePath}\nLines: ${newString.split('\n').length}` + } + ] + } + } + throw new Error(`File not found: ${filePath}`) + } + throw error + } + + // Read current content + const content = await fs.readFile(validPath, 'utf-8') + + // Handle special case: old_string is empty (create file with content) + if (oldString === '') { + await fs.writeFile(validPath, newString, 'utf-8') + + logger.info('File overwritten', { path: validPath }) + + const relativePath = path.relative(baseDir, validPath) + return { + content: [ + { + type: 'text', + text: `Overwrote file: ${relativePath}\nLines: ${newString.split('\n').length}` + } + ] + } + } + + // Perform the replacement with fuzzy matching + const newContent = replaceWithFuzzyMatch(content, oldString, newString, replaceAll) + + // Write the modified content + await fs.writeFile(validPath, newContent, 'utf-8') + + logger.info('File edited', { + path: validPath, + replaceAll + }) + + // Generate a simple diff summary + const oldLines = content.split('\n').length + const newLines = newContent.split('\n').length + const lineDiff = newLines - oldLines + + const relativePath = path.relative(baseDir, validPath) + let diffSummary = `Edited: ${relativePath}` + if (lineDiff > 0) { + diffSummary += `\n+${lineDiff} lines` + } else if (lineDiff < 0) { + diffSummary += `\n${lineDiff} lines` + } + + return { + content: [ + { + type: 'text', + text: diffSummary + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/glob.ts b/src/main/mcpServers/filesystem/tools/glob.ts new file mode 100644 index 0000000000..d6a6b4a757 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/glob.ts @@ -0,0 +1,149 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import type { FileInfo } from '../types' +import { logger, MAX_FILES_LIMIT, runRipgrep, validatePath } from '../types' + +// Schema definition +export const GlobToolSchema = z.object({ + pattern: z.string().describe('The glob pattern to match files against'), + path: z + .string() + .optional() + .describe('The directory to search in (must be absolute path). Defaults to the base directory') +}) + +// Tool definition with detailed description +export const globToolDefinition = { + name: 'glob', + description: `Fast file pattern matching tool that works with any codebase size. + +- Supports glob patterns like "**/*.js" or "src/**/*.ts" +- Returns matching absolute file paths sorted by modification time (newest first) +- Use this when you need to find files by name patterns +- Patterns without "/" (e.g., "*.txt") match files at ANY depth in the directory tree +- Patterns with "/" (e.g., "src/*.ts") match relative to the search path +- Pattern syntax: * (any chars), ** (any path), {a,b} (alternatives), ? (single char) +- Results are limited to 100 files +- The path parameter must be an absolute path if specified +- If path is not specified, defaults to the base directory +- IMPORTANT: Omit the path field for the default directory (don't use "undefined" or "null")`, + inputSchema: z.toJSONSchema(GlobToolSchema) +} + +// Handler implementation +export async function handleGlobTool(args: unknown, baseDir: string) { + const parsed = GlobToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for glob: ${parsed.error}`) + } + + const searchPath = parsed.data.path || baseDir + const validPath = await validatePath(searchPath, baseDir) + + // Verify the search directory exists + try { + const stats = await fs.stat(validPath) + if (!stats.isDirectory()) { + throw new Error(`Path is not a directory: ${validPath}`) + } + } catch (error: unknown) { + if (error && typeof error === 'object' && 'code' in error && error.code === 'ENOENT') { + throw new Error(`Directory not found: ${validPath}`) + } + throw error + } + + // Validate pattern + const pattern = parsed.data.pattern.trim() + if (!pattern) { + throw new Error('Pattern cannot be empty') + } + + const files: FileInfo[] = [] + let truncated = false + + // Build ripgrep arguments for file listing using --glob=pattern format + const rgArgs: string[] = [ + '--files', + '--follow', + '--hidden', + `--glob=${pattern}`, + '--glob=!.git/*', + '--glob=!node_modules/*', + '--glob=!dist/*', + '--glob=!build/*', + '--glob=!__pycache__/*', + validPath + ] + + // Use ripgrep for file listing + logger.debug('Running ripgrep with args', { rgArgs }) + const rgResult = await runRipgrep(rgArgs) + logger.debug('Ripgrep result', { + ok: rgResult.ok, + exitCode: rgResult.exitCode, + stdoutLength: rgResult.stdout.length, + stdoutPreview: rgResult.stdout.slice(0, 500) + }) + + // Process results if we have stdout content + // Exit code 2 can indicate partial errors (e.g., permission denied on some dirs) but still have valid results + if (rgResult.ok && rgResult.stdout.length > 0) { + const lines = rgResult.stdout.split('\n').filter(Boolean) + logger.debug('Parsed lines from ripgrep', { lineCount: lines.length, lines }) + + for (const line of lines) { + if (files.length >= MAX_FILES_LIMIT) { + truncated = true + break + } + + const filePath = line.trim() + if (!filePath) continue + + const absolutePath = path.isAbsolute(filePath) ? filePath : path.resolve(validPath, filePath) + + try { + const stats = await fs.stat(absolutePath) + files.push({ + path: absolutePath, + type: 'file', // ripgrep --files only returns files + size: stats.size, + modified: stats.mtime + }) + } catch (error) { + logger.debug('Failed to stat file from ripgrep output, skipping', { file: absolutePath, error }) + } + } + } + + // Sort by modification time (newest first) + files.sort((a, b) => { + const aTime = a.modified ? a.modified.getTime() : 0 + const bTime = b.modified ? b.modified.getTime() : 0 + return bTime - aTime + }) + + // Format output - always use absolute paths + const output: string[] = [] + if (files.length === 0) { + output.push(`No files found matching pattern "${parsed.data.pattern}" in ${validPath}`) + } else { + output.push(...files.map((f) => f.path)) + if (truncated) { + output.push('') + output.push(`(Results truncated to ${MAX_FILES_LIMIT} files. Consider using a more specific pattern.)`) + } + } + + return { + content: [ + { + type: 'text', + text: output.join('\n') + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/grep.ts b/src/main/mcpServers/filesystem/tools/grep.ts new file mode 100644 index 0000000000..d822db9d88 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/grep.ts @@ -0,0 +1,266 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import type { GrepMatch } from '../types' +import { isBinaryFile, MAX_GREP_MATCHES, MAX_LINE_LENGTH, runRipgrep, validatePath } from '../types' + +// Schema definition +export const GrepToolSchema = z.object({ + pattern: z.string().describe('The regex pattern to search for in file contents'), + path: z + .string() + .optional() + .describe('The directory to search in (must be absolute path). Defaults to the base directory'), + include: z.string().optional().describe('File pattern to include in the search (e.g. "*.js", "*.{ts,tsx}")') +}) + +// Tool definition with detailed description +export const grepToolDefinition = { + name: 'grep', + description: `Fast content search tool that works with any codebase size. + +- Searches file contents using regular expressions +- Supports full regex syntax (e.g., "log.*Error", "function\\s+\\w+") +- Filter files by pattern with include (e.g., "*.js", "*.{ts,tsx}") +- Returns absolute file paths and line numbers with matching content +- Results are limited to 100 matches +- Binary files are automatically skipped +- Common directories (node_modules, .git, dist) are excluded +- The path parameter must be an absolute path if specified +- If path is not specified, defaults to the base directory`, + inputSchema: z.toJSONSchema(GrepToolSchema) +} + +// Handler implementation +export async function handleGrepTool(args: unknown, baseDir: string) { + const parsed = GrepToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for grep: ${parsed.error}`) + } + + const data = parsed.data + + if (!data.pattern) { + throw new Error('Pattern is required for grep') + } + + const searchPath = data.path || baseDir + const validPath = await validatePath(searchPath, baseDir) + + const matches: GrepMatch[] = [] + let truncated = false + let regex: RegExp + + // Build ripgrep arguments + const rgArgs: string[] = [ + '--no-heading', + '--line-number', + '--color', + 'never', + '--ignore-case', + '--glob', + '!.git/**', + '--glob', + '!node_modules/**', + '--glob', + '!dist/**', + '--glob', + '!build/**', + '--glob', + '!__pycache__/**' + ] + + if (data.include) { + for (const pat of data.include + .split(',') + .map((p) => p.trim()) + .filter(Boolean)) { + rgArgs.push('--glob', pat) + } + } + + rgArgs.push(data.pattern) + rgArgs.push(validPath) + + try { + regex = new RegExp(data.pattern, 'gi') + } catch (error) { + throw new Error(`Invalid regex pattern: ${data.pattern}`) + } + + async function searchFile(filePath: string): Promise { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + return + } + + try { + // Skip binary files + if (await isBinaryFile(filePath)) { + return + } + + const content = await fs.readFile(filePath, 'utf-8') + const lines = content.split('\n') + + lines.forEach((line, index) => { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + return + } + + if (regex.test(line)) { + // Truncate long lines + const truncatedLine = line.length > MAX_LINE_LENGTH ? line.substring(0, MAX_LINE_LENGTH) + '...' : line + + matches.push({ + file: filePath, + line: index + 1, + content: truncatedLine.trim() + }) + } + }) + } catch (error) { + // Skip files we can't read + } + } + + async function searchDirectory(dir: string): Promise { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + return + } + + try { + const entries = await fs.readdir(dir, { withFileTypes: true }) + + for (const entry of entries) { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + break + } + + const fullPath = path.join(dir, entry.name) + + // Skip common ignore patterns + if (entry.name.startsWith('.') && entry.name !== '.env.example') { + continue + } + if (['node_modules', 'dist', 'build', '__pycache__', '.git'].includes(entry.name)) { + continue + } + + if (entry.isFile()) { + // Check if file matches include pattern + if (data.include) { + const includePatterns = data.include.split(',').map((p) => p.trim()) + const fileName = path.basename(fullPath) + const matchesInclude = includePatterns.some((pattern) => { + // Simple glob pattern matching + const regexPattern = pattern + .replace(/\*/g, '.*') + .replace(/\?/g, '.') + .replace(/\{([^}]+)\}/g, (_, group) => `(${group.split(',').join('|')})`) + return new RegExp(`^${regexPattern}$`).test(fileName) + }) + if (!matchesInclude) { + continue + } + } + + await searchFile(fullPath) + } else if (entry.isDirectory()) { + await searchDirectory(fullPath) + } + } + } catch (error) { + // Skip directories we can't read + } + } + + // Perform the search + let usedRipgrep = false + try { + const rgResult = await runRipgrep(rgArgs) + if (rgResult.ok && rgResult.exitCode !== null && rgResult.exitCode !== 2) { + usedRipgrep = true + const lines = rgResult.stdout.split('\n').filter(Boolean) + for (const line of lines) { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + break + } + + const firstColon = line.indexOf(':') + const secondColon = line.indexOf(':', firstColon + 1) + if (firstColon === -1 || secondColon === -1) continue + + const filePart = line.slice(0, firstColon) + const linePart = line.slice(firstColon + 1, secondColon) + const contentPart = line.slice(secondColon + 1) + const lineNum = Number.parseInt(linePart, 10) + if (!Number.isFinite(lineNum)) continue + + const absoluteFilePath = path.isAbsolute(filePart) ? filePart : path.resolve(baseDir, filePart) + const truncatedLine = + contentPart.length > MAX_LINE_LENGTH ? contentPart.substring(0, MAX_LINE_LENGTH) + '...' : contentPart + + matches.push({ + file: absoluteFilePath, + line: lineNum, + content: truncatedLine.trim() + }) + } + } + } catch { + usedRipgrep = false + } + + if (!usedRipgrep) { + const stats = await fs.stat(validPath) + if (stats.isFile()) { + await searchFile(validPath) + } else { + await searchDirectory(validPath) + } + } + + // Format output + const output: string[] = [] + + if (matches.length === 0) { + output.push('No matches found') + } else { + // Group matches by file + const fileGroups = new Map() + matches.forEach((match) => { + if (!fileGroups.has(match.file)) { + fileGroups.set(match.file, []) + } + fileGroups.get(match.file)!.push(match) + }) + + // Format grouped matches - always use absolute paths + fileGroups.forEach((fileMatches, filePath) => { + output.push(`\n${filePath}:`) + fileMatches.forEach((match) => { + output.push(` ${match.line}: ${match.content}`) + }) + }) + + if (truncated) { + output.push('') + output.push(`(Results truncated to ${MAX_GREP_MATCHES} matches. Consider using a more specific pattern or path.)`) + } + } + + return { + content: [ + { + type: 'text', + text: output.join('\n') + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/index.ts b/src/main/mcpServers/filesystem/tools/index.ts new file mode 100644 index 0000000000..2e02d613c4 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/index.ts @@ -0,0 +1,8 @@ +// Export all tool definitions and handlers +export { deleteToolDefinition, handleDeleteTool } from './delete' +export { editToolDefinition, handleEditTool } from './edit' +export { globToolDefinition, handleGlobTool } from './glob' +export { grepToolDefinition, handleGrepTool } from './grep' +export { handleLsTool, lsToolDefinition } from './ls' +export { handleReadTool, readToolDefinition } from './read' +export { handleWriteTool, writeToolDefinition } from './write' diff --git a/src/main/mcpServers/filesystem/tools/ls.ts b/src/main/mcpServers/filesystem/tools/ls.ts new file mode 100644 index 0000000000..22672c9fb9 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/ls.ts @@ -0,0 +1,150 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { MAX_FILES_LIMIT, validatePath } from '../types' + +// Schema definition +export const LsToolSchema = z.object({ + path: z.string().optional().describe('The directory to list (must be absolute path). Defaults to the base directory'), + recursive: z.boolean().optional().describe('Whether to list directories recursively (default: false)') +}) + +// Tool definition with detailed description +export const lsToolDefinition = { + name: 'ls', + description: `Lists files and directories in a specified path. + +- Returns a tree-like structure with icons (📁 directories, 📄 files) +- Shows the absolute directory path in the header +- Entries are sorted alphabetically with directories first +- Can list recursively with recursive=true (up to 5 levels deep) +- Common directories (node_modules, dist, .git) are excluded +- Hidden files (starting with .) are excluded except .env.example +- Results are limited to 100 entries +- The path parameter must be an absolute path if specified +- If path is not specified, defaults to the base directory`, + inputSchema: z.toJSONSchema(LsToolSchema) +} + +// Handler implementation +export async function handleLsTool(args: unknown, baseDir: string) { + const parsed = LsToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for ls: ${parsed.error}`) + } + + const targetPath = parsed.data.path || baseDir + const validPath = await validatePath(targetPath, baseDir) + const recursive = parsed.data.recursive || false + + interface TreeNode { + name: string + type: 'file' | 'directory' + children?: TreeNode[] + } + + let fileCount = 0 + let truncated = false + + async function buildTree(dirPath: string, depth: number = 0): Promise { + if (fileCount >= MAX_FILES_LIMIT) { + truncated = true + return [] + } + + try { + const entries = await fs.readdir(dirPath, { withFileTypes: true }) + const nodes: TreeNode[] = [] + + // Sort entries: directories first, then files, alphabetically + entries.sort((a, b) => { + if (a.isDirectory() && !b.isDirectory()) return -1 + if (!a.isDirectory() && b.isDirectory()) return 1 + return a.name.localeCompare(b.name) + }) + + for (const entry of entries) { + if (fileCount >= MAX_FILES_LIMIT) { + truncated = true + break + } + + // Skip hidden files and common ignore patterns + if (entry.name.startsWith('.') && entry.name !== '.env.example') { + continue + } + if (['node_modules', 'dist', 'build', '__pycache__'].includes(entry.name)) { + continue + } + + fileCount++ + const node: TreeNode = { + name: entry.name, + type: entry.isDirectory() ? 'directory' : 'file' + } + + if (entry.isDirectory() && recursive && depth < 5) { + // Limit depth to prevent infinite recursion + const childPath = path.join(dirPath, entry.name) + node.children = await buildTree(childPath, depth + 1) + } + + nodes.push(node) + } + + return nodes + } catch (error) { + return [] + } + } + + // Build the tree + const tree = await buildTree(validPath) + + // Format as text output + function formatTree(nodes: TreeNode[], prefix: string = ''): string[] { + const lines: string[] = [] + + nodes.forEach((node, index) => { + const isLastNode = index === nodes.length - 1 + const connector = isLastNode ? '└── ' : '├── ' + const icon = node.type === 'directory' ? '📁 ' : '📄 ' + + lines.push(prefix + connector + icon + node.name) + + if (node.children && node.children.length > 0) { + const childPrefix = prefix + (isLastNode ? ' ' : '│ ') + lines.push(...formatTree(node.children, childPrefix)) + } + }) + + return lines + } + + // Generate output + const output: string[] = [] + output.push(`Directory: ${validPath}`) + output.push('') + + if (tree.length === 0) { + output.push('(empty directory)') + } else { + const treeLines = formatTree(tree, '') + output.push(...treeLines) + + if (truncated) { + output.push('') + output.push(`(Results truncated to ${MAX_FILES_LIMIT} files. Consider listing a more specific directory.)`) + } + } + + return { + content: [ + { + type: 'text', + text: output.join('\n') + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/read.ts b/src/main/mcpServers/filesystem/tools/read.ts new file mode 100644 index 0000000000..460c88dda4 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/read.ts @@ -0,0 +1,101 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { DEFAULT_READ_LIMIT, isBinaryFile, MAX_LINE_LENGTH, validatePath } from '../types' + +// Schema definition +export const ReadToolSchema = z.object({ + file_path: z.string().describe('The path to the file to read'), + offset: z.number().optional().describe('The line number to start reading from (1-based)'), + limit: z.number().optional().describe('The number of lines to read (defaults to 2000)') +}) + +// Tool definition with detailed description +export const readToolDefinition = { + name: 'read', + description: `Reads a file from the local filesystem. + +- Assumes this tool can read all files on the machine +- The file_path parameter must be an absolute path, not a relative path +- By default, reads up to 2000 lines starting from the beginning +- You can optionally specify a line offset and limit for long files +- Any lines longer than 2000 characters will be truncated +- Results are returned with line numbers starting at 1 +- Binary files are detected and rejected with an error +- Empty files return a warning`, + inputSchema: z.toJSONSchema(ReadToolSchema) +} + +// Handler implementation +export async function handleReadTool(args: unknown, baseDir: string) { + const parsed = ReadToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for read: ${parsed.error}`) + } + + const filePath = parsed.data.file_path + const validPath = await validatePath(filePath, baseDir) + + // Check if file exists + try { + const stats = await fs.stat(validPath) + if (!stats.isFile()) { + throw new Error(`Path is not a file: ${filePath}`) + } + } catch (error: any) { + if (error.code === 'ENOENT') { + throw new Error(`File not found: ${filePath}`) + } + throw error + } + + // Check if file is binary + if (await isBinaryFile(validPath)) { + throw new Error(`Cannot read binary file: ${filePath}`) + } + + // Read file content + const content = await fs.readFile(validPath, 'utf-8') + const lines = content.split('\n') + + // Apply offset and limit + const offset = (parsed.data.offset || 1) - 1 // Convert to 0-based + const limit = parsed.data.limit || DEFAULT_READ_LIMIT + + if (offset < 0 || offset >= lines.length) { + throw new Error(`Invalid offset: ${offset + 1}. File has ${lines.length} lines.`) + } + + const selectedLines = lines.slice(offset, offset + limit) + + // Format output with line numbers and truncate long lines + const output: string[] = [] + const relativePath = path.relative(baseDir, validPath) + + output.push(`File: ${relativePath}`) + if (offset > 0 || limit < lines.length) { + output.push(`Lines ${offset + 1} to ${Math.min(offset + limit, lines.length)} of ${lines.length}`) + } + output.push('') + + selectedLines.forEach((line, index) => { + const lineNumber = offset + index + 1 + const truncatedLine = line.length > MAX_LINE_LENGTH ? line.substring(0, MAX_LINE_LENGTH) + '...' : line + output.push(`${lineNumber.toString().padStart(6)}\t${truncatedLine}`) + }) + + if (offset + limit < lines.length) { + output.push('') + output.push(`(${lines.length - (offset + limit)} more lines not shown)`) + } + + return { + content: [ + { + type: 'text', + text: output.join('\n') + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/write.ts b/src/main/mcpServers/filesystem/tools/write.ts new file mode 100644 index 0000000000..2898f2f874 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/write.ts @@ -0,0 +1,83 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { logger, validatePath } from '../types' + +// Schema definition +export const WriteToolSchema = z.object({ + file_path: z.string().describe('The path to the file to write'), + content: z.string().describe('The content to write to the file') +}) + +// Tool definition with detailed description +export const writeToolDefinition = { + name: 'write', + description: `Writes a file to the local filesystem. + +- This tool will overwrite the existing file if one exists at the path +- You MUST use the read tool first to understand what you're overwriting +- ALWAYS prefer using the 'edit' tool for existing files +- NEVER proactively create documentation files unless explicitly requested +- Parent directories will be created automatically if they don't exist +- The file_path must be an absolute path, not a relative path`, + inputSchema: z.toJSONSchema(WriteToolSchema) +} + +// Handler implementation +export async function handleWriteTool(args: unknown, baseDir: string) { + const parsed = WriteToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for write: ${parsed.error}`) + } + + const filePath = parsed.data.file_path + const validPath = await validatePath(filePath, baseDir) + + // Create parent directory if it doesn't exist + const parentDir = path.dirname(validPath) + try { + await fs.mkdir(parentDir, { recursive: true }) + } catch (error: any) { + if (error.code !== 'EEXIST') { + throw new Error(`Failed to create parent directory: ${error.message}`) + } + } + + // Check if file exists (for logging) + let isOverwrite = false + try { + await fs.stat(validPath) + isOverwrite = true + } catch { + // File doesn't exist, that's fine + } + + // Write the file + try { + await fs.writeFile(validPath, parsed.data.content, 'utf-8') + } catch (error: any) { + throw new Error(`Failed to write file: ${error.message}`) + } + + // Log the operation + logger.info('File written', { + path: validPath, + overwrite: isOverwrite, + size: parsed.data.content.length + }) + + // Format output + const relativePath = path.relative(baseDir, validPath) + const action = isOverwrite ? 'Updated' : 'Created' + const lines = parsed.data.content.split('\n').length + + return { + content: [ + { + type: 'text', + text: `${action} file: ${relativePath}\n` + `Size: ${parsed.data.content.length} bytes\n` + `Lines: ${lines}` + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/types.ts b/src/main/mcpServers/filesystem/types.ts new file mode 100644 index 0000000000..922fe0b23a --- /dev/null +++ b/src/main/mcpServers/filesystem/types.ts @@ -0,0 +1,627 @@ +import { loggerService } from '@logger' +import { isMac, isWin } from '@main/constant' +import { spawn } from 'child_process' +import fs from 'fs/promises' +import os from 'os' +import path from 'path' + +export const logger = loggerService.withContext('MCP:FileSystemServer') + +// Constants +export const MAX_LINE_LENGTH = 2000 +export const DEFAULT_READ_LIMIT = 2000 +export const MAX_FILES_LIMIT = 100 +export const MAX_GREP_MATCHES = 100 + +// Common types +export interface FileInfo { + path: string + type: 'file' | 'directory' + size?: number + modified?: Date +} + +export interface GrepMatch { + file: string + line: number + content: string +} + +// Utility functions for path handling +export function normalizePath(p: string): string { + return path.normalize(p) +} + +export function expandHome(filepath: string): string { + if (filepath.startsWith('~/') || filepath === '~') { + return path.join(os.homedir(), filepath.slice(1)) + } + return filepath +} + +// Security validation +export async function validatePath(requestedPath: string, baseDir?: string): Promise { + const expandedPath = expandHome(requestedPath) + const root = baseDir ?? process.cwd() + const absolute = path.isAbsolute(expandedPath) ? path.resolve(expandedPath) : path.resolve(root, expandedPath) + + // Handle symlinks by checking their real path + try { + const realPath = await fs.realpath(absolute) + return normalizePath(realPath) + } catch (error) { + // For new files that don't exist yet, verify parent directory + const parentDir = path.dirname(absolute) + try { + const realParentPath = await fs.realpath(parentDir) + normalizePath(realParentPath) + return normalizePath(absolute) + } catch { + return normalizePath(absolute) + } + } +} + +// ============================================================================ +// Edit Tool Utilities - Fuzzy matching replacers from opencode +// ============================================================================ + +export type Replacer = (content: string, find: string) => Generator + +// Similarity thresholds for block anchor fallback matching +const SINGLE_CANDIDATE_SIMILARITY_THRESHOLD = 0.0 +const MULTIPLE_CANDIDATES_SIMILARITY_THRESHOLD = 0.3 + +/** + * Levenshtein distance algorithm implementation + */ +function levenshtein(a: string, b: string): number { + if (a === '' || b === '') { + return Math.max(a.length, b.length) + } + const matrix = Array.from({ length: a.length + 1 }, (_, i) => + Array.from({ length: b.length + 1 }, (_, j) => (i === 0 ? j : j === 0 ? i : 0)) + ) + + for (let i = 1; i <= a.length; i++) { + for (let j = 1; j <= b.length; j++) { + const cost = a[i - 1] === b[j - 1] ? 0 : 1 + matrix[i][j] = Math.min(matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + cost) + } + } + return matrix[a.length][b.length] +} + +export const SimpleReplacer: Replacer = function* (_content, find) { + yield find +} + +export const LineTrimmedReplacer: Replacer = function* (content, find) { + const originalLines = content.split('\n') + const searchLines = find.split('\n') + + if (searchLines[searchLines.length - 1] === '') { + searchLines.pop() + } + + for (let i = 0; i <= originalLines.length - searchLines.length; i++) { + let matches = true + + for (let j = 0; j < searchLines.length; j++) { + const originalTrimmed = originalLines[i + j].trim() + const searchTrimmed = searchLines[j].trim() + + if (originalTrimmed !== searchTrimmed) { + matches = false + break + } + } + + if (matches) { + let matchStartIndex = 0 + for (let k = 0; k < i; k++) { + matchStartIndex += originalLines[k].length + 1 + } + + let matchEndIndex = matchStartIndex + for (let k = 0; k < searchLines.length; k++) { + matchEndIndex += originalLines[i + k].length + if (k < searchLines.length - 1) { + matchEndIndex += 1 + } + } + + yield content.substring(matchStartIndex, matchEndIndex) + } + } +} + +export const BlockAnchorReplacer: Replacer = function* (content, find) { + const originalLines = content.split('\n') + const searchLines = find.split('\n') + + if (searchLines.length < 3) { + return + } + + if (searchLines[searchLines.length - 1] === '') { + searchLines.pop() + } + + const firstLineSearch = searchLines[0].trim() + const lastLineSearch = searchLines[searchLines.length - 1].trim() + const searchBlockSize = searchLines.length + + const candidates: Array<{ startLine: number; endLine: number }> = [] + for (let i = 0; i < originalLines.length; i++) { + if (originalLines[i].trim() !== firstLineSearch) { + continue + } + + for (let j = i + 2; j < originalLines.length; j++) { + if (originalLines[j].trim() === lastLineSearch) { + candidates.push({ startLine: i, endLine: j }) + break + } + } + } + + if (candidates.length === 0) { + return + } + + if (candidates.length === 1) { + const { startLine, endLine } = candidates[0] + const actualBlockSize = endLine - startLine + 1 + + let similarity = 0 + const linesToCheck = Math.min(searchBlockSize - 2, actualBlockSize - 2) + + if (linesToCheck > 0) { + for (let j = 1; j < searchBlockSize - 1 && j < actualBlockSize - 1; j++) { + const originalLine = originalLines[startLine + j].trim() + const searchLine = searchLines[j].trim() + const maxLen = Math.max(originalLine.length, searchLine.length) + if (maxLen === 0) { + continue + } + const distance = levenshtein(originalLine, searchLine) + similarity += (1 - distance / maxLen) / linesToCheck + + if (similarity >= SINGLE_CANDIDATE_SIMILARITY_THRESHOLD) { + break + } + } + } else { + similarity = 1.0 + } + + if (similarity >= SINGLE_CANDIDATE_SIMILARITY_THRESHOLD) { + let matchStartIndex = 0 + for (let k = 0; k < startLine; k++) { + matchStartIndex += originalLines[k].length + 1 + } + let matchEndIndex = matchStartIndex + for (let k = startLine; k <= endLine; k++) { + matchEndIndex += originalLines[k].length + if (k < endLine) { + matchEndIndex += 1 + } + } + yield content.substring(matchStartIndex, matchEndIndex) + } + return + } + + let bestMatch: { startLine: number; endLine: number } | null = null + let maxSimilarity = -1 + + for (const candidate of candidates) { + const { startLine, endLine } = candidate + const actualBlockSize = endLine - startLine + 1 + + let similarity = 0 + const linesToCheck = Math.min(searchBlockSize - 2, actualBlockSize - 2) + + if (linesToCheck > 0) { + for (let j = 1; j < searchBlockSize - 1 && j < actualBlockSize - 1; j++) { + const originalLine = originalLines[startLine + j].trim() + const searchLine = searchLines[j].trim() + const maxLen = Math.max(originalLine.length, searchLine.length) + if (maxLen === 0) { + continue + } + const distance = levenshtein(originalLine, searchLine) + similarity += 1 - distance / maxLen + } + similarity /= linesToCheck + } else { + similarity = 1.0 + } + + if (similarity > maxSimilarity) { + maxSimilarity = similarity + bestMatch = candidate + } + } + + if (maxSimilarity >= MULTIPLE_CANDIDATES_SIMILARITY_THRESHOLD && bestMatch) { + const { startLine, endLine } = bestMatch + let matchStartIndex = 0 + for (let k = 0; k < startLine; k++) { + matchStartIndex += originalLines[k].length + 1 + } + let matchEndIndex = matchStartIndex + for (let k = startLine; k <= endLine; k++) { + matchEndIndex += originalLines[k].length + if (k < endLine) { + matchEndIndex += 1 + } + } + yield content.substring(matchStartIndex, matchEndIndex) + } +} + +export const WhitespaceNormalizedReplacer: Replacer = function* (content, find) { + const normalizeWhitespace = (text: string) => text.replace(/\s+/g, ' ').trim() + const normalizedFind = normalizeWhitespace(find) + + const lines = content.split('\n') + for (let i = 0; i < lines.length; i++) { + const line = lines[i] + if (normalizeWhitespace(line) === normalizedFind) { + yield line + } else { + const normalizedLine = normalizeWhitespace(line) + if (normalizedLine.includes(normalizedFind)) { + const words = find.trim().split(/\s+/) + if (words.length > 0) { + const pattern = words.map((word) => word.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')).join('\\s+') + try { + const regex = new RegExp(pattern) + const match = line.match(regex) + if (match) { + yield match[0] + } + } catch { + // Invalid regex pattern, skip + } + } + } + } + } + + const findLines = find.split('\n') + if (findLines.length > 1) { + for (let i = 0; i <= lines.length - findLines.length; i++) { + const block = lines.slice(i, i + findLines.length) + if (normalizeWhitespace(block.join('\n')) === normalizedFind) { + yield block.join('\n') + } + } + } +} + +export const IndentationFlexibleReplacer: Replacer = function* (content, find) { + const removeIndentation = (text: string) => { + const lines = text.split('\n') + const nonEmptyLines = lines.filter((line) => line.trim().length > 0) + if (nonEmptyLines.length === 0) return text + + const minIndent = Math.min( + ...nonEmptyLines.map((line) => { + const match = line.match(/^(\s*)/) + return match ? match[1].length : 0 + }) + ) + + return lines.map((line) => (line.trim().length === 0 ? line : line.slice(minIndent))).join('\n') + } + + const normalizedFind = removeIndentation(find) + const contentLines = content.split('\n') + const findLines = find.split('\n') + + for (let i = 0; i <= contentLines.length - findLines.length; i++) { + const block = contentLines.slice(i, i + findLines.length).join('\n') + if (removeIndentation(block) === normalizedFind) { + yield block + } + } +} + +export const EscapeNormalizedReplacer: Replacer = function* (content, find) { + const unescapeString = (str: string): string => { + return str.replace(/\\(n|t|r|'|"|`|\\|\n|\$)/g, (match, capturedChar) => { + switch (capturedChar) { + case 'n': + return '\n' + case 't': + return '\t' + case 'r': + return '\r' + case "'": + return "'" + case '"': + return '"' + case '`': + return '`' + case '\\': + return '\\' + case '\n': + return '\n' + case '$': + return '$' + default: + return match + } + }) + } + + const unescapedFind = unescapeString(find) + + if (content.includes(unescapedFind)) { + yield unescapedFind + } + + const lines = content.split('\n') + const findLines = unescapedFind.split('\n') + + for (let i = 0; i <= lines.length - findLines.length; i++) { + const block = lines.slice(i, i + findLines.length).join('\n') + const unescapedBlock = unescapeString(block) + + if (unescapedBlock === unescapedFind) { + yield block + } + } +} + +export const TrimmedBoundaryReplacer: Replacer = function* (content, find) { + const trimmedFind = find.trim() + + if (trimmedFind === find) { + return + } + + if (content.includes(trimmedFind)) { + yield trimmedFind + } + + const lines = content.split('\n') + const findLines = find.split('\n') + + for (let i = 0; i <= lines.length - findLines.length; i++) { + const block = lines.slice(i, i + findLines.length).join('\n') + + if (block.trim() === trimmedFind) { + yield block + } + } +} + +export const ContextAwareReplacer: Replacer = function* (content, find) { + const findLines = find.split('\n') + if (findLines.length < 3) { + return + } + + if (findLines[findLines.length - 1] === '') { + findLines.pop() + } + + const contentLines = content.split('\n') + + const firstLine = findLines[0].trim() + const lastLine = findLines[findLines.length - 1].trim() + + for (let i = 0; i < contentLines.length; i++) { + if (contentLines[i].trim() !== firstLine) continue + + for (let j = i + 2; j < contentLines.length; j++) { + if (contentLines[j].trim() === lastLine) { + const blockLines = contentLines.slice(i, j + 1) + const block = blockLines.join('\n') + + if (blockLines.length === findLines.length) { + let matchingLines = 0 + let totalNonEmptyLines = 0 + + for (let k = 1; k < blockLines.length - 1; k++) { + const blockLine = blockLines[k].trim() + const findLine = findLines[k].trim() + + if (blockLine.length > 0 || findLine.length > 0) { + totalNonEmptyLines++ + if (blockLine === findLine) { + matchingLines++ + } + } + } + + if (totalNonEmptyLines === 0 || matchingLines / totalNonEmptyLines >= 0.5) { + yield block + break + } + } + break + } + } + } +} + +export const MultiOccurrenceReplacer: Replacer = function* (content, find) { + let startIndex = 0 + + while (true) { + const index = content.indexOf(find, startIndex) + if (index === -1) break + + yield find + startIndex = index + find.length + } +} + +/** + * All replacers in order of specificity + */ +export const ALL_REPLACERS: Replacer[] = [ + SimpleReplacer, + LineTrimmedReplacer, + BlockAnchorReplacer, + WhitespaceNormalizedReplacer, + IndentationFlexibleReplacer, + EscapeNormalizedReplacer, + TrimmedBoundaryReplacer, + ContextAwareReplacer, + MultiOccurrenceReplacer +] + +/** + * Replace oldString with newString in content using fuzzy matching + */ +export function replaceWithFuzzyMatch( + content: string, + oldString: string, + newString: string, + replaceAll = false +): string { + if (oldString === newString) { + throw new Error('old_string and new_string must be different') + } + + let notFound = true + + for (const replacer of ALL_REPLACERS) { + for (const search of replacer(content, oldString)) { + const index = content.indexOf(search) + if (index === -1) continue + notFound = false + if (replaceAll) { + return content.replaceAll(search, newString) + } + const lastIndex = content.lastIndexOf(search) + if (index !== lastIndex) continue + return content.substring(0, index) + newString + content.substring(index + search.length) + } + } + + if (notFound) { + throw new Error('old_string not found in content') + } + throw new Error( + 'Found multiple matches for old_string. Provide more surrounding lines in old_string to identify the correct match.' + ) +} + +// ============================================================================ +// Binary File Detection +// ============================================================================ + +// Check if a file is likely binary +export async function isBinaryFile(filePath: string): Promise { + try { + const buffer = Buffer.alloc(4096) + const fd = await fs.open(filePath, 'r') + const { bytesRead } = await fd.read(buffer, 0, buffer.length, 0) + await fd.close() + + if (bytesRead === 0) return false + + const view = buffer.subarray(0, bytesRead) + + let zeroBytes = 0 + let evenZeros = 0 + let oddZeros = 0 + let nonPrintable = 0 + + for (let i = 0; i < view.length; i++) { + const b = view[i] + + if (b === 0) { + zeroBytes++ + if (i % 2 === 0) evenZeros++ + else oddZeros++ + continue + } + + // treat common whitespace as printable + if (b === 9 || b === 10 || b === 13) continue + + // basic ASCII printable range + if (b >= 32 && b <= 126) continue + + // bytes >= 128 are likely part of UTF-8 sequences; count as printable + if (b >= 128) continue + + nonPrintable++ + } + + // If there are lots of null bytes, it's probably binary unless it looks like UTF-16 text. + if (zeroBytes > 0) { + const evenSlots = Math.ceil(view.length / 2) + const oddSlots = Math.floor(view.length / 2) + const evenZeroRatio = evenSlots > 0 ? evenZeros / evenSlots : 0 + const oddZeroRatio = oddSlots > 0 ? oddZeros / oddSlots : 0 + + // UTF-16LE/BE tends to have zeros on every other byte. + if (evenZeroRatio > 0.7 || oddZeroRatio > 0.7) return false + + if (zeroBytes / view.length > 0.05) return true + } + + // Heuristic: too many non-printable bytes => binary. + return nonPrintable / view.length > 0.3 + } catch { + return false + } +} + +// ============================================================================ +// Ripgrep Utilities +// ============================================================================ + +export interface RipgrepResult { + ok: boolean + stdout: string + exitCode: number | null +} + +export function getRipgrepAddonPath(): string { + const pkgJsonPath = require.resolve('@anthropic-ai/claude-agent-sdk/package.json') + const pkgRoot = path.dirname(pkgJsonPath) + const platform = isMac ? 'darwin' : isWin ? 'win32' : 'linux' + const arch = process.arch === 'arm64' ? 'arm64' : 'x64' + return path.join(pkgRoot, 'vendor', 'ripgrep', `${arch}-${platform}`, 'ripgrep.node') +} + +export async function runRipgrep(args: string[]): Promise { + const addonPath = getRipgrepAddonPath() + const childScript = `const { ripgrepMain } = require(process.env.RIPGREP_ADDON_PATH); process.exit(ripgrepMain(process.argv.slice(1)));` + + return new Promise((resolve) => { + const child = spawn(process.execPath, ['--eval', childScript, 'rg', ...args], { + cwd: process.cwd(), + env: { + ...process.env, + ELECTRON_RUN_AS_NODE: '1', + RIPGREP_ADDON_PATH: addonPath + }, + stdio: ['ignore', 'pipe', 'pipe'] + }) + + let stdout = '' + + child.stdout?.on('data', (chunk) => { + stdout += chunk.toString('utf-8') + }) + + child.on('error', () => { + resolve({ ok: false, stdout: '', exitCode: null }) + }) + + child.on('close', (code) => { + resolve({ ok: true, stdout, exitCode: code }) + }) + }) +} diff --git a/src/main/services/BackupManager.ts b/src/main/services/BackupManager.ts index f331254fdf..e08bbd4d7b 100644 --- a/src/main/services/BackupManager.ts +++ b/src/main/services/BackupManager.ts @@ -1,3 +1,19 @@ +/** + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- + */ import { loggerService } from '@logger' import { IpcChannel } from '@shared/IpcChannel' import type { WebDavConfig } from '@types' @@ -767,6 +783,56 @@ class BackupManager { const s3Client = this.getS3Storage(s3Config) return await s3Client.checkConnection() } + + /** + * Create a temporary backup for LAN transfer + * Creates a lightweight backup (skipBackupFile=true) in the temp directory + * Returns the path to the created ZIP file + */ + async createLanTransferBackup(_: Electron.IpcMainInvokeEvent, data: string): Promise { + const timestamp = new Date() + .toISOString() + .replace(/[-:T.Z]/g, '') + .slice(0, 12) + const fileName = `cherry-studio.${timestamp}.zip` + const tempPath = path.join(app.getPath('temp'), 'cherry-studio', 'lan-transfer') + + // Ensure temp directory exists + await fs.ensureDir(tempPath) + + // Create backup with skipBackupFile=true (no Data folder) + const backupedFilePath = await this.backup(_, fileName, data, tempPath, true) + + logger.info(`[BackupManager] Created LAN transfer backup at: ${backupedFilePath}`) + return backupedFilePath + } + + /** + * Delete a temporary backup file after LAN transfer completes + */ + async deleteTempBackup(_: Electron.IpcMainInvokeEvent, filePath: string): Promise { + try { + // Security check: only allow deletion within temp directory + const tempBase = path.normalize(path.join(app.getPath('temp'), 'cherry-studio', 'lan-transfer')) + const resolvedPath = path.normalize(path.resolve(filePath)) + + // Use normalized paths with trailing separator to prevent prefix attacks (e.g., /temp-evil) + if (!resolvedPath.startsWith(tempBase + path.sep) && resolvedPath !== tempBase) { + logger.warn(`[BackupManager] Attempted to delete file outside temp directory: ${filePath}`) + return false + } + + if (await fs.pathExists(resolvedPath)) { + await fs.remove(resolvedPath) + logger.info(`[BackupManager] Deleted temp backup: ${resolvedPath}`) + return true + } + return false + } catch (error) { + logger.error('[BackupManager] Failed to delete temp backup:', error as Error) + return false + } + } } export default BackupManager diff --git a/src/main/services/ConfigManager.ts b/src/main/services/ConfigManager.ts index 1767ce7b48..6ee96580e4 100644 --- a/src/main/services/ConfigManager.ts +++ b/src/main/services/ConfigManager.ts @@ -1,3 +1,19 @@ +/** + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- + */ import { ZOOM_SHORTCUTS } from '@shared/config/constant' import type { Shortcut } from '@types' import Store from 'electron-store' @@ -27,7 +43,8 @@ export enum ConfigKeys { Proxy = 'proxy', EnableDeveloperMode = 'enableDeveloperMode', ClientId = 'clientId', - GitBashPath = 'gitBashPath' + GitBashPath = 'gitBashPath', + GitBashPathSource = 'gitBashPathSource' // 'manual' | 'auto' | null } export class ConfigManager { diff --git a/src/main/services/FileStorage.ts b/src/main/services/FileStorage.ts index 81f5c15bd9..2d7520ca67 100644 --- a/src/main/services/FileStorage.ts +++ b/src/main/services/FileStorage.ts @@ -2,7 +2,7 @@ import { loggerService } from '@logger' import { checkName, getFilesDir, - getFileType, + getFileType as getFileTypeByExt, getName, getNotesDir, getTempDir, @@ -11,13 +11,13 @@ import { } from '@main/utils/file' import { documentExts, imageExts, KB, MB } from '@shared/config/constant' import type { FileMetadata, NotesTreeNode } from '@types' +import { FileTypes } from '@types' import chardet from 'chardet' import type { FSWatcher } from 'chokidar' import chokidar from 'chokidar' import * as crypto from 'crypto' import type { OpenDialogOptions, OpenDialogReturnValue, SaveDialogOptions, SaveDialogReturnValue } from 'electron' -import { app } from 'electron' -import { dialog, net, shell } from 'electron' +import { app, dialog, net, shell } from 'electron' import * as fs from 'fs' import { writeFileSync } from 'fs' import { readFile } from 'fs/promises' @@ -130,16 +130,18 @@ interface DirectoryListOptions { includeDirectories?: boolean maxEntries?: number searchPattern?: string + fuzzy?: boolean } const DEFAULT_DIRECTORY_LIST_OPTIONS: Required = { recursive: true, - maxDepth: 3, + maxDepth: 10, includeHidden: false, includeFiles: true, includeDirectories: true, - maxEntries: 10, - searchPattern: '.' + maxEntries: 20, + searchPattern: '.', + fuzzy: true } class FileStorage { @@ -163,7 +165,7 @@ class FileStorage { fs.mkdirSync(this.storageDir, { recursive: true }) } if (!fs.existsSync(this.notesDir)) { - fs.mkdirSync(this.storageDir, { recursive: true }) + fs.mkdirSync(this.notesDir, { recursive: true }) } if (!fs.existsSync(this.tempDir)) { fs.mkdirSync(this.tempDir, { recursive: true }) @@ -185,7 +187,7 @@ class FileStorage { }) } - findDuplicateFile = async (filePath: string): Promise => { + private findDuplicateFile = async (filePath: string): Promise => { const stats = fs.statSync(filePath) logger.debug(`stats: ${stats}, filePath: ${filePath}`) const fileSize = stats.size @@ -204,6 +206,8 @@ class FileStorage { if (originalHash === storedHash) { const ext = path.extname(file) const id = path.basename(file, ext) + const type = await this.getFileType(filePath) + return { id, origin_name: file, @@ -212,7 +216,7 @@ class FileStorage { created_at: storedStats.birthtime.toISOString(), size: storedStats.size, ext, - type: getFileType(ext), + type, count: 2 } } @@ -222,6 +226,13 @@ class FileStorage { return null } + public getFileType = async (filePath: string): Promise => { + const ext = path.extname(filePath) + const fileType = getFileTypeByExt(ext) + + return fileType === FileTypes.OTHER && (await this._isTextFile(filePath)) ? FileTypes.TEXT : fileType + } + public selectFile = async ( _: Electron.IpcMainInvokeEvent, options?: OpenDialogOptions @@ -241,7 +252,7 @@ class FileStorage { const fileMetadataPromises = result.filePaths.map(async (filePath) => { const stats = fs.statSync(filePath) const ext = path.extname(filePath) - const fileType = getFileType(ext) + const fileType = await this.getFileType(filePath) return { id: uuidv4(), @@ -307,7 +318,7 @@ class FileStorage { } const stats = await fs.promises.stat(destPath) - const fileType = getFileType(ext) + const fileType = await this.getFileType(destPath) const fileMetadata: FileMetadata = { id: uuid, @@ -332,8 +343,7 @@ class FileStorage { } const stats = fs.statSync(filePath) - const ext = path.extname(filePath) - const fileType = getFileType(ext) + const fileType = await this.getFileType(filePath) return { id: uuidv4(), @@ -342,7 +352,7 @@ class FileStorage { path: filePath, created_at: stats.birthtime.toISOString(), size: stats.size, - ext: ext, + ext: path.extname(filePath), type: fileType, count: 1 } @@ -690,7 +700,7 @@ class FileStorage { created_at: new Date().toISOString(), size: buffer.length, ext: ext.slice(1), - type: getFileType(ext), + type: getFileTypeByExt(ext), count: 1 } } catch (error) { @@ -740,7 +750,7 @@ class FileStorage { created_at: new Date().toISOString(), size: stats.size, ext: ext.slice(1), - type: getFileType(ext), + type: getFileTypeByExt(ext), count: 1 } } catch (error) { @@ -1038,10 +1048,226 @@ class FileStorage { } /** - * Search files by content pattern + * Fuzzy match: checks if all characters in query appear in text in order (case-insensitive) + * Example: "updater" matches "packages/update/src/node/updateController.ts" */ - private async searchByContent(resolvedPath: string, options: Required): Promise { - const args: string[] = ['-l'] + private isFuzzyMatch(text: string, query: string): boolean { + let i = 0 // text index + let j = 0 // query index + const textLower = text.toLowerCase() + const queryLower = query.toLowerCase() + + while (i < textLower.length && j < queryLower.length) { + if (textLower[i] === queryLower[j]) { + j++ + } + i++ + } + return j === queryLower.length + } + + /** + * Scoring constants for fuzzy match relevance ranking + * Higher values = higher priority in search results + */ + private static readonly SCORE_SEGMENT_MATCH = 60 // Per path segment that matches query + private static readonly SCORE_FILENAME_CONTAINS = 80 // Filename contains exact query substring + private static readonly SCORE_FILENAME_STARTS = 100 // Filename starts with query (highest priority) + private static readonly SCORE_CONSECUTIVE_CHAR = 15 // Per consecutive character match + private static readonly SCORE_WORD_BOUNDARY = 20 // Query matches start of a word + private static readonly PATH_LENGTH_PENALTY_FACTOR = 4 // Logarithmic penalty multiplier for longer paths + + /** + * Calculate fuzzy match score (higher is better) + * Scoring factors: + * - Consecutive character matches (bonus) + * - Match at word boundaries (bonus) + * - Shorter path length (bonus) + * - Match in filename vs directory (bonus) + */ + private getFuzzyMatchScore(filePath: string, query: string): number { + const pathLower = filePath.toLowerCase() + const queryLower = query.toLowerCase() + const fileName = filePath.split('/').pop() || '' + const fileNameLower = fileName.toLowerCase() + + let score = 0 + + // Count how many times query-related words appear in path segments + const pathSegments = pathLower.split(/[/\\]/) + let segmentMatchCount = 0 + for (const segment of pathSegments) { + if (this.isFuzzyMatch(segment, queryLower)) { + segmentMatchCount++ + } + } + score += segmentMatchCount * FileStorage.SCORE_SEGMENT_MATCH + + // Bonus for filename starting with query (stronger than generic "contains") + if (fileNameLower.startsWith(queryLower)) { + score += FileStorage.SCORE_FILENAME_STARTS + } else if (fileNameLower.includes(queryLower)) { + // Bonus for exact substring match in filename (e.g., "updater" in "RCUpdater.js") + score += FileStorage.SCORE_FILENAME_CONTAINS + } + + // Calculate consecutive match bonus + let i = 0 + let j = 0 + let consecutiveCount = 0 + let maxConsecutive = 0 + + while (i < pathLower.length && j < queryLower.length) { + if (pathLower[i] === queryLower[j]) { + consecutiveCount++ + maxConsecutive = Math.max(maxConsecutive, consecutiveCount) + j++ + } else { + consecutiveCount = 0 + } + i++ + } + score += maxConsecutive * FileStorage.SCORE_CONSECUTIVE_CHAR + + // Bonus for word boundary matches (e.g., "upd" matches start of "update") + // Only count once to avoid inflating scores for paths with repeated patterns + const boundaryPrefix = queryLower.slice(0, Math.min(3, queryLower.length)) + const words = pathLower.split(/[/\\._-]/) + for (const word of words) { + if (word.startsWith(boundaryPrefix)) { + score += FileStorage.SCORE_WORD_BOUNDARY + break + } + } + + // Penalty for longer paths (prefer shorter, more specific matches) + // Use logarithmic scaling to prevent long paths from dominating the score + // A 50-char path gets ~-16 penalty, 100-char gets ~-18, 200-char gets ~-21 + score -= Math.log(filePath.length + 1) * FileStorage.PATH_LENGTH_PENALTY_FACTOR + + return score + } + + /** + * Convert query to glob pattern for ripgrep pre-filtering + * e.g., "updater" -> "*u*p*d*a*t*e*r*" + */ + private queryToGlobPattern(query: string): string { + // Escape special glob characters (including ! for negation) + const escaped = query.replace(/[[\]{}()*+?.,\\^$|#!]/g, '\\$&') + // Convert to fuzzy glob: each char separated by * + return '*' + escaped.split('').join('*') + '*' + } + + /** + * Greedy substring match: check if all characters in query can be matched + * by finding consecutive substrings in text (not necessarily single chars) + * e.g., "updatercontroller" matches "updateController" by: + * "update" + "r" (from Controller) + "controller" + */ + private isGreedySubstringMatch(text: string, query: string): boolean { + const textLower = text.toLowerCase() + const queryLower = query.toLowerCase() + + let queryIndex = 0 + let searchStart = 0 + + while (queryIndex < queryLower.length) { + // Try to find the longest matching substring starting at queryIndex + let bestMatchLen = 0 + let bestMatchPos = -1 + + for (let len = queryLower.length - queryIndex; len >= 1; len--) { + const substr = queryLower.slice(queryIndex, queryIndex + len) + const foundAt = textLower.indexOf(substr, searchStart) + if (foundAt !== -1) { + bestMatchLen = len + bestMatchPos = foundAt + break // Found longest possible match + } + } + + if (bestMatchLen === 0) { + // No substring match found, query cannot be matched + return false + } + + queryIndex += bestMatchLen + searchStart = bestMatchPos + bestMatchLen + } + + return true + } + + /** + * Calculate greedy substring match score (higher is better) + * Rewards: fewer match fragments, shorter match span, matches in filename + */ + private getGreedyMatchScore(filePath: string, query: string): number { + const textLower = filePath.toLowerCase() + const queryLower = query.toLowerCase() + const fileName = filePath.split('/').pop() || '' + const fileNameLower = fileName.toLowerCase() + + let queryIndex = 0 + let searchStart = 0 + let fragmentCount = 0 + let firstMatchPos = -1 + let lastMatchEnd = 0 + + while (queryIndex < queryLower.length) { + let bestMatchLen = 0 + let bestMatchPos = -1 + + for (let len = queryLower.length - queryIndex; len >= 1; len--) { + const substr = queryLower.slice(queryIndex, queryIndex + len) + const foundAt = textLower.indexOf(substr, searchStart) + if (foundAt !== -1) { + bestMatchLen = len + bestMatchPos = foundAt + break + } + } + + if (bestMatchLen === 0) { + return -Infinity // No match + } + + fragmentCount++ + if (firstMatchPos === -1) firstMatchPos = bestMatchPos + lastMatchEnd = bestMatchPos + bestMatchLen + queryIndex += bestMatchLen + searchStart = lastMatchEnd + } + + const matchSpan = lastMatchEnd - firstMatchPos + let score = 0 + + // Fewer fragments = better (single continuous match is best) + // Max bonus when fragmentCount=1, decreases as fragments increase + score += Math.max(0, 100 - (fragmentCount - 1) * 30) + + // Shorter span relative to query length = better (tighter match) + // Perfect match: span equals query length + const spanRatio = queryLower.length / matchSpan + score += spanRatio * 50 + + // Bonus for match in filename + if (this.isGreedySubstringMatch(fileNameLower, queryLower)) { + score += 80 + } + + // Penalty for longer paths + score -= Math.log(filePath.length + 1) * 4 + + return score + } + + /** + * Build common ripgrep arguments for file listing + */ + private buildRipgrepBaseArgs(options: Required, resolvedPath: string): string[] { + const args: string[] = ['--files'] // Handle hidden files if (!options.includeHidden) { @@ -1068,82 +1294,74 @@ class FileStorage { args.push('--max-depth', options.maxDepth.toString()) } - // Handle max count - if (options.maxEntries > 0) { - args.push('--max-count', options.maxEntries.toString()) - } - - // Add search pattern (search in content) - args.push(options.searchPattern) - - // Add the directory path args.push(resolvedPath) - const { exitCode, output } = await executeRipgrep(args) - - // Exit code 0 means files found, 1 means no files found (still success), 2+ means error - if (exitCode >= 2) { - throw new Error(`Ripgrep failed with exit code ${exitCode}: ${output}`) - } - - // Parse ripgrep output (already sorted by relevance) - const results = output - .split('\n') - .filter((line) => line.trim()) - .map((line) => line.replace(/\\/g, '/')) - .slice(0, options.maxEntries) - - return results + return args } private async listDirectoryWithRipgrep( resolvedPath: string, options: Required ): Promise { - const maxEntries = options.maxEntries + // Fuzzy search mode: use ripgrep glob for pre-filtering, then score in JS + if (options.fuzzy && options.searchPattern && options.searchPattern !== '.') { + const args = this.buildRipgrepBaseArgs(options, resolvedPath) - // Step 1: Search by filename first + // Insert glob pattern before the path (last element) + const globPattern = this.queryToGlobPattern(options.searchPattern) + args.splice(args.length - 1, 0, '--iglob', globPattern) + + const { exitCode, output } = await executeRipgrep(args) + + if (exitCode >= 2) { + throw new Error(`Ripgrep failed with exit code ${exitCode}: ${output}`) + } + + const filteredFiles = output + .split('\n') + .filter((line) => line.trim()) + .map((line) => line.replace(/\\/g, '/')) + + // If fuzzy glob found results, validate fuzzy match, sort and return + if (filteredFiles.length > 0) { + return filteredFiles + .filter((file) => this.isFuzzyMatch(file, options.searchPattern)) + .map((file) => ({ file, score: this.getFuzzyMatchScore(file, options.searchPattern) })) + .sort((a, b) => b.score - a.score) + .slice(0, options.maxEntries) + .map((item) => item.file) + } + + // Fallback: if no results, try greedy substring match on all files + logger.debug('Fuzzy glob returned no results, falling back to greedy substring match') + const fallbackArgs = this.buildRipgrepBaseArgs(options, resolvedPath) + + const fallbackResult = await executeRipgrep(fallbackArgs) + + if (fallbackResult.exitCode >= 2) { + return [] + } + + const allFiles = fallbackResult.output + .split('\n') + .filter((line) => line.trim()) + .map((line) => line.replace(/\\/g, '/')) + + const greedyMatched = allFiles.filter((file) => this.isGreedySubstringMatch(file, options.searchPattern)) + + return greedyMatched + .map((file) => ({ file, score: this.getGreedyMatchScore(file, options.searchPattern) })) + .sort((a, b) => b.score - a.score) + .slice(0, options.maxEntries) + .map((item) => item.file) + } + + // Fallback: search by filename only (non-fuzzy mode) logger.debug('Searching by filename pattern', { pattern: options.searchPattern, path: resolvedPath }) const filenameResults = await this.searchByFilename(resolvedPath, options) logger.debug('Found matches by filename', { count: filenameResults.length }) - - // If we have enough filename matches, return them - if (filenameResults.length >= maxEntries) { - return filenameResults.slice(0, maxEntries) - } - - // Step 2: If filename matches are less than maxEntries, search by content to fill up - logger.debug('Filename matches insufficient, searching by content to fill up', { - filenameCount: filenameResults.length, - needed: maxEntries - filenameResults.length - }) - - // Adjust maxEntries for content search to get enough results - const contentOptions = { - ...options, - maxEntries: maxEntries - filenameResults.length + 20 // Request extra to account for duplicates - } - - const contentResults = await this.searchByContent(resolvedPath, contentOptions) - - logger.debug('Found matches by content', { count: contentResults.length }) - - // Combine results: filename matches first, then content matches (deduplicated) - const combined = [...filenameResults] - const filenameSet = new Set(filenameResults) - - for (const filePath of contentResults) { - if (!filenameSet.has(filePath)) { - combined.push(filePath) - if (combined.length >= maxEntries) { - break - } - } - } - - logger.debug('Combined results', { total: combined.length, filenameCount: filenameResults.length }) - return combined.slice(0, maxEntries) + return filenameResults.slice(0, options.maxEntries) } public validateNotesDirectory = async (_: Electron.IpcMainInvokeEvent, dirPath: string): Promise => { @@ -1317,7 +1535,7 @@ class FileStorage { await fs.promises.writeFile(destPath, buffer) const stats = await fs.promises.stat(destPath) - const fileType = getFileType(ext) + const fileType = await this.getFileType(destPath) return { id: uuid, @@ -1604,6 +1822,10 @@ class FileStorage { } public isTextFile = async (_: Electron.IpcMainInvokeEvent, filePath: string): Promise => { + return this._isTextFile(filePath) + } + + private _isTextFile = async (filePath: string): Promise => { try { const isBinary = await isBinaryFile(filePath) if (isBinary) { diff --git a/src/main/services/LocalTransferService.ts b/src/main/services/LocalTransferService.ts new file mode 100644 index 0000000000..bc2743757c --- /dev/null +++ b/src/main/services/LocalTransferService.ts @@ -0,0 +1,207 @@ +import { loggerService } from '@logger' +import type { LocalTransferPeer, LocalTransferState } from '@shared/config/types' +import { IpcChannel } from '@shared/IpcChannel' +import type { Browser, Service } from 'bonjour-service' +import Bonjour from 'bonjour-service' + +import { windowService } from './WindowService' + +const SERVICE_TYPE = 'cherrystudio' +const SERVICE_PROTOCOL = 'tcp' as const + +const logger = loggerService.withContext('LocalTransferService') + +type StartDiscoveryOptions = { + resetList?: boolean +} + +class LocalTransferService { + private static instance: LocalTransferService + private bonjour: Bonjour | null = null + private browser: Browser | null = null + private services = new Map() + private isScanning = false + private lastScanStartedAt?: number + private lastUpdatedAt = Date.now() + private lastError?: string + + private constructor() {} + + public static getInstance(): LocalTransferService { + if (!LocalTransferService.instance) { + LocalTransferService.instance = new LocalTransferService() + } + return LocalTransferService.instance + } + + public startDiscovery(options?: StartDiscoveryOptions): LocalTransferState { + if (options?.resetList) { + this.services.clear() + } + + this.isScanning = true + this.lastScanStartedAt = Date.now() + this.lastUpdatedAt = Date.now() + this.lastError = undefined + this.restartBrowser() + this.broadcastState() + return this.getState() + } + + public stopDiscovery(): LocalTransferState { + if (this.browser) { + try { + this.browser.stop() + } catch (error) { + logger.warn('Failed to stop local transfer browser', error as Error) + } + } + this.isScanning = false + this.lastUpdatedAt = Date.now() + this.broadcastState() + return this.getState() + } + + public getState(): LocalTransferState { + const services = Array.from(this.services.values()).sort((a, b) => a.name.localeCompare(b.name)) + return { + services, + isScanning: this.isScanning, + lastScanStartedAt: this.lastScanStartedAt, + lastUpdatedAt: this.lastUpdatedAt, + lastError: this.lastError + } + } + + public getPeerById(id: string): LocalTransferPeer | undefined { + return this.services.get(id) + } + + public dispose(): void { + this.stopDiscovery() + this.services.clear() + this.browser?.removeAllListeners() + this.browser = null + if (this.bonjour) { + try { + this.bonjour.destroy() + } catch (error) { + logger.warn('Failed to destroy Bonjour instance', error as Error) + } + this.bonjour = null + } + } + + private getBonjour(): Bonjour { + if (!this.bonjour) { + this.bonjour = new Bonjour() + } + return this.bonjour + } + + private restartBrowser(): void { + // Clean up existing browser + if (this.browser) { + this.browser.removeAllListeners() + try { + this.browser.stop() + } catch (error) { + logger.warn('Error while stopping Bonjour browser', error as Error) + } + this.browser = null + } + + // Destroy and recreate Bonjour instance to prevent socket leaks + if (this.bonjour) { + try { + this.bonjour.destroy() + } catch (error) { + logger.warn('Error while destroying Bonjour instance', error as Error) + } + this.bonjour = null + } + + const browser = this.getBonjour().find({ type: SERVICE_TYPE, protocol: SERVICE_PROTOCOL }) + this.browser = browser + this.bindBrowserEvents(browser) + + try { + browser.start() + logger.info('Local transfer discovery started') + } catch (error) { + const err = error instanceof Error ? error : new Error(String(error)) + this.lastError = err.message + logger.error('Failed to start local transfer discovery', err) + } + } + + private bindBrowserEvents(browser: Browser) { + browser.on('up', (service) => { + const peer = this.normalizeService(service) + logger.info(`LAN peer detected: ${peer.name} (${peer.addresses.join(', ')})`) + this.services.set(peer.id, peer) + this.lastUpdatedAt = Date.now() + this.broadcastState() + }) + + browser.on('down', (service) => { + const key = this.buildServiceKey(service.fqdn || service.name, service.host, service.port) + if (this.services.delete(key)) { + logger.info(`LAN peer removed: ${service.name}`) + this.lastUpdatedAt = Date.now() + this.broadcastState() + } + }) + + browser.on('error', (error) => { + const err = error instanceof Error ? error : new Error(String(error)) + logger.error('Local transfer discovery error', err) + this.lastError = err.message + this.broadcastState() + }) + } + + private normalizeService(service: Service): LocalTransferPeer { + const addressCandidates = [...(service.addresses || []), service.referer?.address].filter( + (value): value is string => typeof value === 'string' && value.length > 0 + ) + const addresses = Array.from(new Set(addressCandidates)) + const txtEntries = Object.entries(service.txt || {}) + const txt = + txtEntries.length > 0 + ? Object.fromEntries( + txtEntries.map(([key, value]) => [key, value === undefined || value === null ? '' : String(value)]) + ) + : undefined + + const peer: LocalTransferPeer = { + id: this.buildServiceKey(service.fqdn || service.name, service.host, service.port), + name: service.name, + host: service.host, + fqdn: service.fqdn, + port: service.port, + type: service.type, + protocol: service.protocol, + addresses, + txt, + updatedAt: Date.now() + } + + return peer + } + + private buildServiceKey(name?: string, host?: string, port?: number): string { + const raw = [name, host, port?.toString()].filter(Boolean).join('-') + return raw || `service-${Date.now()}` + } + + private broadcastState() { + const mainWindow = windowService.getMainWindow() + if (!mainWindow || mainWindow.isDestroyed()) { + return + } + mainWindow.webContents.send(IpcChannel.LocalTransfer_ServicesUpdated, this.getState()) + } +} + +export const localTransferService = LocalTransferService.getInstance() diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index ec636dd220..eec1f04b02 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -7,7 +7,7 @@ import { loggerService } from '@logger' import { createInMemoryMCPServer } from '@main/mcpServers/factory' import { makeSureDirExists, removeEnvProxy } from '@main/utils' import { buildFunctionCallToolName } from '@main/utils/mcp' -import { getBinaryName, getBinaryPath } from '@main/utils/process' +import { findCommandInShellEnv, getBinaryName, getBinaryPath, isBinaryExists } from '@main/utils/process' import getLoginShellEnvironment from '@main/utils/shell-env' import { TraceMethod, withSpanFunc } from '@mcp-trace/trace-core' import { Client } from '@modelcontextprotocol/sdk/client/index.js' @@ -249,6 +249,26 @@ class McpService { StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport > => { // Create appropriate transport based on configuration + + // Special case for nowledgeMem - uses HTTP transport instead of in-memory + if (isBuiltinMCPServer(server) && server.name === BuiltinMCPServerNames.nowledgeMem) { + const nowledgeMemUrl = 'http://127.0.0.1:14242/mcp' + const options: StreamableHTTPClientTransportOptions = { + fetch: async (url, init) => { + return net.fetch(typeof url === 'string' ? url : url.toString(), init) + }, + requestInit: { + headers: { + ...defaultAppHeaders(), + APP: 'Cherry Studio' + } + }, + authProvider + } + getServerLogger(server).debug(`Using StreamableHTTPClientTransport for ${server.name}`) + return new StreamableHTTPClientTransport(new URL(nowledgeMemUrl), options) + } + if (isBuiltinMCPServer(server) && server.name !== BuiltinMCPServerNames.mcpAutoInstall) { getServerLogger(server).debug(`Using in-memory transport`) const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair() @@ -298,6 +318,10 @@ class McpService { } else if (server.command) { let cmd = server.command + // Get login shell environment first - needed for command detection and server execution + // Note: getLoginShellEnvironment() is memoized, so subsequent calls are fast + const loginShellEnv = await getLoginShellEnvironment() + // For DXT servers, use resolved configuration with platform overrides and variable substitution if (server.dxtPath) { const resolvedConfig = this.dxtService.getResolvedMcpConfig(server.dxtPath) @@ -319,18 +343,45 @@ class McpService { } if (server.command === 'npx') { - cmd = await getBinaryPath('bun') - getServerLogger(server).debug(`Using command`, { command: cmd }) + // First, check if npx is available in user's shell environment + const npxPath = await findCommandInShellEnv('npx', loginShellEnv) - // add -x to args if args exist - if (args && args.length > 0) { - if (!args.includes('-y')) { - args.unshift('-y') - } - if (!args.includes('x')) { - args.unshift('x') + if (npxPath) { + // Use system npx + cmd = npxPath + getServerLogger(server).debug(`Using system npx`, { command: cmd }) + } else { + // System npx not found, try bundled bun as fallback + getServerLogger(server).debug(`System npx not found, checking for bundled bun`) + + if (await isBinaryExists('bun')) { + // Fall back to bundled bun + cmd = await getBinaryPath('bun') + getServerLogger(server).info(`Using bundled bun as fallback (npx not found in PATH)`, { + command: cmd + }) + + // Transform args for bun x format + if (args && args.length > 0) { + if (!args.includes('-y')) { + args.unshift('-y') + } + if (!args.includes('x')) { + args.unshift('x') + } + } + } else { + // Neither npx nor bun available + throw new Error( + 'npx not found in PATH and bundled bun is not available. This may indicate an installation issue.\n' + + 'Please either:\n' + + '1. Install Node.js (which includes npx) from https://nodejs.org\n' + + '2. Run the MCP dependencies installer from Settings\n' + + '3. Restart the application if you recently installed Node.js' + ) } } + if (server.registryUrl) { server.env = { ...server.env, @@ -345,7 +396,35 @@ class McpService { } } } else if (server.command === 'uvx' || server.command === 'uv') { - cmd = await getBinaryPath(server.command) + // First, check if uvx/uv is available in user's shell environment + const uvPath = await findCommandInShellEnv(server.command, loginShellEnv) + + if (uvPath) { + // Use system uvx/uv + cmd = uvPath + getServerLogger(server).debug(`Using system ${server.command}`, { command: cmd }) + } else { + // System command not found, try bundled version as fallback + getServerLogger(server).debug(`System ${server.command} not found, checking for bundled version`) + + if (await isBinaryExists(server.command)) { + // Fall back to bundled version + cmd = await getBinaryPath(server.command) + getServerLogger(server).info(`Using bundled ${server.command} as fallback (not found in PATH)`, { + command: cmd + }) + } else { + // Neither system nor bundled available + throw new Error( + `${server.command} not found in PATH and bundled version is not available. This may indicate an installation issue.\n` + + 'Please either:\n' + + '1. Install uv from https://github.com/astral-sh/uv\n' + + '2. Run the MCP dependencies installer from Settings\n' + + `3. Restart the application if you recently installed ${server.command}` + ) + } + } + if (server.registryUrl) { server.env = { ...server.env, @@ -356,8 +435,6 @@ class McpService { } getServerLogger(server).debug(`Starting server`, { command: cmd, args }) - // Logger.info(`[MCP] Environment variables for server:`, server.env) - const loginShellEnv = await getLoginShellEnvironment() // Bun not support proxy https://github.com/oven-sh/bun/issues/16812 if (cmd.includes('bun')) { @@ -708,7 +785,7 @@ class McpService { ...tool, inputSchema: z.parse(MCPToolInputSchema, tool.inputSchema), outputSchema: tool.outputSchema ? z.parse(MCPToolOutputSchema, tool.outputSchema) : undefined, - id: buildFunctionCallToolName(server.name, tool.name, server.id), + id: buildFunctionCallToolName(server.name, tool.name), serverId: server.id, serverName: server.name, type: 'mcp' diff --git a/src/main/services/OvmsManager.ts b/src/main/services/OvmsManager.ts index 3a32d74ecf..67d6d9a9df 100644 --- a/src/main/services/OvmsManager.ts +++ b/src/main/services/OvmsManager.ts @@ -3,6 +3,8 @@ import { homedir } from 'node:os' import { promisify } from 'node:util' import { loggerService } from '@logger' +import { isWin } from '@main/constant' +import { getCpuName } from '@main/utils/system' import { HOME_CHERRY_DIR } from '@shared/config/constant' import * as fs from 'fs-extra' import * as path from 'path' @@ -11,6 +13,8 @@ const logger = loggerService.withContext('OvmsManager') const execAsync = promisify(exec) +export const isOvmsSupported = isWin && getCpuName().toLowerCase().includes('intel') + interface OvmsProcess { pid: number path: string @@ -29,6 +33,12 @@ interface OvmsConfig { class OvmsManager { private ovms: OvmsProcess | null = null + constructor() { + if (!isOvmsSupported) { + throw new Error('OVMS Manager is only supported on Windows platform with Intel CPU.') + } + } + /** * Recursively terminate a process and all its child processes * @param pid Process ID to terminate @@ -102,32 +112,10 @@ class OvmsManager { */ public async stopOvms(): Promise<{ success: boolean; message?: string }> { try { - // Check if OVMS process is running - const psCommand = `Get-Process -Name "ovms" -ErrorAction SilentlyContinue | Select-Object Id, Path | ConvertTo-Json` - const { stdout } = await execAsync(`powershell -Command "${psCommand}"`) - - if (!stdout.trim()) { - logger.info('OVMS process is not running') - return { success: true, message: 'OVMS process is not running' } - } - - const processes = JSON.parse(stdout) - const processList = Array.isArray(processes) ? processes : [processes] - - if (processList.length === 0) { - logger.info('OVMS process is not running') - return { success: true, message: 'OVMS process is not running' } - } - - // Terminate all OVMS processes using terminalProcess - for (const process of processList) { - const result = await this.terminalProcess(process.Id) - if (!result.success) { - logger.error(`Failed to terminate OVMS process with PID: ${process.Id}, ${result.message}`) - return { success: false, message: `Failed to terminate OVMS process: ${result.message}` } - } - logger.info(`Terminated OVMS process with PID: ${process.Id}`) - } + // close the OVMS process + await execAsync( + `powershell -Command "Get-WmiObject Win32_Process | Where-Object { $_.CommandLine -like 'ovms.exe*' } | ForEach-Object { Stop-Process -Id $_.ProcessId -Force }"` + ) // Reset the ovms instance this.ovms = null @@ -584,4 +572,5 @@ class OvmsManager { } } -export default OvmsManager +// Export singleton instance +export const ovmsManager = isOvmsSupported ? new OvmsManager() : undefined diff --git a/src/main/services/ReduxService.ts b/src/main/services/ReduxService.ts index 2df15bf3eb..8880691a24 100644 --- a/src/main/services/ReduxService.ts +++ b/src/main/services/ReduxService.ts @@ -1,7 +1,19 @@ /** - * @deprecated this file will be removed after v2 refactor + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- */ - import { loggerService } from '@logger' import { IpcChannel } from '@shared/IpcChannel' import { ipcMain } from 'electron' diff --git a/src/main/services/SearchService.ts b/src/main/services/SearchService.ts index 8a4e42099a..6c69f80889 100644 --- a/src/main/services/SearchService.ts +++ b/src/main/services/SearchService.ts @@ -14,38 +14,36 @@ export class SearchService { return SearchService.instance } - constructor() { - // Initialize the service - } - - private async createNewSearchWindow(uid: string): Promise { + private async createNewSearchWindow(uid: string, show: boolean = false): Promise { const newWindow = new BrowserWindow({ - width: 800, - height: 600, - show: false, + width: 1280, + height: 768, + show, webPreferences: { nodeIntegration: true, contextIsolation: false, devTools: is.dev } }) - newWindow.webContents.session.webRequest.onBeforeSendHeaders({ urls: ['*://*/*'] }, (details, callback) => { - const headers = { - ...details.requestHeaders, - 'User-Agent': - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' - } - callback({ requestHeaders: headers }) - }) + this.searchWindows[uid] = newWindow - newWindow.on('closed', () => { - delete this.searchWindows[uid] - }) + newWindow.on('closed', () => delete this.searchWindows[uid]) + + newWindow.webContents.userAgent = + 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36' + return newWindow } - public async openSearchWindow(uid: string): Promise { - await this.createNewSearchWindow(uid) + public async openSearchWindow(uid: string, show: boolean = false): Promise { + const existingWindow = this.searchWindows[uid] + + if (existingWindow) { + show && existingWindow.show() + return + } + + await this.createNewSearchWindow(uid, show) } public async closeSearchWindow(uid: string): Promise { diff --git a/src/main/services/SelectionService.ts b/src/main/services/SelectionService.ts index aeafce5be5..0adf50c6d1 100644 --- a/src/main/services/SelectionService.ts +++ b/src/main/services/SelectionService.ts @@ -1440,6 +1440,12 @@ export class SelectionService { } actionWindow.setBounds({ x, y, width, height }) + + // [Windows only] Update remembered window size for custom resize + // setBounds() may not trigger the 'resized' event, so we need to update manually + if (this.isRemeberWinSize) { + this.lastActionWindowSize = { width, height } + } } /** diff --git a/src/main/services/ShortcutService.ts b/src/main/services/ShortcutService.ts index 536c164a0d..af57a41b7b 100644 --- a/src/main/services/ShortcutService.ts +++ b/src/main/services/ShortcutService.ts @@ -1,3 +1,19 @@ +/** + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- + */ import { preferenceService } from '@data/PreferenceService' import { loggerService } from '@logger' import { handleZoomFactor } from '@main/utils/zoom' diff --git a/src/main/services/StoreSyncService.ts b/src/main/services/StoreSyncService.ts index 57f07195b6..6013afdd57 100644 --- a/src/main/services/StoreSyncService.ts +++ b/src/main/services/StoreSyncService.ts @@ -1,3 +1,19 @@ +/** + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- + */ import { IpcChannel } from '@shared/IpcChannel' import type { StoreSyncAction } from '@types' import { BrowserWindow, ipcMain } from 'electron' diff --git a/src/main/services/WebSocketService.ts b/src/main/services/WebSocketService.ts deleted file mode 100644 index e52919e96a..0000000000 --- a/src/main/services/WebSocketService.ts +++ /dev/null @@ -1,359 +0,0 @@ -import { loggerService } from '@logger' -import type { WebSocketCandidatesResponse, WebSocketStatusResponse } from '@shared/config/types' -import * as fs from 'fs' -import { networkInterfaces } from 'os' -import * as path from 'path' -import type { Socket } from 'socket.io' -import { Server } from 'socket.io' - -import { windowService } from './WindowService' - -const logger = loggerService.withContext('WebSocketService') - -class WebSocketService { - private io: Server | null = null - private isStarted = false - private port = 7017 - private connectedClients = new Set() - - private getLocalIpAddress(): string | undefined { - const interfaces = networkInterfaces() - - // 按优先级排序的网络接口名称模式 - const interfacePriority = [ - // macOS: 以太网/Wi-Fi 优先 - /^en[0-9]+$/, // en0, en1 (以太网/Wi-Fi) - /^(en|eth)[0-9]+$/, // 以太网接口 - /^wlan[0-9]+$/, // 无线接口 - // Windows: 以太网/Wi-Fi 优先 - /^(Ethernet|Wi-Fi|Local Area Connection)/, - /^(Wi-Fi|无线网络连接)/, - // Linux: 以太网/Wi-Fi 优先 - /^(eth|enp|wlp|wlan)[0-9]+/, - // 虚拟化接口(低优先级) - /^bridge[0-9]+$/, // Docker bridge - /^veth[0-9]+$/, // Docker veth - /^docker[0-9]+/, // Docker interfaces - /^br-[0-9a-f]+/, // Docker bridge - /^vmnet[0-9]+$/, // VMware - /^vboxnet[0-9]+$/, // VirtualBox - // VPN 隧道接口(低优先级) - /^utun[0-9]+$/, // macOS VPN - /^tun[0-9]+$/, // Linux/Unix VPN - /^tap[0-9]+$/, // TAP interfaces - /^tailscale[0-9]*$/, // Tailscale VPN - /^wg[0-9]+$/ // WireGuard VPN - ] - - const candidates: Array<{ interface: string; address: string; priority: number }> = [] - - for (const [name, ifaces] of Object.entries(interfaces)) { - for (const iface of ifaces || []) { - if (iface.family === 'IPv4' && !iface.internal) { - // 计算接口优先级 - let priority = 999 // 默认最低优先级 - for (let i = 0; i < interfacePriority.length; i++) { - if (interfacePriority[i].test(name)) { - priority = i - break - } - } - - candidates.push({ - interface: name, - address: iface.address, - priority - }) - } - } - } - - if (candidates.length === 0) { - logger.warn('无法获取局域网 IP,使用默认 IP: 127.0.0.1') - return '127.0.0.1' - } - - // 按优先级排序,选择优先级最高的 - candidates.sort((a, b) => a.priority - b.priority) - const best = candidates[0] - - logger.info(`获取局域网 IP: ${best.address} (interface: ${best.interface})`) - return best.address - } - - public start = async (): Promise<{ success: boolean; port?: number; error?: string }> => { - if (this.isStarted && this.io) { - return { success: true, port: this.port } - } - - try { - this.io = new Server(this.port, { - cors: { - origin: '*', - methods: ['GET', 'POST'] - }, - transports: ['websocket', 'polling'], - allowEIO3: true, - pingTimeout: 60000, - pingInterval: 25000 - }) - - this.io.on('connection', (socket: Socket) => { - this.connectedClients.add(socket.id) - - const mainWindow = windowService.getMainWindow() - if (!mainWindow) { - logger.error('Main window is null, cannot send connection event') - } else { - mainWindow.webContents.send('websocket-client-connected', { - connected: true, - clientId: socket.id - }) - logger.info(`Connection event sent to renderer, total clients: ${this.connectedClients.size}`) - } - - socket.on('message', (data) => { - logger.info('Received message from mobile:', data) - mainWindow?.webContents.send('websocket-message-received', data) - socket.emit('message_received', { success: true }) - }) - - socket.on('disconnect', () => { - logger.info(`Client disconnected: ${socket.id}`) - this.connectedClients.delete(socket.id) - - if (this.connectedClients.size === 0) { - mainWindow?.webContents.send('websocket-client-connected', { - connected: false, - clientId: socket.id - }) - } - }) - }) - - // Engine 层面的事件监听 - this.io.engine.on('connection_error', (err) => { - logger.error('Engine connection error:', err) - }) - - this.io.engine.on('connection', (rawSocket) => { - const remoteAddr = rawSocket.request.connection.remoteAddress - logger.info(`[Engine] Raw connection from: ${remoteAddr}`) - logger.info(`[Engine] Transport: ${rawSocket.transport.name}`) - - rawSocket.on('packet', (packet: { type: string; data?: any }) => { - logger.info( - `[Engine] ← Packet from ${remoteAddr}: type="${packet.type}"`, - packet.data ? { data: packet.data } : {} - ) - }) - - rawSocket.on('packetCreate', (packet: { type: string; data?: any }) => { - logger.info(`[Engine] → Packet to ${remoteAddr}: type="${packet.type}"`) - }) - - rawSocket.on('close', (reason: string) => { - logger.warn(`[Engine] Connection closed from ${remoteAddr}, reason: ${reason}`) - }) - - rawSocket.on('error', (error: Error) => { - logger.error(`[Engine] Connection error from ${remoteAddr}:`, error) - }) - }) - - // Socket.IO 握手失败监听 - this.io.on('connection_error', (err) => { - logger.error('[Socket.IO] Connection error during handshake:', err) - }) - - this.isStarted = true - logger.info(`WebSocket server started on port ${this.port}`) - - return { success: true, port: this.port } - } catch (error) { - logger.error('Failed to start WebSocket server:', error as Error) - return { - success: false, - error: error instanceof Error ? error.message : 'Unknown error' - } - } - } - - public stop = async (): Promise<{ success: boolean }> => { - if (!this.isStarted || !this.io) { - return { success: true } - } - - try { - await new Promise((resolve) => { - this.io!.close(() => { - resolve() - }) - }) - - this.io = null - this.isStarted = false - this.connectedClients.clear() - logger.info('WebSocket server stopped') - - return { success: true } - } catch (error) { - logger.error('Failed to stop WebSocket server:', error as Error) - return { success: false } - } - } - - public getStatus = async (): Promise => { - return { - isRunning: this.isStarted, - port: this.isStarted ? this.port : undefined, - ip: this.isStarted ? this.getLocalIpAddress() : undefined, - clientConnected: this.connectedClients.size > 0 - } - } - - public getAllCandidates = async (): Promise => { - const interfaces = networkInterfaces() - - // 按优先级排序的网络接口名称模式 - const interfacePriority = [ - // macOS: 以太网/Wi-Fi 优先 - /^en[0-9]+$/, // en0, en1 (以太网/Wi-Fi) - /^(en|eth)[0-9]+$/, // 以太网接口 - /^wlan[0-9]+$/, // 无线接口 - // Windows: 以太网/Wi-Fi 优先 - /^(Ethernet|Wi-Fi|Local Area Connection)/, - /^(Wi-Fi|无线网络连接)/, - // Linux: 以太网/Wi-Fi 优先 - /^(eth|enp|wlp|wlan)[0-9]+/, - // 虚拟化接口(低优先级) - /^bridge[0-9]+$/, // Docker bridge - /^veth[0-9]+$/, // Docker veth - /^docker[0-9]+/, // Docker interfaces - /^br-[0-9a-f]+/, // Docker bridge - /^vmnet[0-9]+$/, // VMware - /^vboxnet[0-9]+$/, // VirtualBox - // VPN 隧道接口(低优先级) - /^utun[0-9]+$/, // macOS VPN - /^tun[0-9]+$/, // Linux/Unix VPN - /^tap[0-9]+$/, // TAP interfaces - /^tailscale[0-9]*$/, // Tailscale VPN - /^wg[0-9]+$/ // WireGuard VPN - ] - - const candidates: Array<{ host: string; interface: string; priority: number }> = [] - - for (const [name, ifaces] of Object.entries(interfaces)) { - for (const iface of ifaces || []) { - if (iface.family === 'IPv4' && !iface.internal) { - // 计算接口优先级 - let priority = 999 // 默认最低优先级 - for (let i = 0; i < interfacePriority.length; i++) { - if (interfacePriority[i].test(name)) { - priority = i - break - } - } - - candidates.push({ - host: iface.address, - interface: name, - priority - }) - - logger.debug(`Found interface: ${name} -> ${iface.address} (priority: ${priority})`) - } - } - } - - // 按优先级排序返回 - candidates.sort((a, b) => a.priority - b.priority) - logger.info( - `Found ${candidates.length} IP candidates: ${candidates.map((c) => `${c.host}(${c.interface})`).join(', ')}` - ) - return candidates - } - - public sendFile = async ( - _: Electron.IpcMainInvokeEvent, - filePath: string - ): Promise<{ success: boolean; error?: string }> => { - if (!this.isStarted || !this.io) { - const errorMsg = 'WebSocket server is not running.' - logger.error(errorMsg) - return { success: false, error: errorMsg } - } - - if (this.connectedClients.size === 0) { - const errorMsg = 'No client connected.' - logger.error(errorMsg) - return { success: false, error: errorMsg } - } - - const mainWindow = windowService.getMainWindow() - - return new Promise((resolve, reject) => { - const stats = fs.statSync(filePath) - const totalSize = stats.size - const filename = path.basename(filePath) - const stream = fs.createReadStream(filePath) - let bytesSent = 0 - const startTime = Date.now() - - logger.info(`Starting file transfer: ${filename} (${this.formatFileSize(totalSize)})`) - - // 向客户端发送文件开始的信号,包含文件名和总大小 - this.io!.emit('zip-file-start', { filename, totalSize }) - - stream.on('data', (chunk) => { - bytesSent += chunk.length - const progress = (bytesSent / totalSize) * 100 - - // 向客户端发送文件块 - this.io!.emit('zip-file-chunk', chunk) - - // 向渲染进程发送进度更新 - mainWindow?.webContents.send('file-send-progress', { progress }) - - // 每10%记录一次进度 - if (Math.floor(progress) % 10 === 0) { - const elapsed = (Date.now() - startTime) / 1000 - const speed = elapsed > 0 ? bytesSent / elapsed : 0 - logger.info(`Transfer progress: ${Math.floor(progress)}% (${this.formatFileSize(speed)}/s)`) - } - }) - - stream.on('end', () => { - const totalTime = (Date.now() - startTime) / 1000 - const avgSpeed = totalTime > 0 ? totalSize / totalTime : 0 - logger.info( - `File transfer completed: ${filename} in ${totalTime.toFixed(1)}s (${this.formatFileSize(avgSpeed)}/s)` - ) - - // 确保发送100%的进度 - mainWindow?.webContents.send('file-send-progress', { progress: 100 }) - // 向客户端发送文件结束的信号 - this.io!.emit('zip-file-end') - resolve({ success: true }) - }) - - stream.on('error', (error) => { - logger.error(`File transfer failed: ${filename}`, error) - reject({ - success: false, - error: error instanceof Error ? error.message : 'Unknown error' - }) - }) - }) - } - - private formatFileSize(bytes: number): string { - if (bytes === 0) return '0 B' - const k = 1024 - const sizes = ['B', 'KB', 'MB', 'GB'] - const i = Math.floor(Math.log(bytes) / Math.log(k)) - return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i] - } -} - -export default new WebSocketService() diff --git a/src/main/services/WindowService.ts b/src/main/services/WindowService.ts index 2210582fee..0be118b0ad 100644 --- a/src/main/services/WindowService.ts +++ b/src/main/services/WindowService.ts @@ -255,6 +255,12 @@ export class WindowService { } private setupWebContentsHandlers(mainWindow: BrowserWindow) { + // Fix for Electron bug where zoom resets during in-page navigation (route changes) + // This complements the resize-based workaround by catching navigation events + mainWindow.webContents.on('did-navigate-in-page', () => { + mainWindow.webContents.setZoomFactor(preferenceService.get('app.zoom_factor')) + }) + mainWindow.webContents.on('will-navigate', (event, url) => { if (url.includes('localhost:517')) { return @@ -516,7 +522,9 @@ export class WindowService { miniWindowState.manage(this.miniWindow) //miniWindow should show in current desktop - this.miniWindow?.setVisibleOnAllWorkspaces(true, { visibleOnFullScreen: true }) + this.miniWindow?.setVisibleOnAllWorkspaces(true, { + visibleOnFullScreen: true + }) //make miniWindow always on top of fullscreen apps with level set //[mac] level higher than 'floating' will cover the pinyin input method this.miniWindow.setAlwaysOnTop(true, 'floating') @@ -635,6 +643,11 @@ export class WindowService { return } else if (isMac) { this.miniWindow.hide() + const majorVersion = parseInt(process.getSystemVersion().split('.')[0], 10) + if (majorVersion >= 26) { + // on macOS 26+, the popup of the mimiWindow would not change the focus to previous application. + return + } if (!this.wasMainWindowFocused) { app.hide() } diff --git a/src/main/services/__tests__/AppUpdater.test.ts b/src/main/services/__tests__/AppUpdater.test.ts index babc76ca81..7774738028 100644 --- a/src/main/services/__tests__/AppUpdater.test.ts +++ b/src/main/services/__tests__/AppUpdater.test.ts @@ -14,7 +14,7 @@ vi.mock('@logger', () => ({ // Mock PreferenceService using the existing mock vi.mock('@data/PreferenceService', async () => { - const { MockMainPreferenceServiceExport } = await import('../../../../tests/__mocks__/main/PreferenceService') + const { MockMainPreferenceServiceExport } = await import('@test-mocks/main/PreferenceService') return MockMainPreferenceServiceExport }) @@ -84,9 +84,9 @@ vi.mock('electron-updater', () => ({ // Import after mocks import { preferenceService } from '@data/PreferenceService' import { UpdateMirror } from '@shared/config/constant' +import { MockMainPreferenceServiceUtils } from '@test-mocks/main/PreferenceService' import { app, net } from 'electron' -import { MockMainPreferenceServiceUtils } from '../../../../tests/__mocks__/main/PreferenceService' import AppUpdater from '../AppUpdater' // Mock clientId for ConfigManager since it's not migrated yet diff --git a/src/main/services/__tests__/BackupManager.deleteTempBackup.test.ts b/src/main/services/__tests__/BackupManager.deleteTempBackup.test.ts new file mode 100644 index 0000000000..a371ed0a3c --- /dev/null +++ b/src/main/services/__tests__/BackupManager.deleteTempBackup.test.ts @@ -0,0 +1,279 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +// Use vi.hoisted to define mocks that are available during hoisting +const { mockLogger } = vi.hoisted(() => ({ + mockLogger: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn() + } +})) + +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => mockLogger + } +})) + +vi.mock('electron', () => ({ + app: { + getPath: vi.fn((key: string) => { + if (key === 'temp') return '/tmp' + if (key === 'userData') return '/mock/userData' + return '/mock/unknown' + }) + } +})) + +vi.mock('fs-extra', () => ({ + default: { + pathExists: vi.fn(), + remove: vi.fn(), + ensureDir: vi.fn(), + copy: vi.fn(), + readdir: vi.fn(), + stat: vi.fn(), + readFile: vi.fn(), + writeFile: vi.fn(), + createWriteStream: vi.fn(), + createReadStream: vi.fn() + }, + pathExists: vi.fn(), + remove: vi.fn(), + ensureDir: vi.fn(), + copy: vi.fn(), + readdir: vi.fn(), + stat: vi.fn(), + readFile: vi.fn(), + writeFile: vi.fn(), + createWriteStream: vi.fn(), + createReadStream: vi.fn() +})) + +vi.mock('../WindowService', () => ({ + windowService: { + getMainWindow: vi.fn() + } +})) + +vi.mock('../WebDav', () => ({ + default: vi.fn() +})) + +vi.mock('../S3Storage', () => ({ + default: vi.fn() +})) + +vi.mock('../../utils', () => ({ + getDataPath: vi.fn(() => '/mock/data') +})) + +vi.mock('archiver', () => ({ + default: vi.fn() +})) + +vi.mock('node-stream-zip', () => ({ + default: vi.fn() +})) + +// Import after mocks +import * as fs from 'fs-extra' +import * as path from 'path' + +import BackupManager from '../BackupManager' + +// Helper to construct platform-independent paths for assertions +// The implementation uses path.normalize() which converts to platform separators +const normalizePath = (p: string): string => path.normalize(p) + +describe('BackupManager.deleteTempBackup - Security Tests', () => { + let backupManager: BackupManager + + beforeEach(() => { + vi.clearAllMocks() + backupManager = new BackupManager() + }) + + describe('Normal Operations', () => { + it('should delete valid file in allowed directory', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(true as never) + vi.mocked(fs.remove).mockResolvedValue(undefined as never) + + const validPath = '/tmp/cherry-studio/lan-transfer/backup.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, validPath) + + expect(result).toBe(true) + expect(fs.remove).toHaveBeenCalledWith(normalizePath(validPath)) + expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('Deleted temp backup')) + }) + + it('should delete file in nested subdirectory', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(true as never) + vi.mocked(fs.remove).mockResolvedValue(undefined as never) + + const nestedPath = '/tmp/cherry-studio/lan-transfer/sub/dir/file.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, nestedPath) + + expect(result).toBe(true) + expect(fs.remove).toHaveBeenCalledWith(normalizePath(nestedPath)) + }) + + it('should return false when file does not exist', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(false as never) + + const missingPath = '/tmp/cherry-studio/lan-transfer/missing.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, missingPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + }) + }) + + describe('Path Traversal Attacks', () => { + it('should block basic directory traversal attack (../../../../etc/passwd)', async () => { + const attackPath = '/tmp/cherry-studio/lan-transfer/../../../../etc/passwd' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.pathExists).not.toHaveBeenCalled() + expect(fs.remove).not.toHaveBeenCalled() + expect(mockLogger.warn).toHaveBeenCalledWith(expect.stringContaining('outside temp directory')) + }) + + it('should block absolute path escape (/etc/passwd)', async () => { + const attackPath = '/etc/passwd' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + expect(mockLogger.warn).toHaveBeenCalled() + }) + + it('should block traversal with multiple slashes', async () => { + const attackPath = '/tmp/cherry-studio/lan-transfer/../../../etc/passwd' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + }) + + it('should block relative path traversal from current directory', async () => { + const attackPath = '../../../etc/passwd' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + }) + + it('should block traversal to parent directory', async () => { + const attackPath = '/tmp/cherry-studio/lan-transfer/../backup/secret.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + }) + }) + + describe('Prefix Attacks', () => { + it('should block similar prefix attack (lan-transfer-evil)', async () => { + const attackPath = '/tmp/cherry-studio/lan-transfer-evil/file.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + expect(mockLogger.warn).toHaveBeenCalled() + }) + + it('should block path without separator (lan-transferx)', async () => { + const attackPath = '/tmp/cherry-studio/lan-transferx' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + }) + + it('should block different temp directory prefix', async () => { + const attackPath = '/tmp-evil/cherry-studio/lan-transfer/file.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath) + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + }) + }) + + describe('Error Handling', () => { + it('should return false and log error on permission denied', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(true as never) + vi.mocked(fs.remove).mockRejectedValue(new Error('EACCES: permission denied') as never) + + const validPath = '/tmp/cherry-studio/lan-transfer/file.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, validPath) + + expect(result).toBe(false) + expect(mockLogger.error).toHaveBeenCalledWith(expect.stringContaining('Failed to delete'), expect.any(Error)) + }) + + it('should return false on fs.pathExists error', async () => { + vi.mocked(fs.pathExists).mockRejectedValue(new Error('ENOENT') as never) + + const validPath = '/tmp/cherry-studio/lan-transfer/file.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, validPath) + + expect(result).toBe(false) + expect(mockLogger.error).toHaveBeenCalled() + }) + + it('should handle empty path string', async () => { + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, '') + + expect(result).toBe(false) + expect(fs.remove).not.toHaveBeenCalled() + }) + }) + + describe('Edge Cases', () => { + it('should allow deletion of the temp directory itself', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(true as never) + vi.mocked(fs.remove).mockResolvedValue(undefined as never) + + const tempDir = '/tmp/cherry-studio/lan-transfer' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, tempDir) + + expect(result).toBe(true) + expect(fs.remove).toHaveBeenCalledWith(normalizePath(tempDir)) + }) + + it('should handle path with trailing slash', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(true as never) + vi.mocked(fs.remove).mockResolvedValue(undefined as never) + + const pathWithSlash = '/tmp/cherry-studio/lan-transfer/sub/' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, pathWithSlash) + + // path.normalize removes trailing slash + expect(result).toBe(true) + }) + + it('should handle file with special characters in name', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(true as never) + vi.mocked(fs.remove).mockResolvedValue(undefined as never) + + const specialPath = '/tmp/cherry-studio/lan-transfer/file with spaces & (special).zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, specialPath) + + expect(result).toBe(true) + expect(fs.remove).toHaveBeenCalled() + }) + + it('should handle path with double slashes', async () => { + vi.mocked(fs.pathExists).mockResolvedValue(true as never) + vi.mocked(fs.remove).mockResolvedValue(undefined as never) + + const doubleSlashPath = '/tmp/cherry-studio//lan-transfer//file.zip' + const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, doubleSlashPath) + + // path.normalize handles double slashes + expect(result).toBe(true) + }) + }) +}) diff --git a/src/main/services/__tests__/LocalTransferService.test.ts b/src/main/services/__tests__/LocalTransferService.test.ts new file mode 100644 index 0000000000..d00c7c269b --- /dev/null +++ b/src/main/services/__tests__/LocalTransferService.test.ts @@ -0,0 +1,481 @@ +import { EventEmitter } from 'events' +import { afterEach, beforeEach, describe, expect, it, type Mock, vi } from 'vitest' + +// Create mock objects before vi.mock calls +const mockLogger = { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn() +} + +let mockMainWindow: { + isDestroyed: Mock + webContents: { send: Mock } +} | null = null + +let mockBrowser: EventEmitter & { + start: Mock + stop: Mock + removeAllListeners: Mock +} + +let mockBonjour: { + find: Mock + destroy: Mock +} + +// Mock dependencies before importing the service +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => mockLogger + } +})) + +vi.mock('../WindowService', () => ({ + windowService: { + getMainWindow: vi.fn(() => mockMainWindow) + } +})) + +vi.mock('bonjour-service', () => ({ + default: vi.fn(() => mockBonjour) +})) + +describe('LocalTransferService', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.resetModules() + + // Reset mock objects + mockMainWindow = { + isDestroyed: vi.fn(() => false), + webContents: { send: vi.fn() } + } + + mockBrowser = Object.assign(new EventEmitter(), { + start: vi.fn(), + stop: vi.fn(), + removeAllListeners: vi.fn() + }) + + mockBonjour = { + find: vi.fn(() => mockBrowser), + destroy: vi.fn() + } + }) + + afterEach(() => { + vi.resetAllMocks() + }) + + describe('startDiscovery', () => { + it('should set isScanning to true and start browser', async () => { + const { localTransferService } = await import('../LocalTransferService') + + const state = localTransferService.startDiscovery() + + expect(state.isScanning).toBe(true) + expect(state.lastScanStartedAt).toBeDefined() + expect(mockBonjour.find).toHaveBeenCalledWith({ type: 'cherrystudio', protocol: 'tcp' }) + expect(mockBrowser.start).toHaveBeenCalled() + }) + + it('should clear services when resetList is true', async () => { + const { localTransferService } = await import('../LocalTransferService') + + // First, start discovery and add a service + localTransferService.startDiscovery() + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100'], + fqdn: 'test.local' + }) + + expect(localTransferService.getState().services).toHaveLength(1) + + // Now restart with resetList + const state = localTransferService.startDiscovery({ resetList: true }) + + expect(state.services).toHaveLength(0) + }) + + it('should broadcast state after starting discovery', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + expect(mockMainWindow?.webContents.send).toHaveBeenCalled() + }) + + it('should handle browser.start() error', async () => { + mockBrowser.start.mockImplementation(() => { + throw new Error('Failed to start mDNS') + }) + + const { localTransferService } = await import('../LocalTransferService') + + const state = localTransferService.startDiscovery() + + expect(state.lastError).toBe('Failed to start mDNS') + expect(mockLogger.error).toHaveBeenCalled() + }) + }) + + describe('stopDiscovery', () => { + it('should set isScanning to false and stop browser', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + const state = localTransferService.stopDiscovery() + + expect(state.isScanning).toBe(false) + expect(mockBrowser.stop).toHaveBeenCalled() + }) + + it('should handle browser.stop() error gracefully', async () => { + mockBrowser.stop.mockImplementation(() => { + throw new Error('Stop failed') + }) + + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + // Should not throw + expect(() => localTransferService.stopDiscovery()).not.toThrow() + expect(mockLogger.warn).toHaveBeenCalled() + }) + + it('should broadcast state after stopping', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + vi.clearAllMocks() + + localTransferService.stopDiscovery() + + expect(mockMainWindow?.webContents.send).toHaveBeenCalled() + }) + }) + + describe('browser events', () => { + it('should add service on "up" event', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100'], + fqdn: 'test.local', + type: 'cherrystudio', + protocol: 'tcp' + }) + + const state = localTransferService.getState() + expect(state.services).toHaveLength(1) + expect(state.services[0].name).toBe('Test Service') + expect(state.services[0].port).toBe(12345) + expect(state.services[0].addresses).toContain('192.168.1.100') + }) + + it('should remove service on "down" event', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + // Add service + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100'], + fqdn: 'test.local' + }) + + expect(localTransferService.getState().services).toHaveLength(1) + + // Remove service + mockBrowser.emit('down', { + name: 'Test Service', + host: 'localhost', + port: 12345, + fqdn: 'test.local' + }) + + expect(localTransferService.getState().services).toHaveLength(0) + expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('removed')) + }) + + it('should set lastError on "error" event', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('error', new Error('Discovery failed')) + + const state = localTransferService.getState() + expect(state.lastError).toBe('Discovery failed') + expect(mockLogger.error).toHaveBeenCalled() + }) + + it('should handle non-Error objects in error event', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('error', 'String error message') + + const state = localTransferService.getState() + expect(state.lastError).toBe('String error message') + }) + }) + + describe('getState', () => { + it('should return sorted services by name', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Zebra Service', + host: 'host1', + port: 1001, + addresses: ['192.168.1.1'] + }) + + mockBrowser.emit('up', { + name: 'Alpha Service', + host: 'host2', + port: 1002, + addresses: ['192.168.1.2'] + }) + + const state = localTransferService.getState() + expect(state.services[0].name).toBe('Alpha Service') + expect(state.services[1].name).toBe('Zebra Service') + }) + + it('should include all state properties', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + const state = localTransferService.getState() + + expect(state).toHaveProperty('services') + expect(state).toHaveProperty('isScanning') + expect(state).toHaveProperty('lastScanStartedAt') + expect(state).toHaveProperty('lastUpdatedAt') + }) + }) + + describe('getPeerById', () => { + it('should return peer when exists', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100'], + fqdn: 'test.local' + }) + + const services = localTransferService.getState().services + const peer = localTransferService.getPeerById(services[0].id) + + expect(peer).toBeDefined() + expect(peer?.name).toBe('Test Service') + }) + + it('should return undefined when peer does not exist', async () => { + const { localTransferService } = await import('../LocalTransferService') + + const peer = localTransferService.getPeerById('non-existent-id') + + expect(peer).toBeUndefined() + }) + }) + + describe('normalizeService', () => { + it('should deduplicate addresses', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100', '192.168.1.100', '10.0.0.1'], + referer: { address: '192.168.1.100' } + }) + + const services = localTransferService.getState().services + expect(services[0].addresses).toHaveLength(2) + expect(services[0].addresses).toContain('192.168.1.100') + expect(services[0].addresses).toContain('10.0.0.1') + }) + + it('should filter empty addresses', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100', '', null as any] + }) + + const services = localTransferService.getState().services + expect(services[0].addresses).toEqual(['192.168.1.100']) + }) + + it('should convert txt null/undefined values to empty strings', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100'], + txt: { + version: '1.0', + nullValue: null, + undefinedValue: undefined, + numberValue: 42 + } + }) + + const services = localTransferService.getState().services + expect(services[0].txt).toEqual({ + version: '1.0', + nullValue: '', + undefinedValue: '', + numberValue: '42' + }) + }) + + it('should not include txt when empty', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100'], + txt: {} + }) + + const services = localTransferService.getState().services + expect(services[0].txt).toBeUndefined() + }) + }) + + describe('dispose', () => { + it('should clean up all resources', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + mockBrowser.emit('up', { + name: 'Test Service', + host: 'localhost', + port: 12345, + addresses: ['192.168.1.100'] + }) + + localTransferService.dispose() + + expect(localTransferService.getState().services).toHaveLength(0) + expect(localTransferService.getState().isScanning).toBe(false) + expect(mockBrowser.removeAllListeners).toHaveBeenCalled() + expect(mockBonjour.destroy).toHaveBeenCalled() + }) + + it('should handle bonjour.destroy() error gracefully', async () => { + mockBonjour.destroy.mockImplementation(() => { + throw new Error('Destroy failed') + }) + + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + // Should not throw + expect(() => localTransferService.dispose()).not.toThrow() + expect(mockLogger.warn).toHaveBeenCalled() + }) + + it('should be safe to call multiple times', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + + expect(() => { + localTransferService.dispose() + localTransferService.dispose() + }).not.toThrow() + }) + }) + + describe('broadcastState', () => { + it('should not throw when main window is null', async () => { + mockMainWindow = null + + const { localTransferService } = await import('../LocalTransferService') + + // Should not throw + expect(() => localTransferService.startDiscovery()).not.toThrow() + }) + + it('should not throw when main window is destroyed', async () => { + mockMainWindow = { + isDestroyed: vi.fn(() => true), + webContents: { send: vi.fn() } + } + + const { localTransferService } = await import('../LocalTransferService') + + // Should not throw + expect(() => localTransferService.startDiscovery()).not.toThrow() + expect(mockMainWindow.webContents.send).not.toHaveBeenCalled() + }) + }) + + describe('restartBrowser', () => { + it('should destroy old bonjour instance to prevent socket leaks', async () => { + const { localTransferService } = await import('../LocalTransferService') + + // First start + localTransferService.startDiscovery() + expect(mockBonjour.destroy).not.toHaveBeenCalled() + + // Restart - should destroy old instance + localTransferService.startDiscovery() + expect(mockBonjour.destroy).toHaveBeenCalled() + }) + + it('should remove all listeners from old browser', async () => { + const { localTransferService } = await import('../LocalTransferService') + + localTransferService.startDiscovery() + localTransferService.startDiscovery() + + expect(mockBrowser.removeAllListeners).toHaveBeenCalled() + }) + }) +}) diff --git a/src/main/services/agents/BaseService.ts b/src/main/services/agents/BaseService.ts index a4c1a23240..0b0340c82b 100644 --- a/src/main/services/agents/BaseService.ts +++ b/src/main/services/agents/BaseService.ts @@ -1,6 +1,7 @@ import { loggerService } from '@logger' import { mcpApiService } from '@main/apiServer/services/mcp' import { type ModelValidationError, validateModelId } from '@main/apiServer/utils' +import { buildFunctionCallToolName } from '@main/utils/mcp' import type { AgentType, MCPTool, SlashCommand, Tool } from '@types' import { objectKeys } from '@types' import fs from 'fs' @@ -12,6 +13,17 @@ import { builtinSlashCommands } from './services/claudecode/commands' import { builtinTools } from './services/claudecode/tools' const logger = loggerService.withContext('BaseService') +const MCP_TOOL_ID_PREFIX = 'mcp__' +const MCP_TOOL_LEGACY_PREFIX = 'mcp_' + +const buildMcpToolId = (serverId: string, toolName: string) => `${MCP_TOOL_ID_PREFIX}${serverId}__${toolName}` +const toLegacyMcpToolId = (toolId: string) => { + if (!toolId.startsWith(MCP_TOOL_ID_PREFIX)) { + return null + } + const rawId = toolId.slice(MCP_TOOL_ID_PREFIX.length) + return `${MCP_TOOL_LEGACY_PREFIX}${rawId.replace(/__/g, '_')}` +} /** * Base service class providing shared utilities for all agent-related services. @@ -33,8 +45,12 @@ export abstract class BaseService { 'slash_commands' ] - public async listMcpTools(agentType: AgentType, ids?: string[]): Promise { + public async listMcpTools( + agentType: AgentType, + ids?: string[] + ): Promise<{ tools: Tool[]; legacyIdMap: Map }> { const tools: Tool[] = [] + const legacyIdMap = new Map() if (agentType === 'claude-code') { tools.push(...builtinTools) } @@ -44,13 +60,21 @@ export abstract class BaseService { const server = await mcpApiService.getServerInfo(id) if (server) { server.tools.forEach((tool: MCPTool) => { + const canonicalId = buildFunctionCallToolName(server.name, tool.name) + const serverIdBasedId = buildMcpToolId(id, tool.name) + const legacyId = toLegacyMcpToolId(serverIdBasedId) + tools.push({ - id: `mcp_${id}_${tool.name}`, + id: canonicalId, name: tool.name, type: 'mcp', description: tool.description || '', requirePermissions: true }) + legacyIdMap.set(serverIdBasedId, canonicalId) + if (legacyId) { + legacyIdMap.set(legacyId, canonicalId) + } }) } } catch (error) { @@ -62,7 +86,53 @@ export abstract class BaseService { } } - return tools + return { tools, legacyIdMap } + } + + /** + * Normalize MCP tool IDs in allowed_tools to the current format. + * + * Legacy formats: + * - "mcp____" (double underscore separators, server ID based) + * - "mcp__" (single underscore separators) + * Current format: "mcp____" (double underscore separators). + * + * This keeps persisted data compatible without requiring a database migration. + */ + protected normalizeAllowedTools( + allowedTools: string[] | undefined, + tools: Tool[], + legacyIdMap?: Map + ): string[] | undefined { + if (!allowedTools || allowedTools.length === 0) { + return allowedTools + } + + const resolvedLegacyIdMap = new Map() + + if (legacyIdMap) { + for (const [legacyId, canonicalId] of legacyIdMap) { + resolvedLegacyIdMap.set(legacyId, canonicalId) + } + } + + for (const tool of tools) { + if (tool.type !== 'mcp') { + continue + } + const legacyId = toLegacyMcpToolId(tool.id) + if (!legacyId) { + continue + } + resolvedLegacyIdMap.set(legacyId, tool.id) + } + + if (resolvedLegacyIdMap.size === 0) { + return allowedTools + } + + const normalized = allowedTools.map((toolId) => resolvedLegacyIdMap.get(toolId) ?? toolId) + return Array.from(new Set(normalized)) } public async listSlashCommands(agentType: AgentType): Promise { diff --git a/src/main/services/agents/database/DatabaseManager.ts b/src/main/services/agents/database/DatabaseManager.ts index f4b13971c7..913f9e4a66 100644 --- a/src/main/services/agents/database/DatabaseManager.ts +++ b/src/main/services/agents/database/DatabaseManager.ts @@ -1,3 +1,19 @@ +/** + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- + */ import { type Client, createClient } from '@libsql/client' import { loggerService } from '@logger' import type { LibSQLDatabase } from 'drizzle-orm/libsql' diff --git a/src/main/services/agents/drizzle.config.ts b/src/main/services/agents/drizzle.config.ts index e12518c069..7278883c11 100644 --- a/src/main/services/agents/drizzle.config.ts +++ b/src/main/services/agents/drizzle.config.ts @@ -1,3 +1,19 @@ +/** + * @deprecated Scheduled for removal in v2.0.0 + * -------------------------------------------------------------------------- + * ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex) + * -------------------------------------------------------------------------- + * STOP: Feature PRs affecting this file are currently BLOCKED. + * Only critical bug fixes are accepted during this migration phase. + * + * This file is being refactored to v2 standards. + * Any non-critical changes will conflict with the ongoing work. + * + * 🔗 Context & Status: + * - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954 + * - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162 + * -------------------------------------------------------------------------- + */ /** * Drizzle Kit configuration for agents database */ diff --git a/src/main/services/agents/services/AgentService.ts b/src/main/services/agents/services/AgentService.ts index 2faa87bb45..7542c1935b 100644 --- a/src/main/services/agents/services/AgentService.ts +++ b/src/main/services/agents/services/AgentService.ts @@ -89,7 +89,9 @@ export class AgentService extends BaseService { } const agent = this.deserializeJsonFields(result[0]) as GetAgentResponse - agent.tools = await this.listMcpTools(agent.type, agent.mcps) + const { tools, legacyIdMap } = await this.listMcpTools(agent.type, agent.mcps) + agent.tools = tools + agent.allowed_tools = this.normalizeAllowedTools(agent.allowed_tools, agent.tools, legacyIdMap) // Load installed_plugins from cache file instead of database const workdir = agent.accessible_paths?.[0] @@ -134,7 +136,9 @@ export class AgentService extends BaseService { const agents = result.map((row) => this.deserializeJsonFields(row)) as GetAgentResponse[] for (const agent of agents) { - agent.tools = await this.listMcpTools(agent.type, agent.mcps) + const { tools, legacyIdMap } = await this.listMcpTools(agent.type, agent.mcps) + agent.tools = tools + agent.allowed_tools = this.normalizeAllowedTools(agent.allowed_tools, agent.tools, legacyIdMap) } return { agents, total: totalResult[0].count } diff --git a/src/main/services/agents/services/SessionService.ts b/src/main/services/agents/services/SessionService.ts index 5ba4721646..8f08160060 100644 --- a/src/main/services/agents/services/SessionService.ts +++ b/src/main/services/agents/services/SessionService.ts @@ -157,7 +157,9 @@ export class SessionService extends BaseService { } const session = this.deserializeJsonFields(result[0]) as GetAgentSessionResponse - session.tools = await this.listMcpTools(session.agent_type, session.mcps) + const { tools, legacyIdMap } = await this.listMcpTools(session.agent_type, session.mcps) + session.tools = tools + session.allowed_tools = this.normalizeAllowedTools(session.allowed_tools, session.tools, legacyIdMap) // If slash_commands is not in database yet (e.g., first invoke before init message), // fall back to builtin + local commands. Otherwise, use the merged commands from database. @@ -203,6 +205,12 @@ export class SessionService extends BaseService { const sessions = result.map((row) => this.deserializeJsonFields(row)) as GetAgentSessionResponse[] + for (const session of sessions) { + const { tools, legacyIdMap } = await this.listMcpTools(session.agent_type, session.mcps) + session.tools = tools + session.allowed_tools = this.normalizeAllowedTools(session.allowed_tools, session.tools, legacyIdMap) + } + return { sessions, total } } diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts index 7b1a43ab82..1a18b6aba9 100644 --- a/src/main/services/agents/services/claudecode/index.ts +++ b/src/main/services/agents/services/claudecode/index.ts @@ -15,9 +15,10 @@ import { query } from '@anthropic-ai/claude-agent-sdk' import { preferenceService } from '@data/PreferenceService' import { loggerService } from '@logger' import { validateModelId } from '@main/apiServer/utils' -import { ConfigKeys, configManager } from '@main/services/ConfigManager' -import { validateGitBashPath } from '@main/utils/process' +import { isWin } from '@main/constant' +import { autoDiscoverGitBash } from '@main/utils/process' import getLoginShellEnvironment from '@main/utils/shell-env' +import { withoutTrailingApiVersion } from '@shared/utils' import { app } from 'electron' import type { GetAgentSessionResponse } from '../..' @@ -113,7 +114,15 @@ class ClaudeCodeService implements AgentServiceInterface { Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy')) ) as Record - const customGitBashPath = validateGitBashPath(configManager.get(ConfigKeys.GitBashPath) as string | undefined) + // Auto-discover Git Bash path on Windows (already logs internally) + const customGitBashPath = isWin ? autoDiscoverGitBash() : null + + // Claude Agent SDK builds the final endpoint as `${ANTHROPIC_BASE_URL}/v1/messages`. + // To avoid malformed URLs like `/v1/v1/messages`, we normalize the provider host + // by stripping any trailing API version (e.g. `/v1`). + const anthropicBaseUrl = withoutTrailingApiVersion( + modelInfo.provider.anthropicApiHost?.trim() || modelInfo.provider.apiHost + ) const env = { ...loginShellEnvWithoutProxies, @@ -123,7 +132,7 @@ class ClaudeCodeService implements AgentServiceInterface { // ANTHROPIC_BASE_URL: `http://${apiConfig['feature.csaas.host']}:${apiConfig['feature.csaas.port']}/${modelInfo.provider.id}`, ANTHROPIC_API_KEY: modelInfo.provider.apiKey, ANTHROPIC_AUTH_TOKEN: modelInfo.provider.apiKey, - ANTHROPIC_BASE_URL: modelInfo.provider.anthropicApiHost?.trim() || modelInfo.provider.apiHost, + ANTHROPIC_BASE_URL: anthropicBaseUrl, ANTHROPIC_MODEL: modelInfo.modelId, ANTHROPIC_DEFAULT_OPUS_MODEL: modelInfo.modelId, ANTHROPIC_DEFAULT_SONNET_MODEL: modelInfo.modelId, diff --git a/src/main/services/agents/tests/BaseService.test.ts b/src/main/services/agents/tests/BaseService.test.ts new file mode 100644 index 0000000000..fe2f4e103a --- /dev/null +++ b/src/main/services/agents/tests/BaseService.test.ts @@ -0,0 +1,91 @@ +import type { Tool } from '@types' +import { describe, expect, it, vi } from 'vitest' + +vi.mock('@main/apiServer/services/mcp', () => ({ + mcpApiService: { + getServerInfo: vi.fn() + } +})) + +vi.mock('@main/apiServer/utils', () => ({ + validateModelId: vi.fn() +})) + +import { BaseService } from '../BaseService' + +class TestBaseService extends BaseService { + public normalize( + allowedTools: string[] | undefined, + tools: Tool[], + legacyIdMap?: Map + ): string[] | undefined { + return this.normalizeAllowedTools(allowedTools, tools, legacyIdMap) + } +} + +const buildMcpTool = (id: string): Tool => ({ + id, + name: id, + type: 'mcp', + description: 'test tool', + requirePermissions: true +}) + +describe('BaseService.normalizeAllowedTools', () => { + const service = new TestBaseService() + + it('returns undefined or empty inputs unchanged', () => { + expect(service.normalize(undefined, [])).toBeUndefined() + expect(service.normalize([], [])).toEqual([]) + }) + + it('normalizes legacy MCP tool IDs and deduplicates entries', () => { + const tools: Tool[] = [ + buildMcpTool('mcp__server_one__tool_one'), + buildMcpTool('mcp__server_two__tool_two'), + { id: 'custom_tool', name: 'custom_tool', type: 'custom' } + ] + + const legacyIdMap = new Map([ + ['mcp__server-1__tool-one', 'mcp__server_one__tool_one'], + ['mcp_server-1_tool-one', 'mcp__server_one__tool_one'], + ['mcp__server-2__tool-two', 'mcp__server_two__tool_two'] + ]) + + const allowedTools = [ + 'mcp__server-1__tool-one', + 'mcp_server-1_tool-one', + 'mcp_server_one_tool_one', + 'mcp__server_one__tool_one', + 'custom_tool', + 'mcp__server_two__tool_two', + 'mcp_server_two_tool_two', + 'mcp__server-2__tool-two' + ] + + expect(service.normalize(allowedTools, tools, legacyIdMap)).toEqual([ + 'mcp__server_one__tool_one', + 'custom_tool', + 'mcp__server_two__tool_two' + ]) + }) + + it('keeps legacy IDs when no matching MCP tool exists', () => { + const tools: Tool[] = [buildMcpTool('mcp__server_one__tool_one')] + const legacyIdMap = new Map([['mcp__server-1__tool-one', 'mcp__server_one__tool_one']]) + + const allowedTools = ['mcp__unknown__tool', 'mcp__server_one__tool_one'] + + expect(service.normalize(allowedTools, tools, legacyIdMap)).toEqual([ + 'mcp__unknown__tool', + 'mcp__server_one__tool_one' + ]) + }) + + it('returns allowed tools unchanged when no MCP tools are available', () => { + const allowedTools = ['custom_tool', 'builtin_tool'] + const tools: Tool[] = [{ id: 'custom_tool', name: 'custom_tool', type: 'custom' }] + + expect(service.normalize(allowedTools, tools)).toEqual(allowedTools) + }) +}) diff --git a/src/main/services/lanTransfer/LanTransferClientService.ts b/src/main/services/lanTransfer/LanTransferClientService.ts new file mode 100644 index 0000000000..a6da2f1a20 --- /dev/null +++ b/src/main/services/lanTransfer/LanTransferClientService.ts @@ -0,0 +1,525 @@ +import * as crypto from 'node:crypto' +import { createConnection, type Socket } from 'node:net' + +import { loggerService } from '@logger' +import type { + LanClientEvent, + LanFileCompleteMessage, + LanHandshakeAckMessage, + LocalTransferConnectPayload, + LocalTransferPeer +} from '@shared/config/types' +import { LAN_TRANSFER_GLOBAL_TIMEOUT_MS } from '@shared/config/types' +import { IpcChannel } from '@shared/IpcChannel' + +import { localTransferService } from '../LocalTransferService' +import { windowService } from '../WindowService' +import { + abortTransfer, + buildHandshakeMessage, + calculateFileChecksum, + cleanupTransfer, + createDataHandler, + createTransferState, + formatFileSize, + HANDSHAKE_PROTOCOL_VERSION, + pickHost, + sendFileEnd, + sendFileStart, + sendTestPing, + streamFileChunks, + validateFile, + waitForFileComplete, + waitForFileStartAck +} from './handlers' +import { ResponseManager } from './responseManager' +import type { ActiveFileTransfer, ConnectionContext, FileTransferContext } from './types' + +const DEFAULT_HANDSHAKE_TIMEOUT_MS = 10_000 + +const logger = loggerService.withContext('LanTransferClientService') + +/** + * LAN Transfer Client Service + * + * Handles outgoing file transfers to LAN peers via TCP. + * Protocol v1 with streaming mode (no per-chunk acknowledgment). + */ +class LanTransferClientService { + private socket: Socket | null = null + private currentPeer?: LocalTransferPeer + private dataHandler?: ReturnType + private responseManager = new ResponseManager() + private isConnecting = false + private activeTransfer?: ActiveFileTransfer + private lastConnectOptions?: LocalTransferConnectPayload + private consecutiveJsonErrors = 0 + private static readonly MAX_CONSECUTIVE_JSON_ERRORS = 3 + private reconnectPromise: Promise | null = null + + constructor() { + this.responseManager.setTimeoutCallback(() => void this.disconnect()) + } + + /** + * Connect to a LAN peer and perform handshake. + */ + public async connectAndHandshake(options: LocalTransferConnectPayload): Promise { + if (this.isConnecting) { + throw new Error('LAN transfer client is busy') + } + + const peer = localTransferService.getPeerById(options.peerId) + if (!peer) { + throw new Error('Selected LAN peer is no longer available') + } + if (!peer.port) { + throw new Error('Selected peer does not expose a TCP port') + } + + const host = pickHost(peer) + if (!host) { + throw new Error('Unable to resolve a reachable host for the peer') + } + + await this.disconnect() + this.isConnecting = true + + return new Promise((resolve, reject) => { + const socket = createConnection({ host, port: peer.port as number }, () => { + logger.info(`Connected to LAN peer ${peer.name} (${host}:${peer.port})`) + socket.setKeepAlive(true, 30_000) + this.socket = socket + this.currentPeer = peer + this.attachSocketListeners(socket) + + this.responseManager.waitForResponse( + 'handshake_ack', + options.timeoutMs ?? DEFAULT_HANDSHAKE_TIMEOUT_MS, + (payload) => { + const ack = payload as LanHandshakeAckMessage + if (!ack.accepted) { + const message = ack.message || 'Handshake rejected by remote device' + logger.warn(`Handshake rejected by ${peer.name}: ${message}`) + this.broadcastClientEvent({ + type: 'error', + message, + timestamp: Date.now() + }) + reject(new Error(message)) + void this.disconnect() + return + } + logger.info(`Handshake accepted by ${peer.name}`) + socket.setTimeout(0) + this.isConnecting = false + this.lastConnectOptions = options + sendTestPing(this.createConnectionContext()) + resolve(ack) + }, + (error) => { + this.isConnecting = false + reject(error) + } + ) + + const handshakeMessage = buildHandshakeMessage() + this.sendControlMessage(handshakeMessage) + }) + + socket.setTimeout(options.timeoutMs ?? DEFAULT_HANDSHAKE_TIMEOUT_MS, () => { + const error = new Error('Handshake timed out') + logger.error('LAN transfer socket timeout', error) + this.broadcastClientEvent({ + type: 'error', + message: error.message, + timestamp: Date.now() + }) + reject(error) + socket.destroy(error) + void this.disconnect() + }) + + socket.once('error', (error) => { + logger.error('LAN transfer socket error', error as Error) + const message = error instanceof Error ? error.message : String(error) + this.broadcastClientEvent({ + type: 'error', + message, + timestamp: Date.now() + }) + this.isConnecting = false + reject(error instanceof Error ? error : new Error(message)) + void this.disconnect() + }) + + socket.once('close', () => { + logger.info('LAN transfer socket closed') + if (this.socket === socket) { + this.socket = null + this.dataHandler?.resetBuffer() + this.responseManager.rejectAll(new Error('LAN transfer socket closed')) + this.currentPeer = undefined + abortTransfer(this.activeTransfer, new Error('LAN transfer socket closed')) + } + this.isConnecting = false + this.broadcastClientEvent({ + type: 'socket_closed', + reason: 'connection_closed', + timestamp: Date.now() + }) + }) + }) + } + + /** + * Disconnect from the current peer. + */ + public async disconnect(): Promise { + const socket = this.socket + if (!socket) { + return + } + + this.socket = null + this.dataHandler?.resetBuffer() + this.currentPeer = undefined + this.responseManager.rejectAll(new Error('LAN transfer socket disconnected')) + abortTransfer(this.activeTransfer, new Error('LAN transfer socket disconnected')) + + const DISCONNECT_TIMEOUT_MS = 3000 + await new Promise((resolve) => { + const timeout = setTimeout(() => { + logger.warn('Disconnect timeout, forcing cleanup') + socket.removeAllListeners() + resolve() + }, DISCONNECT_TIMEOUT_MS) + + socket.once('close', () => { + clearTimeout(timeout) + resolve() + }) + + socket.destroy() + }) + } + + /** + * Dispose the service and clean up all resources. + */ + public dispose(): void { + this.responseManager.rejectAll(new Error('LAN transfer client disposed')) + cleanupTransfer(this.activeTransfer) + this.activeTransfer = undefined + if (this.socket) { + this.socket.destroy() + this.socket = null + } + this.dataHandler?.resetBuffer() + this.isConnecting = false + } + + /** + * Send a ZIP file to the connected peer. + */ + public async sendFile(filePath: string): Promise { + await this.ensureConnection() + + if (this.activeTransfer) { + throw new Error('A file transfer is already in progress') + } + + // Validate file + const { stats, fileName } = await validateFile(filePath) + + // Calculate checksum + logger.info('Calculating file checksum...') + const checksum = await calculateFileChecksum(filePath) + logger.info(`File checksum: ${checksum.substring(0, 16)}...`) + + // Connection can drop while validating/checking file; ensure it is still ready before starting transfer. + await this.ensureConnection() + + // Initialize transfer state + const transferId = crypto.randomUUID() + this.activeTransfer = createTransferState(transferId, fileName, stats.size, checksum) + + logger.info( + `Starting file transfer: ${fileName} (${formatFileSize(stats.size)}, ${this.activeTransfer.totalChunks} chunks)` + ) + + // Global timeout + const globalTimeoutError = new Error('Transfer timed out (global timeout exceeded)') + const globalTimeoutHandle = setTimeout(() => { + logger.warn('Global transfer timeout exceeded, aborting transfer', { transferId, fileName }) + abortTransfer(this.activeTransfer, globalTimeoutError) + }, LAN_TRANSFER_GLOBAL_TIMEOUT_MS) + + try { + const result = await this.performFileTransfer(filePath, transferId, fileName) + return result + } catch (error) { + const message = error instanceof Error ? error.message : String(error) + logger.error(`File transfer failed: ${message}`) + + this.broadcastClientEvent({ + type: 'file_transfer_complete', + transferId, + fileName, + success: false, + error: message, + timestamp: Date.now() + }) + + throw error + } finally { + clearTimeout(globalTimeoutHandle) + cleanupTransfer(this.activeTransfer) + this.activeTransfer = undefined + } + } + + /** + * Cancel the current file transfer. + */ + public cancelTransfer(): void { + if (!this.activeTransfer) { + logger.warn('No active transfer to cancel') + return + } + + const { transferId, fileName } = this.activeTransfer + logger.info(`Cancelling file transfer: ${fileName}`) + + this.activeTransfer.isCancelled = true + + try { + this.sendControlMessage({ + type: 'file_cancel', + transferId, + reason: 'Cancelled by user' + }) + } catch (error) { + // Expected when connection is already broken + logger.warn('Failed to send cancel message', error as Error) + } + + abortTransfer(this.activeTransfer, new Error('Transfer cancelled by user')) + } + + // ============================================================================= + // Private Methods + // ============================================================================= + + private async ensureConnection(): Promise { + // Check socket is valid and writable (not just undestroyed) + if (this.socket && !this.socket.destroyed && this.socket.writable && this.currentPeer) { + return + } + + if (!this.lastConnectOptions) { + throw new Error('No active connection. Please connect to a peer first.') + } + + // Prevent concurrent reconnection attempts + if (this.reconnectPromise) { + logger.debug('Waiting for existing reconnection attempt...') + await this.reconnectPromise + return + } + + logger.info('Connection lost, attempting to reconnect...') + this.reconnectPromise = this.connectAndHandshake(this.lastConnectOptions) + .then(() => { + // Handshake succeeded, connection restored + }) + .finally(() => { + this.reconnectPromise = null + }) + + await this.reconnectPromise + } + + private async performFileTransfer( + filePath: string, + transferId: string, + fileName: string + ): Promise { + const transfer = this.activeTransfer! + const ctx = this.createFileTransferContext() + + // Step 1: Send file_start + sendFileStart(ctx, transfer) + + // Step 2: Wait for file_start_ack + const startAck = await waitForFileStartAck(ctx, transferId, transfer.abortController.signal) + if (!startAck.accepted) { + throw new Error(startAck.message || 'Transfer rejected by receiver') + } + logger.info('Received file_start_ack: accepted') + + // Step 3: Stream file chunks + await streamFileChunks(this.socket!, filePath, transfer, transfer.abortController.signal, (bytesSent, chunkIndex) => + this.onTransferProgress(transfer, bytesSent, chunkIndex) + ) + + // Step 4: Send file_end + sendFileEnd(ctx, transferId) + + // Step 5: Wait for file_complete + const result = await waitForFileComplete(ctx, transferId, transfer.abortController.signal) + logger.info(`File transfer ${result.success ? 'completed' : 'failed'}`) + + // Broadcast completion + this.broadcastClientEvent({ + type: 'file_transfer_complete', + transferId, + fileName, + success: result.success, + filePath: result.filePath, + error: result.error, + timestamp: Date.now() + }) + + return result + } + + private onTransferProgress(transfer: ActiveFileTransfer, bytesSent: number, chunkIndex: number): void { + const progress = (bytesSent / transfer.fileSize) * 100 + const elapsed = (Date.now() - transfer.startedAt) / 1000 + const speed = elapsed > 0 ? bytesSent / elapsed : 0 + + this.broadcastClientEvent({ + type: 'file_transfer_progress', + transferId: transfer.transferId, + fileName: transfer.fileName, + bytesSent, + totalBytes: transfer.fileSize, + chunkIndex, + totalChunks: transfer.totalChunks, + progress: Math.round(progress * 100) / 100, + speed, + timestamp: Date.now() + }) + } + + private attachSocketListeners(socket: Socket): void { + this.dataHandler = createDataHandler((line) => this.handleControlLine(line)) + socket.on('data', (chunk: Buffer) => { + try { + this.dataHandler?.handleData(chunk) + } catch (error) { + logger.error('Data handler error', error as Error) + void this.disconnect() + } + }) + } + + private handleControlLine(line: string): void { + let payload: Record + try { + payload = JSON.parse(line) + this.consecutiveJsonErrors = 0 // Reset on successful parse + } catch { + this.consecutiveJsonErrors++ + logger.warn('Received invalid JSON control message', { line, consecutiveErrors: this.consecutiveJsonErrors }) + + if (this.consecutiveJsonErrors >= LanTransferClientService.MAX_CONSECUTIVE_JSON_ERRORS) { + const message = `Protocol error: ${this.consecutiveJsonErrors} consecutive invalid messages, disconnecting` + logger.error(message) + this.broadcastClientEvent({ + type: 'error', + message, + timestamp: Date.now() + }) + void this.disconnect() + } + return + } + + const type = payload?.type as string | undefined + if (!type) { + logger.warn('Received control message without type', payload) + return + } + + // Try to resolve a pending response + const transferId = payload?.transferId as string | undefined + const chunkIndex = payload?.chunkIndex as number | undefined + if (this.responseManager.tryResolve(type, payload, transferId, chunkIndex)) { + return + } + + logger.info('Received control message', payload) + + if (type === 'pong') { + this.broadcastClientEvent({ + type: 'pong', + payload: payload?.payload as string | undefined, + received: payload?.received as boolean | undefined, + timestamp: Date.now() + }) + return + } + + // Ignore late-arriving file transfer messages + const fileTransferMessageTypes = ['file_start_ack', 'file_complete'] + if (fileTransferMessageTypes.includes(type)) { + logger.debug('Ignoring late file transfer message', { type, payload }) + return + } + + this.broadcastClientEvent({ + type: 'error', + message: `Unexpected control message type: ${type}`, + timestamp: Date.now() + }) + } + + private sendControlMessage(message: Record): void { + if (!this.socket || this.socket.destroyed || !this.socket.writable) { + throw new Error('Socket is not connected') + } + const payload = JSON.stringify(message) + this.socket.write(`${payload}\n`) + } + + private createConnectionContext(): ConnectionContext { + return { + socket: this.socket, + currentPeer: this.currentPeer, + sendControlMessage: (msg) => this.sendControlMessage(msg), + broadcastClientEvent: (event) => this.broadcastClientEvent(event) + } + } + + private createFileTransferContext(): FileTransferContext { + return { + ...this.createConnectionContext(), + activeTransfer: this.activeTransfer, + setActiveTransfer: (transfer) => { + this.activeTransfer = transfer + }, + waitForResponse: (type, timeoutMs, resolve, reject, transferId, chunkIndex, abortSignal) => { + this.responseManager.waitForResponse(type, timeoutMs, resolve, reject, transferId, chunkIndex, abortSignal) + } + } + } + + private broadcastClientEvent(event: LanClientEvent): void { + const mainWindow = windowService.getMainWindow() + if (!mainWindow || mainWindow.isDestroyed()) { + return + } + mainWindow.webContents.send(IpcChannel.LocalTransfer_ClientEvent, { + ...event, + peerId: event.peerId ?? this.currentPeer?.id, + peerName: event.peerName ?? this.currentPeer?.name + }) + } +} + +export const lanTransferClientService = new LanTransferClientService() + +// Re-export for backward compatibility +export { HANDSHAKE_PROTOCOL_VERSION } diff --git a/src/main/services/lanTransfer/__tests__/LanTransferClientService.test.ts b/src/main/services/lanTransfer/__tests__/LanTransferClientService.test.ts new file mode 100644 index 0000000000..16f188aa93 --- /dev/null +++ b/src/main/services/lanTransfer/__tests__/LanTransferClientService.test.ts @@ -0,0 +1,133 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// Mock dependencies before importing the service +vi.mock('node:net', async (importOriginal) => { + const actual = (await importOriginal()) as Record + return { + ...actual, + createConnection: vi.fn() + } +}) + +vi.mock('electron', () => ({ + app: { + getName: vi.fn(() => 'Cherry Studio'), + getVersion: vi.fn(() => '1.0.0') + } +})) + +vi.mock('../../LocalTransferService', () => ({ + localTransferService: { + getPeerById: vi.fn() + } +})) + +vi.mock('../../WindowService', () => ({ + windowService: { + getMainWindow: vi.fn(() => ({ + isDestroyed: () => false, + webContents: { + send: vi.fn() + } + })) + } +})) + +// Import after mocks +import { localTransferService } from '../../LocalTransferService' + +describe('LanTransferClientService', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.resetModules() + }) + + afterEach(() => { + vi.resetAllMocks() + }) + + describe('connectAndHandshake - validation', () => { + it('should throw error when peer is not found', async () => { + vi.mocked(localTransferService.getPeerById).mockReturnValue(undefined) + + const { lanTransferClientService } = await import('../LanTransferClientService') + + await expect( + lanTransferClientService.connectAndHandshake({ + peerId: 'non-existent' + }) + ).rejects.toThrow('Selected LAN peer is no longer available') + }) + + it('should throw error when peer has no port', async () => { + vi.mocked(localTransferService.getPeerById).mockReturnValue({ + id: 'test-peer', + name: 'Test Peer', + addresses: ['192.168.1.100'], + updatedAt: Date.now() + }) + + const { lanTransferClientService } = await import('../LanTransferClientService') + + await expect( + lanTransferClientService.connectAndHandshake({ + peerId: 'test-peer' + }) + ).rejects.toThrow('Selected peer does not expose a TCP port') + }) + + it('should throw error when no reachable host', async () => { + vi.mocked(localTransferService.getPeerById).mockReturnValue({ + id: 'test-peer', + name: 'Test Peer', + port: 12345, + addresses: [], + updatedAt: Date.now() + }) + + const { lanTransferClientService } = await import('../LanTransferClientService') + + await expect( + lanTransferClientService.connectAndHandshake({ + peerId: 'test-peer' + }) + ).rejects.toThrow('Unable to resolve a reachable host for the peer') + }) + }) + + describe('cancelTransfer', () => { + it('should not throw when no active transfer', async () => { + const { lanTransferClientService } = await import('../LanTransferClientService') + + // Should not throw, just log warning + expect(() => lanTransferClientService.cancelTransfer()).not.toThrow() + }) + }) + + describe('dispose', () => { + it('should clean up resources without throwing', async () => { + const { lanTransferClientService } = await import('../LanTransferClientService') + + // Should not throw + expect(() => lanTransferClientService.dispose()).not.toThrow() + }) + }) + + describe('sendFile', () => { + it('should throw error when not connected', async () => { + const { lanTransferClientService } = await import('../LanTransferClientService') + + await expect(lanTransferClientService.sendFile('/path/to/file.zip')).rejects.toThrow( + 'No active connection. Please connect to a peer first.' + ) + }) + }) + + describe('HANDSHAKE_PROTOCOL_VERSION', () => { + it('should export protocol version', async () => { + const { HANDSHAKE_PROTOCOL_VERSION } = await import('../LanTransferClientService') + + expect(HANDSHAKE_PROTOCOL_VERSION).toBe('1') + }) + }) +}) diff --git a/src/main/services/lanTransfer/__tests__/binaryProtocol.test.ts b/src/main/services/lanTransfer/__tests__/binaryProtocol.test.ts new file mode 100644 index 0000000000..c485a33098 --- /dev/null +++ b/src/main/services/lanTransfer/__tests__/binaryProtocol.test.ts @@ -0,0 +1,103 @@ +import { EventEmitter } from 'node:events' +import type { Socket } from 'node:net' + +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { BINARY_TYPE_FILE_CHUNK, sendBinaryChunk } from '../binaryProtocol' + +describe('binaryProtocol', () => { + describe('sendBinaryChunk', () => { + let mockSocket: Socket + let writtenBuffers: Buffer[] + + beforeEach(() => { + writtenBuffers = [] + mockSocket = Object.assign(new EventEmitter(), { + destroyed: false, + writable: true, + write: vi.fn((buffer: Buffer) => { + writtenBuffers.push(Buffer.from(buffer)) + return true + }), + cork: vi.fn(), + uncork: vi.fn() + }) as unknown as Socket + }) + + it('should send binary chunk with correct frame format', () => { + const transferId = 'test-uuid-1234' + const chunkIndex = 5 + const data = Buffer.from('test data chunk') + + const result = sendBinaryChunk(mockSocket, transferId, chunkIndex, data) + + expect(result).toBe(true) + expect(mockSocket.cork).toHaveBeenCalled() + expect(mockSocket.uncork).toHaveBeenCalled() + expect(mockSocket.write).toHaveBeenCalledTimes(2) + + // Verify header structure + const header = writtenBuffers[0] + + // Magic bytes "CS" + expect(header[0]).toBe(0x43) + expect(header[1]).toBe(0x53) + + // Type byte + const typeOffset = 2 + 4 // magic + totalLen + expect(header[typeOffset]).toBe(BINARY_TYPE_FILE_CHUNK) + + // TransferId length + const tidLenOffset = typeOffset + 1 + const tidLen = header.readUInt16BE(tidLenOffset) + expect(tidLen).toBe(Buffer.from(transferId).length) + + // ChunkIndex + const chunkIdxOffset = tidLenOffset + 2 + tidLen + expect(header.readUInt32BE(chunkIdxOffset)).toBe(chunkIndex) + + // Data buffer + expect(writtenBuffers[1].toString()).toBe('test data chunk') + }) + + it('should return false when socket write returns false (backpressure)', () => { + ;(mockSocket.write as ReturnType).mockReturnValueOnce(false) + + const result = sendBinaryChunk(mockSocket, 'test-id', 0, Buffer.from('data')) + + expect(result).toBe(false) + }) + + it('should correctly calculate totalLen in frame header', () => { + const transferId = 'uuid-1234' + const data = Buffer.from('chunk data here') + + sendBinaryChunk(mockSocket, transferId, 0, data) + + const header = writtenBuffers[0] + const totalLen = header.readUInt32BE(2) // After magic bytes + + // totalLen = type(1) + tidLen(2) + tid(n) + idx(4) + data(m) + const expectedTotalLen = 1 + 2 + Buffer.from(transferId).length + 4 + data.length + expect(totalLen).toBe(expectedTotalLen) + }) + + it('should throw error when socket is not writable', () => { + ;(mockSocket as any).writable = false + + expect(() => sendBinaryChunk(mockSocket, 'test-id', 0, Buffer.from('data'))).toThrow('Socket is not writable') + }) + + it('should throw error when socket is destroyed', () => { + ;(mockSocket as any).destroyed = true + + expect(() => sendBinaryChunk(mockSocket, 'test-id', 0, Buffer.from('data'))).toThrow('Socket is not writable') + }) + }) + + describe('BINARY_TYPE_FILE_CHUNK', () => { + it('should be 0x01', () => { + expect(BINARY_TYPE_FILE_CHUNK).toBe(0x01) + }) + }) +}) diff --git a/src/main/services/lanTransfer/__tests__/handlers/connection.test.ts b/src/main/services/lanTransfer/__tests__/handlers/connection.test.ts new file mode 100644 index 0000000000..3983e538d3 --- /dev/null +++ b/src/main/services/lanTransfer/__tests__/handlers/connection.test.ts @@ -0,0 +1,265 @@ +import { EventEmitter } from 'node:events' +import type { Socket } from 'node:net' + +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import { + buildHandshakeMessage, + createDataHandler, + getAbortError, + HANDSHAKE_PROTOCOL_VERSION, + pickHost, + waitForSocketDrain +} from '../../handlers/connection' + +// Mock electron app +vi.mock('electron', () => ({ + app: { + getName: vi.fn(() => 'Cherry Studio'), + getVersion: vi.fn(() => '1.0.0') + } +})) + +describe('connection handlers', () => { + describe('buildHandshakeMessage', () => { + it('should build handshake message with correct structure', () => { + const message = buildHandshakeMessage() + + expect(message.type).toBe('handshake') + expect(message.deviceName).toBe('Cherry Studio') + expect(message.version).toBe(HANDSHAKE_PROTOCOL_VERSION) + expect(message.appVersion).toBe('1.0.0') + expect(typeof message.platform).toBe('string') + }) + + it('should use protocol version 1', () => { + expect(HANDSHAKE_PROTOCOL_VERSION).toBe('1') + }) + }) + + describe('pickHost', () => { + it('should prefer IPv4 addresses', () => { + const peer = { + id: '1', + name: 'Test', + addresses: ['fe80::1', '192.168.1.100', '::1'], + updatedAt: Date.now() + } + + expect(pickHost(peer)).toBe('192.168.1.100') + }) + + it('should fall back to first address if no IPv4', () => { + const peer = { + id: '1', + name: 'Test', + addresses: ['fe80::1', '::1'], + updatedAt: Date.now() + } + + expect(pickHost(peer)).toBe('fe80::1') + }) + + it('should fall back to host property if no addresses', () => { + const peer = { + id: '1', + name: 'Test', + host: 'example.local', + addresses: [], + updatedAt: Date.now() + } + + expect(pickHost(peer)).toBe('example.local') + }) + + it('should return undefined if no addresses or host', () => { + const peer = { + id: '1', + name: 'Test', + addresses: [], + updatedAt: Date.now() + } + + expect(pickHost(peer)).toBeUndefined() + }) + }) + + describe('createDataHandler', () => { + it('should parse complete lines from buffer', () => { + const lines: string[] = [] + const handler = createDataHandler((line) => lines.push(line)) + + handler.handleData(Buffer.from('{"type":"test"}\n')) + + expect(lines).toEqual(['{"type":"test"}']) + }) + + it('should handle partial lines across multiple chunks', () => { + const lines: string[] = [] + const handler = createDataHandler((line) => lines.push(line)) + + handler.handleData(Buffer.from('{"type":')) + handler.handleData(Buffer.from('"test"}\n')) + + expect(lines).toEqual(['{"type":"test"}']) + }) + + it('should handle multiple lines in single chunk', () => { + const lines: string[] = [] + const handler = createDataHandler((line) => lines.push(line)) + + handler.handleData(Buffer.from('{"a":1}\n{"b":2}\n')) + + expect(lines).toEqual(['{"a":1}', '{"b":2}']) + }) + + it('should reset buffer', () => { + const lines: string[] = [] + const handler = createDataHandler((line) => lines.push(line)) + + handler.handleData(Buffer.from('partial')) + handler.resetBuffer() + handler.handleData(Buffer.from('{"complete":true}\n')) + + expect(lines).toEqual(['{"complete":true}']) + }) + + it('should trim whitespace from lines', () => { + const lines: string[] = [] + const handler = createDataHandler((line) => lines.push(line)) + + handler.handleData(Buffer.from(' {"type":"test"} \n')) + + expect(lines).toEqual(['{"type":"test"}']) + }) + + it('should skip empty lines', () => { + const lines: string[] = [] + const handler = createDataHandler((line) => lines.push(line)) + + handler.handleData(Buffer.from('\n\n{"type":"test"}\n\n')) + + expect(lines).toEqual(['{"type":"test"}']) + }) + + it('should throw error when buffer exceeds MAX_LINE_BUFFER_SIZE', () => { + const handler = createDataHandler(vi.fn()) + + // Create a buffer larger than 1MB (MAX_LINE_BUFFER_SIZE) + const largeData = 'x'.repeat(1024 * 1024 + 1) + + expect(() => handler.handleData(Buffer.from(largeData))).toThrow('Control message too large') + }) + + it('should reset buffer after exceeding MAX_LINE_BUFFER_SIZE', () => { + const lines: string[] = [] + const handler = createDataHandler((line) => lines.push(line)) + + // Create a buffer larger than 1MB + const largeData = 'x'.repeat(1024 * 1024 + 1) + + try { + handler.handleData(Buffer.from(largeData)) + } catch { + // Expected error + } + + // Buffer should be reset, so lineBuffer should be empty + expect(handler.lineBuffer).toBe('') + }) + }) + + describe('waitForSocketDrain', () => { + let mockSocket: Socket & EventEmitter + + beforeEach(() => { + mockSocket = Object.assign(new EventEmitter(), { + destroyed: false, + writable: true, + write: vi.fn(), + off: vi.fn(), + removeAllListeners: vi.fn() + }) as unknown as Socket & EventEmitter + }) + + afterEach(() => { + vi.resetAllMocks() + }) + + it('should throw error when abort signal is already aborted', async () => { + const abortController = new AbortController() + abortController.abort(new Error('Already aborted')) + + await expect(waitForSocketDrain(mockSocket, abortController.signal)).rejects.toThrow('Already aborted') + }) + + it('should throw error when socket is destroyed', async () => { + ;(mockSocket as any).destroyed = true + const abortController = new AbortController() + + await expect(waitForSocketDrain(mockSocket, abortController.signal)).rejects.toThrow('Socket is closed') + }) + + it('should resolve when drain event is emitted', async () => { + const abortController = new AbortController() + + const drainPromise = waitForSocketDrain(mockSocket, abortController.signal) + + // Emit drain event after a short delay + setImmediate(() => mockSocket.emit('drain')) + + await expect(drainPromise).resolves.toBeUndefined() + }) + + it('should reject when close event is emitted', async () => { + const abortController = new AbortController() + + const drainPromise = waitForSocketDrain(mockSocket, abortController.signal) + + setImmediate(() => mockSocket.emit('close')) + + await expect(drainPromise).rejects.toThrow('Socket closed while waiting for drain') + }) + + it('should reject when error event is emitted', async () => { + const abortController = new AbortController() + + const drainPromise = waitForSocketDrain(mockSocket, abortController.signal) + + setImmediate(() => mockSocket.emit('error', new Error('Network error'))) + + await expect(drainPromise).rejects.toThrow('Network error') + }) + + it('should reject when abort signal is triggered', async () => { + const abortController = new AbortController() + + const drainPromise = waitForSocketDrain(mockSocket, abortController.signal) + + setImmediate(() => abortController.abort(new Error('User cancelled'))) + + await expect(drainPromise).rejects.toThrow('User cancelled') + }) + }) + + describe('getAbortError', () => { + it('should return Error reason directly', () => { + const originalError = new Error('Original') + const signal = { aborted: true, reason: originalError } as AbortSignal + + expect(getAbortError(signal, 'Fallback')).toBe(originalError) + }) + + it('should create Error from string reason', () => { + const signal = { aborted: true, reason: 'String reason' } as AbortSignal + + expect(getAbortError(signal, 'Fallback').message).toBe('String reason') + }) + + it('should use fallback for empty reason', () => { + const signal = { aborted: true, reason: '' } as AbortSignal + + expect(getAbortError(signal, 'Fallback').message).toBe('Fallback') + }) + }) +}) diff --git a/src/main/services/lanTransfer/__tests__/handlers/fileTransfer.test.ts b/src/main/services/lanTransfer/__tests__/handlers/fileTransfer.test.ts new file mode 100644 index 0000000000..814fd2f5c9 --- /dev/null +++ b/src/main/services/lanTransfer/__tests__/handlers/fileTransfer.test.ts @@ -0,0 +1,216 @@ +import { EventEmitter } from 'node:events' +import type * as fs from 'node:fs' +import type { Socket } from 'node:net' + +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import { + abortTransfer, + cleanupTransfer, + createTransferState, + formatFileSize, + streamFileChunks +} from '../../handlers/fileTransfer' +import type { ActiveFileTransfer } from '../../types' + +// Mock binaryProtocol +vi.mock('../../binaryProtocol', () => ({ + sendBinaryChunk: vi.fn().mockReturnValue(true) +})) + +// Mock connection handlers +vi.mock('./connection', () => ({ + waitForSocketDrain: vi.fn().mockResolvedValue(undefined), + getAbortError: vi.fn((signal, fallback) => { + const reason = (signal as AbortSignal & { reason?: unknown }).reason + if (reason instanceof Error) return reason + if (typeof reason === 'string' && reason.length > 0) return new Error(reason) + return new Error(fallback) + }) +})) + +// Note: validateFile and calculateFileChecksum tests are skipped because +// the test environment has globally mocked node:fs and node:os modules. +// These functions are tested through integration tests instead. + +describe('fileTransfer handlers', () => { + describe('createTransferState', () => { + it('should create transfer state with correct defaults', () => { + const state = createTransferState('uuid-123', 'test.zip', 1024000, 'abc123') + + expect(state.transferId).toBe('uuid-123') + expect(state.fileName).toBe('test.zip') + expect(state.fileSize).toBe(1024000) + expect(state.checksum).toBe('abc123') + expect(state.bytesSent).toBe(0) + expect(state.currentChunk).toBe(0) + expect(state.isCancelled).toBe(false) + expect(state.abortController).toBeInstanceOf(AbortController) + }) + + it('should calculate totalChunks based on chunk size', () => { + // 512KB chunk size + const state = createTransferState('id', 'test.zip', 1024 * 1024, 'checksum') // 1MB + + expect(state.totalChunks).toBe(2) // 1MB / 512KB = 2 + }) + }) + + describe('abortTransfer', () => { + it('should abort transfer and destroy stream', () => { + const mockStream = { + destroyed: false, + destroy: vi.fn() + } as unknown as fs.ReadStream + + const transfer: ActiveFileTransfer = { + transferId: 'test', + fileName: 'test.zip', + fileSize: 1000, + checksum: 'abc', + totalChunks: 1, + chunkSize: 512000, + bytesSent: 0, + currentChunk: 0, + startedAt: Date.now(), + stream: mockStream, + isCancelled: false, + abortController: new AbortController() + } + + const error = new Error('Test abort') + abortTransfer(transfer, error) + + expect(transfer.isCancelled).toBe(true) + expect(transfer.abortController.signal.aborted).toBe(true) + expect(mockStream.destroy).toHaveBeenCalledWith(error) + }) + + it('should handle undefined transfer', () => { + expect(() => abortTransfer(undefined, new Error('test'))).not.toThrow() + }) + + it('should not abort already aborted controller', () => { + const transfer: ActiveFileTransfer = { + transferId: 'test', + fileName: 'test.zip', + fileSize: 1000, + checksum: 'abc', + totalChunks: 1, + chunkSize: 512000, + bytesSent: 0, + currentChunk: 0, + startedAt: Date.now(), + isCancelled: false, + abortController: new AbortController() + } + + transfer.abortController.abort() + + // Should not throw when aborting again + expect(() => abortTransfer(transfer, new Error('test'))).not.toThrow() + }) + }) + + describe('cleanupTransfer', () => { + it('should cleanup transfer resources', () => { + const mockStream = { + destroyed: false, + destroy: vi.fn() + } as unknown as fs.ReadStream + + const transfer: ActiveFileTransfer = { + transferId: 'test', + fileName: 'test.zip', + fileSize: 1000, + checksum: 'abc', + totalChunks: 1, + chunkSize: 512000, + bytesSent: 0, + currentChunk: 0, + startedAt: Date.now(), + stream: mockStream, + isCancelled: false, + abortController: new AbortController() + } + + cleanupTransfer(transfer) + + expect(transfer.abortController.signal.aborted).toBe(true) + expect(mockStream.destroy).toHaveBeenCalled() + }) + + it('should handle undefined transfer', () => { + expect(() => cleanupTransfer(undefined)).not.toThrow() + }) + }) + + describe('formatFileSize', () => { + it('should format 0 bytes', () => { + expect(formatFileSize(0)).toBe('0 B') + }) + + it('should format bytes', () => { + expect(formatFileSize(500)).toBe('500 B') + }) + + it('should format kilobytes', () => { + expect(formatFileSize(1024)).toBe('1 KB') + expect(formatFileSize(2048)).toBe('2 KB') + }) + + it('should format megabytes', () => { + expect(formatFileSize(1024 * 1024)).toBe('1 MB') + expect(formatFileSize(5 * 1024 * 1024)).toBe('5 MB') + }) + + it('should format gigabytes', () => { + expect(formatFileSize(1024 * 1024 * 1024)).toBe('1 GB') + }) + + it('should format with decimal precision', () => { + expect(formatFileSize(1536)).toBe('1.5 KB') + expect(formatFileSize(1.5 * 1024 * 1024)).toBe('1.5 MB') + }) + }) + + // Note: streamFileChunks tests require careful mocking of fs.createReadStream + // which is globally mocked in the test environment. These tests verify the + // streaming logic works correctly with mock streams. + describe('streamFileChunks', () => { + let mockSocket: Socket & EventEmitter + let mockProgress: ReturnType + + beforeEach(() => { + vi.clearAllMocks() + + mockSocket = Object.assign(new EventEmitter(), { + destroyed: false, + writable: true, + write: vi.fn().mockReturnValue(true), + cork: vi.fn(), + uncork: vi.fn() + }) as unknown as Socket & EventEmitter + + mockProgress = vi.fn() + }) + + afterEach(() => { + vi.resetAllMocks() + }) + + it('should throw when abort signal is already aborted', async () => { + const transfer = createTransferState('test-id', 'test.zip', 1024, 'checksum') + transfer.abortController.abort(new Error('Already cancelled')) + + await expect( + streamFileChunks(mockSocket, '/fake/path.zip', transfer, transfer.abortController.signal, mockProgress) + ).rejects.toThrow() + }) + + // Note: Full integration testing of streamFileChunks with actual file streaming + // requires a real file system, which cannot be easily mocked in ESM. + // The abort signal test above verifies the early abort path. + // Additional streaming tests are covered through integration tests. + }) +}) diff --git a/src/main/services/lanTransfer/__tests__/responseManager.test.ts b/src/main/services/lanTransfer/__tests__/responseManager.test.ts new file mode 100644 index 0000000000..170ee2de8c --- /dev/null +++ b/src/main/services/lanTransfer/__tests__/responseManager.test.ts @@ -0,0 +1,177 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import { ResponseManager } from '../responseManager' + +describe('ResponseManager', () => { + let manager: ResponseManager + + beforeEach(() => { + vi.useFakeTimers() + manager = new ResponseManager() + }) + + afterEach(() => { + vi.useRealTimers() + }) + + describe('buildResponseKey', () => { + it('should build key with type only', () => { + expect(manager.buildResponseKey('handshake_ack')).toBe('handshake_ack') + }) + + it('should build key with type and transferId', () => { + expect(manager.buildResponseKey('file_start_ack', 'uuid-123')).toBe('file_start_ack:uuid-123') + }) + + it('should build key with type, transferId, and chunkIndex', () => { + expect(manager.buildResponseKey('file_chunk_ack', 'uuid-123', 5)).toBe('file_chunk_ack:uuid-123:5') + }) + }) + + describe('waitForResponse', () => { + it('should resolve when tryResolve is called with matching key', async () => { + const resolvePromise = new Promise((resolve, reject) => { + manager.waitForResponse('handshake_ack', 5000, resolve, reject) + }) + + const payload = { type: 'handshake_ack', accepted: true } + const resolved = manager.tryResolve('handshake_ack', payload) + + expect(resolved).toBe(true) + await expect(resolvePromise).resolves.toEqual(payload) + }) + + it('should reject on timeout', async () => { + const resolvePromise = new Promise((resolve, reject) => { + manager.waitForResponse('handshake_ack', 1000, resolve, reject) + }) + + vi.advanceTimersByTime(1001) + + await expect(resolvePromise).rejects.toThrow('Timeout waiting for handshake_ack') + }) + + it('should call onTimeout callback when timeout occurs', async () => { + const onTimeout = vi.fn() + manager.setTimeoutCallback(onTimeout) + + const resolvePromise = new Promise((resolve, reject) => { + manager.waitForResponse('test', 1000, resolve, reject) + }) + + vi.advanceTimersByTime(1001) + + await expect(resolvePromise).rejects.toThrow() + expect(onTimeout).toHaveBeenCalled() + }) + + it('should reject when abort signal is triggered', async () => { + const abortController = new AbortController() + + const resolvePromise = new Promise((resolve, reject) => { + manager.waitForResponse('test', 10000, resolve, reject, undefined, undefined, abortController.signal) + }) + + abortController.abort(new Error('User cancelled')) + + await expect(resolvePromise).rejects.toThrow('User cancelled') + }) + + it('should replace existing response with same key', async () => { + const firstReject = vi.fn() + const secondResolve = vi.fn() + const secondReject = vi.fn() + + manager.waitForResponse('test', 5000, vi.fn(), firstReject) + manager.waitForResponse('test', 5000, secondResolve, secondReject) + + // First should be cleared (no rejection since it's replaced) + const payload = { type: 'test' } + manager.tryResolve('test', payload) + + expect(secondResolve).toHaveBeenCalledWith(payload) + }) + }) + + describe('tryResolve', () => { + it('should return false when no matching response', () => { + expect(manager.tryResolve('nonexistent', {})).toBe(false) + }) + + it('should match with transferId', async () => { + const resolvePromise = new Promise((resolve, reject) => { + manager.waitForResponse('file_start_ack', 5000, resolve, reject, 'uuid-123') + }) + + const payload = { type: 'file_start_ack', transferId: 'uuid-123' } + manager.tryResolve('file_start_ack', payload, 'uuid-123') + + await expect(resolvePromise).resolves.toEqual(payload) + }) + }) + + describe('rejectAll', () => { + it('should reject all pending responses', async () => { + const promises = [ + new Promise((resolve, reject) => { + manager.waitForResponse('test1', 5000, resolve, reject) + }), + new Promise((resolve, reject) => { + manager.waitForResponse('test2', 5000, resolve, reject, 'uuid') + }) + ] + + manager.rejectAll(new Error('Connection closed')) + + await expect(promises[0]).rejects.toThrow('Connection closed') + await expect(promises[1]).rejects.toThrow('Connection closed') + }) + }) + + describe('clearPendingResponse', () => { + it('should clear specific response by key', () => { + manager.waitForResponse('test', 5000, vi.fn(), vi.fn()) + + manager.clearPendingResponse('test') + + expect(manager.tryResolve('test', {})).toBe(false) + }) + + it('should clear all responses when no key provided', () => { + manager.waitForResponse('test1', 5000, vi.fn(), vi.fn()) + manager.waitForResponse('test2', 5000, vi.fn(), vi.fn()) + + manager.clearPendingResponse() + + expect(manager.tryResolve('test1', {})).toBe(false) + expect(manager.tryResolve('test2', {})).toBe(false) + }) + }) + + describe('getAbortError', () => { + it('should return Error reason directly', () => { + const originalError = new Error('Original error') + const signal = { aborted: true, reason: originalError } as AbortSignal + + const error = manager.getAbortError(signal, 'Fallback') + + expect(error).toBe(originalError) + }) + + it('should create Error from string reason', () => { + const signal = { aborted: true, reason: 'String reason' } as AbortSignal + + const error = manager.getAbortError(signal, 'Fallback') + + expect(error.message).toBe('String reason') + }) + + it('should use fallback message when no reason', () => { + const signal = { aborted: true } as AbortSignal + + const error = manager.getAbortError(signal, 'Fallback message') + + expect(error.message).toBe('Fallback message') + }) + }) +}) diff --git a/src/main/services/lanTransfer/binaryProtocol.ts b/src/main/services/lanTransfer/binaryProtocol.ts new file mode 100644 index 0000000000..864a8b95bd --- /dev/null +++ b/src/main/services/lanTransfer/binaryProtocol.ts @@ -0,0 +1,67 @@ +import type { Socket } from 'node:net' + +/** + * Binary protocol constants (v1) + */ +export const BINARY_TYPE_FILE_CHUNK = 0x01 + +/** + * Send file chunk as binary frame (protocol v1 - streaming mode) + * + * Frame format: + * ``` + * ┌──────────┬──────────┬──────────┬───────────────┬──────────────┬────────────┬───────────┐ + * │ Magic │ TotalLen │ Type │ TransferId Len│ TransferId │ ChunkIdx │ Data │ + * │ 0x43 0x53│ (4B BE) │ 0x01 │ (2B BE) │ (variable) │ (4B BE) │ (raw) │ + * └──────────┴──────────┴──────────┴───────────────┴──────────────┴────────────┴───────────┘ + * ``` + * + * @param socket - TCP socket to write to + * @param transferId - UUID of the transfer + * @param chunkIndex - Index of the chunk (0-based) + * @param data - Raw chunk data buffer + * @returns true if data was buffered, false if backpressure should be applied + */ +export function sendBinaryChunk(socket: Socket, transferId: string, chunkIndex: number, data: Buffer): boolean { + if (!socket || socket.destroyed || !socket.writable) { + throw new Error('Socket is not writable') + } + + const tidBuffer = Buffer.from(transferId, 'utf8') + const tidLen = tidBuffer.length + + // totalLen = type(1) + tidLen(2) + tid(n) + idx(4) + data(m) + const totalLen = 1 + 2 + tidLen + 4 + data.length + + const header = Buffer.allocUnsafe(2 + 4 + 1 + 2 + tidLen + 4) + let offset = 0 + + // Magic (2 bytes): "CS" + header[offset++] = 0x43 + header[offset++] = 0x53 + + // TotalLen (4 bytes, Big-Endian) + header.writeUInt32BE(totalLen, offset) + offset += 4 + + // Type (1 byte) + header[offset++] = BINARY_TYPE_FILE_CHUNK + + // TransferId length (2 bytes, Big-Endian) + header.writeUInt16BE(tidLen, offset) + offset += 2 + + // TransferId (variable) + tidBuffer.copy(header, offset) + offset += tidLen + + // ChunkIndex (4 bytes, Big-Endian) + header.writeUInt32BE(chunkIndex, offset) + + socket.cork() + const wroteHeader = socket.write(header) + const wroteData = socket.write(data) + socket.uncork() + + return wroteHeader && wroteData +} diff --git a/src/main/services/lanTransfer/handlers/connection.ts b/src/main/services/lanTransfer/handlers/connection.ts new file mode 100644 index 0000000000..5a53eeb373 --- /dev/null +++ b/src/main/services/lanTransfer/handlers/connection.ts @@ -0,0 +1,162 @@ +import { isIP, type Socket } from 'node:net' +import { platform } from 'node:os' + +import { loggerService } from '@logger' +import type { LanHandshakeRequestMessage, LocalTransferPeer } from '@shared/config/types' +import { app } from 'electron' + +import type { ConnectionContext } from '../types' + +export const HANDSHAKE_PROTOCOL_VERSION = '1' + +/** Maximum size for line buffer to prevent memory exhaustion from malicious peers */ +const MAX_LINE_BUFFER_SIZE = 1024 * 1024 // 1MB limit for control messages + +const logger = loggerService.withContext('LanTransferConnection') + +/** + * Build a handshake request message with device info. + */ +export function buildHandshakeMessage(): LanHandshakeRequestMessage { + return { + type: 'handshake', + deviceName: app.getName(), + version: HANDSHAKE_PROTOCOL_VERSION, + platform: platform(), + appVersion: app.getVersion() + } +} + +/** + * Pick the best host address from a peer's available addresses. + * Prefers IPv4 addresses over IPv6. + */ +export function pickHost(peer: LocalTransferPeer): string | undefined { + const preferred = peer.addresses?.find((addr) => isIP(addr) === 4) || peer.addresses?.[0] + return preferred || peer.host +} + +/** + * Send a test ping message after successful handshake. + */ +export function sendTestPing(ctx: ConnectionContext): void { + const payload = 'hello world' + try { + ctx.sendControlMessage({ type: 'ping', payload }) + logger.info('Sent LAN ping test payload') + ctx.broadcastClientEvent({ + type: 'ping_sent', + payload, + timestamp: Date.now() + }) + } catch (error) { + const message = error instanceof Error ? error.message : String(error) + logger.error('Failed to send LAN test ping', error as Error) + ctx.broadcastClientEvent({ + type: 'error', + message, + timestamp: Date.now() + }) + } +} + +/** + * Attach data listener to socket for receiving control messages. + * Returns a function to parse the line buffer. + */ +export function createDataHandler(onControlLine: (line: string) => void): { + lineBuffer: string + handleData: (chunk: Buffer) => void + resetBuffer: () => void +} { + let lineBuffer = '' + + return { + get lineBuffer() { + return lineBuffer + }, + handleData(chunk: Buffer) { + lineBuffer += chunk.toString('utf8') + + // Prevent memory exhaustion from malicious peers sending data without newlines + if (lineBuffer.length > MAX_LINE_BUFFER_SIZE) { + logger.error('Line buffer exceeded maximum size, resetting') + lineBuffer = '' + throw new Error('Control message too large') + } + + let newlineIndex = lineBuffer.indexOf('\n') + while (newlineIndex !== -1) { + const line = lineBuffer.slice(0, newlineIndex).trim() + lineBuffer = lineBuffer.slice(newlineIndex + 1) + if (line.length > 0) { + onControlLine(line) + } + newlineIndex = lineBuffer.indexOf('\n') + } + }, + resetBuffer() { + lineBuffer = '' + } + } +} + +/** + * Wait for socket to drain (backpressure handling). + */ +export async function waitForSocketDrain(socket: Socket, abortSignal: AbortSignal): Promise { + if (abortSignal.aborted) { + throw getAbortError(abortSignal, 'Transfer aborted while waiting for socket drain') + } + if (socket.destroyed) { + throw new Error('Socket is closed') + } + + await new Promise((resolve, reject) => { + const cleanup = () => { + socket.off('drain', onDrain) + socket.off('close', onClose) + socket.off('error', onError) + abortSignal.removeEventListener('abort', onAbort) + } + + const onDrain = () => { + cleanup() + resolve() + } + + const onClose = () => { + cleanup() + reject(new Error('Socket closed while waiting for drain')) + } + + const onError = (error: Error) => { + cleanup() + reject(error) + } + + const onAbort = () => { + cleanup() + reject(getAbortError(abortSignal, 'Transfer aborted while waiting for socket drain')) + } + + socket.once('drain', onDrain) + socket.once('close', onClose) + socket.once('error', onError) + abortSignal.addEventListener('abort', onAbort, { once: true }) + }) +} + +/** + * Get the error from an abort signal, or create a fallback error. + */ +export function getAbortError(signal: AbortSignal, fallbackMessage: string): Error { + const reason = (signal as AbortSignal & { reason?: unknown }).reason + if (reason instanceof Error) { + return reason + } + if (typeof reason === 'string' && reason.length > 0) { + return new Error(reason) + } + return new Error(fallbackMessage) +} diff --git a/src/main/services/lanTransfer/handlers/fileTransfer.ts b/src/main/services/lanTransfer/handlers/fileTransfer.ts new file mode 100644 index 0000000000..c469a58421 --- /dev/null +++ b/src/main/services/lanTransfer/handlers/fileTransfer.ts @@ -0,0 +1,267 @@ +import * as crypto from 'node:crypto' +import * as fs from 'node:fs' +import type { Socket } from 'node:net' +import * as path from 'node:path' + +import { loggerService } from '@logger' +import type { + LanFileCompleteMessage, + LanFileEndMessage, + LanFileStartAckMessage, + LanFileStartMessage +} from '@shared/config/types' +import { + LAN_TRANSFER_CHUNK_SIZE, + LAN_TRANSFER_COMPLETE_TIMEOUT_MS, + LAN_TRANSFER_MAX_FILE_SIZE +} from '@shared/config/types' + +import { sendBinaryChunk } from '../binaryProtocol' +import type { ActiveFileTransfer, FileTransferContext } from '../types' +import { getAbortError, waitForSocketDrain } from './connection' + +const DEFAULT_FILE_START_ACK_TIMEOUT_MS = 30_000 // 30s for file_start_ack + +const logger = loggerService.withContext('LanTransferFileHandler') + +/** + * Validate a file for transfer. + * Checks existence, type, extension, and size limits. + */ +export async function validateFile(filePath: string): Promise<{ stats: fs.Stats; fileName: string }> { + let stats: fs.Stats + try { + stats = await fs.promises.stat(filePath) + } catch (error) { + const nodeError = error as NodeJS.ErrnoException + if (nodeError.code === 'ENOENT') { + throw new Error(`File not found: ${filePath}`) + } else if (nodeError.code === 'EACCES') { + throw new Error(`Permission denied: ${filePath}`) + } else if (nodeError.code === 'ENOTDIR') { + throw new Error(`Invalid path: ${filePath}`) + } else { + throw new Error(`Cannot access file: ${filePath} (${nodeError.code || 'unknown error'})`) + } + } + + if (!stats.isFile()) { + throw new Error('Path is not a file') + } + + const fileName = path.basename(filePath) + const ext = path.extname(fileName).toLowerCase() + if (ext !== '.zip') { + throw new Error('Only ZIP files are supported') + } + + if (stats.size > LAN_TRANSFER_MAX_FILE_SIZE) { + throw new Error(`File too large. Maximum size is ${formatFileSize(LAN_TRANSFER_MAX_FILE_SIZE)}`) + } + + return { stats, fileName } +} + +/** + * Calculate SHA-256 checksum of a file. + */ +export async function calculateFileChecksum(filePath: string): Promise { + return new Promise((resolve, reject) => { + const hash = crypto.createHash('sha256') + const stream = fs.createReadStream(filePath) + stream.on('data', (data) => hash.update(data)) + stream.on('end', () => resolve(hash.digest('hex'))) + stream.on('error', reject) + }) +} + +/** + * Create initial transfer state for a new file transfer. + */ +export function createTransferState( + transferId: string, + fileName: string, + fileSize: number, + checksum: string +): ActiveFileTransfer { + const chunkSize = LAN_TRANSFER_CHUNK_SIZE + const totalChunks = Math.ceil(fileSize / chunkSize) + + return { + transferId, + fileName, + fileSize, + checksum, + totalChunks, + chunkSize, + bytesSent: 0, + currentChunk: 0, + startedAt: Date.now(), + isCancelled: false, + abortController: new AbortController() + } +} + +/** + * Send file_start message to receiver. + */ +export function sendFileStart(ctx: FileTransferContext, transfer: ActiveFileTransfer): void { + const startMessage: LanFileStartMessage = { + type: 'file_start', + transferId: transfer.transferId, + fileName: transfer.fileName, + fileSize: transfer.fileSize, + mimeType: 'application/zip', + checksum: transfer.checksum, + totalChunks: transfer.totalChunks, + chunkSize: transfer.chunkSize + } + ctx.sendControlMessage(startMessage) + logger.info('Sent file_start message') +} + +/** + * Wait for file_start_ack from receiver. + */ +export function waitForFileStartAck( + ctx: FileTransferContext, + transferId: string, + abortSignal?: AbortSignal +): Promise { + return new Promise((resolve, reject) => { + ctx.waitForResponse( + 'file_start_ack', + DEFAULT_FILE_START_ACK_TIMEOUT_MS, + (payload) => resolve(payload as LanFileStartAckMessage), + reject, + transferId, + undefined, + abortSignal + ) + }) +} + +/** + * Wait for file_complete from receiver after all chunks sent. + */ +export function waitForFileComplete( + ctx: FileTransferContext, + transferId: string, + abortSignal?: AbortSignal +): Promise { + return new Promise((resolve, reject) => { + ctx.waitForResponse( + 'file_complete', + LAN_TRANSFER_COMPLETE_TIMEOUT_MS, + (payload) => resolve(payload as LanFileCompleteMessage), + reject, + transferId, + undefined, + abortSignal + ) + }) +} + +/** + * Send file_end message to receiver. + */ +export function sendFileEnd(ctx: FileTransferContext, transferId: string): void { + const endMessage: LanFileEndMessage = { + type: 'file_end', + transferId + } + ctx.sendControlMessage(endMessage) + logger.info('Sent file_end message') +} + +/** + * Stream file chunks to the receiver (v1 streaming mode - no per-chunk acknowledgment). + */ +export async function streamFileChunks( + socket: Socket, + filePath: string, + transfer: ActiveFileTransfer, + abortSignal: AbortSignal, + onProgress: (bytesSent: number, chunkIndex: number) => void +): Promise { + const { chunkSize, transferId } = transfer + + const stream = fs.createReadStream(filePath, { highWaterMark: chunkSize }) + transfer.stream = stream + + let chunkIndex = 0 + let bytesSent = 0 + + try { + for await (const chunk of stream) { + if (abortSignal.aborted) { + throw getAbortError(abortSignal, 'Transfer aborted') + } + + const buffer = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk) + bytesSent += buffer.length + + // Send chunk as binary frame (v1 streaming) with backpressure handling + const canContinue = sendBinaryChunk(socket, transferId, chunkIndex, buffer) + if (!canContinue) { + await waitForSocketDrain(socket, abortSignal) + } + + // Update progress + transfer.bytesSent = bytesSent + transfer.currentChunk = chunkIndex + + onProgress(bytesSent, chunkIndex) + chunkIndex++ + } + + logger.info(`File streaming completed: ${chunkIndex} chunks sent`) + } catch (error) { + logger.error('File streaming failed', error as Error) + throw error + } +} + +/** + * Abort an active transfer and clean up resources. + */ +export function abortTransfer(transfer: ActiveFileTransfer | undefined, error: Error): void { + if (!transfer) { + return + } + + transfer.isCancelled = true + if (!transfer.abortController.signal.aborted) { + transfer.abortController.abort(error) + } + if (transfer.stream && !transfer.stream.destroyed) { + transfer.stream.destroy(error) + } +} + +/** + * Clean up transfer resources without error. + */ +export function cleanupTransfer(transfer: ActiveFileTransfer | undefined): void { + if (!transfer) { + return + } + + if (!transfer.abortController.signal.aborted) { + transfer.abortController.abort() + } + if (transfer.stream && !transfer.stream.destroyed) { + transfer.stream.destroy() + } +} + +/** + * Format bytes into human-readable size string. + */ +export function formatFileSize(bytes: number): string { + if (bytes === 0) return '0 B' + const k = 1024 + const sizes = ['B', 'KB', 'MB', 'GB'] + const i = Math.floor(Math.log(bytes) / Math.log(k)) + return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i] +} diff --git a/src/main/services/lanTransfer/handlers/index.ts b/src/main/services/lanTransfer/handlers/index.ts new file mode 100644 index 0000000000..33620d188c --- /dev/null +++ b/src/main/services/lanTransfer/handlers/index.ts @@ -0,0 +1,22 @@ +export { + buildHandshakeMessage, + createDataHandler, + getAbortError, + HANDSHAKE_PROTOCOL_VERSION, + pickHost, + sendTestPing, + waitForSocketDrain +} from './connection' +export { + abortTransfer, + calculateFileChecksum, + cleanupTransfer, + createTransferState, + formatFileSize, + sendFileEnd, + sendFileStart, + streamFileChunks, + validateFile, + waitForFileComplete, + waitForFileStartAck +} from './fileTransfer' diff --git a/src/main/services/lanTransfer/index.ts b/src/main/services/lanTransfer/index.ts new file mode 100644 index 0000000000..12f3c38afc --- /dev/null +++ b/src/main/services/lanTransfer/index.ts @@ -0,0 +1,21 @@ +/** + * LAN Transfer Client Module + * + * Protocol: v1.0 (streaming mode) + * + * Features: + * - Binary frame format for file chunks (no base64 overhead) + * - Streaming mode (no per-chunk acknowledgment) + * - JSON messages for control flow (handshake, file_start, file_end, etc.) + * - Global timeout protection + * - Backpressure handling + * + * Binary Frame Format: + * ┌──────────┬──────────┬──────────┬───────────────┬──────────────┬────────────┬───────────┐ + * │ Magic │ TotalLen │ Type │ TransferId Len│ TransferId │ ChunkIdx │ Data │ + * │ 0x43 0x53│ (4B BE) │ 0x01 │ (2B BE) │ (variable) │ (4B BE) │ (raw) │ + * └──────────┴──────────┴──────────┴───────────────┴──────────────┴────────────┴───────────┘ + */ + +export { HANDSHAKE_PROTOCOL_VERSION, lanTransferClientService } from './LanTransferClientService' +export type { ActiveFileTransfer, ConnectionContext, FileTransferContext, PendingResponse } from './types' diff --git a/src/main/services/lanTransfer/responseManager.ts b/src/main/services/lanTransfer/responseManager.ts new file mode 100644 index 0000000000..74d5196dba --- /dev/null +++ b/src/main/services/lanTransfer/responseManager.ts @@ -0,0 +1,144 @@ +import type { PendingResponse } from './types' + +/** + * Manages pending response handlers for awaiting control messages. + * Handles timeouts, abort signals, and cleanup. + */ +export class ResponseManager { + private pendingResponses = new Map() + private onTimeout?: () => void + + /** + * Set a callback to be called when a response times out. + * Typically used to trigger disconnect on timeout. + */ + setTimeoutCallback(callback: () => void): void { + this.onTimeout = callback + } + + /** + * Build a composite key for identifying pending responses. + */ + buildResponseKey(type: string, transferId?: string, chunkIndex?: number): string { + const parts = [type] + if (transferId !== undefined) parts.push(transferId) + if (chunkIndex !== undefined) parts.push(String(chunkIndex)) + return parts.join(':') + } + + /** + * Register a response listener with timeout and optional abort signal. + */ + waitForResponse( + type: string, + timeoutMs: number, + resolve: (payload: unknown) => void, + reject: (error: Error) => void, + transferId?: string, + chunkIndex?: number, + abortSignal?: AbortSignal + ): void { + const responseKey = this.buildResponseKey(type, transferId, chunkIndex) + + // Clear any existing response with the same key + this.clearPendingResponse(responseKey) + + const timeoutHandle = setTimeout(() => { + this.clearPendingResponse(responseKey) + const error = new Error(`Timeout waiting for ${type}`) + reject(error) + this.onTimeout?.() + }, timeoutMs) + + const pending: PendingResponse = { + type, + transferId, + chunkIndex, + resolve, + reject, + timeoutHandle, + abortSignal + } + + if (abortSignal) { + const abortListener = () => { + this.clearPendingResponse(responseKey) + reject(this.getAbortError(abortSignal, `Aborted while waiting for ${type}`)) + } + pending.abortListener = abortListener + abortSignal.addEventListener('abort', abortListener, { once: true }) + } + + this.pendingResponses.set(responseKey, pending) + } + + /** + * Try to resolve a pending response by type and optional identifiers. + * Returns true if a matching response was found and resolved. + */ + tryResolve(type: string, payload: unknown, transferId?: string, chunkIndex?: number): boolean { + const responseKey = this.buildResponseKey(type, transferId, chunkIndex) + const pendingResponse = this.pendingResponses.get(responseKey) + + if (pendingResponse) { + const resolver = pendingResponse.resolve + this.clearPendingResponse(responseKey) + resolver(payload) + return true + } + + return false + } + + /** + * Clear a single pending response by key, or all responses if no key provided. + */ + clearPendingResponse(key?: string): void { + if (key) { + const pending = this.pendingResponses.get(key) + if (pending?.timeoutHandle) { + clearTimeout(pending.timeoutHandle) + } + if (pending?.abortSignal && pending.abortListener) { + pending.abortSignal.removeEventListener('abort', pending.abortListener) + } + this.pendingResponses.delete(key) + } else { + // Clear all pending responses + for (const pending of this.pendingResponses.values()) { + if (pending.timeoutHandle) { + clearTimeout(pending.timeoutHandle) + } + if (pending.abortSignal && pending.abortListener) { + pending.abortSignal.removeEventListener('abort', pending.abortListener) + } + } + this.pendingResponses.clear() + } + } + + /** + * Reject all pending responses with the given error. + */ + rejectAll(error: Error): void { + for (const key of Array.from(this.pendingResponses.keys())) { + const pending = this.pendingResponses.get(key) + this.clearPendingResponse(key) + pending?.reject(error) + } + } + + /** + * Get the abort error from an abort signal, or create a fallback error. + */ + getAbortError(signal: AbortSignal, fallbackMessage: string): Error { + const reason = (signal as AbortSignal & { reason?: unknown }).reason + if (reason instanceof Error) { + return reason + } + if (typeof reason === 'string' && reason.length > 0) { + return new Error(reason) + } + return new Error(fallbackMessage) + } +} diff --git a/src/main/services/lanTransfer/types.ts b/src/main/services/lanTransfer/types.ts new file mode 100644 index 0000000000..52be660af3 --- /dev/null +++ b/src/main/services/lanTransfer/types.ts @@ -0,0 +1,65 @@ +import type * as fs from 'node:fs' +import type { Socket } from 'node:net' + +import type { LanClientEvent, LocalTransferPeer } from '@shared/config/types' + +/** + * Pending response handler for awaiting control messages + */ +export type PendingResponse = { + type: string + transferId?: string + chunkIndex?: number + resolve: (payload: unknown) => void + reject: (error: Error) => void + timeoutHandle?: NodeJS.Timeout + abortSignal?: AbortSignal + abortListener?: () => void +} + +/** + * Active file transfer state tracking + */ +export type ActiveFileTransfer = { + transferId: string + fileName: string + fileSize: number + checksum: string + totalChunks: number + chunkSize: number + bytesSent: number + currentChunk: number + startedAt: number + stream?: fs.ReadStream + isCancelled: boolean + abortController: AbortController +} + +/** + * Context interface for connection handlers + * Provides access to service methods without circular dependencies + */ +export type ConnectionContext = { + socket: Socket | null + currentPeer?: LocalTransferPeer + sendControlMessage: (message: Record) => void + broadcastClientEvent: (event: LanClientEvent) => void +} + +/** + * Context interface for file transfer handlers + * Extends connection context with transfer-specific methods + */ +export type FileTransferContext = ConnectionContext & { + activeTransfer?: ActiveFileTransfer + setActiveTransfer: (transfer: ActiveFileTransfer | undefined) => void + waitForResponse: ( + type: string, + timeoutMs: number, + resolve: (payload: unknown) => void, + reject: (error: Error) => void, + transferId?: string, + chunkIndex?: number, + abortSignal?: AbortSignal + ) => void +} diff --git a/src/main/services/mcp/oauth/callback.ts b/src/main/services/mcp/oauth/callback.ts index 0cf63853c4..8b3bd7216b 100644 --- a/src/main/services/mcp/oauth/callback.ts +++ b/src/main/services/mcp/oauth/callback.ts @@ -127,8 +127,8 @@ export class CallBackServer { }) return new Promise((resolve, reject) => { - server.listen(port, () => { - logger.info(`OAuth callback server listening on port ${port}`) + server.listen(port, '127.0.0.1', () => { + logger.info(`OAuth callback server listening on 127.0.0.1:${port}`) resolve(server) }) diff --git a/src/main/services/memory/MemoryService.ts b/src/main/services/memory/MemoryService.ts index 3466e2c3c6..101dd54294 100644 --- a/src/main/services/memory/MemoryService.ts +++ b/src/main/services/memory/MemoryService.ts @@ -1,7 +1,9 @@ import type { Client } from '@libsql/client' import { createClient } from '@libsql/client' import { loggerService } from '@logger' +import { DATA_PATH } from '@main/config' import Embeddings from '@main/knowledge/embedjs/embeddings/Embeddings' +import { makeSureDirExists } from '@main/utils' import type { AddMemoryOptions, AssistantMessage, @@ -13,6 +15,7 @@ import type { } from '@types' import crypto from 'crypto' import { app } from 'electron' +import fs from 'fs' import path from 'path' import { MemoryQueries } from './queries' @@ -71,6 +74,21 @@ export class MemoryService { return MemoryService.instance } + /** + * Migrate the memory database from the old path to the new path + * If the old memory database exists, rename it to the new path + */ + public migrateMemoryDb(): void { + const oldMemoryDbPath = path.join(app.getPath('userData'), 'memories.db') + const memoryDbPath = path.join(DATA_PATH, 'Memory', 'memories.db') + + makeSureDirExists(path.dirname(memoryDbPath)) + + if (fs.existsSync(oldMemoryDbPath)) { + fs.renameSync(oldMemoryDbPath, memoryDbPath) + } + } + /** * Initialize the database connection and create tables */ @@ -80,11 +98,12 @@ export class MemoryService { } try { - const userDataPath = app.getPath('userData') - const dbPath = path.join(userDataPath, 'memories.db') + const memoryDbPath = path.join(DATA_PATH, 'Memory', 'memories.db') + + makeSureDirExists(path.dirname(memoryDbPath)) this.db = createClient({ - url: `file:${dbPath}`, + url: `file:${memoryDbPath}`, intMode: 'number' }) @@ -168,12 +187,13 @@ export class MemoryService { // Generate embedding if model is configured let embedding: number[] | null = null - const embedderApiClient = this.config?.embedderApiClient - if (embedderApiClient) { + const embeddingModel = this.config?.embeddingModel + + if (embeddingModel) { try { embedding = await this.generateEmbedding(trimmedMemory) logger.debug( - `Generated embedding for restored memory with dimension: ${embedding.length} (target: ${this.config?.embedderDimensions || MemoryService.UNIFIED_DIMENSION})` + `Generated embedding for restored memory with dimension: ${embedding.length} (target: ${this.config?.embeddingDimensions || MemoryService.UNIFIED_DIMENSION})` ) } catch (error) { logger.error('Failed to generate embedding for restored memory:', error as Error) @@ -211,11 +231,11 @@ export class MemoryService { // Generate embedding if model is configured let embedding: number[] | null = null - if (this.config?.embedderApiClient) { + if (this.config?.embeddingModel) { try { embedding = await this.generateEmbedding(trimmedMemory) logger.debug( - `Generated embedding with dimension: ${embedding.length} (target: ${this.config?.embedderDimensions || MemoryService.UNIFIED_DIMENSION})` + `Generated embedding with dimension: ${embedding.length} (target: ${this.config?.embeddingDimensions || MemoryService.UNIFIED_DIMENSION})` ) // Check for similar memories using vector similarity @@ -300,7 +320,7 @@ export class MemoryService { try { // If we have an embedder model configured, use vector search - if (this.config?.embedderApiClient) { + if (this.config?.embeddingModel) { try { const queryEmbedding = await this.generateEmbedding(query) return await this.hybridSearch(query, queryEmbedding, { limit, userId, agentId, filters }) @@ -497,11 +517,11 @@ export class MemoryService { // Generate new embedding if model is configured let embedding: number[] | null = null - if (this.config?.embedderApiClient) { + if (this.config?.embeddingModel) { try { embedding = await this.generateEmbedding(memory) logger.debug( - `Updated embedding with dimension: ${embedding.length} (target: ${this.config?.embedderDimensions || MemoryService.UNIFIED_DIMENSION})` + `Updated embedding with dimension: ${embedding.length} (target: ${this.config?.embeddingDimensions || MemoryService.UNIFIED_DIMENSION})` ) } catch (error) { logger.error('Failed to generate embedding for update:', error as Error) @@ -710,21 +730,22 @@ export class MemoryService { * Generate embedding for text */ private async generateEmbedding(text: string): Promise { - if (!this.config?.embedderApiClient) { + if (!this.config?.embeddingModel) { throw new Error('Embedder model not configured') } try { // Initialize embeddings instance if needed if (!this.embeddings) { - if (!this.config.embedderApiClient) { + if (!this.config.embeddingApiClient) { throw new Error('Embedder provider not configured') } this.embeddings = new Embeddings({ - embedApiClient: this.config.embedderApiClient, - dimensions: this.config.embedderDimensions + embedApiClient: this.config.embeddingApiClient, + dimensions: this.config.embeddingDimensions }) + await this.embeddings.init() } diff --git a/src/main/utils/__tests__/mcp.test.ts b/src/main/utils/__tests__/mcp.test.ts index b1a35f925e..706a44bc84 100644 --- a/src/main/utils/__tests__/mcp.test.ts +++ b/src/main/utils/__tests__/mcp.test.ts @@ -3,194 +3,223 @@ import { describe, expect, it } from 'vitest' import { buildFunctionCallToolName } from '../mcp' describe('buildFunctionCallToolName', () => { - describe('basic functionality', () => { - it('should combine server name and tool name', () => { + describe('basic format', () => { + it('should return format mcp__{server}__{tool}', () => { const result = buildFunctionCallToolName('github', 'search_issues') - expect(result).toContain('github') - expect(result).toContain('search') + expect(result).toBe('mcp__github__search_issues') }) - it('should sanitize names by replacing dashes with underscores', () => { - const result = buildFunctionCallToolName('my-server', 'my-tool') - // Input dashes are replaced, but the separator between server and tool is a dash - expect(result).toBe('my_serv-my_tool') - expect(result).toContain('_') - }) - - it('should handle empty server names gracefully', () => { - const result = buildFunctionCallToolName('', 'tool') - expect(result).toBeTruthy() + it('should handle simple server and tool names', () => { + expect(buildFunctionCallToolName('fetch', 'get_page')).toBe('mcp__fetch__get_page') + expect(buildFunctionCallToolName('database', 'query')).toBe('mcp__database__query') + expect(buildFunctionCallToolName('cherry_studio', 'search')).toBe('mcp__cherry_studio__search') }) }) - describe('uniqueness with serverId', () => { - it('should generate different IDs for same server name but different serverIds', () => { - const serverId1 = 'server-id-123456' - const serverId2 = 'server-id-789012' - const serverName = 'github' - const toolName = 'search_repos' - - const result1 = buildFunctionCallToolName(serverName, toolName, serverId1) - const result2 = buildFunctionCallToolName(serverName, toolName, serverId2) - - expect(result1).not.toBe(result2) - expect(result1).toContain('123456') - expect(result2).toContain('789012') + describe('valid JavaScript identifier', () => { + it('should always start with mcp__ prefix (valid JS identifier start)', () => { + const result = buildFunctionCallToolName('123server', '456tool') + expect(result).toMatch(/^mcp__/) + expect(result).toBe('mcp__123server__456tool') }) - it('should generate same ID when serverId is not provided', () => { + it('should only contain alphanumeric chars and underscores', () => { + const result = buildFunctionCallToolName('my-server', 'my-tool') + expect(result).toBe('mcp__my_server__my_tool') + expect(result).toMatch(/^[a-zA-Z][a-zA-Z0-9_]*$/) + }) + + it('should be a valid JavaScript identifier', () => { + const testCases = [ + ['github', 'create_issue'], + ['my-server', 'fetch-data'], + ['test@server', 'tool#name'], + ['server.name', 'tool.action'], + ['123abc', 'def456'] + ] + + for (const [server, tool] of testCases) { + const result = buildFunctionCallToolName(server, tool) + // Valid JS identifiers match this pattern + expect(result).toMatch(/^[a-zA-Z_][a-zA-Z0-9_]*$/) + } + }) + }) + + describe('character sanitization', () => { + it('should replace dashes with underscores', () => { + const result = buildFunctionCallToolName('my-server', 'my-tool-name') + expect(result).toBe('mcp__my_server__my_tool_name') + }) + + it('should replace special characters with underscores', () => { + const result = buildFunctionCallToolName('test@server!', 'tool#name$') + expect(result).toBe('mcp__test_server__tool_name') + }) + + it('should replace dots with underscores', () => { + const result = buildFunctionCallToolName('server.name', 'tool.action') + expect(result).toBe('mcp__server_name__tool_action') + }) + + it('should replace spaces with underscores', () => { + const result = buildFunctionCallToolName('my server', 'my tool') + expect(result).toBe('mcp__my_server__my_tool') + }) + + it('should collapse consecutive underscores', () => { + const result = buildFunctionCallToolName('my--server', 'my___tool') + expect(result).toBe('mcp__my_server__my_tool') + expect(result).not.toMatch(/_{3,}/) + }) + + it('should trim leading and trailing underscores from parts', () => { + const result = buildFunctionCallToolName('_server_', '_tool_') + expect(result).toBe('mcp__server__tool') + }) + + it('should handle names with only special characters', () => { + const result = buildFunctionCallToolName('---', '###') + expect(result).toBe('mcp____') + }) + }) + + describe('length constraints', () => { + it('should not exceed 63 characters', () => { + const longServerName = 'a'.repeat(50) + const longToolName = 'b'.repeat(50) + const result = buildFunctionCallToolName(longServerName, longToolName) + + expect(result.length).toBeLessThanOrEqual(63) + }) + + it('should truncate server name to max 20 chars', () => { + const longServerName = 'abcdefghijklmnopqrstuvwxyz' // 26 chars + const result = buildFunctionCallToolName(longServerName, 'tool') + + expect(result).toBe('mcp__abcdefghijklmnopqrst__tool') + expect(result).toContain('abcdefghijklmnopqrst') // First 20 chars + expect(result).not.toContain('uvwxyz') // Truncated + }) + + it('should truncate tool name to max 35 chars', () => { + const longToolName = 'a'.repeat(40) + const result = buildFunctionCallToolName('server', longToolName) + + const expectedTool = 'a'.repeat(35) + expect(result).toBe(`mcp__server__${expectedTool}`) + }) + + it('should not end with underscores after truncation', () => { + // Create a name that would end with underscores after truncation + const longServerName = 'a'.repeat(20) + const longToolName = 'b'.repeat(35) + '___extra' + const result = buildFunctionCallToolName(longServerName, longToolName) + + expect(result).not.toMatch(/_+$/) + expect(result.length).toBeLessThanOrEqual(63) + }) + + it('should handle max length edge case exactly', () => { + // mcp__ (5) + server (20) + __ (2) + tool (35) = 62 chars + const server = 'a'.repeat(20) + const tool = 'b'.repeat(35) + const result = buildFunctionCallToolName(server, tool) + + expect(result.length).toBe(62) + expect(result).toBe(`mcp__${'a'.repeat(20)}__${'b'.repeat(35)}`) + }) + }) + + describe('edge cases', () => { + it('should handle empty server name', () => { + const result = buildFunctionCallToolName('', 'tool') + expect(result).toBe('mcp____tool') + }) + + it('should handle empty tool name', () => { + const result = buildFunctionCallToolName('server', '') + expect(result).toBe('mcp__server__') + }) + + it('should handle both empty names', () => { + const result = buildFunctionCallToolName('', '') + expect(result).toBe('mcp____') + }) + + it('should handle whitespace-only names', () => { + const result = buildFunctionCallToolName(' ', ' ') + expect(result).toBe('mcp____') + }) + + it('should trim whitespace from names', () => { + const result = buildFunctionCallToolName(' server ', ' tool ') + expect(result).toBe('mcp__server__tool') + }) + + it('should handle unicode characters', () => { + const result = buildFunctionCallToolName('服务器', '工具') + // Unicode chars are replaced with underscores, then collapsed + expect(result).toMatch(/^mcp__/) + }) + + it('should handle mixed case', () => { + const result = buildFunctionCallToolName('MyServer', 'MyTool') + expect(result).toBe('mcp__MyServer__MyTool') + }) + }) + + describe('deterministic output', () => { + it('should produce consistent results for same input', () => { const serverName = 'github' const toolName = 'search_repos' const result1 = buildFunctionCallToolName(serverName, toolName) const result2 = buildFunctionCallToolName(serverName, toolName) + const result3 = buildFunctionCallToolName(serverName, toolName) expect(result1).toBe(result2) + expect(result2).toBe(result3) }) - it('should include serverId suffix when provided', () => { - const serverId = 'abc123def456' - const result = buildFunctionCallToolName('server', 'tool', serverId) + it('should produce different results for different inputs', () => { + const result1 = buildFunctionCallToolName('server1', 'tool') + const result2 = buildFunctionCallToolName('server2', 'tool') + const result3 = buildFunctionCallToolName('server', 'tool1') + const result4 = buildFunctionCallToolName('server', 'tool2') - // Should include last 6 chars of serverId - expect(result).toContain('ef456') - }) - }) - - describe('character sanitization', () => { - it('should replace invalid characters with underscores', () => { - const result = buildFunctionCallToolName('test@server', 'tool#name') - expect(result).not.toMatch(/[@#]/) - expect(result).toMatch(/^[a-zA-Z0-9_-]+$/) - }) - - it('should ensure name starts with a letter', () => { - const result = buildFunctionCallToolName('123server', '456tool') - expect(result).toMatch(/^[a-zA-Z]/) - }) - - it('should handle consecutive underscores/dashes', () => { - const result = buildFunctionCallToolName('my--server', 'my__tool') - expect(result).not.toMatch(/[_-]{2,}/) - }) - }) - - describe('length constraints', () => { - it('should truncate names longer than 63 characters', () => { - const longServerName = 'a'.repeat(50) - const longToolName = 'b'.repeat(50) - const result = buildFunctionCallToolName(longServerName, longToolName, 'id123456') - - expect(result.length).toBeLessThanOrEqual(63) - }) - - it('should not end with underscore or dash after truncation', () => { - const longServerName = 'a'.repeat(50) - const longToolName = 'b'.repeat(50) - const result = buildFunctionCallToolName(longServerName, longToolName, 'id123456') - - expect(result).not.toMatch(/[_-]$/) - }) - - it('should preserve serverId suffix even with long server/tool names', () => { - const longServerName = 'a'.repeat(50) - const longToolName = 'b'.repeat(50) - const serverId = 'server-id-xyz789' - - const result = buildFunctionCallToolName(longServerName, longToolName, serverId) - - // The suffix should be preserved and not truncated - expect(result).toContain('xyz789') - expect(result.length).toBeLessThanOrEqual(63) - }) - - it('should ensure two long-named servers with different IDs produce different results', () => { - const longServerName = 'a'.repeat(50) - const longToolName = 'b'.repeat(50) - const serverId1 = 'server-id-abc123' - const serverId2 = 'server-id-def456' - - const result1 = buildFunctionCallToolName(longServerName, longToolName, serverId1) - const result2 = buildFunctionCallToolName(longServerName, longToolName, serverId2) - - // Both should be within limit - expect(result1.length).toBeLessThanOrEqual(63) - expect(result2.length).toBeLessThanOrEqual(63) - - // They should be different due to preserved suffix expect(result1).not.toBe(result2) - }) - }) - - describe('edge cases with serverId', () => { - it('should handle serverId with only non-alphanumeric characters', () => { - const serverId = '------' // All dashes - const result = buildFunctionCallToolName('server', 'tool', serverId) - - // Should still produce a valid unique suffix via fallback hash - expect(result).toBeTruthy() - expect(result.length).toBeLessThanOrEqual(63) - expect(result).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/) - // Should have a suffix (underscore followed by something) - expect(result).toMatch(/_[a-z0-9]+$/) - }) - - it('should produce different results for different non-alphanumeric serverIds', () => { - const serverId1 = '------' - const serverId2 = '!!!!!!' - - const result1 = buildFunctionCallToolName('server', 'tool', serverId1) - const result2 = buildFunctionCallToolName('server', 'tool', serverId2) - - // Should be different because the hash fallback produces different values - expect(result1).not.toBe(result2) - }) - - it('should handle empty string serverId differently from undefined', () => { - const resultWithEmpty = buildFunctionCallToolName('server', 'tool', '') - const resultWithUndefined = buildFunctionCallToolName('server', 'tool', undefined) - - // Empty string is falsy, so both should behave the same (no suffix) - expect(resultWithEmpty).toBe(resultWithUndefined) - }) - - it('should handle serverId with mixed alphanumeric and special chars', () => { - const serverId = 'ab@#cd' // Mixed chars, last 6 chars contain some alphanumeric - const result = buildFunctionCallToolName('server', 'tool', serverId) - - // Should extract alphanumeric chars: 'abcd' from 'ab@#cd' - expect(result).toContain('abcd') + expect(result3).not.toBe(result4) }) }) describe('real-world scenarios', () => { - it('should handle GitHub MCP server instances correctly', () => { - const serverName = 'github' - const toolName = 'search_repositories' - - const githubComId = 'server-github-com-abc123' - const gheId = 'server-ghe-internal-xyz789' - - const tool1 = buildFunctionCallToolName(serverName, toolName, githubComId) - const tool2 = buildFunctionCallToolName(serverName, toolName, gheId) - - // Should be different - expect(tool1).not.toBe(tool2) - - // Both should be valid identifiers - expect(tool1).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/) - expect(tool2).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/) - - // Both should be <= 63 chars - expect(tool1.length).toBeLessThanOrEqual(63) - expect(tool2.length).toBeLessThanOrEqual(63) + it('should handle GitHub MCP server', () => { + expect(buildFunctionCallToolName('github', 'create_issue')).toBe('mcp__github__create_issue') + expect(buildFunctionCallToolName('github', 'search_repositories')).toBe('mcp__github__search_repositories') + expect(buildFunctionCallToolName('github', 'get_pull_request')).toBe('mcp__github__get_pull_request') }) - it('should handle tool names that already include server name prefix', () => { - const result = buildFunctionCallToolName('github', 'github_search_repos') - expect(result).toBeTruthy() - // Should not double the server name - expect(result.split('github').length - 1).toBeLessThanOrEqual(2) + it('should handle filesystem MCP server', () => { + expect(buildFunctionCallToolName('filesystem', 'read_file')).toBe('mcp__filesystem__read_file') + expect(buildFunctionCallToolName('filesystem', 'write_file')).toBe('mcp__filesystem__write_file') + expect(buildFunctionCallToolName('filesystem', 'list_directory')).toBe('mcp__filesystem__list_directory') + }) + + it('should handle hyphenated server names (common in npm packages)', () => { + expect(buildFunctionCallToolName('cherry-fetch', 'get_page')).toBe('mcp__cherry_fetch__get_page') + expect(buildFunctionCallToolName('mcp-server-github', 'search')).toBe('mcp__mcp_server_github__search') + }) + + it('should handle scoped npm package style names', () => { + const result = buildFunctionCallToolName('@anthropic/mcp-server', 'chat') + expect(result).toBe('mcp__anthropic_mcp_server__chat') + }) + + it('should handle tools with long descriptive names', () => { + const result = buildFunctionCallToolName('github', 'search_repositories_by_language_and_stars') + expect(result.length).toBeLessThanOrEqual(63) + expect(result).toMatch(/^mcp__github__search_repositories_by_lan/) }) }) }) diff --git a/src/main/utils/__tests__/process.test.ts b/src/main/utils/__tests__/process.test.ts index 0485ec5fad..faeb73994c 100644 --- a/src/main/utils/__tests__/process.test.ts +++ b/src/main/utils/__tests__/process.test.ts @@ -1,9 +1,29 @@ -import { execFileSync } from 'child_process' +import { configManager } from '@main/services/ConfigManager' +import { execFileSync, spawn } from 'child_process' +import { EventEmitter } from 'events' import fs from 'fs' import path from 'path' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { findExecutable, findGitBash, validateGitBashPath } from '../process' +import { + autoDiscoverGitBash, + findCommandInShellEnv, + findExecutable, + findGitBash, + validateGitBashPath +} from '../process' + +// Mock configManager +vi.mock('@main/services/ConfigManager', () => ({ + ConfigKeys: { + GitBashPath: 'gitBashPath', + GitBashPathSource: 'gitBashPathSource' + }, + configManager: { + get: vi.fn(), + set: vi.fn() + } +})) // Mock dependencies vi.mock('child_process') @@ -695,4 +715,530 @@ describe.skipIf(process.platform !== 'win32')('process utilities', () => { }) }) }) + + describe('autoDiscoverGitBash', () => { + const originalEnvVar = process.env.CLAUDE_CODE_GIT_BASH_PATH + + beforeEach(() => { + vi.mocked(configManager.get).mockReset() + vi.mocked(configManager.set).mockReset() + delete process.env.CLAUDE_CODE_GIT_BASH_PATH + }) + + afterEach(() => { + // Restore original environment variable + if (originalEnvVar !== undefined) { + process.env.CLAUDE_CODE_GIT_BASH_PATH = originalEnvVar + } else { + delete process.env.CLAUDE_CODE_GIT_BASH_PATH + } + }) + + /** + * Helper to mock fs.existsSync with a set of valid paths + */ + const mockExistingPaths = (...validPaths: string[]) => { + vi.mocked(fs.existsSync).mockImplementation((p) => validPaths.includes(p as string)) + } + + describe('with no existing config path', () => { + it('should discover and persist Git Bash path when not configured', () => { + const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + process.env.ProgramFiles = 'C:\\Program Files' + mockExistingPaths(gitPath, bashPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(bashPath) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath) + }) + + it('should return null and not persist when Git Bash is not found', () => { + vi.mocked(configManager.get).mockReturnValue(undefined) + vi.mocked(fs.existsSync).mockReturnValue(false) + vi.mocked(execFileSync).mockImplementation(() => { + throw new Error('Not found') + }) + + const result = autoDiscoverGitBash() + + expect(result).toBeNull() + expect(configManager.set).not.toHaveBeenCalled() + }) + }) + + describe('environment variable precedence', () => { + it('should use env var over valid config path', () => { + const envPath = 'C:\\EnvGit\\bin\\bash.exe' + const configPath = 'C:\\ConfigGit\\bin\\bash.exe' + + process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath + vi.mocked(configManager.get).mockReturnValue(configPath) + mockExistingPaths(envPath, configPath) + + const result = autoDiscoverGitBash() + + // Env var should take precedence + expect(result).toBe(envPath) + // Should not persist env var path (it's a runtime override) + expect(configManager.set).not.toHaveBeenCalled() + }) + + it('should fall back to config path when env var is invalid', () => { + const envPath = 'C:\\Invalid\\bash.exe' + const configPath = 'C:\\ConfigGit\\bin\\bash.exe' + + process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath + vi.mocked(configManager.get).mockReturnValue(configPath) + // Env path is invalid (doesn't exist), only config path exists + mockExistingPaths(configPath) + + const result = autoDiscoverGitBash() + + // Should fall back to config path + expect(result).toBe(configPath) + expect(configManager.set).not.toHaveBeenCalled() + }) + + it('should fall back to auto-discovery when both env var and config are invalid', () => { + const envPath = 'C:\\InvalidEnv\\bash.exe' + const configPath = 'C:\\InvalidConfig\\bash.exe' + const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath + process.env.ProgramFiles = 'C:\\Program Files' + vi.mocked(configManager.get).mockReturnValue(configPath) + // Both env and config paths are invalid, only standard Git exists + mockExistingPaths(gitPath, discoveredPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(discoveredPath) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', discoveredPath) + }) + }) + + describe('with valid existing config path', () => { + it('should validate and return existing path without re-discovering', () => { + const existingPath = 'C:\\CustomGit\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + mockExistingPaths(existingPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(existingPath) + // Should not call findGitBash or persist again + expect(configManager.set).not.toHaveBeenCalled() + // Should not call execFileSync (which findGitBash would use for discovery) + expect(execFileSync).not.toHaveBeenCalled() + }) + + it('should not override existing valid config with auto-discovery', () => { + const existingPath = 'C:\\CustomGit\\bin\\bash.exe' + const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + mockExistingPaths(existingPath, discoveredPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(existingPath) + expect(configManager.set).not.toHaveBeenCalled() + }) + }) + + describe('with invalid existing config path', () => { + it('should attempt auto-discovery when existing path does not exist', () => { + const existingPath = 'C:\\NonExistent\\bin\\bash.exe' + const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + process.env.ProgramFiles = 'C:\\Program Files' + // Invalid path doesn't exist, but Git is installed at standard location + mockExistingPaths(gitPath, discoveredPath) + + const result = autoDiscoverGitBash() + + // Should discover and return the new path + expect(result).toBe(discoveredPath) + // Should persist the discovered path (overwrites invalid) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', discoveredPath) + }) + + it('should attempt auto-discovery when existing path is not bash.exe', () => { + const existingPath = 'C:\\CustomGit\\bin\\git.exe' + const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + process.env.ProgramFiles = 'C:\\Program Files' + // Invalid path exists but is not bash.exe (validation will fail) + // Git is installed at standard location + mockExistingPaths(existingPath, gitPath, discoveredPath) + + const result = autoDiscoverGitBash() + + // Should discover and return the new path + expect(result).toBe(discoveredPath) + // Should persist the discovered path (overwrites invalid) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', discoveredPath) + }) + + it('should return null when existing path is invalid and discovery fails', () => { + const existingPath = 'C:\\NonExistent\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + vi.mocked(fs.existsSync).mockReturnValue(false) + vi.mocked(execFileSync).mockImplementation(() => { + throw new Error('Not found') + }) + + const result = autoDiscoverGitBash() + + // Both validation and discovery failed + expect(result).toBeNull() + // Should not persist when discovery fails + expect(configManager.set).not.toHaveBeenCalled() + }) + }) + + describe('config persistence verification', () => { + it('should persist discovered path with correct config key', () => { + const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + process.env.ProgramFiles = 'C:\\Program Files' + mockExistingPaths(gitPath, bashPath) + + autoDiscoverGitBash() + + // Verify the exact call to configManager.set + expect(configManager.set).toHaveBeenCalledTimes(2) + expect(configManager.set).toHaveBeenNthCalledWith(1, 'gitBashPath', bashPath) + expect(configManager.set).toHaveBeenNthCalledWith(2, 'gitBashPathSource', 'auto') + }) + + it('should persist on each discovery when config remains undefined', () => { + const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + process.env.ProgramFiles = 'C:\\Program Files' + mockExistingPaths(gitPath, bashPath) + + autoDiscoverGitBash() + autoDiscoverGitBash() + + // Each call discovers and persists since config remains undefined (mocked) + expect(configManager.set).toHaveBeenCalledTimes(4) + expect(configManager.set).toHaveBeenNthCalledWith(1, 'gitBashPath', bashPath) + expect(configManager.set).toHaveBeenNthCalledWith(2, 'gitBashPathSource', 'auto') + expect(configManager.set).toHaveBeenNthCalledWith(3, 'gitBashPath', bashPath) + expect(configManager.set).toHaveBeenNthCalledWith(4, 'gitBashPathSource', 'auto') + }) + }) + + describe('real-world scenarios', () => { + it('should discover and persist standard Git for Windows installation', () => { + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + process.env.ProgramFiles = 'C:\\Program Files' + mockExistingPaths(gitPath, bashPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(bashPath) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath) + }) + + it('should discover portable Git via where.exe and persist', () => { + const gitPath = 'D:\\PortableApps\\Git\\bin\\git.exe' + const bashPath = 'D:\\PortableApps\\Git\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + + vi.mocked(fs.existsSync).mockImplementation((p) => { + const pathStr = p?.toString() || '' + // Common git paths don't exist + if (pathStr.includes('Program Files\\Git\\cmd\\git.exe')) return false + if (pathStr.includes('Program Files (x86)\\Git\\cmd\\git.exe')) return false + // Portable bash path exists + if (pathStr === bashPath) return true + return false + }) + + vi.mocked(execFileSync).mockReturnValue(gitPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(bashPath) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath) + }) + + it('should respect user-configured path over auto-discovery', () => { + const userConfiguredPath = 'D:\\MyGit\\bin\\bash.exe' + const systemPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(userConfiguredPath) + mockExistingPaths(userConfiguredPath, systemPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(userConfiguredPath) + expect(configManager.set).not.toHaveBeenCalled() + // Verify findGitBash was not called for discovery + expect(execFileSync).not.toHaveBeenCalled() + }) + }) + }) +}) + +/** + * Helper to create a mock child process for spawn + */ +function createMockChildProcess() { + const mockChild = new EventEmitter() as EventEmitter & { + stdout: EventEmitter + stderr: EventEmitter + kill: ReturnType + } + mockChild.stdout = new EventEmitter() + mockChild.stderr = new EventEmitter() + mockChild.kill = vi.fn() + return mockChild +} + +describe('findCommandInShellEnv', () => { + beforeEach(() => { + vi.clearAllMocks() + // Reset path.isAbsolute to real implementation for these tests + vi.mocked(path.isAbsolute).mockImplementation((p) => p.startsWith('/') || /^[A-Z]:/i.test(p)) + }) + + describe('command name validation', () => { + it('should reject empty command name', async () => { + const result = await findCommandInShellEnv('', {}) + expect(result).toBeNull() + expect(spawn).not.toHaveBeenCalled() + }) + + it('should reject command names with shell metacharacters', async () => { + const maliciousCommands = [ + 'npx; rm -rf /', + 'npx && malicious', + 'npx | cat /etc/passwd', + 'npx`whoami`', + '$(whoami)', + 'npx\nmalicious' + ] + + for (const cmd of maliciousCommands) { + const result = await findCommandInShellEnv(cmd, {}) + expect(result).toBeNull() + expect(spawn).not.toHaveBeenCalled() + } + }) + + it('should reject command names starting with hyphen', async () => { + const result = await findCommandInShellEnv('-npx', {}) + expect(result).toBeNull() + expect(spawn).not.toHaveBeenCalled() + }) + + it('should reject path traversal attempts', async () => { + const pathTraversalCommands = ['../npx', '../../malicious', 'foo/bar', 'foo\\bar'] + + for (const cmd of pathTraversalCommands) { + const result = await findCommandInShellEnv(cmd, {}) + expect(result).toBeNull() + expect(spawn).not.toHaveBeenCalled() + } + }) + + it('should reject command names exceeding max length', async () => { + const longCommand = 'a'.repeat(129) + const result = await findCommandInShellEnv(longCommand, {}) + expect(result).toBeNull() + expect(spawn).not.toHaveBeenCalled() + }) + + it('should accept valid command names', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + // Don't await - just start the call + const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' }) + + // Simulate command not found + mockChild.emit('close', 1) + + const result = await resultPromise + expect(result).toBeNull() + expect(spawn).toHaveBeenCalled() + }) + + it('should accept command names with underscores and hyphens', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('my_command-name', { PATH: '/usr/bin' }) + mockChild.emit('close', 1) + + await resultPromise + expect(spawn).toHaveBeenCalled() + }) + + it('should accept command names at max length (128 chars)', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const maxLengthCommand = 'a'.repeat(128) + const resultPromise = findCommandInShellEnv(maxLengthCommand, { PATH: '/usr/bin' }) + mockChild.emit('close', 1) + + await resultPromise + expect(spawn).toHaveBeenCalled() + }) + }) + + describe.skipIf(process.platform === 'win32')('Unix/macOS behavior', () => { + it('should find command and return absolute path', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' }) + + // Simulate successful command -v output + mockChild.stdout.emit('data', '/usr/local/bin/npx\n') + mockChild.emit('close', 0) + + const result = await resultPromise + expect(result).toBe('/usr/local/bin/npx') + expect(spawn).toHaveBeenCalledWith('/bin/sh', ['-c', 'command -v "$1"', '--', 'npx'], expect.any(Object)) + }) + + it('should return null for non-absolute paths (aliases/builtins)', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('cd', { PATH: '/usr/bin' }) + + // Simulate builtin output (just command name) + mockChild.stdout.emit('data', 'cd\n') + mockChild.emit('close', 0) + + const result = await resultPromise + expect(result).toBeNull() + }) + + it('should return null when command not found', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('nonexistent', { PATH: '/usr/bin' }) + + // Simulate command not found (exit code 1) + mockChild.emit('close', 1) + + const result = await resultPromise + expect(result).toBeNull() + }) + + it('should handle spawn errors gracefully', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' }) + + // Simulate spawn error + mockChild.emit('error', new Error('spawn failed')) + + const result = await resultPromise + expect(result).toBeNull() + }) + + it('should handle timeout gracefully', async () => { + vi.useFakeTimers() + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' }) + + // Fast-forward past timeout (5000ms) + vi.advanceTimersByTime(6000) + + const result = await resultPromise + expect(result).toBeNull() + expect(mockChild.kill).toHaveBeenCalledWith('SIGKILL') + + vi.useRealTimers() + }) + }) + + describe.skipIf(process.platform !== 'win32')('Windows behavior', () => { + it('should find .exe files via where command', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' }) + + // Simulate where output + mockChild.stdout.emit('data', 'C:\\Program Files\\nodejs\\npx.exe\r\n') + mockChild.emit('close', 0) + + const result = await resultPromise + expect(result).toBe('C:\\Program Files\\nodejs\\npx.exe') + expect(spawn).toHaveBeenCalledWith('where', ['npx'], expect.any(Object)) + }) + + it('should reject .cmd files on Windows', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' }) + + // Simulate where output with only .cmd file + mockChild.stdout.emit('data', 'C:\\Program Files\\nodejs\\npx.cmd\r\n') + mockChild.emit('close', 0) + + const result = await resultPromise + expect(result).toBeNull() + }) + + it('should prefer .exe over .cmd when both exist', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' }) + + // Simulate where output with both .cmd and .exe + mockChild.stdout.emit('data', 'C:\\Program Files\\nodejs\\npx.cmd\r\nC:\\Program Files\\nodejs\\npx.exe\r\n') + mockChild.emit('close', 0) + + const result = await resultPromise + expect(result).toBe('C:\\Program Files\\nodejs\\npx.exe') + }) + + it('should handle spawn errors gracefully', async () => { + const mockChild = createMockChildProcess() + vi.mocked(spawn).mockReturnValue(mockChild as never) + + const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' }) + + // Simulate spawn error + mockChild.emit('error', new Error('spawn failed')) + + const result = await resultPromise + expect(result).toBeNull() + }) + }) }) diff --git a/src/main/utils/language.ts b/src/main/utils/language.ts index a0f2c8dc9b..a5dcb1bbf3 100644 --- a/src/main/utils/language.ts +++ b/src/main/utils/language.ts @@ -13,6 +13,7 @@ import esES from '../../renderer/src/i18n/translate/es-es.json' import frFR from '../../renderer/src/i18n/translate/fr-fr.json' import JaJP from '../../renderer/src/i18n/translate/ja-jp.json' import ptPT from '../../renderer/src/i18n/translate/pt-pt.json' +import roRO from '../../renderer/src/i18n/translate/ro-ro.json' import RuRu from '../../renderer/src/i18n/translate/ru-ru.json' export const locales = Object.fromEntries( @@ -26,7 +27,8 @@ export const locales = Object.fromEntries( ['el-GR', elGR], ['es-ES', esES], ['fr-FR', frFR], - ['pt-PT', ptPT] + ['pt-PT', ptPT], + ['ro-RO', roRO] ].map(([locale, translation]) => [locale, { translation }]) ) diff --git a/src/main/utils/mcp.ts b/src/main/utils/mcp.ts index cfa700f2e6..34eb0e63e7 100644 --- a/src/main/utils/mcp.ts +++ b/src/main/utils/mcp.ts @@ -1,56 +1,28 @@ -export function buildFunctionCallToolName(serverName: string, toolName: string, serverId?: string) { - const sanitizedServer = serverName.trim().replace(/-/g, '_') - const sanitizedTool = toolName.trim().replace(/-/g, '_') +/** + * Builds a valid JavaScript function name for MCP tool calls. + * Format: mcp__{server_name}__{tool_name} + * + * @param serverName - The MCP server name + * @param toolName - The tool name from the server + * @returns A valid JS identifier in format mcp__{server}__{tool}, max 63 chars + */ +export function buildFunctionCallToolName(serverName: string, toolName: string): string { + // Sanitize to valid JS identifier chars (alphanumeric + underscore only) + const sanitize = (str: string): string => + str + .trim() + .replace(/[^a-zA-Z0-9]/g, '_') // Replace all non-alphanumeric with underscore + .replace(/_{2,}/g, '_') // Collapse multiple underscores + .replace(/^_+|_+$/g, '') // Trim leading/trailing underscores - // Calculate suffix first to reserve space for it - // Suffix format: "_" + 6 alphanumeric chars = 7 chars total - let serverIdSuffix = '' - if (serverId) { - // Take the last 6 characters of the serverId for brevity - serverIdSuffix = serverId.slice(-6).replace(/[^a-zA-Z0-9]/g, '') + const server = sanitize(serverName).slice(0, 20) // Keep server name short + const tool = sanitize(toolName).slice(0, 35) // More room for tool name - // Fallback: if suffix becomes empty (all non-alphanumeric chars), use a simple hash - if (!serverIdSuffix) { - const hash = serverId.split('').reduce((acc, char) => acc + char.charCodeAt(0), 0) - serverIdSuffix = hash.toString(36).slice(-6) || 'x' - } - } + let name = `mcp__${server}__${tool}` - // Reserve space for suffix when calculating max base name length - const SUFFIX_LENGTH = serverIdSuffix ? serverIdSuffix.length + 1 : 0 // +1 for underscore - const MAX_BASE_LENGTH = 63 - SUFFIX_LENGTH - - // Combine server name and tool name - let name = sanitizedTool - if (!sanitizedTool.includes(sanitizedServer.slice(0, 7))) { - name = `${sanitizedServer.slice(0, 7) || ''}-${sanitizedTool || ''}` - } - - // Replace invalid characters with underscores or dashes - // Keep a-z, A-Z, 0-9, underscores and dashes - name = name.replace(/[^a-zA-Z0-9_-]/g, '_') - - // Ensure name starts with a letter or underscore (for valid JavaScript identifier) - if (!/^[a-zA-Z]/.test(name)) { - name = `tool-${name}` - } - - // Remove consecutive underscores/dashes (optional improvement) - name = name.replace(/[_-]{2,}/g, '_') - - // Truncate base name BEFORE adding suffix to ensure suffix is never cut off - if (name.length > MAX_BASE_LENGTH) { - name = name.slice(0, MAX_BASE_LENGTH) - } - - // Handle edge case: ensure we still have a valid name if truncation left invalid chars at edges - if (name.endsWith('_') || name.endsWith('-')) { - name = name.slice(0, -1) - } - - // Now append the suffix - it will always fit within 63 chars - if (serverIdSuffix) { - name = `${name}_${serverIdSuffix}` + // Ensure max 63 chars and clean trailing underscores + if (name.length > 63) { + name = name.slice(0, 63).replace(/_+$/, '') } return name diff --git a/src/main/utils/process.ts b/src/main/utils/process.ts index 7175af7e75..7c928949fe 100644 --- a/src/main/utils/process.ts +++ b/src/main/utils/process.ts @@ -1,4 +1,5 @@ import { loggerService } from '@logger' +import type { GitBashPathInfo, GitBashPathSource } from '@shared/config/constant' import { HOME_CHERRY_DIR } from '@shared/config/constant' import { execFileSync, spawn } from 'child_process' import fs from 'fs' @@ -6,6 +7,7 @@ import os from 'os' import path from 'path' import { isWin } from '../constant' +import { ConfigKeys, configManager } from '../services/ConfigManager' import { getResourcePath } from '.' const logger = loggerService.withContext('Utils:Process') @@ -59,7 +61,146 @@ export async function getBinaryPath(name?: string): Promise { export async function isBinaryExists(name: string): Promise { const cmd = await getBinaryPath(name) - return await fs.existsSync(cmd) + return fs.existsSync(cmd) +} + +// Timeout for command lookup operations (in milliseconds) +const COMMAND_LOOKUP_TIMEOUT_MS = 5000 + +// Regex to validate command names - must start with alphanumeric or underscore, max 128 chars +const VALID_COMMAND_NAME_REGEX = /^[a-zA-Z0-9_][a-zA-Z0-9_-]{0,127}$/ + +// Maximum output size to prevent buffer overflow (10KB) +const MAX_OUTPUT_SIZE = 10240 + +/** + * Check if a command is available in the user's login shell environment + * @param command - Command name to check (e.g., 'npx', 'uvx') + * @param loginShellEnv - The login shell environment from getLoginShellEnvironment() + * @returns Full path to the command if found, null otherwise + */ +export async function findCommandInShellEnv( + command: string, + loginShellEnv: Record +): Promise { + // Validate command name to prevent command injection + if (!VALID_COMMAND_NAME_REGEX.test(command)) { + logger.warn(`Invalid command name '${command}' - must only contain alphanumeric characters, underscore, or hyphen`) + return null + } + + return new Promise((resolve) => { + let resolved = false + + const safeResolve = (value: string | null) => { + if (resolved) return + resolved = true + resolve(value) + } + + if (isWin) { + // On Windows, use 'where' command + const child = spawn('where', [command], { + env: loginShellEnv, + stdio: ['ignore', 'pipe', 'pipe'] + }) + + let output = '' + const timeoutId = setTimeout(() => { + if (resolved) return + child.kill('SIGKILL') + logger.debug(`Timeout checking command '${command}' on Windows`) + safeResolve(null) + }, COMMAND_LOOKUP_TIMEOUT_MS) + + child.stdout.on('data', (data) => { + if (output.length < MAX_OUTPUT_SIZE) { + output += data.toString() + } + }) + + child.on('close', (code) => { + clearTimeout(timeoutId) + if (resolved) return + + if (code === 0 && output.trim()) { + const paths = output.trim().split(/\r?\n/) + // Only accept .exe files on Windows - .cmd/.bat files cannot be executed + // with spawn({ shell: false }) which is used by MCP SDK's StdioClientTransport + const exePath = paths.find((p) => p.toLowerCase().endsWith('.exe')) + if (exePath) { + logger.debug(`Found command '${command}' at: ${exePath}`) + safeResolve(exePath) + } else { + logger.debug(`Command '${command}' found but not as .exe (${paths[0]}), treating as not found`) + safeResolve(null) + } + } else { + logger.debug(`Command '${command}' not found in shell environment`) + safeResolve(null) + } + }) + + child.on('error', (error) => { + clearTimeout(timeoutId) + if (resolved) return + logger.warn(`Error checking command '${command}':`, { error, platform: 'windows' }) + safeResolve(null) + }) + } else { + // Unix/Linux/macOS: use 'command -v' which is POSIX standard + // Use /bin/sh for reliability - it's POSIX compliant and always available + // This avoids issues with user's custom shell (csh, fish, etc.) + // SECURITY: Use positional parameter $1 to prevent command injection + const child = spawn('/bin/sh', ['-c', 'command -v "$1"', '--', command], { + env: loginShellEnv, + stdio: ['ignore', 'pipe', 'pipe'] + }) + + let output = '' + const timeoutId = setTimeout(() => { + if (resolved) return + child.kill('SIGKILL') + logger.debug(`Timeout checking command '${command}'`) + safeResolve(null) + }, COMMAND_LOOKUP_TIMEOUT_MS) + + child.stdout.on('data', (data) => { + if (output.length < MAX_OUTPUT_SIZE) { + output += data.toString() + } + }) + + child.on('close', (code) => { + clearTimeout(timeoutId) + if (resolved) return + + if (code === 0 && output.trim()) { + const commandPath = output.trim().split('\n')[0] + + // Validate the output is an absolute path (not an alias, function, or builtin) + // command -v can return just the command name for aliases/builtins + if (path.isAbsolute(commandPath)) { + logger.debug(`Found command '${command}' at: ${commandPath}`) + safeResolve(commandPath) + } else { + logger.debug(`Command '${command}' resolved to non-path '${commandPath}', treating as not found`) + safeResolve(null) + } + } else { + logger.debug(`Command '${command}' not found in shell environment`) + safeResolve(null) + } + }) + + child.on('error', (error) => { + clearTimeout(timeoutId) + if (resolved) return + logger.warn(`Error checking command '${command}':`, { error, platform: 'unix' }) + safeResolve(null) + }) + } + }) } /** @@ -225,3 +366,77 @@ export function validateGitBashPath(customPath?: string | null): string | null { logger.debug('Validated custom Git Bash path', { path: resolved }) return resolved } + +/** + * Auto-discover and persist Git Bash path if not already configured + * Only called when Git Bash is actually needed + * + * Precedence order: + * 1. CLAUDE_CODE_GIT_BASH_PATH environment variable (highest - runtime override) + * 2. Configured path from settings (manual or auto) + * 3. Auto-discovery via findGitBash (only if no valid config exists) + */ +export function autoDiscoverGitBash(): string | null { + if (!isWin) { + return null + } + + // 1. Check environment variable override first (highest priority) + const envOverride = process.env.CLAUDE_CODE_GIT_BASH_PATH + if (envOverride) { + const validated = validateGitBashPath(envOverride) + if (validated) { + logger.debug('Using CLAUDE_CODE_GIT_BASH_PATH override', { path: validated }) + return validated + } + logger.warn('CLAUDE_CODE_GIT_BASH_PATH provided but path is invalid', { path: envOverride }) + } + + // 2. Check if a path is already configured + const existingPath = configManager.get(ConfigKeys.GitBashPath) + const existingSource = configManager.get(ConfigKeys.GitBashPathSource) + + if (existingPath) { + const validated = validateGitBashPath(existingPath) + if (validated) { + return validated + } + // Existing path is invalid, try to auto-discover + logger.warn('Existing Git Bash path is invalid, attempting auto-discovery', { + path: existingPath, + source: existingSource + }) + } + + // 3. Try to find Git Bash via auto-discovery + const discoveredPath = findGitBash() + if (discoveredPath) { + // Persist the discovered path with 'auto' source + configManager.set(ConfigKeys.GitBashPath, discoveredPath) + configManager.set(ConfigKeys.GitBashPathSource, 'auto') + logger.info('Auto-discovered Git Bash path', { path: discoveredPath }) + } + + return discoveredPath +} + +/** + * Get Git Bash path info including source + * If no path is configured, triggers auto-discovery first + */ +export function getGitBashPathInfo(): GitBashPathInfo { + if (!isWin) { + return { path: null, source: null } + } + + let path = configManager.get(ConfigKeys.GitBashPath) ?? null + let source = configManager.get(ConfigKeys.GitBashPathSource) ?? null + + // If no path configured, trigger auto-discovery (handles upgrade from old versions) + if (!path) { + path = autoDiscoverGitBash() + source = path ? 'auto' : null + } + + return { path, source } +} diff --git a/src/main/utils/system.ts b/src/main/utils/system.ts new file mode 100644 index 0000000000..2cd9e4bf22 --- /dev/null +++ b/src/main/utils/system.ts @@ -0,0 +1,19 @@ +import os from 'node:os' + +import { isMac, isWin } from '@main/constant' + +export const getDeviceType = () => (isMac ? 'mac' : isWin ? 'windows' : 'linux') + +export const getHostname = () => os.hostname() + +export const getCpuName = () => { + try { + const cpus = os.cpus() + if (!cpus || cpus.length === 0 || !cpus[0].model) { + return 'Unknown CPU' + } + return cpus[0].model + } catch { + return 'Unknown CPU' + } +} diff --git a/src/preload/index.ts b/src/preload/index.ts index 0d50ea5f62..a6f68fe07b 100644 --- a/src/preload/index.ts +++ b/src/preload/index.ts @@ -2,11 +2,19 @@ import type { PermissionUpdate } from '@anthropic-ai/claude-agent-sdk' import { electronAPI } from '@electron-toolkit/preload' import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core' import type { SpanContext } from '@opentelemetry/api' -import type { TerminalConfig } from '@shared/config/constant' +import type { GitBashPathInfo, TerminalConfig } from '@shared/config/constant' import type { LogLevel, LogSourceWithContext } from '@shared/config/logger' -import type { FileChangeEvent, WebviewKeyEvent } from '@shared/config/types' +import type { + FileChangeEvent, + LanClientEvent, + LanFileCompleteMessage, + LanHandshakeAckMessage, + LocalTransferConnectPayload, + LocalTransferState, + WebviewKeyEvent +} from '@shared/config/types' import type { MCPServerLogEntry } from '@shared/config/types' -import type { CacheSyncMessage } from '@shared/data/cache/cacheTypes' +import type { CacheEntry, CacheSyncMessage } from '@shared/data/cache/cacheTypes' import type { PreferenceDefaultScopeType, PreferenceKeyType, @@ -132,6 +140,7 @@ const api = { getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName), checkGitBash: (): Promise => ipcRenderer.invoke(IpcChannel.System_CheckGitBash), getGitBashPath: (): Promise => ipcRenderer.invoke(IpcChannel.System_GetGitBashPath), + getGitBashPathInfo: (): Promise => ipcRenderer.invoke(IpcChannel.System_GetGitBashPathInfo), setGitBashPath: (newPath: string | null): Promise => ipcRenderer.invoke(IpcChannel.System_SetGitBashPath, newPath) }, @@ -177,7 +186,11 @@ const api = { listS3Files: (s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_ListS3Files, s3Config), deleteS3File: (fileName: string, s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_DeleteS3File, fileName, s3Config), - checkS3Connection: (s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_CheckS3Connection, s3Config) + checkS3Connection: (s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_CheckS3Connection, s3Config), + createLanTransferBackup: (data: string): Promise => + ipcRenderer.invoke(IpcChannel.Backup_CreateLanTransferBackup, data), + deleteTempBackup: (filePath: string): Promise => + ipcRenderer.invoke(IpcChannel.Backup_DeleteTempBackup, filePath) }, file: { select: (options?: OpenDialogOptions): Promise => @@ -303,7 +316,8 @@ const api = { deleteUser: (userId: string) => ipcRenderer.invoke(IpcChannel.Memory_DeleteUser, userId), deleteAllMemoriesForUser: (userId: string) => ipcRenderer.invoke(IpcChannel.Memory_DeleteAllMemoriesForUser, userId), - getUsersList: () => ipcRenderer.invoke(IpcChannel.Memory_GetUsersList) + getUsersList: () => ipcRenderer.invoke(IpcChannel.Memory_GetUsersList), + migrateMemoryDb: () => ipcRenderer.invoke(IpcChannel.Memory_MigrateMemoryDb) }, window: { setMinimumSize: (width: number, height: number) => @@ -332,6 +346,7 @@ const api = { ipcRenderer.invoke(IpcChannel.VertexAI_ClearAuthCache, projectId, clientEmail) }, ovms: { + isSupported: (): Promise => ipcRenderer.invoke(IpcChannel.Ovms_IsSupported), addModel: (modelName: string, modelId: string, modelSource: string, task: string) => ipcRenderer.invoke(IpcChannel.Ovms_AddModel, modelName, modelId, modelSource, task), stopAddModel: () => ipcRenderer.invoke(IpcChannel.Ovms_StopAddModel), @@ -434,7 +449,7 @@ const api = { ipcRenderer.invoke(IpcChannel.Nutstore_GetDirectoryContents, token, path) }, searchService: { - openSearchWindow: (uid: string) => ipcRenderer.invoke(IpcChannel.SearchWindow_Open, uid), + openSearchWindow: (uid: string, show?: boolean) => ipcRenderer.invoke(IpcChannel.SearchWindow_Open, uid, show), closeSearchWindow: (uid: string) => ipcRenderer.invoke(IpcChannel.SearchWindow_Close, uid), openUrlInSearchWindow: (uid: string, url: string) => ipcRenderer.invoke(IpcChannel.SearchWindow_OpenUrl, uid, url) }, @@ -566,7 +581,10 @@ const api = { const listener = (_: any, message: CacheSyncMessage) => callback(message) ipcRenderer.on(IpcChannel.Cache_Sync, listener) return () => ipcRenderer.off(IpcChannel.Cache_Sync, listener) - } + }, + + // Get all shared cache entries from Main for initialization sync + getAllShared: (): Promise> => ipcRenderer.invoke(IpcChannel.Cache_GetAllShared) }, // PreferenceService related APIs @@ -591,8 +609,6 @@ const api = { // Data API related APIs dataApi: { request: (req: any) => ipcRenderer.invoke(IpcChannel.DataApi_Request, req), - batch: (req: any) => ipcRenderer.invoke(IpcChannel.DataApi_Batch, req), - transaction: (req: any) => ipcRenderer.invoke(IpcChannel.DataApi_Transaction, req), subscribe: (path: string, callback: (data: any, event: string) => void) => { const channel = `${IpcChannel.DataApi_Stream}:${path}` const listener = (_: any, data: any, event: string) => callback(data, event) @@ -630,12 +646,32 @@ const api = { writeContent: (options: WritePluginContentOptions): Promise> => ipcRenderer.invoke(IpcChannel.ClaudeCodePlugin_WriteContent, options) }, - webSocket: { - start: () => ipcRenderer.invoke(IpcChannel.WebSocket_Start), - stop: () => ipcRenderer.invoke(IpcChannel.WebSocket_Stop), - status: () => ipcRenderer.invoke(IpcChannel.WebSocket_Status), - sendFile: (filePath: string) => ipcRenderer.invoke(IpcChannel.WebSocket_SendFile, filePath), - getAllCandidates: () => ipcRenderer.invoke(IpcChannel.WebSocket_GetAllCandidates) + localTransfer: { + getState: (): Promise => ipcRenderer.invoke(IpcChannel.LocalTransfer_ListServices), + startScan: (): Promise => ipcRenderer.invoke(IpcChannel.LocalTransfer_StartScan), + stopScan: (): Promise => ipcRenderer.invoke(IpcChannel.LocalTransfer_StopScan), + connect: (payload: LocalTransferConnectPayload): Promise => + ipcRenderer.invoke(IpcChannel.LocalTransfer_Connect, payload), + disconnect: (): Promise => ipcRenderer.invoke(IpcChannel.LocalTransfer_Disconnect), + onServicesUpdated: (callback: (state: LocalTransferState) => void): (() => void) => { + const channel = IpcChannel.LocalTransfer_ServicesUpdated + const listener = (_: Electron.IpcRendererEvent, state: LocalTransferState) => callback(state) + ipcRenderer.on(channel, listener) + return () => { + ipcRenderer.removeListener(channel, listener) + } + }, + onClientEvent: (callback: (event: LanClientEvent) => void): (() => void) => { + const channel = IpcChannel.LocalTransfer_ClientEvent + const listener = (_: Electron.IpcRendererEvent, event: LanClientEvent) => callback(event) + ipcRenderer.on(channel, listener) + return () => { + ipcRenderer.removeListener(channel, listener) + } + }, + sendFile: (filePath: string): Promise => + ipcRenderer.invoke(IpcChannel.LocalTransfer_SendFile, { filePath }), + cancelTransfer: (): Promise => ipcRenderer.invoke(IpcChannel.LocalTransfer_CancelTransfer) } } diff --git a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts index 9f9542a92a..00e0e4e9ce 100644 --- a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts +++ b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts @@ -119,6 +119,21 @@ export class AiSdkToChunkAdapter { } } + /** + * 如果有累积的思考内容,发送 THINKING_COMPLETE chunk 并清空 + * @param final 包含 reasoningContent 的状态对象 + * @returns 是否发送了 THINKING_COMPLETE chunk + */ + private emitThinkingCompleteIfNeeded(final: { reasoningContent: string; [key: string]: any }) { + if (final.reasoningContent) { + this.onChunk({ + type: ChunkType.THINKING_COMPLETE, + text: final.reasoningContent + }) + final.reasoningContent = '' + } + } + /** * 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调 * @param chunk AI SDK 的 chunk 数据 @@ -144,6 +159,9 @@ export class AiSdkToChunkAdapter { } // === 文本相关事件 === case 'text-start': + // 如果有未完成的思考内容,先生成 THINKING_COMPLETE + // 这处理了某些提供商不发送 reasoning-end 事件的情况 + this.emitThinkingCompleteIfNeeded(final) this.onChunk({ type: ChunkType.TEXT_START }) @@ -214,11 +232,7 @@ export class AiSdkToChunkAdapter { }) break case 'reasoning-end': - this.onChunk({ - type: ChunkType.THINKING_COMPLETE, - text: final.reasoningContent || '' - }) - final.reasoningContent = '' + this.emitThinkingCompleteIfNeeded(final) break // === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) === diff --git a/src/renderer/src/aiCore/legacy/clients/__tests__/OpenAIBaseClient.azureEndpoint.test.ts b/src/renderer/src/aiCore/legacy/clients/__tests__/OpenAIBaseClient.azureEndpoint.test.ts new file mode 100644 index 0000000000..e3b2ef2676 --- /dev/null +++ b/src/renderer/src/aiCore/legacy/clients/__tests__/OpenAIBaseClient.azureEndpoint.test.ts @@ -0,0 +1,38 @@ +import { describe, expect, it } from 'vitest' + +import { normalizeAzureOpenAIEndpoint } from '../openai/azureOpenAIEndpoint' + +describe('normalizeAzureOpenAIEndpoint', () => { + it.each([ + { + apiHost: 'https://example.openai.azure.com/openai', + expectedEndpoint: 'https://example.openai.azure.com' + }, + { + apiHost: 'https://example.openai.azure.com/openai/', + expectedEndpoint: 'https://example.openai.azure.com' + }, + { + apiHost: 'https://example.openai.azure.com/openai/v1', + expectedEndpoint: 'https://example.openai.azure.com' + }, + { + apiHost: 'https://example.openai.azure.com/openai/v1/', + expectedEndpoint: 'https://example.openai.azure.com' + }, + { + apiHost: 'https://example.openai.azure.com', + expectedEndpoint: 'https://example.openai.azure.com' + }, + { + apiHost: 'https://example.openai.azure.com/', + expectedEndpoint: 'https://example.openai.azure.com' + }, + { + apiHost: 'https://example.openai.azure.com/OPENAI/V1', + expectedEndpoint: 'https://example.openai.azure.com' + } + ])('strips trailing /openai from $apiHost', ({ apiHost, expectedEndpoint }) => { + expect(normalizeAzureOpenAIEndpoint(apiHost)).toBe(expectedEndpoint) + }) +}) diff --git a/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts index ac10106f37..d7f14326f6 100644 --- a/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts @@ -46,7 +46,6 @@ import type { GeminiSdkRawOutput, GeminiSdkToolCall } from '@renderer/types/sdk' -import { getTrailingApiVersion, withoutTrailingApiVersion } from '@renderer/utils' import { isToolUseModeFunction } from '@renderer/utils/assistant' import { geminiFunctionCallToMcpTool, @@ -56,6 +55,7 @@ import { } from '@renderer/utils/mcp-tools' import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { defaultTimeout, MB } from '@shared/config/constant' +import { getTrailingApiVersion, withoutTrailingApiVersion } from '@shared/utils' import { t } from 'i18next' import type { GenericChunk } from '../../middleware/schemas' diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts index cfc9087545..73a5bed4fe 100644 --- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts @@ -10,7 +10,7 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { findTokenLimit, GEMINI_FLASH_MODEL_REGEX, - getThinkModelType, + getModelSupportedReasoningEffortOptions, isDeepSeekHybridInferenceModel, isDoubaoThinkingAutoModel, isGPT5SeriesModel, @@ -33,7 +33,6 @@ import { isSupportedThinkingTokenQwenModel, isSupportedThinkingTokenZhipuModel, isVisionModel, - MODEL_SUPPORTED_REASONING_EFFORT, ZHIPU_RESULT_TOKENS } from '@renderer/config/models' import { mapLanguageToQwenMTModel } from '@renderer/config/translate' @@ -143,6 +142,10 @@ export class OpenAIAPIClient extends OpenAIBaseClient< return { thinking: { type: reasoningEffort ? 'enabled' : 'disabled' } } } + if (reasoningEffort === 'default') { + return {} + } + if (!reasoningEffort) { // DeepSeek hybrid inference models, v3.1 and maybe more in the future // 不同的 provider 有不同的思考控制方式,在这里统一解决 @@ -304,16 +307,15 @@ export class OpenAIAPIClient extends OpenAIBaseClient< // Grok models/Perplexity models/OpenAI models if (isSupportedReasoningEffortModel(model)) { // 检查模型是否支持所选选项 - const modelType = getThinkModelType(model) - const supportedOptions = MODEL_SUPPORTED_REASONING_EFFORT[modelType] - if (supportedOptions.includes(reasoningEffort)) { + const supportedOptions = getModelSupportedReasoningEffortOptions(model)?.filter((option) => option !== 'default') + if (supportedOptions?.includes(reasoningEffort)) { return { reasoning_effort: reasoningEffort } } else { // 如果不支持,fallback到第一个支持的值 return { - reasoning_effort: supportedOptions[0] + reasoning_effort: supportedOptions?.[0] } } } diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts index c51f8aac8a..efc3f4f7ce 100644 --- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts @@ -29,6 +29,7 @@ import { withoutTrailingSlash } from '@renderer/utils/api' import { isOllamaProvider } from '@renderer/utils/provider' import { BaseApiClient } from '../BaseApiClient' +import { normalizeAzureOpenAIEndpoint } from './azureOpenAIEndpoint' const logger = loggerService.withContext('OpenAIBaseClient') @@ -69,7 +70,7 @@ export abstract class OpenAIBaseClient< const sdk = await this.getSdkInstance() const response = (await sdk.request({ method: 'post', - path: '/images/generations', + path: '/v1/images/generations', signal, body: { model, @@ -88,7 +89,11 @@ export abstract class OpenAIBaseClient< } override async getEmbeddingDimensions(model: Model): Promise { - const sdk = await this.getSdkInstance() + let sdk: OpenAI = await this.getSdkInstance() + if (isOllamaProvider(this.provider)) { + const embedBaseUrl = `${this.provider.apiHost.replace(/(\/(api|v1))\/?$/, '')}/v1` + sdk = sdk.withOptions({ baseURL: embedBaseUrl }) + } const data = await sdk.embeddings.create({ model: model.id, @@ -209,7 +214,7 @@ export abstract class OpenAIBaseClient< dangerouslyAllowBrowser: true, apiKey: apiKeyForSdkInstance, apiVersion: this.provider.apiVersion, - endpoint: this.provider.apiHost + endpoint: normalizeAzureOpenAIEndpoint(this.provider.apiHost) }) as TSdkInstance } else { this.sdkInstance = new OpenAI({ diff --git a/src/renderer/src/aiCore/legacy/clients/openai/azureOpenAIEndpoint.ts b/src/renderer/src/aiCore/legacy/clients/openai/azureOpenAIEndpoint.ts new file mode 100644 index 0000000000..777dbe74d7 --- /dev/null +++ b/src/renderer/src/aiCore/legacy/clients/openai/azureOpenAIEndpoint.ts @@ -0,0 +1,4 @@ +export function normalizeAzureOpenAIEndpoint(apiHost: string): string { + const normalizedHost = apiHost.replace(/\/+$/, '') + return normalizedHost.replace(/\/openai(?:\/v1)?$/i, '') +} diff --git a/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts b/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts index 2be36057bf..3ee4c6a15c 100644 --- a/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts @@ -2,7 +2,8 @@ import type OpenAI from '@cherrystudio/openai' import { loggerService } from '@logger' import { isSupportedModel } from '@renderer/config/models' import { objectKeys, type Provider } from '@renderer/types' -import { formatApiHost, withoutTrailingApiVersion } from '@renderer/utils' +import { formatApiHost } from '@renderer/utils' +import { withoutTrailingApiVersion } from '@shared/utils' import { OpenAIAPIClient } from '../openai/OpenAIApiClient' diff --git a/src/renderer/src/aiCore/legacy/clients/zhipu/ZhipuAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/zhipu/ZhipuAPIClient.ts index 2c77649634..d31632de57 100644 --- a/src/renderer/src/aiCore/legacy/clients/zhipu/ZhipuAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/zhipu/ZhipuAPIClient.ts @@ -65,6 +65,11 @@ export class ZhipuAPIClient extends OpenAIAPIClient { public async listModels(): Promise { const models = [ + 'glm-4.7', + 'glm-4.6', + 'glm-4.6v', + 'glm-4.6v-flash', + 'glm-4.6v-flashx', 'glm-4.5', 'glm-4.5-x', 'glm-4.5-air', diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts index 10a4d59384..b2a796bd33 100644 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -7,7 +7,6 @@ import type { Chunk } from '@renderer/types/chunk' import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider' import type { LanguageModelMiddleware } from 'ai' import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' -import { isEmpty } from 'lodash' import { getAiSdkProviderId } from '../provider/factory' import { isOpenRouterGeminiGenerateImageModel } from '../utils/image' @@ -16,7 +15,6 @@ import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMidd import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware' import { qwenThinkingMiddleware } from './qwenThinkingMiddleware' import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware' -import { toolChoiceMiddleware } from './toolChoiceMiddleware' const logger = loggerService.withContext('AiSdkMiddlewareBuilder') @@ -136,15 +134,6 @@ export class AiSdkMiddlewareBuilder { export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] { const builder = new AiSdkMiddlewareBuilder() - // 0. 知识库强制调用中间件(必须在最前面,确保第一轮强制调用知识库) - if (!isEmpty(config.assistant?.knowledge_bases?.map((base) => base.id)) && config.knowledgeRecognition !== 'on') { - builder.add({ - name: 'force-knowledge-first', - middleware: toolChoiceMiddleware('builtin_knowledge_search') - }) - logger.debug('Added toolChoice middleware to force knowledge base search on first round') - } - // 1. 根据provider添加特定中间件 if (config.provider) { addProviderSpecificMiddlewares(builder, config) diff --git a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts index aa4a48693b..38b8f4bd5c 100644 --- a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts +++ b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts @@ -32,7 +32,7 @@ import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool' const logger = loggerService.withContext('SearchOrchestrationPlugin') -const getMessageContent = (message: ModelMessage) => { +export const getMessageContent = (message: ModelMessage) => { if (typeof message.content === 'string') return message.content return message.content.reduce((acc, part) => { if (part.type === 'text') { @@ -267,14 +267,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) // 判断是否需要各种搜索 const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) const hasKnowledgeBase = !isEmpty(knowledgeBaseIds) - const knowledgeRecognition = assistant.knowledgeRecognition || 'on' + const knowledgeRecognition = assistant.knowledgeRecognition || 'off' const globalMemoryEnabled = await preferenceService.get('feature.memory.enabled') const shouldWebSearch = !!assistant.webSearchProviderId const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on' const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory // 执行意图分析 - if (shouldWebSearch || hasKnowledgeBase) { + if (shouldWebSearch || shouldKnowledgeSearch) { const analysisResult = await analyzeSearchIntent(lastUserMessage, assistant, { shouldWebSearch, shouldKnowledgeSearch, @@ -331,41 +331,25 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) // 📚 知识库搜索工具配置 const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) const hasKnowledgeBase = !isEmpty(knowledgeBaseIds) - const knowledgeRecognition = assistant.knowledgeRecognition || 'on' + const knowledgeRecognition = assistant.knowledgeRecognition || 'off' + const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on' - if (hasKnowledgeBase) { - if (knowledgeRecognition === 'off') { - // off 模式:直接添加知识库搜索工具,使用用户消息作为搜索关键词 + if (shouldKnowledgeSearch) { + // on 模式:根据意图识别结果决定是否添加工具 + const needsKnowledgeSearch = + analysisResult?.knowledge && + analysisResult.knowledge.question && + analysisResult.knowledge.question[0] !== 'not_needed' + + if (needsKnowledgeSearch && analysisResult.knowledge) { + // logger.info('📚 Adding knowledge search tool (intent-based)') const userMessage = userMessages[context.requestId] - const fallbackKeywords = { - question: [getMessageContent(userMessage) || 'search'], - rewrite: getMessageContent(userMessage) || 'search' - } - // logger.info('📚 Adding knowledge search tool (force mode)') params.tools['builtin_knowledge_search'] = knowledgeSearchTool( assistant, - fallbackKeywords, + analysisResult.knowledge, getMessageContent(userMessage), topicId ) - // params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' } - } else { - // on 模式:根据意图识别结果决定是否添加工具 - const needsKnowledgeSearch = - analysisResult?.knowledge && - analysisResult.knowledge.question && - analysisResult.knowledge.question[0] !== 'not_needed' - - if (needsKnowledgeSearch && analysisResult.knowledge) { - // logger.info('📚 Adding knowledge search tool (intent-based)') - const userMessage = userMessages[context.requestId] - params.tools['builtin_knowledge_search'] = knowledgeSearchTool( - assistant, - analysisResult.knowledge, - getMessageContent(userMessage), - topicId - ) - } } } diff --git a/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts b/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts index cb0c5cf9a6..2a69f3bcef 100644 --- a/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts +++ b/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts @@ -109,6 +109,20 @@ const createImageBlock = ( ...overrides }) +const createThinkingBlock = ( + messageId: string, + overrides: Partial> = {} +): ThinkingMessageBlock => ({ + id: overrides.id ?? `thinking-block-${++blockCounter}`, + messageId, + type: MessageBlockType.THINKING, + createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(), + status: overrides.status ?? MessageBlockStatus.SUCCESS, + content: overrides.content ?? 'Let me think...', + thinking_millsec: overrides.thinking_millsec ?? 1000, + ...overrides +}) + describe('messageConverter', () => { beforeEach(() => { convertFileBlockToFilePartMock.mockReset() @@ -229,6 +243,23 @@ describe('messageConverter', () => { } ]) }) + + it('includes reasoning parts for assistant messages with thinking blocks', async () => { + const model = createModel() + const message = createMessage('assistant') + message.__mockContent = 'Here is my answer' + message.__mockThinkingBlocks = [createThinkingBlock(message.id, { content: 'Let me think...' })] + + const result = await convertMessageToSdkParam(message, false, model) + + expect(result).toEqual({ + role: 'assistant', + content: [ + { type: 'text', text: 'Here is my answer' }, + { type: 'reasoning', text: 'Let me think...' } + ] + }) + }) }) describe('convertMessagesToSdkMessages', () => { diff --git a/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts b/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts index 70b4ac84b7..a4f345e3e5 100644 --- a/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts +++ b/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts @@ -18,7 +18,7 @@ vi.mock('@renderer/services/AssistantService', () => ({ toolUseMode: assistant.settings?.toolUseMode ?? 'prompt', defaultModel: assistant.defaultModel, customParameters: assistant.settings?.customParameters ?? [], - reasoning_effort: assistant.settings?.reasoning_effort, + reasoning_effort: assistant.settings?.reasoning_effort ?? 'default', reasoning_effort_cache: assistant.settings?.reasoning_effort_cache, qwenThinkMode: assistant.settings?.qwenThinkMode }) diff --git a/src/renderer/src/aiCore/prepareParams/messageConverter.ts b/src/renderer/src/aiCore/prepareParams/messageConverter.ts index 328a10b941..c3798c1f43 100644 --- a/src/renderer/src/aiCore/prepareParams/messageConverter.ts +++ b/src/renderer/src/aiCore/prepareParams/messageConverter.ts @@ -3,6 +3,7 @@ * 将 Cherry Studio 消息格式转换为 AI SDK 消息格式 */ +import type { ReasoningPart } from '@ai-sdk/provider-utils' import { loggerService } from '@logger' import { isImageEnhancementModel, isVisionModel } from '@renderer/config/models' import type { Message, Model } from '@renderer/types' @@ -163,13 +164,13 @@ async function convertMessageToAssistantModelMessage( thinkingBlocks: ThinkingMessageBlock[], model?: Model ): Promise { - const parts: Array = [] + const parts: Array = [] if (content) { parts.push({ type: 'text', text: content }) } for (const thinkingBlock of thinkingBlocks) { - parts.push({ type: 'text', text: thinkingBlock.content }) + parts.push({ type: 'reasoning', text: thinkingBlock.content }) } for (const fileBlock of fileBlocks) { diff --git a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts index cba7fcdb10..52234c5f1f 100644 --- a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts +++ b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts @@ -21,6 +21,7 @@ import { isGrokModel, isOpenAIModel, isOpenRouterBuiltInWebSearchModel, + isPureGenerateImageModel, isSupportedReasoningEffortModel, isSupportedThinkingTokenModel, isWebSearchModel @@ -33,7 +34,7 @@ import { type Assistant, type MCPTool, type Provider, SystemProviderIds } from ' import type { StreamTextParams } from '@renderer/types/aiCoreTypes' import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern' import { replacePromptVariables } from '@renderer/utils/prompt' -import { isAIGatewayProvider, isAwsBedrockProvider } from '@renderer/utils/provider' +import { isAIGatewayProvider, isAwsBedrockProvider, isSupportUrlContextProvider } from '@renderer/utils/provider' import type { ModelMessage, Tool } from 'ai' import { stepCountIs } from 'ai' @@ -118,7 +119,13 @@ export async function buildStreamTextParams( isOpenRouterBuiltInWebSearchModel(model) || model.id.includes('sonar')) - const enableUrlContext = assistant.enableUrlContext || false + // Validate provider and model support to prevent stale state from triggering urlContext + const enableUrlContext = !!( + assistant.enableUrlContext && + isSupportUrlContextProvider(provider) && + !isPureGenerateImageModel(model) && + (isGeminiModel(model) || isAnthropicModel(model)) + ) const enableGenerateImage = !!(isGenerateImageModel(model) && assistant.enableGenerateImage) diff --git a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts index 20aa78dcbd..b1d8e34fcd 100644 --- a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts +++ b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts @@ -79,7 +79,7 @@ vi.mock('@renderer/services/AssistantService', () => ({ import { getProviderByModel } from '@renderer/services/AssistantService' import type { Model, Provider } from '@renderer/types' import { formatApiHost } from '@renderer/utils/api' -import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider' +import { isAzureOpenAIProvider, isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider' import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants' import { getActualProvider, providerToAiSdkConfig } from '../providerConfig' @@ -133,6 +133,17 @@ const createPerplexityProvider = (): Provider => ({ isSystem: false }) +const createAzureProvider = (apiVersion: string): Provider => ({ + id: 'azure-openai', + type: 'azure-openai', + name: 'Azure OpenAI', + apiKey: 'test-key', + apiHost: 'https://example.openai.azure.com/openai', + apiVersion, + models: [], + isSystem: true +}) + describe('Copilot responses routing', () => { beforeEach(() => { ;(globalThis as any).window = { @@ -504,3 +515,46 @@ describe('Stream options includeUsage configuration', () => { expect(config.providerId).toBe('github-copilot-openai-compatible') }) }) + +describe('Azure OpenAI traditional API routing', () => { + beforeEach(() => { + ;(globalThis as any).window = { + ...(globalThis as any).window, + keyv: createWindowKeyv() + } + mockGetState.mockReturnValue({ + settings: { + openAI: { + streamOptions: { + includeUsage: undefined + } + } + } + }) + + vi.mocked(isAzureOpenAIProvider).mockImplementation((provider) => provider.type === 'azure-openai') + }) + + it('uses deployment-based URLs when apiVersion is a date version', () => { + const provider = createAzureProvider('2024-02-15-preview') + const config = providerToAiSdkConfig(provider, createModel('gpt-4o', 'GPT-4o', provider.id)) + + expect(config.providerId).toBe('azure') + expect(config.options.apiVersion).toBe('2024-02-15-preview') + expect(config.options.useDeploymentBasedUrls).toBe(true) + }) + + it('does not force deployment-based URLs for apiVersion v1/preview', () => { + const v1Provider = createAzureProvider('v1') + const v1Config = providerToAiSdkConfig(v1Provider, createModel('gpt-4o', 'GPT-4o', v1Provider.id)) + expect(v1Config.providerId).toBe('azure-responses') + expect(v1Config.options.apiVersion).toBe('v1') + expect(v1Config.options.useDeploymentBasedUrls).toBeUndefined() + + const previewProvider = createAzureProvider('preview') + const previewConfig = providerToAiSdkConfig(previewProvider, createModel('gpt-4o', 'GPT-4o', previewProvider.id)) + expect(previewConfig.providerId).toBe('azure-responses') + expect(previewConfig.options.apiVersion).toBe('preview') + expect(previewConfig.options.useDeploymentBasedUrls).toBeUndefined() + }) +}) diff --git a/src/renderer/src/aiCore/provider/factory.ts b/src/renderer/src/aiCore/provider/factory.ts index ff100051b7..d18aa02eeb 100644 --- a/src/renderer/src/aiCore/provider/factory.ts +++ b/src/renderer/src/aiCore/provider/factory.ts @@ -31,7 +31,8 @@ const STATIC_PROVIDER_MAPPING: Record = { 'azure-openai': 'azure', // Azure OpenAI -> azure 'openai-response': 'openai', // OpenAI Responses -> openai grok: 'xai', // Grok -> xai - copilot: 'github-copilot-openai-compatible' + copilot: 'github-copilot-openai-compatible', + tokenflux: 'openrouter' // TokenFlux -> openrouter (fully compatible) } /** diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts index 1c410bf124..0ad15ea895 100644 --- a/src/renderer/src/aiCore/provider/providerConfig.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -32,6 +32,7 @@ import { isSupportStreamOptionsProvider, isVertexProvider } from '@renderer/utils/provider' +import { defaultAppHeaders } from '@shared/utils' import { cloneDeep, isEmpty } from 'lodash' import type { AiSdkConfig } from '../types' @@ -197,18 +198,13 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A extraOptions.mode = 'chat' } - // 添加额外headers - if (actualProvider.extra_headers) { - extraOptions.headers = actualProvider.extra_headers - // copy from openaiBaseClient/openaiResponseApiClient - if (aiSdkProviderId === 'openai') { - extraOptions.headers = { - ...extraOptions.headers, - 'HTTP-Referer': 'https://cherry-ai.com', - 'X-Title': 'Cherry Studio', - 'X-Api-Key': baseConfig.apiKey - } - } + extraOptions.headers = { + ...defaultAppHeaders(), + ...actualProvider.extra_headers + } + + if (aiSdkProviderId === 'openai') { + extraOptions.headers['X-Api-Key'] = baseConfig.apiKey } // azure // https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest @@ -218,6 +214,15 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A } else if (aiSdkProviderId === 'azure') { extraOptions.mode = 'chat' } + if (isAzureOpenAIProvider(actualProvider)) { + const apiVersion = actualProvider.apiVersion?.trim() + if (apiVersion) { + extraOptions.apiVersion = apiVersion + if (!['preview', 'v1'].includes(apiVersion)) { + extraOptions.useDeploymentBasedUrls = true + } + } + } // bedrock if (aiSdkProviderId === 'bedrock') { diff --git a/src/renderer/src/aiCore/tools/MemorySearchTool.ts b/src/renderer/src/aiCore/tools/MemorySearchTool.ts index bf2bbd286a..f185f79841 100644 --- a/src/renderer/src/aiCore/tools/MemorySearchTool.ts +++ b/src/renderer/src/aiCore/tools/MemorySearchTool.ts @@ -25,7 +25,8 @@ export const memorySearchTool = () => { } const memoryConfig = selectMemoryConfig(store.getState()) - if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) { + + if (!memoryConfig.llmModel || !memoryConfig.embeddingModel) { return [] } diff --git a/src/renderer/src/aiCore/utils/__tests__/options.test.ts b/src/renderer/src/aiCore/utils/__tests__/options.test.ts index 9eeeac725b..a6c9a6c95c 100644 --- a/src/renderer/src/aiCore/utils/__tests__/options.test.ts +++ b/src/renderer/src/aiCore/utils/__tests__/options.test.ts @@ -464,7 +464,8 @@ describe('options utils', () => { custom_param: 'custom_value', another_param: 123, serviceTier: undefined, - textVerbosity: undefined + textVerbosity: undefined, + store: false } }) }) diff --git a/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts b/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts index fec4d197e3..df7d69d0c2 100644 --- a/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts +++ b/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts @@ -11,6 +11,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { getAnthropicReasoningParams, + getAnthropicThinkingBudget, getBedrockReasoningParams, getCustomParameters, getGeminiReasoningParams, @@ -89,7 +90,8 @@ vi.mock('@renderer/config/models', async (importOriginal) => { isQwenAlwaysThinkModel: vi.fn(() => false), isSupportedThinkingTokenHunyuanModel: vi.fn(() => false), isSupportedThinkingTokenModel: vi.fn(() => false), - isGPT51SeriesModel: vi.fn(() => false) + isGPT51SeriesModel: vi.fn(() => false), + findTokenLimit: vi.fn(actual.findTokenLimit) } }) @@ -596,7 +598,7 @@ describe('reasoning utils', () => { expect(result).toEqual({}) }) - it('should return disabled thinking when no reasoning effort', async () => { + it('should return disabled thinking when reasoning effort is none', async () => { const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models') vi.mocked(isReasoningModel).mockReturnValue(true) @@ -611,7 +613,9 @@ describe('reasoning utils', () => { const assistant: Assistant = { id: 'test', name: 'Test', - settings: {} + settings: { + reasoning_effort: 'none' + } } as Assistant const result = getAnthropicReasoningParams(assistant, model) @@ -647,7 +651,7 @@ describe('reasoning utils', () => { expect(result).toEqual({ thinking: { type: 'enabled', - budgetTokens: 2048 + budgetTokens: 4096 } }) }) @@ -675,7 +679,7 @@ describe('reasoning utils', () => { expect(result).toEqual({}) }) - it('should disable thinking for Flash models without reasoning effort', async () => { + it('should disable thinking for Flash models when reasoning effort is none', async () => { const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models') vi.mocked(isReasoningModel).mockReturnValue(true) @@ -690,7 +694,9 @@ describe('reasoning utils', () => { const assistant: Assistant = { id: 'test', name: 'Test', - settings: {} + settings: { + reasoning_effort: 'none' + } } as Assistant const result = getGeminiReasoningParams(assistant, model) @@ -725,7 +731,7 @@ describe('reasoning utils', () => { const result = getGeminiReasoningParams(assistant, model) expect(result).toEqual({ thinkingConfig: { - thinkingBudget: 16448, + thinkingBudget: expect.any(Number), includeThoughts: true } }) @@ -889,7 +895,7 @@ describe('reasoning utils', () => { expect(result).toEqual({ reasoningConfig: { type: 'enabled', - budgetTokens: 2048 + budgetTokens: 4096 } }) }) @@ -990,4 +996,89 @@ describe('reasoning utils', () => { }) }) }) + + describe('getAnthropicThinkingBudget', () => { + it('should return undefined when reasoningEffort is undefined', async () => { + const result = getAnthropicThinkingBudget(4096, undefined, 'claude-3-7-sonnet') + expect(result).toBeUndefined() + }) + + it('should return undefined when reasoningEffort is none', async () => { + const result = getAnthropicThinkingBudget(4096, 'none', 'claude-3-7-sonnet') + expect(result).toBeUndefined() + }) + + it('should return undefined when tokenLimit is not found', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue(undefined) + + const result = getAnthropicThinkingBudget(4096, 'medium', 'unknown-model') + expect(result).toBeUndefined() + }) + + it('should calculate budget correctly when maxTokens is provided', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 }) + + const result = getAnthropicThinkingBudget(4096, 'medium', 'claude-3-7-sonnet') + // EFFORT_RATIO['medium'] = 0.5 + // budget = Math.floor((32768 - 1024) * 0.5 + 1024) + // = Math.floor(31744 * 0.5 + 1024) = Math.floor(15872 + 1024) = 16896 + // budgetTokens = Math.min(16896, 4096) = 4096 + // result = Math.max(1024, 4096) = 4096 + expect(result).toBe(4096) + }) + + it('should use tokenLimit.max when maxTokens is undefined', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 }) + + const result = getAnthropicThinkingBudget(undefined, 'medium', 'claude-3-7-sonnet') + // When maxTokens is undefined, budget is not constrained by maxTokens + // EFFORT_RATIO['medium'] = 0.5 + // budget = Math.floor((32768 - 1024) * 0.5 + 1024) + // = Math.floor(31744 * 0.5 + 1024) = Math.floor(15872 + 1024) = 16896 + // result = Math.max(1024, 16896) = 16896 + expect(result).toBe(16896) + }) + + it('should enforce minimum budget of 1024', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 100, max: 1000 }) + + const result = getAnthropicThinkingBudget(500, 'low', 'claude-3-7-sonnet') + // EFFORT_RATIO['low'] = 0.05 + // budget = Math.floor((1000 - 100) * 0.05 + 100) + // = Math.floor(900 * 0.05 + 100) = Math.floor(45 + 100) = 145 + // budgetTokens = Math.min(145, 500) = 145 + // result = Math.max(1024, 145) = 1024 + expect(result).toBe(1024) + }) + + it('should respect effort ratio for high reasoning effort', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 }) + + const result = getAnthropicThinkingBudget(8192, 'high', 'claude-3-7-sonnet') + // EFFORT_RATIO['high'] = 0.8 + // budget = Math.floor((32768 - 1024) * 0.8 + 1024) + // = Math.floor(31744 * 0.8 + 1024) = Math.floor(25395.2 + 1024) = 26419 + // budgetTokens = Math.min(26419, 8192) = 8192 + // result = Math.max(1024, 8192) = 8192 + expect(result).toBe(8192) + }) + + it('should use full token limit when maxTokens is undefined and reasoning effort is high', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 }) + + const result = getAnthropicThinkingBudget(undefined, 'high', 'claude-3-7-sonnet') + // When maxTokens is undefined, budget is not constrained by maxTokens + // EFFORT_RATIO['high'] = 0.8 + // budget = Math.floor((32768 - 1024) * 0.8 + 1024) + // = Math.floor(31744 * 0.8 + 1024) = Math.floor(25395.2 + 1024) = 26419 + // result = Math.max(1024, 26419) = 26419 + expect(result).toBe(26419) + }) + }) }) diff --git a/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts b/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts index fa5e3c3b36..5c95a664aa 100644 --- a/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts +++ b/src/renderer/src/aiCore/utils/__tests__/websearch.test.ts @@ -259,7 +259,7 @@ describe('websearch utils', () => { expect(result).toEqual({ xai: { - maxSearchResults: 50, + maxSearchResults: 30, returnCitations: true, sources: [{ type: 'web', excludedWebsites: [] }, { type: 'news' }, { type: 'x' }], mode: 'on' diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index fd9bc590cd..8dc7a10af9 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -10,6 +10,7 @@ import { isAnthropicModel, isGeminiModel, isGrokModel, + isInterleavedThinkingModel, isOpenAIModel, isOpenAIOpenWeightModel, isQwenMTModel, @@ -396,10 +397,12 @@ function buildOpenAIProviderOptions( } } + // TODO: 支持配置是否在服务端持久化 providerOptions = { ...providerOptions, serviceTier, - textVerbosity + textVerbosity, + store: false } return { @@ -577,8 +580,10 @@ function buildOllamaProviderOptions( const reasoningEffort = assistant.settings?.reasoning_effort if (enableReasoning) { if (isOpenAIOpenWeightModel(model)) { - // @ts-ignore upstream type error - providerOptions.think = reasoningEffort as any + // For gpt-oss models, Ollama accepts: 'low' | 'medium' | 'high' + if (reasoningEffort === 'low' || reasoningEffort === 'medium' || reasoningEffort === 'high') { + providerOptions.think = reasoningEffort + } } else { providerOptions.think = !['none', undefined].includes(reasoningEffort) } @@ -601,7 +606,7 @@ function buildGenericProviderOptions( enableGenerateImage: boolean } ): Record { - const { enableWebSearch } = capabilities + const { enableWebSearch, enableReasoning } = capabilities let providerOptions: Record = {} const reasoningParams = getReasoningEffort(assistant, model) @@ -609,6 +614,14 @@ function buildGenericProviderOptions( ...providerOptions, ...reasoningParams } + if (enableReasoning) { + if (isInterleavedThinkingModel(model)) { + providerOptions = { + ...providerOptions, + sendReasoning: true + } + } + } if (enableWebSearch) { const webSearchParams = getWebSearchParams(model) diff --git a/src/renderer/src/aiCore/utils/reasoning.ts b/src/renderer/src/aiCore/utils/reasoning.ts index 996d676761..ab8a0b7983 100644 --- a/src/renderer/src/aiCore/utils/reasoning.ts +++ b/src/renderer/src/aiCore/utils/reasoning.ts @@ -8,12 +8,12 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { findTokenLimit, GEMINI_FLASH_MODEL_REGEX, - getThinkModelType, + getModelSupportedReasoningEffortOptions, isDeepSeekHybridInferenceModel, + isDoubaoSeed18Model, isDoubaoSeedAfter251015, isDoubaoThinkingAutoModel, isGemini3ThinkingTokenModel, - isGPT51SeriesModel, isGrok4FastReasoningModel, isOpenAIDeepResearchModel, isOpenAIModel, @@ -28,14 +28,15 @@ import { isSupportedThinkingTokenDoubaoModel, isSupportedThinkingTokenGeminiModel, isSupportedThinkingTokenHunyuanModel, + isSupportedThinkingTokenMiMoModel, isSupportedThinkingTokenModel, isSupportedThinkingTokenQwenModel, isSupportedThinkingTokenZhipuModel, - MODEL_SUPPORTED_REASONING_EFFORT + isSupportNoneReasoningEffortModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService' -import type { Assistant, Model } from '@renderer/types' +import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types' import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types' import type { OpenAIReasoningSummary } from '@renderer/types/aiCoreTypes' import type { ReasoningEffortOptionalParams } from '@renderer/types/sdk' @@ -65,7 +66,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin // reasoningEffort is not set, no extra reasoning setting // Generally, for every model which supports reasoning control, the reasoning effort won't be undefined. // It's for some reasoning models that don't support reasoning control, such as deepseek reasoner. - if (!reasoningEffort) { + if (!reasoningEffort || reasoningEffort === 'default') { return {} } @@ -73,9 +74,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin if (reasoningEffort === 'none') { // openrouter: use reasoning if (model.provider === SystemProviderIds.openrouter) { - // 'none' is not an available value for effort for now. - // I think they should resolve this issue soon, so I'll just go ahead and use this value. - if (isGPT51SeriesModel(model) && reasoningEffort === 'none') { + if (isSupportNoneReasoningEffortModel(model) && reasoningEffort === 'none') { return { reasoning: { effort: 'none' } } } return { reasoning: { enabled: false, exclude: true } } @@ -119,8 +118,8 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin return { thinking: { type: 'disabled' } } } - // Specially for GPT-5.1. Suppose this is a OpenAI Compatible provider - if (isGPT51SeriesModel(model)) { + // GPT 5.1, GPT 5.2, or newer + if (isSupportNoneReasoningEffortModel(model)) { return { reasoningEffort: 'none' } @@ -330,16 +329,15 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin // Grok models/Perplexity models/OpenAI models, use reasoning_effort if (isSupportedReasoningEffortModel(model)) { // 检查模型是否支持所选选项 - const modelType = getThinkModelType(model) - const supportedOptions = MODEL_SUPPORTED_REASONING_EFFORT[modelType] - if (supportedOptions.includes(reasoningEffort)) { + const supportedOptions = getModelSupportedReasoningEffortOptions(model)?.filter((option) => option !== 'default') + if (supportedOptions?.includes(reasoningEffort)) { return { reasoningEffort } } else { // 如果不支持,fallback到第一个支持的值 return { - reasoningEffort: supportedOptions[0] + reasoningEffort: supportedOptions?.[0] } } } @@ -391,7 +389,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin // Use thinking, doubao, zhipu, etc. if (isSupportedThinkingTokenDoubaoModel(model)) { - if (isDoubaoSeedAfter251015(model)) { + if (isDoubaoSeedAfter251015(model) || isDoubaoSeed18Model(model)) { return { reasoningEffort } } if (reasoningEffort === 'high') { @@ -410,6 +408,12 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin return { thinking: { type: 'enabled' } } } + if (isSupportedThinkingTokenMiMoModel(model)) { + return { + thinking: { type: 'enabled' } + } + } + // Default case: no special thinking settings return {} } @@ -429,7 +433,7 @@ export function getOpenAIReasoningParams( let reasoningEffort = assistant?.settings?.reasoning_effort - if (!reasoningEffort) { + if (!reasoningEffort || reasoningEffort === 'default') { return {} } @@ -481,16 +485,14 @@ export function getAnthropicThinkingBudget( return undefined } - const budgetTokens = Math.max( - 1024, - Math.floor( - Math.min( - (tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min, - (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio - ) - ) - ) - return budgetTokens + const budget = Math.floor((tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min) + + let budgetTokens = budget + if (maxTokens !== undefined) { + budgetTokens = Math.min(budget, maxTokens) + } + + return Math.max(1024, budgetTokens) } /** @@ -507,7 +509,11 @@ export function getAnthropicReasoningParams( const reasoningEffort = assistant?.settings?.reasoning_effort - if (reasoningEffort === undefined || reasoningEffort === 'none') { + if (!reasoningEffort || reasoningEffort === 'default') { + return {} + } + + if (reasoningEffort === 'none') { return { thinking: { type: 'disabled' @@ -531,20 +537,25 @@ export function getAnthropicReasoningParams( return {} } -// type GoogleThinkingLevel = NonNullable['thinkingLevel'] +type GoogleThinkingLevel = NonNullable['thinkingLevel'] -// function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogelThinkingLevel { -// switch (reasoningEffort) { -// case 'low': -// return 'low' -// case 'medium': -// return 'medium' -// case 'high': -// return 'high' -// default: -// return 'medium' -// } -// } +function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogleThinkingLevel { + switch (reasoningEffort) { + case 'default': + return undefined + case 'minimal': + return 'minimal' + case 'low': + return 'low' + case 'medium': + return 'medium' + case 'high': + return 'high' + default: + logger.warn('Unknown thinking level for Gemini. Fallback to medium instead.', { reasoningEffort }) + return 'medium' + } +} /** * 获取 Gemini 推理参数 @@ -562,6 +573,10 @@ export function getGeminiReasoningParams( const reasoningEffort = assistant?.settings?.reasoning_effort + if (!reasoningEffort || reasoningEffort === 'default') { + return {} + } + // Gemini 推理参数 if (isSupportedThinkingTokenGeminiModel(model)) { if (reasoningEffort === undefined || reasoningEffort === 'none') { @@ -573,15 +588,15 @@ export function getGeminiReasoningParams( } } - // TODO: 很多中转还不支持 // https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#new_api_features_in_gemini_3 - // if (isGemini3ThinkingTokenModel(model)) { - // return { - // thinkingConfig: { - // thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort) - // } - // } - // } + if (isGemini3ThinkingTokenModel(model)) { + return { + thinkingConfig: { + includeThoughts: true, + thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort) + } + } + } const effortRatio = EFFORT_RATIO[reasoningEffort] @@ -622,10 +637,6 @@ export function getXAIReasoningParams(assistant: Assistant, model: Model): Pick< const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant) - if (!reasoningEffort || reasoningEffort === 'none') { - return {} - } - switch (reasoningEffort) { case 'auto': case 'minimal': @@ -636,6 +647,10 @@ export function getXAIReasoningParams(assistant: Assistant, model: Model): Pick< return { reasoningEffort } case 'xhigh': return { reasoningEffort: 'high' } + case 'default': + case 'none': + default: + return {} } } @@ -652,7 +667,7 @@ export function getBedrockReasoningParams( const reasoningEffort = assistant?.settings?.reasoning_effort - if (reasoningEffort === undefined) { + if (reasoningEffort === undefined || reasoningEffort === 'default') { return {} } diff --git a/src/renderer/src/aiCore/utils/websearch.ts b/src/renderer/src/aiCore/utils/websearch.ts index 127636a50b..14a99139be 100644 --- a/src/renderer/src/aiCore/utils/websearch.ts +++ b/src/renderer/src/aiCore/utils/websearch.ts @@ -9,6 +9,8 @@ import type { CherryWebSearchConfig } from '@renderer/store/websearch' import type { Model } from '@renderer/types' import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern' +const X_AI_MAX_SEARCH_RESULT = 30 + export function getWebSearchParams(model: Model): Record { if (model.provider === 'hunyuan') { return { enable_enhancement: true, citation: true, search_info: true } @@ -82,7 +84,7 @@ export function buildProviderBuiltinWebSearchConfig( const excludeDomains = mapRegexToPatterns(webSearchConfig.excludeDomains) return { xai: { - maxSearchResults: webSearchConfig.maxResults, + maxSearchResults: Math.min(webSearchConfig.maxResults, X_AI_MAX_SEARCH_RESULT), returnCitations: true, sources: [ { diff --git a/src/renderer/src/assets/images/apps/aistudio.png b/src/renderer/src/assets/images/apps/aistudio.png new file mode 100644 index 0000000000..c7cb2adebe Binary files /dev/null and b/src/renderer/src/assets/images/apps/aistudio.png differ diff --git a/src/renderer/src/assets/images/apps/aistudio.svg b/src/renderer/src/assets/images/apps/aistudio.svg deleted file mode 100644 index 2c08015593..0000000000 --- a/src/renderer/src/assets/images/apps/aistudio.svg +++ /dev/null @@ -1,27 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/src/renderer/src/assets/images/models/mimo.svg b/src/renderer/src/assets/images/models/mimo.svg new file mode 100644 index 0000000000..82370fece3 --- /dev/null +++ b/src/renderer/src/assets/images/models/mimo.svg @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/src/renderer/src/assets/images/providers/mimo.svg b/src/renderer/src/assets/images/providers/mimo.svg new file mode 100644 index 0000000000..82370fece3 --- /dev/null +++ b/src/renderer/src/assets/images/providers/mimo.svg @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/src/renderer/src/assets/images/search/baidu.svg b/src/renderer/src/assets/images/search/baidu.svg new file mode 100644 index 0000000000..ead7f89822 --- /dev/null +++ b/src/renderer/src/assets/images/search/baidu.svg @@ -0,0 +1 @@ +Baidu \ No newline at end of file diff --git a/src/renderer/src/assets/images/search/bing.svg b/src/renderer/src/assets/images/search/bing.svg new file mode 100644 index 0000000000..b411a4f068 --- /dev/null +++ b/src/renderer/src/assets/images/search/bing.svg @@ -0,0 +1 @@ +Bing \ No newline at end of file diff --git a/src/renderer/src/assets/images/search/google.svg b/src/renderer/src/assets/images/search/google.svg new file mode 100644 index 0000000000..e8e0f867bd --- /dev/null +++ b/src/renderer/src/assets/images/search/google.svg @@ -0,0 +1 @@ +Google \ No newline at end of file diff --git a/src/renderer/src/components/Avatar/AssistantAvatar.tsx b/src/renderer/src/components/Avatar/AssistantAvatar.tsx new file mode 100644 index 0000000000..7ab87e633e --- /dev/null +++ b/src/renderer/src/components/Avatar/AssistantAvatar.tsx @@ -0,0 +1,34 @@ +import EmojiIcon from '@renderer/components/EmojiIcon' +import { useSettings } from '@renderer/hooks/useSettings' +import { getDefaultModel } from '@renderer/services/AssistantService' +import type { Assistant } from '@renderer/types' +import { getLeadingEmoji } from '@renderer/utils' +import type { FC } from 'react' +import { useMemo } from 'react' + +import ModelAvatar from './ModelAvatar' + +interface AssistantAvatarProps { + assistant: Assistant + size?: number + className?: string +} + +const AssistantAvatar: FC = ({ assistant, size = 24, className }) => { + const { assistantIconType } = useSettings() + const defaultModel = getDefaultModel() + + const assistantName = useMemo(() => assistant.name || '', [assistant.name]) + + if (assistantIconType === 'model') { + return + } + + if (assistantIconType === 'emoji') { + return + } + + return null +} + +export default AssistantAvatar diff --git a/src/renderer/src/components/Avatar/ModelAvatar.tsx b/src/renderer/src/components/Avatar/ModelAvatar.tsx index 9ce6d87c46..cff23486f7 100644 --- a/src/renderer/src/components/Avatar/ModelAvatar.tsx +++ b/src/renderer/src/components/Avatar/ModelAvatar.tsx @@ -1,7 +1,8 @@ import type { AvatarProps } from '@cherrystudio/ui' -import { Avatar, cn } from '@cherrystudio/ui' +import { Avatar } from '@cherrystudio/ui' import { getModelLogo } from '@renderer/config/models' import type { Model } from '@renderer/types' +import { cn } from '@renderer/utils' import { first } from 'lodash' import type { FC } from 'react' diff --git a/src/renderer/src/components/Buttons/ActionIconButton.tsx b/src/renderer/src/components/Buttons/ActionIconButton.tsx index 221a5eeb30..ec1da45ab8 100644 --- a/src/renderer/src/components/Buttons/ActionIconButton.tsx +++ b/src/renderer/src/components/Buttons/ActionIconButton.tsx @@ -1,4 +1,5 @@ -import { Button, cn } from '@cherrystudio/ui' +import { Button } from '@cherrystudio/ui' +import { cn } from '@renderer/utils' import React, { memo } from 'react' interface ActionIconButtonProps extends Omit, 'ref'> { diff --git a/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx b/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx index 36bfc559da..05c3f230e1 100644 --- a/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx +++ b/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx @@ -224,6 +224,7 @@ const HtmlArtifactsPopup: React.FC = ({ open, title, ht afterClose={onClose} centered={!isFullscreen} destroyOnHidden + forceRender={isFullscreen} mask={!isFullscreen} maskClosable={false} width={isFullscreen ? '100vw' : '90vw'} diff --git a/src/renderer/src/components/ContextMenu/index.tsx b/src/renderer/src/components/ContextMenu/index.tsx index 719cd14133..8db88955bd 100644 --- a/src/renderer/src/components/ContextMenu/index.tsx +++ b/src/renderer/src/components/ContextMenu/index.tsx @@ -6,6 +6,61 @@ interface ContextMenuProps { children: React.ReactNode } +/** + * Extract text content from selection, filtering out line numbers in code viewers. + * Preserves all content including plain text and code blocks, only removing line numbers. + * This ensures right-click copy in code blocks doesn't include line numbers while preserving indentation. + */ +function extractSelectedText(selection: Selection): string { + // Validate selection + if (selection.rangeCount === 0 || selection.isCollapsed) { + return '' + } + + const range = selection.getRangeAt(0) + const fragment = range.cloneContents() + + // Check if the selection contains code viewer elements + const hasLineNumbers = fragment.querySelectorAll('.line-number').length > 0 + + // If no line numbers, return the original text (preserves formatting) + if (!hasLineNumbers) { + return selection.toString() + } + + // Remove all line number elements + fragment.querySelectorAll('.line-number').forEach((el) => el.remove()) + + // Handle all content using optimized TreeWalker with precise node filtering + // This approach handles mixed content correctly while improving performance + const walker = document.createTreeWalker(fragment, NodeFilter.SHOW_TEXT | NodeFilter.SHOW_ELEMENT, null) + + let result = '' + let node = walker.nextNode() + + while (node) { + if (node.nodeType === Node.TEXT_NODE) { + // Preserve text content including whitespace + result += node.textContent + } else if (node.nodeType === Node.ELEMENT_NODE) { + const element = node as Element + + // Add newline after block elements and code lines to preserve structure + if (['H1', 'H2', 'H3', 'H4', 'H5', 'H6'].includes(element.tagName)) { + result += '\n' + } else if (element.classList.contains('line')) { + // Add newline after code lines to preserve code structure + result += '\n' + } + } + + node = walker.nextNode() + } + + // Clean up excessive newlines but preserve code structure + return result.trim() +} + // FIXME: Why does this component name look like a generic component but is not customizable at all? const ContextMenu: React.FC = ({ children }) => { const { t } = useTranslation() @@ -45,8 +100,12 @@ const ContextMenu: React.FC = ({ children }) => { const onOpenChange = (open: boolean) => { if (open) { - const selectedText = window.getSelection()?.toString() - setSelectedText(selectedText) + const selection = window.getSelection() + if (!selection || selection.rangeCount === 0 || selection.isCollapsed) { + setSelectedText(undefined) + return + } + setSelectedText(extractSelectedText(selection) || undefined) } } diff --git a/src/renderer/src/components/EmojiPicker/index.tsx b/src/renderer/src/components/EmojiPicker/index.tsx index 16561ea9be..245b997704 100644 --- a/src/renderer/src/components/EmojiPicker/index.tsx +++ b/src/renderer/src/components/EmojiPicker/index.tsx @@ -45,6 +45,7 @@ const i18nMap: Record = { 'fr-FR': fr, 'ja-JP': ja, 'pt-PT': pt_PT, + 'ro-RO': en, // No Romanian available, fallback to English 'ru-RU': ru_RU } @@ -60,6 +61,7 @@ const dataSourceMap: Record = { 'fr-FR': dataFR, 'ja-JP': dataJA, 'pt-PT': dataPT, + 'ro-RO': dataEN, // No Romanian CLDR available, fallback to English 'ru-RU': dataRU } @@ -75,6 +77,7 @@ const localeMap: Record = { 'fr-FR': 'fr', 'ja-JP': 'ja', 'pt-PT': 'pt', + 'ro-RO': 'en', 'ru-RU': 'ru' } diff --git a/src/renderer/src/components/Icons/SVGIcon.tsx b/src/renderer/src/components/Icons/SVGIcon.tsx index f866ad21d5..7b2d3549f7 100644 --- a/src/renderer/src/components/Icons/SVGIcon.tsx +++ b/src/renderer/src/components/Icons/SVGIcon.tsx @@ -113,6 +113,18 @@ export function MdiLightbulbOn(props: SVGProps) { ) } +export function MdiLightbulbQuestion(props: SVGProps) { + // {/* Icon from Material Design Icons by Pictogrammers - https://github.com/Templarian/MaterialDesign/blob/master/LICENSE */} + return ( + + + + ) +} + export function BingLogo(props: SVGProps) { return ( ) { ) } +export function McpLogo(props: SVGProps) { + return ( + + ModelContextProtocol + + + + ) +} + export function PoeLogo(props: SVGProps) { return ( void -} - -type ConnectionPhase = 'initializing' | 'waiting_qr_scan' | 'connecting' | 'connected' | 'disconnected' | 'error' -type TransferPhase = 'idle' | 'preparing' | 'sending' | 'completed' | 'error' - -const LoadingQRCode: React.FC = () => { - const { t } = useTranslation() - return ( -
- - - {t('settings.data.export_to_phone.lan.generating_qr')} - -
- ) -} - -const ScanQRCode: React.FC<{ qrCodeValue: string }> = ({ qrCodeValue }) => { - const { t } = useTranslation() - return ( -
- - - {t('settings.data.export_to_phone.lan.scan_qr')} - -
- ) -} - -const ConnectingAnimation: React.FC = () => { - const { t } = useTranslation() - return ( -
-
- - - {t('settings.data.export_to_phone.lan.status.connecting')} - -
-
- ) -} - -const ConnectedDisplay: React.FC = () => { - const { t } = useTranslation() - return ( -
-
- 📱 - - {t('settings.data.export_to_phone.lan.connected')} - -
-
- ) -} - -const ErrorQRCode: React.FC<{ error: string | null }> = ({ error }) => { - const { t } = useTranslation() - return ( -
- ⚠️ - - {t('settings.data.export_to_phone.lan.connection_failed')} - - {error && {error}} -
- ) -} - -const PopupContainer: React.FC = ({ resolve }) => { - const [isOpen, setIsOpen] = useState(true) - const [connectionPhase, setConnectionPhase] = useState('initializing') - const [transferPhase, setTransferPhase] = useState('idle') - const [qrCodeValue, setQrCodeValue] = useState('') - const [selectedFolderPath, setSelectedFolderPath] = useState(null) - const [sendProgress, setSendProgress] = useState(0) - const [error, setError] = useState(null) - const [autoCloseCountdown, setAutoCloseCountdown] = useState(null) - - const { t } = useTranslation() - - // 派生状态 - const isConnected = connectionPhase === 'connected' - const canSend = isConnected && selectedFolderPath && transferPhase === 'idle' - const isSending = transferPhase === 'preparing' || transferPhase === 'sending' - - // 状态文本映射 - const connectionStatusText = useMemo(() => { - const statusMap = { - initializing: t('settings.data.export_to_phone.lan.status.initializing'), - waiting_qr_scan: t('settings.data.export_to_phone.lan.status.waiting_qr_scan'), - connecting: t('settings.data.export_to_phone.lan.status.connecting'), - connected: t('settings.data.export_to_phone.lan.status.connected'), - disconnected: t('settings.data.export_to_phone.lan.status.disconnected'), - error: t('settings.data.export_to_phone.lan.status.error') - } - return statusMap[connectionPhase] - }, [connectionPhase, t]) - - const transferStatusText = useMemo(() => { - const statusMap = { - idle: '', - preparing: t('settings.data.export_to_phone.lan.status.preparing'), - sending: t('settings.data.export_to_phone.lan.status.sending'), - completed: t('settings.data.export_to_phone.lan.status.completed'), - error: t('settings.data.export_to_phone.lan.status.error') - } - return statusMap[transferPhase] - }, [transferPhase, t]) - - // 状态样式映射 - const connectionStatusStyles = useMemo(() => { - const styleMap = { - initializing: { - bg: 'var(--color-background-mute)', - border: 'var(--color-border-mute)' - }, - waiting_qr_scan: { - bg: 'var(--color-primary-mute)', - border: 'var(--color-primary-soft)' - }, - connecting: { bg: 'var(--color-status-warning)', border: 'var(--color-status-warning)' }, - connected: { - bg: 'var(--color-status-success)', - border: 'var(--color-status-success)' - }, - disconnected: { bg: 'var(--color-error)', border: 'var(--color-error)' }, - error: { bg: 'var(--color-error)', border: 'var(--color-error)' } - } - return styleMap[connectionPhase] - }, [connectionPhase]) - - const initWebSocket = useCallback(async () => { - try { - setConnectionPhase('initializing') - await window.api.webSocket.start() - const { port, ip } = await window.api.webSocket.status() - - if (ip && port) { - const candidatesData = await window.api.webSocket.getAllCandidates() - - const optimizeConnectionInfo = () => { - const ipToNumber = (ip: string) => { - return ip.split('.').reduce((acc, octet) => (acc << 8) + parseInt(octet), 0) - } - - const compressedData = [ - 'CSA', - ipToNumber(ip), - candidatesData.map((candidate: WebSocketCandidatesResponse) => ipToNumber(candidate.host)), - port, // 端口号 - Date.now() % 86400000 - ] - - return compressedData - } - - const compressedData = optimizeConnectionInfo() - const qrCodeValue = JSON.stringify(compressedData) - setQrCodeValue(qrCodeValue) - setConnectionPhase('waiting_qr_scan') - } else { - setError(t('settings.data.export_to_phone.lan.error.no_ip')) - setConnectionPhase('error') - } - } catch (error) { - setError( - `${t('settings.data.export_to_phone.lan.error.init_failed')}: ${error instanceof Error ? error.message : ''}` - ) - setConnectionPhase('error') - logger.error('Failed to initialize WebSocket:', error as Error) - } - }, [t]) - - const handleClientConnected = useCallback((_event: any, data: { connected: boolean }) => { - logger.info(`Client connection status: ${data.connected ? 'connected' : 'disconnected'}`) - if (data.connected) { - setConnectionPhase('connected') - setError(null) - } else { - setConnectionPhase('disconnected') - } - }, []) - - const handleMessageReceived = useCallback((_event: any, data: any) => { - logger.info(`Received message from mobile: ${JSON.stringify(data)}`) - }, []) - - const handleSendProgress = useCallback( - (_event: any, data: { progress: number }) => { - const progress = data.progress - setSendProgress(progress) - - if (transferPhase === 'preparing' && progress > 0) { - setTransferPhase('sending') - } - - if (progress >= 100) { - setTransferPhase('completed') - // 启动 3 秒倒计时自动关闭 - setAutoCloseCountdown(3) - } - }, - [transferPhase] - ) - - const handleSelectZip = useCallback(async () => { - const result = await window.api.file.select() - if (result) { - setSelectedFolderPath(result[0].path) - } - }, []) - - const handleSendZip = useCallback(async () => { - if (!selectedFolderPath) { - setError(t('settings.data.export_to_phone.lan.error.no_file')) - return - } - - setTransferPhase('preparing') - setError(null) - setSendProgress(0) - - try { - logger.info(`Starting file transfer: ${selectedFolderPath}`) - await window.api.webSocket.sendFile(selectedFolderPath) - } catch (error) { - setError( - `${t('settings.data.export_to_phone.lan.error.send_failed')}: ${error instanceof Error ? error.message : ''}` - ) - setTransferPhase('error') - logger.error('Failed to send file:', error as Error) - } - }, [selectedFolderPath, t]) - - // 尝试关闭弹窗 - 如果正在传输则显示确认 - const handleCancel = useCallback(() => { - if (isSending) { - window.modal.confirm({ - title: t('settings.data.export_to_phone.lan.confirm_close_title'), - content: t('settings.data.export_to_phone.lan.confirm_close_message'), - centered: true, - okButtonProps: { - danger: true - }, - okText: t('settings.data.export_to_phone.lan.force_close'), - onOk: () => setIsOpen(false) - }) - } else { - setIsOpen(false) - } - }, [isSending, t]) - - // 清理并关闭 - const handleClose = useCallback(async () => { - try { - // 主动断开 WebSocket 连接 - if (isConnected || connectionPhase !== 'disconnected') { - logger.info('Closing popup, stopping WebSocket') - await window.api.webSocket.stop() - } - } catch (error) { - logger.error('Failed to stop WebSocket on close:', error as Error) - } - resolve({}) - }, [resolve, isConnected, connectionPhase]) - - useEffect(() => { - initWebSocket() - - const removeClientConnectedListener = window.electron.ipcRenderer.on( - 'websocket-client-connected', - handleClientConnected - ) - const removeMessageReceivedListener = window.electron.ipcRenderer.on( - 'websocket-message-received', - handleMessageReceived - ) - const removeSendProgressListener = window.electron.ipcRenderer.on('file-send-progress', handleSendProgress) - - return () => { - removeClientConnectedListener() - removeMessageReceivedListener() - removeSendProgressListener() - window.api.webSocket.stop() - } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []) - - // 自动关闭倒计时 - useEffect(() => { - if (autoCloseCountdown === null) return - - if (autoCloseCountdown <= 0) { - logger.debug('Auto-closing popup after transfer completion') - setIsOpen(false) - return - } - - const timer = setTimeout(() => { - setAutoCloseCountdown(autoCloseCountdown - 1) - }, 1000) - - return () => clearTimeout(timer) - }, [autoCloseCountdown]) - - // 状态指示器组件 - const StatusIndicator = useCallback( - () => ( -
- {connectionStatusText} -
- ), - [connectionStatusStyles, connectionStatusText] - ) - - // 二维码显示组件 - 使用显式条件渲染以避免类型不匹配 - const QRCodeDisplay = useCallback(() => { - switch (connectionPhase) { - case 'waiting_qr_scan': - case 'disconnected': - return - case 'initializing': - return - case 'connecting': - return - case 'connected': - return - case 'error': - return - default: - return null - } - }, [connectionPhase, qrCodeValue, error]) - - // 传输进度组件 - const TransferProgress = useCallback(() => { - if (!isSending && transferPhase !== 'completed') return null - - return ( -
-
-
- - {t('settings.data.export_to_phone.lan.transfer_progress')} - - - {transferPhase === 'completed' ? '✅ ' + t('common.completed') : `${Math.round(sendProgress)}%`} - -
- - -
-
- ) - }, [isSending, transferPhase, sendProgress, t]) - - const AutoCloseCountdown = useCallback(() => { - if (transferPhase !== 'completed' || autoCloseCountdown === null || autoCloseCountdown <= 0) return null - - return ( -
- {t('settings.data.export_to_phone.lan.auto_close_tip', { seconds: autoCloseCountdown })} -
- ) - }, [transferPhase, autoCloseCountdown, t]) - - // 错误显示组件 - const ErrorDisplay = useCallback(() => { - if (!error || transferPhase !== 'error') return null - - return ( -
- ❌ {error} -
- ) - }, [error, transferPhase]) - - return ( - - - - - - - - - - - - -
- - -
-
- - - {selectedFolderPath || t('settings.data.export_to_phone.lan.noZipSelected')} - - - - - -
- ) -} - -const TopViewKey = 'ExportToPhoneLanPopup' - -export default class ExportToPhoneLanPopup { - static topviewId = 0 - static hide() { - TopView.hide(TopViewKey) - } - static show() { - return new Promise((resolve) => { - TopView.show( - { - resolve(v) - TopView.hide(TopViewKey) - }} - />, - TopViewKey - ) - }) - } -} diff --git a/src/renderer/src/components/Popups/LanTransferPopup/LanDeviceCard.tsx b/src/renderer/src/components/Popups/LanTransferPopup/LanDeviceCard.tsx new file mode 100644 index 0000000000..db16112e04 --- /dev/null +++ b/src/renderer/src/components/Popups/LanTransferPopup/LanDeviceCard.tsx @@ -0,0 +1,97 @@ +import { cn } from '@renderer/utils' +import type { FC, KeyboardEventHandler } from 'react' +import { useTranslation } from 'react-i18next' + +import { ProgressIndicator } from './ProgressIndicator' +import type { LanDeviceCardProps } from './types' + +export const LanDeviceCard: FC = ({ + service, + transferState, + isConnected, + handshakeInProgress, + isDisabled, + onSendFile +}) => { + const { t } = useTranslation() + + // Device info + const deviceName = service.txt?.modelName || t('common.unknown') + const platform = service.txt?.platform + const appVersion = service.txt?.appVersion + const platformInfo = [platform, appVersion].filter(Boolean).join(' ') + const displayTitle = platformInfo ? `${deviceName} (${platformInfo})` : deviceName + + // Address info + const primaryAddress = service.addresses?.[0] + const addressesWithPort = primaryAddress ? (service.port ? `${primaryAddress}:${service.port}` : primaryAddress) : '' + + // Progress visibility + const shouldShowProgress = + transferState && ['selecting', 'transferring', 'completed', 'failed'].includes(transferState.status) + + // Status text + const statusText = handshakeInProgress + ? t('settings.data.export_to_phone.lan.handshake.in_progress') + : isConnected + ? t('settings.data.export_to_phone.lan.connected') + : t('settings.data.export_to_phone.lan.send_file') + + // Event handlers + const handleClick = () => { + if (isDisabled) return + onSendFile(service.id) + } + + const handleKeyDown: KeyboardEventHandler = (event) => { + if (event.key === 'Enter' || event.key === ' ') { + event.preventDefault() + handleClick() + } + } + + return ( +
+ {/* Header */} +
+
+
{displayTitle}
+ {statusText} +
+
+ + {/* Meta Row - IP Address */} +
+ + {t('settings.data.export_to_phone.lan.ip_addresses')} + + {addressesWithPort || t('common.unknown')} +
+ + {/* Footer with Progress */} +
+ {shouldShowProgress && transferState && ( + + )} +
+
+ ) +} diff --git a/src/renderer/src/components/Popups/LanTransferPopup/ProgressIndicator.tsx b/src/renderer/src/components/Popups/LanTransferPopup/ProgressIndicator.tsx new file mode 100644 index 0000000000..b9707b4485 --- /dev/null +++ b/src/renderer/src/components/Popups/LanTransferPopup/ProgressIndicator.tsx @@ -0,0 +1,55 @@ +import { cn } from '@renderer/utils' +import type { FC } from 'react' +import { useTranslation } from 'react-i18next' + +import type { ProgressIndicatorProps } from './types' + +export const ProgressIndicator: FC = ({ transferState, handshakeInProgress }) => { + const { t } = useTranslation() + + const progressPercent = Math.min(100, Math.max(0, transferState.progress ?? 0)) + + const progressLabel = (() => { + if (transferState.status === 'failed') { + return transferState.error || t('common.unknown_error') + } + if (transferState.status === 'selecting') { + return handshakeInProgress + ? t('settings.data.export_to_phone.lan.handshake.in_progress') + : t('settings.data.export_to_phone.lan.status.preparing') + } + return `${Math.round(progressPercent)}%` + })() + + const isFailed = transferState.status === 'failed' + const isCompleted = transferState.status === 'completed' + + return ( +
+ {/* Label Row */} +
+ {transferState.fileName} + {progressLabel} +
+ + {/* Progress Track */} +
+
+
+
+ ) +} diff --git a/src/renderer/src/components/Popups/LanTransferPopup/hook.ts b/src/renderer/src/components/Popups/LanTransferPopup/hook.ts new file mode 100644 index 0000000000..6d2ea77527 --- /dev/null +++ b/src/renderer/src/components/Popups/LanTransferPopup/hook.ts @@ -0,0 +1,397 @@ +import { loggerService } from '@logger' +import { getBackupData } from '@renderer/services/BackupService' +import type { LocalTransferPeer } from '@shared/config/types' +import { useCallback, useEffect, useMemo, useReducer, useRef } from 'react' +import { useTranslation } from 'react-i18next' + +import type { LanPeerTransferState, LanTransferAction, LanTransferReducerState } from './types' + +const logger = loggerService.withContext('useLanTransfer') + +// ========================================== +// Initial State +// ========================================== + +export const initialState: LanTransferReducerState = { + open: true, + lanState: null, + lanHandshakePeerId: null, + lastHandshakeResult: null, + fileTransferState: {}, + tempBackupPath: null +} + +// ========================================== +// Reducer +// ========================================== + +export function lanTransferReducer(state: LanTransferReducerState, action: LanTransferAction): LanTransferReducerState { + switch (action.type) { + case 'SET_OPEN': + return { ...state, open: action.payload } + + case 'SET_LAN_STATE': + return { ...state, lanState: action.payload } + + case 'SET_HANDSHAKE_PEER_ID': + return { ...state, lanHandshakePeerId: action.payload } + + case 'SET_HANDSHAKE_RESULT': + return { ...state, lastHandshakeResult: action.payload } + + case 'SET_TEMP_BACKUP_PATH': + return { ...state, tempBackupPath: action.payload } + + case 'UPDATE_TRANSFER_STATE': { + const { peerId, state: transferState } = action.payload + return { + ...state, + fileTransferState: { + ...state.fileTransferState, + [peerId]: { + ...(state.fileTransferState[peerId] ?? { progress: 0, status: 'idle' as const }), + ...transferState + } + } + } + } + + case 'SET_TRANSFER_STATE': { + const { peerId, state: transferState } = action.payload + return { + ...state, + fileTransferState: { + ...state.fileTransferState, + [peerId]: transferState + } + } + } + + case 'CLEANUP_STALE_PEERS': { + const activeIds = action.payload + const newFileTransferState: Record = {} + for (const id of Object.keys(state.fileTransferState)) { + if (activeIds.has(id)) { + newFileTransferState[id] = state.fileTransferState[id] + } + } + return { + ...state, + fileTransferState: newFileTransferState, + lastHandshakeResult: + state.lastHandshakeResult && activeIds.has(state.lastHandshakeResult.peerId) + ? state.lastHandshakeResult + : null, + lanHandshakePeerId: + state.lanHandshakePeerId && activeIds.has(state.lanHandshakePeerId) ? state.lanHandshakePeerId : null + } + } + + case 'RESET_CONNECTION_STATE': + return { + ...state, + fileTransferState: {}, + lastHandshakeResult: null, + lanHandshakePeerId: null, + tempBackupPath: null + } + + default: + return state + } +} + +// ========================================== +// Hook Return Type +// ========================================== + +export interface UseLanTransferReturn { + // State + state: LanTransferReducerState + + // Derived values + lanDevices: LocalTransferPeer[] + isAnyTransferring: boolean + lastError: string | undefined + + // Actions + handleSendFile: (peerId: string) => Promise + handleModalCancel: () => void + getTransferState: (peerId: string) => LanPeerTransferState | undefined + isConnected: (peerId: string) => boolean + isHandshakeInProgress: (peerId: string) => boolean + + // Dispatch (for advanced use) + dispatch: React.Dispatch +} + +// ========================================== +// Hook +// ========================================== + +export function useLanTransfer(): UseLanTransferReturn { + const { t } = useTranslation() + const [state, dispatch] = useReducer(lanTransferReducer, initialState) + const isSendingRef = useRef(false) + + // ========================================== + // Derived Values + // ========================================== + + const lanDevices = useMemo(() => state.lanState?.services ?? [], [state.lanState]) + + const isAnyTransferring = useMemo( + () => Object.values(state.fileTransferState).some((s) => s.status === 'transferring' || s.status === 'selecting'), + [state.fileTransferState] + ) + + const lastError = state.lanState?.lastError + + // ========================================== + // LAN State Sync + // ========================================== + + const syncLanState = useCallback(async () => { + if (!window.api?.localTransfer) { + logger.warn('Local transfer bridge is unavailable') + return + } + try { + const nextState = await window.api.localTransfer.getState() + dispatch({ type: 'SET_LAN_STATE', payload: nextState }) + } catch (error) { + logger.error('Failed to sync LAN state', error as Error) + } + }, []) + + // ========================================== + // Send File Handler + // ========================================== + + const handleSendFile = useCallback( + async (peerId: string) => { + if (!window.api?.localTransfer || isSendingRef.current) { + return + } + isSendingRef.current = true + + dispatch({ + type: 'SET_TRANSFER_STATE', + payload: { peerId, state: { progress: 0, status: 'selecting' } } + }) + + let backupPath: string | null = null + + try { + // Step 0: Ensure handshake (connect if needed) + if (!state.lastHandshakeResult?.ack.accepted || state.lastHandshakeResult.peerId !== peerId) { + dispatch({ type: 'SET_HANDSHAKE_PEER_ID', payload: peerId }) + try { + const ack = await window.api.localTransfer.connect({ peerId }) + dispatch({ + type: 'SET_HANDSHAKE_RESULT', + payload: { peerId, ack, timestamp: Date.now() } + }) + if (!ack.accepted) { + throw new Error(ack.message || t('settings.data.export_to_phone.lan.connection_failed')) + } + } finally { + dispatch({ type: 'SET_HANDSHAKE_PEER_ID', payload: null }) + } + } + + // Step 1: Create temporary backup + logger.info('Creating temporary backup for LAN transfer...') + const backupData = await getBackupData() + backupPath = await window.api.backup.createLanTransferBackup(backupData) + dispatch({ type: 'SET_TEMP_BACKUP_PATH', payload: backupPath }) + + // Extract filename from path + const fileName = backupPath.split(/[/\\]/).pop() || 'backup.zip' + + // Step 2: Set transferring state + dispatch({ + type: 'UPDATE_TRANSFER_STATE', + payload: { peerId, state: { fileName, progress: 0, status: 'transferring' } } + }) + + // Step 3: Send file + logger.info(`Sending backup file: ${backupPath}`) + const result = await window.api.localTransfer.sendFile(backupPath) + + if (result.success) { + dispatch({ + type: 'UPDATE_TRANSFER_STATE', + payload: { peerId, state: { progress: 100, status: 'completed' } } + }) + } else { + dispatch({ + type: 'UPDATE_TRANSFER_STATE', + payload: { peerId, state: { status: 'failed', error: result.error } } + }) + } + } catch (error) { + const message = error instanceof Error ? error.message : String(error) + dispatch({ + type: 'UPDATE_TRANSFER_STATE', + payload: { peerId, state: { status: 'failed', error: message } } + }) + logger.error('Failed to send file', error as Error) + } finally { + // Step 4: Clean up temp file + if (backupPath) { + try { + await window.api.backup.deleteTempBackup(backupPath) + logger.info('Cleaned up temporary backup file') + } catch (cleanupError) { + logger.warn('Failed to clean up temp backup', cleanupError as Error) + } + dispatch({ type: 'SET_TEMP_BACKUP_PATH', payload: null }) + } + isSendingRef.current = false + } + }, + [state.lastHandshakeResult, t] + ) + + // ========================================== + // Teardown + // ========================================== + + // Use ref to track temp backup path for cleanup without causing effect re-runs + const tempBackupPathRef = useRef(null) + tempBackupPathRef.current = state.tempBackupPath + + const teardownLan = useCallback(async () => { + if (!window.api?.localTransfer) { + return + } + try { + await window.api.localTransfer.cancelTransfer?.() + } catch (error) { + logger.warn('Failed to cancel LAN transfer on close', error as Error) + } + try { + await window.api.localTransfer.disconnect?.() + } catch (error) { + logger.warn('Failed to disconnect LAN on close', error as Error) + } + // Clean up temp backup if exists (use ref to get current value) + if (tempBackupPathRef.current) { + try { + await window.api.backup.deleteTempBackup(tempBackupPathRef.current) + } catch (error) { + logger.warn('Failed to cleanup temp backup on close', error as Error) + } + } + dispatch({ type: 'RESET_CONNECTION_STATE' }) + }, []) // No dependencies - uses ref for current value + + const handleModalCancel = useCallback(() => { + void teardownLan() + dispatch({ type: 'SET_OPEN', payload: false }) + }, [teardownLan]) + + // ========================================== + // Effects + // ========================================== + + // Initial sync and service listener + useEffect(() => { + if (!window.api?.localTransfer) { + return + } + syncLanState() + const removeListener = window.api.localTransfer.onServicesUpdated((lanState) => { + dispatch({ type: 'SET_LAN_STATE', payload: lanState }) + }) + return () => { + removeListener?.() + } + }, [syncLanState]) + + // Client events listener (progress, completion) + useEffect(() => { + if (!window.api?.localTransfer) { + return + } + const removeListener = window.api.localTransfer.onClientEvent((event) => { + const key = event.peerId ?? 'global' + + if (event.type === 'file_transfer_progress') { + dispatch({ + type: 'UPDATE_TRANSFER_STATE', + payload: { + peerId: key, + state: { + transferId: event.transferId, + fileName: event.fileName, + progress: event.progress, + speed: event.speed, + status: 'transferring' + } + } + }) + } else if (event.type === 'file_transfer_complete') { + dispatch({ + type: 'UPDATE_TRANSFER_STATE', + payload: { + peerId: key, + state: { + progress: event.success ? 100 : undefined, + status: event.success ? 'completed' : 'failed', + error: event.error + } + } + }) + } + }) + return () => { + removeListener?.() + } + }, []) + + // Cleanup stale peers when services change + useEffect(() => { + const activeIds = new Set(lanDevices.map((s) => s.id)) + dispatch({ type: 'CLEANUP_STALE_PEERS', payload: activeIds }) + }, [lanDevices]) + + // Cleanup on unmount only (teardownLan is stable with no deps) + useEffect(() => { + return () => { + void teardownLan() + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) + + // ========================================== + // Helper Functions + // ========================================== + + const getTransferState = useCallback((peerId: string) => state.fileTransferState[peerId], [state.fileTransferState]) + + const isConnected = useCallback( + (peerId: string) => + state.lastHandshakeResult?.peerId === peerId && state.lastHandshakeResult?.ack.accepted === true, + [state.lastHandshakeResult] + ) + + const isHandshakeInProgress = useCallback( + (peerId: string) => state.lanHandshakePeerId === peerId, + [state.lanHandshakePeerId] + ) + + return { + state, + lanDevices, + isAnyTransferring, + lastError, + handleSendFile, + handleModalCancel, + getTransferState, + isConnected, + isHandshakeInProgress, + dispatch + } +} diff --git a/src/renderer/src/components/Popups/LanTransferPopup/index.tsx b/src/renderer/src/components/Popups/LanTransferPopup/index.tsx new file mode 100644 index 0000000000..66455f12a1 --- /dev/null +++ b/src/renderer/src/components/Popups/LanTransferPopup/index.tsx @@ -0,0 +1,37 @@ +import { TopView } from '@renderer/components/TopView' + +import { getHideCallback, PopupContainer } from './popup' +import type { PopupResolveData } from './types' + +// Re-export types for external use +export type { LanPeerTransferState } from './types' + +const TopViewKey = 'LanTransferPopup' + +export default class LanTransferPopup { + static topviewId = 0 + + static hide() { + // Try to use the registered callback for proper cleanup, fallback to TopView.hide + const callback = getHideCallback() + if (callback) { + callback() + } else { + TopView.hide(TopViewKey) + } + } + + static show() { + return new Promise((resolve) => { + TopView.show( + { + resolve(v) + TopView.hide(TopViewKey) + }} + />, + TopViewKey + ) + }) + } +} diff --git a/src/renderer/src/components/Popups/LanTransferPopup/popup.tsx b/src/renderer/src/components/Popups/LanTransferPopup/popup.tsx new file mode 100644 index 0000000000..34c53a6ad6 --- /dev/null +++ b/src/renderer/src/components/Popups/LanTransferPopup/popup.tsx @@ -0,0 +1,88 @@ +import { Modal } from 'antd' +import { TriangleAlert } from 'lucide-react' +import type { FC } from 'react' +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' + +import { useLanTransfer } from './hook' +import { LanDeviceCard } from './LanDeviceCard' +import type { PopupContainerProps } from './types' + +// Module-level callback for external hide access +let hideCallback: (() => void) | null = null +export const setHideCallback = (cb: () => void) => { + hideCallback = cb +} +export const getHideCallback = () => hideCallback + +export const PopupContainer: FC = ({ resolve }) => { + const { t } = useTranslation() + + const { + state, + lanDevices, + isAnyTransferring, + lastError, + handleSendFile, + handleModalCancel, + getTransferState, + isConnected, + isHandshakeInProgress + } = useLanTransfer() + + const contentTitle = useMemo(() => t('settings.data.export_to_phone.lan.title'), [t]) + + const onClose = () => resolve({}) + + // Register hide callback for external access + setHideCallback(handleModalCancel) + + return ( + +
+ {/* Error Display */} + {lastError &&
{lastError}
} + + {/* Device List */} +
+ {lanDevices.length === 0 ? ( + // Warning when no devices +
+ + + {t('settings.data.export_to_phone.lan.no_connection_warning')} + +
+ ) : ( + // Device cards + lanDevices.map((service) => { + const transferState = getTransferState(service.id) + const connected = isConnected(service.id) + const handshakeInProgress = isHandshakeInProgress(service.id) + const isCardDisabled = isAnyTransferring || handshakeInProgress + + return ( + + ) + }) + )} +
+
+
+ ) +} diff --git a/src/renderer/src/components/Popups/LanTransferPopup/types.ts b/src/renderer/src/components/Popups/LanTransferPopup/types.ts new file mode 100644 index 0000000000..644541bc27 --- /dev/null +++ b/src/renderer/src/components/Popups/LanTransferPopup/types.ts @@ -0,0 +1,84 @@ +import type { LanHandshakeAckMessage, LocalTransferPeer, LocalTransferState } from '@shared/config/types' + +// ========================================== +// Transfer Status +// ========================================== + +export type TransferStatus = 'idle' | 'selecting' | 'transferring' | 'completed' | 'failed' + +// ========================================== +// Per-Peer Transfer State +// ========================================== + +export interface LanPeerTransferState { + transferId?: string + fileName?: string + progress: number + speed?: number + status: TransferStatus + error?: string +} + +// ========================================== +// Handshake Result +// ========================================== + +export type HandshakeResult = { + peerId: string + ack: LanHandshakeAckMessage + timestamp: number +} | null + +// ========================================== +// Reducer State +// ========================================== + +export interface LanTransferReducerState { + open: boolean + lanState: LocalTransferState | null + lanHandshakePeerId: string | null + lastHandshakeResult: HandshakeResult + fileTransferState: Record + tempBackupPath: string | null +} + +// ========================================== +// Reducer Actions +// ========================================== + +export type LanTransferAction = + | { type: 'SET_OPEN'; payload: boolean } + | { type: 'SET_LAN_STATE'; payload: LocalTransferState | null } + | { type: 'SET_HANDSHAKE_PEER_ID'; payload: string | null } + | { type: 'SET_HANDSHAKE_RESULT'; payload: HandshakeResult } + | { type: 'SET_TEMP_BACKUP_PATH'; payload: string | null } + | { type: 'UPDATE_TRANSFER_STATE'; payload: { peerId: string; state: Partial } } + | { type: 'SET_TRANSFER_STATE'; payload: { peerId: string; state: LanPeerTransferState } } + | { type: 'CLEANUP_STALE_PEERS'; payload: Set } + | { type: 'RESET_CONNECTION_STATE' } + +// ========================================== +// Component Props +// ========================================== + +export interface LanDeviceCardProps { + service: LocalTransferPeer + transferState?: LanPeerTransferState + isConnected: boolean + handshakeInProgress: boolean + isDisabled: boolean + onSendFile: (peerId: string) => void +} + +export interface ProgressIndicatorProps { + transferState: LanPeerTransferState + handshakeInProgress: boolean +} + +export interface PopupResolveData { + // Empty for now, can be extended +} + +export interface PopupContainerProps { + resolve: (data: PopupResolveData) => void +} diff --git a/src/renderer/src/components/Popups/agent/AgentModal.tsx b/src/renderer/src/components/Popups/agent/AgentModal.tsx index 1e8b2a6fa7..fe77c9bf03 100644 --- a/src/renderer/src/components/Popups/agent/AgentModal.tsx +++ b/src/renderer/src/components/Popups/agent/AgentModal.tsx @@ -3,6 +3,7 @@ import { loggerService } from '@logger' import { ErrorBoundary } from '@renderer/components/ErrorBoundary' import { TopView } from '@renderer/components/TopView' import { permissionModeCards } from '@renderer/config/agent' +import { isWin } from '@renderer/config/constant' import { useAgents } from '@renderer/hooks/agents/useAgents' import { useUpdateAgent } from '@renderer/hooks/agents/useUpdateAgent' import SelectAgentBaseModelButton from '@renderer/pages/home/components/SelectAgentBaseModelButton' @@ -16,7 +17,8 @@ import type { UpdateAgentForm } from '@renderer/types' import { AgentConfigurationSchema, isAgentType } from '@renderer/types' -import { Alert, Button, Input, Modal, Select } from 'antd' +import type { GitBashPathInfo } from '@shared/config/constant' +import { Button, Input, Modal, Select } from 'antd' import { AlertTriangleIcon } from 'lucide-react' import type { ChangeEvent, FormEvent } from 'react' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' @@ -59,8 +61,7 @@ const PopupContainer: React.FC = ({ agent, afterSubmit, resolve }) => { const isEditing = (agent?: AgentWithTools) => agent !== undefined const [form, setForm] = useState(() => buildAgentForm(agent)) - const [hasGitBash, setHasGitBash] = useState(true) - const [customGitBashPath, setCustomGitBashPath] = useState('') + const [gitBashPathInfo, setGitBashPathInfo] = useState({ path: null, source: null }) useEffect(() => { if (open) { @@ -68,29 +69,15 @@ const PopupContainer: React.FC = ({ agent, afterSubmit, resolve }) => { } }, [agent, open]) - const checkGitBash = useCallback( - async (showToast = false) => { - try { - const [gitBashInstalled, savedPath] = await Promise.all([ - window.api.system.checkGitBash(), - window.api.system.getGitBashPath().catch(() => null) - ]) - setCustomGitBashPath(savedPath ?? '') - setHasGitBash(gitBashInstalled) - if (showToast) { - if (gitBashInstalled) { - window.toast.success(t('agent.gitBash.success', 'Git Bash detected successfully!')) - } else { - window.toast.error(t('agent.gitBash.notFound', 'Git Bash not found. Please install it first.')) - } - } - } catch (error) { - logger.error('Failed to check Git Bash:', error as Error) - setHasGitBash(true) // Default to true on error to avoid false warnings - } - }, - [t] - ) + const checkGitBash = useCallback(async () => { + if (!isWin) return + try { + const pathInfo = await window.api.system.getGitBashPathInfo() + setGitBashPathInfo(pathInfo) + } catch (error) { + logger.error('Failed to check Git Bash:', error as Error) + } + }, []) useEffect(() => { checkGitBash() @@ -119,24 +106,22 @@ const PopupContainer: React.FC = ({ agent, afterSubmit, resolve }) => { return } - setCustomGitBashPath(pickedPath) - await checkGitBash(true) + await checkGitBash() } catch (error) { logger.error('Failed to pick Git Bash path', error as Error) window.toast.error(t('agent.gitBash.pick.failed', 'Failed to set Git Bash path')) } }, [checkGitBash, t]) - const handleClearGitBash = useCallback(async () => { + const handleResetGitBash = useCallback(async () => { try { + // Clear manual setting and re-run auto-discovery await window.api.system.setGitBashPath(null) - setCustomGitBashPath('') - await checkGitBash(true) + await checkGitBash() } catch (error) { - logger.error('Failed to clear Git Bash path', error as Error) - window.toast.error(t('agent.gitBash.pick.failed', 'Failed to set Git Bash path')) + logger.error('Failed to reset Git Bash path', error as Error) } - }, [checkGitBash, t]) + }, [checkGitBash]) const onPermissionModeChange = useCallback((value: PermissionMode) => { setForm((prev) => { @@ -268,6 +253,12 @@ const PopupContainer: React.FC = ({ agent, afterSubmit, resolve }) => { return } + if (isWin && !gitBashPathInfo.path) { + window.toast.error(t('agent.gitBash.error.required', 'Git Bash path is required on Windows')) + loadingRef.current = false + return + } + if (isEditing(agent)) { if (!agent) { loadingRef.current = false @@ -327,7 +318,8 @@ const PopupContainer: React.FC = ({ agent, afterSubmit, resolve }) => { t, updateAgent, afterSubmit, - addAgent + addAgent, + gitBashPathInfo.path ] ) @@ -346,66 +338,6 @@ const PopupContainer: React.FC = ({ agent, afterSubmit, resolve }) => { footer={null}> - {!hasGitBash && ( - -
- {t( - 'agent.gitBash.error.description', - 'Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from' - )}{' '} - { - e.preventDefault() - window.api.openWebsite('https://git-scm.com/download/win') - }} - style={{ textDecoration: 'underline' }}> - git-scm.com - -
- - -
- } - type="error" - showIcon - style={{ marginBottom: 16 }} - /> - )} - - {hasGitBash && customGitBashPath && ( - -
- {t('agent.gitBash.customPath', { - defaultValue: 'Using custom path: {{path}}', - path: customGitBashPath - })} -
-
- - -
-
- } - type="success" - showIcon - style={{ marginBottom: 16 }} - /> - )} + {isWin && ( + +
+ + +
+ + + + {gitBashPathInfo.source === 'manual' && ( + + )} + + {gitBashPathInfo.path && gitBashPathInfo.source === 'auto' && ( + {t('agent.gitBash.autoDiscoveredHint', 'Auto-discovered')} + )} +
+ )} +