mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-24 02:20:10 +08:00
Compare commits
No commits in common. "main" and "v1.7.3" have entirely different histories.
2
.github/workflows/auto-i18n.yml
vendored
2
.github/workflows/auto-i18n.yml
vendored
@ -54,7 +54,7 @@ jobs:
|
|||||||
yarn install
|
yarn install
|
||||||
|
|
||||||
- name: 🏃♀️ Translate
|
- name: 🏃♀️ Translate
|
||||||
run: yarn i18n:sync && yarn i18n:translate
|
run: yarn sync:i18n && yarn auto:i18n
|
||||||
|
|
||||||
- name: 🔍 Format
|
- name: 🔍 Format
|
||||||
run: yarn format
|
run: yarn format
|
||||||
|
|||||||
2
.github/workflows/pr-ci.yml
vendored
2
.github/workflows/pr-ci.yml
vendored
@ -58,7 +58,7 @@ jobs:
|
|||||||
run: yarn typecheck
|
run: yarn typecheck
|
||||||
|
|
||||||
- name: i18n Check
|
- name: i18n Check
|
||||||
run: yarn i18n:check
|
run: yarn check:i18n
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: yarn test
|
run: yarn test
|
||||||
|
|||||||
48
.github/workflows/sync-to-gitcode.yml
vendored
48
.github/workflows/sync-to-gitcode.yml
vendored
@ -216,7 +216,6 @@ jobs:
|
|||||||
local filename=$(basename "$file")
|
local filename=$(basename "$file")
|
||||||
local max_retries=3
|
local max_retries=3
|
||||||
local retry=0
|
local retry=0
|
||||||
local curl_status=0
|
|
||||||
|
|
||||||
echo "Uploading: $filename"
|
echo "Uploading: $filename"
|
||||||
|
|
||||||
@ -225,45 +224,34 @@ jobs:
|
|||||||
|
|
||||||
while [ $retry -lt $max_retries ]; do
|
while [ $retry -lt $max_retries ]; do
|
||||||
# Get upload URL
|
# Get upload URL
|
||||||
curl_status=0
|
|
||||||
UPLOAD_INFO=$(curl -s --connect-timeout 30 --max-time 60 \
|
UPLOAD_INFO=$(curl -s --connect-timeout 30 --max-time 60 \
|
||||||
-H "Authorization: Bearer ${GITCODE_TOKEN}" \
|
-H "Authorization: Bearer ${GITCODE_TOKEN}" \
|
||||||
"${API_URL}/repos/${GITCODE_OWNER}/${GITCODE_REPO}/releases/${TAG_NAME}/upload_url?file_name=${encoded_filename}") || curl_status=$?
|
"${API_URL}/repos/${GITCODE_OWNER}/${GITCODE_REPO}/releases/${TAG_NAME}/upload_url?file_name=${encoded_filename}")
|
||||||
|
|
||||||
if [ $curl_status -eq 0 ]; then
|
UPLOAD_URL=$(echo "$UPLOAD_INFO" | jq -r '.url // empty')
|
||||||
UPLOAD_URL=$(echo "$UPLOAD_INFO" | jq -r '.url // empty')
|
|
||||||
|
|
||||||
if [ -n "$UPLOAD_URL" ]; then
|
if [ -n "$UPLOAD_URL" ]; then
|
||||||
# Write headers to temp file to avoid shell escaping issues
|
# Write headers to temp file to avoid shell escaping issues
|
||||||
echo "$UPLOAD_INFO" | jq -r '.headers | to_entries[] | "header = \"" + .key + ": " + .value + "\""' > /tmp/upload_headers.txt
|
echo "$UPLOAD_INFO" | jq -r '.headers | to_entries[] | "header = \"" + .key + ": " + .value + "\""' > /tmp/upload_headers.txt
|
||||||
|
|
||||||
# Upload file using PUT with headers from file
|
# Upload file using PUT with headers from file
|
||||||
curl_status=0
|
UPLOAD_RESPONSE=$(curl -s -w "\n%{http_code}" -X PUT \
|
||||||
UPLOAD_RESPONSE=$(curl -s -w "\n%{http_code}" -X PUT \
|
-K /tmp/upload_headers.txt \
|
||||||
-K /tmp/upload_headers.txt \
|
--data-binary "@${file}" \
|
||||||
--data-binary "@${file}" \
|
"$UPLOAD_URL")
|
||||||
"$UPLOAD_URL") || curl_status=$?
|
|
||||||
|
|
||||||
if [ $curl_status -eq 0 ]; then
|
HTTP_CODE=$(echo "$UPLOAD_RESPONSE" | tail -n1)
|
||||||
HTTP_CODE=$(echo "$UPLOAD_RESPONSE" | tail -n1)
|
RESPONSE_BODY=$(echo "$UPLOAD_RESPONSE" | sed '$d')
|
||||||
RESPONSE_BODY=$(echo "$UPLOAD_RESPONSE" | sed '$d')
|
|
||||||
|
|
||||||
if [ "$HTTP_CODE" -ge 200 ] && [ "$HTTP_CODE" -lt 300 ]; then
|
if [ "$HTTP_CODE" -ge 200 ] && [ "$HTTP_CODE" -lt 300 ]; then
|
||||||
echo " Uploaded: $filename"
|
echo " Uploaded: $filename"
|
||||||
return 0
|
return 0
|
||||||
else
|
|
||||||
echo " Failed (HTTP $HTTP_CODE), retry $((retry + 1))/$max_retries"
|
|
||||||
echo " Response: $RESPONSE_BODY"
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
echo " Upload request failed (curl exit $curl_status), retry $((retry + 1))/$max_retries"
|
|
||||||
fi
|
|
||||||
else
|
else
|
||||||
echo " Failed to get upload URL, retry $((retry + 1))/$max_retries"
|
echo " Failed (HTTP $HTTP_CODE), retry $((retry + 1))/$max_retries"
|
||||||
echo " Response: $UPLOAD_INFO"
|
echo " Response: $RESPONSE_BODY"
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
echo " Failed to get upload URL (curl exit $curl_status), retry $((retry + 1))/$max_retries"
|
echo " Failed to get upload URL, retry $((retry + 1))/$max_retries"
|
||||||
echo " Response: $UPLOAD_INFO"
|
echo " Response: $UPLOAD_INFO"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
diff --git a/dist/index.js b/dist/index.js
|
diff --git a/dist/index.js b/dist/index.js
|
||||||
index d004b415c5841a1969705823614f395265ea5a8a..6b1e0dad4610b0424393ecc12e9114723bbe316b 100644
|
index 51ce7e423934fb717cb90245cdfcdb3dae6780e6..0f7f7009e2f41a79a8669d38c8a44867bbff5e1f 100644
|
||||||
--- a/dist/index.js
|
--- a/dist/index.js
|
||||||
+++ b/dist/index.js
|
+++ b/dist/index.js
|
||||||
@@ -474,7 +474,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
|
@@ -474,7 +474,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
|
||||||
@ -12,7 +12,7 @@ index d004b415c5841a1969705823614f395265ea5a8a..6b1e0dad4610b0424393ecc12e911472
|
|||||||
|
|
||||||
// src/google-generative-ai-options.ts
|
// src/google-generative-ai-options.ts
|
||||||
diff --git a/dist/index.mjs b/dist/index.mjs
|
diff --git a/dist/index.mjs b/dist/index.mjs
|
||||||
index 1780dd2391b7f42224a0b8048c723d2f81222c44..1f12ed14399d6902107ce9b435d7d8e6cc61e06b 100644
|
index f4b77e35c0cbfece85a3ef0d4f4e67aa6dde6271..8d2fecf8155a226006a0bde72b00b6036d4014b6 100644
|
||||||
--- a/dist/index.mjs
|
--- a/dist/index.mjs
|
||||||
+++ b/dist/index.mjs
|
+++ b/dist/index.mjs
|
||||||
@@ -480,7 +480,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
|
@@ -480,7 +480,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
|
||||||
@ -24,14 +24,3 @@ index 1780dd2391b7f42224a0b8048c723d2f81222c44..1f12ed14399d6902107ce9b435d7d8e6
|
|||||||
}
|
}
|
||||||
|
|
||||||
// src/google-generative-ai-options.ts
|
// src/google-generative-ai-options.ts
|
||||||
@@ -1909,8 +1909,7 @@ function createGoogleGenerativeAI(options = {}) {
|
|
||||||
}
|
|
||||||
var google = createGoogleGenerativeAI();
|
|
||||||
export {
|
|
||||||
- VERSION,
|
|
||||||
createGoogleGenerativeAI,
|
|
||||||
- google
|
|
||||||
+ google, VERSION
|
|
||||||
};
|
|
||||||
//# sourceMappingURL=index.mjs.map
|
|
||||||
\ No newline at end of file
|
|
||||||
140
.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch
vendored
Normal file
140
.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch
vendored
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
diff --git a/dist/index.js b/dist/index.js
|
||||||
|
index 73045a7d38faafdc7f7d2cd79d7ff0e2b031056b..8d948c9ac4ea4b474db9ef3c5491961e7fcf9a07 100644
|
||||||
|
--- a/dist/index.js
|
||||||
|
+++ b/dist/index.js
|
||||||
|
@@ -421,6 +421,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||||
|
text: reasoning
|
||||||
|
});
|
||||||
|
}
|
||||||
|
+ if (choice.message.images) {
|
||||||
|
+ for (const image of choice.message.images) {
|
||||||
|
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||||
|
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||||
|
+ content.push({
|
||||||
|
+ type: 'file',
|
||||||
|
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||||
|
+ data: match2 ? match2[1] : image.image_url.url,
|
||||||
|
+ });
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
if (choice.message.tool_calls != null) {
|
||||||
|
for (const toolCall of choice.message.tool_calls) {
|
||||||
|
content.push({
|
||||||
|
@@ -598,6 +609,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||||
|
delta: delta.content
|
||||||
|
});
|
||||||
|
}
|
||||||
|
+ if (delta.images) {
|
||||||
|
+ for (const image of delta.images) {
|
||||||
|
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||||
|
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||||
|
+ controller.enqueue({
|
||||||
|
+ type: 'file',
|
||||||
|
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||||
|
+ data: match2 ? match2[1] : image.image_url.url,
|
||||||
|
+ });
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
if (delta.tool_calls != null) {
|
||||||
|
for (const toolCallDelta of delta.tool_calls) {
|
||||||
|
const index = toolCallDelta.index;
|
||||||
|
@@ -765,6 +787,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({
|
||||||
|
arguments: import_v43.z.string()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
+ ).nullish(),
|
||||||
|
+ images: import_v43.z.array(
|
||||||
|
+ import_v43.z.object({
|
||||||
|
+ type: import_v43.z.literal('image_url'),
|
||||||
|
+ image_url: import_v43.z.object({
|
||||||
|
+ url: import_v43.z.string(),
|
||||||
|
+ })
|
||||||
|
+ })
|
||||||
|
).nullish()
|
||||||
|
}),
|
||||||
|
finish_reason: import_v43.z.string().nullish()
|
||||||
|
@@ -795,6 +825,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union(
|
||||||
|
arguments: import_v43.z.string().nullish()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
+ ).nullish(),
|
||||||
|
+ images: import_v43.z.array(
|
||||||
|
+ import_v43.z.object({
|
||||||
|
+ type: import_v43.z.literal('image_url'),
|
||||||
|
+ image_url: import_v43.z.object({
|
||||||
|
+ url: import_v43.z.string(),
|
||||||
|
+ })
|
||||||
|
+ })
|
||||||
|
).nullish()
|
||||||
|
}).nullish(),
|
||||||
|
finish_reason: import_v43.z.string().nullish()
|
||||||
|
diff --git a/dist/index.mjs b/dist/index.mjs
|
||||||
|
index 1c2b9560bbfbfe10cb01af080aeeed4ff59db29c..2c8ddc4fc9bfc5e7e06cfca105d197a08864c427 100644
|
||||||
|
--- a/dist/index.mjs
|
||||||
|
+++ b/dist/index.mjs
|
||||||
|
@@ -405,6 +405,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||||
|
text: reasoning
|
||||||
|
});
|
||||||
|
}
|
||||||
|
+ if (choice.message.images) {
|
||||||
|
+ for (const image of choice.message.images) {
|
||||||
|
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||||
|
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||||
|
+ content.push({
|
||||||
|
+ type: 'file',
|
||||||
|
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||||
|
+ data: match2 ? match2[1] : image.image_url.url,
|
||||||
|
+ });
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
if (choice.message.tool_calls != null) {
|
||||||
|
for (const toolCall of choice.message.tool_calls) {
|
||||||
|
content.push({
|
||||||
|
@@ -582,6 +593,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||||
|
delta: delta.content
|
||||||
|
});
|
||||||
|
}
|
||||||
|
+ if (delta.images) {
|
||||||
|
+ for (const image of delta.images) {
|
||||||
|
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||||
|
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||||
|
+ controller.enqueue({
|
||||||
|
+ type: 'file',
|
||||||
|
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||||
|
+ data: match2 ? match2[1] : image.image_url.url,
|
||||||
|
+ });
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
if (delta.tool_calls != null) {
|
||||||
|
for (const toolCallDelta of delta.tool_calls) {
|
||||||
|
const index = toolCallDelta.index;
|
||||||
|
@@ -749,6 +771,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({
|
||||||
|
arguments: z3.string()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
+ ).nullish(),
|
||||||
|
+ images: z3.array(
|
||||||
|
+ z3.object({
|
||||||
|
+ type: z3.literal('image_url'),
|
||||||
|
+ image_url: z3.object({
|
||||||
|
+ url: z3.string(),
|
||||||
|
+ })
|
||||||
|
+ })
|
||||||
|
).nullish()
|
||||||
|
}),
|
||||||
|
finish_reason: z3.string().nullish()
|
||||||
|
@@ -779,6 +809,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([
|
||||||
|
arguments: z3.string().nullish()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
+ ).nullish(),
|
||||||
|
+ images: z3.array(
|
||||||
|
+ z3.object({
|
||||||
|
+ type: z3.literal('image_url'),
|
||||||
|
+ image_url: z3.object({
|
||||||
|
+ url: z3.string(),
|
||||||
|
+ })
|
||||||
|
+ })
|
||||||
|
).nullish()
|
||||||
|
}).nullish(),
|
||||||
|
finish_reason: z3.string().nullish()
|
||||||
@ -1,266 +0,0 @@
|
|||||||
diff --git a/dist/index.d.ts b/dist/index.d.ts
|
|
||||||
index 48e2f6263c6ee4c75d7e5c28733e64f6ebe92200..00d0729c4a3cbf9a48e8e1e962c7e2b256b75eba 100644
|
|
||||||
--- a/dist/index.d.ts
|
|
||||||
+++ b/dist/index.d.ts
|
|
||||||
@@ -7,6 +7,7 @@ declare const openaiCompatibleProviderOptions: z.ZodObject<{
|
|
||||||
user: z.ZodOptional<z.ZodString>;
|
|
||||||
reasoningEffort: z.ZodOptional<z.ZodString>;
|
|
||||||
textVerbosity: z.ZodOptional<z.ZodString>;
|
|
||||||
+ sendReasoning: z.ZodOptional<z.ZodBoolean>;
|
|
||||||
}, z.core.$strip>;
|
|
||||||
type OpenAICompatibleProviderOptions = z.infer<typeof openaiCompatibleProviderOptions>;
|
|
||||||
|
|
||||||
diff --git a/dist/index.js b/dist/index.js
|
|
||||||
index da237bb35b7fa8e24b37cd861ee73dfc51cdfc72..b3060fbaf010e30b64df55302807828e5bfe0f9a 100644
|
|
||||||
--- a/dist/index.js
|
|
||||||
+++ b/dist/index.js
|
|
||||||
@@ -41,7 +41,7 @@ function getOpenAIMetadata(message) {
|
|
||||||
var _a, _b;
|
|
||||||
return (_b = (_a = message == null ? void 0 : message.providerOptions) == null ? void 0 : _a.openaiCompatible) != null ? _b : {};
|
|
||||||
}
|
|
||||||
-function convertToOpenAICompatibleChatMessages(prompt) {
|
|
||||||
+function convertToOpenAICompatibleChatMessages({prompt, options}) {
|
|
||||||
const messages = [];
|
|
||||||
for (const { role, content, ...message } of prompt) {
|
|
||||||
const metadata = getOpenAIMetadata({ ...message });
|
|
||||||
@@ -91,6 +91,7 @@ function convertToOpenAICompatibleChatMessages(prompt) {
|
|
||||||
}
|
|
||||||
case "assistant": {
|
|
||||||
let text = "";
|
|
||||||
+ let reasoning_text = "";
|
|
||||||
const toolCalls = [];
|
|
||||||
for (const part of content) {
|
|
||||||
const partMetadata = getOpenAIMetadata(part);
|
|
||||||
@@ -99,6 +100,12 @@ function convertToOpenAICompatibleChatMessages(prompt) {
|
|
||||||
text += part.text;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
+ case "reasoning": {
|
|
||||||
+ if (options.sendReasoning) {
|
|
||||||
+ reasoning_text += part.text;
|
|
||||||
+ }
|
|
||||||
+ break;
|
|
||||||
+ }
|
|
||||||
case "tool-call": {
|
|
||||||
toolCalls.push({
|
|
||||||
id: part.toolCallId,
|
|
||||||
@@ -116,6 +123,7 @@ function convertToOpenAICompatibleChatMessages(prompt) {
|
|
||||||
messages.push({
|
|
||||||
role: "assistant",
|
|
||||||
content: text,
|
|
||||||
+ reasoning_content: reasoning_text ?? undefined,
|
|
||||||
tool_calls: toolCalls.length > 0 ? toolCalls : void 0,
|
|
||||||
...metadata
|
|
||||||
});
|
|
||||||
@@ -200,7 +208,8 @@ var openaiCompatibleProviderOptions = import_v4.z.object({
|
|
||||||
/**
|
|
||||||
* Controls the verbosity of the generated text. Defaults to `medium`.
|
|
||||||
*/
|
|
||||||
- textVerbosity: import_v4.z.string().optional()
|
|
||||||
+ textVerbosity: import_v4.z.string().optional(),
|
|
||||||
+ sendReasoning: import_v4.z.boolean().optional()
|
|
||||||
});
|
|
||||||
|
|
||||||
// src/openai-compatible-error.ts
|
|
||||||
@@ -378,7 +387,7 @@ var OpenAICompatibleChatLanguageModel = class {
|
|
||||||
reasoning_effort: compatibleOptions.reasoningEffort,
|
|
||||||
verbosity: compatibleOptions.textVerbosity,
|
|
||||||
// messages:
|
|
||||||
- messages: convertToOpenAICompatibleChatMessages(prompt),
|
|
||||||
+ messages: convertToOpenAICompatibleChatMessages({prompt, options: compatibleOptions}),
|
|
||||||
// tools:
|
|
||||||
tools: openaiTools,
|
|
||||||
tool_choice: openaiToolChoice
|
|
||||||
@@ -421,6 +430,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
|
||||||
text: reasoning
|
|
||||||
});
|
|
||||||
}
|
|
||||||
+ if (choice.message.images) {
|
|
||||||
+ for (const image of choice.message.images) {
|
|
||||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
|
||||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
|
||||||
+ content.push({
|
|
||||||
+ type: 'file',
|
|
||||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
|
||||||
+ data: match2 ? match2[1] : image.image_url.url,
|
|
||||||
+ });
|
|
||||||
+ }
|
|
||||||
+ }
|
|
||||||
if (choice.message.tool_calls != null) {
|
|
||||||
for (const toolCall of choice.message.tool_calls) {
|
|
||||||
content.push({
|
|
||||||
@@ -598,6 +618,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
|
||||||
delta: delta.content
|
|
||||||
});
|
|
||||||
}
|
|
||||||
+ if (delta.images) {
|
|
||||||
+ for (const image of delta.images) {
|
|
||||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
|
||||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
|
||||||
+ controller.enqueue({
|
|
||||||
+ type: 'file',
|
|
||||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
|
||||||
+ data: match2 ? match2[1] : image.image_url.url,
|
|
||||||
+ });
|
|
||||||
+ }
|
|
||||||
+ }
|
|
||||||
if (delta.tool_calls != null) {
|
|
||||||
for (const toolCallDelta of delta.tool_calls) {
|
|
||||||
const index = toolCallDelta.index;
|
|
||||||
@@ -765,6 +796,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({
|
|
||||||
arguments: import_v43.z.string()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
+ ).nullish(),
|
|
||||||
+ images: import_v43.z.array(
|
|
||||||
+ import_v43.z.object({
|
|
||||||
+ type: import_v43.z.literal('image_url'),
|
|
||||||
+ image_url: import_v43.z.object({
|
|
||||||
+ url: import_v43.z.string(),
|
|
||||||
+ })
|
|
||||||
+ })
|
|
||||||
).nullish()
|
|
||||||
}),
|
|
||||||
finish_reason: import_v43.z.string().nullish()
|
|
||||||
@@ -795,6 +834,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union(
|
|
||||||
arguments: import_v43.z.string().nullish()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
+ ).nullish(),
|
|
||||||
+ images: import_v43.z.array(
|
|
||||||
+ import_v43.z.object({
|
|
||||||
+ type: import_v43.z.literal('image_url'),
|
|
||||||
+ image_url: import_v43.z.object({
|
|
||||||
+ url: import_v43.z.string(),
|
|
||||||
+ })
|
|
||||||
+ })
|
|
||||||
).nullish()
|
|
||||||
}).nullish(),
|
|
||||||
finish_reason: import_v43.z.string().nullish()
|
|
||||||
diff --git a/dist/index.mjs b/dist/index.mjs
|
|
||||||
index a809a7aa0e148bfd43e01dd7b018568b151c8ad5..565b605eeacd9830b2b0e817e58ad0c5700264de 100644
|
|
||||||
--- a/dist/index.mjs
|
|
||||||
+++ b/dist/index.mjs
|
|
||||||
@@ -23,7 +23,7 @@ function getOpenAIMetadata(message) {
|
|
||||||
var _a, _b;
|
|
||||||
return (_b = (_a = message == null ? void 0 : message.providerOptions) == null ? void 0 : _a.openaiCompatible) != null ? _b : {};
|
|
||||||
}
|
|
||||||
-function convertToOpenAICompatibleChatMessages(prompt) {
|
|
||||||
+function convertToOpenAICompatibleChatMessages({prompt, options}) {
|
|
||||||
const messages = [];
|
|
||||||
for (const { role, content, ...message } of prompt) {
|
|
||||||
const metadata = getOpenAIMetadata({ ...message });
|
|
||||||
@@ -73,6 +73,7 @@ function convertToOpenAICompatibleChatMessages(prompt) {
|
|
||||||
}
|
|
||||||
case "assistant": {
|
|
||||||
let text = "";
|
|
||||||
+ let reasoning_text = "";
|
|
||||||
const toolCalls = [];
|
|
||||||
for (const part of content) {
|
|
||||||
const partMetadata = getOpenAIMetadata(part);
|
|
||||||
@@ -81,6 +82,12 @@ function convertToOpenAICompatibleChatMessages(prompt) {
|
|
||||||
text += part.text;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
+ case "reasoning": {
|
|
||||||
+ if (options.sendReasoning) {
|
|
||||||
+ reasoning_text += part.text;
|
|
||||||
+ }
|
|
||||||
+ break;
|
|
||||||
+ }
|
|
||||||
case "tool-call": {
|
|
||||||
toolCalls.push({
|
|
||||||
id: part.toolCallId,
|
|
||||||
@@ -98,6 +105,7 @@ function convertToOpenAICompatibleChatMessages(prompt) {
|
|
||||||
messages.push({
|
|
||||||
role: "assistant",
|
|
||||||
content: text,
|
|
||||||
+ reasoning_content: reasoning_text ?? undefined,
|
|
||||||
tool_calls: toolCalls.length > 0 ? toolCalls : void 0,
|
|
||||||
...metadata
|
|
||||||
});
|
|
||||||
@@ -182,7 +190,8 @@ var openaiCompatibleProviderOptions = z.object({
|
|
||||||
/**
|
|
||||||
* Controls the verbosity of the generated text. Defaults to `medium`.
|
|
||||||
*/
|
|
||||||
- textVerbosity: z.string().optional()
|
|
||||||
+ textVerbosity: z.string().optional(),
|
|
||||||
+ sendReasoning: z.boolean().optional()
|
|
||||||
});
|
|
||||||
|
|
||||||
// src/openai-compatible-error.ts
|
|
||||||
@@ -362,7 +371,7 @@ var OpenAICompatibleChatLanguageModel = class {
|
|
||||||
reasoning_effort: compatibleOptions.reasoningEffort,
|
|
||||||
verbosity: compatibleOptions.textVerbosity,
|
|
||||||
// messages:
|
|
||||||
- messages: convertToOpenAICompatibleChatMessages(prompt),
|
|
||||||
+ messages: convertToOpenAICompatibleChatMessages({prompt, options: compatibleOptions}),
|
|
||||||
// tools:
|
|
||||||
tools: openaiTools,
|
|
||||||
tool_choice: openaiToolChoice
|
|
||||||
@@ -405,6 +414,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
|
||||||
text: reasoning
|
|
||||||
});
|
|
||||||
}
|
|
||||||
+ if (choice.message.images) {
|
|
||||||
+ for (const image of choice.message.images) {
|
|
||||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
|
||||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
|
||||||
+ content.push({
|
|
||||||
+ type: 'file',
|
|
||||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
|
||||||
+ data: match2 ? match2[1] : image.image_url.url,
|
|
||||||
+ });
|
|
||||||
+ }
|
|
||||||
+ }
|
|
||||||
if (choice.message.tool_calls != null) {
|
|
||||||
for (const toolCall of choice.message.tool_calls) {
|
|
||||||
content.push({
|
|
||||||
@@ -582,6 +602,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
|
||||||
delta: delta.content
|
|
||||||
});
|
|
||||||
}
|
|
||||||
+ if (delta.images) {
|
|
||||||
+ for (const image of delta.images) {
|
|
||||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
|
||||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
|
||||||
+ controller.enqueue({
|
|
||||||
+ type: 'file',
|
|
||||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
|
||||||
+ data: match2 ? match2[1] : image.image_url.url,
|
|
||||||
+ });
|
|
||||||
+ }
|
|
||||||
+ }
|
|
||||||
if (delta.tool_calls != null) {
|
|
||||||
for (const toolCallDelta of delta.tool_calls) {
|
|
||||||
const index = toolCallDelta.index;
|
|
||||||
@@ -749,6 +780,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({
|
|
||||||
arguments: z3.string()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
+ ).nullish(),
|
|
||||||
+ images: z3.array(
|
|
||||||
+ z3.object({
|
|
||||||
+ type: z3.literal('image_url'),
|
|
||||||
+ image_url: z3.object({
|
|
||||||
+ url: z3.string(),
|
|
||||||
+ })
|
|
||||||
+ })
|
|
||||||
).nullish()
|
|
||||||
}),
|
|
||||||
finish_reason: z3.string().nullish()
|
|
||||||
@@ -779,6 +818,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([
|
|
||||||
arguments: z3.string().nullish()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
+ ).nullish(),
|
|
||||||
+ images: z3.array(
|
|
||||||
+ z3.object({
|
|
||||||
+ type: z3.literal('image_url'),
|
|
||||||
+ image_url: z3.object({
|
|
||||||
+ url: z3.string(),
|
|
||||||
+ })
|
|
||||||
+ })
|
|
||||||
).nullish()
|
|
||||||
}).nullish(),
|
|
||||||
finish_reason: z3.string().nullish()
|
|
||||||
@ -1,8 +1,8 @@
|
|||||||
diff --git a/dist/index.js b/dist/index.js
|
diff --git a/dist/index.js b/dist/index.js
|
||||||
index 130094d194ea1e8e7d3027d07d82465741192124..4d13dcee8c962ca9ee8f1c3d748f8ffe6a3cfb47 100644
|
index bf900591bf2847a3253fe441aad24c06da19c6c1..c1d9bb6fefa2df1383339324073db0a70ea2b5a2 100644
|
||||||
--- a/dist/index.js
|
--- a/dist/index.js
|
||||||
+++ b/dist/index.js
|
+++ b/dist/index.js
|
||||||
@@ -290,6 +290,7 @@ var openaiChatResponseSchema = (0, import_provider_utils3.lazyValidator)(
|
@@ -274,6 +274,7 @@ var openaiChatResponseSchema = (0, import_provider_utils3.lazyValidator)(
|
||||||
message: import_v42.z.object({
|
message: import_v42.z.object({
|
||||||
role: import_v42.z.literal("assistant").nullish(),
|
role: import_v42.z.literal("assistant").nullish(),
|
||||||
content: import_v42.z.string().nullish(),
|
content: import_v42.z.string().nullish(),
|
||||||
@ -10,7 +10,7 @@ index 130094d194ea1e8e7d3027d07d82465741192124..4d13dcee8c962ca9ee8f1c3d748f8ffe
|
|||||||
tool_calls: import_v42.z.array(
|
tool_calls: import_v42.z.array(
|
||||||
import_v42.z.object({
|
import_v42.z.object({
|
||||||
id: import_v42.z.string().nullish(),
|
id: import_v42.z.string().nullish(),
|
||||||
@@ -356,6 +357,7 @@ var openaiChatChunkSchema = (0, import_provider_utils3.lazyValidator)(
|
@@ -340,6 +341,7 @@ var openaiChatChunkSchema = (0, import_provider_utils3.lazyValidator)(
|
||||||
delta: import_v42.z.object({
|
delta: import_v42.z.object({
|
||||||
role: import_v42.z.enum(["assistant"]).nullish(),
|
role: import_v42.z.enum(["assistant"]).nullish(),
|
||||||
content: import_v42.z.string().nullish(),
|
content: import_v42.z.string().nullish(),
|
||||||
@ -18,7 +18,7 @@ index 130094d194ea1e8e7d3027d07d82465741192124..4d13dcee8c962ca9ee8f1c3d748f8ffe
|
|||||||
tool_calls: import_v42.z.array(
|
tool_calls: import_v42.z.array(
|
||||||
import_v42.z.object({
|
import_v42.z.object({
|
||||||
index: import_v42.z.number(),
|
index: import_v42.z.number(),
|
||||||
@@ -814,6 +816,13 @@ var OpenAIChatLanguageModel = class {
|
@@ -795,6 +797,13 @@ var OpenAIChatLanguageModel = class {
|
||||||
if (text != null && text.length > 0) {
|
if (text != null && text.length > 0) {
|
||||||
content.push({ type: "text", text });
|
content.push({ type: "text", text });
|
||||||
}
|
}
|
||||||
@ -32,7 +32,7 @@ index 130094d194ea1e8e7d3027d07d82465741192124..4d13dcee8c962ca9ee8f1c3d748f8ffe
|
|||||||
for (const toolCall of (_a = choice.message.tool_calls) != null ? _a : []) {
|
for (const toolCall of (_a = choice.message.tool_calls) != null ? _a : []) {
|
||||||
content.push({
|
content.push({
|
||||||
type: "tool-call",
|
type: "tool-call",
|
||||||
@@ -895,6 +904,7 @@ var OpenAIChatLanguageModel = class {
|
@@ -876,6 +885,7 @@ var OpenAIChatLanguageModel = class {
|
||||||
};
|
};
|
||||||
let metadataExtracted = false;
|
let metadataExtracted = false;
|
||||||
let isActiveText = false;
|
let isActiveText = false;
|
||||||
@ -40,7 +40,7 @@ index 130094d194ea1e8e7d3027d07d82465741192124..4d13dcee8c962ca9ee8f1c3d748f8ffe
|
|||||||
const providerMetadata = { openai: {} };
|
const providerMetadata = { openai: {} };
|
||||||
return {
|
return {
|
||||||
stream: response.pipeThrough(
|
stream: response.pipeThrough(
|
||||||
@@ -952,6 +962,21 @@ var OpenAIChatLanguageModel = class {
|
@@ -933,6 +943,21 @@ var OpenAIChatLanguageModel = class {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const delta = choice.delta;
|
const delta = choice.delta;
|
||||||
@ -62,7 +62,7 @@ index 130094d194ea1e8e7d3027d07d82465741192124..4d13dcee8c962ca9ee8f1c3d748f8ffe
|
|||||||
if (delta.content != null) {
|
if (delta.content != null) {
|
||||||
if (!isActiveText) {
|
if (!isActiveText) {
|
||||||
controller.enqueue({ type: "text-start", id: "0" });
|
controller.enqueue({ type: "text-start", id: "0" });
|
||||||
@@ -1064,6 +1089,9 @@ var OpenAIChatLanguageModel = class {
|
@@ -1045,6 +1070,9 @@ var OpenAIChatLanguageModel = class {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
flush(controller) {
|
flush(controller) {
|
||||||
@ -1,145 +0,0 @@
|
|||||||
diff --git a/dist/index.d.ts b/dist/index.d.ts
|
|
||||||
index 8dd9b498050dbecd8dd6b901acf1aa8ca38a49af..ed644349c9d38fe2a66b2fb44214f7c18eb97f89 100644
|
|
||||||
--- a/dist/index.d.ts
|
|
||||||
+++ b/dist/index.d.ts
|
|
||||||
@@ -4,7 +4,7 @@ import { z } from 'zod/v4';
|
|
||||||
|
|
||||||
type OllamaChatModelId = "athene-v2" | "athene-v2:72b" | "aya-expanse" | "aya-expanse:8b" | "aya-expanse:32b" | "codegemma" | "codegemma:2b" | "codegemma:7b" | "codellama" | "codellama:7b" | "codellama:13b" | "codellama:34b" | "codellama:70b" | "codellama:code" | "codellama:python" | "command-r" | "command-r:35b" | "command-r-plus" | "command-r-plus:104b" | "command-r7b" | "command-r7b:7b" | "deepseek-r1" | "deepseek-r1:1.5b" | "deepseek-r1:7b" | "deepseek-r1:8b" | "deepseek-r1:14b" | "deepseek-r1:32b" | "deepseek-r1:70b" | "deepseek-r1:671b" | "deepseek-coder-v2" | "deepseek-coder-v2:16b" | "deepseek-coder-v2:236b" | "deepseek-v3" | "deepseek-v3:671b" | "devstral" | "devstral:24b" | "dolphin3" | "dolphin3:8b" | "exaone3.5" | "exaone3.5:2.4b" | "exaone3.5:7.8b" | "exaone3.5:32b" | "falcon2" | "falcon2:11b" | "falcon3" | "falcon3:1b" | "falcon3:3b" | "falcon3:7b" | "falcon3:10b" | "firefunction-v2" | "firefunction-v2:70b" | "gemma" | "gemma:2b" | "gemma:7b" | "gemma2" | "gemma2:2b" | "gemma2:9b" | "gemma2:27b" | "gemma3" | "gemma3:1b" | "gemma3:4b" | "gemma3:12b" | "gemma3:27b" | "granite3-dense" | "granite3-dense:2b" | "granite3-dense:8b" | "granite3-guardian" | "granite3-guardian:2b" | "granite3-guardian:8b" | "granite3-moe" | "granite3-moe:1b" | "granite3-moe:3b" | "granite3.1-dense" | "granite3.1-dense:2b" | "granite3.1-dense:8b" | "granite3.1-moe" | "granite3.1-moe:1b" | "granite3.1-moe:3b" | "llama2" | "llama2:7b" | "llama2:13b" | "llama2:70b" | "llama3" | "llama3:8b" | "llama3:70b" | "llama3-chatqa" | "llama3-chatqa:8b" | "llama3-chatqa:70b" | "llama3-gradient" | "llama3-gradient:8b" | "llama3-gradient:70b" | "llama3.1" | "llama3.1:8b" | "llama3.1:70b" | "llama3.1:405b" | "llama3.2" | "llama3.2:1b" | "llama3.2:3b" | "llama3.2-vision" | "llama3.2-vision:11b" | "llama3.2-vision:90b" | "llama3.3" | "llama3.3:70b" | "llama4" | "llama4:16x17b" | "llama4:128x17b" | "llama-guard3" | "llama-guard3:1b" | "llama-guard3:8b" | "llava" | "llava:7b" | "llava:13b" | "llava:34b" | "llava-llama3" | "llava-llama3:8b" | "llava-phi3" | "llava-phi3:3.8b" | "marco-o1" | "marco-o1:7b" | "mistral" | "mistral:7b" | "mistral-large" | "mistral-large:123b" | "mistral-nemo" | "mistral-nemo:12b" | "mistral-small" | "mistral-small:22b" | "mixtral" | "mixtral:8x7b" | "mixtral:8x22b" | "moondream" | "moondream:1.8b" | "openhermes" | "openhermes:v2.5" | "nemotron" | "nemotron:70b" | "nemotron-mini" | "nemotron-mini:4b" | "olmo" | "olmo:7b" | "olmo:13b" | "opencoder" | "opencoder:1.5b" | "opencoder:8b" | "phi3" | "phi3:3.8b" | "phi3:14b" | "phi3.5" | "phi3.5:3.8b" | "phi4" | "phi4:14b" | "qwen" | "qwen:7b" | "qwen:14b" | "qwen:32b" | "qwen:72b" | "qwen:110b" | "qwen2" | "qwen2:0.5b" | "qwen2:1.5b" | "qwen2:7b" | "qwen2:72b" | "qwen2.5" | "qwen2.5:0.5b" | "qwen2.5:1.5b" | "qwen2.5:3b" | "qwen2.5:7b" | "qwen2.5:14b" | "qwen2.5:32b" | "qwen2.5:72b" | "qwen2.5-coder" | "qwen2.5-coder:0.5b" | "qwen2.5-coder:1.5b" | "qwen2.5-coder:3b" | "qwen2.5-coder:7b" | "qwen2.5-coder:14b" | "qwen2.5-coder:32b" | "qwen3" | "qwen3:0.6b" | "qwen3:1.7b" | "qwen3:4b" | "qwen3:8b" | "qwen3:14b" | "qwen3:30b" | "qwen3:32b" | "qwen3:235b" | "qwq" | "qwq:32b" | "sailor2" | "sailor2:1b" | "sailor2:8b" | "sailor2:20b" | "shieldgemma" | "shieldgemma:2b" | "shieldgemma:9b" | "shieldgemma:27b" | "smallthinker" | "smallthinker:3b" | "smollm" | "smollm:135m" | "smollm:360m" | "smollm:1.7b" | "tinyllama" | "tinyllama:1.1b" | "tulu3" | "tulu3:8b" | "tulu3:70b" | (string & {});
|
|
||||||
declare const ollamaProviderOptions: z.ZodObject<{
|
|
||||||
- think: z.ZodOptional<z.ZodBoolean>;
|
|
||||||
+ think: z.ZodOptional<z.ZodUnion<[z.ZodBoolean, z.ZodEnum<['low', 'medium', 'high']>]>>;
|
|
||||||
options: z.ZodOptional<z.ZodObject<{
|
|
||||||
num_ctx: z.ZodOptional<z.ZodNumber>;
|
|
||||||
repeat_last_n: z.ZodOptional<z.ZodNumber>;
|
|
||||||
@@ -27,9 +27,11 @@ interface OllamaCompletionSettings {
|
|
||||||
* the model's thinking from the model's output. When disabled, the model will not think
|
|
||||||
* and directly output the content.
|
|
||||||
*
|
|
||||||
+ * For gpt-oss models, you can also use 'low', 'medium', or 'high' to control the depth of thinking.
|
|
||||||
+ *
|
|
||||||
* Only supported by certain models like DeepSeek R1 and Qwen 3.
|
|
||||||
*/
|
|
||||||
- think?: boolean;
|
|
||||||
+ think?: boolean | 'low' | 'medium' | 'high';
|
|
||||||
/**
|
|
||||||
* Echo back the prompt in addition to the completion.
|
|
||||||
*/
|
|
||||||
@@ -146,7 +148,7 @@ declare const ollamaEmbeddingProviderOptions: z.ZodObject<{
|
|
||||||
type OllamaEmbeddingProviderOptions = z.infer<typeof ollamaEmbeddingProviderOptions>;
|
|
||||||
|
|
||||||
declare const ollamaCompletionProviderOptions: z.ZodObject<{
|
|
||||||
- think: z.ZodOptional<z.ZodBoolean>;
|
|
||||||
+ think: z.ZodOptional<z.ZodUnion<[z.ZodBoolean, z.ZodEnum<['low', 'medium', 'high']>]>>;
|
|
||||||
user: z.ZodOptional<z.ZodString>;
|
|
||||||
suffix: z.ZodOptional<z.ZodString>;
|
|
||||||
echo: z.ZodOptional<z.ZodBoolean>;
|
|
||||||
diff --git a/dist/index.js b/dist/index.js
|
|
||||||
index 35b5142ce8476ce2549ed7c2ec48e7d8c46c90d9..2ef64dc9a4c2be043e6af608241a6a8309a5a69f 100644
|
|
||||||
--- a/dist/index.js
|
|
||||||
+++ b/dist/index.js
|
|
||||||
@@ -158,7 +158,7 @@ function getResponseMetadata({
|
|
||||||
|
|
||||||
// src/completion/ollama-completion-language-model.ts
|
|
||||||
var ollamaCompletionProviderOptions = import_v42.z.object({
|
|
||||||
- think: import_v42.z.boolean().optional(),
|
|
||||||
+ think: import_v42.z.union([import_v42.z.boolean(), import_v42.z.enum(['low', 'medium', 'high'])]).optional(),
|
|
||||||
user: import_v42.z.string().optional(),
|
|
||||||
suffix: import_v42.z.string().optional(),
|
|
||||||
echo: import_v42.z.boolean().optional()
|
|
||||||
@@ -662,7 +662,7 @@ function convertToOllamaChatMessages({
|
|
||||||
const images = content.filter((part) => part.type === "file" && part.mediaType.startsWith("image/")).map((part) => part.data);
|
|
||||||
messages.push({
|
|
||||||
role: "user",
|
|
||||||
- content: userText.length > 0 ? userText : [],
|
|
||||||
+ content: userText.length > 0 ? userText : '',
|
|
||||||
images: images.length > 0 ? images : void 0
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
@@ -813,9 +813,11 @@ var ollamaProviderOptions = import_v44.z.object({
|
|
||||||
* the model's thinking from the model's output. When disabled, the model will not think
|
|
||||||
* and directly output the content.
|
|
||||||
*
|
|
||||||
+ * For gpt-oss models, you can also use 'low', 'medium', or 'high' to control the depth of thinking.
|
|
||||||
+ *
|
|
||||||
* Only supported by certain models like DeepSeek R1 and Qwen 3.
|
|
||||||
*/
|
|
||||||
- think: import_v44.z.boolean().optional(),
|
|
||||||
+ think: import_v44.z.union([import_v44.z.boolean(), import_v44.z.enum(['low', 'medium', 'high'])]).optional(),
|
|
||||||
options: import_v44.z.object({
|
|
||||||
num_ctx: import_v44.z.number().optional(),
|
|
||||||
repeat_last_n: import_v44.z.number().optional(),
|
|
||||||
@@ -929,14 +931,16 @@ var OllamaRequestBuilder = class {
|
|
||||||
prompt,
|
|
||||||
systemMessageMode: "system"
|
|
||||||
}),
|
|
||||||
- temperature,
|
|
||||||
- top_p: topP,
|
|
||||||
max_output_tokens: maxOutputTokens,
|
|
||||||
...(responseFormat == null ? void 0 : responseFormat.type) === "json" && {
|
|
||||||
format: responseFormat.schema != null ? responseFormat.schema : "json"
|
|
||||||
},
|
|
||||||
think: (_a = ollamaOptions == null ? void 0 : ollamaOptions.think) != null ? _a : false,
|
|
||||||
- options: (_b = ollamaOptions == null ? void 0 : ollamaOptions.options) != null ? _b : void 0
|
|
||||||
+ options: {
|
|
||||||
+ ...temperature !== void 0 && { temperature },
|
|
||||||
+ ...topP !== void 0 && { top_p: topP },
|
|
||||||
+ ...((_b = ollamaOptions == null ? void 0 : ollamaOptions.options) != null ? _b : {})
|
|
||||||
+ }
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
diff --git a/dist/index.mjs b/dist/index.mjs
|
|
||||||
index e2a634a78d80ac9542f2cc4f96cf2291094b10cf..67b23efce3c1cf4f026693d3ff9246988a3ef26e 100644
|
|
||||||
--- a/dist/index.mjs
|
|
||||||
+++ b/dist/index.mjs
|
|
||||||
@@ -144,7 +144,7 @@ function getResponseMetadata({
|
|
||||||
|
|
||||||
// src/completion/ollama-completion-language-model.ts
|
|
||||||
var ollamaCompletionProviderOptions = z2.object({
|
|
||||||
- think: z2.boolean().optional(),
|
|
||||||
+ think: z2.union([z2.boolean(), z2.enum(['low', 'medium', 'high'])]).optional(),
|
|
||||||
user: z2.string().optional(),
|
|
||||||
suffix: z2.string().optional(),
|
|
||||||
echo: z2.boolean().optional()
|
|
||||||
@@ -662,7 +662,7 @@ function convertToOllamaChatMessages({
|
|
||||||
const images = content.filter((part) => part.type === "file" && part.mediaType.startsWith("image/")).map((part) => part.data);
|
|
||||||
messages.push({
|
|
||||||
role: "user",
|
|
||||||
- content: userText.length > 0 ? userText : [],
|
|
||||||
+ content: userText.length > 0 ? userText : '',
|
|
||||||
images: images.length > 0 ? images : void 0
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
@@ -815,9 +815,11 @@ var ollamaProviderOptions = z4.object({
|
|
||||||
* the model's thinking from the model's output. When disabled, the model will not think
|
|
||||||
* and directly output the content.
|
|
||||||
*
|
|
||||||
+ * For gpt-oss models, you can also use 'low', 'medium', or 'high' to control the depth of thinking.
|
|
||||||
+ *
|
|
||||||
* Only supported by certain models like DeepSeek R1 and Qwen 3.
|
|
||||||
*/
|
|
||||||
- think: z4.boolean().optional(),
|
|
||||||
+ think: z4.union([z4.boolean(), z4.enum(['low', 'medium', 'high'])]).optional(),
|
|
||||||
options: z4.object({
|
|
||||||
num_ctx: z4.number().optional(),
|
|
||||||
repeat_last_n: z4.number().optional(),
|
|
||||||
@@ -931,14 +933,16 @@ var OllamaRequestBuilder = class {
|
|
||||||
prompt,
|
|
||||||
systemMessageMode: "system"
|
|
||||||
}),
|
|
||||||
- temperature,
|
|
||||||
- top_p: topP,
|
|
||||||
max_output_tokens: maxOutputTokens,
|
|
||||||
...(responseFormat == null ? void 0 : responseFormat.type) === "json" && {
|
|
||||||
format: responseFormat.schema != null ? responseFormat.schema : "json"
|
|
||||||
},
|
|
||||||
think: (_a = ollamaOptions == null ? void 0 : ollamaOptions.think) != null ? _a : false,
|
|
||||||
- options: (_b = ollamaOptions == null ? void 0 : ollamaOptions.options) != null ? _b : void 0
|
|
||||||
+ options: {
|
|
||||||
+ ...temperature !== void 0 && { temperature },
|
|
||||||
+ ...topP !== void 0 && { top_p: topP },
|
|
||||||
+ ...((_b = ollamaOptions == null ? void 0 : ollamaOptions.options) != null ? _b : {})
|
|
||||||
+ }
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
11
CLAUDE.md
11
CLAUDE.md
@ -28,7 +28,7 @@ When creating a Pull Request, you MUST:
|
|||||||
- **Development**: `yarn dev` - Runs Electron app in development mode with hot reload
|
- **Development**: `yarn dev` - Runs Electron app in development mode with hot reload
|
||||||
- **Debug**: `yarn debug` - Starts with debugging enabled, use `chrome://inspect` to attach debugger
|
- **Debug**: `yarn debug` - Starts with debugging enabled, use `chrome://inspect` to attach debugger
|
||||||
- **Build Check**: `yarn build:check` - **REQUIRED** before commits (lint + test + typecheck)
|
- **Build Check**: `yarn build:check` - **REQUIRED** before commits (lint + test + typecheck)
|
||||||
- If having i18n sort issues, run `yarn i18n:sync` first to sync template
|
- If having i18n sort issues, run `yarn sync:i18n` first to sync template
|
||||||
- If having formatting issues, run `yarn format` first
|
- If having formatting issues, run `yarn format` first
|
||||||
- **Test**: `yarn test` - Run all tests (Vitest) across main and renderer processes
|
- **Test**: `yarn test` - Run all tests (Vitest) across main and renderer processes
|
||||||
- **Single Test**:
|
- **Single Test**:
|
||||||
@ -40,23 +40,20 @@ When creating a Pull Request, you MUST:
|
|||||||
## Project Architecture
|
## Project Architecture
|
||||||
|
|
||||||
### Electron Structure
|
### Electron Structure
|
||||||
|
|
||||||
- **Main Process** (`src/main/`): Node.js backend with services (MCP, Knowledge, Storage, etc.)
|
- **Main Process** (`src/main/`): Node.js backend with services (MCP, Knowledge, Storage, etc.)
|
||||||
- **Renderer Process** (`src/renderer/`): React UI with Redux state management
|
- **Renderer Process** (`src/renderer/`): React UI with Redux state management
|
||||||
- **Preload Scripts** (`src/preload/`): Secure IPC bridge
|
- **Preload Scripts** (`src/preload/`): Secure IPC bridge
|
||||||
|
|
||||||
### Key Components
|
### Key Components
|
||||||
|
|
||||||
- **AI Core** (`src/renderer/src/aiCore/`): Middleware pipeline for multiple AI providers.
|
- **AI Core** (`src/renderer/src/aiCore/`): Middleware pipeline for multiple AI providers.
|
||||||
- **Services** (`src/main/services/`): MCPService, KnowledgeService, WindowService, etc.
|
- **Services** (`src/main/services/`): MCPService, KnowledgeService, WindowService, etc.
|
||||||
- **Build System**: Electron-Vite with experimental rolldown-vite, yarn workspaces.
|
- **Build System**: Electron-Vite with experimental rolldown-vite, yarn workspaces.
|
||||||
- **State Management**: Redux Toolkit (`src/renderer/src/store/`) for predictable state.
|
- **State Management**: Redux Toolkit (`src/renderer/src/store/`) for predictable state.
|
||||||
|
|
||||||
### Logging
|
### Logging
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
import { loggerService } from "@logger";
|
import { loggerService } from '@logger'
|
||||||
const logger = loggerService.withContext("moduleName");
|
const logger = loggerService.withContext('moduleName')
|
||||||
// Renderer: loggerService.initWindowSource('windowName') first
|
// Renderer: loggerService.initWindowSource('windowName') first
|
||||||
logger.info("message", CONTEXT);
|
logger.info('message', CONTEXT)
|
||||||
```
|
```
|
||||||
|
|||||||
@ -23,7 +23,7 @@
|
|||||||
},
|
},
|
||||||
"files": {
|
"files": {
|
||||||
"ignoreUnknown": false,
|
"ignoreUnknown": false,
|
||||||
"includes": ["**", "!**/.claude/**", "!**/.vscode/**", "!**/.conductor/**"],
|
"includes": ["**", "!**/.claude/**", "!**/.vscode/**"],
|
||||||
"maxSize": 2097152
|
"maxSize": 2097152
|
||||||
},
|
},
|
||||||
"formatter": {
|
"formatter": {
|
||||||
|
|||||||
@ -12,13 +12,8 @@
|
|||||||
|
|
||||||
; https://github.com/electron-userland/electron-builder/issues/1122
|
; https://github.com/electron-userland/electron-builder/issues/1122
|
||||||
!ifndef BUILD_UNINSTALLER
|
!ifndef BUILD_UNINSTALLER
|
||||||
; Check VC++ Redistributable based on architecture stored in $1
|
|
||||||
Function checkVCRedist
|
Function checkVCRedist
|
||||||
${If} $1 == "arm64"
|
ReadRegDWORD $0 HKLM "SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\x64" "Installed"
|
||||||
ReadRegDWORD $0 HKLM "SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\ARM64" "Installed"
|
|
||||||
${Else}
|
|
||||||
ReadRegDWORD $0 HKLM "SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\x64" "Installed"
|
|
||||||
${EndIf}
|
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
|
|
||||||
Function checkArchitectureCompatibility
|
Function checkArchitectureCompatibility
|
||||||
@ -102,47 +97,29 @@
|
|||||||
|
|
||||||
Call checkVCRedist
|
Call checkVCRedist
|
||||||
${If} $0 != "1"
|
${If} $0 != "1"
|
||||||
; VC++ is required - install automatically since declining would abort anyway
|
MessageBox MB_YESNO "\
|
||||||
; Select download URL based on system architecture (stored in $1)
|
NOTE: ${PRODUCT_NAME} requires $\r$\n\
|
||||||
${If} $1 == "arm64"
|
'Microsoft Visual C++ Redistributable'$\r$\n\
|
||||||
StrCpy $2 "https://aka.ms/vs/17/release/vc_redist.arm64.exe"
|
to function properly.$\r$\n$\r$\n\
|
||||||
StrCpy $3 "$TEMP\vc_redist.arm64.exe"
|
Download and install now?" /SD IDYES IDYES InstallVCRedist IDNO DontInstall
|
||||||
${Else}
|
InstallVCRedist:
|
||||||
StrCpy $2 "https://aka.ms/vs/17/release/vc_redist.x64.exe"
|
inetc::get /CAPTION " " /BANNER "Downloading Microsoft Visual C++ Redistributable..." "https://aka.ms/vs/17/release/vc_redist.x64.exe" "$TEMP\vc_redist.x64.exe"
|
||||||
StrCpy $3 "$TEMP\vc_redist.x64.exe"
|
ExecWait "$TEMP\vc_redist.x64.exe /install /norestart"
|
||||||
${EndIf}
|
;IfErrors InstallError ContinueInstall ; vc_redist exit code is unreliable :(
|
||||||
|
Call checkVCRedist
|
||||||
|
${If} $0 == "1"
|
||||||
|
Goto ContinueInstall
|
||||||
|
${EndIf}
|
||||||
|
|
||||||
inetc::get /CAPTION " " /BANNER "Downloading Microsoft Visual C++ Redistributable..." \
|
;InstallError:
|
||||||
$2 $3 /END
|
MessageBox MB_ICONSTOP "\
|
||||||
Pop $0 ; Get download status from inetc::get
|
There was an unexpected error installing$\r$\n\
|
||||||
${If} $0 != "OK"
|
Microsoft Visual C++ Redistributable.$\r$\n\
|
||||||
MessageBox MB_ICONSTOP|MB_YESNO "\
|
The installation of ${PRODUCT_NAME} cannot continue."
|
||||||
Failed to download Microsoft Visual C++ Redistributable.$\r$\n$\r$\n\
|
DontInstall:
|
||||||
Error: $0$\r$\n$\r$\n\
|
|
||||||
Would you like to open the download page in your browser?$\r$\n\
|
|
||||||
$2" IDYES openDownloadUrl IDNO skipDownloadUrl
|
|
||||||
openDownloadUrl:
|
|
||||||
ExecShell "open" $2
|
|
||||||
skipDownloadUrl:
|
|
||||||
Abort
|
Abort
|
||||||
${EndIf}
|
|
||||||
|
|
||||||
ExecWait "$3 /install /quiet /norestart"
|
|
||||||
; Note: vc_redist exit code is unreliable, verify via registry check instead
|
|
||||||
|
|
||||||
Call checkVCRedist
|
|
||||||
${If} $0 != "1"
|
|
||||||
MessageBox MB_ICONSTOP|MB_YESNO "\
|
|
||||||
Microsoft Visual C++ Redistributable installation failed.$\r$\n$\r$\n\
|
|
||||||
Would you like to open the download page in your browser?$\r$\n\
|
|
||||||
$2$\r$\n$\r$\n\
|
|
||||||
The installation of ${PRODUCT_NAME} cannot continue." IDYES openInstallUrl IDNO skipInstallUrl
|
|
||||||
openInstallUrl:
|
|
||||||
ExecShell "open" $2
|
|
||||||
skipInstallUrl:
|
|
||||||
Abort
|
|
||||||
${EndIf}
|
|
||||||
${EndIf}
|
${EndIf}
|
||||||
|
ContinueInstall:
|
||||||
Pop $4
|
Pop $4
|
||||||
Pop $3
|
Pop $3
|
||||||
Pop $2
|
Pop $2
|
||||||
|
|||||||
@ -71,7 +71,7 @@ Tools like i18n Ally cannot parse dynamic content within template strings, resul
|
|||||||
|
|
||||||
```javascript
|
```javascript
|
||||||
// Not recommended - Plugin cannot resolve
|
// Not recommended - Plugin cannot resolve
|
||||||
const message = t(`fruits.${fruit}`);
|
const message = t(`fruits.${fruit}`)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 2. **No Real-time Rendering in Editor**
|
#### 2. **No Real-time Rendering in Editor**
|
||||||
@ -91,14 +91,14 @@ For example:
|
|||||||
```ts
|
```ts
|
||||||
// src/renderer/src/i18n/label.ts
|
// src/renderer/src/i18n/label.ts
|
||||||
const themeModeKeyMap = {
|
const themeModeKeyMap = {
|
||||||
dark: "settings.theme.dark",
|
dark: 'settings.theme.dark',
|
||||||
light: "settings.theme.light",
|
light: 'settings.theme.light',
|
||||||
system: "settings.theme.system",
|
system: 'settings.theme.system'
|
||||||
} as const;
|
} as const
|
||||||
|
|
||||||
export const getThemeModeLabel = (key: string): string => {
|
export const getThemeModeLabel = (key: string): string => {
|
||||||
return themeModeKeyMap[key] ? t(themeModeKeyMap[key]) : key;
|
return themeModeKeyMap[key] ? t(themeModeKeyMap[key]) : key
|
||||||
};
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
By avoiding template strings, you gain better developer experience, more reliable translation checks, and a more maintainable codebase.
|
By avoiding template strings, you gain better developer experience, more reliable translation checks, and a more maintainable codebase.
|
||||||
@ -107,7 +107,7 @@ By avoiding template strings, you gain better developer experience, more reliabl
|
|||||||
|
|
||||||
The project includes several scripts to automate i18n-related tasks:
|
The project includes several scripts to automate i18n-related tasks:
|
||||||
|
|
||||||
### `i18n:check` - Validate i18n Structure
|
### `check:i18n` - Validate i18n Structure
|
||||||
|
|
||||||
This script checks:
|
This script checks:
|
||||||
|
|
||||||
@ -116,10 +116,10 @@ This script checks:
|
|||||||
- Whether keys are properly sorted
|
- Whether keys are properly sorted
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
yarn i18n:check
|
yarn check:i18n
|
||||||
```
|
```
|
||||||
|
|
||||||
### `i18n:sync` - Synchronize JSON Structure and Sort Order
|
### `sync:i18n` - Synchronize JSON Structure and Sort Order
|
||||||
|
|
||||||
This script uses `zh-cn.json` as the source of truth to sync structure across all language files, including:
|
This script uses `zh-cn.json` as the source of truth to sync structure across all language files, including:
|
||||||
|
|
||||||
@ -128,14 +128,14 @@ This script uses `zh-cn.json` as the source of truth to sync structure across al
|
|||||||
3. Sorting keys automatically
|
3. Sorting keys automatically
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
yarn i18n:sync
|
yarn sync:i18n
|
||||||
```
|
```
|
||||||
|
|
||||||
### `i18n:translate` - Automatically Translate Pending Texts
|
### `auto:i18n` - Automatically Translate Pending Texts
|
||||||
|
|
||||||
This script fills in texts marked as `[to be translated]` using machine translation.
|
This script fills in texts marked as `[to be translated]` using machine translation.
|
||||||
|
|
||||||
Typically, after adding new texts in `zh-cn.json`, run `i18n:sync`, then `i18n:translate` to complete translations.
|
Typically, after adding new texts in `zh-cn.json`, run `sync:i18n`, then `auto:i18n` to complete translations.
|
||||||
|
|
||||||
Before using this script, set the required environment variables:
|
Before using this script, set the required environment variables:
|
||||||
|
|
||||||
@ -148,20 +148,30 @@ MODEL="qwen-plus-latest"
|
|||||||
Alternatively, add these variables directly to your `.env` file.
|
Alternatively, add these variables directly to your `.env` file.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
yarn i18n:translate
|
yarn auto:i18n
|
||||||
|
```
|
||||||
|
|
||||||
|
### `update:i18n` - Object-level Translation Update
|
||||||
|
|
||||||
|
Updates translations in language files under `src/renderer/src/i18n/translate` at the object level, preserving existing translations and only updating new content.
|
||||||
|
|
||||||
|
**Not recommended** — prefer `auto:i18n` for translation tasks.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
yarn update:i18n
|
||||||
```
|
```
|
||||||
|
|
||||||
### Workflow
|
### Workflow
|
||||||
|
|
||||||
1. During development, first add the required text in `zh-cn.json`
|
1. During development, first add the required text in `zh-cn.json`
|
||||||
2. Confirm it displays correctly in the Chinese environment
|
2. Confirm it displays correctly in the Chinese environment
|
||||||
3. Run `yarn i18n:sync` to propagate the keys to other language files
|
3. Run `yarn sync:i18n` to propagate the keys to other language files
|
||||||
4. Run `yarn i18n:translate` to perform machine translation
|
4. Run `yarn auto:i18n` to perform machine translation
|
||||||
5. Grab a coffee and let the magic happen!
|
5. Grab a coffee and let the magic happen!
|
||||||
|
|
||||||
## Best Practices
|
## Best Practices
|
||||||
|
|
||||||
1. **Use Chinese as Source Language**: All development starts in Chinese, then translates to other languages.
|
1. **Use Chinese as Source Language**: All development starts in Chinese, then translates to other languages.
|
||||||
2. **Run Check Script Before Commit**: Use `yarn i18n:check` to catch i18n issues early.
|
2. **Run Check Script Before Commit**: Use `yarn check:i18n` to catch i18n issues early.
|
||||||
3. **Translate in Small Increments**: Avoid accumulating a large backlog of untranslated content.
|
3. **Translate in Small Increments**: Avoid accumulating a large backlog of untranslated content.
|
||||||
4. **Keep Keys Semantically Clear**: Keys should clearly express their purpose, e.g., `user.profile.avatar.upload.error`
|
4. **Keep Keys Semantically Clear**: Keys should clearly express their purpose, e.g., `user.profile.avatar.upload.error`
|
||||||
|
|||||||
@ -1,17 +1,17 @@
|
|||||||
# 如何优雅地做好 i18n
|
# 如何优雅地做好 i18n
|
||||||
|
|
||||||
## 使用 i18n ally 插件提升开发体验
|
## 使用i18n ally插件提升开发体验
|
||||||
|
|
||||||
i18n ally 是一个强大的 VSCode 插件,它能在开发阶段提供实时反馈,帮助开发者更早发现文案缺失和错译问题。
|
i18n ally是一个强大的VSCode插件,它能在开发阶段提供实时反馈,帮助开发者更早发现文案缺失和错译问题。
|
||||||
|
|
||||||
项目中已经配置好了插件设置,直接安装即可。
|
项目中已经配置好了插件设置,直接安装即可。
|
||||||
|
|
||||||
### 开发时优势
|
### 开发时优势
|
||||||
|
|
||||||
- **实时预览**:翻译文案会直接显示在编辑器中
|
- **实时预览**:翻译文案会直接显示在编辑器中
|
||||||
- **错误检测**:自动追踪标记出缺失的翻译或未使用的 key
|
- **错误检测**:自动追踪标记出缺失的翻译或未使用的key
|
||||||
- **快速跳转**:可通过 key 直接跳转到定义处(Ctrl/Cmd + click)
|
- **快速跳转**:可通过key直接跳转到定义处(Ctrl/Cmd + click)
|
||||||
- **自动补全**:输入 i18n key 时提供自动补全建议
|
- **自动补全**:输入i18n key时提供自动补全建议
|
||||||
|
|
||||||
### 效果展示
|
### 效果展示
|
||||||
|
|
||||||
@ -23,9 +23,9 @@ i18n ally 是一个强大的 VSCode 插件,它能在开发阶段提供实时
|
|||||||
|
|
||||||
## i18n 约定
|
## i18n 约定
|
||||||
|
|
||||||
### **绝对避免使用 flat 格式**
|
### **绝对避免使用flat格式**
|
||||||
|
|
||||||
绝对避免使用 flat 格式,如`"add.button.tip": "添加"`。应采用清晰的嵌套结构:
|
绝对避免使用flat格式,如`"add.button.tip": "添加"`。应采用清晰的嵌套结构:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
// 错误示例 - flat结构
|
// 错误示例 - flat结构
|
||||||
@ -52,14 +52,14 @@ i18n ally 是一个强大的 VSCode 插件,它能在开发阶段提供实时
|
|||||||
#### 为什么要使用嵌套结构
|
#### 为什么要使用嵌套结构
|
||||||
|
|
||||||
1. **自然分组**:通过对象结构天然能将相关上下文的文案分到一个组别中
|
1. **自然分组**:通过对象结构天然能将相关上下文的文案分到一个组别中
|
||||||
2. **插件要求**:i18n ally 插件需要嵌套或 flat 格式其一的文件才能正常分析
|
2. **插件要求**:i18n ally 插件需要嵌套或flat格式其一的文件才能正常分析
|
||||||
|
|
||||||
### **避免在`t()`中使用模板字符串**
|
### **避免在`t()`中使用模板字符串**
|
||||||
|
|
||||||
**强烈建议避免使用模板字符串**进行动态插值。虽然模板字符串在 JavaScript 开发中非常方便,但在国际化场景下会带来一系列问题。
|
**强烈建议避免使用模板字符串**进行动态插值。虽然模板字符串在JavaScript开发中非常方便,但在国际化场景下会带来一系列问题。
|
||||||
|
|
||||||
1. **插件无法跟踪**
|
1. **插件无法跟踪**
|
||||||
i18n ally 等工具无法解析模板字符串中的动态内容,导致:
|
i18n ally等工具无法解析模板字符串中的动态内容,导致:
|
||||||
|
|
||||||
- 无法正确显示实时预览
|
- 无法正确显示实时预览
|
||||||
- 无法检测翻译缺失
|
- 无法检测翻译缺失
|
||||||
@ -67,11 +67,11 @@ i18n ally 是一个强大的 VSCode 插件,它能在开发阶段提供实时
|
|||||||
|
|
||||||
```javascript
|
```javascript
|
||||||
// 不推荐 - 插件无法解析
|
// 不推荐 - 插件无法解析
|
||||||
const message = t(`fruits.${fruit}`);
|
const message = t(`fruits.${fruit}`)
|
||||||
```
|
```
|
||||||
|
|
||||||
2. **编辑器无法实时渲染**
|
2. **编辑器无法实时渲染**
|
||||||
在 IDE 中,模板字符串会显示为原始代码而非最终翻译结果,降低了开发体验。
|
在IDE中,模板字符串会显示为原始代码而非最终翻译结果,降低了开发体验。
|
||||||
|
|
||||||
3. **更难以维护**
|
3. **更难以维护**
|
||||||
由于插件无法跟踪这样的文案,编辑器中也无法渲染,开发者必须人工确认语言文件中是否存在相应的文案。
|
由于插件无法跟踪这样的文案,编辑器中也无法渲染,开发者必须人工确认语言文件中是否存在相应的文案。
|
||||||
@ -85,36 +85,36 @@ i18n ally 是一个强大的 VSCode 插件,它能在开发阶段提供实时
|
|||||||
```ts
|
```ts
|
||||||
// src/renderer/src/i18n/label.ts
|
// src/renderer/src/i18n/label.ts
|
||||||
const themeModeKeyMap = {
|
const themeModeKeyMap = {
|
||||||
dark: "settings.theme.dark",
|
dark: 'settings.theme.dark',
|
||||||
light: "settings.theme.light",
|
light: 'settings.theme.light',
|
||||||
system: "settings.theme.system",
|
system: 'settings.theme.system'
|
||||||
} as const;
|
} as const
|
||||||
|
|
||||||
export const getThemeModeLabel = (key: string): string => {
|
export const getThemeModeLabel = (key: string): string => {
|
||||||
return themeModeKeyMap[key] ? t(themeModeKeyMap[key]) : key;
|
return themeModeKeyMap[key] ? t(themeModeKeyMap[key]) : key
|
||||||
};
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
通过避免模板字符串,可以获得更好的开发体验、更可靠的翻译检查以及更易维护的代码库。
|
通过避免模板字符串,可以获得更好的开发体验、更可靠的翻译检查以及更易维护的代码库。
|
||||||
|
|
||||||
## 自动化脚本
|
## 自动化脚本
|
||||||
|
|
||||||
项目中有一系列脚本来自动化 i18n 相关任务:
|
项目中有一系列脚本来自动化i18n相关任务:
|
||||||
|
|
||||||
### `i18n:check` - 检查 i18n 结构
|
### `check:i18n` - 检查i18n结构
|
||||||
|
|
||||||
此脚本会检查:
|
此脚本会检查:
|
||||||
|
|
||||||
- 所有语言文件是否为嵌套结构
|
- 所有语言文件是否为嵌套结构
|
||||||
- 是否存在缺失的 key
|
- 是否存在缺失的key
|
||||||
- 是否存在多余的 key
|
- 是否存在多余的key
|
||||||
- 是否已经有序
|
- 是否已经有序
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
yarn i18n:check
|
yarn check:i18n
|
||||||
```
|
```
|
||||||
|
|
||||||
### `i18n:sync` - 同步 json 结构与排序
|
### `sync:i18n` - 同步json结构与排序
|
||||||
|
|
||||||
此脚本以`zh-cn.json`文件为基准,将结构同步到其他语言文件,包括:
|
此脚本以`zh-cn.json`文件为基准,将结构同步到其他语言文件,包括:
|
||||||
|
|
||||||
@ -123,14 +123,14 @@ yarn i18n:check
|
|||||||
3. 自动排序
|
3. 自动排序
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
yarn i18n:sync
|
yarn sync:i18n
|
||||||
```
|
```
|
||||||
|
|
||||||
### `i18n:translate` - 自动翻译待翻译文本
|
### `auto:i18n` - 自动翻译待翻译文本
|
||||||
|
|
||||||
次脚本自动将标记为待翻译的文本通过机器翻译填充。
|
次脚本自动将标记为待翻译的文本通过机器翻译填充。
|
||||||
|
|
||||||
通常,在`zh-cn.json`中添加所需文案后,执行`i18n:sync`即可自动完成翻译。
|
通常,在`zh-cn.json`中添加所需文案后,执行`sync:i18n`即可自动完成翻译。
|
||||||
|
|
||||||
使用该脚本前,需要配置环境变量,例如:
|
使用该脚本前,需要配置环境变量,例如:
|
||||||
|
|
||||||
@ -143,19 +143,29 @@ MODEL="qwen-plus-latest"
|
|||||||
你也可以通过直接编辑`.env`文件来添加环境变量。
|
你也可以通过直接编辑`.env`文件来添加环境变量。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
yarn i18n:translate
|
yarn auto:i18n
|
||||||
|
```
|
||||||
|
|
||||||
|
### `update:i18n` - 对象级别翻译更新
|
||||||
|
|
||||||
|
对`src/renderer/src/i18n/translate`中的语言文件进行对象级别的翻译更新,保留已有翻译,只更新新增内容。
|
||||||
|
|
||||||
|
**不建议**使用该脚本,更推荐使用`auto:i18n`进行翻译。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
yarn update:i18n
|
||||||
```
|
```
|
||||||
|
|
||||||
### 工作流
|
### 工作流
|
||||||
|
|
||||||
1. 开发阶段,先在`zh-cn.json`中添加所需文案
|
1. 开发阶段,先在`zh-cn.json`中添加所需文案
|
||||||
2. 确认在中文环境下显示无误后,使用`yarn i18n:sync`将文案同步到其他语言文件
|
2. 确认在中文环境下显示无误后,使用`yarn sync:i18n`将文案同步到其他语言文件
|
||||||
3. 使用`yarn i18n:translate`进行自动翻译
|
3. 使用`yarn auto:i18n`进行自动翻译
|
||||||
4. 喝杯咖啡,等翻译完成吧!
|
4. 喝杯咖啡,等翻译完成吧!
|
||||||
|
|
||||||
## 最佳实践
|
## 最佳实践
|
||||||
|
|
||||||
1. **以中文为源语言**:所有开发首先使用中文,再翻译为其他语言
|
1. **以中文为源语言**:所有开发首先使用中文,再翻译为其他语言
|
||||||
2. **提交前运行检查脚本**:使用`yarn i18n:check`检查 i18n 是否有问题
|
2. **提交前运行检查脚本**:使用`yarn check:i18n`检查i18n是否有问题
|
||||||
3. **小步提交翻译**:避免积累大量未翻译文本
|
3. **小步提交翻译**:避免积累大量未翻译文本
|
||||||
4. **保持 key 语义明确**:key 应能清晰表达其用途,如`user.profile.avatar.upload.error`
|
4. **保持key语义明确**:key应能清晰表达其用途,如`user.profile.avatar.upload.error`
|
||||||
|
|||||||
@ -1,850 +0,0 @@
|
|||||||
# Cherry Studio 局域网传输协议规范
|
|
||||||
|
|
||||||
> 版本: 1.0
|
|
||||||
> 最后更新: 2025-12
|
|
||||||
|
|
||||||
本文档定义了 Cherry Studio 桌面客户端(Electron)与移动端(Expo)之间的局域网文件传输协议。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 目录
|
|
||||||
|
|
||||||
1. [协议概述](#1-协议概述)
|
|
||||||
2. [服务发现(Bonjour/mDNS)](#2-服务发现bonjourmdns)
|
|
||||||
3. [TCP 连接与握手](#3-tcp-连接与握手)
|
|
||||||
4. [消息格式规范](#4-消息格式规范)
|
|
||||||
5. [文件传输协议](#5-文件传输协议)
|
|
||||||
6. [心跳与连接保活](#6-心跳与连接保活)
|
|
||||||
7. [错误处理](#7-错误处理)
|
|
||||||
8. [常量与配置](#8-常量与配置)
|
|
||||||
9. [完整时序图](#9-完整时序图)
|
|
||||||
10. [移动端实现指南](#10-移动端实现指南)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 1. 协议概述
|
|
||||||
|
|
||||||
### 1.1 架构角色
|
|
||||||
|
|
||||||
| 角色 | 平台 | 职责 |
|
|
||||||
| -------------------- | --------------- | ---------------------------- |
|
|
||||||
| **Client(客户端)** | Electron 桌面端 | 扫描服务、发起连接、发送文件 |
|
|
||||||
| **Server(服务端)** | Expo 移动端 | 发布服务、接受连接、接收文件 |
|
|
||||||
|
|
||||||
### 1.2 协议栈(v1)
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────────────────────────┐
|
|
||||||
│ 应用层(文件传输) │
|
|
||||||
├─────────────────────────────────────┤
|
|
||||||
│ 消息层(控制: JSON \n) │
|
|
||||||
│ (数据: 二进制帧) │
|
|
||||||
├─────────────────────────────────────┤
|
|
||||||
│ 传输层(TCP) │
|
|
||||||
├─────────────────────────────────────┤
|
|
||||||
│ 发现层(Bonjour/mDNS) │
|
|
||||||
└─────────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### 1.3 通信流程概览
|
|
||||||
|
|
||||||
```
|
|
||||||
1. 服务发现 → 移动端发布 mDNS 服务,桌面端扫描发现
|
|
||||||
2. TCP 握手 → 建立连接,交换设备信息(`version=1`)
|
|
||||||
3. 文件传输 → 控制消息使用 JSON,`file_chunk` 使用二进制帧分块传输
|
|
||||||
4. 连接保活 → ping/pong 心跳
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. 服务发现(Bonjour/mDNS)
|
|
||||||
|
|
||||||
### 2.1 服务类型
|
|
||||||
|
|
||||||
| 属性 | 值 |
|
|
||||||
| ------------ | -------------------- |
|
|
||||||
| 服务类型 | `cherrystudio` |
|
|
||||||
| 协议 | `tcp` |
|
|
||||||
| 完整服务标识 | `_cherrystudio._tcp` |
|
|
||||||
|
|
||||||
### 2.2 服务发布(移动端)
|
|
||||||
|
|
||||||
移动端需要通过 mDNS/Bonjour 发布服务:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
// 服务发布参数
|
|
||||||
{
|
|
||||||
name: "Cherry Studio Mobile", // 设备名称
|
|
||||||
type: "cherrystudio", // 服务类型
|
|
||||||
protocol: "tcp", // 协议
|
|
||||||
port: 53317, // TCP 监听端口
|
|
||||||
txt: { // TXT 记录(可选)
|
|
||||||
version: "1",
|
|
||||||
platform: "ios" // 或 "android"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2.3 服务发现(桌面端)
|
|
||||||
|
|
||||||
桌面端扫描并解析服务信息:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
// 发现的服务信息结构
|
|
||||||
type LocalTransferPeer = {
|
|
||||||
id: string; // 唯一标识符
|
|
||||||
name: string; // 设备名称
|
|
||||||
host?: string; // 主机名
|
|
||||||
fqdn?: string; // 完全限定域名
|
|
||||||
port?: number; // TCP 端口
|
|
||||||
type?: string; // 服务类型
|
|
||||||
protocol?: "tcp" | "udp"; // 协议
|
|
||||||
addresses: string[]; // IP 地址列表
|
|
||||||
txt?: Record<string, string>; // TXT 记录
|
|
||||||
updatedAt: number; // 发现时间戳
|
|
||||||
};
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2.4 IP 地址选择策略
|
|
||||||
|
|
||||||
当服务有多个 IP 地址时,优先选择 IPv4:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
// 优先选择 IPv4 地址
|
|
||||||
const preferredAddress = addresses.find((addr) => isIPv4(addr)) || addresses[0];
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. TCP 连接与握手
|
|
||||||
|
|
||||||
### 3.1 连接建立
|
|
||||||
|
|
||||||
1. 客户端使用发现的 `host:port` 建立 TCP 连接
|
|
||||||
2. 连接成功后立即发送握手消息
|
|
||||||
3. 等待服务端响应握手确认
|
|
||||||
|
|
||||||
### 3.2 握手消息(协议版本 v1)
|
|
||||||
|
|
||||||
#### Client → Server: `handshake`
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
type LanTransferHandshakeMessage = {
|
|
||||||
type: "handshake";
|
|
||||||
deviceName: string; // 设备名称
|
|
||||||
version: string; // 协议版本,当前为 "1"
|
|
||||||
platform?: string; // 平台:'darwin' | 'win32' | 'linux'
|
|
||||||
appVersion?: string; // 应用版本
|
|
||||||
};
|
|
||||||
```
|
|
||||||
|
|
||||||
**示例:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "handshake",
|
|
||||||
"deviceName": "Cherry Studio 1.7.2",
|
|
||||||
"version": "1",
|
|
||||||
"platform": "darwin",
|
|
||||||
"appVersion": "1.7.2"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. 消息格式规范(混合协议)
|
|
||||||
|
|
||||||
v1 使用"控制 JSON + 二进制数据帧"的混合协议(流式传输模式,无 per-chunk ACK):
|
|
||||||
|
|
||||||
- **控制消息**(握手、心跳、file_start/ack、file_end、file_complete):UTF-8 JSON,`\n` 分隔
|
|
||||||
- **数据消息**(`file_chunk`):二进制帧,使用 Magic + 总长度做分帧,不经 Base64
|
|
||||||
|
|
||||||
### 4.1 控制消息编码(JSON + `\n`)
|
|
||||||
|
|
||||||
| 属性 | 规范 |
|
|
||||||
| ---------- | ------------ |
|
|
||||||
| 编码格式 | UTF-8 |
|
|
||||||
| 序列化格式 | JSON |
|
|
||||||
| 消息分隔符 | `\n`(0x0A) |
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
function sendControlMessage(socket: Socket, message: object): void {
|
|
||||||
socket.write(`${JSON.stringify(message)}\n`);
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4.2 `file_chunk` 二进制帧格式
|
|
||||||
|
|
||||||
为解决 TCP 分包/粘包并消除 Base64 开销,`file_chunk` 采用带总长度的二进制帧:
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────┬──────────┬────────┬───────────────┬──────────────┬────────────┬───────────┐
|
|
||||||
│ Magic │ TotalLen │ Type │ TransferId Len│ TransferId │ ChunkIdx │ Data │
|
|
||||||
│ 0x43 0x53│ (4B BE) │ 0x01 │ (2B BE) │ (UTF-8) │ (4B BE) │ (raw) │
|
|
||||||
└──────────┴──────────┴────────┴───────────────┴──────────────┴────────────┴───────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
| 字段 | 大小 | 说明 |
|
|
||||||
| -------------- | ---- | ------------------------------------------- |
|
|
||||||
| Magic | 2B | 常量 `0x43 0x53` ("CS"), 用于区分 JSON 消息 |
|
|
||||||
| TotalLen | 4B | Big-endian,帧总长度(不含 Magic/TotalLen) |
|
|
||||||
| Type | 1B | `0x01` 代表 `file_chunk` |
|
|
||||||
| TransferId Len | 2B | Big-endian,transferId 字符串长度 |
|
|
||||||
| TransferId | nB | UTF-8 transferId(长度由上一字段给出) |
|
|
||||||
| ChunkIdx | 4B | Big-endian,块索引,从 0 开始 |
|
|
||||||
| Data | mB | 原始文件二进制数据(未编码) |
|
|
||||||
|
|
||||||
> 计算帧总长度:`TotalLen = 1 + 2 + transferIdLen + 4 + dataLen`(即 Type~Data 的长度和)。
|
|
||||||
|
|
||||||
### 4.3 消息解析策略
|
|
||||||
|
|
||||||
1. 读取 socket 数据到缓冲区;
|
|
||||||
2. 若前两字节为 `0x43 0x53` → 按二进制帧解析:
|
|
||||||
- 至少需要 6 字节头(Magic + TotalLen),不足则等待更多数据
|
|
||||||
- 读取 `TotalLen` 判断帧整体长度,缓冲区不足则继续等待
|
|
||||||
- 解析 Type/TransferId/ChunkIdx/Data,并传入文件接收逻辑
|
|
||||||
3. 否则若首字节为 `{` → 按 JSON + `\n` 解析控制消息
|
|
||||||
4. 其它数据丢弃 1 字节并继续循环,避免阻塞。
|
|
||||||
|
|
||||||
### 4.4 消息类型汇总(v1)
|
|
||||||
|
|
||||||
| 类型 | 方向 | 编码 | 用途 |
|
|
||||||
| ---------------- | --------------- | -------- | ----------------------- |
|
|
||||||
| `handshake` | Client → Server | JSON+\n | 握手请求(version=1) |
|
|
||||||
| `handshake_ack` | Server → Client | JSON+\n | 握手响应 |
|
|
||||||
| `ping` | Client → Server | JSON+\n | 心跳请求 |
|
|
||||||
| `pong` | Server → Client | JSON+\n | 心跳响应 |
|
|
||||||
| `file_start` | Client → Server | JSON+\n | 开始文件传输 |
|
|
||||||
| `file_start_ack` | Server → Client | JSON+\n | 文件传输确认 |
|
|
||||||
| `file_chunk` | Client → Server | 二进制帧 | 文件数据块(无 Base64,流式无 per-chunk ACK) |
|
|
||||||
| `file_end` | Client → Server | JSON+\n | 文件传输结束 |
|
|
||||||
| `file_complete` | Server → Client | JSON+\n | 传输完成结果 |
|
|
||||||
|
|
||||||
```
|
|
||||||
{"type":"message_type",...其他字段...}\n
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 5. 文件传输协议
|
|
||||||
|
|
||||||
### 5.1 传输流程
|
|
||||||
|
|
||||||
```
|
|
||||||
Client (Sender) Server (Receiver)
|
|
||||||
| |
|
|
||||||
|──── 1. file_start ────────────────>|
|
|
||||||
| (文件元数据) |
|
|
||||||
| |
|
|
||||||
|<─── 2. file_start_ack ─────────────|
|
|
||||||
| (接受/拒绝) |
|
|
||||||
| |
|
|
||||||
|══════ 循环发送数据块(流式,无 ACK) ═════|
|
|
||||||
| |
|
|
||||||
|──── 3. file_chunk [0] ────────────>|
|
|
||||||
| |
|
|
||||||
|──── 3. file_chunk [1] ────────────>|
|
|
||||||
| |
|
|
||||||
| ... 重复直到所有块发送完成 ... |
|
|
||||||
| |
|
|
||||||
|══════════════════════════════════════
|
|
||||||
| |
|
|
||||||
|──── 5. file_end ──────────────────>|
|
|
||||||
| (所有块已发送) |
|
|
||||||
| |
|
|
||||||
|<─── 6. file_complete ──────────────|
|
|
||||||
| (最终结果) |
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5.2 消息定义
|
|
||||||
|
|
||||||
#### 5.2.1 `file_start` - 开始传输
|
|
||||||
|
|
||||||
**方向:** Client → Server
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
type LanTransferFileStartMessage = {
|
|
||||||
type: "file_start";
|
|
||||||
transferId: string; // UUID,唯一传输标识
|
|
||||||
fileName: string; // 文件名(含扩展名)
|
|
||||||
fileSize: number; // 文件总字节数
|
|
||||||
mimeType: string; // MIME 类型
|
|
||||||
checksum: string; // 整个文件的 SHA-256 哈希(hex)
|
|
||||||
totalChunks: number; // 总数据块数
|
|
||||||
chunkSize: number; // 每块大小(字节)
|
|
||||||
};
|
|
||||||
```
|
|
||||||
|
|
||||||
**示例:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "file_start",
|
|
||||||
"transferId": "550e8400-e29b-41d4-a716-446655440000",
|
|
||||||
"fileName": "backup.zip",
|
|
||||||
"fileSize": 524288000,
|
|
||||||
"mimeType": "application/zip",
|
|
||||||
"checksum": "a1b2c3d4e5f6789012345678901234567890abcdef1234567890abcdef123456",
|
|
||||||
"totalChunks": 8192,
|
|
||||||
"chunkSize": 65536
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 5.2.2 `file_start_ack` - 传输确认
|
|
||||||
|
|
||||||
**方向:** Server → Client
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
type LanTransferFileStartAckMessage = {
|
|
||||||
type: "file_start_ack";
|
|
||||||
transferId: string; // 对应的传输 ID
|
|
||||||
accepted: boolean; // 是否接受传输
|
|
||||||
message?: string; // 拒绝原因
|
|
||||||
};
|
|
||||||
```
|
|
||||||
|
|
||||||
**接受示例:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "file_start_ack",
|
|
||||||
"transferId": "550e8400-e29b-41d4-a716-446655440000",
|
|
||||||
"accepted": true
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**拒绝示例:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "file_start_ack",
|
|
||||||
"transferId": "550e8400-e29b-41d4-a716-446655440000",
|
|
||||||
"accepted": false,
|
|
||||||
"message": "Insufficient storage space"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 5.2.3 `file_chunk` - 数据块
|
|
||||||
|
|
||||||
**方向:** Client → Server(**二进制帧**,见 4.2)
|
|
||||||
|
|
||||||
- 不再使用 JSON/`\n`,也不再使用 Base64
|
|
||||||
- 帧结构:`Magic` + `TotalLen` + `Type` + `TransferId` + `ChunkIdx` + `Data`
|
|
||||||
- `Type` 固定 `0x01`,`Data` 为原始文件二进制数据
|
|
||||||
- 传输完整性依赖 `file_start.checksum`(全文件 SHA-256);分块校验和可选,不在帧中发送
|
|
||||||
|
|
||||||
#### 5.2.4 `file_chunk_ack` - 数据块确认(v1 流式不使用)
|
|
||||||
|
|
||||||
v1 采用流式传输,不发送 per-chunk ACK。本节类型仅保留作为向后兼容参考,实际不会发送。
|
|
||||||
|
|
||||||
#### 5.2.5 `file_end` - 传输结束
|
|
||||||
|
|
||||||
**方向:** Client → Server
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
type LanTransferFileEndMessage = {
|
|
||||||
type: "file_end";
|
|
||||||
transferId: string; // 传输 ID
|
|
||||||
};
|
|
||||||
```
|
|
||||||
|
|
||||||
**示例:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "file_end",
|
|
||||||
"transferId": "550e8400-e29b-41d4-a716-446655440000"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 5.2.6 `file_complete` - 传输完成
|
|
||||||
|
|
||||||
**方向:** Server → Client
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
type LanTransferFileCompleteMessage = {
|
|
||||||
type: "file_complete";
|
|
||||||
transferId: string; // 传输 ID
|
|
||||||
success: boolean; // 是否成功
|
|
||||||
filePath?: string; // 保存路径(成功时)
|
|
||||||
error?: string; // 错误信息(失败时)
|
|
||||||
};
|
|
||||||
```
|
|
||||||
|
|
||||||
**成功示例:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "file_complete",
|
|
||||||
"transferId": "550e8400-e29b-41d4-a716-446655440000",
|
|
||||||
"success": true,
|
|
||||||
"filePath": "/storage/emulated/0/Documents/backup.zip"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**失败示例:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "file_complete",
|
|
||||||
"transferId": "550e8400-e29b-41d4-a716-446655440000",
|
|
||||||
"success": false,
|
|
||||||
"error": "File checksum verification failed"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5.3 校验和算法
|
|
||||||
|
|
||||||
#### 整个文件校验和(保持不变)
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
async function calculateFileChecksum(filePath: string): Promise<string> {
|
|
||||||
const hash = crypto.createHash("sha256");
|
|
||||||
const stream = fs.createReadStream(filePath);
|
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
|
||||||
hash.update(chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
return hash.digest("hex");
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 数据块校验和
|
|
||||||
|
|
||||||
v1 默认 **不传输分块校验和**,依赖最终文件 checksum。若需要,可在应用层自定义(非协议字段)。
|
|
||||||
|
|
||||||
### 5.4 校验流程
|
|
||||||
|
|
||||||
**发送端(Client):**
|
|
||||||
|
|
||||||
1. 发送前计算整个文件的 SHA-256 → `file_start.checksum`
|
|
||||||
2. 分块直接发送原始二进制(无 Base64)
|
|
||||||
|
|
||||||
**接收端(Server):**
|
|
||||||
|
|
||||||
1. 收到 `file_chunk` 后直接使用二进制数据
|
|
||||||
2. 边收边落盘并增量计算 SHA-256(推荐)
|
|
||||||
3. 所有块接收完成后,计算/完成增量哈希,得到最终 SHA-256
|
|
||||||
4. 与 `file_start.checksum` 比对,结果写入 `file_complete`
|
|
||||||
|
|
||||||
### 5.5 数据块大小计算
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
const CHUNK_SIZE = 512 * 1024; // 512KB
|
|
||||||
|
|
||||||
const totalChunks = Math.ceil(fileSize / CHUNK_SIZE);
|
|
||||||
|
|
||||||
// 最后一个块可能小于 CHUNK_SIZE
|
|
||||||
const lastChunkSize = fileSize % CHUNK_SIZE || CHUNK_SIZE;
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 6. 心跳与连接保活
|
|
||||||
|
|
||||||
### 6.1 心跳消息
|
|
||||||
|
|
||||||
#### `ping`
|
|
||||||
|
|
||||||
**方向:** Client → Server
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
type LanTransferPingMessage = {
|
|
||||||
type: "ping";
|
|
||||||
payload?: string; // 可选载荷
|
|
||||||
};
|
|
||||||
```
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "ping",
|
|
||||||
"payload": "heartbeat"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### `pong`
|
|
||||||
|
|
||||||
**方向:** Server → Client
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
type LanTransferPongMessage = {
|
|
||||||
type: "pong";
|
|
||||||
received: boolean; // 确认收到
|
|
||||||
payload?: string; // 回传 ping 的载荷
|
|
||||||
};
|
|
||||||
```
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "pong",
|
|
||||||
"received": true,
|
|
||||||
"payload": "heartbeat"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.2 心跳策略
|
|
||||||
|
|
||||||
- 握手成功后立即发送一次 `ping` 验证连接
|
|
||||||
- 可选:定期发送心跳保持连接活跃
|
|
||||||
- `pong` 应返回 `ping` 中的 `payload`(可选)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 7. 错误处理
|
|
||||||
|
|
||||||
### 7.1 超时配置
|
|
||||||
|
|
||||||
| 操作 | 超时时间 | 说明 |
|
|
||||||
| ---------- | -------- | --------------------- |
|
|
||||||
| TCP 连接 | 10 秒 | 连接建立超时 |
|
|
||||||
| 握手等待 | 10 秒 | 等待 `handshake_ack` |
|
|
||||||
| 传输完成 | 60 秒 | 等待 `file_complete` |
|
|
||||||
|
|
||||||
### 7.2 错误场景处理
|
|
||||||
|
|
||||||
| 场景 | Client 处理 | Server 处理 |
|
|
||||||
| --------------- | ------------------ | ---------------------- |
|
|
||||||
| TCP 连接失败 | 通知 UI,允许重试 | - |
|
|
||||||
| 握手超时 | 断开连接,通知 UI | 关闭 socket |
|
|
||||||
| 握手被拒绝 | 显示拒绝原因 | - |
|
|
||||||
| 数据块处理失败 | 中止传输,清理状态 | 清理临时文件 |
|
|
||||||
| 连接意外断开 | 清理状态,通知 UI | 清理临时文件 |
|
|
||||||
| 存储空间不足 | - | 发送 `accepted: false` |
|
|
||||||
|
|
||||||
### 7.3 资源清理
|
|
||||||
|
|
||||||
**Client 端:**
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
function cleanup(): void {
|
|
||||||
// 1. 销毁文件读取流
|
|
||||||
if (readStream) {
|
|
||||||
readStream.destroy();
|
|
||||||
}
|
|
||||||
// 2. 清理传输状态
|
|
||||||
activeTransfer = undefined;
|
|
||||||
// 3. 关闭 socket(如需要)
|
|
||||||
socket?.destroy();
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Server 端:**
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
function cleanup(): void {
|
|
||||||
// 1. 关闭文件写入流
|
|
||||||
if (writeStream) {
|
|
||||||
writeStream.end();
|
|
||||||
}
|
|
||||||
// 2. 删除未完成的临时文件
|
|
||||||
if (tempFilePath) {
|
|
||||||
fs.unlinkSync(tempFilePath);
|
|
||||||
}
|
|
||||||
// 3. 清理传输状态
|
|
||||||
activeTransfer = undefined;
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 8. 常量与配置
|
|
||||||
|
|
||||||
### 8.1 协议常量
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
// 协议版本(v1 = 控制 JSON + 二进制 chunk + 流式传输)
|
|
||||||
export const LAN_TRANSFER_PROTOCOL_VERSION = "1";
|
|
||||||
|
|
||||||
// 服务发现
|
|
||||||
export const LAN_TRANSFER_SERVICE_TYPE = "cherrystudio";
|
|
||||||
export const LAN_TRANSFER_SERVICE_FULL_NAME = "_cherrystudio._tcp";
|
|
||||||
|
|
||||||
// TCP 端口
|
|
||||||
export const LAN_TRANSFER_TCP_PORT = 53317;
|
|
||||||
|
|
||||||
// 文件传输(与二进制帧一致)
|
|
||||||
export const LAN_TRANSFER_CHUNK_SIZE = 512 * 1024; // 512KB
|
|
||||||
export const LAN_TRANSFER_GLOBAL_TIMEOUT_MS = 10 * 60 * 1000; // 10 分钟
|
|
||||||
|
|
||||||
// 超时设置
|
|
||||||
export const LAN_TRANSFER_HANDSHAKE_TIMEOUT_MS = 10_000; // 10秒
|
|
||||||
export const LAN_TRANSFER_CHUNK_TIMEOUT_MS = 30_000; // 30秒
|
|
||||||
export const LAN_TRANSFER_COMPLETE_TIMEOUT_MS = 60_000; // 60秒
|
|
||||||
```
|
|
||||||
|
|
||||||
### 8.2 支持的文件类型
|
|
||||||
|
|
||||||
当前仅支持 ZIP 文件:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
export const LAN_TRANSFER_ALLOWED_EXTENSIONS = [".zip"];
|
|
||||||
export const LAN_TRANSFER_ALLOWED_MIME_TYPES = [
|
|
||||||
"application/zip",
|
|
||||||
"application/x-zip-compressed",
|
|
||||||
];
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 9. 完整时序图
|
|
||||||
|
|
||||||
### 9.1 完整传输流程(v1,流式传输)
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────┐ ┌─────────┐ ┌─────────┐
|
|
||||||
│ Renderer│ │ Main │ │ Mobile │
|
|
||||||
│ (UI) │ │ Process │ │ Server │
|
|
||||||
└────┬────┘ └────┬────┘ └────┬────┘
|
|
||||||
│ │ │
|
|
||||||
│ ════════════ 服务发现阶段 ════════════ │
|
|
||||||
│ │ │
|
|
||||||
│ startScan() │ │
|
|
||||||
│────────────────────────────────────>│ │
|
|
||||||
│ │ mDNS browse │
|
|
||||||
│ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─>│
|
|
||||||
│ │ │
|
|
||||||
│ │<─ ─ ─ service discovered ─ ─ ─ ─ ─ ─│
|
|
||||||
│ │ │
|
|
||||||
│<────── onServicesUpdated ───────────│ │
|
|
||||||
│ │ │
|
|
||||||
│ ════════════ 握手连接阶段 ════════════ │
|
|
||||||
│ │ │
|
|
||||||
│ connect(peer) │ │
|
|
||||||
│────────────────────────────────────>│ │
|
|
||||||
│ │──────── TCP Connect ───────────────>│
|
|
||||||
│ │ │
|
|
||||||
│ │──────── handshake ─────────────────>│
|
|
||||||
│ │ │
|
|
||||||
│ │<─────── handshake_ack ──────────────│
|
|
||||||
│ │ │
|
|
||||||
│ │──────── ping ──────────────────────>│
|
|
||||||
│ │<─────── pong ───────────────────────│
|
|
||||||
│ │ │
|
|
||||||
│<────── connect result ──────────────│ │
|
|
||||||
│ │ │
|
|
||||||
│ ════════════ 文件传输阶段 ════════════ │
|
|
||||||
│ │ │
|
|
||||||
│ sendFile(path) │ │
|
|
||||||
│────────────────────────────────────>│ │
|
|
||||||
│ │──────── file_start ────────────────>│
|
|
||||||
│ │ │
|
|
||||||
│ │<─────── file_start_ack ─────────────│
|
|
||||||
│ │ │
|
|
||||||
│ │ │
|
|
||||||
│ │══════ 循环发送数据块 ═══════════════│
|
|
||||||
│ │ │
|
|
||||||
│ │──────── file_chunk[0] (binary) ────>│
|
|
||||||
│<────── progress event ──────────────│ │
|
|
||||||
│ │ │
|
|
||||||
│ │──────── file_chunk[1] (binary) ────>│
|
|
||||||
│<────── progress event ──────────────│ │
|
|
||||||
│ │ │
|
|
||||||
│ │ ... 重复 ... │
|
|
||||||
│ │ │
|
|
||||||
│ │══════════════════════════════════════│
|
|
||||||
│ │ │
|
|
||||||
│ │──────── file_end ──────────────────>│
|
|
||||||
│ │ │
|
|
||||||
│ │<─────── file_complete ──────────────│
|
|
||||||
│ │ │
|
|
||||||
│<────── complete event ──────────────│ │
|
|
||||||
│<────── sendFile result ─────────────│ │
|
|
||||||
│ │ │
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 10. 移动端实现指南(v1 要点)
|
|
||||||
|
|
||||||
### 10.1 必须实现的功能
|
|
||||||
|
|
||||||
1. **mDNS 服务发布**
|
|
||||||
|
|
||||||
- 发布 `_cherrystudio._tcp` 服务
|
|
||||||
- 提供 TCP 端口号 `53317`
|
|
||||||
- 可选:TXT 记录(版本、平台信息)
|
|
||||||
|
|
||||||
2. **TCP 服务端**
|
|
||||||
|
|
||||||
- 监听指定端口
|
|
||||||
- 支持单连接或多连接
|
|
||||||
|
|
||||||
3. **消息解析**
|
|
||||||
|
|
||||||
- 控制消息:UTF-8 + `\n` JSON
|
|
||||||
- 数据消息:二进制帧(Magic+TotalLen 分帧)
|
|
||||||
|
|
||||||
4. **握手处理**
|
|
||||||
|
|
||||||
- 验证 `handshake` 消息
|
|
||||||
- 发送 `handshake_ack` 响应
|
|
||||||
- 响应 `ping` 消息
|
|
||||||
|
|
||||||
5. **文件接收(流式模式)**
|
|
||||||
- 解析 `file_start`,准备接收
|
|
||||||
- 接收 `file_chunk` 二进制帧,直接写入文件/缓冲并增量哈希
|
|
||||||
- v1 不发送 per-chunk ACK(流式传输)
|
|
||||||
- 处理 `file_end`,完成增量哈希并校验 checksum
|
|
||||||
- 发送 `file_complete` 结果
|
|
||||||
|
|
||||||
### 10.2 推荐的库
|
|
||||||
|
|
||||||
**React Native / Expo:**
|
|
||||||
|
|
||||||
- mDNS: `react-native-zeroconf` 或 `@homielab/react-native-bonjour`
|
|
||||||
- TCP: `react-native-tcp-socket`
|
|
||||||
- Crypto: `expo-crypto` 或 `react-native-quick-crypto`
|
|
||||||
|
|
||||||
### 10.3 接收端伪代码
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
class FileReceiver {
|
|
||||||
private transfer?: {
|
|
||||||
id: string;
|
|
||||||
fileName: string;
|
|
||||||
fileSize: number;
|
|
||||||
checksum: string;
|
|
||||||
totalChunks: number;
|
|
||||||
receivedChunks: number;
|
|
||||||
tempPath: string;
|
|
||||||
// v1: 边收边写文件,避免大文件 OOM
|
|
||||||
// stream: FileSystem writable stream (平台相关封装)
|
|
||||||
};
|
|
||||||
|
|
||||||
handleMessage(message: any) {
|
|
||||||
switch (message.type) {
|
|
||||||
case "handshake":
|
|
||||||
this.handleHandshake(message);
|
|
||||||
break;
|
|
||||||
case "ping":
|
|
||||||
this.sendPong(message);
|
|
||||||
break;
|
|
||||||
case "file_start":
|
|
||||||
this.handleFileStart(message);
|
|
||||||
break;
|
|
||||||
// v1: file_chunk 为二进制帧,不再走 JSON 分支
|
|
||||||
case "file_end":
|
|
||||||
this.handleFileEnd(message);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
handleFileStart(msg: LanTransferFileStartMessage) {
|
|
||||||
// 1. 检查存储空间
|
|
||||||
// 2. 创建临时文件
|
|
||||||
// 3. 初始化传输状态
|
|
||||||
// 4. 发送 file_start_ack
|
|
||||||
}
|
|
||||||
|
|
||||||
// v1: 二进制帧处理在 socket data 流中解析,随后调用 handleBinaryFileChunk
|
|
||||||
handleBinaryFileChunk(transferId: string, chunkIndex: number, data: Buffer) {
|
|
||||||
// 直接使用二进制数据,按 chunkSize/lastChunk 计算长度
|
|
||||||
// 写入文件流并更新增量 SHA-256
|
|
||||||
this.transfer.receivedChunks++;
|
|
||||||
// v1: 流式传输,不发送 per-chunk ACK
|
|
||||||
}
|
|
||||||
|
|
||||||
handleFileEnd(msg: LanTransferFileEndMessage) {
|
|
||||||
// 1. 合并所有数据块
|
|
||||||
// 2. 验证完整文件 checksum
|
|
||||||
// 3. 写入最终位置
|
|
||||||
// 4. 发送 file_complete
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 附录 A:TypeScript 类型定义
|
|
||||||
|
|
||||||
完整的类型定义位于 `packages/shared/config/types.ts`:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
// 握手消息
|
|
||||||
export interface LanTransferHandshakeMessage {
|
|
||||||
type: "handshake";
|
|
||||||
deviceName: string;
|
|
||||||
version: string;
|
|
||||||
platform?: string;
|
|
||||||
appVersion?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface LanTransferHandshakeAckMessage {
|
|
||||||
type: "handshake_ack";
|
|
||||||
accepted: boolean;
|
|
||||||
message?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 心跳消息
|
|
||||||
export interface LanTransferPingMessage {
|
|
||||||
type: "ping";
|
|
||||||
payload?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface LanTransferPongMessage {
|
|
||||||
type: "pong";
|
|
||||||
received: boolean;
|
|
||||||
payload?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 文件传输消息 (Client -> Server)
|
|
||||||
export interface LanTransferFileStartMessage {
|
|
||||||
type: "file_start";
|
|
||||||
transferId: string;
|
|
||||||
fileName: string;
|
|
||||||
fileSize: number;
|
|
||||||
mimeType: string;
|
|
||||||
checksum: string;
|
|
||||||
totalChunks: number;
|
|
||||||
chunkSize: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface LanTransferFileChunkMessage {
|
|
||||||
type: "file_chunk";
|
|
||||||
transferId: string;
|
|
||||||
chunkIndex: number;
|
|
||||||
data: string; // Base64 encoded (v1: 二进制帧模式下不使用)
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface LanTransferFileEndMessage {
|
|
||||||
type: "file_end";
|
|
||||||
transferId: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 文件传输响应消息 (Server -> Client)
|
|
||||||
export interface LanTransferFileStartAckMessage {
|
|
||||||
type: "file_start_ack";
|
|
||||||
transferId: string;
|
|
||||||
accepted: boolean;
|
|
||||||
message?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
// v1 流式不发送 per-chunk ACK,以下类型仅用于向后兼容参考
|
|
||||||
export interface LanTransferFileChunkAckMessage {
|
|
||||||
type: "file_chunk_ack";
|
|
||||||
transferId: string;
|
|
||||||
chunkIndex: number;
|
|
||||||
received: boolean;
|
|
||||||
error?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface LanTransferFileCompleteMessage {
|
|
||||||
type: "file_complete";
|
|
||||||
transferId: string;
|
|
||||||
success: boolean;
|
|
||||||
filePath?: string;
|
|
||||||
error?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 常量
|
|
||||||
export const LAN_TRANSFER_TCP_PORT = 53317;
|
|
||||||
export const LAN_TRANSFER_CHUNK_SIZE = 512 * 1024;
|
|
||||||
export const LAN_TRANSFER_CHUNK_TIMEOUT_MS = 30_000;
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 附录 B:版本历史
|
|
||||||
|
|
||||||
| 版本 | 日期 | 变更 |
|
|
||||||
| ---- | ------- | ---------------------------------------- |
|
|
||||||
| 1.0 | 2025-12 | 初始发布版本,支持二进制帧格式与流式传输 |
|
|
||||||
@ -134,38 +134,60 @@ artifactBuildCompleted: scripts/artifact-build-completed.js
|
|||||||
releaseInfo:
|
releaseInfo:
|
||||||
releaseNotes: |
|
releaseNotes: |
|
||||||
<!--LANG:en-->
|
<!--LANG:en-->
|
||||||
Cherry Studio 1.7.6 - New Models & MCP Enhancements
|
Cherry Studio 1.7.3 - Feature & Stability Update
|
||||||
|
|
||||||
This release adds support for new AI models and includes a new MCP server for memory management.
|
This release brings new features, UI improvements, and important bug fixes.
|
||||||
|
|
||||||
✨ New Features
|
✨ New Features
|
||||||
- [Models] Add support for Xiaomi MiMo model
|
- Add MCP server log viewer for better debugging
|
||||||
- [Models] Add support for Gemini 3 Flash and Pro model detection
|
- Support custom Git Bash path configuration
|
||||||
- [Models] Add support for Volcengine Doubao-Seed-1.8 model
|
- Add print to PDF and save as HTML for mini program webviews
|
||||||
- [MCP] Add Nowledge Mem builtin MCP server for memory management
|
- Add CherryIN API host selection settings
|
||||||
- [Settings] Add default reasoning effort option to resolve confusion between undefined and none
|
- Enhance assistant presets with sort and batch delete modes
|
||||||
|
- Open URL directly for SelectionAssistant search action
|
||||||
|
- Enhance web search tool switching with provider-specific context
|
||||||
|
|
||||||
|
🔧 Improvements
|
||||||
|
- Remove Intel Ultra limit for OVMS
|
||||||
|
- Improve settings tab and assistant item UI
|
||||||
|
|
||||||
🐛 Bug Fixes
|
🐛 Bug Fixes
|
||||||
- [Azure] Restore deployment-based URLs for non-v1 apiVersion
|
- Fix stack overflow with base64 images
|
||||||
- [Translation] Disable reasoning mode for translation to improve efficiency
|
- Fix infinite loop in knowledge queue processing
|
||||||
- [Image] Update API path for image generation requests in OpenAIBaseClient
|
- Fix quick panel closing in multiple selection mode
|
||||||
- [Windows] Auto-discover and persist Git Bash path on Windows for scoop users
|
- Fix thinking timer not stopping when reply is aborted
|
||||||
|
- Fix ThinkingButton icon display for fixed reasoning mode
|
||||||
|
- Fix knowledge query prioritization and intent prompt
|
||||||
|
- Fix OpenRouter embeddings support
|
||||||
|
- Fix SelectionAction window resize on Windows
|
||||||
|
- Add gpustack provider support for qwen3 thinking mode
|
||||||
|
|
||||||
<!--LANG:zh-CN-->
|
<!--LANG:zh-CN-->
|
||||||
Cherry Studio 1.7.6 - 新模型与 MCP 增强
|
Cherry Studio 1.7.3 - 功能与稳定性更新
|
||||||
|
|
||||||
本次更新添加了多个新 AI 模型支持,并新增记忆管理 MCP 服务器。
|
本次更新带来新功能、界面改进和重要的问题修复。
|
||||||
|
|
||||||
✨ 新功能
|
✨ 新功能
|
||||||
- [模型] 添加小米 MiMo 模型支持
|
- 新增 MCP 服务器日志查看器,便于调试
|
||||||
- [模型] 添加 Gemini 3 Flash 和 Pro 模型检测支持
|
- 支持自定义 Git Bash 路径配置
|
||||||
- [模型] 添加火山引擎 Doubao-Seed-1.8 模型支持
|
- 小程序 webview 支持打印 PDF 和保存为 HTML
|
||||||
- [MCP] 新增 Nowledge Mem 内置 MCP 服务器,用于记忆管理
|
- 新增 CherryIN API 主机选择设置
|
||||||
- [设置] 添加默认推理强度选项,解决 undefined 和 none 之间的混淆
|
- 助手预设增强:支持排序和批量删除模式
|
||||||
|
- 划词助手搜索操作直接打开 URL
|
||||||
|
- 增强网页搜索工具切换逻辑,支持服务商特定上下文
|
||||||
|
|
||||||
|
🔧 功能改进
|
||||||
|
- 移除 OVMS 的 Intel Ultra 限制
|
||||||
|
- 优化设置标签页和助手项目 UI
|
||||||
|
|
||||||
🐛 问题修复
|
🐛 问题修复
|
||||||
- [Azure] 修复非 v1 apiVersion 的部署 URL 问题
|
- 修复 base64 图片导致的栈溢出问题
|
||||||
- [翻译] 禁用翻译时的推理模式以提高效率
|
- 修复知识库队列处理的无限循环问题
|
||||||
- [图像] 更新 OpenAIBaseClient 中图像生成请求的 API 路径
|
- 修复多选模式下快捷面板意外关闭的问题
|
||||||
- [Windows] 自动发现并保存 Windows scoop 用户的 Git Bash 路径
|
- 修复回复中止时思考计时器未停止的问题
|
||||||
|
- 修复固定推理模式下思考按钮图标显示问题
|
||||||
|
- 修复知识库查询优先级和意图提示
|
||||||
|
- 修复 OpenRouter 嵌入模型支持
|
||||||
|
- 修复 Windows 上划词助手窗口大小调整问题
|
||||||
|
- 为 gpustack 服务商添加 qwen3 思考模式支持
|
||||||
<!--LANG:END-->
|
<!--LANG:END-->
|
||||||
|
|||||||
@ -61,7 +61,6 @@ export default defineConfig([
|
|||||||
'tests/**',
|
'tests/**',
|
||||||
'.yarn/**',
|
'.yarn/**',
|
||||||
'.gitignore',
|
'.gitignore',
|
||||||
'.conductor/**',
|
|
||||||
'scripts/cloudflare-worker.js',
|
'scripts/cloudflare-worker.js',
|
||||||
'src/main/integration/nutstore/sso/lib/**',
|
'src/main/integration/nutstore/sso/lib/**',
|
||||||
'src/main/integration/cherryai/index.js',
|
'src/main/integration/cherryai/index.js',
|
||||||
|
|||||||
32
package.json
32
package.json
@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "CherryStudio",
|
"name": "CherryStudio",
|
||||||
"version": "1.7.6",
|
"version": "1.7.3",
|
||||||
"private": true,
|
"private": true,
|
||||||
"description": "A powerful AI assistant for producer.",
|
"description": "A powerful AI assistant for producer.",
|
||||||
"main": "./out/main/index.js",
|
"main": "./out/main/index.js",
|
||||||
@ -53,10 +53,10 @@
|
|||||||
"typecheck": "concurrently -n \"node,web\" -c \"cyan,magenta\" \"npm run typecheck:node\" \"npm run typecheck:web\"",
|
"typecheck": "concurrently -n \"node,web\" -c \"cyan,magenta\" \"npm run typecheck:node\" \"npm run typecheck:web\"",
|
||||||
"typecheck:node": "tsgo --noEmit -p tsconfig.node.json --composite false",
|
"typecheck:node": "tsgo --noEmit -p tsconfig.node.json --composite false",
|
||||||
"typecheck:web": "tsgo --noEmit -p tsconfig.web.json --composite false",
|
"typecheck:web": "tsgo --noEmit -p tsconfig.web.json --composite false",
|
||||||
"i18n:check": "dotenv -e .env -- tsx scripts/check-i18n.ts",
|
"check:i18n": "dotenv -e .env -- tsx scripts/check-i18n.ts",
|
||||||
"i18n:sync": "dotenv -e .env -- tsx scripts/sync-i18n.ts",
|
"sync:i18n": "dotenv -e .env -- tsx scripts/sync-i18n.ts",
|
||||||
"i18n:translate": "dotenv -e .env -- tsx scripts/auto-translate-i18n.ts",
|
"update:i18n": "dotenv -e .env -- tsx scripts/update-i18n.ts",
|
||||||
"i18n:all": "yarn i18n:check && yarn i18n:sync && yarn i18n:translate",
|
"auto:i18n": "dotenv -e .env -- tsx scripts/auto-translate-i18n.ts",
|
||||||
"update:languages": "tsx scripts/update-languages.ts",
|
"update:languages": "tsx scripts/update-languages.ts",
|
||||||
"update:upgrade-config": "tsx scripts/update-app-upgrade-config.ts",
|
"update:upgrade-config": "tsx scripts/update-app-upgrade-config.ts",
|
||||||
"test": "vitest run --silent",
|
"test": "vitest run --silent",
|
||||||
@ -70,7 +70,7 @@
|
|||||||
"test:e2e": "yarn playwright test",
|
"test:e2e": "yarn playwright test",
|
||||||
"test:lint": "oxlint --deny-warnings && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --cache",
|
"test:lint": "oxlint --deny-warnings && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --cache",
|
||||||
"test:scripts": "vitest scripts",
|
"test:scripts": "vitest scripts",
|
||||||
"lint": "oxlint --fix && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --cache && yarn typecheck && yarn i18n:check && yarn format:check",
|
"lint": "oxlint --fix && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --cache && yarn typecheck && yarn check:i18n && yarn format:check",
|
||||||
"format": "biome format --write && biome lint --write",
|
"format": "biome format --write && biome lint --write",
|
||||||
"format:check": "biome format && biome lint",
|
"format:check": "biome format && biome lint",
|
||||||
"prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky",
|
"prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky",
|
||||||
@ -87,7 +87,6 @@
|
|||||||
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
|
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
|
||||||
"@paymoapp/electron-shutdown-handler": "^1.1.2",
|
"@paymoapp/electron-shutdown-handler": "^1.1.2",
|
||||||
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
||||||
"bonjour-service": "^1.3.0",
|
|
||||||
"emoji-picker-element-data": "^1",
|
"emoji-picker-element-data": "^1",
|
||||||
"express": "^5.1.0",
|
"express": "^5.1.0",
|
||||||
"font-list": "^2.0.0",
|
"font-list": "^2.0.0",
|
||||||
@ -98,8 +97,10 @@
|
|||||||
"node-stream-zip": "^1.15.0",
|
"node-stream-zip": "^1.15.0",
|
||||||
"officeparser": "^4.2.0",
|
"officeparser": "^4.2.0",
|
||||||
"os-proxy-config": "^1.1.2",
|
"os-proxy-config": "^1.1.2",
|
||||||
|
"qrcode.react": "^4.2.0",
|
||||||
"selection-hook": "^1.0.12",
|
"selection-hook": "^1.0.12",
|
||||||
"sharp": "^0.34.3",
|
"sharp": "^0.34.3",
|
||||||
|
"socket.io": "^4.8.1",
|
||||||
"swagger-jsdoc": "^6.2.8",
|
"swagger-jsdoc": "^6.2.8",
|
||||||
"swagger-ui-express": "^5.0.1",
|
"swagger-ui-express": "^5.0.1",
|
||||||
"tesseract.js": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch",
|
"tesseract.js": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch",
|
||||||
@ -113,11 +114,11 @@
|
|||||||
"@ai-sdk/anthropic": "^2.0.49",
|
"@ai-sdk/anthropic": "^2.0.49",
|
||||||
"@ai-sdk/cerebras": "^1.0.31",
|
"@ai-sdk/cerebras": "^1.0.31",
|
||||||
"@ai-sdk/gateway": "^2.0.15",
|
"@ai-sdk/gateway": "^2.0.15",
|
||||||
"@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.49#~/.yarn/patches/@ai-sdk-google-npm-2.0.49-84720f41bd.patch",
|
"@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.43#~/.yarn/patches/@ai-sdk-google-npm-2.0.43-689ed559b3.patch",
|
||||||
"@ai-sdk/google-vertex": "^3.0.94",
|
"@ai-sdk/google-vertex": "^3.0.79",
|
||||||
"@ai-sdk/huggingface": "^0.0.10",
|
"@ai-sdk/huggingface": "^0.0.10",
|
||||||
"@ai-sdk/mistral": "^2.0.24",
|
"@ai-sdk/mistral": "^2.0.24",
|
||||||
"@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.85#~/.yarn/patches/@ai-sdk-openai-npm-2.0.85-27483d1d6a.patch",
|
"@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.72#~/.yarn/patches/@ai-sdk-openai-npm-2.0.72-234e68da87.patch",
|
||||||
"@ai-sdk/perplexity": "^2.0.20",
|
"@ai-sdk/perplexity": "^2.0.20",
|
||||||
"@ai-sdk/test-server": "^0.0.1",
|
"@ai-sdk/test-server": "^0.0.1",
|
||||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||||
@ -141,7 +142,7 @@
|
|||||||
"@cherrystudio/embedjs-ollama": "^0.1.31",
|
"@cherrystudio/embedjs-ollama": "^0.1.31",
|
||||||
"@cherrystudio/embedjs-openai": "^0.1.31",
|
"@cherrystudio/embedjs-openai": "^0.1.31",
|
||||||
"@cherrystudio/extension-table-plus": "workspace:^",
|
"@cherrystudio/extension-table-plus": "workspace:^",
|
||||||
"@cherrystudio/openai": "^6.12.0",
|
"@cherrystudio/openai": "^6.9.0",
|
||||||
"@dnd-kit/core": "^6.3.1",
|
"@dnd-kit/core": "^6.3.1",
|
||||||
"@dnd-kit/modifiers": "^9.0.0",
|
"@dnd-kit/modifiers": "^9.0.0",
|
||||||
"@dnd-kit/sortable": "^10.0.0",
|
"@dnd-kit/sortable": "^10.0.0",
|
||||||
@ -317,7 +318,7 @@
|
|||||||
"motion": "^12.10.5",
|
"motion": "^12.10.5",
|
||||||
"notion-helper": "^1.3.22",
|
"notion-helper": "^1.3.22",
|
||||||
"npx-scope-finder": "^1.2.0",
|
"npx-scope-finder": "^1.2.0",
|
||||||
"ollama-ai-provider-v2": "patch:ollama-ai-provider-v2@npm%3A1.5.5#~/.yarn/patches/ollama-ai-provider-v2-npm-1.5.5-8bef249af9.patch",
|
"ollama-ai-provider-v2": "^1.5.5",
|
||||||
"oxlint": "^1.22.0",
|
"oxlint": "^1.22.0",
|
||||||
"oxlint-tsgolint": "^0.2.0",
|
"oxlint-tsgolint": "^0.2.0",
|
||||||
"p-queue": "^8.1.0",
|
"p-queue": "^8.1.0",
|
||||||
@ -413,12 +414,9 @@
|
|||||||
"@langchain/openai@npm:>=0.1.0 <0.6.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
"@langchain/openai@npm:>=0.1.0 <0.6.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
||||||
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
||||||
"@langchain/openai@npm:>=0.2.0 <0.7.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
"@langchain/openai@npm:>=0.2.0 <0.7.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
||||||
"@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.85#~/.yarn/patches/@ai-sdk-openai-npm-2.0.85-27483d1d6a.patch",
|
"@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.72#~/.yarn/patches/@ai-sdk-openai-npm-2.0.72-234e68da87.patch",
|
||||||
"@ai-sdk/google@npm:^2.0.40": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
|
"@ai-sdk/google@npm:^2.0.40": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
|
||||||
"@ai-sdk/openai-compatible@npm:^1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
|
"@ai-sdk/openai-compatible@npm:^1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch"
|
||||||
"@ai-sdk/google@npm:2.0.49": "patch:@ai-sdk/google@npm%3A2.0.49#~/.yarn/patches/@ai-sdk-google-npm-2.0.49-84720f41bd.patch",
|
|
||||||
"@ai-sdk/openai-compatible@npm:1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch",
|
|
||||||
"@ai-sdk/openai-compatible@npm:^1.0.19": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch"
|
|
||||||
},
|
},
|
||||||
"packageManager": "yarn@4.9.1",
|
"packageManager": "yarn@4.9.1",
|
||||||
"lint-staged": {
|
"lint-staged": {
|
||||||
|
|||||||
@ -41,7 +41,7 @@
|
|||||||
"ai": "^5.0.26"
|
"ai": "^5.0.26"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch",
|
"@ai-sdk/openai-compatible": "^1.0.28",
|
||||||
"@ai-sdk/provider": "^2.0.0",
|
"@ai-sdk/provider": "^2.0.0",
|
||||||
"@ai-sdk/provider-utils": "^3.0.17"
|
"@ai-sdk/provider-utils": "^3.0.17"
|
||||||
},
|
},
|
||||||
|
|||||||
@ -40,9 +40,9 @@
|
|||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@ai-sdk/anthropic": "^2.0.49",
|
"@ai-sdk/anthropic": "^2.0.49",
|
||||||
"@ai-sdk/azure": "^2.0.87",
|
"@ai-sdk/azure": "^2.0.74",
|
||||||
"@ai-sdk/deepseek": "^1.0.31",
|
"@ai-sdk/deepseek": "^1.0.31",
|
||||||
"@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch",
|
"@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
|
||||||
"@ai-sdk/provider": "^2.0.0",
|
"@ai-sdk/provider": "^2.0.0",
|
||||||
"@ai-sdk/provider-utils": "^3.0.17",
|
"@ai-sdk/provider-utils": "^3.0.17",
|
||||||
"@ai-sdk/xai": "^2.0.36",
|
"@ai-sdk/xai": "^2.0.36",
|
||||||
|
|||||||
@ -62,7 +62,7 @@ export class StreamEventManager {
|
|||||||
const recursiveResult = await context.recursiveCall(recursiveParams)
|
const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||||
|
|
||||||
if (recursiveResult && recursiveResult.fullStream) {
|
if (recursiveResult && recursiveResult.fullStream) {
|
||||||
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
|
await this.pipeRecursiveStream(controller, recursiveResult.fullStream, context)
|
||||||
} else {
|
} else {
|
||||||
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
|
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
|
||||||
}
|
}
|
||||||
@ -74,7 +74,11 @@ export class StreamEventManager {
|
|||||||
/**
|
/**
|
||||||
* 将递归流的数据传递到当前流
|
* 将递归流的数据传递到当前流
|
||||||
*/
|
*/
|
||||||
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
|
private async pipeRecursiveStream(
|
||||||
|
controller: StreamController,
|
||||||
|
recursiveStream: ReadableStream,
|
||||||
|
context?: AiRequestContext
|
||||||
|
): Promise<void> {
|
||||||
const reader = recursiveStream.getReader()
|
const reader = recursiveStream.getReader()
|
||||||
try {
|
try {
|
||||||
while (true) {
|
while (true) {
|
||||||
@ -82,14 +86,18 @@ export class StreamEventManager {
|
|||||||
if (done) {
|
if (done) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if (value.type === 'start') {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if (value.type === 'finish') {
|
if (value.type === 'finish') {
|
||||||
|
// 迭代的流不发finish,但需要累加其 usage
|
||||||
|
if (value.usage && context?.accumulatedUsage) {
|
||||||
|
this.accumulateUsage(context.accumulatedUsage, value.usage)
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
// 对于 finish-step 类型,累加其 usage
|
||||||
|
if (value.type === 'finish-step' && value.usage && context?.accumulatedUsage) {
|
||||||
|
this.accumulateUsage(context.accumulatedUsage, value.usage)
|
||||||
|
}
|
||||||
|
// 将递归流的数据传递到当前流
|
||||||
controller.enqueue(value)
|
controller.enqueue(value)
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
@ -151,7 +159,7 @@ export class StreamEventManager {
|
|||||||
/**
|
/**
|
||||||
* 累加 usage 数据
|
* 累加 usage 数据
|
||||||
*/
|
*/
|
||||||
accumulateUsage(target: any, source: any): void {
|
private accumulateUsage(target: any, source: any): void {
|
||||||
if (!target || !source) return
|
if (!target || !source) return
|
||||||
|
|
||||||
// 累加各种 token 类型
|
// 累加各种 token 类型
|
||||||
|
|||||||
@ -22,10 +22,10 @@ const TOOL_USE_TAG_CONFIG: TagConfig = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 默认系统提示符模板
|
* 默认系统提示符模板(提取自 Cherry Studio)
|
||||||
*/
|
*/
|
||||||
export const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \
|
const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \\
|
||||||
You can use one or more tools per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
|
You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
|
||||||
|
|
||||||
## Tool Use Formatting
|
## Tool Use Formatting
|
||||||
|
|
||||||
@ -74,13 +74,10 @@ Here are the rules you should always follow to solve your task:
|
|||||||
4. Never re-do a tool call that you previously did with the exact same parameters.
|
4. Never re-do a tool call that you previously did with the exact same parameters.
|
||||||
5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format.
|
5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format.
|
||||||
|
|
||||||
## Response rules
|
|
||||||
|
|
||||||
Respond in the language of the user's query, unless the user instructions specify additional requirements for the language to be used.
|
|
||||||
|
|
||||||
# User Instructions
|
# User Instructions
|
||||||
{{ USER_SYSTEM_PROMPT }}
|
{{ USER_SYSTEM_PROMPT }}
|
||||||
`
|
|
||||||
|
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.`
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 默认工具使用示例(提取自 Cherry Studio)
|
* 默认工具使用示例(提取自 Cherry Studio)
|
||||||
@ -414,10 +411,7 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果没有执行工具调用,累加 usage 后透传 finish-step 事件
|
// 如果没有执行工具调用,直接传递原始finish-step事件
|
||||||
if (chunk.usage && context.accumulatedUsage) {
|
|
||||||
streamEventManager.accumulateUsage(context.accumulatedUsage, chunk.usage)
|
|
||||||
}
|
|
||||||
controller.enqueue(chunk)
|
controller.enqueue(chunk)
|
||||||
|
|
||||||
// 清理状态
|
// 清理状态
|
||||||
|
|||||||
@ -233,8 +233,6 @@ export enum IpcChannel {
|
|||||||
Backup_ListS3Files = 'backup:listS3Files',
|
Backup_ListS3Files = 'backup:listS3Files',
|
||||||
Backup_DeleteS3File = 'backup:deleteS3File',
|
Backup_DeleteS3File = 'backup:deleteS3File',
|
||||||
Backup_CheckS3Connection = 'backup:checkS3Connection',
|
Backup_CheckS3Connection = 'backup:checkS3Connection',
|
||||||
Backup_CreateLanTransferBackup = 'backup:createLanTransferBackup',
|
|
||||||
Backup_DeleteTempBackup = 'backup:deleteTempBackup',
|
|
||||||
|
|
||||||
// zip
|
// zip
|
||||||
Zip_Compress = 'zip:compress',
|
Zip_Compress = 'zip:compress',
|
||||||
@ -246,7 +244,6 @@ export enum IpcChannel {
|
|||||||
System_GetCpuName = 'system:getCpuName',
|
System_GetCpuName = 'system:getCpuName',
|
||||||
System_CheckGitBash = 'system:checkGitBash',
|
System_CheckGitBash = 'system:checkGitBash',
|
||||||
System_GetGitBashPath = 'system:getGitBashPath',
|
System_GetGitBashPath = 'system:getGitBashPath',
|
||||||
System_GetGitBashPathInfo = 'system:getGitBashPathInfo',
|
|
||||||
System_SetGitBashPath = 'system:setGitBashPath',
|
System_SetGitBashPath = 'system:setGitBashPath',
|
||||||
|
|
||||||
// DevTools
|
// DevTools
|
||||||
@ -318,7 +315,6 @@ export enum IpcChannel {
|
|||||||
Memory_DeleteUser = 'memory:delete-user',
|
Memory_DeleteUser = 'memory:delete-user',
|
||||||
Memory_DeleteAllMemoriesForUser = 'memory:delete-all-memories-for-user',
|
Memory_DeleteAllMemoriesForUser = 'memory:delete-all-memories-for-user',
|
||||||
Memory_GetUsersList = 'memory:get-users-list',
|
Memory_GetUsersList = 'memory:get-users-list',
|
||||||
Memory_MigrateMemoryDb = 'memory:migrate-memory-db',
|
|
||||||
|
|
||||||
// TRACE
|
// TRACE
|
||||||
TRACE_SAVE_DATA = 'trace:saveData',
|
TRACE_SAVE_DATA = 'trace:saveData',
|
||||||
@ -384,14 +380,10 @@ export enum IpcChannel {
|
|||||||
ClaudeCodePlugin_ReadContent = 'claudeCodePlugin:read-content',
|
ClaudeCodePlugin_ReadContent = 'claudeCodePlugin:read-content',
|
||||||
ClaudeCodePlugin_WriteContent = 'claudeCodePlugin:write-content',
|
ClaudeCodePlugin_WriteContent = 'claudeCodePlugin:write-content',
|
||||||
|
|
||||||
// Local Transfer
|
// WebSocket
|
||||||
LocalTransfer_ListServices = 'local-transfer:list',
|
WebSocket_Start = 'webSocket:start',
|
||||||
LocalTransfer_StartScan = 'local-transfer:start-scan',
|
WebSocket_Stop = 'webSocket:stop',
|
||||||
LocalTransfer_StopScan = 'local-transfer:stop-scan',
|
WebSocket_Status = 'webSocket:status',
|
||||||
LocalTransfer_ServicesUpdated = 'local-transfer:services-updated',
|
WebSocket_SendFile = 'webSocket:send-file',
|
||||||
LocalTransfer_Connect = 'local-transfer:connect',
|
WebSocket_GetAllCandidates = 'webSocket:get-all-candidates'
|
||||||
LocalTransfer_Disconnect = 'local-transfer:disconnect',
|
|
||||||
LocalTransfer_ClientEvent = 'local-transfer:client-event',
|
|
||||||
LocalTransfer_SendFile = 'local-transfer:send-file',
|
|
||||||
LocalTransfer_CancelTransfer = 'local-transfer:cancel-transfer'
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -488,11 +488,3 @@ export const MACOS_TERMINALS_WITH_COMMANDS: TerminalConfigWithCommand[] = [
|
|||||||
|
|
||||||
// resources/scripts should be maintained manually
|
// resources/scripts should be maintained manually
|
||||||
export const HOME_CHERRY_DIR = '.cherrystudio'
|
export const HOME_CHERRY_DIR = '.cherrystudio'
|
||||||
|
|
||||||
// Git Bash path configuration types
|
|
||||||
export type GitBashPathSource = 'manual' | 'auto'
|
|
||||||
|
|
||||||
export interface GitBashPathInfo {
|
|
||||||
path: string | null
|
|
||||||
source: GitBashPathSource | null
|
|
||||||
}
|
|
||||||
|
|||||||
@ -52,196 +52,3 @@ export interface WebSocketCandidatesResponse {
|
|||||||
interface: string
|
interface: string
|
||||||
priority: number
|
priority: number
|
||||||
}
|
}
|
||||||
|
|
||||||
export type LocalTransferPeer = {
|
|
||||||
id: string
|
|
||||||
name: string
|
|
||||||
host?: string
|
|
||||||
fqdn?: string
|
|
||||||
port?: number
|
|
||||||
type?: string
|
|
||||||
protocol?: 'tcp' | 'udp'
|
|
||||||
addresses: string[]
|
|
||||||
txt?: Record<string, string>
|
|
||||||
updatedAt: number
|
|
||||||
}
|
|
||||||
|
|
||||||
export type LocalTransferState = {
|
|
||||||
services: LocalTransferPeer[]
|
|
||||||
isScanning: boolean
|
|
||||||
lastScanStartedAt?: number
|
|
||||||
lastUpdatedAt: number
|
|
||||||
lastError?: string
|
|
||||||
}
|
|
||||||
|
|
||||||
export type LanHandshakeRequestMessage = {
|
|
||||||
type: 'handshake'
|
|
||||||
deviceName: string
|
|
||||||
version: string
|
|
||||||
platform?: string
|
|
||||||
appVersion?: string
|
|
||||||
}
|
|
||||||
|
|
||||||
export type LanHandshakeAckMessage = {
|
|
||||||
type: 'handshake_ack'
|
|
||||||
accepted: boolean
|
|
||||||
message?: string
|
|
||||||
}
|
|
||||||
|
|
||||||
export type LocalTransferConnectPayload = {
|
|
||||||
peerId: string
|
|
||||||
metadata?: Record<string, string>
|
|
||||||
timeoutMs?: number
|
|
||||||
}
|
|
||||||
|
|
||||||
export type LanClientEvent =
|
|
||||||
| {
|
|
||||||
type: 'ping_sent'
|
|
||||||
payload: string
|
|
||||||
timestamp: number
|
|
||||||
peerId?: string
|
|
||||||
peerName?: string
|
|
||||||
}
|
|
||||||
| {
|
|
||||||
type: 'pong'
|
|
||||||
payload?: string
|
|
||||||
received?: boolean
|
|
||||||
timestamp: number
|
|
||||||
peerId?: string
|
|
||||||
peerName?: string
|
|
||||||
}
|
|
||||||
| {
|
|
||||||
type: 'socket_closed'
|
|
||||||
reason?: string
|
|
||||||
timestamp: number
|
|
||||||
peerId?: string
|
|
||||||
peerName?: string
|
|
||||||
}
|
|
||||||
| {
|
|
||||||
type: 'error'
|
|
||||||
message: string
|
|
||||||
timestamp: number
|
|
||||||
peerId?: string
|
|
||||||
peerName?: string
|
|
||||||
}
|
|
||||||
| {
|
|
||||||
type: 'file_transfer_progress'
|
|
||||||
transferId: string
|
|
||||||
fileName: string
|
|
||||||
bytesSent: number
|
|
||||||
totalBytes: number
|
|
||||||
chunkIndex: number
|
|
||||||
totalChunks: number
|
|
||||||
progress: number // 0-100
|
|
||||||
speed: number // bytes/sec
|
|
||||||
timestamp: number
|
|
||||||
peerId?: string
|
|
||||||
peerName?: string
|
|
||||||
}
|
|
||||||
| {
|
|
||||||
type: 'file_transfer_complete'
|
|
||||||
transferId: string
|
|
||||||
fileName: string
|
|
||||||
success: boolean
|
|
||||||
filePath?: string
|
|
||||||
error?: string
|
|
||||||
timestamp: number
|
|
||||||
peerId?: string
|
|
||||||
peerName?: string
|
|
||||||
}
|
|
||||||
|
|
||||||
// =============================================================================
|
|
||||||
// LAN File Transfer Protocol Types
|
|
||||||
// =============================================================================
|
|
||||||
|
|
||||||
// Constants for file transfer
|
|
||||||
export const LAN_TRANSFER_TCP_PORT = 53317
|
|
||||||
export const LAN_TRANSFER_CHUNK_SIZE = 512 * 1024 // 512KB
|
|
||||||
export const LAN_TRANSFER_MAX_FILE_SIZE = 500 * 1024 * 1024 // 500MB
|
|
||||||
export const LAN_TRANSFER_COMPLETE_TIMEOUT_MS = 60_000 // 60s - wait for file_complete after file_end
|
|
||||||
export const LAN_TRANSFER_GLOBAL_TIMEOUT_MS = 10 * 60 * 1000 // 10 minutes - global transfer timeout
|
|
||||||
|
|
||||||
// Binary protocol constants (v1)
|
|
||||||
export const LAN_TRANSFER_PROTOCOL_VERSION = '1'
|
|
||||||
export const LAN_BINARY_FRAME_MAGIC = 0x4353 // "CS" as uint16
|
|
||||||
export const LAN_BINARY_TYPE_FILE_CHUNK = 0x01
|
|
||||||
|
|
||||||
// Messages from Electron (Client/Sender) to Mobile (Server/Receiver)
|
|
||||||
|
|
||||||
/** Request to start file transfer */
|
|
||||||
export type LanFileStartMessage = {
|
|
||||||
type: 'file_start'
|
|
||||||
transferId: string
|
|
||||||
fileName: string
|
|
||||||
fileSize: number
|
|
||||||
mimeType: string // 'application/zip'
|
|
||||||
checksum: string // SHA-256 of entire file
|
|
||||||
totalChunks: number
|
|
||||||
chunkSize: number
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* File chunk data (JSON format)
|
|
||||||
* @deprecated Use binary frame format in protocol v1. This type is kept for reference only.
|
|
||||||
*/
|
|
||||||
export type LanFileChunkMessage = {
|
|
||||||
type: 'file_chunk'
|
|
||||||
transferId: string
|
|
||||||
chunkIndex: number
|
|
||||||
data: string // Base64 encoded
|
|
||||||
chunkChecksum: string // SHA-256 of this chunk
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Notification that all chunks have been sent */
|
|
||||||
export type LanFileEndMessage = {
|
|
||||||
type: 'file_end'
|
|
||||||
transferId: string
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Request to cancel file transfer */
|
|
||||||
export type LanFileCancelMessage = {
|
|
||||||
type: 'file_cancel'
|
|
||||||
transferId: string
|
|
||||||
reason?: string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Messages from Mobile (Server/Receiver) to Electron (Client/Sender)
|
|
||||||
|
|
||||||
/** Acknowledgment of file transfer request */
|
|
||||||
export type LanFileStartAckMessage = {
|
|
||||||
type: 'file_start_ack'
|
|
||||||
transferId: string
|
|
||||||
accepted: boolean
|
|
||||||
message?: string // Rejection reason
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Acknowledgment of file chunk received
|
|
||||||
* @deprecated Protocol v1 uses streaming mode without per-chunk acknowledgment.
|
|
||||||
* This type is kept for backward compatibility reference only.
|
|
||||||
*/
|
|
||||||
export type LanFileChunkAckMessage = {
|
|
||||||
type: 'file_chunk_ack'
|
|
||||||
transferId: string
|
|
||||||
chunkIndex: number
|
|
||||||
received: boolean
|
|
||||||
message?: string
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Final result of file transfer */
|
|
||||||
export type LanFileCompleteMessage = {
|
|
||||||
type: 'file_complete'
|
|
||||||
transferId: string
|
|
||||||
success: boolean
|
|
||||||
filePath?: string // Path where file was saved on mobile
|
|
||||||
error?: string
|
|
||||||
// Enhanced error diagnostics
|
|
||||||
errorCode?: 'CHECKSUM_MISMATCH' | 'INCOMPLETE_TRANSFER' | 'DISK_ERROR' | 'CANCELLED'
|
|
||||||
receivedChunks?: number
|
|
||||||
receivedBytes?: number
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Payload for sending a file via IPC */
|
|
||||||
export type LanFileSendPayload = {
|
|
||||||
filePath: string
|
|
||||||
}
|
|
||||||
|
|||||||
@ -6,8 +6,8 @@ const { downloadWithPowerShell } = require('./download')
|
|||||||
|
|
||||||
// Base URL for downloading OVMS binaries
|
// Base URL for downloading OVMS binaries
|
||||||
const OVMS_RELEASE_BASE_URL =
|
const OVMS_RELEASE_BASE_URL =
|
||||||
'https://storage.openvinotoolkit.org/repositories/openvino_model_server/packages/2025.4.1/ovms_windows_python_on.zip'
|
'https://storage.openvinotoolkit.org/repositories/openvino_model_server/packages/2025.3.0/ovms_windows_python_on.zip'
|
||||||
const OVMS_EX_URL = 'https://gitcode.com/gcw_ggDjjkY3/kjfile/releases/download/download/ovms_25.4_ex.zip'
|
const OVMS_EX_URL = 'https://gitcode.com/gcw_ggDjjkY3/kjfile/releases/download/download/ovms_25.3_ex.zip'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* error code:
|
* error code:
|
||||||
|
|||||||
@ -50,7 +50,7 @@ Usage Instructions:
|
|||||||
- pt-pt (Portuguese)
|
- pt-pt (Portuguese)
|
||||||
|
|
||||||
Run Command:
|
Run Command:
|
||||||
yarn i18n:translate
|
yarn auto:i18n
|
||||||
|
|
||||||
Performance Optimization Recommendations:
|
Performance Optimization Recommendations:
|
||||||
- For stable API services: MAX_CONCURRENT_TRANSLATIONS=8, TRANSLATION_DELAY_MS=50
|
- For stable API services: MAX_CONCURRENT_TRANSLATIONS=8, TRANSLATION_DELAY_MS=50
|
||||||
|
|||||||
@ -145,7 +145,7 @@ export function main() {
|
|||||||
console.log('i18n 检查已通过')
|
console.log('i18n 检查已通过')
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error(e)
|
console.error(e)
|
||||||
throw new Error(`检查未通过。尝试运行 yarn i18n:sync 以解决问题。`)
|
throw new Error(`检查未通过。尝试运行 yarn sync:i18n 以解决问题。`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -19,10 +19,8 @@ import { agentService } from './services/agents'
|
|||||||
import { apiServerService } from './services/ApiServerService'
|
import { apiServerService } from './services/ApiServerService'
|
||||||
import { appMenuService } from './services/AppMenuService'
|
import { appMenuService } from './services/AppMenuService'
|
||||||
import { configManager } from './services/ConfigManager'
|
import { configManager } from './services/ConfigManager'
|
||||||
import { lanTransferClientService } from './services/lanTransfer'
|
|
||||||
import mcpService from './services/MCPService'
|
|
||||||
import { localTransferService } from './services/LocalTransferService'
|
|
||||||
import { nodeTraceService } from './services/NodeTraceService'
|
import { nodeTraceService } from './services/NodeTraceService'
|
||||||
|
import mcpService from './services/MCPService'
|
||||||
import powerMonitorService from './services/PowerMonitorService'
|
import powerMonitorService from './services/PowerMonitorService'
|
||||||
import {
|
import {
|
||||||
CHERRY_STUDIO_PROTOCOL,
|
CHERRY_STUDIO_PROTOCOL,
|
||||||
@ -158,7 +156,6 @@ if (!app.requestSingleInstanceLock()) {
|
|||||||
registerShortcuts(mainWindow)
|
registerShortcuts(mainWindow)
|
||||||
|
|
||||||
registerIpc(mainWindow, app)
|
registerIpc(mainWindow, app)
|
||||||
localTransferService.startDiscovery({ resetList: true })
|
|
||||||
|
|
||||||
replaceDevtoolsFont(mainWindow)
|
replaceDevtoolsFont(mainWindow)
|
||||||
|
|
||||||
@ -240,9 +237,6 @@ if (!app.requestSingleInstanceLock()) {
|
|||||||
if (selectionService) {
|
if (selectionService) {
|
||||||
selectionService.quit()
|
selectionService.quit()
|
||||||
}
|
}
|
||||||
|
|
||||||
lanTransferClientService.dispose()
|
|
||||||
localTransferService.dispose()
|
|
||||||
})
|
})
|
||||||
|
|
||||||
app.on('will-quit', async () => {
|
app.on('will-quit', async () => {
|
||||||
|
|||||||
@ -6,19 +6,11 @@ import { loggerService } from '@logger'
|
|||||||
import { isLinux, isMac, isPortable, isWin } from '@main/constant'
|
import { isLinux, isMac, isPortable, isWin } from '@main/constant'
|
||||||
import { generateSignature } from '@main/integration/cherryai'
|
import { generateSignature } from '@main/integration/cherryai'
|
||||||
import anthropicService from '@main/services/AnthropicService'
|
import anthropicService from '@main/services/AnthropicService'
|
||||||
import {
|
import { findGitBash, getBinaryPath, isBinaryExists, runInstallScript, validateGitBashPath } from '@main/utils/process'
|
||||||
autoDiscoverGitBash,
|
|
||||||
getBinaryPath,
|
|
||||||
getGitBashPathInfo,
|
|
||||||
isBinaryExists,
|
|
||||||
runInstallScript,
|
|
||||||
validateGitBashPath
|
|
||||||
} from '@main/utils/process'
|
|
||||||
import { handleZoomFactor } from '@main/utils/zoom'
|
import { handleZoomFactor } from '@main/utils/zoom'
|
||||||
import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
||||||
import type { UpgradeChannel } from '@shared/config/constant'
|
import type { UpgradeChannel } from '@shared/config/constant'
|
||||||
import { MIN_WINDOW_HEIGHT, MIN_WINDOW_WIDTH } from '@shared/config/constant'
|
import { MIN_WINDOW_HEIGHT, MIN_WINDOW_WIDTH } from '@shared/config/constant'
|
||||||
import type { LocalTransferConnectPayload } from '@shared/config/types'
|
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
import type { PluginError } from '@types'
|
import type { PluginError } from '@types'
|
||||||
import type {
|
import type {
|
||||||
@ -50,8 +42,6 @@ import { ExportService } from './services/ExportService'
|
|||||||
import { fileStorage as fileManager } from './services/FileStorage'
|
import { fileStorage as fileManager } from './services/FileStorage'
|
||||||
import FileService from './services/FileSystemService'
|
import FileService from './services/FileSystemService'
|
||||||
import KnowledgeService from './services/KnowledgeService'
|
import KnowledgeService from './services/KnowledgeService'
|
||||||
import { lanTransferClientService } from './services/lanTransfer'
|
|
||||||
import { localTransferService } from './services/LocalTransferService'
|
|
||||||
import mcpService from './services/MCPService'
|
import mcpService from './services/MCPService'
|
||||||
import MemoryService from './services/memory/MemoryService'
|
import MemoryService from './services/memory/MemoryService'
|
||||||
import { openTraceWindow, setTraceWindowTitle } from './services/NodeTraceService'
|
import { openTraceWindow, setTraceWindowTitle } from './services/NodeTraceService'
|
||||||
@ -83,6 +73,7 @@ import {
|
|||||||
import storeSyncService from './services/StoreSyncService'
|
import storeSyncService from './services/StoreSyncService'
|
||||||
import { themeService } from './services/ThemeService'
|
import { themeService } from './services/ThemeService'
|
||||||
import VertexAIService from './services/VertexAIService'
|
import VertexAIService from './services/VertexAIService'
|
||||||
|
import WebSocketService from './services/WebSocketService'
|
||||||
import { setOpenLinkExternal } from './services/WebviewService'
|
import { setOpenLinkExternal } from './services/WebviewService'
|
||||||
import { windowService } from './services/WindowService'
|
import { windowService } from './services/WindowService'
|
||||||
import { calculateDirectorySize, getResourcePath } from './utils'
|
import { calculateDirectorySize, getResourcePath } from './utils'
|
||||||
@ -508,8 +499,9 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Use autoDiscoverGitBash to handle auto-discovery and persistence
|
const customPath = configManager.get(ConfigKeys.GitBashPath) as string | undefined
|
||||||
const bashPath = autoDiscoverGitBash()
|
const bashPath = findGitBash(customPath)
|
||||||
|
|
||||||
if (bashPath) {
|
if (bashPath) {
|
||||||
logger.info('Git Bash is available', { path: bashPath })
|
logger.info('Git Bash is available', { path: bashPath })
|
||||||
return true
|
return true
|
||||||
@ -532,22 +524,13 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
return customPath ?? null
|
return customPath ?? null
|
||||||
})
|
})
|
||||||
|
|
||||||
// Returns { path, source } where source is 'manual' | 'auto' | null
|
|
||||||
ipcMain.handle(IpcChannel.System_GetGitBashPathInfo, () => {
|
|
||||||
return getGitBashPathInfo()
|
|
||||||
})
|
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.System_SetGitBashPath, (_, newPath: string | null) => {
|
ipcMain.handle(IpcChannel.System_SetGitBashPath, (_, newPath: string | null) => {
|
||||||
if (!isWin) {
|
if (!isWin) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!newPath) {
|
if (!newPath) {
|
||||||
// Clear manual setting and re-run auto-discovery
|
|
||||||
configManager.set(ConfigKeys.GitBashPath, null)
|
configManager.set(ConfigKeys.GitBashPath, null)
|
||||||
configManager.set(ConfigKeys.GitBashPathSource, null)
|
|
||||||
// Re-run auto-discovery to restore auto-discovered path if available
|
|
||||||
autoDiscoverGitBash()
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -556,9 +539,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set path with 'manual' source
|
|
||||||
configManager.set(ConfigKeys.GitBashPath, validated)
|
configManager.set(ConfigKeys.GitBashPath, validated)
|
||||||
configManager.set(ConfigKeys.GitBashPathSource, 'manual')
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -585,8 +566,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
ipcMain.handle(IpcChannel.Backup_ListS3Files, backupManager.listS3Files.bind(backupManager))
|
ipcMain.handle(IpcChannel.Backup_ListS3Files, backupManager.listS3Files.bind(backupManager))
|
||||||
ipcMain.handle(IpcChannel.Backup_DeleteS3File, backupManager.deleteS3File.bind(backupManager))
|
ipcMain.handle(IpcChannel.Backup_DeleteS3File, backupManager.deleteS3File.bind(backupManager))
|
||||||
ipcMain.handle(IpcChannel.Backup_CheckS3Connection, backupManager.checkS3Connection.bind(backupManager))
|
ipcMain.handle(IpcChannel.Backup_CheckS3Connection, backupManager.checkS3Connection.bind(backupManager))
|
||||||
ipcMain.handle(IpcChannel.Backup_CreateLanTransferBackup, backupManager.createLanTransferBackup.bind(backupManager))
|
|
||||||
ipcMain.handle(IpcChannel.Backup_DeleteTempBackup, backupManager.deleteTempBackup.bind(backupManager))
|
|
||||||
|
|
||||||
// file
|
// file
|
||||||
ipcMain.handle(IpcChannel.File_Open, fileManager.open.bind(fileManager))
|
ipcMain.handle(IpcChannel.File_Open, fileManager.open.bind(fileManager))
|
||||||
@ -686,19 +665,36 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
ipcMain.handle(IpcChannel.KnowledgeBase_Check_Quota, KnowledgeService.checkQuota.bind(KnowledgeService))
|
ipcMain.handle(IpcChannel.KnowledgeBase_Check_Quota, KnowledgeService.checkQuota.bind(KnowledgeService))
|
||||||
|
|
||||||
// memory
|
// memory
|
||||||
ipcMain.handle(IpcChannel.Memory_Add, (_, messages, config) => memoryService.add(messages, config))
|
ipcMain.handle(IpcChannel.Memory_Add, async (_, messages, config) => {
|
||||||
ipcMain.handle(IpcChannel.Memory_Search, (_, query, config) => memoryService.search(query, config))
|
return await memoryService.add(messages, config)
|
||||||
ipcMain.handle(IpcChannel.Memory_List, (_, config) => memoryService.list(config))
|
})
|
||||||
ipcMain.handle(IpcChannel.Memory_Delete, (_, id) => memoryService.delete(id))
|
ipcMain.handle(IpcChannel.Memory_Search, async (_, query, config) => {
|
||||||
ipcMain.handle(IpcChannel.Memory_Update, (_, id, memory, metadata) => memoryService.update(id, memory, metadata))
|
return await memoryService.search(query, config)
|
||||||
ipcMain.handle(IpcChannel.Memory_Get, (_, memoryId) => memoryService.get(memoryId))
|
})
|
||||||
ipcMain.handle(IpcChannel.Memory_SetConfig, (_, config) => memoryService.setConfig(config))
|
ipcMain.handle(IpcChannel.Memory_List, async (_, config) => {
|
||||||
ipcMain.handle(IpcChannel.Memory_DeleteUser, (_, userId) => memoryService.deleteUser(userId))
|
return await memoryService.list(config)
|
||||||
ipcMain.handle(IpcChannel.Memory_DeleteAllMemoriesForUser, (_, userId) =>
|
})
|
||||||
memoryService.deleteAllMemoriesForUser(userId)
|
ipcMain.handle(IpcChannel.Memory_Delete, async (_, id) => {
|
||||||
)
|
return await memoryService.delete(id)
|
||||||
ipcMain.handle(IpcChannel.Memory_GetUsersList, () => memoryService.getUsersList())
|
})
|
||||||
ipcMain.handle(IpcChannel.Memory_MigrateMemoryDb, () => memoryService.migrateMemoryDb())
|
ipcMain.handle(IpcChannel.Memory_Update, async (_, id, memory, metadata) => {
|
||||||
|
return await memoryService.update(id, memory, metadata)
|
||||||
|
})
|
||||||
|
ipcMain.handle(IpcChannel.Memory_Get, async (_, memoryId) => {
|
||||||
|
return await memoryService.get(memoryId)
|
||||||
|
})
|
||||||
|
ipcMain.handle(IpcChannel.Memory_SetConfig, async (_, config) => {
|
||||||
|
memoryService.setConfig(config)
|
||||||
|
})
|
||||||
|
ipcMain.handle(IpcChannel.Memory_DeleteUser, async (_, userId) => {
|
||||||
|
return await memoryService.deleteUser(userId)
|
||||||
|
})
|
||||||
|
ipcMain.handle(IpcChannel.Memory_DeleteAllMemoriesForUser, async (_, userId) => {
|
||||||
|
return await memoryService.deleteAllMemoriesForUser(userId)
|
||||||
|
})
|
||||||
|
ipcMain.handle(IpcChannel.Memory_GetUsersList, async () => {
|
||||||
|
return await memoryService.getUsersList()
|
||||||
|
})
|
||||||
|
|
||||||
// window
|
// window
|
||||||
ipcMain.handle(IpcChannel.Windows_SetMinimumSize, (_, width: number, height: number) => {
|
ipcMain.handle(IpcChannel.Windows_SetMinimumSize, (_, width: number, height: number) => {
|
||||||
@ -858,8 +854,8 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
// search window
|
// search window
|
||||||
ipcMain.handle(IpcChannel.SearchWindow_Open, async (_, uid: string, show?: boolean) => {
|
ipcMain.handle(IpcChannel.SearchWindow_Open, async (_, uid: string) => {
|
||||||
await searchService.openSearchWindow(uid, show)
|
await searchService.openSearchWindow(uid)
|
||||||
})
|
})
|
||||||
ipcMain.handle(IpcChannel.SearchWindow_Close, async (_, uid: string) => {
|
ipcMain.handle(IpcChannel.SearchWindow_Close, async (_, uid: string) => {
|
||||||
await searchService.closeSearchWindow(uid)
|
await searchService.closeSearchWindow(uid)
|
||||||
@ -1101,17 +1097,12 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.LocalTransfer_ListServices, () => localTransferService.getState())
|
// WebSocket
|
||||||
ipcMain.handle(IpcChannel.LocalTransfer_StartScan, () => localTransferService.startDiscovery({ resetList: true }))
|
ipcMain.handle(IpcChannel.WebSocket_Start, WebSocketService.start)
|
||||||
ipcMain.handle(IpcChannel.LocalTransfer_StopScan, () => localTransferService.stopDiscovery())
|
ipcMain.handle(IpcChannel.WebSocket_Stop, WebSocketService.stop)
|
||||||
ipcMain.handle(IpcChannel.LocalTransfer_Connect, (_, payload: LocalTransferConnectPayload) =>
|
ipcMain.handle(IpcChannel.WebSocket_Status, WebSocketService.getStatus)
|
||||||
lanTransferClientService.connectAndHandshake(payload)
|
ipcMain.handle(IpcChannel.WebSocket_SendFile, WebSocketService.sendFile)
|
||||||
)
|
ipcMain.handle(IpcChannel.WebSocket_GetAllCandidates, WebSocketService.getAllCandidates)
|
||||||
ipcMain.handle(IpcChannel.LocalTransfer_Disconnect, () => lanTransferClientService.disconnect())
|
|
||||||
ipcMain.handle(IpcChannel.LocalTransfer_SendFile, (_, payload: { filePath: string }) =>
|
|
||||||
lanTransferClientService.sendFile(payload.filePath)
|
|
||||||
)
|
|
||||||
ipcMain.handle(IpcChannel.LocalTransfer_CancelTransfer, () => lanTransferClientService.cancelTransfer())
|
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.APP_CrashRenderProcess, () => {
|
ipcMain.handle(IpcChannel.APP_CrashRenderProcess, () => {
|
||||||
mainWindow.webContents.forcefullyCrashRenderer()
|
mainWindow.webContents.forcefullyCrashRenderer()
|
||||||
|
|||||||
@ -1,134 +0,0 @@
|
|||||||
import { describe, expect, it, vi } from 'vitest'
|
|
||||||
|
|
||||||
vi.mock('electron', () => {
|
|
||||||
const sendCommand = vi.fn(async (command: string, params?: { expression?: string }) => {
|
|
||||||
if (command === 'Runtime.evaluate') {
|
|
||||||
if (params?.expression === 'document.documentElement.outerHTML') {
|
|
||||||
return { result: { value: '<html><body><h1>Test</h1><p>Content</p></body></html>' } }
|
|
||||||
}
|
|
||||||
if (params?.expression === 'document.body.innerText') {
|
|
||||||
return { result: { value: 'Test\nContent' } }
|
|
||||||
}
|
|
||||||
return { result: { value: 'ok' } }
|
|
||||||
}
|
|
||||||
return {}
|
|
||||||
})
|
|
||||||
|
|
||||||
const debuggerObj = {
|
|
||||||
isAttached: vi.fn(() => true),
|
|
||||||
attach: vi.fn(),
|
|
||||||
detach: vi.fn(),
|
|
||||||
sendCommand
|
|
||||||
}
|
|
||||||
|
|
||||||
const webContents = {
|
|
||||||
debugger: debuggerObj,
|
|
||||||
setUserAgent: vi.fn(),
|
|
||||||
getURL: vi.fn(() => 'https://example.com/'),
|
|
||||||
getTitle: vi.fn(async () => 'Example Title'),
|
|
||||||
once: vi.fn(),
|
|
||||||
removeListener: vi.fn(),
|
|
||||||
on: vi.fn()
|
|
||||||
}
|
|
||||||
|
|
||||||
const loadURL = vi.fn(async () => {})
|
|
||||||
|
|
||||||
const windows: any[] = []
|
|
||||||
|
|
||||||
class MockBrowserWindow {
|
|
||||||
private destroyed = false
|
|
||||||
public webContents = webContents
|
|
||||||
public loadURL = loadURL
|
|
||||||
public isDestroyed = vi.fn(() => this.destroyed)
|
|
||||||
public close = vi.fn(() => {
|
|
||||||
this.destroyed = true
|
|
||||||
})
|
|
||||||
public destroy = vi.fn(() => {
|
|
||||||
this.destroyed = true
|
|
||||||
})
|
|
||||||
public on = vi.fn()
|
|
||||||
|
|
||||||
constructor() {
|
|
||||||
windows.push(this)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const app = {
|
|
||||||
isReady: vi.fn(() => true),
|
|
||||||
whenReady: vi.fn(async () => {}),
|
|
||||||
on: vi.fn()
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
BrowserWindow: MockBrowserWindow as any,
|
|
||||||
app,
|
|
||||||
__mockDebugger: debuggerObj,
|
|
||||||
__mockSendCommand: sendCommand,
|
|
||||||
__mockLoadURL: loadURL,
|
|
||||||
__mockWindows: windows
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
import * as electron from 'electron'
|
|
||||||
const { __mockWindows } = electron as typeof electron & { __mockWindows: any[] }
|
|
||||||
|
|
||||||
import { CdpBrowserController } from '../browser'
|
|
||||||
|
|
||||||
describe('CdpBrowserController', () => {
|
|
||||||
it('executes single-line code via Runtime.evaluate', async () => {
|
|
||||||
const controller = new CdpBrowserController()
|
|
||||||
const result = await controller.execute('1+1')
|
|
||||||
expect(result).toBe('ok')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('opens a URL (hidden) and returns current page info', async () => {
|
|
||||||
const controller = new CdpBrowserController()
|
|
||||||
const result = await controller.open('https://foo.bar/', 5000, false)
|
|
||||||
expect(result.currentUrl).toBe('https://example.com/')
|
|
||||||
expect(result.title).toBe('Example Title')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('opens a URL (visible) when show=true', async () => {
|
|
||||||
const controller = new CdpBrowserController()
|
|
||||||
const result = await controller.open('https://foo.bar/', 5000, true, 'session-a')
|
|
||||||
expect(result.currentUrl).toBe('https://example.com/')
|
|
||||||
expect(result.title).toBe('Example Title')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('reuses session for execute and supports multiline', async () => {
|
|
||||||
const controller = new CdpBrowserController()
|
|
||||||
await controller.open('https://foo.bar/', 5000, false, 'session-b')
|
|
||||||
const result = await controller.execute('const a=1; const b=2; a+b;', 5000, 'session-b')
|
|
||||||
expect(result).toBe('ok')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('evicts least recently used session when exceeding maxSessions', async () => {
|
|
||||||
const controller = new CdpBrowserController({ maxSessions: 2, idleTimeoutMs: 1000 * 60 })
|
|
||||||
await controller.open('https://foo.bar/', 5000, false, 's1')
|
|
||||||
await controller.open('https://foo.bar/', 5000, false, 's2')
|
|
||||||
await controller.open('https://foo.bar/', 5000, false, 's3')
|
|
||||||
const destroyedCount = __mockWindows.filter(
|
|
||||||
(w: any) => w.destroy.mock.calls.length > 0 || w.close.mock.calls.length > 0
|
|
||||||
).length
|
|
||||||
expect(destroyedCount).toBeGreaterThanOrEqual(1)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('fetches URL and returns html format', async () => {
|
|
||||||
const controller = new CdpBrowserController()
|
|
||||||
const result = await controller.fetch('https://example.com/', 'html')
|
|
||||||
expect(result).toBe('<html><body><h1>Test</h1><p>Content</p></body></html>')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('fetches URL and returns txt format', async () => {
|
|
||||||
const controller = new CdpBrowserController()
|
|
||||||
const result = await controller.fetch('https://example.com/', 'txt')
|
|
||||||
expect(result).toBe('Test\nContent')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('fetches URL and returns markdown format (default)', async () => {
|
|
||||||
const controller = new CdpBrowserController()
|
|
||||||
const result = await controller.fetch('https://example.com/')
|
|
||||||
expect(typeof result).toBe('string')
|
|
||||||
expect(result).toContain('Test')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@ -1,307 +0,0 @@
|
|||||||
import { app, BrowserWindow } from 'electron'
|
|
||||||
import TurndownService from 'turndown'
|
|
||||||
|
|
||||||
import { logger, userAgent } from './types'
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Controller for managing browser windows via Chrome DevTools Protocol (CDP).
|
|
||||||
* Supports multiple sessions with LRU eviction and idle timeout cleanup.
|
|
||||||
*/
|
|
||||||
export class CdpBrowserController {
|
|
||||||
private windows: Map<string, { win: BrowserWindow; lastActive: number }> = new Map()
|
|
||||||
private readonly maxSessions: number
|
|
||||||
private readonly idleTimeoutMs: number
|
|
||||||
|
|
||||||
constructor(options?: { maxSessions?: number; idleTimeoutMs?: number }) {
|
|
||||||
this.maxSessions = options?.maxSessions ?? 5
|
|
||||||
this.idleTimeoutMs = options?.idleTimeoutMs ?? 5 * 60 * 1000
|
|
||||||
}
|
|
||||||
|
|
||||||
private async ensureAppReady() {
|
|
||||||
if (!app.isReady()) {
|
|
||||||
await app.whenReady()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private touch(sessionId: string) {
|
|
||||||
const entry = this.windows.get(sessionId)
|
|
||||||
if (entry) entry.lastActive = Date.now()
|
|
||||||
}
|
|
||||||
|
|
||||||
private closeWindow(win: BrowserWindow, sessionId: string) {
|
|
||||||
try {
|
|
||||||
if (!win.isDestroyed()) {
|
|
||||||
if (win.webContents.debugger.isAttached()) {
|
|
||||||
win.webContents.debugger.detach()
|
|
||||||
}
|
|
||||||
win.close()
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
logger.warn('Error closing window', { error, sessionId })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private async ensureDebuggerAttached(dbg: Electron.Debugger, sessionId: string) {
|
|
||||||
if (!dbg.isAttached()) {
|
|
||||||
try {
|
|
||||||
logger.info('Attaching debugger', { sessionId })
|
|
||||||
dbg.attach('1.3')
|
|
||||||
await dbg.sendCommand('Page.enable')
|
|
||||||
await dbg.sendCommand('Runtime.enable')
|
|
||||||
logger.info('Debugger attached and domains enabled')
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('Failed to attach debugger', { error })
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private sweepIdle() {
|
|
||||||
const now = Date.now()
|
|
||||||
for (const [id, entry] of this.windows.entries()) {
|
|
||||||
if (now - entry.lastActive > this.idleTimeoutMs) {
|
|
||||||
this.closeWindow(entry.win, id)
|
|
||||||
this.windows.delete(id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private evictIfNeeded(newSessionId: string) {
|
|
||||||
if (this.windows.size < this.maxSessions) return
|
|
||||||
let lruId: string | null = null
|
|
||||||
let lruTime = Number.POSITIVE_INFINITY
|
|
||||||
for (const [id, entry] of this.windows.entries()) {
|
|
||||||
if (id === newSessionId) continue
|
|
||||||
if (entry.lastActive < lruTime) {
|
|
||||||
lruTime = entry.lastActive
|
|
||||||
lruId = id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (lruId) {
|
|
||||||
const entry = this.windows.get(lruId)
|
|
||||||
if (entry) {
|
|
||||||
this.closeWindow(entry.win, lruId)
|
|
||||||
}
|
|
||||||
this.windows.delete(lruId)
|
|
||||||
logger.info('Evicted session to respect maxSessions', { evicted: lruId })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private async getWindow(sessionId = 'default', forceNew = false, show = false): Promise<BrowserWindow> {
|
|
||||||
await this.ensureAppReady()
|
|
||||||
|
|
||||||
this.sweepIdle()
|
|
||||||
|
|
||||||
const existing = this.windows.get(sessionId)
|
|
||||||
if (existing && !existing.win.isDestroyed() && !forceNew) {
|
|
||||||
this.touch(sessionId)
|
|
||||||
return existing.win
|
|
||||||
}
|
|
||||||
|
|
||||||
if (existing && !existing.win.isDestroyed() && forceNew) {
|
|
||||||
try {
|
|
||||||
if (existing.win.webContents.debugger.isAttached()) {
|
|
||||||
existing.win.webContents.debugger.detach()
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
logger.warn('Error detaching debugger before recreate', { error, sessionId })
|
|
||||||
}
|
|
||||||
existing.win.destroy()
|
|
||||||
this.windows.delete(sessionId)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.evictIfNeeded(sessionId)
|
|
||||||
|
|
||||||
const win = new BrowserWindow({
|
|
||||||
show,
|
|
||||||
webPreferences: {
|
|
||||||
contextIsolation: true,
|
|
||||||
sandbox: true,
|
|
||||||
nodeIntegration: false,
|
|
||||||
devTools: true
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Use a standard Chrome UA to avoid some anti-bot blocks
|
|
||||||
win.webContents.setUserAgent(userAgent)
|
|
||||||
|
|
||||||
// Log navigation lifecycle to help diagnose slow loads
|
|
||||||
win.webContents.on('did-start-loading', () => logger.info(`did-start-loading`, { sessionId }))
|
|
||||||
win.webContents.on('dom-ready', () => logger.info(`dom-ready`, { sessionId }))
|
|
||||||
win.webContents.on('did-finish-load', () => logger.info(`did-finish-load`, { sessionId }))
|
|
||||||
win.webContents.on('did-fail-load', (_e, code, desc) => logger.warn('Navigation failed', { code, desc }))
|
|
||||||
|
|
||||||
win.on('closed', () => {
|
|
||||||
this.windows.delete(sessionId)
|
|
||||||
})
|
|
||||||
|
|
||||||
this.windows.set(sessionId, { win, lastActive: Date.now() })
|
|
||||||
return win
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Opens a URL in a browser window and waits for navigation to complete.
|
|
||||||
* @param url - The URL to navigate to
|
|
||||||
* @param timeout - Navigation timeout in milliseconds (default: 10000)
|
|
||||||
* @param show - Whether to show the browser window (default: false)
|
|
||||||
* @param sessionId - Session identifier for window reuse (default: 'default')
|
|
||||||
* @returns Object containing the current URL and page title after navigation
|
|
||||||
*/
|
|
||||||
public async open(url: string, timeout = 10000, show = false, sessionId = 'default') {
|
|
||||||
const win = await this.getWindow(sessionId, true, show)
|
|
||||||
logger.info('Loading URL', { url, sessionId })
|
|
||||||
const { webContents } = win
|
|
||||||
this.touch(sessionId)
|
|
||||||
|
|
||||||
// Track resolution state to prevent multiple handlers from firing
|
|
||||||
let resolved = false
|
|
||||||
let onFinish: () => void
|
|
||||||
let onDomReady: () => void
|
|
||||||
let onFail: (_event: Electron.Event, code: number, desc: string) => void
|
|
||||||
|
|
||||||
// Define cleanup outside Promise to ensure it's callable in finally block,
|
|
||||||
// preventing memory leaks when timeout occurs before navigation completes
|
|
||||||
const cleanup = () => {
|
|
||||||
webContents.removeListener('did-finish-load', onFinish)
|
|
||||||
webContents.removeListener('did-fail-load', onFail)
|
|
||||||
webContents.removeListener('dom-ready', onDomReady)
|
|
||||||
}
|
|
||||||
|
|
||||||
const loadPromise = new Promise<void>((resolve, reject) => {
|
|
||||||
onFinish = () => {
|
|
||||||
if (resolved) return
|
|
||||||
resolved = true
|
|
||||||
cleanup()
|
|
||||||
resolve()
|
|
||||||
}
|
|
||||||
onDomReady = () => {
|
|
||||||
if (resolved) return
|
|
||||||
resolved = true
|
|
||||||
cleanup()
|
|
||||||
resolve()
|
|
||||||
}
|
|
||||||
onFail = (_event: Electron.Event, code: number, desc: string) => {
|
|
||||||
if (resolved) return
|
|
||||||
resolved = true
|
|
||||||
cleanup()
|
|
||||||
reject(new Error(`Navigation failed (${code}): ${desc}`))
|
|
||||||
}
|
|
||||||
webContents.once('did-finish-load', onFinish)
|
|
||||||
webContents.once('dom-ready', onDomReady)
|
|
||||||
webContents.once('did-fail-load', onFail)
|
|
||||||
})
|
|
||||||
|
|
||||||
const timeoutPromise = new Promise<void>((_, reject) => {
|
|
||||||
setTimeout(() => reject(new Error('Navigation timed out')), timeout)
|
|
||||||
})
|
|
||||||
|
|
||||||
try {
|
|
||||||
await Promise.race([win.loadURL(url), loadPromise, timeoutPromise])
|
|
||||||
} finally {
|
|
||||||
// Always cleanup listeners to prevent memory leaks on timeout
|
|
||||||
cleanup()
|
|
||||||
}
|
|
||||||
|
|
||||||
const currentUrl = webContents.getURL()
|
|
||||||
const title = await webContents.getTitle()
|
|
||||||
return { currentUrl, title }
|
|
||||||
}
|
|
||||||
|
|
||||||
public async execute(code: string, timeout = 5000, sessionId = 'default') {
|
|
||||||
const win = await this.getWindow(sessionId)
|
|
||||||
this.touch(sessionId)
|
|
||||||
const dbg = win.webContents.debugger
|
|
||||||
|
|
||||||
await this.ensureDebuggerAttached(dbg, sessionId)
|
|
||||||
|
|
||||||
const evalPromise = dbg.sendCommand('Runtime.evaluate', {
|
|
||||||
expression: code,
|
|
||||||
awaitPromise: true,
|
|
||||||
returnByValue: true
|
|
||||||
})
|
|
||||||
|
|
||||||
const result = await Promise.race([
|
|
||||||
evalPromise,
|
|
||||||
new Promise((_, reject) => setTimeout(() => reject(new Error('Execution timed out')), timeout))
|
|
||||||
])
|
|
||||||
|
|
||||||
const evalResult = result as any
|
|
||||||
|
|
||||||
if (evalResult?.exceptionDetails) {
|
|
||||||
const message = evalResult.exceptionDetails.exception?.description || 'Unknown script error'
|
|
||||||
logger.warn('Runtime.evaluate raised exception', { message })
|
|
||||||
throw new Error(message)
|
|
||||||
}
|
|
||||||
|
|
||||||
const value = evalResult?.result?.value ?? evalResult?.result?.description ?? null
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
public async reset(sessionId?: string) {
|
|
||||||
if (sessionId) {
|
|
||||||
const entry = this.windows.get(sessionId)
|
|
||||||
if (entry) {
|
|
||||||
this.closeWindow(entry.win, sessionId)
|
|
||||||
}
|
|
||||||
this.windows.delete(sessionId)
|
|
||||||
logger.info('Browser CDP context reset', { sessionId })
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const [id, entry] of this.windows.entries()) {
|
|
||||||
this.closeWindow(entry.win, id)
|
|
||||||
this.windows.delete(id)
|
|
||||||
}
|
|
||||||
logger.info('Browser CDP context reset (all sessions)')
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Fetches a URL and returns content in the specified format.
|
|
||||||
* @param url - The URL to fetch
|
|
||||||
* @param format - Output format: 'html', 'txt', 'markdown', or 'json' (default: 'markdown')
|
|
||||||
* @param timeout - Navigation timeout in milliseconds (default: 10000)
|
|
||||||
* @param sessionId - Session identifier (default: 'default')
|
|
||||||
* @returns Content in the requested format. For 'json', returns parsed object or { data: rawContent } if parsing fails
|
|
||||||
*/
|
|
||||||
public async fetch(
|
|
||||||
url: string,
|
|
||||||
format: 'html' | 'txt' | 'markdown' | 'json' = 'markdown',
|
|
||||||
timeout = 10000,
|
|
||||||
sessionId = 'default'
|
|
||||||
) {
|
|
||||||
await this.open(url, timeout, false, sessionId)
|
|
||||||
|
|
||||||
const win = await this.getWindow(sessionId)
|
|
||||||
const dbg = win.webContents.debugger
|
|
||||||
|
|
||||||
await this.ensureDebuggerAttached(dbg, sessionId)
|
|
||||||
|
|
||||||
let expression: string
|
|
||||||
if (format === 'json' || format === 'txt') {
|
|
||||||
expression = 'document.body.innerText'
|
|
||||||
} else {
|
|
||||||
expression = 'document.documentElement.outerHTML'
|
|
||||||
}
|
|
||||||
|
|
||||||
const result = (await dbg.sendCommand('Runtime.evaluate', {
|
|
||||||
expression,
|
|
||||||
returnByValue: true
|
|
||||||
})) as { result?: { value?: string } }
|
|
||||||
|
|
||||||
const content = result?.result?.value ?? ''
|
|
||||||
|
|
||||||
if (format === 'markdown') {
|
|
||||||
const turndownService = new TurndownService()
|
|
||||||
return turndownService.turndown(content)
|
|
||||||
}
|
|
||||||
if (format === 'json') {
|
|
||||||
// Attempt to parse as JSON; if content is not valid JSON, wrap it in a data object
|
|
||||||
try {
|
|
||||||
return JSON.parse(content)
|
|
||||||
} catch {
|
|
||||||
return { data: content }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
export { CdpBrowserController } from './controller'
|
|
||||||
export { BrowserServer } from './server'
|
|
||||||
export { BrowserServer as default } from './server'
|
|
||||||
@ -1,50 +0,0 @@
|
|||||||
import type { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
|
||||||
import { Server as MCServer } from '@modelcontextprotocol/sdk/server/index.js'
|
|
||||||
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'
|
|
||||||
import { app } from 'electron'
|
|
||||||
|
|
||||||
import { CdpBrowserController } from './controller'
|
|
||||||
import { toolDefinitions, toolHandlers } from './tools'
|
|
||||||
|
|
||||||
export class BrowserServer {
|
|
||||||
public server: Server
|
|
||||||
private controller = new CdpBrowserController()
|
|
||||||
|
|
||||||
constructor() {
|
|
||||||
const server = new MCServer(
|
|
||||||
{
|
|
||||||
name: '@cherry/browser',
|
|
||||||
version: '0.1.0'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: {
|
|
||||||
resources: {},
|
|
||||||
tools: {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
server.setRequestHandler(ListToolsRequestSchema, async () => {
|
|
||||||
return {
|
|
||||||
tools: toolDefinitions
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
server.setRequestHandler(CallToolRequestSchema, async (request) => {
|
|
||||||
const { name, arguments: args } = request.params
|
|
||||||
const handler = toolHandlers[name]
|
|
||||||
if (!handler) {
|
|
||||||
throw new Error('Tool not found')
|
|
||||||
}
|
|
||||||
return handler(this.controller, args)
|
|
||||||
})
|
|
||||||
|
|
||||||
app.on('before-quit', () => {
|
|
||||||
void this.controller.reset()
|
|
||||||
})
|
|
||||||
|
|
||||||
this.server = server
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export default BrowserServer
|
|
||||||
@ -1,48 +0,0 @@
|
|||||||
import * as z from 'zod'
|
|
||||||
|
|
||||||
import type { CdpBrowserController } from '../controller'
|
|
||||||
import { errorResponse, successResponse } from './utils'
|
|
||||||
|
|
||||||
export const ExecuteSchema = z.object({
|
|
||||||
code: z
|
|
||||||
.string()
|
|
||||||
.describe(
|
|
||||||
'JavaScript evaluated via Chrome DevTools Runtime.evaluate. Keep it short; prefer one-line with semicolons for multiple statements.'
|
|
||||||
),
|
|
||||||
timeout: z.number().default(5000).describe('Timeout in milliseconds for code execution (default: 5000ms)'),
|
|
||||||
sessionId: z.string().optional().describe('Session identifier to target a specific page (default: default)')
|
|
||||||
})
|
|
||||||
|
|
||||||
export const executeToolDefinition = {
|
|
||||||
name: 'execute',
|
|
||||||
description:
|
|
||||||
'Run JavaScript in the current page via Runtime.evaluate. Prefer short, single-line snippets; use semicolons for multiple statements.',
|
|
||||||
inputSchema: {
|
|
||||||
type: 'object',
|
|
||||||
properties: {
|
|
||||||
code: {
|
|
||||||
type: 'string',
|
|
||||||
description: 'One-line JS to evaluate in page context'
|
|
||||||
},
|
|
||||||
timeout: {
|
|
||||||
type: 'number',
|
|
||||||
description: 'Timeout in milliseconds (default 5000)'
|
|
||||||
},
|
|
||||||
sessionId: {
|
|
||||||
type: 'string',
|
|
||||||
description: 'Session identifier; targets a specific page (default: default)'
|
|
||||||
}
|
|
||||||
},
|
|
||||||
required: ['code']
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function handleExecute(controller: CdpBrowserController, args: unknown) {
|
|
||||||
const { code, timeout, sessionId } = ExecuteSchema.parse(args)
|
|
||||||
try {
|
|
||||||
const value = await controller.execute(code, timeout, sessionId ?? 'default')
|
|
||||||
return successResponse(typeof value === 'string' ? value : JSON.stringify(value))
|
|
||||||
} catch (error) {
|
|
||||||
return errorResponse(error as Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,49 +0,0 @@
|
|||||||
import * as z from 'zod'
|
|
||||||
|
|
||||||
import type { CdpBrowserController } from '../controller'
|
|
||||||
import { errorResponse, successResponse } from './utils'
|
|
||||||
|
|
||||||
export const FetchSchema = z.object({
|
|
||||||
url: z.url().describe('URL to fetch'),
|
|
||||||
format: z.enum(['html', 'txt', 'markdown', 'json']).default('markdown').describe('Output format (default: markdown)'),
|
|
||||||
timeout: z.number().optional().describe('Timeout in milliseconds for navigation (default: 10000)'),
|
|
||||||
sessionId: z.string().optional().describe('Session identifier (default: default)')
|
|
||||||
})
|
|
||||||
|
|
||||||
export const fetchToolDefinition = {
|
|
||||||
name: 'fetch',
|
|
||||||
description: 'Fetch a URL using the browser and return content in specified format (html, txt, markdown, json)',
|
|
||||||
inputSchema: {
|
|
||||||
type: 'object',
|
|
||||||
properties: {
|
|
||||||
url: {
|
|
||||||
type: 'string',
|
|
||||||
description: 'URL to fetch'
|
|
||||||
},
|
|
||||||
format: {
|
|
||||||
type: 'string',
|
|
||||||
enum: ['html', 'txt', 'markdown', 'json'],
|
|
||||||
description: 'Output format (default: markdown)'
|
|
||||||
},
|
|
||||||
timeout: {
|
|
||||||
type: 'number',
|
|
||||||
description: 'Navigation timeout in milliseconds (default: 10000)'
|
|
||||||
},
|
|
||||||
sessionId: {
|
|
||||||
type: 'string',
|
|
||||||
description: 'Session identifier (default: default)'
|
|
||||||
}
|
|
||||||
},
|
|
||||||
required: ['url']
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function handleFetch(controller: CdpBrowserController, args: unknown) {
|
|
||||||
const { url, format, timeout, sessionId } = FetchSchema.parse(args)
|
|
||||||
try {
|
|
||||||
const content = await controller.fetch(url, format, timeout ?? 10000, sessionId ?? 'default')
|
|
||||||
return successResponse(typeof content === 'string' ? content : JSON.stringify(content))
|
|
||||||
} catch (error) {
|
|
||||||
return errorResponse(error as Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,25 +0,0 @@
|
|||||||
export { ExecuteSchema, executeToolDefinition, handleExecute } from './execute'
|
|
||||||
export { FetchSchema, fetchToolDefinition, handleFetch } from './fetch'
|
|
||||||
export { handleOpen, OpenSchema, openToolDefinition } from './open'
|
|
||||||
export { handleReset, resetToolDefinition } from './reset'
|
|
||||||
|
|
||||||
import type { CdpBrowserController } from '../controller'
|
|
||||||
import { executeToolDefinition, handleExecute } from './execute'
|
|
||||||
import { fetchToolDefinition, handleFetch } from './fetch'
|
|
||||||
import { handleOpen, openToolDefinition } from './open'
|
|
||||||
import { handleReset, resetToolDefinition } from './reset'
|
|
||||||
|
|
||||||
export const toolDefinitions = [openToolDefinition, executeToolDefinition, resetToolDefinition, fetchToolDefinition]
|
|
||||||
|
|
||||||
export const toolHandlers: Record<
|
|
||||||
string,
|
|
||||||
(
|
|
||||||
controller: CdpBrowserController,
|
|
||||||
args: unknown
|
|
||||||
) => Promise<{ content: { type: string; text: string }[]; isError: boolean }>
|
|
||||||
> = {
|
|
||||||
open: handleOpen,
|
|
||||||
execute: handleExecute,
|
|
||||||
reset: handleReset,
|
|
||||||
fetch: handleFetch
|
|
||||||
}
|
|
||||||
@ -1,47 +0,0 @@
|
|||||||
import * as z from 'zod'
|
|
||||||
|
|
||||||
import type { CdpBrowserController } from '../controller'
|
|
||||||
import { successResponse } from './utils'
|
|
||||||
|
|
||||||
export const OpenSchema = z.object({
|
|
||||||
url: z.url().describe('URL to open in the controlled Electron window'),
|
|
||||||
timeout: z.number().optional().describe('Timeout in milliseconds for navigation (default: 10000)'),
|
|
||||||
show: z.boolean().optional().describe('Whether to show the browser window (default: false)'),
|
|
||||||
sessionId: z
|
|
||||||
.string()
|
|
||||||
.optional()
|
|
||||||
.describe('Session identifier; separate sessions keep separate pages (default: default)')
|
|
||||||
})
|
|
||||||
|
|
||||||
export const openToolDefinition = {
|
|
||||||
name: 'open',
|
|
||||||
description: 'Open a URL in a hidden Electron window controlled via Chrome DevTools Protocol',
|
|
||||||
inputSchema: {
|
|
||||||
type: 'object',
|
|
||||||
properties: {
|
|
||||||
url: {
|
|
||||||
type: 'string',
|
|
||||||
description: 'URL to load'
|
|
||||||
},
|
|
||||||
timeout: {
|
|
||||||
type: 'number',
|
|
||||||
description: 'Navigation timeout in milliseconds (default 10000)'
|
|
||||||
},
|
|
||||||
show: {
|
|
||||||
type: 'boolean',
|
|
||||||
description: 'Whether to show the browser window (default false)'
|
|
||||||
},
|
|
||||||
sessionId: {
|
|
||||||
type: 'string',
|
|
||||||
description: 'Session identifier; separate sessions keep separate pages (default: default)'
|
|
||||||
}
|
|
||||||
},
|
|
||||||
required: ['url']
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function handleOpen(controller: CdpBrowserController, args: unknown) {
|
|
||||||
const { url, timeout, show, sessionId } = OpenSchema.parse(args)
|
|
||||||
const res = await controller.open(url, timeout ?? 10000, show ?? false, sessionId ?? 'default')
|
|
||||||
return successResponse(JSON.stringify(res))
|
|
||||||
}
|
|
||||||
@ -1,34 +0,0 @@
|
|||||||
import * as z from 'zod'
|
|
||||||
|
|
||||||
import type { CdpBrowserController } from '../controller'
|
|
||||||
import { successResponse } from './utils'
|
|
||||||
|
|
||||||
/** Zod schema for validating reset tool arguments */
|
|
||||||
export const ResetSchema = z.object({
|
|
||||||
sessionId: z.string().optional().describe('Session identifier to reset; omit to reset all sessions')
|
|
||||||
})
|
|
||||||
|
|
||||||
/** MCP tool definition for the reset tool */
|
|
||||||
export const resetToolDefinition = {
|
|
||||||
name: 'reset',
|
|
||||||
description: 'Reset the controlled window and detach debugger',
|
|
||||||
inputSchema: {
|
|
||||||
type: 'object',
|
|
||||||
properties: {
|
|
||||||
sessionId: {
|
|
||||||
type: 'string',
|
|
||||||
description: 'Session identifier to reset; omit to reset all sessions'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handler for the reset MCP tool.
|
|
||||||
* Closes browser window(s) and detaches debugger for the specified session or all sessions.
|
|
||||||
*/
|
|
||||||
export async function handleReset(controller: CdpBrowserController, args: unknown) {
|
|
||||||
const { sessionId } = ResetSchema.parse(args)
|
|
||||||
await controller.reset(sessionId)
|
|
||||||
return successResponse('reset')
|
|
||||||
}
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
export function successResponse(text: string) {
|
|
||||||
return {
|
|
||||||
content: [{ type: 'text', text }],
|
|
||||||
isError: false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export function errorResponse(error: Error) {
|
|
||||||
return {
|
|
||||||
content: [{ type: 'text', text: error.message }],
|
|
||||||
isError: true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,4 +0,0 @@
|
|||||||
import { loggerService } from '@logger'
|
|
||||||
|
|
||||||
export const logger = loggerService.withContext('MCPBrowserCDP')
|
|
||||||
export const userAgent = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:145.0) Gecko/20100101 Firefox/145.0'
|
|
||||||
@ -4,7 +4,6 @@ import type { BuiltinMCPServerName } from '@types'
|
|||||||
import { BuiltinMCPServerNames } from '@types'
|
import { BuiltinMCPServerNames } from '@types'
|
||||||
|
|
||||||
import BraveSearchServer from './brave-search'
|
import BraveSearchServer from './brave-search'
|
||||||
import BrowserServer from './browser'
|
|
||||||
import DiDiMcpServer from './didi-mcp'
|
import DiDiMcpServer from './didi-mcp'
|
||||||
import DifyKnowledgeServer from './dify-knowledge'
|
import DifyKnowledgeServer from './dify-knowledge'
|
||||||
import FetchServer from './fetch'
|
import FetchServer from './fetch'
|
||||||
@ -36,7 +35,7 @@ export function createInMemoryMCPServer(
|
|||||||
return new FetchServer().server
|
return new FetchServer().server
|
||||||
}
|
}
|
||||||
case BuiltinMCPServerNames.filesystem: {
|
case BuiltinMCPServerNames.filesystem: {
|
||||||
return new FileSystemServer(envs.WORKSPACE_ROOT).server
|
return new FileSystemServer(args).server
|
||||||
}
|
}
|
||||||
case BuiltinMCPServerNames.difyKnowledge: {
|
case BuiltinMCPServerNames.difyKnowledge: {
|
||||||
const difyKey = envs.DIFY_KEY
|
const difyKey = envs.DIFY_KEY
|
||||||
@ -49,9 +48,6 @@ export function createInMemoryMCPServer(
|
|||||||
const apiKey = envs.DIDI_API_KEY
|
const apiKey = envs.DIDI_API_KEY
|
||||||
return new DiDiMcpServer(apiKey).server
|
return new DiDiMcpServer(apiKey).server
|
||||||
}
|
}
|
||||||
case BuiltinMCPServerNames.browser: {
|
|
||||||
return new BrowserServer().server
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unknown in-memory MCP server: ${name}`)
|
throw new Error(`Unknown in-memory MCP server: ${name}`)
|
||||||
}
|
}
|
||||||
|
|||||||
652
src/main/mcpServers/filesystem.ts
Normal file
652
src/main/mcpServers/filesystem.ts
Normal file
@ -0,0 +1,652 @@
|
|||||||
|
// port https://github.com/modelcontextprotocol/servers/blob/main/src/filesystem/index.ts
|
||||||
|
|
||||||
|
import { loggerService } from '@logger'
|
||||||
|
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||||
|
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'
|
||||||
|
import { createTwoFilesPatch } from 'diff'
|
||||||
|
import fs from 'fs/promises'
|
||||||
|
import { minimatch } from 'minimatch'
|
||||||
|
import os from 'os'
|
||||||
|
import path from 'path'
|
||||||
|
import * as z from 'zod'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('MCP:FileSystemServer')
|
||||||
|
|
||||||
|
// Normalize all paths consistently
|
||||||
|
function normalizePath(p: string): string {
|
||||||
|
return path.normalize(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
function expandHome(filepath: string): string {
|
||||||
|
if (filepath.startsWith('~/') || filepath === '~') {
|
||||||
|
return path.join(os.homedir(), filepath.slice(1))
|
||||||
|
}
|
||||||
|
return filepath
|
||||||
|
}
|
||||||
|
|
||||||
|
// Security utilities
|
||||||
|
async function validatePath(allowedDirectories: string[], requestedPath: string): Promise<string> {
|
||||||
|
const expandedPath = expandHome(requestedPath)
|
||||||
|
const absolute = path.isAbsolute(expandedPath)
|
||||||
|
? path.resolve(expandedPath)
|
||||||
|
: path.resolve(process.cwd(), expandedPath)
|
||||||
|
|
||||||
|
const normalizedRequested = normalizePath(absolute)
|
||||||
|
|
||||||
|
// Check if path is within allowed directories
|
||||||
|
const isAllowed = allowedDirectories.some((dir) => normalizedRequested.startsWith(dir))
|
||||||
|
if (!isAllowed) {
|
||||||
|
throw new Error(
|
||||||
|
`Access denied - path outside allowed directories: ${absolute} not in ${allowedDirectories.join(', ')}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle symlinks by checking their real path
|
||||||
|
try {
|
||||||
|
const realPath = await fs.realpath(absolute)
|
||||||
|
const normalizedReal = normalizePath(realPath)
|
||||||
|
const isRealPathAllowed = allowedDirectories.some((dir) => normalizedReal.startsWith(dir))
|
||||||
|
if (!isRealPathAllowed) {
|
||||||
|
throw new Error('Access denied - symlink target outside allowed directories')
|
||||||
|
}
|
||||||
|
return realPath
|
||||||
|
} catch (error) {
|
||||||
|
// For new files that don't exist yet, verify parent directory
|
||||||
|
const parentDir = path.dirname(absolute)
|
||||||
|
try {
|
||||||
|
const realParentPath = await fs.realpath(parentDir)
|
||||||
|
const normalizedParent = normalizePath(realParentPath)
|
||||||
|
const isParentAllowed = allowedDirectories.some((dir) => normalizedParent.startsWith(dir))
|
||||||
|
if (!isParentAllowed) {
|
||||||
|
throw new Error('Access denied - parent directory outside allowed directories')
|
||||||
|
}
|
||||||
|
return absolute
|
||||||
|
} catch {
|
||||||
|
throw new Error(`Parent directory does not exist: ${parentDir}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schema definitions
|
||||||
|
const ReadFileArgsSchema = z.object({
|
||||||
|
path: z.string()
|
||||||
|
})
|
||||||
|
|
||||||
|
const ReadMultipleFilesArgsSchema = z.object({
|
||||||
|
paths: z.array(z.string())
|
||||||
|
})
|
||||||
|
|
||||||
|
const WriteFileArgsSchema = z.object({
|
||||||
|
path: z.string(),
|
||||||
|
content: z.string()
|
||||||
|
})
|
||||||
|
|
||||||
|
const EditOperation = z.object({
|
||||||
|
oldText: z.string().describe('Text to search for - must match exactly'),
|
||||||
|
newText: z.string().describe('Text to replace with')
|
||||||
|
})
|
||||||
|
|
||||||
|
const EditFileArgsSchema = z.object({
|
||||||
|
path: z.string(),
|
||||||
|
edits: z.array(EditOperation),
|
||||||
|
dryRun: z.boolean().default(false).describe('Preview changes using git-style diff format')
|
||||||
|
})
|
||||||
|
|
||||||
|
const CreateDirectoryArgsSchema = z.object({
|
||||||
|
path: z.string()
|
||||||
|
})
|
||||||
|
|
||||||
|
const ListDirectoryArgsSchema = z.object({
|
||||||
|
path: z.string()
|
||||||
|
})
|
||||||
|
|
||||||
|
const DirectoryTreeArgsSchema = z.object({
|
||||||
|
path: z.string()
|
||||||
|
})
|
||||||
|
|
||||||
|
const MoveFileArgsSchema = z.object({
|
||||||
|
source: z.string(),
|
||||||
|
destination: z.string()
|
||||||
|
})
|
||||||
|
|
||||||
|
const SearchFilesArgsSchema = z.object({
|
||||||
|
path: z.string(),
|
||||||
|
pattern: z.string(),
|
||||||
|
excludePatterns: z.array(z.string()).optional().default([])
|
||||||
|
})
|
||||||
|
|
||||||
|
const GetFileInfoArgsSchema = z.object({
|
||||||
|
path: z.string()
|
||||||
|
})
|
||||||
|
|
||||||
|
interface FileInfo {
|
||||||
|
size: number
|
||||||
|
created: Date
|
||||||
|
modified: Date
|
||||||
|
accessed: Date
|
||||||
|
isDirectory: boolean
|
||||||
|
isFile: boolean
|
||||||
|
permissions: string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool implementations
|
||||||
|
async function getFileStats(filePath: string): Promise<FileInfo> {
|
||||||
|
const stats = await fs.stat(filePath)
|
||||||
|
return {
|
||||||
|
size: stats.size,
|
||||||
|
created: stats.birthtime,
|
||||||
|
modified: stats.mtime,
|
||||||
|
accessed: stats.atime,
|
||||||
|
isDirectory: stats.isDirectory(),
|
||||||
|
isFile: stats.isFile(),
|
||||||
|
permissions: stats.mode.toString(8).slice(-3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function searchFiles(
|
||||||
|
allowedDirectories: string[],
|
||||||
|
rootPath: string,
|
||||||
|
pattern: string,
|
||||||
|
excludePatterns: string[] = []
|
||||||
|
): Promise<string[]> {
|
||||||
|
const results: string[] = []
|
||||||
|
|
||||||
|
async function search(currentPath: string) {
|
||||||
|
const entries = await fs.readdir(currentPath, { withFileTypes: true })
|
||||||
|
|
||||||
|
for (const entry of entries) {
|
||||||
|
const fullPath = path.join(currentPath, entry.name)
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Validate each path before processing
|
||||||
|
await validatePath(allowedDirectories, fullPath)
|
||||||
|
|
||||||
|
// Check if path matches any exclude pattern
|
||||||
|
const relativePath = path.relative(rootPath, fullPath)
|
||||||
|
const shouldExclude = excludePatterns.some((pattern) => {
|
||||||
|
const globPattern = pattern.includes('*') ? pattern : `**/${pattern}/**`
|
||||||
|
return minimatch(relativePath, globPattern, { dot: true })
|
||||||
|
})
|
||||||
|
|
||||||
|
if (shouldExclude) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if (entry.name.toLowerCase().includes(pattern.toLowerCase())) {
|
||||||
|
results.push(fullPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (entry.isDirectory()) {
|
||||||
|
await search(fullPath)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// Skip invalid paths during search
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await search(rootPath)
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// file editing and diffing utilities
|
||||||
|
function normalizeLineEndings(text: string): string {
|
||||||
|
return text.replace(/\r\n/g, '\n')
|
||||||
|
}
|
||||||
|
|
||||||
|
function createUnifiedDiff(originalContent: string, newContent: string, filepath: string = 'file'): string {
|
||||||
|
// Ensure consistent line endings for diff
|
||||||
|
const normalizedOriginal = normalizeLineEndings(originalContent)
|
||||||
|
const normalizedNew = normalizeLineEndings(newContent)
|
||||||
|
|
||||||
|
return createTwoFilesPatch(filepath, filepath, normalizedOriginal, normalizedNew, 'original', 'modified')
|
||||||
|
}
|
||||||
|
|
||||||
|
async function applyFileEdits(
|
||||||
|
filePath: string,
|
||||||
|
edits: Array<{ oldText: string; newText: string }>,
|
||||||
|
dryRun = false
|
||||||
|
): Promise<string> {
|
||||||
|
// Read file content and normalize line endings
|
||||||
|
const content = normalizeLineEndings(await fs.readFile(filePath, 'utf-8'))
|
||||||
|
|
||||||
|
// Apply edits sequentially
|
||||||
|
let modifiedContent = content
|
||||||
|
for (const edit of edits) {
|
||||||
|
const normalizedOld = normalizeLineEndings(edit.oldText)
|
||||||
|
const normalizedNew = normalizeLineEndings(edit.newText)
|
||||||
|
|
||||||
|
// If exact match exists, use it
|
||||||
|
if (modifiedContent.includes(normalizedOld)) {
|
||||||
|
modifiedContent = modifiedContent.replace(normalizedOld, normalizedNew)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, try line-by-line matching with flexibility for whitespace
|
||||||
|
const oldLines = normalizedOld.split('\n')
|
||||||
|
const contentLines = modifiedContent.split('\n')
|
||||||
|
let matchFound = false
|
||||||
|
|
||||||
|
for (let i = 0; i <= contentLines.length - oldLines.length; i++) {
|
||||||
|
const potentialMatch = contentLines.slice(i, i + oldLines.length)
|
||||||
|
|
||||||
|
// Compare lines with normalized whitespace
|
||||||
|
const isMatch = oldLines.every((oldLine, j) => {
|
||||||
|
const contentLine = potentialMatch[j]
|
||||||
|
return oldLine.trim() === contentLine.trim()
|
||||||
|
})
|
||||||
|
|
||||||
|
if (isMatch) {
|
||||||
|
// Preserve original indentation of first line
|
||||||
|
const originalIndent = contentLines[i].match(/^\s*/)?.[0] || ''
|
||||||
|
const newLines = normalizedNew.split('\n').map((line, j) => {
|
||||||
|
if (j === 0) return originalIndent + line.trimStart()
|
||||||
|
// For subsequent lines, try to preserve relative indentation
|
||||||
|
const oldIndent = oldLines[j]?.match(/^\s*/)?.[0] || ''
|
||||||
|
const newIndent = line.match(/^\s*/)?.[0] || ''
|
||||||
|
if (oldIndent && newIndent) {
|
||||||
|
const relativeIndent = newIndent.length - oldIndent.length
|
||||||
|
return originalIndent + ' '.repeat(Math.max(0, relativeIndent)) + line.trimStart()
|
||||||
|
}
|
||||||
|
return line
|
||||||
|
})
|
||||||
|
|
||||||
|
contentLines.splice(i, oldLines.length, ...newLines)
|
||||||
|
modifiedContent = contentLines.join('\n')
|
||||||
|
matchFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!matchFound) {
|
||||||
|
throw new Error(`Could not find exact match for edit:\n${edit.oldText}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create unified diff
|
||||||
|
const diff = createUnifiedDiff(content, modifiedContent, filePath)
|
||||||
|
|
||||||
|
// Format diff with appropriate number of backticks
|
||||||
|
let numBackticks = 3
|
||||||
|
while (diff.includes('`'.repeat(numBackticks))) {
|
||||||
|
numBackticks++
|
||||||
|
}
|
||||||
|
const formattedDiff = `${'`'.repeat(numBackticks)}diff\n${diff}${'`'.repeat(numBackticks)}\n\n`
|
||||||
|
|
||||||
|
if (!dryRun) {
|
||||||
|
await fs.writeFile(filePath, modifiedContent, 'utf-8')
|
||||||
|
}
|
||||||
|
|
||||||
|
return formattedDiff
|
||||||
|
}
|
||||||
|
|
||||||
|
class FileSystemServer {
|
||||||
|
public server: Server
|
||||||
|
private allowedDirectories: string[]
|
||||||
|
constructor(allowedDirs: string[]) {
|
||||||
|
if (!Array.isArray(allowedDirs) || allowedDirs.length === 0) {
|
||||||
|
throw new Error('No allowed directories provided, please specify at least one directory in args')
|
||||||
|
}
|
||||||
|
|
||||||
|
this.allowedDirectories = allowedDirs.map((dir) => normalizePath(path.resolve(expandHome(dir))))
|
||||||
|
|
||||||
|
// Validate that all directories exist and are accessible
|
||||||
|
this.validateDirs().catch((error) => {
|
||||||
|
logger.error('Error validating allowed directories:', error)
|
||||||
|
throw new Error(`Error validating allowed directories: ${error}`)
|
||||||
|
})
|
||||||
|
|
||||||
|
this.server = new Server(
|
||||||
|
{
|
||||||
|
name: 'secure-filesystem-server',
|
||||||
|
version: '0.2.0'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: {
|
||||||
|
tools: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
this.initialize()
|
||||||
|
}
|
||||||
|
|
||||||
|
async validateDirs() {
|
||||||
|
// Validate that all directories exist and are accessible
|
||||||
|
await Promise.all(
|
||||||
|
this.allowedDirectories.map(async (dir) => {
|
||||||
|
try {
|
||||||
|
const stats = await fs.stat(expandHome(dir))
|
||||||
|
if (!stats.isDirectory()) {
|
||||||
|
logger.error(`Error: ${dir} is not a directory`)
|
||||||
|
throw new Error(`Error: ${dir} is not a directory`)
|
||||||
|
}
|
||||||
|
} catch (error: any) {
|
||||||
|
logger.error(`Error accessing directory ${dir}:`, error)
|
||||||
|
throw new Error(`Error accessing directory ${dir}:`, error)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
initialize() {
|
||||||
|
// Tool handlers
|
||||||
|
this.server.setRequestHandler(ListToolsRequestSchema, async () => {
|
||||||
|
return {
|
||||||
|
tools: [
|
||||||
|
{
|
||||||
|
name: 'read_file',
|
||||||
|
description:
|
||||||
|
'Read the complete contents of a file from the file system. ' +
|
||||||
|
'Handles various text encodings and provides detailed error messages ' +
|
||||||
|
'if the file cannot be read. Use this tool when you need to examine ' +
|
||||||
|
'the contents of a single file. Only works within allowed directories.',
|
||||||
|
inputSchema: z.toJSONSchema(ReadFileArgsSchema)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'read_multiple_files',
|
||||||
|
description:
|
||||||
|
'Read the contents of multiple files simultaneously. This is more ' +
|
||||||
|
'efficient than reading files one by one when you need to analyze ' +
|
||||||
|
"or compare multiple files. Each file's content is returned with its " +
|
||||||
|
"path as a reference. Failed reads for individual files won't stop " +
|
||||||
|
'the entire operation. Only works within allowed directories.',
|
||||||
|
inputSchema: z.toJSONSchema(ReadMultipleFilesArgsSchema)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'write_file',
|
||||||
|
description:
|
||||||
|
'Create a new file or completely overwrite an existing file with new content. ' +
|
||||||
|
'Use with caution as it will overwrite existing files without warning. ' +
|
||||||
|
'Handles text content with proper encoding. Only works within allowed directories.',
|
||||||
|
inputSchema: z.toJSONSchema(WriteFileArgsSchema)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'edit_file',
|
||||||
|
description:
|
||||||
|
'Make line-based edits to a text file. Each edit replaces exact line sequences ' +
|
||||||
|
'with new content. Returns a git-style diff showing the changes made. ' +
|
||||||
|
'Only works within allowed directories.',
|
||||||
|
inputSchema: z.toJSONSchema(EditFileArgsSchema)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'create_directory',
|
||||||
|
description:
|
||||||
|
'Create a new directory or ensure a directory exists. Can create multiple ' +
|
||||||
|
'nested directories in one operation. If the directory already exists, ' +
|
||||||
|
'this operation will succeed silently. Perfect for setting up directory ' +
|
||||||
|
'structures for projects or ensuring required paths exist. Only works within allowed directories.',
|
||||||
|
inputSchema: z.toJSONSchema(CreateDirectoryArgsSchema)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'list_directory',
|
||||||
|
description:
|
||||||
|
'Get a detailed listing of all files and directories in a specified path. ' +
|
||||||
|
'Results clearly distinguish between files and directories with [FILE] and [DIR] ' +
|
||||||
|
'prefixes. This tool is essential for understanding directory structure and ' +
|
||||||
|
'finding specific files within a directory. Only works within allowed directories.',
|
||||||
|
inputSchema: z.toJSONSchema(ListDirectoryArgsSchema)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'directory_tree',
|
||||||
|
description:
|
||||||
|
'Get a recursive tree view of files and directories as a JSON structure. ' +
|
||||||
|
"Each entry includes 'name', 'type' (file/directory), and 'children' for directories. " +
|
||||||
|
'Files have no children array, while directories always have a children array (which may be empty). ' +
|
||||||
|
'The output is formatted with 2-space indentation for readability. Only works within allowed directories.',
|
||||||
|
inputSchema: z.toJSONSchema(DirectoryTreeArgsSchema)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'move_file',
|
||||||
|
description:
|
||||||
|
'Move or rename files and directories. Can move files between directories ' +
|
||||||
|
'and rename them in a single operation. If the destination exists, the ' +
|
||||||
|
'operation will fail. Works across different directories and can be used ' +
|
||||||
|
'for simple renaming within the same directory. Both source and destination must be within allowed directories.',
|
||||||
|
inputSchema: z.toJSONSchema(MoveFileArgsSchema)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'search_files',
|
||||||
|
description:
|
||||||
|
'Recursively search for files and directories matching a pattern. ' +
|
||||||
|
'Searches through all subdirectories from the starting path. The search ' +
|
||||||
|
'is case-insensitive and matches partial names. Returns full paths to all ' +
|
||||||
|
"matching items. Great for finding files when you don't know their exact location. " +
|
||||||
|
'Only searches within allowed directories.',
|
||||||
|
inputSchema: z.toJSONSchema(SearchFilesArgsSchema)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'get_file_info',
|
||||||
|
description:
|
||||||
|
'Retrieve detailed metadata about a file or directory. Returns comprehensive ' +
|
||||||
|
'information including size, creation time, last modified time, permissions, ' +
|
||||||
|
'and type. This tool is perfect for understanding file characteristics ' +
|
||||||
|
'without reading the actual content. Only works within allowed directories.',
|
||||||
|
inputSchema: z.toJSONSchema(GetFileInfoArgsSchema)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'list_allowed_directories',
|
||||||
|
description:
|
||||||
|
'Returns the list of directories that this server is allowed to access. ' +
|
||||||
|
'Use this to understand which directories are available before trying to access files.',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {},
|
||||||
|
required: []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
this.server.setRequestHandler(CallToolRequestSchema, async (request) => {
|
||||||
|
try {
|
||||||
|
const { name, arguments: args } = request.params
|
||||||
|
|
||||||
|
switch (name) {
|
||||||
|
case 'read_file': {
|
||||||
|
const parsed = ReadFileArgsSchema.safeParse(args)
|
||||||
|
if (!parsed.success) {
|
||||||
|
throw new Error(`Invalid arguments for read_file: ${parsed.error}`)
|
||||||
|
}
|
||||||
|
const validPath = await validatePath(this.allowedDirectories, parsed.data.path)
|
||||||
|
const content = await fs.readFile(validPath, 'utf-8')
|
||||||
|
return {
|
||||||
|
content: [{ type: 'text', text: content }]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'read_multiple_files': {
|
||||||
|
const parsed = ReadMultipleFilesArgsSchema.safeParse(args)
|
||||||
|
if (!parsed.success) {
|
||||||
|
throw new Error(`Invalid arguments for read_multiple_files: ${parsed.error}`)
|
||||||
|
}
|
||||||
|
const results = await Promise.all(
|
||||||
|
parsed.data.paths.map(async (filePath: string) => {
|
||||||
|
try {
|
||||||
|
const validPath = await validatePath(this.allowedDirectories, filePath)
|
||||||
|
const content = await fs.readFile(validPath, 'utf-8')
|
||||||
|
return `${filePath}:\n${content}\n`
|
||||||
|
} catch (error) {
|
||||||
|
const errorMessage = error instanceof Error ? error.message : String(error)
|
||||||
|
return `${filePath}: Error - ${errorMessage}`
|
||||||
|
}
|
||||||
|
})
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
content: [{ type: 'text', text: results.join('\n---\n') }]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'write_file': {
|
||||||
|
const parsed = WriteFileArgsSchema.safeParse(args)
|
||||||
|
if (!parsed.success) {
|
||||||
|
throw new Error(`Invalid arguments for write_file: ${parsed.error}`)
|
||||||
|
}
|
||||||
|
const validPath = await validatePath(this.allowedDirectories, parsed.data.path)
|
||||||
|
await fs.writeFile(validPath, parsed.data.content, 'utf-8')
|
||||||
|
return {
|
||||||
|
content: [{ type: 'text', text: `Successfully wrote to ${parsed.data.path}` }]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'edit_file': {
|
||||||
|
const parsed = EditFileArgsSchema.safeParse(args)
|
||||||
|
if (!parsed.success) {
|
||||||
|
throw new Error(`Invalid arguments for edit_file: ${parsed.error}`)
|
||||||
|
}
|
||||||
|
const validPath = await validatePath(this.allowedDirectories, parsed.data.path)
|
||||||
|
const result = await applyFileEdits(validPath, parsed.data.edits, parsed.data.dryRun)
|
||||||
|
return {
|
||||||
|
content: [{ type: 'text', text: result }]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'create_directory': {
|
||||||
|
const parsed = CreateDirectoryArgsSchema.safeParse(args)
|
||||||
|
if (!parsed.success) {
|
||||||
|
throw new Error(`Invalid arguments for create_directory: ${parsed.error}`)
|
||||||
|
}
|
||||||
|
const validPath = await validatePath(this.allowedDirectories, parsed.data.path)
|
||||||
|
await fs.mkdir(validPath, { recursive: true })
|
||||||
|
return {
|
||||||
|
content: [{ type: 'text', text: `Successfully created directory ${parsed.data.path}` }]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'list_directory': {
|
||||||
|
const parsed = ListDirectoryArgsSchema.safeParse(args)
|
||||||
|
if (!parsed.success) {
|
||||||
|
throw new Error(`Invalid arguments for list_directory: ${parsed.error}`)
|
||||||
|
}
|
||||||
|
const validPath = await validatePath(this.allowedDirectories, parsed.data.path)
|
||||||
|
const entries = await fs.readdir(validPath, { withFileTypes: true })
|
||||||
|
const formatted = entries
|
||||||
|
.map((entry) => `${entry.isDirectory() ? '[DIR]' : '[FILE]'} ${entry.name}`)
|
||||||
|
.join('\n')
|
||||||
|
return {
|
||||||
|
content: [{ type: 'text', text: formatted }]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'directory_tree': {
|
||||||
|
const parsed = DirectoryTreeArgsSchema.safeParse(args)
|
||||||
|
if (!parsed.success) {
|
||||||
|
throw new Error(`Invalid arguments for directory_tree: ${parsed.error}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
interface TreeEntry {
|
||||||
|
name: string
|
||||||
|
type: 'file' | 'directory'
|
||||||
|
children?: TreeEntry[]
|
||||||
|
}
|
||||||
|
|
||||||
|
async function buildTree(allowedDirectories: string[], currentPath: string): Promise<TreeEntry[]> {
|
||||||
|
const validPath = await validatePath(allowedDirectories, currentPath)
|
||||||
|
const entries = await fs.readdir(validPath, { withFileTypes: true })
|
||||||
|
const result: TreeEntry[] = []
|
||||||
|
|
||||||
|
for (const entry of entries) {
|
||||||
|
const entryData: TreeEntry = {
|
||||||
|
name: entry.name,
|
||||||
|
type: entry.isDirectory() ? 'directory' : 'file'
|
||||||
|
}
|
||||||
|
|
||||||
|
if (entry.isDirectory()) {
|
||||||
|
const subPath = path.join(currentPath, entry.name)
|
||||||
|
entryData.children = await buildTree(allowedDirectories, subPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
result.push(entryData)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
const treeData = await buildTree(this.allowedDirectories, parsed.data.path)
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: JSON.stringify(treeData, null, 2)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'move_file': {
|
||||||
|
const parsed = MoveFileArgsSchema.safeParse(args)
|
||||||
|
if (!parsed.success) {
|
||||||
|
throw new Error(`Invalid arguments for move_file: ${parsed.error}`)
|
||||||
|
}
|
||||||
|
const validSourcePath = await validatePath(this.allowedDirectories, parsed.data.source)
|
||||||
|
const validDestPath = await validatePath(this.allowedDirectories, parsed.data.destination)
|
||||||
|
await fs.rename(validSourcePath, validDestPath)
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{ type: 'text', text: `Successfully moved ${parsed.data.source} to ${parsed.data.destination}` }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'search_files': {
|
||||||
|
const parsed = SearchFilesArgsSchema.safeParse(args)
|
||||||
|
if (!parsed.success) {
|
||||||
|
throw new Error(`Invalid arguments for search_files: ${parsed.error}`)
|
||||||
|
}
|
||||||
|
const validPath = await validatePath(this.allowedDirectories, parsed.data.path)
|
||||||
|
const results = await searchFiles(
|
||||||
|
this.allowedDirectories,
|
||||||
|
validPath,
|
||||||
|
parsed.data.pattern,
|
||||||
|
parsed.data.excludePatterns
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
content: [{ type: 'text', text: results.length > 0 ? results.join('\n') : 'No matches found' }]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'get_file_info': {
|
||||||
|
const parsed = GetFileInfoArgsSchema.safeParse(args)
|
||||||
|
if (!parsed.success) {
|
||||||
|
throw new Error(`Invalid arguments for get_file_info: ${parsed.error}`)
|
||||||
|
}
|
||||||
|
const validPath = await validatePath(this.allowedDirectories, parsed.data.path)
|
||||||
|
const info = await getFileStats(validPath)
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: Object.entries(info)
|
||||||
|
.map(([key, value]) => `${key}: ${value}`)
|
||||||
|
.join('\n')
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'list_allowed_directories': {
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: `Allowed directories:\n${this.allowedDirectories.join('\n')}`
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw new Error(`Unknown tool: ${name}`)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
const errorMessage = error instanceof Error ? error.message : String(error)
|
||||||
|
return {
|
||||||
|
content: [{ type: 'text', text: `Error: ${errorMessage}` }],
|
||||||
|
isError: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default FileSystemServer
|
||||||
@ -1,2 +0,0 @@
|
|||||||
// Re-export FileSystemServer to maintain existing import pattern
|
|
||||||
export { default, FileSystemServer } from './server'
|
|
||||||
@ -1,118 +0,0 @@
|
|||||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
|
||||||
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'
|
|
||||||
import { app } from 'electron'
|
|
||||||
import fs from 'fs/promises'
|
|
||||||
import path from 'path'
|
|
||||||
|
|
||||||
import {
|
|
||||||
deleteToolDefinition,
|
|
||||||
editToolDefinition,
|
|
||||||
globToolDefinition,
|
|
||||||
grepToolDefinition,
|
|
||||||
handleDeleteTool,
|
|
||||||
handleEditTool,
|
|
||||||
handleGlobTool,
|
|
||||||
handleGrepTool,
|
|
||||||
handleLsTool,
|
|
||||||
handleReadTool,
|
|
||||||
handleWriteTool,
|
|
||||||
lsToolDefinition,
|
|
||||||
readToolDefinition,
|
|
||||||
writeToolDefinition
|
|
||||||
} from './tools'
|
|
||||||
import { logger } from './types'
|
|
||||||
|
|
||||||
export class FileSystemServer {
|
|
||||||
public server: Server
|
|
||||||
private baseDir: string
|
|
||||||
|
|
||||||
constructor(baseDir?: string) {
|
|
||||||
if (baseDir && path.isAbsolute(baseDir)) {
|
|
||||||
this.baseDir = baseDir
|
|
||||||
logger.info(`Using provided baseDir for filesystem MCP: ${baseDir}`)
|
|
||||||
} else {
|
|
||||||
const userData = app.getPath('userData')
|
|
||||||
this.baseDir = path.join(userData, 'Data', 'Workspace')
|
|
||||||
logger.info(`Using default workspace for filesystem MCP baseDir: ${this.baseDir}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.server = new Server(
|
|
||||||
{
|
|
||||||
name: 'filesystem-server',
|
|
||||||
version: '2.0.0'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: {
|
|
||||||
tools: {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
this.initialize()
|
|
||||||
}
|
|
||||||
|
|
||||||
async initialize() {
|
|
||||||
try {
|
|
||||||
await fs.mkdir(this.baseDir, { recursive: true })
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('Failed to create filesystem MCP baseDir', { error, baseDir: this.baseDir })
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register tool list handler
|
|
||||||
this.server.setRequestHandler(ListToolsRequestSchema, async () => {
|
|
||||||
return {
|
|
||||||
tools: [
|
|
||||||
globToolDefinition,
|
|
||||||
lsToolDefinition,
|
|
||||||
grepToolDefinition,
|
|
||||||
readToolDefinition,
|
|
||||||
editToolDefinition,
|
|
||||||
writeToolDefinition,
|
|
||||||
deleteToolDefinition
|
|
||||||
]
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Register tool call handler
|
|
||||||
this.server.setRequestHandler(CallToolRequestSchema, async (request) => {
|
|
||||||
try {
|
|
||||||
const { name, arguments: args } = request.params
|
|
||||||
|
|
||||||
switch (name) {
|
|
||||||
case 'glob':
|
|
||||||
return await handleGlobTool(args, this.baseDir)
|
|
||||||
|
|
||||||
case 'ls':
|
|
||||||
return await handleLsTool(args, this.baseDir)
|
|
||||||
|
|
||||||
case 'grep':
|
|
||||||
return await handleGrepTool(args, this.baseDir)
|
|
||||||
|
|
||||||
case 'read':
|
|
||||||
return await handleReadTool(args, this.baseDir)
|
|
||||||
|
|
||||||
case 'edit':
|
|
||||||
return await handleEditTool(args, this.baseDir)
|
|
||||||
|
|
||||||
case 'write':
|
|
||||||
return await handleWriteTool(args, this.baseDir)
|
|
||||||
|
|
||||||
case 'delete':
|
|
||||||
return await handleDeleteTool(args, this.baseDir)
|
|
||||||
|
|
||||||
default:
|
|
||||||
throw new Error(`Unknown tool: ${name}`)
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
const errorMessage = error instanceof Error ? error.message : String(error)
|
|
||||||
logger.error(`Tool execution error for ${request.params.name}:`, { error })
|
|
||||||
return {
|
|
||||||
content: [{ type: 'text', text: `Error: ${errorMessage}` }],
|
|
||||||
isError: true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export default FileSystemServer
|
|
||||||
@ -1,93 +0,0 @@
|
|||||||
import fs from 'fs/promises'
|
|
||||||
import path from 'path'
|
|
||||||
import * as z from 'zod'
|
|
||||||
|
|
||||||
import { logger, validatePath } from '../types'
|
|
||||||
|
|
||||||
// Schema definition
|
|
||||||
export const DeleteToolSchema = z.object({
|
|
||||||
path: z.string().describe('The path to the file or directory to delete'),
|
|
||||||
recursive: z.boolean().optional().describe('For directories, whether to delete recursively (default: false)')
|
|
||||||
})
|
|
||||||
|
|
||||||
// Tool definition with detailed description
|
|
||||||
export const deleteToolDefinition = {
|
|
||||||
name: 'delete',
|
|
||||||
description: `Deletes a file or directory from the filesystem.
|
|
||||||
|
|
||||||
CAUTION: This operation cannot be undone!
|
|
||||||
|
|
||||||
- For files: simply provide the path
|
|
||||||
- For empty directories: provide the path
|
|
||||||
- For non-empty directories: set recursive=true
|
|
||||||
- The path must be an absolute path, not a relative path
|
|
||||||
- Always verify the path before deleting to avoid data loss`,
|
|
||||||
inputSchema: z.toJSONSchema(DeleteToolSchema)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handler implementation
|
|
||||||
export async function handleDeleteTool(args: unknown, baseDir: string) {
|
|
||||||
const parsed = DeleteToolSchema.safeParse(args)
|
|
||||||
if (!parsed.success) {
|
|
||||||
throw new Error(`Invalid arguments for delete: ${parsed.error}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
const targetPath = parsed.data.path
|
|
||||||
const validPath = await validatePath(targetPath, baseDir)
|
|
||||||
const recursive = parsed.data.recursive || false
|
|
||||||
|
|
||||||
// Check if path exists and get stats
|
|
||||||
let stats
|
|
||||||
try {
|
|
||||||
stats = await fs.stat(validPath)
|
|
||||||
} catch (error: any) {
|
|
||||||
if (error.code === 'ENOENT') {
|
|
||||||
throw new Error(`Path not found: ${targetPath}`)
|
|
||||||
}
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
|
|
||||||
const isDirectory = stats.isDirectory()
|
|
||||||
const relativePath = path.relative(baseDir, validPath)
|
|
||||||
|
|
||||||
// Perform deletion
|
|
||||||
try {
|
|
||||||
if (isDirectory) {
|
|
||||||
if (recursive) {
|
|
||||||
// Delete directory recursively
|
|
||||||
await fs.rm(validPath, { recursive: true, force: true })
|
|
||||||
} else {
|
|
||||||
// Try to delete empty directory
|
|
||||||
await fs.rmdir(validPath)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Delete file
|
|
||||||
await fs.unlink(validPath)
|
|
||||||
}
|
|
||||||
} catch (error: any) {
|
|
||||||
if (error.code === 'ENOTEMPTY') {
|
|
||||||
throw new Error(`Directory not empty: ${targetPath}. Use recursive=true to delete non-empty directories.`)
|
|
||||||
}
|
|
||||||
throw new Error(`Failed to delete: ${error.message}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log the operation
|
|
||||||
logger.info('Path deleted', {
|
|
||||||
path: validPath,
|
|
||||||
type: isDirectory ? 'directory' : 'file',
|
|
||||||
recursive: isDirectory ? recursive : undefined
|
|
||||||
})
|
|
||||||
|
|
||||||
// Format output
|
|
||||||
const itemType = isDirectory ? 'Directory' : 'File'
|
|
||||||
const recursiveNote = isDirectory && recursive ? ' (recursive)' : ''
|
|
||||||
|
|
||||||
return {
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: 'text',
|
|
||||||
text: `${itemType} deleted${recursiveNote}: ${relativePath}`
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,130 +0,0 @@
|
|||||||
import fs from 'fs/promises'
|
|
||||||
import path from 'path'
|
|
||||||
import * as z from 'zod'
|
|
||||||
|
|
||||||
import { logger, replaceWithFuzzyMatch, validatePath } from '../types'
|
|
||||||
|
|
||||||
// Schema definition
|
|
||||||
export const EditToolSchema = z.object({
|
|
||||||
file_path: z.string().describe('The path to the file to modify'),
|
|
||||||
old_string: z.string().describe('The text to replace'),
|
|
||||||
new_string: z.string().describe('The text to replace it with'),
|
|
||||||
replace_all: z.boolean().optional().default(false).describe('Replace all occurrences of old_string (default false)')
|
|
||||||
})
|
|
||||||
|
|
||||||
// Tool definition with detailed description
|
|
||||||
export const editToolDefinition = {
|
|
||||||
name: 'edit',
|
|
||||||
description: `Performs exact string replacements in files.
|
|
||||||
|
|
||||||
- You must use the 'read' tool at least once before editing
|
|
||||||
- The file_path must be an absolute path, not a relative path
|
|
||||||
- Preserve exact indentation from read output (after the line number prefix)
|
|
||||||
- Never include line number prefixes in old_string or new_string
|
|
||||||
- ALWAYS prefer editing existing files over creating new ones
|
|
||||||
- The edit will FAIL if old_string is not found in the file
|
|
||||||
- The edit will FAIL if old_string appears multiple times (provide more context or use replace_all)
|
|
||||||
- The edit will FAIL if old_string equals new_string
|
|
||||||
- Use replace_all to rename variables or replace all occurrences`,
|
|
||||||
inputSchema: z.toJSONSchema(EditToolSchema)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handler implementation
|
|
||||||
export async function handleEditTool(args: unknown, baseDir: string) {
|
|
||||||
const parsed = EditToolSchema.safeParse(args)
|
|
||||||
if (!parsed.success) {
|
|
||||||
throw new Error(`Invalid arguments for edit: ${parsed.error}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
const { file_path: filePath, old_string: oldString, new_string: newString, replace_all: replaceAll } = parsed.data
|
|
||||||
|
|
||||||
// Validate path
|
|
||||||
const validPath = await validatePath(filePath, baseDir)
|
|
||||||
|
|
||||||
// Check if file exists
|
|
||||||
try {
|
|
||||||
const stats = await fs.stat(validPath)
|
|
||||||
if (!stats.isFile()) {
|
|
||||||
throw new Error(`Path is not a file: ${filePath}`)
|
|
||||||
}
|
|
||||||
} catch (error: any) {
|
|
||||||
if (error.code === 'ENOENT') {
|
|
||||||
// If old_string is empty, this is a create new file operation
|
|
||||||
if (oldString === '') {
|
|
||||||
// Create parent directory if needed
|
|
||||||
const parentDir = path.dirname(validPath)
|
|
||||||
await fs.mkdir(parentDir, { recursive: true })
|
|
||||||
|
|
||||||
// Write the new content
|
|
||||||
await fs.writeFile(validPath, newString, 'utf-8')
|
|
||||||
|
|
||||||
logger.info('File created', { path: validPath })
|
|
||||||
|
|
||||||
const relativePath = path.relative(baseDir, validPath)
|
|
||||||
return {
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: 'text',
|
|
||||||
text: `Created new file: ${relativePath}\nLines: ${newString.split('\n').length}`
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
throw new Error(`File not found: ${filePath}`)
|
|
||||||
}
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read current content
|
|
||||||
const content = await fs.readFile(validPath, 'utf-8')
|
|
||||||
|
|
||||||
// Handle special case: old_string is empty (create file with content)
|
|
||||||
if (oldString === '') {
|
|
||||||
await fs.writeFile(validPath, newString, 'utf-8')
|
|
||||||
|
|
||||||
logger.info('File overwritten', { path: validPath })
|
|
||||||
|
|
||||||
const relativePath = path.relative(baseDir, validPath)
|
|
||||||
return {
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: 'text',
|
|
||||||
text: `Overwrote file: ${relativePath}\nLines: ${newString.split('\n').length}`
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform the replacement with fuzzy matching
|
|
||||||
const newContent = replaceWithFuzzyMatch(content, oldString, newString, replaceAll)
|
|
||||||
|
|
||||||
// Write the modified content
|
|
||||||
await fs.writeFile(validPath, newContent, 'utf-8')
|
|
||||||
|
|
||||||
logger.info('File edited', {
|
|
||||||
path: validPath,
|
|
||||||
replaceAll
|
|
||||||
})
|
|
||||||
|
|
||||||
// Generate a simple diff summary
|
|
||||||
const oldLines = content.split('\n').length
|
|
||||||
const newLines = newContent.split('\n').length
|
|
||||||
const lineDiff = newLines - oldLines
|
|
||||||
|
|
||||||
const relativePath = path.relative(baseDir, validPath)
|
|
||||||
let diffSummary = `Edited: ${relativePath}`
|
|
||||||
if (lineDiff > 0) {
|
|
||||||
diffSummary += `\n+${lineDiff} lines`
|
|
||||||
} else if (lineDiff < 0) {
|
|
||||||
diffSummary += `\n${lineDiff} lines`
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: 'text',
|
|
||||||
text: diffSummary
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,149 +0,0 @@
|
|||||||
import fs from 'fs/promises'
|
|
||||||
import path from 'path'
|
|
||||||
import * as z from 'zod'
|
|
||||||
|
|
||||||
import type { FileInfo } from '../types'
|
|
||||||
import { logger, MAX_FILES_LIMIT, runRipgrep, validatePath } from '../types'
|
|
||||||
|
|
||||||
// Schema definition
|
|
||||||
export const GlobToolSchema = z.object({
|
|
||||||
pattern: z.string().describe('The glob pattern to match files against'),
|
|
||||||
path: z
|
|
||||||
.string()
|
|
||||||
.optional()
|
|
||||||
.describe('The directory to search in (must be absolute path). Defaults to the base directory')
|
|
||||||
})
|
|
||||||
|
|
||||||
// Tool definition with detailed description
|
|
||||||
export const globToolDefinition = {
|
|
||||||
name: 'glob',
|
|
||||||
description: `Fast file pattern matching tool that works with any codebase size.
|
|
||||||
|
|
||||||
- Supports glob patterns like "**/*.js" or "src/**/*.ts"
|
|
||||||
- Returns matching absolute file paths sorted by modification time (newest first)
|
|
||||||
- Use this when you need to find files by name patterns
|
|
||||||
- Patterns without "/" (e.g., "*.txt") match files at ANY depth in the directory tree
|
|
||||||
- Patterns with "/" (e.g., "src/*.ts") match relative to the search path
|
|
||||||
- Pattern syntax: * (any chars), ** (any path), {a,b} (alternatives), ? (single char)
|
|
||||||
- Results are limited to 100 files
|
|
||||||
- The path parameter must be an absolute path if specified
|
|
||||||
- If path is not specified, defaults to the base directory
|
|
||||||
- IMPORTANT: Omit the path field for the default directory (don't use "undefined" or "null")`,
|
|
||||||
inputSchema: z.toJSONSchema(GlobToolSchema)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handler implementation
|
|
||||||
export async function handleGlobTool(args: unknown, baseDir: string) {
|
|
||||||
const parsed = GlobToolSchema.safeParse(args)
|
|
||||||
if (!parsed.success) {
|
|
||||||
throw new Error(`Invalid arguments for glob: ${parsed.error}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
const searchPath = parsed.data.path || baseDir
|
|
||||||
const validPath = await validatePath(searchPath, baseDir)
|
|
||||||
|
|
||||||
// Verify the search directory exists
|
|
||||||
try {
|
|
||||||
const stats = await fs.stat(validPath)
|
|
||||||
if (!stats.isDirectory()) {
|
|
||||||
throw new Error(`Path is not a directory: ${validPath}`)
|
|
||||||
}
|
|
||||||
} catch (error: unknown) {
|
|
||||||
if (error && typeof error === 'object' && 'code' in error && error.code === 'ENOENT') {
|
|
||||||
throw new Error(`Directory not found: ${validPath}`)
|
|
||||||
}
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate pattern
|
|
||||||
const pattern = parsed.data.pattern.trim()
|
|
||||||
if (!pattern) {
|
|
||||||
throw new Error('Pattern cannot be empty')
|
|
||||||
}
|
|
||||||
|
|
||||||
const files: FileInfo[] = []
|
|
||||||
let truncated = false
|
|
||||||
|
|
||||||
// Build ripgrep arguments for file listing using --glob=pattern format
|
|
||||||
const rgArgs: string[] = [
|
|
||||||
'--files',
|
|
||||||
'--follow',
|
|
||||||
'--hidden',
|
|
||||||
`--glob=${pattern}`,
|
|
||||||
'--glob=!.git/*',
|
|
||||||
'--glob=!node_modules/*',
|
|
||||||
'--glob=!dist/*',
|
|
||||||
'--glob=!build/*',
|
|
||||||
'--glob=!__pycache__/*',
|
|
||||||
validPath
|
|
||||||
]
|
|
||||||
|
|
||||||
// Use ripgrep for file listing
|
|
||||||
logger.debug('Running ripgrep with args', { rgArgs })
|
|
||||||
const rgResult = await runRipgrep(rgArgs)
|
|
||||||
logger.debug('Ripgrep result', {
|
|
||||||
ok: rgResult.ok,
|
|
||||||
exitCode: rgResult.exitCode,
|
|
||||||
stdoutLength: rgResult.stdout.length,
|
|
||||||
stdoutPreview: rgResult.stdout.slice(0, 500)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Process results if we have stdout content
|
|
||||||
// Exit code 2 can indicate partial errors (e.g., permission denied on some dirs) but still have valid results
|
|
||||||
if (rgResult.ok && rgResult.stdout.length > 0) {
|
|
||||||
const lines = rgResult.stdout.split('\n').filter(Boolean)
|
|
||||||
logger.debug('Parsed lines from ripgrep', { lineCount: lines.length, lines })
|
|
||||||
|
|
||||||
for (const line of lines) {
|
|
||||||
if (files.length >= MAX_FILES_LIMIT) {
|
|
||||||
truncated = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
const filePath = line.trim()
|
|
||||||
if (!filePath) continue
|
|
||||||
|
|
||||||
const absolutePath = path.isAbsolute(filePath) ? filePath : path.resolve(validPath, filePath)
|
|
||||||
|
|
||||||
try {
|
|
||||||
const stats = await fs.stat(absolutePath)
|
|
||||||
files.push({
|
|
||||||
path: absolutePath,
|
|
||||||
type: 'file', // ripgrep --files only returns files
|
|
||||||
size: stats.size,
|
|
||||||
modified: stats.mtime
|
|
||||||
})
|
|
||||||
} catch (error) {
|
|
||||||
logger.debug('Failed to stat file from ripgrep output, skipping', { file: absolutePath, error })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort by modification time (newest first)
|
|
||||||
files.sort((a, b) => {
|
|
||||||
const aTime = a.modified ? a.modified.getTime() : 0
|
|
||||||
const bTime = b.modified ? b.modified.getTime() : 0
|
|
||||||
return bTime - aTime
|
|
||||||
})
|
|
||||||
|
|
||||||
// Format output - always use absolute paths
|
|
||||||
const output: string[] = []
|
|
||||||
if (files.length === 0) {
|
|
||||||
output.push(`No files found matching pattern "${parsed.data.pattern}" in ${validPath}`)
|
|
||||||
} else {
|
|
||||||
output.push(...files.map((f) => f.path))
|
|
||||||
if (truncated) {
|
|
||||||
output.push('')
|
|
||||||
output.push(`(Results truncated to ${MAX_FILES_LIMIT} files. Consider using a more specific pattern.)`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: 'text',
|
|
||||||
text: output.join('\n')
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,266 +0,0 @@
|
|||||||
import fs from 'fs/promises'
|
|
||||||
import path from 'path'
|
|
||||||
import * as z from 'zod'
|
|
||||||
|
|
||||||
import type { GrepMatch } from '../types'
|
|
||||||
import { isBinaryFile, MAX_GREP_MATCHES, MAX_LINE_LENGTH, runRipgrep, validatePath } from '../types'
|
|
||||||
|
|
||||||
// Schema definition
|
|
||||||
export const GrepToolSchema = z.object({
|
|
||||||
pattern: z.string().describe('The regex pattern to search for in file contents'),
|
|
||||||
path: z
|
|
||||||
.string()
|
|
||||||
.optional()
|
|
||||||
.describe('The directory to search in (must be absolute path). Defaults to the base directory'),
|
|
||||||
include: z.string().optional().describe('File pattern to include in the search (e.g. "*.js", "*.{ts,tsx}")')
|
|
||||||
})
|
|
||||||
|
|
||||||
// Tool definition with detailed description
|
|
||||||
export const grepToolDefinition = {
|
|
||||||
name: 'grep',
|
|
||||||
description: `Fast content search tool that works with any codebase size.
|
|
||||||
|
|
||||||
- Searches file contents using regular expressions
|
|
||||||
- Supports full regex syntax (e.g., "log.*Error", "function\\s+\\w+")
|
|
||||||
- Filter files by pattern with include (e.g., "*.js", "*.{ts,tsx}")
|
|
||||||
- Returns absolute file paths and line numbers with matching content
|
|
||||||
- Results are limited to 100 matches
|
|
||||||
- Binary files are automatically skipped
|
|
||||||
- Common directories (node_modules, .git, dist) are excluded
|
|
||||||
- The path parameter must be an absolute path if specified
|
|
||||||
- If path is not specified, defaults to the base directory`,
|
|
||||||
inputSchema: z.toJSONSchema(GrepToolSchema)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handler implementation
|
|
||||||
export async function handleGrepTool(args: unknown, baseDir: string) {
|
|
||||||
const parsed = GrepToolSchema.safeParse(args)
|
|
||||||
if (!parsed.success) {
|
|
||||||
throw new Error(`Invalid arguments for grep: ${parsed.error}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
const data = parsed.data
|
|
||||||
|
|
||||||
if (!data.pattern) {
|
|
||||||
throw new Error('Pattern is required for grep')
|
|
||||||
}
|
|
||||||
|
|
||||||
const searchPath = data.path || baseDir
|
|
||||||
const validPath = await validatePath(searchPath, baseDir)
|
|
||||||
|
|
||||||
const matches: GrepMatch[] = []
|
|
||||||
let truncated = false
|
|
||||||
let regex: RegExp
|
|
||||||
|
|
||||||
// Build ripgrep arguments
|
|
||||||
const rgArgs: string[] = [
|
|
||||||
'--no-heading',
|
|
||||||
'--line-number',
|
|
||||||
'--color',
|
|
||||||
'never',
|
|
||||||
'--ignore-case',
|
|
||||||
'--glob',
|
|
||||||
'!.git/**',
|
|
||||||
'--glob',
|
|
||||||
'!node_modules/**',
|
|
||||||
'--glob',
|
|
||||||
'!dist/**',
|
|
||||||
'--glob',
|
|
||||||
'!build/**',
|
|
||||||
'--glob',
|
|
||||||
'!__pycache__/**'
|
|
||||||
]
|
|
||||||
|
|
||||||
if (data.include) {
|
|
||||||
for (const pat of data.include
|
|
||||||
.split(',')
|
|
||||||
.map((p) => p.trim())
|
|
||||||
.filter(Boolean)) {
|
|
||||||
rgArgs.push('--glob', pat)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rgArgs.push(data.pattern)
|
|
||||||
rgArgs.push(validPath)
|
|
||||||
|
|
||||||
try {
|
|
||||||
regex = new RegExp(data.pattern, 'gi')
|
|
||||||
} catch (error) {
|
|
||||||
throw new Error(`Invalid regex pattern: ${data.pattern}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
async function searchFile(filePath: string): Promise<void> {
|
|
||||||
if (matches.length >= MAX_GREP_MATCHES) {
|
|
||||||
truncated = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Skip binary files
|
|
||||||
if (await isBinaryFile(filePath)) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const content = await fs.readFile(filePath, 'utf-8')
|
|
||||||
const lines = content.split('\n')
|
|
||||||
|
|
||||||
lines.forEach((line, index) => {
|
|
||||||
if (matches.length >= MAX_GREP_MATCHES) {
|
|
||||||
truncated = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if (regex.test(line)) {
|
|
||||||
// Truncate long lines
|
|
||||||
const truncatedLine = line.length > MAX_LINE_LENGTH ? line.substring(0, MAX_LINE_LENGTH) + '...' : line
|
|
||||||
|
|
||||||
matches.push({
|
|
||||||
file: filePath,
|
|
||||||
line: index + 1,
|
|
||||||
content: truncatedLine.trim()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
} catch (error) {
|
|
||||||
// Skip files we can't read
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function searchDirectory(dir: string): Promise<void> {
|
|
||||||
if (matches.length >= MAX_GREP_MATCHES) {
|
|
||||||
truncated = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const entries = await fs.readdir(dir, { withFileTypes: true })
|
|
||||||
|
|
||||||
for (const entry of entries) {
|
|
||||||
if (matches.length >= MAX_GREP_MATCHES) {
|
|
||||||
truncated = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
const fullPath = path.join(dir, entry.name)
|
|
||||||
|
|
||||||
// Skip common ignore patterns
|
|
||||||
if (entry.name.startsWith('.') && entry.name !== '.env.example') {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if (['node_modules', 'dist', 'build', '__pycache__', '.git'].includes(entry.name)) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if (entry.isFile()) {
|
|
||||||
// Check if file matches include pattern
|
|
||||||
if (data.include) {
|
|
||||||
const includePatterns = data.include.split(',').map((p) => p.trim())
|
|
||||||
const fileName = path.basename(fullPath)
|
|
||||||
const matchesInclude = includePatterns.some((pattern) => {
|
|
||||||
// Simple glob pattern matching
|
|
||||||
const regexPattern = pattern
|
|
||||||
.replace(/\*/g, '.*')
|
|
||||||
.replace(/\?/g, '.')
|
|
||||||
.replace(/\{([^}]+)\}/g, (_, group) => `(${group.split(',').join('|')})`)
|
|
||||||
return new RegExp(`^${regexPattern}$`).test(fileName)
|
|
||||||
})
|
|
||||||
if (!matchesInclude) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
await searchFile(fullPath)
|
|
||||||
} else if (entry.isDirectory()) {
|
|
||||||
await searchDirectory(fullPath)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
// Skip directories we can't read
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform the search
|
|
||||||
let usedRipgrep = false
|
|
||||||
try {
|
|
||||||
const rgResult = await runRipgrep(rgArgs)
|
|
||||||
if (rgResult.ok && rgResult.exitCode !== null && rgResult.exitCode !== 2) {
|
|
||||||
usedRipgrep = true
|
|
||||||
const lines = rgResult.stdout.split('\n').filter(Boolean)
|
|
||||||
for (const line of lines) {
|
|
||||||
if (matches.length >= MAX_GREP_MATCHES) {
|
|
||||||
truncated = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
const firstColon = line.indexOf(':')
|
|
||||||
const secondColon = line.indexOf(':', firstColon + 1)
|
|
||||||
if (firstColon === -1 || secondColon === -1) continue
|
|
||||||
|
|
||||||
const filePart = line.slice(0, firstColon)
|
|
||||||
const linePart = line.slice(firstColon + 1, secondColon)
|
|
||||||
const contentPart = line.slice(secondColon + 1)
|
|
||||||
const lineNum = Number.parseInt(linePart, 10)
|
|
||||||
if (!Number.isFinite(lineNum)) continue
|
|
||||||
|
|
||||||
const absoluteFilePath = path.isAbsolute(filePart) ? filePart : path.resolve(baseDir, filePart)
|
|
||||||
const truncatedLine =
|
|
||||||
contentPart.length > MAX_LINE_LENGTH ? contentPart.substring(0, MAX_LINE_LENGTH) + '...' : contentPart
|
|
||||||
|
|
||||||
matches.push({
|
|
||||||
file: absoluteFilePath,
|
|
||||||
line: lineNum,
|
|
||||||
content: truncatedLine.trim()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch {
|
|
||||||
usedRipgrep = false
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!usedRipgrep) {
|
|
||||||
const stats = await fs.stat(validPath)
|
|
||||||
if (stats.isFile()) {
|
|
||||||
await searchFile(validPath)
|
|
||||||
} else {
|
|
||||||
await searchDirectory(validPath)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Format output
|
|
||||||
const output: string[] = []
|
|
||||||
|
|
||||||
if (matches.length === 0) {
|
|
||||||
output.push('No matches found')
|
|
||||||
} else {
|
|
||||||
// Group matches by file
|
|
||||||
const fileGroups = new Map<string, GrepMatch[]>()
|
|
||||||
matches.forEach((match) => {
|
|
||||||
if (!fileGroups.has(match.file)) {
|
|
||||||
fileGroups.set(match.file, [])
|
|
||||||
}
|
|
||||||
fileGroups.get(match.file)!.push(match)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Format grouped matches - always use absolute paths
|
|
||||||
fileGroups.forEach((fileMatches, filePath) => {
|
|
||||||
output.push(`\n${filePath}:`)
|
|
||||||
fileMatches.forEach((match) => {
|
|
||||||
output.push(` ${match.line}: ${match.content}`)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
if (truncated) {
|
|
||||||
output.push('')
|
|
||||||
output.push(`(Results truncated to ${MAX_GREP_MATCHES} matches. Consider using a more specific pattern or path.)`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: 'text',
|
|
||||||
text: output.join('\n')
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,8 +0,0 @@
|
|||||||
// Export all tool definitions and handlers
|
|
||||||
export { deleteToolDefinition, handleDeleteTool } from './delete'
|
|
||||||
export { editToolDefinition, handleEditTool } from './edit'
|
|
||||||
export { globToolDefinition, handleGlobTool } from './glob'
|
|
||||||
export { grepToolDefinition, handleGrepTool } from './grep'
|
|
||||||
export { handleLsTool, lsToolDefinition } from './ls'
|
|
||||||
export { handleReadTool, readToolDefinition } from './read'
|
|
||||||
export { handleWriteTool, writeToolDefinition } from './write'
|
|
||||||
@ -1,150 +0,0 @@
|
|||||||
import fs from 'fs/promises'
|
|
||||||
import path from 'path'
|
|
||||||
import * as z from 'zod'
|
|
||||||
|
|
||||||
import { MAX_FILES_LIMIT, validatePath } from '../types'
|
|
||||||
|
|
||||||
// Schema definition
|
|
||||||
export const LsToolSchema = z.object({
|
|
||||||
path: z.string().optional().describe('The directory to list (must be absolute path). Defaults to the base directory'),
|
|
||||||
recursive: z.boolean().optional().describe('Whether to list directories recursively (default: false)')
|
|
||||||
})
|
|
||||||
|
|
||||||
// Tool definition with detailed description
|
|
||||||
export const lsToolDefinition = {
|
|
||||||
name: 'ls',
|
|
||||||
description: `Lists files and directories in a specified path.
|
|
||||||
|
|
||||||
- Returns a tree-like structure with icons (📁 directories, 📄 files)
|
|
||||||
- Shows the absolute directory path in the header
|
|
||||||
- Entries are sorted alphabetically with directories first
|
|
||||||
- Can list recursively with recursive=true (up to 5 levels deep)
|
|
||||||
- Common directories (node_modules, dist, .git) are excluded
|
|
||||||
- Hidden files (starting with .) are excluded except .env.example
|
|
||||||
- Results are limited to 100 entries
|
|
||||||
- The path parameter must be an absolute path if specified
|
|
||||||
- If path is not specified, defaults to the base directory`,
|
|
||||||
inputSchema: z.toJSONSchema(LsToolSchema)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handler implementation
|
|
||||||
export async function handleLsTool(args: unknown, baseDir: string) {
|
|
||||||
const parsed = LsToolSchema.safeParse(args)
|
|
||||||
if (!parsed.success) {
|
|
||||||
throw new Error(`Invalid arguments for ls: ${parsed.error}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
const targetPath = parsed.data.path || baseDir
|
|
||||||
const validPath = await validatePath(targetPath, baseDir)
|
|
||||||
const recursive = parsed.data.recursive || false
|
|
||||||
|
|
||||||
interface TreeNode {
|
|
||||||
name: string
|
|
||||||
type: 'file' | 'directory'
|
|
||||||
children?: TreeNode[]
|
|
||||||
}
|
|
||||||
|
|
||||||
let fileCount = 0
|
|
||||||
let truncated = false
|
|
||||||
|
|
||||||
async function buildTree(dirPath: string, depth: number = 0): Promise<TreeNode[]> {
|
|
||||||
if (fileCount >= MAX_FILES_LIMIT) {
|
|
||||||
truncated = true
|
|
||||||
return []
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const entries = await fs.readdir(dirPath, { withFileTypes: true })
|
|
||||||
const nodes: TreeNode[] = []
|
|
||||||
|
|
||||||
// Sort entries: directories first, then files, alphabetically
|
|
||||||
entries.sort((a, b) => {
|
|
||||||
if (a.isDirectory() && !b.isDirectory()) return -1
|
|
||||||
if (!a.isDirectory() && b.isDirectory()) return 1
|
|
||||||
return a.name.localeCompare(b.name)
|
|
||||||
})
|
|
||||||
|
|
||||||
for (const entry of entries) {
|
|
||||||
if (fileCount >= MAX_FILES_LIMIT) {
|
|
||||||
truncated = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip hidden files and common ignore patterns
|
|
||||||
if (entry.name.startsWith('.') && entry.name !== '.env.example') {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if (['node_modules', 'dist', 'build', '__pycache__'].includes(entry.name)) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
fileCount++
|
|
||||||
const node: TreeNode = {
|
|
||||||
name: entry.name,
|
|
||||||
type: entry.isDirectory() ? 'directory' : 'file'
|
|
||||||
}
|
|
||||||
|
|
||||||
if (entry.isDirectory() && recursive && depth < 5) {
|
|
||||||
// Limit depth to prevent infinite recursion
|
|
||||||
const childPath = path.join(dirPath, entry.name)
|
|
||||||
node.children = await buildTree(childPath, depth + 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
nodes.push(node)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nodes
|
|
||||||
} catch (error) {
|
|
||||||
return []
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the tree
|
|
||||||
const tree = await buildTree(validPath)
|
|
||||||
|
|
||||||
// Format as text output
|
|
||||||
function formatTree(nodes: TreeNode[], prefix: string = ''): string[] {
|
|
||||||
const lines: string[] = []
|
|
||||||
|
|
||||||
nodes.forEach((node, index) => {
|
|
||||||
const isLastNode = index === nodes.length - 1
|
|
||||||
const connector = isLastNode ? '└── ' : '├── '
|
|
||||||
const icon = node.type === 'directory' ? '📁 ' : '📄 '
|
|
||||||
|
|
||||||
lines.push(prefix + connector + icon + node.name)
|
|
||||||
|
|
||||||
if (node.children && node.children.length > 0) {
|
|
||||||
const childPrefix = prefix + (isLastNode ? ' ' : '│ ')
|
|
||||||
lines.push(...formatTree(node.children, childPrefix))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return lines
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate output
|
|
||||||
const output: string[] = []
|
|
||||||
output.push(`Directory: ${validPath}`)
|
|
||||||
output.push('')
|
|
||||||
|
|
||||||
if (tree.length === 0) {
|
|
||||||
output.push('(empty directory)')
|
|
||||||
} else {
|
|
||||||
const treeLines = formatTree(tree, '')
|
|
||||||
output.push(...treeLines)
|
|
||||||
|
|
||||||
if (truncated) {
|
|
||||||
output.push('')
|
|
||||||
output.push(`(Results truncated to ${MAX_FILES_LIMIT} files. Consider listing a more specific directory.)`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: 'text',
|
|
||||||
text: output.join('\n')
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,101 +0,0 @@
|
|||||||
import fs from 'fs/promises'
|
|
||||||
import path from 'path'
|
|
||||||
import * as z from 'zod'
|
|
||||||
|
|
||||||
import { DEFAULT_READ_LIMIT, isBinaryFile, MAX_LINE_LENGTH, validatePath } from '../types'
|
|
||||||
|
|
||||||
// Schema definition
|
|
||||||
export const ReadToolSchema = z.object({
|
|
||||||
file_path: z.string().describe('The path to the file to read'),
|
|
||||||
offset: z.number().optional().describe('The line number to start reading from (1-based)'),
|
|
||||||
limit: z.number().optional().describe('The number of lines to read (defaults to 2000)')
|
|
||||||
})
|
|
||||||
|
|
||||||
// Tool definition with detailed description
|
|
||||||
export const readToolDefinition = {
|
|
||||||
name: 'read',
|
|
||||||
description: `Reads a file from the local filesystem.
|
|
||||||
|
|
||||||
- Assumes this tool can read all files on the machine
|
|
||||||
- The file_path parameter must be an absolute path, not a relative path
|
|
||||||
- By default, reads up to 2000 lines starting from the beginning
|
|
||||||
- You can optionally specify a line offset and limit for long files
|
|
||||||
- Any lines longer than 2000 characters will be truncated
|
|
||||||
- Results are returned with line numbers starting at 1
|
|
||||||
- Binary files are detected and rejected with an error
|
|
||||||
- Empty files return a warning`,
|
|
||||||
inputSchema: z.toJSONSchema(ReadToolSchema)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handler implementation
|
|
||||||
export async function handleReadTool(args: unknown, baseDir: string) {
|
|
||||||
const parsed = ReadToolSchema.safeParse(args)
|
|
||||||
if (!parsed.success) {
|
|
||||||
throw new Error(`Invalid arguments for read: ${parsed.error}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
const filePath = parsed.data.file_path
|
|
||||||
const validPath = await validatePath(filePath, baseDir)
|
|
||||||
|
|
||||||
// Check if file exists
|
|
||||||
try {
|
|
||||||
const stats = await fs.stat(validPath)
|
|
||||||
if (!stats.isFile()) {
|
|
||||||
throw new Error(`Path is not a file: ${filePath}`)
|
|
||||||
}
|
|
||||||
} catch (error: any) {
|
|
||||||
if (error.code === 'ENOENT') {
|
|
||||||
throw new Error(`File not found: ${filePath}`)
|
|
||||||
}
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if file is binary
|
|
||||||
if (await isBinaryFile(validPath)) {
|
|
||||||
throw new Error(`Cannot read binary file: ${filePath}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read file content
|
|
||||||
const content = await fs.readFile(validPath, 'utf-8')
|
|
||||||
const lines = content.split('\n')
|
|
||||||
|
|
||||||
// Apply offset and limit
|
|
||||||
const offset = (parsed.data.offset || 1) - 1 // Convert to 0-based
|
|
||||||
const limit = parsed.data.limit || DEFAULT_READ_LIMIT
|
|
||||||
|
|
||||||
if (offset < 0 || offset >= lines.length) {
|
|
||||||
throw new Error(`Invalid offset: ${offset + 1}. File has ${lines.length} lines.`)
|
|
||||||
}
|
|
||||||
|
|
||||||
const selectedLines = lines.slice(offset, offset + limit)
|
|
||||||
|
|
||||||
// Format output with line numbers and truncate long lines
|
|
||||||
const output: string[] = []
|
|
||||||
const relativePath = path.relative(baseDir, validPath)
|
|
||||||
|
|
||||||
output.push(`File: ${relativePath}`)
|
|
||||||
if (offset > 0 || limit < lines.length) {
|
|
||||||
output.push(`Lines ${offset + 1} to ${Math.min(offset + limit, lines.length)} of ${lines.length}`)
|
|
||||||
}
|
|
||||||
output.push('')
|
|
||||||
|
|
||||||
selectedLines.forEach((line, index) => {
|
|
||||||
const lineNumber = offset + index + 1
|
|
||||||
const truncatedLine = line.length > MAX_LINE_LENGTH ? line.substring(0, MAX_LINE_LENGTH) + '...' : line
|
|
||||||
output.push(`${lineNumber.toString().padStart(6)}\t${truncatedLine}`)
|
|
||||||
})
|
|
||||||
|
|
||||||
if (offset + limit < lines.length) {
|
|
||||||
output.push('')
|
|
||||||
output.push(`(${lines.length - (offset + limit)} more lines not shown)`)
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: 'text',
|
|
||||||
text: output.join('\n')
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,83 +0,0 @@
|
|||||||
import fs from 'fs/promises'
|
|
||||||
import path from 'path'
|
|
||||||
import * as z from 'zod'
|
|
||||||
|
|
||||||
import { logger, validatePath } from '../types'
|
|
||||||
|
|
||||||
// Schema definition
|
|
||||||
export const WriteToolSchema = z.object({
|
|
||||||
file_path: z.string().describe('The path to the file to write'),
|
|
||||||
content: z.string().describe('The content to write to the file')
|
|
||||||
})
|
|
||||||
|
|
||||||
// Tool definition with detailed description
|
|
||||||
export const writeToolDefinition = {
|
|
||||||
name: 'write',
|
|
||||||
description: `Writes a file to the local filesystem.
|
|
||||||
|
|
||||||
- This tool will overwrite the existing file if one exists at the path
|
|
||||||
- You MUST use the read tool first to understand what you're overwriting
|
|
||||||
- ALWAYS prefer using the 'edit' tool for existing files
|
|
||||||
- NEVER proactively create documentation files unless explicitly requested
|
|
||||||
- Parent directories will be created automatically if they don't exist
|
|
||||||
- The file_path must be an absolute path, not a relative path`,
|
|
||||||
inputSchema: z.toJSONSchema(WriteToolSchema)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handler implementation
|
|
||||||
export async function handleWriteTool(args: unknown, baseDir: string) {
|
|
||||||
const parsed = WriteToolSchema.safeParse(args)
|
|
||||||
if (!parsed.success) {
|
|
||||||
throw new Error(`Invalid arguments for write: ${parsed.error}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
const filePath = parsed.data.file_path
|
|
||||||
const validPath = await validatePath(filePath, baseDir)
|
|
||||||
|
|
||||||
// Create parent directory if it doesn't exist
|
|
||||||
const parentDir = path.dirname(validPath)
|
|
||||||
try {
|
|
||||||
await fs.mkdir(parentDir, { recursive: true })
|
|
||||||
} catch (error: any) {
|
|
||||||
if (error.code !== 'EEXIST') {
|
|
||||||
throw new Error(`Failed to create parent directory: ${error.message}`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if file exists (for logging)
|
|
||||||
let isOverwrite = false
|
|
||||||
try {
|
|
||||||
await fs.stat(validPath)
|
|
||||||
isOverwrite = true
|
|
||||||
} catch {
|
|
||||||
// File doesn't exist, that's fine
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write the file
|
|
||||||
try {
|
|
||||||
await fs.writeFile(validPath, parsed.data.content, 'utf-8')
|
|
||||||
} catch (error: any) {
|
|
||||||
throw new Error(`Failed to write file: ${error.message}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log the operation
|
|
||||||
logger.info('File written', {
|
|
||||||
path: validPath,
|
|
||||||
overwrite: isOverwrite,
|
|
||||||
size: parsed.data.content.length
|
|
||||||
})
|
|
||||||
|
|
||||||
// Format output
|
|
||||||
const relativePath = path.relative(baseDir, validPath)
|
|
||||||
const action = isOverwrite ? 'Updated' : 'Created'
|
|
||||||
const lines = parsed.data.content.split('\n').length
|
|
||||||
|
|
||||||
return {
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: 'text',
|
|
||||||
text: `${action} file: ${relativePath}\n` + `Size: ${parsed.data.content.length} bytes\n` + `Lines: ${lines}`
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,627 +0,0 @@
|
|||||||
import { loggerService } from '@logger'
|
|
||||||
import { isMac, isWin } from '@main/constant'
|
|
||||||
import { spawn } from 'child_process'
|
|
||||||
import fs from 'fs/promises'
|
|
||||||
import os from 'os'
|
|
||||||
import path from 'path'
|
|
||||||
|
|
||||||
export const logger = loggerService.withContext('MCP:FileSystemServer')
|
|
||||||
|
|
||||||
// Constants
|
|
||||||
export const MAX_LINE_LENGTH = 2000
|
|
||||||
export const DEFAULT_READ_LIMIT = 2000
|
|
||||||
export const MAX_FILES_LIMIT = 100
|
|
||||||
export const MAX_GREP_MATCHES = 100
|
|
||||||
|
|
||||||
// Common types
|
|
||||||
export interface FileInfo {
|
|
||||||
path: string
|
|
||||||
type: 'file' | 'directory'
|
|
||||||
size?: number
|
|
||||||
modified?: Date
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface GrepMatch {
|
|
||||||
file: string
|
|
||||||
line: number
|
|
||||||
content: string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Utility functions for path handling
|
|
||||||
export function normalizePath(p: string): string {
|
|
||||||
return path.normalize(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
export function expandHome(filepath: string): string {
|
|
||||||
if (filepath.startsWith('~/') || filepath === '~') {
|
|
||||||
return path.join(os.homedir(), filepath.slice(1))
|
|
||||||
}
|
|
||||||
return filepath
|
|
||||||
}
|
|
||||||
|
|
||||||
// Security validation
|
|
||||||
export async function validatePath(requestedPath: string, baseDir?: string): Promise<string> {
|
|
||||||
const expandedPath = expandHome(requestedPath)
|
|
||||||
const root = baseDir ?? process.cwd()
|
|
||||||
const absolute = path.isAbsolute(expandedPath) ? path.resolve(expandedPath) : path.resolve(root, expandedPath)
|
|
||||||
|
|
||||||
// Handle symlinks by checking their real path
|
|
||||||
try {
|
|
||||||
const realPath = await fs.realpath(absolute)
|
|
||||||
return normalizePath(realPath)
|
|
||||||
} catch (error) {
|
|
||||||
// For new files that don't exist yet, verify parent directory
|
|
||||||
const parentDir = path.dirname(absolute)
|
|
||||||
try {
|
|
||||||
const realParentPath = await fs.realpath(parentDir)
|
|
||||||
normalizePath(realParentPath)
|
|
||||||
return normalizePath(absolute)
|
|
||||||
} catch {
|
|
||||||
return normalizePath(absolute)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
// Edit Tool Utilities - Fuzzy matching replacers from opencode
|
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
export type Replacer = (content: string, find: string) => Generator<string, void, unknown>
|
|
||||||
|
|
||||||
// Similarity thresholds for block anchor fallback matching
|
|
||||||
const SINGLE_CANDIDATE_SIMILARITY_THRESHOLD = 0.0
|
|
||||||
const MULTIPLE_CANDIDATES_SIMILARITY_THRESHOLD = 0.3
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Levenshtein distance algorithm implementation
|
|
||||||
*/
|
|
||||||
function levenshtein(a: string, b: string): number {
|
|
||||||
if (a === '' || b === '') {
|
|
||||||
return Math.max(a.length, b.length)
|
|
||||||
}
|
|
||||||
const matrix = Array.from({ length: a.length + 1 }, (_, i) =>
|
|
||||||
Array.from({ length: b.length + 1 }, (_, j) => (i === 0 ? j : j === 0 ? i : 0))
|
|
||||||
)
|
|
||||||
|
|
||||||
for (let i = 1; i <= a.length; i++) {
|
|
||||||
for (let j = 1; j <= b.length; j++) {
|
|
||||||
const cost = a[i - 1] === b[j - 1] ? 0 : 1
|
|
||||||
matrix[i][j] = Math.min(matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + cost)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return matrix[a.length][b.length]
|
|
||||||
}
|
|
||||||
|
|
||||||
export const SimpleReplacer: Replacer = function* (_content, find) {
|
|
||||||
yield find
|
|
||||||
}
|
|
||||||
|
|
||||||
export const LineTrimmedReplacer: Replacer = function* (content, find) {
|
|
||||||
const originalLines = content.split('\n')
|
|
||||||
const searchLines = find.split('\n')
|
|
||||||
|
|
||||||
if (searchLines[searchLines.length - 1] === '') {
|
|
||||||
searchLines.pop()
|
|
||||||
}
|
|
||||||
|
|
||||||
for (let i = 0; i <= originalLines.length - searchLines.length; i++) {
|
|
||||||
let matches = true
|
|
||||||
|
|
||||||
for (let j = 0; j < searchLines.length; j++) {
|
|
||||||
const originalTrimmed = originalLines[i + j].trim()
|
|
||||||
const searchTrimmed = searchLines[j].trim()
|
|
||||||
|
|
||||||
if (originalTrimmed !== searchTrimmed) {
|
|
||||||
matches = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (matches) {
|
|
||||||
let matchStartIndex = 0
|
|
||||||
for (let k = 0; k < i; k++) {
|
|
||||||
matchStartIndex += originalLines[k].length + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
let matchEndIndex = matchStartIndex
|
|
||||||
for (let k = 0; k < searchLines.length; k++) {
|
|
||||||
matchEndIndex += originalLines[i + k].length
|
|
||||||
if (k < searchLines.length - 1) {
|
|
||||||
matchEndIndex += 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
yield content.substring(matchStartIndex, matchEndIndex)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const BlockAnchorReplacer: Replacer = function* (content, find) {
|
|
||||||
const originalLines = content.split('\n')
|
|
||||||
const searchLines = find.split('\n')
|
|
||||||
|
|
||||||
if (searchLines.length < 3) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if (searchLines[searchLines.length - 1] === '') {
|
|
||||||
searchLines.pop()
|
|
||||||
}
|
|
||||||
|
|
||||||
const firstLineSearch = searchLines[0].trim()
|
|
||||||
const lastLineSearch = searchLines[searchLines.length - 1].trim()
|
|
||||||
const searchBlockSize = searchLines.length
|
|
||||||
|
|
||||||
const candidates: Array<{ startLine: number; endLine: number }> = []
|
|
||||||
for (let i = 0; i < originalLines.length; i++) {
|
|
||||||
if (originalLines[i].trim() !== firstLineSearch) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for (let j = i + 2; j < originalLines.length; j++) {
|
|
||||||
if (originalLines[j].trim() === lastLineSearch) {
|
|
||||||
candidates.push({ startLine: i, endLine: j })
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (candidates.length === 0) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if (candidates.length === 1) {
|
|
||||||
const { startLine, endLine } = candidates[0]
|
|
||||||
const actualBlockSize = endLine - startLine + 1
|
|
||||||
|
|
||||||
let similarity = 0
|
|
||||||
const linesToCheck = Math.min(searchBlockSize - 2, actualBlockSize - 2)
|
|
||||||
|
|
||||||
if (linesToCheck > 0) {
|
|
||||||
for (let j = 1; j < searchBlockSize - 1 && j < actualBlockSize - 1; j++) {
|
|
||||||
const originalLine = originalLines[startLine + j].trim()
|
|
||||||
const searchLine = searchLines[j].trim()
|
|
||||||
const maxLen = Math.max(originalLine.length, searchLine.length)
|
|
||||||
if (maxLen === 0) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
const distance = levenshtein(originalLine, searchLine)
|
|
||||||
similarity += (1 - distance / maxLen) / linesToCheck
|
|
||||||
|
|
||||||
if (similarity >= SINGLE_CANDIDATE_SIMILARITY_THRESHOLD) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
similarity = 1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
if (similarity >= SINGLE_CANDIDATE_SIMILARITY_THRESHOLD) {
|
|
||||||
let matchStartIndex = 0
|
|
||||||
for (let k = 0; k < startLine; k++) {
|
|
||||||
matchStartIndex += originalLines[k].length + 1
|
|
||||||
}
|
|
||||||
let matchEndIndex = matchStartIndex
|
|
||||||
for (let k = startLine; k <= endLine; k++) {
|
|
||||||
matchEndIndex += originalLines[k].length
|
|
||||||
if (k < endLine) {
|
|
||||||
matchEndIndex += 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
yield content.substring(matchStartIndex, matchEndIndex)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
let bestMatch: { startLine: number; endLine: number } | null = null
|
|
||||||
let maxSimilarity = -1
|
|
||||||
|
|
||||||
for (const candidate of candidates) {
|
|
||||||
const { startLine, endLine } = candidate
|
|
||||||
const actualBlockSize = endLine - startLine + 1
|
|
||||||
|
|
||||||
let similarity = 0
|
|
||||||
const linesToCheck = Math.min(searchBlockSize - 2, actualBlockSize - 2)
|
|
||||||
|
|
||||||
if (linesToCheck > 0) {
|
|
||||||
for (let j = 1; j < searchBlockSize - 1 && j < actualBlockSize - 1; j++) {
|
|
||||||
const originalLine = originalLines[startLine + j].trim()
|
|
||||||
const searchLine = searchLines[j].trim()
|
|
||||||
const maxLen = Math.max(originalLine.length, searchLine.length)
|
|
||||||
if (maxLen === 0) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
const distance = levenshtein(originalLine, searchLine)
|
|
||||||
similarity += 1 - distance / maxLen
|
|
||||||
}
|
|
||||||
similarity /= linesToCheck
|
|
||||||
} else {
|
|
||||||
similarity = 1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
if (similarity > maxSimilarity) {
|
|
||||||
maxSimilarity = similarity
|
|
||||||
bestMatch = candidate
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (maxSimilarity >= MULTIPLE_CANDIDATES_SIMILARITY_THRESHOLD && bestMatch) {
|
|
||||||
const { startLine, endLine } = bestMatch
|
|
||||||
let matchStartIndex = 0
|
|
||||||
for (let k = 0; k < startLine; k++) {
|
|
||||||
matchStartIndex += originalLines[k].length + 1
|
|
||||||
}
|
|
||||||
let matchEndIndex = matchStartIndex
|
|
||||||
for (let k = startLine; k <= endLine; k++) {
|
|
||||||
matchEndIndex += originalLines[k].length
|
|
||||||
if (k < endLine) {
|
|
||||||
matchEndIndex += 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
yield content.substring(matchStartIndex, matchEndIndex)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const WhitespaceNormalizedReplacer: Replacer = function* (content, find) {
|
|
||||||
const normalizeWhitespace = (text: string) => text.replace(/\s+/g, ' ').trim()
|
|
||||||
const normalizedFind = normalizeWhitespace(find)
|
|
||||||
|
|
||||||
const lines = content.split('\n')
|
|
||||||
for (let i = 0; i < lines.length; i++) {
|
|
||||||
const line = lines[i]
|
|
||||||
if (normalizeWhitespace(line) === normalizedFind) {
|
|
||||||
yield line
|
|
||||||
} else {
|
|
||||||
const normalizedLine = normalizeWhitespace(line)
|
|
||||||
if (normalizedLine.includes(normalizedFind)) {
|
|
||||||
const words = find.trim().split(/\s+/)
|
|
||||||
if (words.length > 0) {
|
|
||||||
const pattern = words.map((word) => word.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')).join('\\s+')
|
|
||||||
try {
|
|
||||||
const regex = new RegExp(pattern)
|
|
||||||
const match = line.match(regex)
|
|
||||||
if (match) {
|
|
||||||
yield match[0]
|
|
||||||
}
|
|
||||||
} catch {
|
|
||||||
// Invalid regex pattern, skip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const findLines = find.split('\n')
|
|
||||||
if (findLines.length > 1) {
|
|
||||||
for (let i = 0; i <= lines.length - findLines.length; i++) {
|
|
||||||
const block = lines.slice(i, i + findLines.length)
|
|
||||||
if (normalizeWhitespace(block.join('\n')) === normalizedFind) {
|
|
||||||
yield block.join('\n')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const IndentationFlexibleReplacer: Replacer = function* (content, find) {
|
|
||||||
const removeIndentation = (text: string) => {
|
|
||||||
const lines = text.split('\n')
|
|
||||||
const nonEmptyLines = lines.filter((line) => line.trim().length > 0)
|
|
||||||
if (nonEmptyLines.length === 0) return text
|
|
||||||
|
|
||||||
const minIndent = Math.min(
|
|
||||||
...nonEmptyLines.map((line) => {
|
|
||||||
const match = line.match(/^(\s*)/)
|
|
||||||
return match ? match[1].length : 0
|
|
||||||
})
|
|
||||||
)
|
|
||||||
|
|
||||||
return lines.map((line) => (line.trim().length === 0 ? line : line.slice(minIndent))).join('\n')
|
|
||||||
}
|
|
||||||
|
|
||||||
const normalizedFind = removeIndentation(find)
|
|
||||||
const contentLines = content.split('\n')
|
|
||||||
const findLines = find.split('\n')
|
|
||||||
|
|
||||||
for (let i = 0; i <= contentLines.length - findLines.length; i++) {
|
|
||||||
const block = contentLines.slice(i, i + findLines.length).join('\n')
|
|
||||||
if (removeIndentation(block) === normalizedFind) {
|
|
||||||
yield block
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const EscapeNormalizedReplacer: Replacer = function* (content, find) {
|
|
||||||
const unescapeString = (str: string): string => {
|
|
||||||
return str.replace(/\\(n|t|r|'|"|`|\\|\n|\$)/g, (match, capturedChar) => {
|
|
||||||
switch (capturedChar) {
|
|
||||||
case 'n':
|
|
||||||
return '\n'
|
|
||||||
case 't':
|
|
||||||
return '\t'
|
|
||||||
case 'r':
|
|
||||||
return '\r'
|
|
||||||
case "'":
|
|
||||||
return "'"
|
|
||||||
case '"':
|
|
||||||
return '"'
|
|
||||||
case '`':
|
|
||||||
return '`'
|
|
||||||
case '\\':
|
|
||||||
return '\\'
|
|
||||||
case '\n':
|
|
||||||
return '\n'
|
|
||||||
case '$':
|
|
||||||
return '$'
|
|
||||||
default:
|
|
||||||
return match
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const unescapedFind = unescapeString(find)
|
|
||||||
|
|
||||||
if (content.includes(unescapedFind)) {
|
|
||||||
yield unescapedFind
|
|
||||||
}
|
|
||||||
|
|
||||||
const lines = content.split('\n')
|
|
||||||
const findLines = unescapedFind.split('\n')
|
|
||||||
|
|
||||||
for (let i = 0; i <= lines.length - findLines.length; i++) {
|
|
||||||
const block = lines.slice(i, i + findLines.length).join('\n')
|
|
||||||
const unescapedBlock = unescapeString(block)
|
|
||||||
|
|
||||||
if (unescapedBlock === unescapedFind) {
|
|
||||||
yield block
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const TrimmedBoundaryReplacer: Replacer = function* (content, find) {
|
|
||||||
const trimmedFind = find.trim()
|
|
||||||
|
|
||||||
if (trimmedFind === find) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if (content.includes(trimmedFind)) {
|
|
||||||
yield trimmedFind
|
|
||||||
}
|
|
||||||
|
|
||||||
const lines = content.split('\n')
|
|
||||||
const findLines = find.split('\n')
|
|
||||||
|
|
||||||
for (let i = 0; i <= lines.length - findLines.length; i++) {
|
|
||||||
const block = lines.slice(i, i + findLines.length).join('\n')
|
|
||||||
|
|
||||||
if (block.trim() === trimmedFind) {
|
|
||||||
yield block
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const ContextAwareReplacer: Replacer = function* (content, find) {
|
|
||||||
const findLines = find.split('\n')
|
|
||||||
if (findLines.length < 3) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if (findLines[findLines.length - 1] === '') {
|
|
||||||
findLines.pop()
|
|
||||||
}
|
|
||||||
|
|
||||||
const contentLines = content.split('\n')
|
|
||||||
|
|
||||||
const firstLine = findLines[0].trim()
|
|
||||||
const lastLine = findLines[findLines.length - 1].trim()
|
|
||||||
|
|
||||||
for (let i = 0; i < contentLines.length; i++) {
|
|
||||||
if (contentLines[i].trim() !== firstLine) continue
|
|
||||||
|
|
||||||
for (let j = i + 2; j < contentLines.length; j++) {
|
|
||||||
if (contentLines[j].trim() === lastLine) {
|
|
||||||
const blockLines = contentLines.slice(i, j + 1)
|
|
||||||
const block = blockLines.join('\n')
|
|
||||||
|
|
||||||
if (blockLines.length === findLines.length) {
|
|
||||||
let matchingLines = 0
|
|
||||||
let totalNonEmptyLines = 0
|
|
||||||
|
|
||||||
for (let k = 1; k < blockLines.length - 1; k++) {
|
|
||||||
const blockLine = blockLines[k].trim()
|
|
||||||
const findLine = findLines[k].trim()
|
|
||||||
|
|
||||||
if (blockLine.length > 0 || findLine.length > 0) {
|
|
||||||
totalNonEmptyLines++
|
|
||||||
if (blockLine === findLine) {
|
|
||||||
matchingLines++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (totalNonEmptyLines === 0 || matchingLines / totalNonEmptyLines >= 0.5) {
|
|
||||||
yield block
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const MultiOccurrenceReplacer: Replacer = function* (content, find) {
|
|
||||||
let startIndex = 0
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
const index = content.indexOf(find, startIndex)
|
|
||||||
if (index === -1) break
|
|
||||||
|
|
||||||
yield find
|
|
||||||
startIndex = index + find.length
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* All replacers in order of specificity
|
|
||||||
*/
|
|
||||||
export const ALL_REPLACERS: Replacer[] = [
|
|
||||||
SimpleReplacer,
|
|
||||||
LineTrimmedReplacer,
|
|
||||||
BlockAnchorReplacer,
|
|
||||||
WhitespaceNormalizedReplacer,
|
|
||||||
IndentationFlexibleReplacer,
|
|
||||||
EscapeNormalizedReplacer,
|
|
||||||
TrimmedBoundaryReplacer,
|
|
||||||
ContextAwareReplacer,
|
|
||||||
MultiOccurrenceReplacer
|
|
||||||
]
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Replace oldString with newString in content using fuzzy matching
|
|
||||||
*/
|
|
||||||
export function replaceWithFuzzyMatch(
|
|
||||||
content: string,
|
|
||||||
oldString: string,
|
|
||||||
newString: string,
|
|
||||||
replaceAll = false
|
|
||||||
): string {
|
|
||||||
if (oldString === newString) {
|
|
||||||
throw new Error('old_string and new_string must be different')
|
|
||||||
}
|
|
||||||
|
|
||||||
let notFound = true
|
|
||||||
|
|
||||||
for (const replacer of ALL_REPLACERS) {
|
|
||||||
for (const search of replacer(content, oldString)) {
|
|
||||||
const index = content.indexOf(search)
|
|
||||||
if (index === -1) continue
|
|
||||||
notFound = false
|
|
||||||
if (replaceAll) {
|
|
||||||
return content.replaceAll(search, newString)
|
|
||||||
}
|
|
||||||
const lastIndex = content.lastIndexOf(search)
|
|
||||||
if (index !== lastIndex) continue
|
|
||||||
return content.substring(0, index) + newString + content.substring(index + search.length)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (notFound) {
|
|
||||||
throw new Error('old_string not found in content')
|
|
||||||
}
|
|
||||||
throw new Error(
|
|
||||||
'Found multiple matches for old_string. Provide more surrounding lines in old_string to identify the correct match.'
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
// Binary File Detection
|
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
// Check if a file is likely binary
|
|
||||||
export async function isBinaryFile(filePath: string): Promise<boolean> {
|
|
||||||
try {
|
|
||||||
const buffer = Buffer.alloc(4096)
|
|
||||||
const fd = await fs.open(filePath, 'r')
|
|
||||||
const { bytesRead } = await fd.read(buffer, 0, buffer.length, 0)
|
|
||||||
await fd.close()
|
|
||||||
|
|
||||||
if (bytesRead === 0) return false
|
|
||||||
|
|
||||||
const view = buffer.subarray(0, bytesRead)
|
|
||||||
|
|
||||||
let zeroBytes = 0
|
|
||||||
let evenZeros = 0
|
|
||||||
let oddZeros = 0
|
|
||||||
let nonPrintable = 0
|
|
||||||
|
|
||||||
for (let i = 0; i < view.length; i++) {
|
|
||||||
const b = view[i]
|
|
||||||
|
|
||||||
if (b === 0) {
|
|
||||||
zeroBytes++
|
|
||||||
if (i % 2 === 0) evenZeros++
|
|
||||||
else oddZeros++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// treat common whitespace as printable
|
|
||||||
if (b === 9 || b === 10 || b === 13) continue
|
|
||||||
|
|
||||||
// basic ASCII printable range
|
|
||||||
if (b >= 32 && b <= 126) continue
|
|
||||||
|
|
||||||
// bytes >= 128 are likely part of UTF-8 sequences; count as printable
|
|
||||||
if (b >= 128) continue
|
|
||||||
|
|
||||||
nonPrintable++
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there are lots of null bytes, it's probably binary unless it looks like UTF-16 text.
|
|
||||||
if (zeroBytes > 0) {
|
|
||||||
const evenSlots = Math.ceil(view.length / 2)
|
|
||||||
const oddSlots = Math.floor(view.length / 2)
|
|
||||||
const evenZeroRatio = evenSlots > 0 ? evenZeros / evenSlots : 0
|
|
||||||
const oddZeroRatio = oddSlots > 0 ? oddZeros / oddSlots : 0
|
|
||||||
|
|
||||||
// UTF-16LE/BE tends to have zeros on every other byte.
|
|
||||||
if (evenZeroRatio > 0.7 || oddZeroRatio > 0.7) return false
|
|
||||||
|
|
||||||
if (zeroBytes / view.length > 0.05) return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Heuristic: too many non-printable bytes => binary.
|
|
||||||
return nonPrintable / view.length > 0.3
|
|
||||||
} catch {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
// Ripgrep Utilities
|
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
export interface RipgrepResult {
|
|
||||||
ok: boolean
|
|
||||||
stdout: string
|
|
||||||
exitCode: number | null
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getRipgrepAddonPath(): string {
|
|
||||||
const pkgJsonPath = require.resolve('@anthropic-ai/claude-agent-sdk/package.json')
|
|
||||||
const pkgRoot = path.dirname(pkgJsonPath)
|
|
||||||
const platform = isMac ? 'darwin' : isWin ? 'win32' : 'linux'
|
|
||||||
const arch = process.arch === 'arm64' ? 'arm64' : 'x64'
|
|
||||||
return path.join(pkgRoot, 'vendor', 'ripgrep', `${arch}-${platform}`, 'ripgrep.node')
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function runRipgrep(args: string[]): Promise<RipgrepResult> {
|
|
||||||
const addonPath = getRipgrepAddonPath()
|
|
||||||
const childScript = `const { ripgrepMain } = require(process.env.RIPGREP_ADDON_PATH); process.exit(ripgrepMain(process.argv.slice(1)));`
|
|
||||||
|
|
||||||
return new Promise((resolve) => {
|
|
||||||
const child = spawn(process.execPath, ['--eval', childScript, 'rg', ...args], {
|
|
||||||
cwd: process.cwd(),
|
|
||||||
env: {
|
|
||||||
...process.env,
|
|
||||||
ELECTRON_RUN_AS_NODE: '1',
|
|
||||||
RIPGREP_ADDON_PATH: addonPath
|
|
||||||
},
|
|
||||||
stdio: ['ignore', 'pipe', 'pipe']
|
|
||||||
})
|
|
||||||
|
|
||||||
let stdout = ''
|
|
||||||
|
|
||||||
child.stdout?.on('data', (chunk) => {
|
|
||||||
stdout += chunk.toString('utf-8')
|
|
||||||
})
|
|
||||||
|
|
||||||
child.on('error', () => {
|
|
||||||
resolve({ ok: false, stdout: '', exitCode: null })
|
|
||||||
})
|
|
||||||
|
|
||||||
child.on('close', (code) => {
|
|
||||||
resolve({ ok: true, stdout, exitCode: code })
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@ -767,56 +767,6 @@ class BackupManager {
|
|||||||
const s3Client = this.getS3Storage(s3Config)
|
const s3Client = this.getS3Storage(s3Config)
|
||||||
return await s3Client.checkConnection()
|
return await s3Client.checkConnection()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a temporary backup for LAN transfer
|
|
||||||
* Creates a lightweight backup (skipBackupFile=true) in the temp directory
|
|
||||||
* Returns the path to the created ZIP file
|
|
||||||
*/
|
|
||||||
async createLanTransferBackup(_: Electron.IpcMainInvokeEvent, data: string): Promise<string> {
|
|
||||||
const timestamp = new Date()
|
|
||||||
.toISOString()
|
|
||||||
.replace(/[-:T.Z]/g, '')
|
|
||||||
.slice(0, 12)
|
|
||||||
const fileName = `cherry-studio.${timestamp}.zip`
|
|
||||||
const tempPath = path.join(app.getPath('temp'), 'cherry-studio', 'lan-transfer')
|
|
||||||
|
|
||||||
// Ensure temp directory exists
|
|
||||||
await fs.ensureDir(tempPath)
|
|
||||||
|
|
||||||
// Create backup with skipBackupFile=true (no Data folder)
|
|
||||||
const backupedFilePath = await this.backup(_, fileName, data, tempPath, true)
|
|
||||||
|
|
||||||
logger.info(`[BackupManager] Created LAN transfer backup at: ${backupedFilePath}`)
|
|
||||||
return backupedFilePath
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Delete a temporary backup file after LAN transfer completes
|
|
||||||
*/
|
|
||||||
async deleteTempBackup(_: Electron.IpcMainInvokeEvent, filePath: string): Promise<boolean> {
|
|
||||||
try {
|
|
||||||
// Security check: only allow deletion within temp directory
|
|
||||||
const tempBase = path.normalize(path.join(app.getPath('temp'), 'cherry-studio', 'lan-transfer'))
|
|
||||||
const resolvedPath = path.normalize(path.resolve(filePath))
|
|
||||||
|
|
||||||
// Use normalized paths with trailing separator to prevent prefix attacks (e.g., /temp-evil)
|
|
||||||
if (!resolvedPath.startsWith(tempBase + path.sep) && resolvedPath !== tempBase) {
|
|
||||||
logger.warn(`[BackupManager] Attempted to delete file outside temp directory: ${filePath}`)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if (await fs.pathExists(resolvedPath)) {
|
|
||||||
await fs.remove(resolvedPath)
|
|
||||||
logger.info(`[BackupManager] Deleted temp backup: ${resolvedPath}`)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('[BackupManager] Failed to delete temp backup:', error as Error)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export default BackupManager
|
export default BackupManager
|
||||||
|
|||||||
@ -32,8 +32,7 @@ export enum ConfigKeys {
|
|||||||
Proxy = 'proxy',
|
Proxy = 'proxy',
|
||||||
EnableDeveloperMode = 'enableDeveloperMode',
|
EnableDeveloperMode = 'enableDeveloperMode',
|
||||||
ClientId = 'clientId',
|
ClientId = 'clientId',
|
||||||
GitBashPath = 'gitBashPath',
|
GitBashPath = 'gitBashPath'
|
||||||
GitBashPathSource = 'gitBashPathSource' // 'manual' | 'auto' | null
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export class ConfigManager {
|
export class ConfigManager {
|
||||||
|
|||||||
@ -163,7 +163,7 @@ class FileStorage {
|
|||||||
fs.mkdirSync(this.storageDir, { recursive: true })
|
fs.mkdirSync(this.storageDir, { recursive: true })
|
||||||
}
|
}
|
||||||
if (!fs.existsSync(this.notesDir)) {
|
if (!fs.existsSync(this.notesDir)) {
|
||||||
fs.mkdirSync(this.notesDir, { recursive: true })
|
fs.mkdirSync(this.storageDir, { recursive: true })
|
||||||
}
|
}
|
||||||
if (!fs.existsSync(this.tempDir)) {
|
if (!fs.existsSync(this.tempDir)) {
|
||||||
fs.mkdirSync(this.tempDir, { recursive: true })
|
fs.mkdirSync(this.tempDir, { recursive: true })
|
||||||
|
|||||||
@ -1,207 +0,0 @@
|
|||||||
import { loggerService } from '@logger'
|
|
||||||
import type { LocalTransferPeer, LocalTransferState } from '@shared/config/types'
|
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
|
||||||
import type { Browser, Service } from 'bonjour-service'
|
|
||||||
import Bonjour from 'bonjour-service'
|
|
||||||
|
|
||||||
import { windowService } from './WindowService'
|
|
||||||
|
|
||||||
const SERVICE_TYPE = 'cherrystudio'
|
|
||||||
const SERVICE_PROTOCOL = 'tcp' as const
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('LocalTransferService')
|
|
||||||
|
|
||||||
type StartDiscoveryOptions = {
|
|
||||||
resetList?: boolean
|
|
||||||
}
|
|
||||||
|
|
||||||
class LocalTransferService {
|
|
||||||
private static instance: LocalTransferService
|
|
||||||
private bonjour: Bonjour | null = null
|
|
||||||
private browser: Browser | null = null
|
|
||||||
private services = new Map<string, LocalTransferPeer>()
|
|
||||||
private isScanning = false
|
|
||||||
private lastScanStartedAt?: number
|
|
||||||
private lastUpdatedAt = Date.now()
|
|
||||||
private lastError?: string
|
|
||||||
|
|
||||||
private constructor() {}
|
|
||||||
|
|
||||||
public static getInstance(): LocalTransferService {
|
|
||||||
if (!LocalTransferService.instance) {
|
|
||||||
LocalTransferService.instance = new LocalTransferService()
|
|
||||||
}
|
|
||||||
return LocalTransferService.instance
|
|
||||||
}
|
|
||||||
|
|
||||||
public startDiscovery(options?: StartDiscoveryOptions): LocalTransferState {
|
|
||||||
if (options?.resetList) {
|
|
||||||
this.services.clear()
|
|
||||||
}
|
|
||||||
|
|
||||||
this.isScanning = true
|
|
||||||
this.lastScanStartedAt = Date.now()
|
|
||||||
this.lastUpdatedAt = Date.now()
|
|
||||||
this.lastError = undefined
|
|
||||||
this.restartBrowser()
|
|
||||||
this.broadcastState()
|
|
||||||
return this.getState()
|
|
||||||
}
|
|
||||||
|
|
||||||
public stopDiscovery(): LocalTransferState {
|
|
||||||
if (this.browser) {
|
|
||||||
try {
|
|
||||||
this.browser.stop()
|
|
||||||
} catch (error) {
|
|
||||||
logger.warn('Failed to stop local transfer browser', error as Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
this.isScanning = false
|
|
||||||
this.lastUpdatedAt = Date.now()
|
|
||||||
this.broadcastState()
|
|
||||||
return this.getState()
|
|
||||||
}
|
|
||||||
|
|
||||||
public getState(): LocalTransferState {
|
|
||||||
const services = Array.from(this.services.values()).sort((a, b) => a.name.localeCompare(b.name))
|
|
||||||
return {
|
|
||||||
services,
|
|
||||||
isScanning: this.isScanning,
|
|
||||||
lastScanStartedAt: this.lastScanStartedAt,
|
|
||||||
lastUpdatedAt: this.lastUpdatedAt,
|
|
||||||
lastError: this.lastError
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public getPeerById(id: string): LocalTransferPeer | undefined {
|
|
||||||
return this.services.get(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
public dispose(): void {
|
|
||||||
this.stopDiscovery()
|
|
||||||
this.services.clear()
|
|
||||||
this.browser?.removeAllListeners()
|
|
||||||
this.browser = null
|
|
||||||
if (this.bonjour) {
|
|
||||||
try {
|
|
||||||
this.bonjour.destroy()
|
|
||||||
} catch (error) {
|
|
||||||
logger.warn('Failed to destroy Bonjour instance', error as Error)
|
|
||||||
}
|
|
||||||
this.bonjour = null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private getBonjour(): Bonjour {
|
|
||||||
if (!this.bonjour) {
|
|
||||||
this.bonjour = new Bonjour()
|
|
||||||
}
|
|
||||||
return this.bonjour
|
|
||||||
}
|
|
||||||
|
|
||||||
private restartBrowser(): void {
|
|
||||||
// Clean up existing browser
|
|
||||||
if (this.browser) {
|
|
||||||
this.browser.removeAllListeners()
|
|
||||||
try {
|
|
||||||
this.browser.stop()
|
|
||||||
} catch (error) {
|
|
||||||
logger.warn('Error while stopping Bonjour browser', error as Error)
|
|
||||||
}
|
|
||||||
this.browser = null
|
|
||||||
}
|
|
||||||
|
|
||||||
// Destroy and recreate Bonjour instance to prevent socket leaks
|
|
||||||
if (this.bonjour) {
|
|
||||||
try {
|
|
||||||
this.bonjour.destroy()
|
|
||||||
} catch (error) {
|
|
||||||
logger.warn('Error while destroying Bonjour instance', error as Error)
|
|
||||||
}
|
|
||||||
this.bonjour = null
|
|
||||||
}
|
|
||||||
|
|
||||||
const browser = this.getBonjour().find({ type: SERVICE_TYPE, protocol: SERVICE_PROTOCOL })
|
|
||||||
this.browser = browser
|
|
||||||
this.bindBrowserEvents(browser)
|
|
||||||
|
|
||||||
try {
|
|
||||||
browser.start()
|
|
||||||
logger.info('Local transfer discovery started')
|
|
||||||
} catch (error) {
|
|
||||||
const err = error instanceof Error ? error : new Error(String(error))
|
|
||||||
this.lastError = err.message
|
|
||||||
logger.error('Failed to start local transfer discovery', err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private bindBrowserEvents(browser: Browser) {
|
|
||||||
browser.on('up', (service) => {
|
|
||||||
const peer = this.normalizeService(service)
|
|
||||||
logger.info(`LAN peer detected: ${peer.name} (${peer.addresses.join(', ')})`)
|
|
||||||
this.services.set(peer.id, peer)
|
|
||||||
this.lastUpdatedAt = Date.now()
|
|
||||||
this.broadcastState()
|
|
||||||
})
|
|
||||||
|
|
||||||
browser.on('down', (service) => {
|
|
||||||
const key = this.buildServiceKey(service.fqdn || service.name, service.host, service.port)
|
|
||||||
if (this.services.delete(key)) {
|
|
||||||
logger.info(`LAN peer removed: ${service.name}`)
|
|
||||||
this.lastUpdatedAt = Date.now()
|
|
||||||
this.broadcastState()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
browser.on('error', (error) => {
|
|
||||||
const err = error instanceof Error ? error : new Error(String(error))
|
|
||||||
logger.error('Local transfer discovery error', err)
|
|
||||||
this.lastError = err.message
|
|
||||||
this.broadcastState()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
private normalizeService(service: Service): LocalTransferPeer {
|
|
||||||
const addressCandidates = [...(service.addresses || []), service.referer?.address].filter(
|
|
||||||
(value): value is string => typeof value === 'string' && value.length > 0
|
|
||||||
)
|
|
||||||
const addresses = Array.from(new Set(addressCandidates))
|
|
||||||
const txtEntries = Object.entries(service.txt || {})
|
|
||||||
const txt =
|
|
||||||
txtEntries.length > 0
|
|
||||||
? Object.fromEntries(
|
|
||||||
txtEntries.map(([key, value]) => [key, value === undefined || value === null ? '' : String(value)])
|
|
||||||
)
|
|
||||||
: undefined
|
|
||||||
|
|
||||||
const peer: LocalTransferPeer = {
|
|
||||||
id: this.buildServiceKey(service.fqdn || service.name, service.host, service.port),
|
|
||||||
name: service.name,
|
|
||||||
host: service.host,
|
|
||||||
fqdn: service.fqdn,
|
|
||||||
port: service.port,
|
|
||||||
type: service.type,
|
|
||||||
protocol: service.protocol,
|
|
||||||
addresses,
|
|
||||||
txt,
|
|
||||||
updatedAt: Date.now()
|
|
||||||
}
|
|
||||||
|
|
||||||
return peer
|
|
||||||
}
|
|
||||||
|
|
||||||
private buildServiceKey(name?: string, host?: string, port?: number): string {
|
|
||||||
const raw = [name, host, port?.toString()].filter(Boolean).join('-')
|
|
||||||
return raw || `service-${Date.now()}`
|
|
||||||
}
|
|
||||||
|
|
||||||
private broadcastState() {
|
|
||||||
const mainWindow = windowService.getMainWindow()
|
|
||||||
if (!mainWindow || mainWindow.isDestroyed()) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
mainWindow.webContents.send(IpcChannel.LocalTransfer_ServicesUpdated, this.getState())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const localTransferService = LocalTransferService.getInstance()
|
|
||||||
@ -6,7 +6,7 @@ import { loggerService } from '@logger'
|
|||||||
import { createInMemoryMCPServer } from '@main/mcpServers/factory'
|
import { createInMemoryMCPServer } from '@main/mcpServers/factory'
|
||||||
import { makeSureDirExists, removeEnvProxy } from '@main/utils'
|
import { makeSureDirExists, removeEnvProxy } from '@main/utils'
|
||||||
import { buildFunctionCallToolName } from '@main/utils/mcp'
|
import { buildFunctionCallToolName } from '@main/utils/mcp'
|
||||||
import { findCommandInShellEnv, getBinaryName, getBinaryPath, isBinaryExists } from '@main/utils/process'
|
import { getBinaryName, getBinaryPath } from '@main/utils/process'
|
||||||
import getLoginShellEnvironment from '@main/utils/shell-env'
|
import getLoginShellEnvironment from '@main/utils/shell-env'
|
||||||
import { TraceMethod, withSpanFunc } from '@mcp-trace/trace-core'
|
import { TraceMethod, withSpanFunc } from '@mcp-trace/trace-core'
|
||||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||||
@ -249,26 +249,6 @@ class McpService {
|
|||||||
StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
||||||
> => {
|
> => {
|
||||||
// Create appropriate transport based on configuration
|
// Create appropriate transport based on configuration
|
||||||
|
|
||||||
// Special case for nowledgeMem - uses HTTP transport instead of in-memory
|
|
||||||
if (isBuiltinMCPServer(server) && server.name === BuiltinMCPServerNames.nowledgeMem) {
|
|
||||||
const nowledgeMemUrl = 'http://127.0.0.1:14242/mcp'
|
|
||||||
const options: StreamableHTTPClientTransportOptions = {
|
|
||||||
fetch: async (url, init) => {
|
|
||||||
return net.fetch(typeof url === 'string' ? url : url.toString(), init)
|
|
||||||
},
|
|
||||||
requestInit: {
|
|
||||||
headers: {
|
|
||||||
...defaultAppHeaders(),
|
|
||||||
APP: 'Cherry Studio'
|
|
||||||
}
|
|
||||||
},
|
|
||||||
authProvider
|
|
||||||
}
|
|
||||||
getServerLogger(server).debug(`Using StreamableHTTPClientTransport for ${server.name}`)
|
|
||||||
return new StreamableHTTPClientTransport(new URL(nowledgeMemUrl), options)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isBuiltinMCPServer(server) && server.name !== BuiltinMCPServerNames.mcpAutoInstall) {
|
if (isBuiltinMCPServer(server) && server.name !== BuiltinMCPServerNames.mcpAutoInstall) {
|
||||||
getServerLogger(server).debug(`Using in-memory transport`)
|
getServerLogger(server).debug(`Using in-memory transport`)
|
||||||
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair()
|
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair()
|
||||||
@ -318,10 +298,6 @@ class McpService {
|
|||||||
} else if (server.command) {
|
} else if (server.command) {
|
||||||
let cmd = server.command
|
let cmd = server.command
|
||||||
|
|
||||||
// Get login shell environment first - needed for command detection and server execution
|
|
||||||
// Note: getLoginShellEnvironment() is memoized, so subsequent calls are fast
|
|
||||||
const loginShellEnv = await getLoginShellEnvironment()
|
|
||||||
|
|
||||||
// For DXT servers, use resolved configuration with platform overrides and variable substitution
|
// For DXT servers, use resolved configuration with platform overrides and variable substitution
|
||||||
if (server.dxtPath) {
|
if (server.dxtPath) {
|
||||||
const resolvedConfig = this.dxtService.getResolvedMcpConfig(server.dxtPath)
|
const resolvedConfig = this.dxtService.getResolvedMcpConfig(server.dxtPath)
|
||||||
@ -343,45 +319,18 @@ class McpService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (server.command === 'npx') {
|
if (server.command === 'npx') {
|
||||||
// First, check if npx is available in user's shell environment
|
cmd = await getBinaryPath('bun')
|
||||||
const npxPath = await findCommandInShellEnv('npx', loginShellEnv)
|
getServerLogger(server).debug(`Using command`, { command: cmd })
|
||||||
|
|
||||||
if (npxPath) {
|
// add -x to args if args exist
|
||||||
// Use system npx
|
if (args && args.length > 0) {
|
||||||
cmd = npxPath
|
if (!args.includes('-y')) {
|
||||||
getServerLogger(server).debug(`Using system npx`, { command: cmd })
|
args.unshift('-y')
|
||||||
} else {
|
}
|
||||||
// System npx not found, try bundled bun as fallback
|
if (!args.includes('x')) {
|
||||||
getServerLogger(server).debug(`System npx not found, checking for bundled bun`)
|
args.unshift('x')
|
||||||
|
|
||||||
if (await isBinaryExists('bun')) {
|
|
||||||
// Fall back to bundled bun
|
|
||||||
cmd = await getBinaryPath('bun')
|
|
||||||
getServerLogger(server).info(`Using bundled bun as fallback (npx not found in PATH)`, {
|
|
||||||
command: cmd
|
|
||||||
})
|
|
||||||
|
|
||||||
// Transform args for bun x format
|
|
||||||
if (args && args.length > 0) {
|
|
||||||
if (!args.includes('-y')) {
|
|
||||||
args.unshift('-y')
|
|
||||||
}
|
|
||||||
if (!args.includes('x')) {
|
|
||||||
args.unshift('x')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Neither npx nor bun available
|
|
||||||
throw new Error(
|
|
||||||
'npx not found in PATH and bundled bun is not available. This may indicate an installation issue.\n' +
|
|
||||||
'Please either:\n' +
|
|
||||||
'1. Install Node.js (which includes npx) from https://nodejs.org\n' +
|
|
||||||
'2. Run the MCP dependencies installer from Settings\n' +
|
|
||||||
'3. Restart the application if you recently installed Node.js'
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (server.registryUrl) {
|
if (server.registryUrl) {
|
||||||
server.env = {
|
server.env = {
|
||||||
...server.env,
|
...server.env,
|
||||||
@ -396,35 +345,7 @@ class McpService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (server.command === 'uvx' || server.command === 'uv') {
|
} else if (server.command === 'uvx' || server.command === 'uv') {
|
||||||
// First, check if uvx/uv is available in user's shell environment
|
cmd = await getBinaryPath(server.command)
|
||||||
const uvPath = await findCommandInShellEnv(server.command, loginShellEnv)
|
|
||||||
|
|
||||||
if (uvPath) {
|
|
||||||
// Use system uvx/uv
|
|
||||||
cmd = uvPath
|
|
||||||
getServerLogger(server).debug(`Using system ${server.command}`, { command: cmd })
|
|
||||||
} else {
|
|
||||||
// System command not found, try bundled version as fallback
|
|
||||||
getServerLogger(server).debug(`System ${server.command} not found, checking for bundled version`)
|
|
||||||
|
|
||||||
if (await isBinaryExists(server.command)) {
|
|
||||||
// Fall back to bundled version
|
|
||||||
cmd = await getBinaryPath(server.command)
|
|
||||||
getServerLogger(server).info(`Using bundled ${server.command} as fallback (not found in PATH)`, {
|
|
||||||
command: cmd
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
// Neither system nor bundled available
|
|
||||||
throw new Error(
|
|
||||||
`${server.command} not found in PATH and bundled version is not available. This may indicate an installation issue.\n` +
|
|
||||||
'Please either:\n' +
|
|
||||||
'1. Install uv from https://github.com/astral-sh/uv\n' +
|
|
||||||
'2. Run the MCP dependencies installer from Settings\n' +
|
|
||||||
`3. Restart the application if you recently installed ${server.command}`
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (server.registryUrl) {
|
if (server.registryUrl) {
|
||||||
server.env = {
|
server.env = {
|
||||||
...server.env,
|
...server.env,
|
||||||
@ -435,6 +356,8 @@ class McpService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
getServerLogger(server).debug(`Starting server`, { command: cmd, args })
|
getServerLogger(server).debug(`Starting server`, { command: cmd, args })
|
||||||
|
// Logger.info(`[MCP] Environment variables for server:`, server.env)
|
||||||
|
const loginShellEnv = await getLoginShellEnvironment()
|
||||||
|
|
||||||
// Bun not support proxy https://github.com/oven-sh/bun/issues/16812
|
// Bun not support proxy https://github.com/oven-sh/bun/issues/16812
|
||||||
if (cmd.includes('bun')) {
|
if (cmd.includes('bun')) {
|
||||||
|
|||||||
@ -14,36 +14,38 @@ export class SearchService {
|
|||||||
return SearchService.instance
|
return SearchService.instance
|
||||||
}
|
}
|
||||||
|
|
||||||
private async createNewSearchWindow(uid: string, show: boolean = false): Promise<BrowserWindow> {
|
constructor() {
|
||||||
|
// Initialize the service
|
||||||
|
}
|
||||||
|
|
||||||
|
private async createNewSearchWindow(uid: string): Promise<BrowserWindow> {
|
||||||
const newWindow = new BrowserWindow({
|
const newWindow = new BrowserWindow({
|
||||||
width: 1280,
|
width: 800,
|
||||||
height: 768,
|
height: 600,
|
||||||
show,
|
show: false,
|
||||||
webPreferences: {
|
webPreferences: {
|
||||||
nodeIntegration: true,
|
nodeIntegration: true,
|
||||||
contextIsolation: false,
|
contextIsolation: false,
|
||||||
devTools: is.dev
|
devTools: is.dev
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
newWindow.webContents.session.webRequest.onBeforeSendHeaders({ urls: ['*://*/*'] }, (details, callback) => {
|
||||||
|
const headers = {
|
||||||
|
...details.requestHeaders,
|
||||||
|
'User-Agent':
|
||||||
|
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||||||
|
}
|
||||||
|
callback({ requestHeaders: headers })
|
||||||
|
})
|
||||||
this.searchWindows[uid] = newWindow
|
this.searchWindows[uid] = newWindow
|
||||||
newWindow.on('closed', () => delete this.searchWindows[uid])
|
newWindow.on('closed', () => {
|
||||||
|
delete this.searchWindows[uid]
|
||||||
newWindow.webContents.userAgent =
|
})
|
||||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36'
|
|
||||||
|
|
||||||
return newWindow
|
return newWindow
|
||||||
}
|
}
|
||||||
|
|
||||||
public async openSearchWindow(uid: string, show: boolean = false): Promise<void> {
|
public async openSearchWindow(uid: string): Promise<void> {
|
||||||
const existingWindow = this.searchWindows[uid]
|
await this.createNewSearchWindow(uid)
|
||||||
|
|
||||||
if (existingWindow) {
|
|
||||||
show && existingWindow.show()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
await this.createNewSearchWindow(uid, show)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public async closeSearchWindow(uid: string): Promise<void> {
|
public async closeSearchWindow(uid: string): Promise<void> {
|
||||||
|
|||||||
359
src/main/services/WebSocketService.ts
Normal file
359
src/main/services/WebSocketService.ts
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
import { loggerService } from '@logger'
|
||||||
|
import type { WebSocketCandidatesResponse, WebSocketStatusResponse } from '@shared/config/types'
|
||||||
|
import * as fs from 'fs'
|
||||||
|
import { networkInterfaces } from 'os'
|
||||||
|
import * as path from 'path'
|
||||||
|
import type { Socket } from 'socket.io'
|
||||||
|
import { Server } from 'socket.io'
|
||||||
|
|
||||||
|
import { windowService } from './WindowService'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('WebSocketService')
|
||||||
|
|
||||||
|
class WebSocketService {
|
||||||
|
private io: Server | null = null
|
||||||
|
private isStarted = false
|
||||||
|
private port = 7017
|
||||||
|
private connectedClients = new Set<string>()
|
||||||
|
|
||||||
|
private getLocalIpAddress(): string | undefined {
|
||||||
|
const interfaces = networkInterfaces()
|
||||||
|
|
||||||
|
// 按优先级排序的网络接口名称模式
|
||||||
|
const interfacePriority = [
|
||||||
|
// macOS: 以太网/Wi-Fi 优先
|
||||||
|
/^en[0-9]+$/, // en0, en1 (以太网/Wi-Fi)
|
||||||
|
/^(en|eth)[0-9]+$/, // 以太网接口
|
||||||
|
/^wlan[0-9]+$/, // 无线接口
|
||||||
|
// Windows: 以太网/Wi-Fi 优先
|
||||||
|
/^(Ethernet|Wi-Fi|Local Area Connection)/,
|
||||||
|
/^(Wi-Fi|无线网络连接)/,
|
||||||
|
// Linux: 以太网/Wi-Fi 优先
|
||||||
|
/^(eth|enp|wlp|wlan)[0-9]+/,
|
||||||
|
// 虚拟化接口(低优先级)
|
||||||
|
/^bridge[0-9]+$/, // Docker bridge
|
||||||
|
/^veth[0-9]+$/, // Docker veth
|
||||||
|
/^docker[0-9]+/, // Docker interfaces
|
||||||
|
/^br-[0-9a-f]+/, // Docker bridge
|
||||||
|
/^vmnet[0-9]+$/, // VMware
|
||||||
|
/^vboxnet[0-9]+$/, // VirtualBox
|
||||||
|
// VPN 隧道接口(低优先级)
|
||||||
|
/^utun[0-9]+$/, // macOS VPN
|
||||||
|
/^tun[0-9]+$/, // Linux/Unix VPN
|
||||||
|
/^tap[0-9]+$/, // TAP interfaces
|
||||||
|
/^tailscale[0-9]*$/, // Tailscale VPN
|
||||||
|
/^wg[0-9]+$/ // WireGuard VPN
|
||||||
|
]
|
||||||
|
|
||||||
|
const candidates: Array<{ interface: string; address: string; priority: number }> = []
|
||||||
|
|
||||||
|
for (const [name, ifaces] of Object.entries(interfaces)) {
|
||||||
|
for (const iface of ifaces || []) {
|
||||||
|
if (iface.family === 'IPv4' && !iface.internal) {
|
||||||
|
// 计算接口优先级
|
||||||
|
let priority = 999 // 默认最低优先级
|
||||||
|
for (let i = 0; i < interfacePriority.length; i++) {
|
||||||
|
if (interfacePriority[i].test(name)) {
|
||||||
|
priority = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates.push({
|
||||||
|
interface: name,
|
||||||
|
address: iface.address,
|
||||||
|
priority
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (candidates.length === 0) {
|
||||||
|
logger.warn('无法获取局域网 IP,使用默认 IP: 127.0.0.1')
|
||||||
|
return '127.0.0.1'
|
||||||
|
}
|
||||||
|
|
||||||
|
// 按优先级排序,选择优先级最高的
|
||||||
|
candidates.sort((a, b) => a.priority - b.priority)
|
||||||
|
const best = candidates[0]
|
||||||
|
|
||||||
|
logger.info(`获取局域网 IP: ${best.address} (interface: ${best.interface})`)
|
||||||
|
return best.address
|
||||||
|
}
|
||||||
|
|
||||||
|
public start = async (): Promise<{ success: boolean; port?: number; error?: string }> => {
|
||||||
|
if (this.isStarted && this.io) {
|
||||||
|
return { success: true, port: this.port }
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
this.io = new Server(this.port, {
|
||||||
|
cors: {
|
||||||
|
origin: '*',
|
||||||
|
methods: ['GET', 'POST']
|
||||||
|
},
|
||||||
|
transports: ['websocket', 'polling'],
|
||||||
|
allowEIO3: true,
|
||||||
|
pingTimeout: 60000,
|
||||||
|
pingInterval: 25000
|
||||||
|
})
|
||||||
|
|
||||||
|
this.io.on('connection', (socket: Socket) => {
|
||||||
|
this.connectedClients.add(socket.id)
|
||||||
|
|
||||||
|
const mainWindow = windowService.getMainWindow()
|
||||||
|
if (!mainWindow) {
|
||||||
|
logger.error('Main window is null, cannot send connection event')
|
||||||
|
} else {
|
||||||
|
mainWindow.webContents.send('websocket-client-connected', {
|
||||||
|
connected: true,
|
||||||
|
clientId: socket.id
|
||||||
|
})
|
||||||
|
logger.info(`Connection event sent to renderer, total clients: ${this.connectedClients.size}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
socket.on('message', (data) => {
|
||||||
|
logger.info('Received message from mobile:', data)
|
||||||
|
mainWindow?.webContents.send('websocket-message-received', data)
|
||||||
|
socket.emit('message_received', { success: true })
|
||||||
|
})
|
||||||
|
|
||||||
|
socket.on('disconnect', () => {
|
||||||
|
logger.info(`Client disconnected: ${socket.id}`)
|
||||||
|
this.connectedClients.delete(socket.id)
|
||||||
|
|
||||||
|
if (this.connectedClients.size === 0) {
|
||||||
|
mainWindow?.webContents.send('websocket-client-connected', {
|
||||||
|
connected: false,
|
||||||
|
clientId: socket.id
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Engine 层面的事件监听
|
||||||
|
this.io.engine.on('connection_error', (err) => {
|
||||||
|
logger.error('Engine connection error:', err)
|
||||||
|
})
|
||||||
|
|
||||||
|
this.io.engine.on('connection', (rawSocket) => {
|
||||||
|
const remoteAddr = rawSocket.request.connection.remoteAddress
|
||||||
|
logger.info(`[Engine] Raw connection from: ${remoteAddr}`)
|
||||||
|
logger.info(`[Engine] Transport: ${rawSocket.transport.name}`)
|
||||||
|
|
||||||
|
rawSocket.on('packet', (packet: { type: string; data?: any }) => {
|
||||||
|
logger.info(
|
||||||
|
`[Engine] ← Packet from ${remoteAddr}: type="${packet.type}"`,
|
||||||
|
packet.data ? { data: packet.data } : {}
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
rawSocket.on('packetCreate', (packet: { type: string; data?: any }) => {
|
||||||
|
logger.info(`[Engine] → Packet to ${remoteAddr}: type="${packet.type}"`)
|
||||||
|
})
|
||||||
|
|
||||||
|
rawSocket.on('close', (reason: string) => {
|
||||||
|
logger.warn(`[Engine] Connection closed from ${remoteAddr}, reason: ${reason}`)
|
||||||
|
})
|
||||||
|
|
||||||
|
rawSocket.on('error', (error: Error) => {
|
||||||
|
logger.error(`[Engine] Connection error from ${remoteAddr}:`, error)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Socket.IO 握手失败监听
|
||||||
|
this.io.on('connection_error', (err) => {
|
||||||
|
logger.error('[Socket.IO] Connection error during handshake:', err)
|
||||||
|
})
|
||||||
|
|
||||||
|
this.isStarted = true
|
||||||
|
logger.info(`WebSocket server started on port ${this.port}`)
|
||||||
|
|
||||||
|
return { success: true, port: this.port }
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to start WebSocket server:', error as Error)
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
error: error instanceof Error ? error.message : 'Unknown error'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public stop = async (): Promise<{ success: boolean }> => {
|
||||||
|
if (!this.isStarted || !this.io) {
|
||||||
|
return { success: true }
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
await new Promise<void>((resolve) => {
|
||||||
|
this.io!.close(() => {
|
||||||
|
resolve()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
this.io = null
|
||||||
|
this.isStarted = false
|
||||||
|
this.connectedClients.clear()
|
||||||
|
logger.info('WebSocket server stopped')
|
||||||
|
|
||||||
|
return { success: true }
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to stop WebSocket server:', error as Error)
|
||||||
|
return { success: false }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public getStatus = async (): Promise<WebSocketStatusResponse> => {
|
||||||
|
return {
|
||||||
|
isRunning: this.isStarted,
|
||||||
|
port: this.isStarted ? this.port : undefined,
|
||||||
|
ip: this.isStarted ? this.getLocalIpAddress() : undefined,
|
||||||
|
clientConnected: this.connectedClients.size > 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public getAllCandidates = async (): Promise<WebSocketCandidatesResponse[]> => {
|
||||||
|
const interfaces = networkInterfaces()
|
||||||
|
|
||||||
|
// 按优先级排序的网络接口名称模式
|
||||||
|
const interfacePriority = [
|
||||||
|
// macOS: 以太网/Wi-Fi 优先
|
||||||
|
/^en[0-9]+$/, // en0, en1 (以太网/Wi-Fi)
|
||||||
|
/^(en|eth)[0-9]+$/, // 以太网接口
|
||||||
|
/^wlan[0-9]+$/, // 无线接口
|
||||||
|
// Windows: 以太网/Wi-Fi 优先
|
||||||
|
/^(Ethernet|Wi-Fi|Local Area Connection)/,
|
||||||
|
/^(Wi-Fi|无线网络连接)/,
|
||||||
|
// Linux: 以太网/Wi-Fi 优先
|
||||||
|
/^(eth|enp|wlp|wlan)[0-9]+/,
|
||||||
|
// 虚拟化接口(低优先级)
|
||||||
|
/^bridge[0-9]+$/, // Docker bridge
|
||||||
|
/^veth[0-9]+$/, // Docker veth
|
||||||
|
/^docker[0-9]+/, // Docker interfaces
|
||||||
|
/^br-[0-9a-f]+/, // Docker bridge
|
||||||
|
/^vmnet[0-9]+$/, // VMware
|
||||||
|
/^vboxnet[0-9]+$/, // VirtualBox
|
||||||
|
// VPN 隧道接口(低优先级)
|
||||||
|
/^utun[0-9]+$/, // macOS VPN
|
||||||
|
/^tun[0-9]+$/, // Linux/Unix VPN
|
||||||
|
/^tap[0-9]+$/, // TAP interfaces
|
||||||
|
/^tailscale[0-9]*$/, // Tailscale VPN
|
||||||
|
/^wg[0-9]+$/ // WireGuard VPN
|
||||||
|
]
|
||||||
|
|
||||||
|
const candidates: Array<{ host: string; interface: string; priority: number }> = []
|
||||||
|
|
||||||
|
for (const [name, ifaces] of Object.entries(interfaces)) {
|
||||||
|
for (const iface of ifaces || []) {
|
||||||
|
if (iface.family === 'IPv4' && !iface.internal) {
|
||||||
|
// 计算接口优先级
|
||||||
|
let priority = 999 // 默认最低优先级
|
||||||
|
for (let i = 0; i < interfacePriority.length; i++) {
|
||||||
|
if (interfacePriority[i].test(name)) {
|
||||||
|
priority = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates.push({
|
||||||
|
host: iface.address,
|
||||||
|
interface: name,
|
||||||
|
priority
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.debug(`Found interface: ${name} -> ${iface.address} (priority: ${priority})`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 按优先级排序返回
|
||||||
|
candidates.sort((a, b) => a.priority - b.priority)
|
||||||
|
logger.info(
|
||||||
|
`Found ${candidates.length} IP candidates: ${candidates.map((c) => `${c.host}(${c.interface})`).join(', ')}`
|
||||||
|
)
|
||||||
|
return candidates
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendFile = async (
|
||||||
|
_: Electron.IpcMainInvokeEvent,
|
||||||
|
filePath: string
|
||||||
|
): Promise<{ success: boolean; error?: string }> => {
|
||||||
|
if (!this.isStarted || !this.io) {
|
||||||
|
const errorMsg = 'WebSocket server is not running.'
|
||||||
|
logger.error(errorMsg)
|
||||||
|
return { success: false, error: errorMsg }
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.connectedClients.size === 0) {
|
||||||
|
const errorMsg = 'No client connected.'
|
||||||
|
logger.error(errorMsg)
|
||||||
|
return { success: false, error: errorMsg }
|
||||||
|
}
|
||||||
|
|
||||||
|
const mainWindow = windowService.getMainWindow()
|
||||||
|
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const stats = fs.statSync(filePath)
|
||||||
|
const totalSize = stats.size
|
||||||
|
const filename = path.basename(filePath)
|
||||||
|
const stream = fs.createReadStream(filePath)
|
||||||
|
let bytesSent = 0
|
||||||
|
const startTime = Date.now()
|
||||||
|
|
||||||
|
logger.info(`Starting file transfer: ${filename} (${this.formatFileSize(totalSize)})`)
|
||||||
|
|
||||||
|
// 向客户端发送文件开始的信号,包含文件名和总大小
|
||||||
|
this.io!.emit('zip-file-start', { filename, totalSize })
|
||||||
|
|
||||||
|
stream.on('data', (chunk) => {
|
||||||
|
bytesSent += chunk.length
|
||||||
|
const progress = (bytesSent / totalSize) * 100
|
||||||
|
|
||||||
|
// 向客户端发送文件块
|
||||||
|
this.io!.emit('zip-file-chunk', chunk)
|
||||||
|
|
||||||
|
// 向渲染进程发送进度更新
|
||||||
|
mainWindow?.webContents.send('file-send-progress', { progress })
|
||||||
|
|
||||||
|
// 每10%记录一次进度
|
||||||
|
if (Math.floor(progress) % 10 === 0) {
|
||||||
|
const elapsed = (Date.now() - startTime) / 1000
|
||||||
|
const speed = elapsed > 0 ? bytesSent / elapsed : 0
|
||||||
|
logger.info(`Transfer progress: ${Math.floor(progress)}% (${this.formatFileSize(speed)}/s)`)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
stream.on('end', () => {
|
||||||
|
const totalTime = (Date.now() - startTime) / 1000
|
||||||
|
const avgSpeed = totalTime > 0 ? totalSize / totalTime : 0
|
||||||
|
logger.info(
|
||||||
|
`File transfer completed: ${filename} in ${totalTime.toFixed(1)}s (${this.formatFileSize(avgSpeed)}/s)`
|
||||||
|
)
|
||||||
|
|
||||||
|
// 确保发送100%的进度
|
||||||
|
mainWindow?.webContents.send('file-send-progress', { progress: 100 })
|
||||||
|
// 向客户端发送文件结束的信号
|
||||||
|
this.io!.emit('zip-file-end')
|
||||||
|
resolve({ success: true })
|
||||||
|
})
|
||||||
|
|
||||||
|
stream.on('error', (error) => {
|
||||||
|
logger.error(`File transfer failed: ${filename}`, error)
|
||||||
|
reject({
|
||||||
|
success: false,
|
||||||
|
error: error instanceof Error ? error.message : 'Unknown error'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
private formatFileSize(bytes: number): string {
|
||||||
|
if (bytes === 0) return '0 B'
|
||||||
|
const k = 1024
|
||||||
|
const sizes = ['B', 'KB', 'MB', 'GB']
|
||||||
|
const i = Math.floor(Math.log(bytes) / Math.log(k))
|
||||||
|
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default new WebSocketService()
|
||||||
@ -1,274 +0,0 @@
|
|||||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
|
||||||
|
|
||||||
// Use vi.hoisted to define mocks that are available during hoisting
|
|
||||||
const { mockLogger } = vi.hoisted(() => ({
|
|
||||||
mockLogger: {
|
|
||||||
info: vi.fn(),
|
|
||||||
warn: vi.fn(),
|
|
||||||
error: vi.fn()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@logger', () => ({
|
|
||||||
loggerService: {
|
|
||||||
withContext: () => mockLogger
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('electron', () => ({
|
|
||||||
app: {
|
|
||||||
getPath: vi.fn((key: string) => {
|
|
||||||
if (key === 'temp') return '/tmp'
|
|
||||||
if (key === 'userData') return '/mock/userData'
|
|
||||||
return '/mock/unknown'
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('fs-extra', () => ({
|
|
||||||
default: {
|
|
||||||
pathExists: vi.fn(),
|
|
||||||
remove: vi.fn(),
|
|
||||||
ensureDir: vi.fn(),
|
|
||||||
copy: vi.fn(),
|
|
||||||
readdir: vi.fn(),
|
|
||||||
stat: vi.fn(),
|
|
||||||
readFile: vi.fn(),
|
|
||||||
writeFile: vi.fn(),
|
|
||||||
createWriteStream: vi.fn(),
|
|
||||||
createReadStream: vi.fn()
|
|
||||||
},
|
|
||||||
pathExists: vi.fn(),
|
|
||||||
remove: vi.fn(),
|
|
||||||
ensureDir: vi.fn(),
|
|
||||||
copy: vi.fn(),
|
|
||||||
readdir: vi.fn(),
|
|
||||||
stat: vi.fn(),
|
|
||||||
readFile: vi.fn(),
|
|
||||||
writeFile: vi.fn(),
|
|
||||||
createWriteStream: vi.fn(),
|
|
||||||
createReadStream: vi.fn()
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('../WindowService', () => ({
|
|
||||||
windowService: {
|
|
||||||
getMainWindow: vi.fn()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('../WebDav', () => ({
|
|
||||||
default: vi.fn()
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('../S3Storage', () => ({
|
|
||||||
default: vi.fn()
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('../../utils', () => ({
|
|
||||||
getDataPath: vi.fn(() => '/mock/data')
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('archiver', () => ({
|
|
||||||
default: vi.fn()
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('node-stream-zip', () => ({
|
|
||||||
default: vi.fn()
|
|
||||||
}))
|
|
||||||
|
|
||||||
// Import after mocks
|
|
||||||
import * as fs from 'fs-extra'
|
|
||||||
|
|
||||||
import BackupManager from '../BackupManager'
|
|
||||||
|
|
||||||
describe('BackupManager.deleteTempBackup - Security Tests', () => {
|
|
||||||
let backupManager: BackupManager
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
vi.clearAllMocks()
|
|
||||||
backupManager = new BackupManager()
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('Normal Operations', () => {
|
|
||||||
it('should delete valid file in allowed directory', async () => {
|
|
||||||
vi.mocked(fs.pathExists).mockResolvedValue(true as never)
|
|
||||||
vi.mocked(fs.remove).mockResolvedValue(undefined as never)
|
|
||||||
|
|
||||||
const validPath = '/tmp/cherry-studio/lan-transfer/backup.zip'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, validPath)
|
|
||||||
|
|
||||||
expect(result).toBe(true)
|
|
||||||
expect(fs.remove).toHaveBeenCalledWith(validPath)
|
|
||||||
expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('Deleted temp backup'))
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should delete file in nested subdirectory', async () => {
|
|
||||||
vi.mocked(fs.pathExists).mockResolvedValue(true as never)
|
|
||||||
vi.mocked(fs.remove).mockResolvedValue(undefined as never)
|
|
||||||
|
|
||||||
const nestedPath = '/tmp/cherry-studio/lan-transfer/sub/dir/file.zip'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, nestedPath)
|
|
||||||
|
|
||||||
expect(result).toBe(true)
|
|
||||||
expect(fs.remove).toHaveBeenCalledWith(nestedPath)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should return false when file does not exist', async () => {
|
|
||||||
vi.mocked(fs.pathExists).mockResolvedValue(false as never)
|
|
||||||
|
|
||||||
const missingPath = '/tmp/cherry-studio/lan-transfer/missing.zip'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, missingPath)
|
|
||||||
|
|
||||||
expect(result).toBe(false)
|
|
||||||
expect(fs.remove).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('Path Traversal Attacks', () => {
|
|
||||||
it('should block basic directory traversal attack (../../../../etc/passwd)', async () => {
|
|
||||||
const attackPath = '/tmp/cherry-studio/lan-transfer/../../../../etc/passwd'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
|
||||||
|
|
||||||
expect(result).toBe(false)
|
|
||||||
expect(fs.pathExists).not.toHaveBeenCalled()
|
|
||||||
expect(fs.remove).not.toHaveBeenCalled()
|
|
||||||
expect(mockLogger.warn).toHaveBeenCalledWith(expect.stringContaining('outside temp directory'))
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should block absolute path escape (/etc/passwd)', async () => {
|
|
||||||
const attackPath = '/etc/passwd'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
|
||||||
|
|
||||||
expect(result).toBe(false)
|
|
||||||
expect(fs.remove).not.toHaveBeenCalled()
|
|
||||||
expect(mockLogger.warn).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should block traversal with multiple slashes', async () => {
|
|
||||||
const attackPath = '/tmp/cherry-studio/lan-transfer/../../../etc/passwd'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
|
||||||
|
|
||||||
expect(result).toBe(false)
|
|
||||||
expect(fs.remove).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should block relative path traversal from current directory', async () => {
|
|
||||||
const attackPath = '../../../etc/passwd'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
|
||||||
|
|
||||||
expect(result).toBe(false)
|
|
||||||
expect(fs.remove).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should block traversal to parent directory', async () => {
|
|
||||||
const attackPath = '/tmp/cherry-studio/lan-transfer/../backup/secret.zip'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
|
||||||
|
|
||||||
expect(result).toBe(false)
|
|
||||||
expect(fs.remove).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('Prefix Attacks', () => {
|
|
||||||
it('should block similar prefix attack (lan-transfer-evil)', async () => {
|
|
||||||
const attackPath = '/tmp/cherry-studio/lan-transfer-evil/file.zip'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
|
||||||
|
|
||||||
expect(result).toBe(false)
|
|
||||||
expect(fs.remove).not.toHaveBeenCalled()
|
|
||||||
expect(mockLogger.warn).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should block path without separator (lan-transferx)', async () => {
|
|
||||||
const attackPath = '/tmp/cherry-studio/lan-transferx'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
|
||||||
|
|
||||||
expect(result).toBe(false)
|
|
||||||
expect(fs.remove).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should block different temp directory prefix', async () => {
|
|
||||||
const attackPath = '/tmp-evil/cherry-studio/lan-transfer/file.zip'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
|
||||||
|
|
||||||
expect(result).toBe(false)
|
|
||||||
expect(fs.remove).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('Error Handling', () => {
|
|
||||||
it('should return false and log error on permission denied', async () => {
|
|
||||||
vi.mocked(fs.pathExists).mockResolvedValue(true as never)
|
|
||||||
vi.mocked(fs.remove).mockRejectedValue(new Error('EACCES: permission denied') as never)
|
|
||||||
|
|
||||||
const validPath = '/tmp/cherry-studio/lan-transfer/file.zip'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, validPath)
|
|
||||||
|
|
||||||
expect(result).toBe(false)
|
|
||||||
expect(mockLogger.error).toHaveBeenCalledWith(expect.stringContaining('Failed to delete'), expect.any(Error))
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should return false on fs.pathExists error', async () => {
|
|
||||||
vi.mocked(fs.pathExists).mockRejectedValue(new Error('ENOENT') as never)
|
|
||||||
|
|
||||||
const validPath = '/tmp/cherry-studio/lan-transfer/file.zip'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, validPath)
|
|
||||||
|
|
||||||
expect(result).toBe(false)
|
|
||||||
expect(mockLogger.error).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle empty path string', async () => {
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, '')
|
|
||||||
|
|
||||||
expect(result).toBe(false)
|
|
||||||
expect(fs.remove).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('Edge Cases', () => {
|
|
||||||
it('should allow deletion of the temp directory itself', async () => {
|
|
||||||
vi.mocked(fs.pathExists).mockResolvedValue(true as never)
|
|
||||||
vi.mocked(fs.remove).mockResolvedValue(undefined as never)
|
|
||||||
|
|
||||||
const tempDir = '/tmp/cherry-studio/lan-transfer'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, tempDir)
|
|
||||||
|
|
||||||
expect(result).toBe(true)
|
|
||||||
expect(fs.remove).toHaveBeenCalledWith(tempDir)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle path with trailing slash', async () => {
|
|
||||||
vi.mocked(fs.pathExists).mockResolvedValue(true as never)
|
|
||||||
vi.mocked(fs.remove).mockResolvedValue(undefined as never)
|
|
||||||
|
|
||||||
const pathWithSlash = '/tmp/cherry-studio/lan-transfer/sub/'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, pathWithSlash)
|
|
||||||
|
|
||||||
// path.normalize removes trailing slash
|
|
||||||
expect(result).toBe(true)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle file with special characters in name', async () => {
|
|
||||||
vi.mocked(fs.pathExists).mockResolvedValue(true as never)
|
|
||||||
vi.mocked(fs.remove).mockResolvedValue(undefined as never)
|
|
||||||
|
|
||||||
const specialPath = '/tmp/cherry-studio/lan-transfer/file with spaces & (special).zip'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, specialPath)
|
|
||||||
|
|
||||||
expect(result).toBe(true)
|
|
||||||
expect(fs.remove).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle path with double slashes', async () => {
|
|
||||||
vi.mocked(fs.pathExists).mockResolvedValue(true as never)
|
|
||||||
vi.mocked(fs.remove).mockResolvedValue(undefined as never)
|
|
||||||
|
|
||||||
const doubleSlashPath = '/tmp/cherry-studio//lan-transfer//file.zip'
|
|
||||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, doubleSlashPath)
|
|
||||||
|
|
||||||
// path.normalize handles double slashes
|
|
||||||
expect(result).toBe(true)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@ -1,481 +0,0 @@
|
|||||||
import { EventEmitter } from 'events'
|
|
||||||
import { afterEach, beforeEach, describe, expect, it, type Mock, vi } from 'vitest'
|
|
||||||
|
|
||||||
// Create mock objects before vi.mock calls
|
|
||||||
const mockLogger = {
|
|
||||||
info: vi.fn(),
|
|
||||||
warn: vi.fn(),
|
|
||||||
error: vi.fn()
|
|
||||||
}
|
|
||||||
|
|
||||||
let mockMainWindow: {
|
|
||||||
isDestroyed: Mock
|
|
||||||
webContents: { send: Mock }
|
|
||||||
} | null = null
|
|
||||||
|
|
||||||
let mockBrowser: EventEmitter & {
|
|
||||||
start: Mock
|
|
||||||
stop: Mock
|
|
||||||
removeAllListeners: Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
let mockBonjour: {
|
|
||||||
find: Mock
|
|
||||||
destroy: Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mock dependencies before importing the service
|
|
||||||
vi.mock('@logger', () => ({
|
|
||||||
loggerService: {
|
|
||||||
withContext: () => mockLogger
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('../WindowService', () => ({
|
|
||||||
windowService: {
|
|
||||||
getMainWindow: vi.fn(() => mockMainWindow)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('bonjour-service', () => ({
|
|
||||||
default: vi.fn(() => mockBonjour)
|
|
||||||
}))
|
|
||||||
|
|
||||||
describe('LocalTransferService', () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
vi.clearAllMocks()
|
|
||||||
vi.resetModules()
|
|
||||||
|
|
||||||
// Reset mock objects
|
|
||||||
mockMainWindow = {
|
|
||||||
isDestroyed: vi.fn(() => false),
|
|
||||||
webContents: { send: vi.fn() }
|
|
||||||
}
|
|
||||||
|
|
||||||
mockBrowser = Object.assign(new EventEmitter(), {
|
|
||||||
start: vi.fn(),
|
|
||||||
stop: vi.fn(),
|
|
||||||
removeAllListeners: vi.fn()
|
|
||||||
})
|
|
||||||
|
|
||||||
mockBonjour = {
|
|
||||||
find: vi.fn(() => mockBrowser),
|
|
||||||
destroy: vi.fn()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
afterEach(() => {
|
|
||||||
vi.resetAllMocks()
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('startDiscovery', () => {
|
|
||||||
it('should set isScanning to true and start browser', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
const state = localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
expect(state.isScanning).toBe(true)
|
|
||||||
expect(state.lastScanStartedAt).toBeDefined()
|
|
||||||
expect(mockBonjour.find).toHaveBeenCalledWith({ type: 'cherrystudio', protocol: 'tcp' })
|
|
||||||
expect(mockBrowser.start).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should clear services when resetList is true', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
// First, start discovery and add a service
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
mockBrowser.emit('up', {
|
|
||||||
name: 'Test Service',
|
|
||||||
host: 'localhost',
|
|
||||||
port: 12345,
|
|
||||||
addresses: ['192.168.1.100'],
|
|
||||||
fqdn: 'test.local'
|
|
||||||
})
|
|
||||||
|
|
||||||
expect(localTransferService.getState().services).toHaveLength(1)
|
|
||||||
|
|
||||||
// Now restart with resetList
|
|
||||||
const state = localTransferService.startDiscovery({ resetList: true })
|
|
||||||
|
|
||||||
expect(state.services).toHaveLength(0)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should broadcast state after starting discovery', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
expect(mockMainWindow?.webContents.send).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle browser.start() error', async () => {
|
|
||||||
mockBrowser.start.mockImplementation(() => {
|
|
||||||
throw new Error('Failed to start mDNS')
|
|
||||||
})
|
|
||||||
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
const state = localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
expect(state.lastError).toBe('Failed to start mDNS')
|
|
||||||
expect(mockLogger.error).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('stopDiscovery', () => {
|
|
||||||
it('should set isScanning to false and stop browser', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
const state = localTransferService.stopDiscovery()
|
|
||||||
|
|
||||||
expect(state.isScanning).toBe(false)
|
|
||||||
expect(mockBrowser.stop).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle browser.stop() error gracefully', async () => {
|
|
||||||
mockBrowser.stop.mockImplementation(() => {
|
|
||||||
throw new Error('Stop failed')
|
|
||||||
})
|
|
||||||
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
// Should not throw
|
|
||||||
expect(() => localTransferService.stopDiscovery()).not.toThrow()
|
|
||||||
expect(mockLogger.warn).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should broadcast state after stopping', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
vi.clearAllMocks()
|
|
||||||
|
|
||||||
localTransferService.stopDiscovery()
|
|
||||||
|
|
||||||
expect(mockMainWindow?.webContents.send).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('browser events', () => {
|
|
||||||
it('should add service on "up" event', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
mockBrowser.emit('up', {
|
|
||||||
name: 'Test Service',
|
|
||||||
host: 'localhost',
|
|
||||||
port: 12345,
|
|
||||||
addresses: ['192.168.1.100'],
|
|
||||||
fqdn: 'test.local',
|
|
||||||
type: 'cherrystudio',
|
|
||||||
protocol: 'tcp'
|
|
||||||
})
|
|
||||||
|
|
||||||
const state = localTransferService.getState()
|
|
||||||
expect(state.services).toHaveLength(1)
|
|
||||||
expect(state.services[0].name).toBe('Test Service')
|
|
||||||
expect(state.services[0].port).toBe(12345)
|
|
||||||
expect(state.services[0].addresses).toContain('192.168.1.100')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should remove service on "down" event', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
// Add service
|
|
||||||
mockBrowser.emit('up', {
|
|
||||||
name: 'Test Service',
|
|
||||||
host: 'localhost',
|
|
||||||
port: 12345,
|
|
||||||
addresses: ['192.168.1.100'],
|
|
||||||
fqdn: 'test.local'
|
|
||||||
})
|
|
||||||
|
|
||||||
expect(localTransferService.getState().services).toHaveLength(1)
|
|
||||||
|
|
||||||
// Remove service
|
|
||||||
mockBrowser.emit('down', {
|
|
||||||
name: 'Test Service',
|
|
||||||
host: 'localhost',
|
|
||||||
port: 12345,
|
|
||||||
fqdn: 'test.local'
|
|
||||||
})
|
|
||||||
|
|
||||||
expect(localTransferService.getState().services).toHaveLength(0)
|
|
||||||
expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('removed'))
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should set lastError on "error" event', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
mockBrowser.emit('error', new Error('Discovery failed'))
|
|
||||||
|
|
||||||
const state = localTransferService.getState()
|
|
||||||
expect(state.lastError).toBe('Discovery failed')
|
|
||||||
expect(mockLogger.error).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle non-Error objects in error event', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
mockBrowser.emit('error', 'String error message')
|
|
||||||
|
|
||||||
const state = localTransferService.getState()
|
|
||||||
expect(state.lastError).toBe('String error message')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('getState', () => {
|
|
||||||
it('should return sorted services by name', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
mockBrowser.emit('up', {
|
|
||||||
name: 'Zebra Service',
|
|
||||||
host: 'host1',
|
|
||||||
port: 1001,
|
|
||||||
addresses: ['192.168.1.1']
|
|
||||||
})
|
|
||||||
|
|
||||||
mockBrowser.emit('up', {
|
|
||||||
name: 'Alpha Service',
|
|
||||||
host: 'host2',
|
|
||||||
port: 1002,
|
|
||||||
addresses: ['192.168.1.2']
|
|
||||||
})
|
|
||||||
|
|
||||||
const state = localTransferService.getState()
|
|
||||||
expect(state.services[0].name).toBe('Alpha Service')
|
|
||||||
expect(state.services[1].name).toBe('Zebra Service')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should include all state properties', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
const state = localTransferService.getState()
|
|
||||||
|
|
||||||
expect(state).toHaveProperty('services')
|
|
||||||
expect(state).toHaveProperty('isScanning')
|
|
||||||
expect(state).toHaveProperty('lastScanStartedAt')
|
|
||||||
expect(state).toHaveProperty('lastUpdatedAt')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('getPeerById', () => {
|
|
||||||
it('should return peer when exists', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
mockBrowser.emit('up', {
|
|
||||||
name: 'Test Service',
|
|
||||||
host: 'localhost',
|
|
||||||
port: 12345,
|
|
||||||
addresses: ['192.168.1.100'],
|
|
||||||
fqdn: 'test.local'
|
|
||||||
})
|
|
||||||
|
|
||||||
const services = localTransferService.getState().services
|
|
||||||
const peer = localTransferService.getPeerById(services[0].id)
|
|
||||||
|
|
||||||
expect(peer).toBeDefined()
|
|
||||||
expect(peer?.name).toBe('Test Service')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should return undefined when peer does not exist', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
const peer = localTransferService.getPeerById('non-existent-id')
|
|
||||||
|
|
||||||
expect(peer).toBeUndefined()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('normalizeService', () => {
|
|
||||||
it('should deduplicate addresses', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
mockBrowser.emit('up', {
|
|
||||||
name: 'Test Service',
|
|
||||||
host: 'localhost',
|
|
||||||
port: 12345,
|
|
||||||
addresses: ['192.168.1.100', '192.168.1.100', '10.0.0.1'],
|
|
||||||
referer: { address: '192.168.1.100' }
|
|
||||||
})
|
|
||||||
|
|
||||||
const services = localTransferService.getState().services
|
|
||||||
expect(services[0].addresses).toHaveLength(2)
|
|
||||||
expect(services[0].addresses).toContain('192.168.1.100')
|
|
||||||
expect(services[0].addresses).toContain('10.0.0.1')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should filter empty addresses', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
mockBrowser.emit('up', {
|
|
||||||
name: 'Test Service',
|
|
||||||
host: 'localhost',
|
|
||||||
port: 12345,
|
|
||||||
addresses: ['192.168.1.100', '', null as any]
|
|
||||||
})
|
|
||||||
|
|
||||||
const services = localTransferService.getState().services
|
|
||||||
expect(services[0].addresses).toEqual(['192.168.1.100'])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should convert txt null/undefined values to empty strings', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
mockBrowser.emit('up', {
|
|
||||||
name: 'Test Service',
|
|
||||||
host: 'localhost',
|
|
||||||
port: 12345,
|
|
||||||
addresses: ['192.168.1.100'],
|
|
||||||
txt: {
|
|
||||||
version: '1.0',
|
|
||||||
nullValue: null,
|
|
||||||
undefinedValue: undefined,
|
|
||||||
numberValue: 42
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
const services = localTransferService.getState().services
|
|
||||||
expect(services[0].txt).toEqual({
|
|
||||||
version: '1.0',
|
|
||||||
nullValue: '',
|
|
||||||
undefinedValue: '',
|
|
||||||
numberValue: '42'
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should not include txt when empty', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
mockBrowser.emit('up', {
|
|
||||||
name: 'Test Service',
|
|
||||||
host: 'localhost',
|
|
||||||
port: 12345,
|
|
||||||
addresses: ['192.168.1.100'],
|
|
||||||
txt: {}
|
|
||||||
})
|
|
||||||
|
|
||||||
const services = localTransferService.getState().services
|
|
||||||
expect(services[0].txt).toBeUndefined()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('dispose', () => {
|
|
||||||
it('should clean up all resources', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
mockBrowser.emit('up', {
|
|
||||||
name: 'Test Service',
|
|
||||||
host: 'localhost',
|
|
||||||
port: 12345,
|
|
||||||
addresses: ['192.168.1.100']
|
|
||||||
})
|
|
||||||
|
|
||||||
localTransferService.dispose()
|
|
||||||
|
|
||||||
expect(localTransferService.getState().services).toHaveLength(0)
|
|
||||||
expect(localTransferService.getState().isScanning).toBe(false)
|
|
||||||
expect(mockBrowser.removeAllListeners).toHaveBeenCalled()
|
|
||||||
expect(mockBonjour.destroy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle bonjour.destroy() error gracefully', async () => {
|
|
||||||
mockBonjour.destroy.mockImplementation(() => {
|
|
||||||
throw new Error('Destroy failed')
|
|
||||||
})
|
|
||||||
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
// Should not throw
|
|
||||||
expect(() => localTransferService.dispose()).not.toThrow()
|
|
||||||
expect(mockLogger.warn).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should be safe to call multiple times', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
expect(() => {
|
|
||||||
localTransferService.dispose()
|
|
||||||
localTransferService.dispose()
|
|
||||||
}).not.toThrow()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('broadcastState', () => {
|
|
||||||
it('should not throw when main window is null', async () => {
|
|
||||||
mockMainWindow = null
|
|
||||||
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
// Should not throw
|
|
||||||
expect(() => localTransferService.startDiscovery()).not.toThrow()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should not throw when main window is destroyed', async () => {
|
|
||||||
mockMainWindow = {
|
|
||||||
isDestroyed: vi.fn(() => true),
|
|
||||||
webContents: { send: vi.fn() }
|
|
||||||
}
|
|
||||||
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
// Should not throw
|
|
||||||
expect(() => localTransferService.startDiscovery()).not.toThrow()
|
|
||||||
expect(mockMainWindow.webContents.send).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('restartBrowser', () => {
|
|
||||||
it('should destroy old bonjour instance to prevent socket leaks', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
// First start
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
expect(mockBonjour.destroy).not.toHaveBeenCalled()
|
|
||||||
|
|
||||||
// Restart - should destroy old instance
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
expect(mockBonjour.destroy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should remove all listeners from old browser', async () => {
|
|
||||||
const { localTransferService } = await import('../LocalTransferService')
|
|
||||||
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
localTransferService.startDiscovery()
|
|
||||||
|
|
||||||
expect(mockBrowser.removeAllListeners).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@ -15,8 +15,8 @@ import { query } from '@anthropic-ai/claude-agent-sdk'
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { config as apiConfigService } from '@main/apiServer/config'
|
import { config as apiConfigService } from '@main/apiServer/config'
|
||||||
import { validateModelId } from '@main/apiServer/utils'
|
import { validateModelId } from '@main/apiServer/utils'
|
||||||
import { isWin } from '@main/constant'
|
import { ConfigKeys, configManager } from '@main/services/ConfigManager'
|
||||||
import { autoDiscoverGitBash } from '@main/utils/process'
|
import { validateGitBashPath } from '@main/utils/process'
|
||||||
import getLoginShellEnvironment from '@main/utils/shell-env'
|
import getLoginShellEnvironment from '@main/utils/shell-env'
|
||||||
import { app } from 'electron'
|
import { app } from 'electron'
|
||||||
|
|
||||||
@ -109,8 +109,7 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy'))
|
Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy'))
|
||||||
) as Record<string, string>
|
) as Record<string, string>
|
||||||
|
|
||||||
// Auto-discover Git Bash path on Windows (already logs internally)
|
const customGitBashPath = validateGitBashPath(configManager.get(ConfigKeys.GitBashPath) as string | undefined)
|
||||||
const customGitBashPath = isWin ? autoDiscoverGitBash() : null
|
|
||||||
|
|
||||||
const env = {
|
const env = {
|
||||||
...loginShellEnvWithoutProxies,
|
...loginShellEnvWithoutProxies,
|
||||||
|
|||||||
@ -1,525 +0,0 @@
|
|||||||
import * as crypto from 'node:crypto'
|
|
||||||
import { createConnection, type Socket } from 'node:net'
|
|
||||||
|
|
||||||
import { loggerService } from '@logger'
|
|
||||||
import type {
|
|
||||||
LanClientEvent,
|
|
||||||
LanFileCompleteMessage,
|
|
||||||
LanHandshakeAckMessage,
|
|
||||||
LocalTransferConnectPayload,
|
|
||||||
LocalTransferPeer
|
|
||||||
} from '@shared/config/types'
|
|
||||||
import { LAN_TRANSFER_GLOBAL_TIMEOUT_MS } from '@shared/config/types'
|
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
|
||||||
|
|
||||||
import { localTransferService } from '../LocalTransferService'
|
|
||||||
import { windowService } from '../WindowService'
|
|
||||||
import {
|
|
||||||
abortTransfer,
|
|
||||||
buildHandshakeMessage,
|
|
||||||
calculateFileChecksum,
|
|
||||||
cleanupTransfer,
|
|
||||||
createDataHandler,
|
|
||||||
createTransferState,
|
|
||||||
formatFileSize,
|
|
||||||
HANDSHAKE_PROTOCOL_VERSION,
|
|
||||||
pickHost,
|
|
||||||
sendFileEnd,
|
|
||||||
sendFileStart,
|
|
||||||
sendTestPing,
|
|
||||||
streamFileChunks,
|
|
||||||
validateFile,
|
|
||||||
waitForFileComplete,
|
|
||||||
waitForFileStartAck
|
|
||||||
} from './handlers'
|
|
||||||
import { ResponseManager } from './responseManager'
|
|
||||||
import type { ActiveFileTransfer, ConnectionContext, FileTransferContext } from './types'
|
|
||||||
|
|
||||||
const DEFAULT_HANDSHAKE_TIMEOUT_MS = 10_000
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('LanTransferClientService')
|
|
||||||
|
|
||||||
/**
|
|
||||||
* LAN Transfer Client Service
|
|
||||||
*
|
|
||||||
* Handles outgoing file transfers to LAN peers via TCP.
|
|
||||||
* Protocol v1 with streaming mode (no per-chunk acknowledgment).
|
|
||||||
*/
|
|
||||||
class LanTransferClientService {
|
|
||||||
private socket: Socket | null = null
|
|
||||||
private currentPeer?: LocalTransferPeer
|
|
||||||
private dataHandler?: ReturnType<typeof createDataHandler>
|
|
||||||
private responseManager = new ResponseManager()
|
|
||||||
private isConnecting = false
|
|
||||||
private activeTransfer?: ActiveFileTransfer
|
|
||||||
private lastConnectOptions?: LocalTransferConnectPayload
|
|
||||||
private consecutiveJsonErrors = 0
|
|
||||||
private static readonly MAX_CONSECUTIVE_JSON_ERRORS = 3
|
|
||||||
private reconnectPromise: Promise<void> | null = null
|
|
||||||
|
|
||||||
constructor() {
|
|
||||||
this.responseManager.setTimeoutCallback(() => void this.disconnect())
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Connect to a LAN peer and perform handshake.
|
|
||||||
*/
|
|
||||||
public async connectAndHandshake(options: LocalTransferConnectPayload): Promise<LanHandshakeAckMessage> {
|
|
||||||
if (this.isConnecting) {
|
|
||||||
throw new Error('LAN transfer client is busy')
|
|
||||||
}
|
|
||||||
|
|
||||||
const peer = localTransferService.getPeerById(options.peerId)
|
|
||||||
if (!peer) {
|
|
||||||
throw new Error('Selected LAN peer is no longer available')
|
|
||||||
}
|
|
||||||
if (!peer.port) {
|
|
||||||
throw new Error('Selected peer does not expose a TCP port')
|
|
||||||
}
|
|
||||||
|
|
||||||
const host = pickHost(peer)
|
|
||||||
if (!host) {
|
|
||||||
throw new Error('Unable to resolve a reachable host for the peer')
|
|
||||||
}
|
|
||||||
|
|
||||||
await this.disconnect()
|
|
||||||
this.isConnecting = true
|
|
||||||
|
|
||||||
return new Promise<LanHandshakeAckMessage>((resolve, reject) => {
|
|
||||||
const socket = createConnection({ host, port: peer.port as number }, () => {
|
|
||||||
logger.info(`Connected to LAN peer ${peer.name} (${host}:${peer.port})`)
|
|
||||||
socket.setKeepAlive(true, 30_000)
|
|
||||||
this.socket = socket
|
|
||||||
this.currentPeer = peer
|
|
||||||
this.attachSocketListeners(socket)
|
|
||||||
|
|
||||||
this.responseManager.waitForResponse(
|
|
||||||
'handshake_ack',
|
|
||||||
options.timeoutMs ?? DEFAULT_HANDSHAKE_TIMEOUT_MS,
|
|
||||||
(payload) => {
|
|
||||||
const ack = payload as LanHandshakeAckMessage
|
|
||||||
if (!ack.accepted) {
|
|
||||||
const message = ack.message || 'Handshake rejected by remote device'
|
|
||||||
logger.warn(`Handshake rejected by ${peer.name}: ${message}`)
|
|
||||||
this.broadcastClientEvent({
|
|
||||||
type: 'error',
|
|
||||||
message,
|
|
||||||
timestamp: Date.now()
|
|
||||||
})
|
|
||||||
reject(new Error(message))
|
|
||||||
void this.disconnect()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.info(`Handshake accepted by ${peer.name}`)
|
|
||||||
socket.setTimeout(0)
|
|
||||||
this.isConnecting = false
|
|
||||||
this.lastConnectOptions = options
|
|
||||||
sendTestPing(this.createConnectionContext())
|
|
||||||
resolve(ack)
|
|
||||||
},
|
|
||||||
(error) => {
|
|
||||||
this.isConnecting = false
|
|
||||||
reject(error)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
const handshakeMessage = buildHandshakeMessage()
|
|
||||||
this.sendControlMessage(handshakeMessage)
|
|
||||||
})
|
|
||||||
|
|
||||||
socket.setTimeout(options.timeoutMs ?? DEFAULT_HANDSHAKE_TIMEOUT_MS, () => {
|
|
||||||
const error = new Error('Handshake timed out')
|
|
||||||
logger.error('LAN transfer socket timeout', error)
|
|
||||||
this.broadcastClientEvent({
|
|
||||||
type: 'error',
|
|
||||||
message: error.message,
|
|
||||||
timestamp: Date.now()
|
|
||||||
})
|
|
||||||
reject(error)
|
|
||||||
socket.destroy(error)
|
|
||||||
void this.disconnect()
|
|
||||||
})
|
|
||||||
|
|
||||||
socket.once('error', (error) => {
|
|
||||||
logger.error('LAN transfer socket error', error as Error)
|
|
||||||
const message = error instanceof Error ? error.message : String(error)
|
|
||||||
this.broadcastClientEvent({
|
|
||||||
type: 'error',
|
|
||||||
message,
|
|
||||||
timestamp: Date.now()
|
|
||||||
})
|
|
||||||
this.isConnecting = false
|
|
||||||
reject(error instanceof Error ? error : new Error(message))
|
|
||||||
void this.disconnect()
|
|
||||||
})
|
|
||||||
|
|
||||||
socket.once('close', () => {
|
|
||||||
logger.info('LAN transfer socket closed')
|
|
||||||
if (this.socket === socket) {
|
|
||||||
this.socket = null
|
|
||||||
this.dataHandler?.resetBuffer()
|
|
||||||
this.responseManager.rejectAll(new Error('LAN transfer socket closed'))
|
|
||||||
this.currentPeer = undefined
|
|
||||||
abortTransfer(this.activeTransfer, new Error('LAN transfer socket closed'))
|
|
||||||
}
|
|
||||||
this.isConnecting = false
|
|
||||||
this.broadcastClientEvent({
|
|
||||||
type: 'socket_closed',
|
|
||||||
reason: 'connection_closed',
|
|
||||||
timestamp: Date.now()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Disconnect from the current peer.
|
|
||||||
*/
|
|
||||||
public async disconnect(): Promise<void> {
|
|
||||||
const socket = this.socket
|
|
||||||
if (!socket) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
this.socket = null
|
|
||||||
this.dataHandler?.resetBuffer()
|
|
||||||
this.currentPeer = undefined
|
|
||||||
this.responseManager.rejectAll(new Error('LAN transfer socket disconnected'))
|
|
||||||
abortTransfer(this.activeTransfer, new Error('LAN transfer socket disconnected'))
|
|
||||||
|
|
||||||
const DISCONNECT_TIMEOUT_MS = 3000
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
const timeout = setTimeout(() => {
|
|
||||||
logger.warn('Disconnect timeout, forcing cleanup')
|
|
||||||
socket.removeAllListeners()
|
|
||||||
resolve()
|
|
||||||
}, DISCONNECT_TIMEOUT_MS)
|
|
||||||
|
|
||||||
socket.once('close', () => {
|
|
||||||
clearTimeout(timeout)
|
|
||||||
resolve()
|
|
||||||
})
|
|
||||||
|
|
||||||
socket.destroy()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Dispose the service and clean up all resources.
|
|
||||||
*/
|
|
||||||
public dispose(): void {
|
|
||||||
this.responseManager.rejectAll(new Error('LAN transfer client disposed'))
|
|
||||||
cleanupTransfer(this.activeTransfer)
|
|
||||||
this.activeTransfer = undefined
|
|
||||||
if (this.socket) {
|
|
||||||
this.socket.destroy()
|
|
||||||
this.socket = null
|
|
||||||
}
|
|
||||||
this.dataHandler?.resetBuffer()
|
|
||||||
this.isConnecting = false
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Send a ZIP file to the connected peer.
|
|
||||||
*/
|
|
||||||
public async sendFile(filePath: string): Promise<LanFileCompleteMessage> {
|
|
||||||
await this.ensureConnection()
|
|
||||||
|
|
||||||
if (this.activeTransfer) {
|
|
||||||
throw new Error('A file transfer is already in progress')
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate file
|
|
||||||
const { stats, fileName } = await validateFile(filePath)
|
|
||||||
|
|
||||||
// Calculate checksum
|
|
||||||
logger.info('Calculating file checksum...')
|
|
||||||
const checksum = await calculateFileChecksum(filePath)
|
|
||||||
logger.info(`File checksum: ${checksum.substring(0, 16)}...`)
|
|
||||||
|
|
||||||
// Connection can drop while validating/checking file; ensure it is still ready before starting transfer.
|
|
||||||
await this.ensureConnection()
|
|
||||||
|
|
||||||
// Initialize transfer state
|
|
||||||
const transferId = crypto.randomUUID()
|
|
||||||
this.activeTransfer = createTransferState(transferId, fileName, stats.size, checksum)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
`Starting file transfer: ${fileName} (${formatFileSize(stats.size)}, ${this.activeTransfer.totalChunks} chunks)`
|
|
||||||
)
|
|
||||||
|
|
||||||
// Global timeout
|
|
||||||
const globalTimeoutError = new Error('Transfer timed out (global timeout exceeded)')
|
|
||||||
const globalTimeoutHandle = setTimeout(() => {
|
|
||||||
logger.warn('Global transfer timeout exceeded, aborting transfer', { transferId, fileName })
|
|
||||||
abortTransfer(this.activeTransfer, globalTimeoutError)
|
|
||||||
}, LAN_TRANSFER_GLOBAL_TIMEOUT_MS)
|
|
||||||
|
|
||||||
try {
|
|
||||||
const result = await this.performFileTransfer(filePath, transferId, fileName)
|
|
||||||
return result
|
|
||||||
} catch (error) {
|
|
||||||
const message = error instanceof Error ? error.message : String(error)
|
|
||||||
logger.error(`File transfer failed: ${message}`)
|
|
||||||
|
|
||||||
this.broadcastClientEvent({
|
|
||||||
type: 'file_transfer_complete',
|
|
||||||
transferId,
|
|
||||||
fileName,
|
|
||||||
success: false,
|
|
||||||
error: message,
|
|
||||||
timestamp: Date.now()
|
|
||||||
})
|
|
||||||
|
|
||||||
throw error
|
|
||||||
} finally {
|
|
||||||
clearTimeout(globalTimeoutHandle)
|
|
||||||
cleanupTransfer(this.activeTransfer)
|
|
||||||
this.activeTransfer = undefined
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Cancel the current file transfer.
|
|
||||||
*/
|
|
||||||
public cancelTransfer(): void {
|
|
||||||
if (!this.activeTransfer) {
|
|
||||||
logger.warn('No active transfer to cancel')
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const { transferId, fileName } = this.activeTransfer
|
|
||||||
logger.info(`Cancelling file transfer: ${fileName}`)
|
|
||||||
|
|
||||||
this.activeTransfer.isCancelled = true
|
|
||||||
|
|
||||||
try {
|
|
||||||
this.sendControlMessage({
|
|
||||||
type: 'file_cancel',
|
|
||||||
transferId,
|
|
||||||
reason: 'Cancelled by user'
|
|
||||||
})
|
|
||||||
} catch (error) {
|
|
||||||
// Expected when connection is already broken
|
|
||||||
logger.warn('Failed to send cancel message', error as Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
abortTransfer(this.activeTransfer, new Error('Transfer cancelled by user'))
|
|
||||||
}
|
|
||||||
|
|
||||||
// =============================================================================
|
|
||||||
// Private Methods
|
|
||||||
// =============================================================================
|
|
||||||
|
|
||||||
private async ensureConnection(): Promise<void> {
|
|
||||||
// Check socket is valid and writable (not just undestroyed)
|
|
||||||
if (this.socket && !this.socket.destroyed && this.socket.writable && this.currentPeer) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!this.lastConnectOptions) {
|
|
||||||
throw new Error('No active connection. Please connect to a peer first.')
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prevent concurrent reconnection attempts
|
|
||||||
if (this.reconnectPromise) {
|
|
||||||
logger.debug('Waiting for existing reconnection attempt...')
|
|
||||||
await this.reconnectPromise
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info('Connection lost, attempting to reconnect...')
|
|
||||||
this.reconnectPromise = this.connectAndHandshake(this.lastConnectOptions)
|
|
||||||
.then(() => {
|
|
||||||
// Handshake succeeded, connection restored
|
|
||||||
})
|
|
||||||
.finally(() => {
|
|
||||||
this.reconnectPromise = null
|
|
||||||
})
|
|
||||||
|
|
||||||
await this.reconnectPromise
|
|
||||||
}
|
|
||||||
|
|
||||||
private async performFileTransfer(
|
|
||||||
filePath: string,
|
|
||||||
transferId: string,
|
|
||||||
fileName: string
|
|
||||||
): Promise<LanFileCompleteMessage> {
|
|
||||||
const transfer = this.activeTransfer!
|
|
||||||
const ctx = this.createFileTransferContext()
|
|
||||||
|
|
||||||
// Step 1: Send file_start
|
|
||||||
sendFileStart(ctx, transfer)
|
|
||||||
|
|
||||||
// Step 2: Wait for file_start_ack
|
|
||||||
const startAck = await waitForFileStartAck(ctx, transferId, transfer.abortController.signal)
|
|
||||||
if (!startAck.accepted) {
|
|
||||||
throw new Error(startAck.message || 'Transfer rejected by receiver')
|
|
||||||
}
|
|
||||||
logger.info('Received file_start_ack: accepted')
|
|
||||||
|
|
||||||
// Step 3: Stream file chunks
|
|
||||||
await streamFileChunks(this.socket!, filePath, transfer, transfer.abortController.signal, (bytesSent, chunkIndex) =>
|
|
||||||
this.onTransferProgress(transfer, bytesSent, chunkIndex)
|
|
||||||
)
|
|
||||||
|
|
||||||
// Step 4: Send file_end
|
|
||||||
sendFileEnd(ctx, transferId)
|
|
||||||
|
|
||||||
// Step 5: Wait for file_complete
|
|
||||||
const result = await waitForFileComplete(ctx, transferId, transfer.abortController.signal)
|
|
||||||
logger.info(`File transfer ${result.success ? 'completed' : 'failed'}`)
|
|
||||||
|
|
||||||
// Broadcast completion
|
|
||||||
this.broadcastClientEvent({
|
|
||||||
type: 'file_transfer_complete',
|
|
||||||
transferId,
|
|
||||||
fileName,
|
|
||||||
success: result.success,
|
|
||||||
filePath: result.filePath,
|
|
||||||
error: result.error,
|
|
||||||
timestamp: Date.now()
|
|
||||||
})
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
private onTransferProgress(transfer: ActiveFileTransfer, bytesSent: number, chunkIndex: number): void {
|
|
||||||
const progress = (bytesSent / transfer.fileSize) * 100
|
|
||||||
const elapsed = (Date.now() - transfer.startedAt) / 1000
|
|
||||||
const speed = elapsed > 0 ? bytesSent / elapsed : 0
|
|
||||||
|
|
||||||
this.broadcastClientEvent({
|
|
||||||
type: 'file_transfer_progress',
|
|
||||||
transferId: transfer.transferId,
|
|
||||||
fileName: transfer.fileName,
|
|
||||||
bytesSent,
|
|
||||||
totalBytes: transfer.fileSize,
|
|
||||||
chunkIndex,
|
|
||||||
totalChunks: transfer.totalChunks,
|
|
||||||
progress: Math.round(progress * 100) / 100,
|
|
||||||
speed,
|
|
||||||
timestamp: Date.now()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
private attachSocketListeners(socket: Socket): void {
|
|
||||||
this.dataHandler = createDataHandler((line) => this.handleControlLine(line))
|
|
||||||
socket.on('data', (chunk: Buffer) => {
|
|
||||||
try {
|
|
||||||
this.dataHandler?.handleData(chunk)
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('Data handler error', error as Error)
|
|
||||||
void this.disconnect()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
private handleControlLine(line: string): void {
|
|
||||||
let payload: Record<string, unknown>
|
|
||||||
try {
|
|
||||||
payload = JSON.parse(line)
|
|
||||||
this.consecutiveJsonErrors = 0 // Reset on successful parse
|
|
||||||
} catch {
|
|
||||||
this.consecutiveJsonErrors++
|
|
||||||
logger.warn('Received invalid JSON control message', { line, consecutiveErrors: this.consecutiveJsonErrors })
|
|
||||||
|
|
||||||
if (this.consecutiveJsonErrors >= LanTransferClientService.MAX_CONSECUTIVE_JSON_ERRORS) {
|
|
||||||
const message = `Protocol error: ${this.consecutiveJsonErrors} consecutive invalid messages, disconnecting`
|
|
||||||
logger.error(message)
|
|
||||||
this.broadcastClientEvent({
|
|
||||||
type: 'error',
|
|
||||||
message,
|
|
||||||
timestamp: Date.now()
|
|
||||||
})
|
|
||||||
void this.disconnect()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const type = payload?.type as string | undefined
|
|
||||||
if (!type) {
|
|
||||||
logger.warn('Received control message without type', payload)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to resolve a pending response
|
|
||||||
const transferId = payload?.transferId as string | undefined
|
|
||||||
const chunkIndex = payload?.chunkIndex as number | undefined
|
|
||||||
if (this.responseManager.tryResolve(type, payload, transferId, chunkIndex)) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info('Received control message', payload)
|
|
||||||
|
|
||||||
if (type === 'pong') {
|
|
||||||
this.broadcastClientEvent({
|
|
||||||
type: 'pong',
|
|
||||||
payload: payload?.payload as string | undefined,
|
|
||||||
received: payload?.received as boolean | undefined,
|
|
||||||
timestamp: Date.now()
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ignore late-arriving file transfer messages
|
|
||||||
const fileTransferMessageTypes = ['file_start_ack', 'file_complete']
|
|
||||||
if (fileTransferMessageTypes.includes(type)) {
|
|
||||||
logger.debug('Ignoring late file transfer message', { type, payload })
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
this.broadcastClientEvent({
|
|
||||||
type: 'error',
|
|
||||||
message: `Unexpected control message type: ${type}`,
|
|
||||||
timestamp: Date.now()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
private sendControlMessage(message: Record<string, unknown>): void {
|
|
||||||
if (!this.socket || this.socket.destroyed || !this.socket.writable) {
|
|
||||||
throw new Error('Socket is not connected')
|
|
||||||
}
|
|
||||||
const payload = JSON.stringify(message)
|
|
||||||
this.socket.write(`${payload}\n`)
|
|
||||||
}
|
|
||||||
|
|
||||||
private createConnectionContext(): ConnectionContext {
|
|
||||||
return {
|
|
||||||
socket: this.socket,
|
|
||||||
currentPeer: this.currentPeer,
|
|
||||||
sendControlMessage: (msg) => this.sendControlMessage(msg),
|
|
||||||
broadcastClientEvent: (event) => this.broadcastClientEvent(event)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private createFileTransferContext(): FileTransferContext {
|
|
||||||
return {
|
|
||||||
...this.createConnectionContext(),
|
|
||||||
activeTransfer: this.activeTransfer,
|
|
||||||
setActiveTransfer: (transfer) => {
|
|
||||||
this.activeTransfer = transfer
|
|
||||||
},
|
|
||||||
waitForResponse: (type, timeoutMs, resolve, reject, transferId, chunkIndex, abortSignal) => {
|
|
||||||
this.responseManager.waitForResponse(type, timeoutMs, resolve, reject, transferId, chunkIndex, abortSignal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private broadcastClientEvent(event: LanClientEvent): void {
|
|
||||||
const mainWindow = windowService.getMainWindow()
|
|
||||||
if (!mainWindow || mainWindow.isDestroyed()) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
mainWindow.webContents.send(IpcChannel.LocalTransfer_ClientEvent, {
|
|
||||||
...event,
|
|
||||||
peerId: event.peerId ?? this.currentPeer?.id,
|
|
||||||
peerName: event.peerName ?? this.currentPeer?.name
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const lanTransferClientService = new LanTransferClientService()
|
|
||||||
|
|
||||||
// Re-export for backward compatibility
|
|
||||||
export { HANDSHAKE_PROTOCOL_VERSION }
|
|
||||||
@ -1,133 +0,0 @@
|
|||||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
|
||||||
|
|
||||||
// Mock dependencies before importing the service
|
|
||||||
vi.mock('node:net', async (importOriginal) => {
|
|
||||||
const actual = (await importOriginal()) as Record<string, unknown>
|
|
||||||
return {
|
|
||||||
...actual,
|
|
||||||
createConnection: vi.fn()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
vi.mock('electron', () => ({
|
|
||||||
app: {
|
|
||||||
getName: vi.fn(() => 'Cherry Studio'),
|
|
||||||
getVersion: vi.fn(() => '1.0.0')
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('../../LocalTransferService', () => ({
|
|
||||||
localTransferService: {
|
|
||||||
getPeerById: vi.fn()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('../../WindowService', () => ({
|
|
||||||
windowService: {
|
|
||||||
getMainWindow: vi.fn(() => ({
|
|
||||||
isDestroyed: () => false,
|
|
||||||
webContents: {
|
|
||||||
send: vi.fn()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
// Import after mocks
|
|
||||||
import { localTransferService } from '../../LocalTransferService'
|
|
||||||
|
|
||||||
describe('LanTransferClientService', () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
vi.clearAllMocks()
|
|
||||||
vi.resetModules()
|
|
||||||
})
|
|
||||||
|
|
||||||
afterEach(() => {
|
|
||||||
vi.resetAllMocks()
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('connectAndHandshake - validation', () => {
|
|
||||||
it('should throw error when peer is not found', async () => {
|
|
||||||
vi.mocked(localTransferService.getPeerById).mockReturnValue(undefined)
|
|
||||||
|
|
||||||
const { lanTransferClientService } = await import('../LanTransferClientService')
|
|
||||||
|
|
||||||
await expect(
|
|
||||||
lanTransferClientService.connectAndHandshake({
|
|
||||||
peerId: 'non-existent'
|
|
||||||
})
|
|
||||||
).rejects.toThrow('Selected LAN peer is no longer available')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should throw error when peer has no port', async () => {
|
|
||||||
vi.mocked(localTransferService.getPeerById).mockReturnValue({
|
|
||||||
id: 'test-peer',
|
|
||||||
name: 'Test Peer',
|
|
||||||
addresses: ['192.168.1.100'],
|
|
||||||
updatedAt: Date.now()
|
|
||||||
})
|
|
||||||
|
|
||||||
const { lanTransferClientService } = await import('../LanTransferClientService')
|
|
||||||
|
|
||||||
await expect(
|
|
||||||
lanTransferClientService.connectAndHandshake({
|
|
||||||
peerId: 'test-peer'
|
|
||||||
})
|
|
||||||
).rejects.toThrow('Selected peer does not expose a TCP port')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should throw error when no reachable host', async () => {
|
|
||||||
vi.mocked(localTransferService.getPeerById).mockReturnValue({
|
|
||||||
id: 'test-peer',
|
|
||||||
name: 'Test Peer',
|
|
||||||
port: 12345,
|
|
||||||
addresses: [],
|
|
||||||
updatedAt: Date.now()
|
|
||||||
})
|
|
||||||
|
|
||||||
const { lanTransferClientService } = await import('../LanTransferClientService')
|
|
||||||
|
|
||||||
await expect(
|
|
||||||
lanTransferClientService.connectAndHandshake({
|
|
||||||
peerId: 'test-peer'
|
|
||||||
})
|
|
||||||
).rejects.toThrow('Unable to resolve a reachable host for the peer')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('cancelTransfer', () => {
|
|
||||||
it('should not throw when no active transfer', async () => {
|
|
||||||
const { lanTransferClientService } = await import('../LanTransferClientService')
|
|
||||||
|
|
||||||
// Should not throw, just log warning
|
|
||||||
expect(() => lanTransferClientService.cancelTransfer()).not.toThrow()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('dispose', () => {
|
|
||||||
it('should clean up resources without throwing', async () => {
|
|
||||||
const { lanTransferClientService } = await import('../LanTransferClientService')
|
|
||||||
|
|
||||||
// Should not throw
|
|
||||||
expect(() => lanTransferClientService.dispose()).not.toThrow()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('sendFile', () => {
|
|
||||||
it('should throw error when not connected', async () => {
|
|
||||||
const { lanTransferClientService } = await import('../LanTransferClientService')
|
|
||||||
|
|
||||||
await expect(lanTransferClientService.sendFile('/path/to/file.zip')).rejects.toThrow(
|
|
||||||
'No active connection. Please connect to a peer first.'
|
|
||||||
)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('HANDSHAKE_PROTOCOL_VERSION', () => {
|
|
||||||
it('should export protocol version', async () => {
|
|
||||||
const { HANDSHAKE_PROTOCOL_VERSION } = await import('../LanTransferClientService')
|
|
||||||
|
|
||||||
expect(HANDSHAKE_PROTOCOL_VERSION).toBe('1')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@ -1,103 +0,0 @@
|
|||||||
import { EventEmitter } from 'node:events'
|
|
||||||
import type { Socket } from 'node:net'
|
|
||||||
|
|
||||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
|
||||||
|
|
||||||
import { BINARY_TYPE_FILE_CHUNK, sendBinaryChunk } from '../binaryProtocol'
|
|
||||||
|
|
||||||
describe('binaryProtocol', () => {
|
|
||||||
describe('sendBinaryChunk', () => {
|
|
||||||
let mockSocket: Socket
|
|
||||||
let writtenBuffers: Buffer[]
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
writtenBuffers = []
|
|
||||||
mockSocket = Object.assign(new EventEmitter(), {
|
|
||||||
destroyed: false,
|
|
||||||
writable: true,
|
|
||||||
write: vi.fn((buffer: Buffer) => {
|
|
||||||
writtenBuffers.push(Buffer.from(buffer))
|
|
||||||
return true
|
|
||||||
}),
|
|
||||||
cork: vi.fn(),
|
|
||||||
uncork: vi.fn()
|
|
||||||
}) as unknown as Socket
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should send binary chunk with correct frame format', () => {
|
|
||||||
const transferId = 'test-uuid-1234'
|
|
||||||
const chunkIndex = 5
|
|
||||||
const data = Buffer.from('test data chunk')
|
|
||||||
|
|
||||||
const result = sendBinaryChunk(mockSocket, transferId, chunkIndex, data)
|
|
||||||
|
|
||||||
expect(result).toBe(true)
|
|
||||||
expect(mockSocket.cork).toHaveBeenCalled()
|
|
||||||
expect(mockSocket.uncork).toHaveBeenCalled()
|
|
||||||
expect(mockSocket.write).toHaveBeenCalledTimes(2)
|
|
||||||
|
|
||||||
// Verify header structure
|
|
||||||
const header = writtenBuffers[0]
|
|
||||||
|
|
||||||
// Magic bytes "CS"
|
|
||||||
expect(header[0]).toBe(0x43)
|
|
||||||
expect(header[1]).toBe(0x53)
|
|
||||||
|
|
||||||
// Type byte
|
|
||||||
const typeOffset = 2 + 4 // magic + totalLen
|
|
||||||
expect(header[typeOffset]).toBe(BINARY_TYPE_FILE_CHUNK)
|
|
||||||
|
|
||||||
// TransferId length
|
|
||||||
const tidLenOffset = typeOffset + 1
|
|
||||||
const tidLen = header.readUInt16BE(tidLenOffset)
|
|
||||||
expect(tidLen).toBe(Buffer.from(transferId).length)
|
|
||||||
|
|
||||||
// ChunkIndex
|
|
||||||
const chunkIdxOffset = tidLenOffset + 2 + tidLen
|
|
||||||
expect(header.readUInt32BE(chunkIdxOffset)).toBe(chunkIndex)
|
|
||||||
|
|
||||||
// Data buffer
|
|
||||||
expect(writtenBuffers[1].toString()).toBe('test data chunk')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should return false when socket write returns false (backpressure)', () => {
|
|
||||||
;(mockSocket.write as ReturnType<typeof vi.fn>).mockReturnValueOnce(false)
|
|
||||||
|
|
||||||
const result = sendBinaryChunk(mockSocket, 'test-id', 0, Buffer.from('data'))
|
|
||||||
|
|
||||||
expect(result).toBe(false)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should correctly calculate totalLen in frame header', () => {
|
|
||||||
const transferId = 'uuid-1234'
|
|
||||||
const data = Buffer.from('chunk data here')
|
|
||||||
|
|
||||||
sendBinaryChunk(mockSocket, transferId, 0, data)
|
|
||||||
|
|
||||||
const header = writtenBuffers[0]
|
|
||||||
const totalLen = header.readUInt32BE(2) // After magic bytes
|
|
||||||
|
|
||||||
// totalLen = type(1) + tidLen(2) + tid(n) + idx(4) + data(m)
|
|
||||||
const expectedTotalLen = 1 + 2 + Buffer.from(transferId).length + 4 + data.length
|
|
||||||
expect(totalLen).toBe(expectedTotalLen)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should throw error when socket is not writable', () => {
|
|
||||||
;(mockSocket as any).writable = false
|
|
||||||
|
|
||||||
expect(() => sendBinaryChunk(mockSocket, 'test-id', 0, Buffer.from('data'))).toThrow('Socket is not writable')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should throw error when socket is destroyed', () => {
|
|
||||||
;(mockSocket as any).destroyed = true
|
|
||||||
|
|
||||||
expect(() => sendBinaryChunk(mockSocket, 'test-id', 0, Buffer.from('data'))).toThrow('Socket is not writable')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('BINARY_TYPE_FILE_CHUNK', () => {
|
|
||||||
it('should be 0x01', () => {
|
|
||||||
expect(BINARY_TYPE_FILE_CHUNK).toBe(0x01)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@ -1,265 +0,0 @@
|
|||||||
import { EventEmitter } from 'node:events'
|
|
||||||
import type { Socket } from 'node:net'
|
|
||||||
|
|
||||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
|
||||||
|
|
||||||
import {
|
|
||||||
buildHandshakeMessage,
|
|
||||||
createDataHandler,
|
|
||||||
getAbortError,
|
|
||||||
HANDSHAKE_PROTOCOL_VERSION,
|
|
||||||
pickHost,
|
|
||||||
waitForSocketDrain
|
|
||||||
} from '../../handlers/connection'
|
|
||||||
|
|
||||||
// Mock electron app
|
|
||||||
vi.mock('electron', () => ({
|
|
||||||
app: {
|
|
||||||
getName: vi.fn(() => 'Cherry Studio'),
|
|
||||||
getVersion: vi.fn(() => '1.0.0')
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
describe('connection handlers', () => {
|
|
||||||
describe('buildHandshakeMessage', () => {
|
|
||||||
it('should build handshake message with correct structure', () => {
|
|
||||||
const message = buildHandshakeMessage()
|
|
||||||
|
|
||||||
expect(message.type).toBe('handshake')
|
|
||||||
expect(message.deviceName).toBe('Cherry Studio')
|
|
||||||
expect(message.version).toBe(HANDSHAKE_PROTOCOL_VERSION)
|
|
||||||
expect(message.appVersion).toBe('1.0.0')
|
|
||||||
expect(typeof message.platform).toBe('string')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should use protocol version 1', () => {
|
|
||||||
expect(HANDSHAKE_PROTOCOL_VERSION).toBe('1')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('pickHost', () => {
|
|
||||||
it('should prefer IPv4 addresses', () => {
|
|
||||||
const peer = {
|
|
||||||
id: '1',
|
|
||||||
name: 'Test',
|
|
||||||
addresses: ['fe80::1', '192.168.1.100', '::1'],
|
|
||||||
updatedAt: Date.now()
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(pickHost(peer)).toBe('192.168.1.100')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should fall back to first address if no IPv4', () => {
|
|
||||||
const peer = {
|
|
||||||
id: '1',
|
|
||||||
name: 'Test',
|
|
||||||
addresses: ['fe80::1', '::1'],
|
|
||||||
updatedAt: Date.now()
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(pickHost(peer)).toBe('fe80::1')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should fall back to host property if no addresses', () => {
|
|
||||||
const peer = {
|
|
||||||
id: '1',
|
|
||||||
name: 'Test',
|
|
||||||
host: 'example.local',
|
|
||||||
addresses: [],
|
|
||||||
updatedAt: Date.now()
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(pickHost(peer)).toBe('example.local')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should return undefined if no addresses or host', () => {
|
|
||||||
const peer = {
|
|
||||||
id: '1',
|
|
||||||
name: 'Test',
|
|
||||||
addresses: [],
|
|
||||||
updatedAt: Date.now()
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(pickHost(peer)).toBeUndefined()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('createDataHandler', () => {
|
|
||||||
it('should parse complete lines from buffer', () => {
|
|
||||||
const lines: string[] = []
|
|
||||||
const handler = createDataHandler((line) => lines.push(line))
|
|
||||||
|
|
||||||
handler.handleData(Buffer.from('{"type":"test"}\n'))
|
|
||||||
|
|
||||||
expect(lines).toEqual(['{"type":"test"}'])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle partial lines across multiple chunks', () => {
|
|
||||||
const lines: string[] = []
|
|
||||||
const handler = createDataHandler((line) => lines.push(line))
|
|
||||||
|
|
||||||
handler.handleData(Buffer.from('{"type":'))
|
|
||||||
handler.handleData(Buffer.from('"test"}\n'))
|
|
||||||
|
|
||||||
expect(lines).toEqual(['{"type":"test"}'])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle multiple lines in single chunk', () => {
|
|
||||||
const lines: string[] = []
|
|
||||||
const handler = createDataHandler((line) => lines.push(line))
|
|
||||||
|
|
||||||
handler.handleData(Buffer.from('{"a":1}\n{"b":2}\n'))
|
|
||||||
|
|
||||||
expect(lines).toEqual(['{"a":1}', '{"b":2}'])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should reset buffer', () => {
|
|
||||||
const lines: string[] = []
|
|
||||||
const handler = createDataHandler((line) => lines.push(line))
|
|
||||||
|
|
||||||
handler.handleData(Buffer.from('partial'))
|
|
||||||
handler.resetBuffer()
|
|
||||||
handler.handleData(Buffer.from('{"complete":true}\n'))
|
|
||||||
|
|
||||||
expect(lines).toEqual(['{"complete":true}'])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should trim whitespace from lines', () => {
|
|
||||||
const lines: string[] = []
|
|
||||||
const handler = createDataHandler((line) => lines.push(line))
|
|
||||||
|
|
||||||
handler.handleData(Buffer.from(' {"type":"test"} \n'))
|
|
||||||
|
|
||||||
expect(lines).toEqual(['{"type":"test"}'])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should skip empty lines', () => {
|
|
||||||
const lines: string[] = []
|
|
||||||
const handler = createDataHandler((line) => lines.push(line))
|
|
||||||
|
|
||||||
handler.handleData(Buffer.from('\n\n{"type":"test"}\n\n'))
|
|
||||||
|
|
||||||
expect(lines).toEqual(['{"type":"test"}'])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should throw error when buffer exceeds MAX_LINE_BUFFER_SIZE', () => {
|
|
||||||
const handler = createDataHandler(vi.fn())
|
|
||||||
|
|
||||||
// Create a buffer larger than 1MB (MAX_LINE_BUFFER_SIZE)
|
|
||||||
const largeData = 'x'.repeat(1024 * 1024 + 1)
|
|
||||||
|
|
||||||
expect(() => handler.handleData(Buffer.from(largeData))).toThrow('Control message too large')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should reset buffer after exceeding MAX_LINE_BUFFER_SIZE', () => {
|
|
||||||
const lines: string[] = []
|
|
||||||
const handler = createDataHandler((line) => lines.push(line))
|
|
||||||
|
|
||||||
// Create a buffer larger than 1MB
|
|
||||||
const largeData = 'x'.repeat(1024 * 1024 + 1)
|
|
||||||
|
|
||||||
try {
|
|
||||||
handler.handleData(Buffer.from(largeData))
|
|
||||||
} catch {
|
|
||||||
// Expected error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Buffer should be reset, so lineBuffer should be empty
|
|
||||||
expect(handler.lineBuffer).toBe('')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('waitForSocketDrain', () => {
|
|
||||||
let mockSocket: Socket & EventEmitter
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
mockSocket = Object.assign(new EventEmitter(), {
|
|
||||||
destroyed: false,
|
|
||||||
writable: true,
|
|
||||||
write: vi.fn(),
|
|
||||||
off: vi.fn(),
|
|
||||||
removeAllListeners: vi.fn()
|
|
||||||
}) as unknown as Socket & EventEmitter
|
|
||||||
})
|
|
||||||
|
|
||||||
afterEach(() => {
|
|
||||||
vi.resetAllMocks()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should throw error when abort signal is already aborted', async () => {
|
|
||||||
const abortController = new AbortController()
|
|
||||||
abortController.abort(new Error('Already aborted'))
|
|
||||||
|
|
||||||
await expect(waitForSocketDrain(mockSocket, abortController.signal)).rejects.toThrow('Already aborted')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should throw error when socket is destroyed', async () => {
|
|
||||||
;(mockSocket as any).destroyed = true
|
|
||||||
const abortController = new AbortController()
|
|
||||||
|
|
||||||
await expect(waitForSocketDrain(mockSocket, abortController.signal)).rejects.toThrow('Socket is closed')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should resolve when drain event is emitted', async () => {
|
|
||||||
const abortController = new AbortController()
|
|
||||||
|
|
||||||
const drainPromise = waitForSocketDrain(mockSocket, abortController.signal)
|
|
||||||
|
|
||||||
// Emit drain event after a short delay
|
|
||||||
setImmediate(() => mockSocket.emit('drain'))
|
|
||||||
|
|
||||||
await expect(drainPromise).resolves.toBeUndefined()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should reject when close event is emitted', async () => {
|
|
||||||
const abortController = new AbortController()
|
|
||||||
|
|
||||||
const drainPromise = waitForSocketDrain(mockSocket, abortController.signal)
|
|
||||||
|
|
||||||
setImmediate(() => mockSocket.emit('close'))
|
|
||||||
|
|
||||||
await expect(drainPromise).rejects.toThrow('Socket closed while waiting for drain')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should reject when error event is emitted', async () => {
|
|
||||||
const abortController = new AbortController()
|
|
||||||
|
|
||||||
const drainPromise = waitForSocketDrain(mockSocket, abortController.signal)
|
|
||||||
|
|
||||||
setImmediate(() => mockSocket.emit('error', new Error('Network error')))
|
|
||||||
|
|
||||||
await expect(drainPromise).rejects.toThrow('Network error')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should reject when abort signal is triggered', async () => {
|
|
||||||
const abortController = new AbortController()
|
|
||||||
|
|
||||||
const drainPromise = waitForSocketDrain(mockSocket, abortController.signal)
|
|
||||||
|
|
||||||
setImmediate(() => abortController.abort(new Error('User cancelled')))
|
|
||||||
|
|
||||||
await expect(drainPromise).rejects.toThrow('User cancelled')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('getAbortError', () => {
|
|
||||||
it('should return Error reason directly', () => {
|
|
||||||
const originalError = new Error('Original')
|
|
||||||
const signal = { aborted: true, reason: originalError } as AbortSignal
|
|
||||||
|
|
||||||
expect(getAbortError(signal, 'Fallback')).toBe(originalError)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should create Error from string reason', () => {
|
|
||||||
const signal = { aborted: true, reason: 'String reason' } as AbortSignal
|
|
||||||
|
|
||||||
expect(getAbortError(signal, 'Fallback').message).toBe('String reason')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should use fallback for empty reason', () => {
|
|
||||||
const signal = { aborted: true, reason: '' } as AbortSignal
|
|
||||||
|
|
||||||
expect(getAbortError(signal, 'Fallback').message).toBe('Fallback')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@ -1,216 +0,0 @@
|
|||||||
import { EventEmitter } from 'node:events'
|
|
||||||
import type * as fs from 'node:fs'
|
|
||||||
import type { Socket } from 'node:net'
|
|
||||||
|
|
||||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
|
||||||
|
|
||||||
import {
|
|
||||||
abortTransfer,
|
|
||||||
cleanupTransfer,
|
|
||||||
createTransferState,
|
|
||||||
formatFileSize,
|
|
||||||
streamFileChunks
|
|
||||||
} from '../../handlers/fileTransfer'
|
|
||||||
import type { ActiveFileTransfer } from '../../types'
|
|
||||||
|
|
||||||
// Mock binaryProtocol
|
|
||||||
vi.mock('../../binaryProtocol', () => ({
|
|
||||||
sendBinaryChunk: vi.fn().mockReturnValue(true)
|
|
||||||
}))
|
|
||||||
|
|
||||||
// Mock connection handlers
|
|
||||||
vi.mock('./connection', () => ({
|
|
||||||
waitForSocketDrain: vi.fn().mockResolvedValue(undefined),
|
|
||||||
getAbortError: vi.fn((signal, fallback) => {
|
|
||||||
const reason = (signal as AbortSignal & { reason?: unknown }).reason
|
|
||||||
if (reason instanceof Error) return reason
|
|
||||||
if (typeof reason === 'string' && reason.length > 0) return new Error(reason)
|
|
||||||
return new Error(fallback)
|
|
||||||
})
|
|
||||||
}))
|
|
||||||
|
|
||||||
// Note: validateFile and calculateFileChecksum tests are skipped because
|
|
||||||
// the test environment has globally mocked node:fs and node:os modules.
|
|
||||||
// These functions are tested through integration tests instead.
|
|
||||||
|
|
||||||
describe('fileTransfer handlers', () => {
|
|
||||||
describe('createTransferState', () => {
|
|
||||||
it('should create transfer state with correct defaults', () => {
|
|
||||||
const state = createTransferState('uuid-123', 'test.zip', 1024000, 'abc123')
|
|
||||||
|
|
||||||
expect(state.transferId).toBe('uuid-123')
|
|
||||||
expect(state.fileName).toBe('test.zip')
|
|
||||||
expect(state.fileSize).toBe(1024000)
|
|
||||||
expect(state.checksum).toBe('abc123')
|
|
||||||
expect(state.bytesSent).toBe(0)
|
|
||||||
expect(state.currentChunk).toBe(0)
|
|
||||||
expect(state.isCancelled).toBe(false)
|
|
||||||
expect(state.abortController).toBeInstanceOf(AbortController)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should calculate totalChunks based on chunk size', () => {
|
|
||||||
// 512KB chunk size
|
|
||||||
const state = createTransferState('id', 'test.zip', 1024 * 1024, 'checksum') // 1MB
|
|
||||||
|
|
||||||
expect(state.totalChunks).toBe(2) // 1MB / 512KB = 2
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('abortTransfer', () => {
|
|
||||||
it('should abort transfer and destroy stream', () => {
|
|
||||||
const mockStream = {
|
|
||||||
destroyed: false,
|
|
||||||
destroy: vi.fn()
|
|
||||||
} as unknown as fs.ReadStream
|
|
||||||
|
|
||||||
const transfer: ActiveFileTransfer = {
|
|
||||||
transferId: 'test',
|
|
||||||
fileName: 'test.zip',
|
|
||||||
fileSize: 1000,
|
|
||||||
checksum: 'abc',
|
|
||||||
totalChunks: 1,
|
|
||||||
chunkSize: 512000,
|
|
||||||
bytesSent: 0,
|
|
||||||
currentChunk: 0,
|
|
||||||
startedAt: Date.now(),
|
|
||||||
stream: mockStream,
|
|
||||||
isCancelled: false,
|
|
||||||
abortController: new AbortController()
|
|
||||||
}
|
|
||||||
|
|
||||||
const error = new Error('Test abort')
|
|
||||||
abortTransfer(transfer, error)
|
|
||||||
|
|
||||||
expect(transfer.isCancelled).toBe(true)
|
|
||||||
expect(transfer.abortController.signal.aborted).toBe(true)
|
|
||||||
expect(mockStream.destroy).toHaveBeenCalledWith(error)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle undefined transfer', () => {
|
|
||||||
expect(() => abortTransfer(undefined, new Error('test'))).not.toThrow()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should not abort already aborted controller', () => {
|
|
||||||
const transfer: ActiveFileTransfer = {
|
|
||||||
transferId: 'test',
|
|
||||||
fileName: 'test.zip',
|
|
||||||
fileSize: 1000,
|
|
||||||
checksum: 'abc',
|
|
||||||
totalChunks: 1,
|
|
||||||
chunkSize: 512000,
|
|
||||||
bytesSent: 0,
|
|
||||||
currentChunk: 0,
|
|
||||||
startedAt: Date.now(),
|
|
||||||
isCancelled: false,
|
|
||||||
abortController: new AbortController()
|
|
||||||
}
|
|
||||||
|
|
||||||
transfer.abortController.abort()
|
|
||||||
|
|
||||||
// Should not throw when aborting again
|
|
||||||
expect(() => abortTransfer(transfer, new Error('test'))).not.toThrow()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('cleanupTransfer', () => {
|
|
||||||
it('should cleanup transfer resources', () => {
|
|
||||||
const mockStream = {
|
|
||||||
destroyed: false,
|
|
||||||
destroy: vi.fn()
|
|
||||||
} as unknown as fs.ReadStream
|
|
||||||
|
|
||||||
const transfer: ActiveFileTransfer = {
|
|
||||||
transferId: 'test',
|
|
||||||
fileName: 'test.zip',
|
|
||||||
fileSize: 1000,
|
|
||||||
checksum: 'abc',
|
|
||||||
totalChunks: 1,
|
|
||||||
chunkSize: 512000,
|
|
||||||
bytesSent: 0,
|
|
||||||
currentChunk: 0,
|
|
||||||
startedAt: Date.now(),
|
|
||||||
stream: mockStream,
|
|
||||||
isCancelled: false,
|
|
||||||
abortController: new AbortController()
|
|
||||||
}
|
|
||||||
|
|
||||||
cleanupTransfer(transfer)
|
|
||||||
|
|
||||||
expect(transfer.abortController.signal.aborted).toBe(true)
|
|
||||||
expect(mockStream.destroy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle undefined transfer', () => {
|
|
||||||
expect(() => cleanupTransfer(undefined)).not.toThrow()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('formatFileSize', () => {
|
|
||||||
it('should format 0 bytes', () => {
|
|
||||||
expect(formatFileSize(0)).toBe('0 B')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should format bytes', () => {
|
|
||||||
expect(formatFileSize(500)).toBe('500 B')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should format kilobytes', () => {
|
|
||||||
expect(formatFileSize(1024)).toBe('1 KB')
|
|
||||||
expect(formatFileSize(2048)).toBe('2 KB')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should format megabytes', () => {
|
|
||||||
expect(formatFileSize(1024 * 1024)).toBe('1 MB')
|
|
||||||
expect(formatFileSize(5 * 1024 * 1024)).toBe('5 MB')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should format gigabytes', () => {
|
|
||||||
expect(formatFileSize(1024 * 1024 * 1024)).toBe('1 GB')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should format with decimal precision', () => {
|
|
||||||
expect(formatFileSize(1536)).toBe('1.5 KB')
|
|
||||||
expect(formatFileSize(1.5 * 1024 * 1024)).toBe('1.5 MB')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
// Note: streamFileChunks tests require careful mocking of fs.createReadStream
|
|
||||||
// which is globally mocked in the test environment. These tests verify the
|
|
||||||
// streaming logic works correctly with mock streams.
|
|
||||||
describe('streamFileChunks', () => {
|
|
||||||
let mockSocket: Socket & EventEmitter
|
|
||||||
let mockProgress: ReturnType<typeof vi.fn>
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
vi.clearAllMocks()
|
|
||||||
|
|
||||||
mockSocket = Object.assign(new EventEmitter(), {
|
|
||||||
destroyed: false,
|
|
||||||
writable: true,
|
|
||||||
write: vi.fn().mockReturnValue(true),
|
|
||||||
cork: vi.fn(),
|
|
||||||
uncork: vi.fn()
|
|
||||||
}) as unknown as Socket & EventEmitter
|
|
||||||
|
|
||||||
mockProgress = vi.fn()
|
|
||||||
})
|
|
||||||
|
|
||||||
afterEach(() => {
|
|
||||||
vi.resetAllMocks()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should throw when abort signal is already aborted', async () => {
|
|
||||||
const transfer = createTransferState('test-id', 'test.zip', 1024, 'checksum')
|
|
||||||
transfer.abortController.abort(new Error('Already cancelled'))
|
|
||||||
|
|
||||||
await expect(
|
|
||||||
streamFileChunks(mockSocket, '/fake/path.zip', transfer, transfer.abortController.signal, mockProgress)
|
|
||||||
).rejects.toThrow()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Note: Full integration testing of streamFileChunks with actual file streaming
|
|
||||||
// requires a real file system, which cannot be easily mocked in ESM.
|
|
||||||
// The abort signal test above verifies the early abort path.
|
|
||||||
// Additional streaming tests are covered through integration tests.
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@ -1,177 +0,0 @@
|
|||||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
|
||||||
|
|
||||||
import { ResponseManager } from '../responseManager'
|
|
||||||
|
|
||||||
describe('ResponseManager', () => {
|
|
||||||
let manager: ResponseManager
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
vi.useFakeTimers()
|
|
||||||
manager = new ResponseManager()
|
|
||||||
})
|
|
||||||
|
|
||||||
afterEach(() => {
|
|
||||||
vi.useRealTimers()
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('buildResponseKey', () => {
|
|
||||||
it('should build key with type only', () => {
|
|
||||||
expect(manager.buildResponseKey('handshake_ack')).toBe('handshake_ack')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should build key with type and transferId', () => {
|
|
||||||
expect(manager.buildResponseKey('file_start_ack', 'uuid-123')).toBe('file_start_ack:uuid-123')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should build key with type, transferId, and chunkIndex', () => {
|
|
||||||
expect(manager.buildResponseKey('file_chunk_ack', 'uuid-123', 5)).toBe('file_chunk_ack:uuid-123:5')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('waitForResponse', () => {
|
|
||||||
it('should resolve when tryResolve is called with matching key', async () => {
|
|
||||||
const resolvePromise = new Promise<unknown>((resolve, reject) => {
|
|
||||||
manager.waitForResponse('handshake_ack', 5000, resolve, reject)
|
|
||||||
})
|
|
||||||
|
|
||||||
const payload = { type: 'handshake_ack', accepted: true }
|
|
||||||
const resolved = manager.tryResolve('handshake_ack', payload)
|
|
||||||
|
|
||||||
expect(resolved).toBe(true)
|
|
||||||
await expect(resolvePromise).resolves.toEqual(payload)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should reject on timeout', async () => {
|
|
||||||
const resolvePromise = new Promise<unknown>((resolve, reject) => {
|
|
||||||
manager.waitForResponse('handshake_ack', 1000, resolve, reject)
|
|
||||||
})
|
|
||||||
|
|
||||||
vi.advanceTimersByTime(1001)
|
|
||||||
|
|
||||||
await expect(resolvePromise).rejects.toThrow('Timeout waiting for handshake_ack')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should call onTimeout callback when timeout occurs', async () => {
|
|
||||||
const onTimeout = vi.fn()
|
|
||||||
manager.setTimeoutCallback(onTimeout)
|
|
||||||
|
|
||||||
const resolvePromise = new Promise<unknown>((resolve, reject) => {
|
|
||||||
manager.waitForResponse('test', 1000, resolve, reject)
|
|
||||||
})
|
|
||||||
|
|
||||||
vi.advanceTimersByTime(1001)
|
|
||||||
|
|
||||||
await expect(resolvePromise).rejects.toThrow()
|
|
||||||
expect(onTimeout).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should reject when abort signal is triggered', async () => {
|
|
||||||
const abortController = new AbortController()
|
|
||||||
|
|
||||||
const resolvePromise = new Promise<unknown>((resolve, reject) => {
|
|
||||||
manager.waitForResponse('test', 10000, resolve, reject, undefined, undefined, abortController.signal)
|
|
||||||
})
|
|
||||||
|
|
||||||
abortController.abort(new Error('User cancelled'))
|
|
||||||
|
|
||||||
await expect(resolvePromise).rejects.toThrow('User cancelled')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should replace existing response with same key', async () => {
|
|
||||||
const firstReject = vi.fn()
|
|
||||||
const secondResolve = vi.fn()
|
|
||||||
const secondReject = vi.fn()
|
|
||||||
|
|
||||||
manager.waitForResponse('test', 5000, vi.fn(), firstReject)
|
|
||||||
manager.waitForResponse('test', 5000, secondResolve, secondReject)
|
|
||||||
|
|
||||||
// First should be cleared (no rejection since it's replaced)
|
|
||||||
const payload = { type: 'test' }
|
|
||||||
manager.tryResolve('test', payload)
|
|
||||||
|
|
||||||
expect(secondResolve).toHaveBeenCalledWith(payload)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('tryResolve', () => {
|
|
||||||
it('should return false when no matching response', () => {
|
|
||||||
expect(manager.tryResolve('nonexistent', {})).toBe(false)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should match with transferId', async () => {
|
|
||||||
const resolvePromise = new Promise<unknown>((resolve, reject) => {
|
|
||||||
manager.waitForResponse('file_start_ack', 5000, resolve, reject, 'uuid-123')
|
|
||||||
})
|
|
||||||
|
|
||||||
const payload = { type: 'file_start_ack', transferId: 'uuid-123' }
|
|
||||||
manager.tryResolve('file_start_ack', payload, 'uuid-123')
|
|
||||||
|
|
||||||
await expect(resolvePromise).resolves.toEqual(payload)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('rejectAll', () => {
|
|
||||||
it('should reject all pending responses', async () => {
|
|
||||||
const promises = [
|
|
||||||
new Promise<unknown>((resolve, reject) => {
|
|
||||||
manager.waitForResponse('test1', 5000, resolve, reject)
|
|
||||||
}),
|
|
||||||
new Promise<unknown>((resolve, reject) => {
|
|
||||||
manager.waitForResponse('test2', 5000, resolve, reject, 'uuid')
|
|
||||||
})
|
|
||||||
]
|
|
||||||
|
|
||||||
manager.rejectAll(new Error('Connection closed'))
|
|
||||||
|
|
||||||
await expect(promises[0]).rejects.toThrow('Connection closed')
|
|
||||||
await expect(promises[1]).rejects.toThrow('Connection closed')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('clearPendingResponse', () => {
|
|
||||||
it('should clear specific response by key', () => {
|
|
||||||
manager.waitForResponse('test', 5000, vi.fn(), vi.fn())
|
|
||||||
|
|
||||||
manager.clearPendingResponse('test')
|
|
||||||
|
|
||||||
expect(manager.tryResolve('test', {})).toBe(false)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should clear all responses when no key provided', () => {
|
|
||||||
manager.waitForResponse('test1', 5000, vi.fn(), vi.fn())
|
|
||||||
manager.waitForResponse('test2', 5000, vi.fn(), vi.fn())
|
|
||||||
|
|
||||||
manager.clearPendingResponse()
|
|
||||||
|
|
||||||
expect(manager.tryResolve('test1', {})).toBe(false)
|
|
||||||
expect(manager.tryResolve('test2', {})).toBe(false)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('getAbortError', () => {
|
|
||||||
it('should return Error reason directly', () => {
|
|
||||||
const originalError = new Error('Original error')
|
|
||||||
const signal = { aborted: true, reason: originalError } as AbortSignal
|
|
||||||
|
|
||||||
const error = manager.getAbortError(signal, 'Fallback')
|
|
||||||
|
|
||||||
expect(error).toBe(originalError)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should create Error from string reason', () => {
|
|
||||||
const signal = { aborted: true, reason: 'String reason' } as AbortSignal
|
|
||||||
|
|
||||||
const error = manager.getAbortError(signal, 'Fallback')
|
|
||||||
|
|
||||||
expect(error.message).toBe('String reason')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should use fallback message when no reason', () => {
|
|
||||||
const signal = { aborted: true } as AbortSignal
|
|
||||||
|
|
||||||
const error = manager.getAbortError(signal, 'Fallback message')
|
|
||||||
|
|
||||||
expect(error.message).toBe('Fallback message')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@ -1,67 +0,0 @@
|
|||||||
import type { Socket } from 'node:net'
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Binary protocol constants (v1)
|
|
||||||
*/
|
|
||||||
export const BINARY_TYPE_FILE_CHUNK = 0x01
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Send file chunk as binary frame (protocol v1 - streaming mode)
|
|
||||||
*
|
|
||||||
* Frame format:
|
|
||||||
* ```
|
|
||||||
* ┌──────────┬──────────┬──────────┬───────────────┬──────────────┬────────────┬───────────┐
|
|
||||||
* │ Magic │ TotalLen │ Type │ TransferId Len│ TransferId │ ChunkIdx │ Data │
|
|
||||||
* │ 0x43 0x53│ (4B BE) │ 0x01 │ (2B BE) │ (variable) │ (4B BE) │ (raw) │
|
|
||||||
* └──────────┴──────────┴──────────┴───────────────┴──────────────┴────────────┴───────────┘
|
|
||||||
* ```
|
|
||||||
*
|
|
||||||
* @param socket - TCP socket to write to
|
|
||||||
* @param transferId - UUID of the transfer
|
|
||||||
* @param chunkIndex - Index of the chunk (0-based)
|
|
||||||
* @param data - Raw chunk data buffer
|
|
||||||
* @returns true if data was buffered, false if backpressure should be applied
|
|
||||||
*/
|
|
||||||
export function sendBinaryChunk(socket: Socket, transferId: string, chunkIndex: number, data: Buffer): boolean {
|
|
||||||
if (!socket || socket.destroyed || !socket.writable) {
|
|
||||||
throw new Error('Socket is not writable')
|
|
||||||
}
|
|
||||||
|
|
||||||
const tidBuffer = Buffer.from(transferId, 'utf8')
|
|
||||||
const tidLen = tidBuffer.length
|
|
||||||
|
|
||||||
// totalLen = type(1) + tidLen(2) + tid(n) + idx(4) + data(m)
|
|
||||||
const totalLen = 1 + 2 + tidLen + 4 + data.length
|
|
||||||
|
|
||||||
const header = Buffer.allocUnsafe(2 + 4 + 1 + 2 + tidLen + 4)
|
|
||||||
let offset = 0
|
|
||||||
|
|
||||||
// Magic (2 bytes): "CS"
|
|
||||||
header[offset++] = 0x43
|
|
||||||
header[offset++] = 0x53
|
|
||||||
|
|
||||||
// TotalLen (4 bytes, Big-Endian)
|
|
||||||
header.writeUInt32BE(totalLen, offset)
|
|
||||||
offset += 4
|
|
||||||
|
|
||||||
// Type (1 byte)
|
|
||||||
header[offset++] = BINARY_TYPE_FILE_CHUNK
|
|
||||||
|
|
||||||
// TransferId length (2 bytes, Big-Endian)
|
|
||||||
header.writeUInt16BE(tidLen, offset)
|
|
||||||
offset += 2
|
|
||||||
|
|
||||||
// TransferId (variable)
|
|
||||||
tidBuffer.copy(header, offset)
|
|
||||||
offset += tidLen
|
|
||||||
|
|
||||||
// ChunkIndex (4 bytes, Big-Endian)
|
|
||||||
header.writeUInt32BE(chunkIndex, offset)
|
|
||||||
|
|
||||||
socket.cork()
|
|
||||||
const wroteHeader = socket.write(header)
|
|
||||||
const wroteData = socket.write(data)
|
|
||||||
socket.uncork()
|
|
||||||
|
|
||||||
return wroteHeader && wroteData
|
|
||||||
}
|
|
||||||
@ -1,162 +0,0 @@
|
|||||||
import { isIP, type Socket } from 'node:net'
|
|
||||||
import { platform } from 'node:os'
|
|
||||||
|
|
||||||
import { loggerService } from '@logger'
|
|
||||||
import type { LanHandshakeRequestMessage, LocalTransferPeer } from '@shared/config/types'
|
|
||||||
import { app } from 'electron'
|
|
||||||
|
|
||||||
import type { ConnectionContext } from '../types'
|
|
||||||
|
|
||||||
export const HANDSHAKE_PROTOCOL_VERSION = '1'
|
|
||||||
|
|
||||||
/** Maximum size for line buffer to prevent memory exhaustion from malicious peers */
|
|
||||||
const MAX_LINE_BUFFER_SIZE = 1024 * 1024 // 1MB limit for control messages
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('LanTransferConnection')
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build a handshake request message with device info.
|
|
||||||
*/
|
|
||||||
export function buildHandshakeMessage(): LanHandshakeRequestMessage {
|
|
||||||
return {
|
|
||||||
type: 'handshake',
|
|
||||||
deviceName: app.getName(),
|
|
||||||
version: HANDSHAKE_PROTOCOL_VERSION,
|
|
||||||
platform: platform(),
|
|
||||||
appVersion: app.getVersion()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Pick the best host address from a peer's available addresses.
|
|
||||||
* Prefers IPv4 addresses over IPv6.
|
|
||||||
*/
|
|
||||||
export function pickHost(peer: LocalTransferPeer): string | undefined {
|
|
||||||
const preferred = peer.addresses?.find((addr) => isIP(addr) === 4) || peer.addresses?.[0]
|
|
||||||
return preferred || peer.host
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Send a test ping message after successful handshake.
|
|
||||||
*/
|
|
||||||
export function sendTestPing(ctx: ConnectionContext): void {
|
|
||||||
const payload = 'hello world'
|
|
||||||
try {
|
|
||||||
ctx.sendControlMessage({ type: 'ping', payload })
|
|
||||||
logger.info('Sent LAN ping test payload')
|
|
||||||
ctx.broadcastClientEvent({
|
|
||||||
type: 'ping_sent',
|
|
||||||
payload,
|
|
||||||
timestamp: Date.now()
|
|
||||||
})
|
|
||||||
} catch (error) {
|
|
||||||
const message = error instanceof Error ? error.message : String(error)
|
|
||||||
logger.error('Failed to send LAN test ping', error as Error)
|
|
||||||
ctx.broadcastClientEvent({
|
|
||||||
type: 'error',
|
|
||||||
message,
|
|
||||||
timestamp: Date.now()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Attach data listener to socket for receiving control messages.
|
|
||||||
* Returns a function to parse the line buffer.
|
|
||||||
*/
|
|
||||||
export function createDataHandler(onControlLine: (line: string) => void): {
|
|
||||||
lineBuffer: string
|
|
||||||
handleData: (chunk: Buffer) => void
|
|
||||||
resetBuffer: () => void
|
|
||||||
} {
|
|
||||||
let lineBuffer = ''
|
|
||||||
|
|
||||||
return {
|
|
||||||
get lineBuffer() {
|
|
||||||
return lineBuffer
|
|
||||||
},
|
|
||||||
handleData(chunk: Buffer) {
|
|
||||||
lineBuffer += chunk.toString('utf8')
|
|
||||||
|
|
||||||
// Prevent memory exhaustion from malicious peers sending data without newlines
|
|
||||||
if (lineBuffer.length > MAX_LINE_BUFFER_SIZE) {
|
|
||||||
logger.error('Line buffer exceeded maximum size, resetting')
|
|
||||||
lineBuffer = ''
|
|
||||||
throw new Error('Control message too large')
|
|
||||||
}
|
|
||||||
|
|
||||||
let newlineIndex = lineBuffer.indexOf('\n')
|
|
||||||
while (newlineIndex !== -1) {
|
|
||||||
const line = lineBuffer.slice(0, newlineIndex).trim()
|
|
||||||
lineBuffer = lineBuffer.slice(newlineIndex + 1)
|
|
||||||
if (line.length > 0) {
|
|
||||||
onControlLine(line)
|
|
||||||
}
|
|
||||||
newlineIndex = lineBuffer.indexOf('\n')
|
|
||||||
}
|
|
||||||
},
|
|
||||||
resetBuffer() {
|
|
||||||
lineBuffer = ''
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Wait for socket to drain (backpressure handling).
|
|
||||||
*/
|
|
||||||
export async function waitForSocketDrain(socket: Socket, abortSignal: AbortSignal): Promise<void> {
|
|
||||||
if (abortSignal.aborted) {
|
|
||||||
throw getAbortError(abortSignal, 'Transfer aborted while waiting for socket drain')
|
|
||||||
}
|
|
||||||
if (socket.destroyed) {
|
|
||||||
throw new Error('Socket is closed')
|
|
||||||
}
|
|
||||||
|
|
||||||
await new Promise<void>((resolve, reject) => {
|
|
||||||
const cleanup = () => {
|
|
||||||
socket.off('drain', onDrain)
|
|
||||||
socket.off('close', onClose)
|
|
||||||
socket.off('error', onError)
|
|
||||||
abortSignal.removeEventListener('abort', onAbort)
|
|
||||||
}
|
|
||||||
|
|
||||||
const onDrain = () => {
|
|
||||||
cleanup()
|
|
||||||
resolve()
|
|
||||||
}
|
|
||||||
|
|
||||||
const onClose = () => {
|
|
||||||
cleanup()
|
|
||||||
reject(new Error('Socket closed while waiting for drain'))
|
|
||||||
}
|
|
||||||
|
|
||||||
const onError = (error: Error) => {
|
|
||||||
cleanup()
|
|
||||||
reject(error)
|
|
||||||
}
|
|
||||||
|
|
||||||
const onAbort = () => {
|
|
||||||
cleanup()
|
|
||||||
reject(getAbortError(abortSignal, 'Transfer aborted while waiting for socket drain'))
|
|
||||||
}
|
|
||||||
|
|
||||||
socket.once('drain', onDrain)
|
|
||||||
socket.once('close', onClose)
|
|
||||||
socket.once('error', onError)
|
|
||||||
abortSignal.addEventListener('abort', onAbort, { once: true })
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the error from an abort signal, or create a fallback error.
|
|
||||||
*/
|
|
||||||
export function getAbortError(signal: AbortSignal, fallbackMessage: string): Error {
|
|
||||||
const reason = (signal as AbortSignal & { reason?: unknown }).reason
|
|
||||||
if (reason instanceof Error) {
|
|
||||||
return reason
|
|
||||||
}
|
|
||||||
if (typeof reason === 'string' && reason.length > 0) {
|
|
||||||
return new Error(reason)
|
|
||||||
}
|
|
||||||
return new Error(fallbackMessage)
|
|
||||||
}
|
|
||||||
@ -1,267 +0,0 @@
|
|||||||
import * as crypto from 'node:crypto'
|
|
||||||
import * as fs from 'node:fs'
|
|
||||||
import type { Socket } from 'node:net'
|
|
||||||
import * as path from 'node:path'
|
|
||||||
|
|
||||||
import { loggerService } from '@logger'
|
|
||||||
import type {
|
|
||||||
LanFileCompleteMessage,
|
|
||||||
LanFileEndMessage,
|
|
||||||
LanFileStartAckMessage,
|
|
||||||
LanFileStartMessage
|
|
||||||
} from '@shared/config/types'
|
|
||||||
import {
|
|
||||||
LAN_TRANSFER_CHUNK_SIZE,
|
|
||||||
LAN_TRANSFER_COMPLETE_TIMEOUT_MS,
|
|
||||||
LAN_TRANSFER_MAX_FILE_SIZE
|
|
||||||
} from '@shared/config/types'
|
|
||||||
|
|
||||||
import { sendBinaryChunk } from '../binaryProtocol'
|
|
||||||
import type { ActiveFileTransfer, FileTransferContext } from '../types'
|
|
||||||
import { getAbortError, waitForSocketDrain } from './connection'
|
|
||||||
|
|
||||||
const DEFAULT_FILE_START_ACK_TIMEOUT_MS = 30_000 // 30s for file_start_ack
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('LanTransferFileHandler')
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Validate a file for transfer.
|
|
||||||
* Checks existence, type, extension, and size limits.
|
|
||||||
*/
|
|
||||||
export async function validateFile(filePath: string): Promise<{ stats: fs.Stats; fileName: string }> {
|
|
||||||
let stats: fs.Stats
|
|
||||||
try {
|
|
||||||
stats = await fs.promises.stat(filePath)
|
|
||||||
} catch (error) {
|
|
||||||
const nodeError = error as NodeJS.ErrnoException
|
|
||||||
if (nodeError.code === 'ENOENT') {
|
|
||||||
throw new Error(`File not found: ${filePath}`)
|
|
||||||
} else if (nodeError.code === 'EACCES') {
|
|
||||||
throw new Error(`Permission denied: ${filePath}`)
|
|
||||||
} else if (nodeError.code === 'ENOTDIR') {
|
|
||||||
throw new Error(`Invalid path: ${filePath}`)
|
|
||||||
} else {
|
|
||||||
throw new Error(`Cannot access file: ${filePath} (${nodeError.code || 'unknown error'})`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!stats.isFile()) {
|
|
||||||
throw new Error('Path is not a file')
|
|
||||||
}
|
|
||||||
|
|
||||||
const fileName = path.basename(filePath)
|
|
||||||
const ext = path.extname(fileName).toLowerCase()
|
|
||||||
if (ext !== '.zip') {
|
|
||||||
throw new Error('Only ZIP files are supported')
|
|
||||||
}
|
|
||||||
|
|
||||||
if (stats.size > LAN_TRANSFER_MAX_FILE_SIZE) {
|
|
||||||
throw new Error(`File too large. Maximum size is ${formatFileSize(LAN_TRANSFER_MAX_FILE_SIZE)}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
return { stats, fileName }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate SHA-256 checksum of a file.
|
|
||||||
*/
|
|
||||||
export async function calculateFileChecksum(filePath: string): Promise<string> {
|
|
||||||
return new Promise((resolve, reject) => {
|
|
||||||
const hash = crypto.createHash('sha256')
|
|
||||||
const stream = fs.createReadStream(filePath)
|
|
||||||
stream.on('data', (data) => hash.update(data))
|
|
||||||
stream.on('end', () => resolve(hash.digest('hex')))
|
|
||||||
stream.on('error', reject)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create initial transfer state for a new file transfer.
|
|
||||||
*/
|
|
||||||
export function createTransferState(
|
|
||||||
transferId: string,
|
|
||||||
fileName: string,
|
|
||||||
fileSize: number,
|
|
||||||
checksum: string
|
|
||||||
): ActiveFileTransfer {
|
|
||||||
const chunkSize = LAN_TRANSFER_CHUNK_SIZE
|
|
||||||
const totalChunks = Math.ceil(fileSize / chunkSize)
|
|
||||||
|
|
||||||
return {
|
|
||||||
transferId,
|
|
||||||
fileName,
|
|
||||||
fileSize,
|
|
||||||
checksum,
|
|
||||||
totalChunks,
|
|
||||||
chunkSize,
|
|
||||||
bytesSent: 0,
|
|
||||||
currentChunk: 0,
|
|
||||||
startedAt: Date.now(),
|
|
||||||
isCancelled: false,
|
|
||||||
abortController: new AbortController()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Send file_start message to receiver.
|
|
||||||
*/
|
|
||||||
export function sendFileStart(ctx: FileTransferContext, transfer: ActiveFileTransfer): void {
|
|
||||||
const startMessage: LanFileStartMessage = {
|
|
||||||
type: 'file_start',
|
|
||||||
transferId: transfer.transferId,
|
|
||||||
fileName: transfer.fileName,
|
|
||||||
fileSize: transfer.fileSize,
|
|
||||||
mimeType: 'application/zip',
|
|
||||||
checksum: transfer.checksum,
|
|
||||||
totalChunks: transfer.totalChunks,
|
|
||||||
chunkSize: transfer.chunkSize
|
|
||||||
}
|
|
||||||
ctx.sendControlMessage(startMessage)
|
|
||||||
logger.info('Sent file_start message')
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Wait for file_start_ack from receiver.
|
|
||||||
*/
|
|
||||||
export function waitForFileStartAck(
|
|
||||||
ctx: FileTransferContext,
|
|
||||||
transferId: string,
|
|
||||||
abortSignal?: AbortSignal
|
|
||||||
): Promise<LanFileStartAckMessage> {
|
|
||||||
return new Promise((resolve, reject) => {
|
|
||||||
ctx.waitForResponse(
|
|
||||||
'file_start_ack',
|
|
||||||
DEFAULT_FILE_START_ACK_TIMEOUT_MS,
|
|
||||||
(payload) => resolve(payload as LanFileStartAckMessage),
|
|
||||||
reject,
|
|
||||||
transferId,
|
|
||||||
undefined,
|
|
||||||
abortSignal
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Wait for file_complete from receiver after all chunks sent.
|
|
||||||
*/
|
|
||||||
export function waitForFileComplete(
|
|
||||||
ctx: FileTransferContext,
|
|
||||||
transferId: string,
|
|
||||||
abortSignal?: AbortSignal
|
|
||||||
): Promise<LanFileCompleteMessage> {
|
|
||||||
return new Promise((resolve, reject) => {
|
|
||||||
ctx.waitForResponse(
|
|
||||||
'file_complete',
|
|
||||||
LAN_TRANSFER_COMPLETE_TIMEOUT_MS,
|
|
||||||
(payload) => resolve(payload as LanFileCompleteMessage),
|
|
||||||
reject,
|
|
||||||
transferId,
|
|
||||||
undefined,
|
|
||||||
abortSignal
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Send file_end message to receiver.
|
|
||||||
*/
|
|
||||||
export function sendFileEnd(ctx: FileTransferContext, transferId: string): void {
|
|
||||||
const endMessage: LanFileEndMessage = {
|
|
||||||
type: 'file_end',
|
|
||||||
transferId
|
|
||||||
}
|
|
||||||
ctx.sendControlMessage(endMessage)
|
|
||||||
logger.info('Sent file_end message')
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Stream file chunks to the receiver (v1 streaming mode - no per-chunk acknowledgment).
|
|
||||||
*/
|
|
||||||
export async function streamFileChunks(
|
|
||||||
socket: Socket,
|
|
||||||
filePath: string,
|
|
||||||
transfer: ActiveFileTransfer,
|
|
||||||
abortSignal: AbortSignal,
|
|
||||||
onProgress: (bytesSent: number, chunkIndex: number) => void
|
|
||||||
): Promise<void> {
|
|
||||||
const { chunkSize, transferId } = transfer
|
|
||||||
|
|
||||||
const stream = fs.createReadStream(filePath, { highWaterMark: chunkSize })
|
|
||||||
transfer.stream = stream
|
|
||||||
|
|
||||||
let chunkIndex = 0
|
|
||||||
let bytesSent = 0
|
|
||||||
|
|
||||||
try {
|
|
||||||
for await (const chunk of stream) {
|
|
||||||
if (abortSignal.aborted) {
|
|
||||||
throw getAbortError(abortSignal, 'Transfer aborted')
|
|
||||||
}
|
|
||||||
|
|
||||||
const buffer = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk)
|
|
||||||
bytesSent += buffer.length
|
|
||||||
|
|
||||||
// Send chunk as binary frame (v1 streaming) with backpressure handling
|
|
||||||
const canContinue = sendBinaryChunk(socket, transferId, chunkIndex, buffer)
|
|
||||||
if (!canContinue) {
|
|
||||||
await waitForSocketDrain(socket, abortSignal)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update progress
|
|
||||||
transfer.bytesSent = bytesSent
|
|
||||||
transfer.currentChunk = chunkIndex
|
|
||||||
|
|
||||||
onProgress(bytesSent, chunkIndex)
|
|
||||||
chunkIndex++
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(`File streaming completed: ${chunkIndex} chunks sent`)
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('File streaming failed', error as Error)
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Abort an active transfer and clean up resources.
|
|
||||||
*/
|
|
||||||
export function abortTransfer(transfer: ActiveFileTransfer | undefined, error: Error): void {
|
|
||||||
if (!transfer) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
transfer.isCancelled = true
|
|
||||||
if (!transfer.abortController.signal.aborted) {
|
|
||||||
transfer.abortController.abort(error)
|
|
||||||
}
|
|
||||||
if (transfer.stream && !transfer.stream.destroyed) {
|
|
||||||
transfer.stream.destroy(error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up transfer resources without error.
|
|
||||||
*/
|
|
||||||
export function cleanupTransfer(transfer: ActiveFileTransfer | undefined): void {
|
|
||||||
if (!transfer) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!transfer.abortController.signal.aborted) {
|
|
||||||
transfer.abortController.abort()
|
|
||||||
}
|
|
||||||
if (transfer.stream && !transfer.stream.destroyed) {
|
|
||||||
transfer.stream.destroy()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Format bytes into human-readable size string.
|
|
||||||
*/
|
|
||||||
export function formatFileSize(bytes: number): string {
|
|
||||||
if (bytes === 0) return '0 B'
|
|
||||||
const k = 1024
|
|
||||||
const sizes = ['B', 'KB', 'MB', 'GB']
|
|
||||||
const i = Math.floor(Math.log(bytes) / Math.log(k))
|
|
||||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]
|
|
||||||
}
|
|
||||||
@ -1,22 +0,0 @@
|
|||||||
export {
|
|
||||||
buildHandshakeMessage,
|
|
||||||
createDataHandler,
|
|
||||||
getAbortError,
|
|
||||||
HANDSHAKE_PROTOCOL_VERSION,
|
|
||||||
pickHost,
|
|
||||||
sendTestPing,
|
|
||||||
waitForSocketDrain
|
|
||||||
} from './connection'
|
|
||||||
export {
|
|
||||||
abortTransfer,
|
|
||||||
calculateFileChecksum,
|
|
||||||
cleanupTransfer,
|
|
||||||
createTransferState,
|
|
||||||
formatFileSize,
|
|
||||||
sendFileEnd,
|
|
||||||
sendFileStart,
|
|
||||||
streamFileChunks,
|
|
||||||
validateFile,
|
|
||||||
waitForFileComplete,
|
|
||||||
waitForFileStartAck
|
|
||||||
} from './fileTransfer'
|
|
||||||
@ -1,21 +0,0 @@
|
|||||||
/**
|
|
||||||
* LAN Transfer Client Module
|
|
||||||
*
|
|
||||||
* Protocol: v1.0 (streaming mode)
|
|
||||||
*
|
|
||||||
* Features:
|
|
||||||
* - Binary frame format for file chunks (no base64 overhead)
|
|
||||||
* - Streaming mode (no per-chunk acknowledgment)
|
|
||||||
* - JSON messages for control flow (handshake, file_start, file_end, etc.)
|
|
||||||
* - Global timeout protection
|
|
||||||
* - Backpressure handling
|
|
||||||
*
|
|
||||||
* Binary Frame Format:
|
|
||||||
* ┌──────────┬──────────┬──────────┬───────────────┬──────────────┬────────────┬───────────┐
|
|
||||||
* │ Magic │ TotalLen │ Type │ TransferId Len│ TransferId │ ChunkIdx │ Data │
|
|
||||||
* │ 0x43 0x53│ (4B BE) │ 0x01 │ (2B BE) │ (variable) │ (4B BE) │ (raw) │
|
|
||||||
* └──────────┴──────────┴──────────┴───────────────┴──────────────┴────────────┴───────────┘
|
|
||||||
*/
|
|
||||||
|
|
||||||
export { HANDSHAKE_PROTOCOL_VERSION, lanTransferClientService } from './LanTransferClientService'
|
|
||||||
export type { ActiveFileTransfer, ConnectionContext, FileTransferContext, PendingResponse } from './types'
|
|
||||||
@ -1,144 +0,0 @@
|
|||||||
import type { PendingResponse } from './types'
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Manages pending response handlers for awaiting control messages.
|
|
||||||
* Handles timeouts, abort signals, and cleanup.
|
|
||||||
*/
|
|
||||||
export class ResponseManager {
|
|
||||||
private pendingResponses = new Map<string, PendingResponse>()
|
|
||||||
private onTimeout?: () => void
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set a callback to be called when a response times out.
|
|
||||||
* Typically used to trigger disconnect on timeout.
|
|
||||||
*/
|
|
||||||
setTimeoutCallback(callback: () => void): void {
|
|
||||||
this.onTimeout = callback
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build a composite key for identifying pending responses.
|
|
||||||
*/
|
|
||||||
buildResponseKey(type: string, transferId?: string, chunkIndex?: number): string {
|
|
||||||
const parts = [type]
|
|
||||||
if (transferId !== undefined) parts.push(transferId)
|
|
||||||
if (chunkIndex !== undefined) parts.push(String(chunkIndex))
|
|
||||||
return parts.join(':')
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Register a response listener with timeout and optional abort signal.
|
|
||||||
*/
|
|
||||||
waitForResponse(
|
|
||||||
type: string,
|
|
||||||
timeoutMs: number,
|
|
||||||
resolve: (payload: unknown) => void,
|
|
||||||
reject: (error: Error) => void,
|
|
||||||
transferId?: string,
|
|
||||||
chunkIndex?: number,
|
|
||||||
abortSignal?: AbortSignal
|
|
||||||
): void {
|
|
||||||
const responseKey = this.buildResponseKey(type, transferId, chunkIndex)
|
|
||||||
|
|
||||||
// Clear any existing response with the same key
|
|
||||||
this.clearPendingResponse(responseKey)
|
|
||||||
|
|
||||||
const timeoutHandle = setTimeout(() => {
|
|
||||||
this.clearPendingResponse(responseKey)
|
|
||||||
const error = new Error(`Timeout waiting for ${type}`)
|
|
||||||
reject(error)
|
|
||||||
this.onTimeout?.()
|
|
||||||
}, timeoutMs)
|
|
||||||
|
|
||||||
const pending: PendingResponse = {
|
|
||||||
type,
|
|
||||||
transferId,
|
|
||||||
chunkIndex,
|
|
||||||
resolve,
|
|
||||||
reject,
|
|
||||||
timeoutHandle,
|
|
||||||
abortSignal
|
|
||||||
}
|
|
||||||
|
|
||||||
if (abortSignal) {
|
|
||||||
const abortListener = () => {
|
|
||||||
this.clearPendingResponse(responseKey)
|
|
||||||
reject(this.getAbortError(abortSignal, `Aborted while waiting for ${type}`))
|
|
||||||
}
|
|
||||||
pending.abortListener = abortListener
|
|
||||||
abortSignal.addEventListener('abort', abortListener, { once: true })
|
|
||||||
}
|
|
||||||
|
|
||||||
this.pendingResponses.set(responseKey, pending)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Try to resolve a pending response by type and optional identifiers.
|
|
||||||
* Returns true if a matching response was found and resolved.
|
|
||||||
*/
|
|
||||||
tryResolve(type: string, payload: unknown, transferId?: string, chunkIndex?: number): boolean {
|
|
||||||
const responseKey = this.buildResponseKey(type, transferId, chunkIndex)
|
|
||||||
const pendingResponse = this.pendingResponses.get(responseKey)
|
|
||||||
|
|
||||||
if (pendingResponse) {
|
|
||||||
const resolver = pendingResponse.resolve
|
|
||||||
this.clearPendingResponse(responseKey)
|
|
||||||
resolver(payload)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clear a single pending response by key, or all responses if no key provided.
|
|
||||||
*/
|
|
||||||
clearPendingResponse(key?: string): void {
|
|
||||||
if (key) {
|
|
||||||
const pending = this.pendingResponses.get(key)
|
|
||||||
if (pending?.timeoutHandle) {
|
|
||||||
clearTimeout(pending.timeoutHandle)
|
|
||||||
}
|
|
||||||
if (pending?.abortSignal && pending.abortListener) {
|
|
||||||
pending.abortSignal.removeEventListener('abort', pending.abortListener)
|
|
||||||
}
|
|
||||||
this.pendingResponses.delete(key)
|
|
||||||
} else {
|
|
||||||
// Clear all pending responses
|
|
||||||
for (const pending of this.pendingResponses.values()) {
|
|
||||||
if (pending.timeoutHandle) {
|
|
||||||
clearTimeout(pending.timeoutHandle)
|
|
||||||
}
|
|
||||||
if (pending.abortSignal && pending.abortListener) {
|
|
||||||
pending.abortSignal.removeEventListener('abort', pending.abortListener)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
this.pendingResponses.clear()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Reject all pending responses with the given error.
|
|
||||||
*/
|
|
||||||
rejectAll(error: Error): void {
|
|
||||||
for (const key of Array.from(this.pendingResponses.keys())) {
|
|
||||||
const pending = this.pendingResponses.get(key)
|
|
||||||
this.clearPendingResponse(key)
|
|
||||||
pending?.reject(error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the abort error from an abort signal, or create a fallback error.
|
|
||||||
*/
|
|
||||||
getAbortError(signal: AbortSignal, fallbackMessage: string): Error {
|
|
||||||
const reason = (signal as AbortSignal & { reason?: unknown }).reason
|
|
||||||
if (reason instanceof Error) {
|
|
||||||
return reason
|
|
||||||
}
|
|
||||||
if (typeof reason === 'string' && reason.length > 0) {
|
|
||||||
return new Error(reason)
|
|
||||||
}
|
|
||||||
return new Error(fallbackMessage)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,65 +0,0 @@
|
|||||||
import type * as fs from 'node:fs'
|
|
||||||
import type { Socket } from 'node:net'
|
|
||||||
|
|
||||||
import type { LanClientEvent, LocalTransferPeer } from '@shared/config/types'
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Pending response handler for awaiting control messages
|
|
||||||
*/
|
|
||||||
export type PendingResponse = {
|
|
||||||
type: string
|
|
||||||
transferId?: string
|
|
||||||
chunkIndex?: number
|
|
||||||
resolve: (payload: unknown) => void
|
|
||||||
reject: (error: Error) => void
|
|
||||||
timeoutHandle?: NodeJS.Timeout
|
|
||||||
abortSignal?: AbortSignal
|
|
||||||
abortListener?: () => void
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Active file transfer state tracking
|
|
||||||
*/
|
|
||||||
export type ActiveFileTransfer = {
|
|
||||||
transferId: string
|
|
||||||
fileName: string
|
|
||||||
fileSize: number
|
|
||||||
checksum: string
|
|
||||||
totalChunks: number
|
|
||||||
chunkSize: number
|
|
||||||
bytesSent: number
|
|
||||||
currentChunk: number
|
|
||||||
startedAt: number
|
|
||||||
stream?: fs.ReadStream
|
|
||||||
isCancelled: boolean
|
|
||||||
abortController: AbortController
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Context interface for connection handlers
|
|
||||||
* Provides access to service methods without circular dependencies
|
|
||||||
*/
|
|
||||||
export type ConnectionContext = {
|
|
||||||
socket: Socket | null
|
|
||||||
currentPeer?: LocalTransferPeer
|
|
||||||
sendControlMessage: (message: Record<string, unknown>) => void
|
|
||||||
broadcastClientEvent: (event: LanClientEvent) => void
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Context interface for file transfer handlers
|
|
||||||
* Extends connection context with transfer-specific methods
|
|
||||||
*/
|
|
||||||
export type FileTransferContext = ConnectionContext & {
|
|
||||||
activeTransfer?: ActiveFileTransfer
|
|
||||||
setActiveTransfer: (transfer: ActiveFileTransfer | undefined) => void
|
|
||||||
waitForResponse: (
|
|
||||||
type: string,
|
|
||||||
timeoutMs: number,
|
|
||||||
resolve: (payload: unknown) => void,
|
|
||||||
reject: (error: Error) => void,
|
|
||||||
transferId?: string,
|
|
||||||
chunkIndex?: number,
|
|
||||||
abortSignal?: AbortSignal
|
|
||||||
) => void
|
|
||||||
}
|
|
||||||
@ -128,8 +128,8 @@ export class CallBackServer {
|
|||||||
})
|
})
|
||||||
|
|
||||||
return new Promise<http.Server>((resolve, reject) => {
|
return new Promise<http.Server>((resolve, reject) => {
|
||||||
server.listen(port, '127.0.0.1', () => {
|
server.listen(port, () => {
|
||||||
logger.info(`OAuth callback server listening on 127.0.0.1:${port}`)
|
logger.info(`OAuth callback server listening on port ${port}`)
|
||||||
resolve(server)
|
resolve(server)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,7 @@
|
|||||||
import type { Client } from '@libsql/client'
|
import type { Client } from '@libsql/client'
|
||||||
import { createClient } from '@libsql/client'
|
import { createClient } from '@libsql/client'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { DATA_PATH } from '@main/config'
|
|
||||||
import Embeddings from '@main/knowledge/embedjs/embeddings/Embeddings'
|
import Embeddings from '@main/knowledge/embedjs/embeddings/Embeddings'
|
||||||
import { makeSureDirExists } from '@main/utils'
|
|
||||||
import type {
|
import type {
|
||||||
AddMemoryOptions,
|
AddMemoryOptions,
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
@ -15,7 +13,6 @@ import type {
|
|||||||
} from '@types'
|
} from '@types'
|
||||||
import crypto from 'crypto'
|
import crypto from 'crypto'
|
||||||
import { app } from 'electron'
|
import { app } from 'electron'
|
||||||
import fs from 'fs'
|
|
||||||
import path from 'path'
|
import path from 'path'
|
||||||
|
|
||||||
import { MemoryQueries } from './queries'
|
import { MemoryQueries } from './queries'
|
||||||
@ -74,21 +71,6 @@ export class MemoryService {
|
|||||||
return MemoryService.instance
|
return MemoryService.instance
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Migrate the memory database from the old path to the new path
|
|
||||||
* If the old memory database exists, rename it to the new path
|
|
||||||
*/
|
|
||||||
public migrateMemoryDb(): void {
|
|
||||||
const oldMemoryDbPath = path.join(app.getPath('userData'), 'memories.db')
|
|
||||||
const memoryDbPath = path.join(DATA_PATH, 'Memory', 'memories.db')
|
|
||||||
|
|
||||||
makeSureDirExists(path.dirname(memoryDbPath))
|
|
||||||
|
|
||||||
if (fs.existsSync(oldMemoryDbPath)) {
|
|
||||||
fs.renameSync(oldMemoryDbPath, memoryDbPath)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize the database connection and create tables
|
* Initialize the database connection and create tables
|
||||||
*/
|
*/
|
||||||
@ -98,12 +80,11 @@ export class MemoryService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const memoryDbPath = path.join(DATA_PATH, 'Memory', 'memories.db')
|
const userDataPath = app.getPath('userData')
|
||||||
|
const dbPath = path.join(userDataPath, 'memories.db')
|
||||||
makeSureDirExists(path.dirname(memoryDbPath))
|
|
||||||
|
|
||||||
this.db = createClient({
|
this.db = createClient({
|
||||||
url: `file:${memoryDbPath}`,
|
url: `file:${dbPath}`,
|
||||||
intMode: 'number'
|
intMode: 'number'
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -187,13 +168,12 @@ export class MemoryService {
|
|||||||
|
|
||||||
// Generate embedding if model is configured
|
// Generate embedding if model is configured
|
||||||
let embedding: number[] | null = null
|
let embedding: number[] | null = null
|
||||||
const embeddingModel = this.config?.embeddingModel
|
const embedderApiClient = this.config?.embedderApiClient
|
||||||
|
if (embedderApiClient) {
|
||||||
if (embeddingModel) {
|
|
||||||
try {
|
try {
|
||||||
embedding = await this.generateEmbedding(trimmedMemory)
|
embedding = await this.generateEmbedding(trimmedMemory)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`Generated embedding for restored memory with dimension: ${embedding.length} (target: ${this.config?.embeddingDimensions || MemoryService.UNIFIED_DIMENSION})`
|
`Generated embedding for restored memory with dimension: ${embedding.length} (target: ${this.config?.embedderDimensions || MemoryService.UNIFIED_DIMENSION})`
|
||||||
)
|
)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Failed to generate embedding for restored memory:', error as Error)
|
logger.error('Failed to generate embedding for restored memory:', error as Error)
|
||||||
@ -231,11 +211,11 @@ export class MemoryService {
|
|||||||
|
|
||||||
// Generate embedding if model is configured
|
// Generate embedding if model is configured
|
||||||
let embedding: number[] | null = null
|
let embedding: number[] | null = null
|
||||||
if (this.config?.embeddingModel) {
|
if (this.config?.embedderApiClient) {
|
||||||
try {
|
try {
|
||||||
embedding = await this.generateEmbedding(trimmedMemory)
|
embedding = await this.generateEmbedding(trimmedMemory)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`Generated embedding with dimension: ${embedding.length} (target: ${this.config?.embeddingDimensions || MemoryService.UNIFIED_DIMENSION})`
|
`Generated embedding with dimension: ${embedding.length} (target: ${this.config?.embedderDimensions || MemoryService.UNIFIED_DIMENSION})`
|
||||||
)
|
)
|
||||||
|
|
||||||
// Check for similar memories using vector similarity
|
// Check for similar memories using vector similarity
|
||||||
@ -320,7 +300,7 @@ export class MemoryService {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
// If we have an embedder model configured, use vector search
|
// If we have an embedder model configured, use vector search
|
||||||
if (this.config?.embeddingModel) {
|
if (this.config?.embedderApiClient) {
|
||||||
try {
|
try {
|
||||||
const queryEmbedding = await this.generateEmbedding(query)
|
const queryEmbedding = await this.generateEmbedding(query)
|
||||||
return await this.hybridSearch(query, queryEmbedding, { limit, userId, agentId, filters })
|
return await this.hybridSearch(query, queryEmbedding, { limit, userId, agentId, filters })
|
||||||
@ -517,11 +497,11 @@ export class MemoryService {
|
|||||||
|
|
||||||
// Generate new embedding if model is configured
|
// Generate new embedding if model is configured
|
||||||
let embedding: number[] | null = null
|
let embedding: number[] | null = null
|
||||||
if (this.config?.embeddingModel) {
|
if (this.config?.embedderApiClient) {
|
||||||
try {
|
try {
|
||||||
embedding = await this.generateEmbedding(memory)
|
embedding = await this.generateEmbedding(memory)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`Updated embedding with dimension: ${embedding.length} (target: ${this.config?.embeddingDimensions || MemoryService.UNIFIED_DIMENSION})`
|
`Updated embedding with dimension: ${embedding.length} (target: ${this.config?.embedderDimensions || MemoryService.UNIFIED_DIMENSION})`
|
||||||
)
|
)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Failed to generate embedding for update:', error as Error)
|
logger.error('Failed to generate embedding for update:', error as Error)
|
||||||
@ -730,22 +710,21 @@ export class MemoryService {
|
|||||||
* Generate embedding for text
|
* Generate embedding for text
|
||||||
*/
|
*/
|
||||||
private async generateEmbedding(text: string): Promise<number[]> {
|
private async generateEmbedding(text: string): Promise<number[]> {
|
||||||
if (!this.config?.embeddingModel) {
|
if (!this.config?.embedderApiClient) {
|
||||||
throw new Error('Embedder model not configured')
|
throw new Error('Embedder model not configured')
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Initialize embeddings instance if needed
|
// Initialize embeddings instance if needed
|
||||||
if (!this.embeddings) {
|
if (!this.embeddings) {
|
||||||
if (!this.config.embeddingApiClient) {
|
if (!this.config.embedderApiClient) {
|
||||||
throw new Error('Embedder provider not configured')
|
throw new Error('Embedder provider not configured')
|
||||||
}
|
}
|
||||||
|
|
||||||
this.embeddings = new Embeddings({
|
this.embeddings = new Embeddings({
|
||||||
embedApiClient: this.config.embeddingApiClient,
|
embedApiClient: this.config.embedderApiClient,
|
||||||
dimensions: this.config.embeddingDimensions
|
dimensions: this.config.embedderDimensions
|
||||||
})
|
})
|
||||||
|
|
||||||
await this.embeddings.init()
|
await this.embeddings.init()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,28 +1,9 @@
|
|||||||
import { configManager } from '@main/services/ConfigManager'
|
import { execFileSync } from 'child_process'
|
||||||
import { execFileSync, spawn } from 'child_process'
|
|
||||||
import { EventEmitter } from 'events'
|
|
||||||
import fs from 'fs'
|
import fs from 'fs'
|
||||||
import path from 'path'
|
import path from 'path'
|
||||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
import {
|
import { findExecutable, findGitBash, validateGitBashPath } from '../process'
|
||||||
autoDiscoverGitBash,
|
|
||||||
findCommandInShellEnv,
|
|
||||||
findExecutable,
|
|
||||||
findGitBash,
|
|
||||||
validateGitBashPath
|
|
||||||
} from '../process'
|
|
||||||
|
|
||||||
// Mock configManager
|
|
||||||
vi.mock('@main/services/ConfigManager', () => ({
|
|
||||||
ConfigKeys: {
|
|
||||||
GitBashPath: 'gitBashPath'
|
|
||||||
},
|
|
||||||
configManager: {
|
|
||||||
get: vi.fn(),
|
|
||||||
set: vi.fn()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
// Mock dependencies
|
// Mock dependencies
|
||||||
vi.mock('child_process')
|
vi.mock('child_process')
|
||||||
@ -714,525 +695,4 @@ describe.skipIf(process.platform !== 'win32')('process utilities', () => {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('autoDiscoverGitBash', () => {
|
|
||||||
const originalEnvVar = process.env.CLAUDE_CODE_GIT_BASH_PATH
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
vi.mocked(configManager.get).mockReset()
|
|
||||||
vi.mocked(configManager.set).mockReset()
|
|
||||||
delete process.env.CLAUDE_CODE_GIT_BASH_PATH
|
|
||||||
})
|
|
||||||
|
|
||||||
afterEach(() => {
|
|
||||||
// Restore original environment variable
|
|
||||||
if (originalEnvVar !== undefined) {
|
|
||||||
process.env.CLAUDE_CODE_GIT_BASH_PATH = originalEnvVar
|
|
||||||
} else {
|
|
||||||
delete process.env.CLAUDE_CODE_GIT_BASH_PATH
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper to mock fs.existsSync with a set of valid paths
|
|
||||||
*/
|
|
||||||
const mockExistingPaths = (...validPaths: string[]) => {
|
|
||||||
vi.mocked(fs.existsSync).mockImplementation((p) => validPaths.includes(p as string))
|
|
||||||
}
|
|
||||||
|
|
||||||
describe('with no existing config path', () => {
|
|
||||||
it('should discover and persist Git Bash path when not configured', () => {
|
|
||||||
const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe'
|
|
||||||
const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe'
|
|
||||||
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(undefined)
|
|
||||||
process.env.ProgramFiles = 'C:\\Program Files'
|
|
||||||
mockExistingPaths(gitPath, bashPath)
|
|
||||||
|
|
||||||
const result = autoDiscoverGitBash()
|
|
||||||
|
|
||||||
expect(result).toBe(bashPath)
|
|
||||||
expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should return null and not persist when Git Bash is not found', () => {
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(undefined)
|
|
||||||
vi.mocked(fs.existsSync).mockReturnValue(false)
|
|
||||||
vi.mocked(execFileSync).mockImplementation(() => {
|
|
||||||
throw new Error('Not found')
|
|
||||||
})
|
|
||||||
|
|
||||||
const result = autoDiscoverGitBash()
|
|
||||||
|
|
||||||
expect(result).toBeNull()
|
|
||||||
expect(configManager.set).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('environment variable precedence', () => {
|
|
||||||
it('should use env var over valid config path', () => {
|
|
||||||
const envPath = 'C:\\EnvGit\\bin\\bash.exe'
|
|
||||||
const configPath = 'C:\\ConfigGit\\bin\\bash.exe'
|
|
||||||
|
|
||||||
process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(configPath)
|
|
||||||
mockExistingPaths(envPath, configPath)
|
|
||||||
|
|
||||||
const result = autoDiscoverGitBash()
|
|
||||||
|
|
||||||
// Env var should take precedence
|
|
||||||
expect(result).toBe(envPath)
|
|
||||||
// Should not persist env var path (it's a runtime override)
|
|
||||||
expect(configManager.set).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should fall back to config path when env var is invalid', () => {
|
|
||||||
const envPath = 'C:\\Invalid\\bash.exe'
|
|
||||||
const configPath = 'C:\\ConfigGit\\bin\\bash.exe'
|
|
||||||
|
|
||||||
process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(configPath)
|
|
||||||
// Env path is invalid (doesn't exist), only config path exists
|
|
||||||
mockExistingPaths(configPath)
|
|
||||||
|
|
||||||
const result = autoDiscoverGitBash()
|
|
||||||
|
|
||||||
// Should fall back to config path
|
|
||||||
expect(result).toBe(configPath)
|
|
||||||
expect(configManager.set).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should fall back to auto-discovery when both env var and config are invalid', () => {
|
|
||||||
const envPath = 'C:\\InvalidEnv\\bash.exe'
|
|
||||||
const configPath = 'C:\\InvalidConfig\\bash.exe'
|
|
||||||
const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe'
|
|
||||||
const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe'
|
|
||||||
|
|
||||||
process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath
|
|
||||||
process.env.ProgramFiles = 'C:\\Program Files'
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(configPath)
|
|
||||||
// Both env and config paths are invalid, only standard Git exists
|
|
||||||
mockExistingPaths(gitPath, discoveredPath)
|
|
||||||
|
|
||||||
const result = autoDiscoverGitBash()
|
|
||||||
|
|
||||||
expect(result).toBe(discoveredPath)
|
|
||||||
expect(configManager.set).toHaveBeenCalledWith('gitBashPath', discoveredPath)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('with valid existing config path', () => {
|
|
||||||
it('should validate and return existing path without re-discovering', () => {
|
|
||||||
const existingPath = 'C:\\CustomGit\\bin\\bash.exe'
|
|
||||||
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(existingPath)
|
|
||||||
mockExistingPaths(existingPath)
|
|
||||||
|
|
||||||
const result = autoDiscoverGitBash()
|
|
||||||
|
|
||||||
expect(result).toBe(existingPath)
|
|
||||||
// Should not call findGitBash or persist again
|
|
||||||
expect(configManager.set).not.toHaveBeenCalled()
|
|
||||||
// Should not call execFileSync (which findGitBash would use for discovery)
|
|
||||||
expect(execFileSync).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should not override existing valid config with auto-discovery', () => {
|
|
||||||
const existingPath = 'C:\\CustomGit\\bin\\bash.exe'
|
|
||||||
const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe'
|
|
||||||
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(existingPath)
|
|
||||||
mockExistingPaths(existingPath, discoveredPath)
|
|
||||||
|
|
||||||
const result = autoDiscoverGitBash()
|
|
||||||
|
|
||||||
expect(result).toBe(existingPath)
|
|
||||||
expect(configManager.set).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('with invalid existing config path', () => {
|
|
||||||
it('should attempt auto-discovery when existing path does not exist', () => {
|
|
||||||
const existingPath = 'C:\\NonExistent\\bin\\bash.exe'
|
|
||||||
const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe'
|
|
||||||
const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe'
|
|
||||||
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(existingPath)
|
|
||||||
process.env.ProgramFiles = 'C:\\Program Files'
|
|
||||||
// Invalid path doesn't exist, but Git is installed at standard location
|
|
||||||
mockExistingPaths(gitPath, discoveredPath)
|
|
||||||
|
|
||||||
const result = autoDiscoverGitBash()
|
|
||||||
|
|
||||||
// Should discover and return the new path
|
|
||||||
expect(result).toBe(discoveredPath)
|
|
||||||
// Should persist the discovered path (overwrites invalid)
|
|
||||||
expect(configManager.set).toHaveBeenCalledWith('gitBashPath', discoveredPath)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should attempt auto-discovery when existing path is not bash.exe', () => {
|
|
||||||
const existingPath = 'C:\\CustomGit\\bin\\git.exe'
|
|
||||||
const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe'
|
|
||||||
const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe'
|
|
||||||
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(existingPath)
|
|
||||||
process.env.ProgramFiles = 'C:\\Program Files'
|
|
||||||
// Invalid path exists but is not bash.exe (validation will fail)
|
|
||||||
// Git is installed at standard location
|
|
||||||
mockExistingPaths(existingPath, gitPath, discoveredPath)
|
|
||||||
|
|
||||||
const result = autoDiscoverGitBash()
|
|
||||||
|
|
||||||
// Should discover and return the new path
|
|
||||||
expect(result).toBe(discoveredPath)
|
|
||||||
// Should persist the discovered path (overwrites invalid)
|
|
||||||
expect(configManager.set).toHaveBeenCalledWith('gitBashPath', discoveredPath)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should return null when existing path is invalid and discovery fails', () => {
|
|
||||||
const existingPath = 'C:\\NonExistent\\bin\\bash.exe'
|
|
||||||
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(existingPath)
|
|
||||||
vi.mocked(fs.existsSync).mockReturnValue(false)
|
|
||||||
vi.mocked(execFileSync).mockImplementation(() => {
|
|
||||||
throw new Error('Not found')
|
|
||||||
})
|
|
||||||
|
|
||||||
const result = autoDiscoverGitBash()
|
|
||||||
|
|
||||||
// Both validation and discovery failed
|
|
||||||
expect(result).toBeNull()
|
|
||||||
// Should not persist when discovery fails
|
|
||||||
expect(configManager.set).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('config persistence verification', () => {
|
|
||||||
it('should persist discovered path with correct config key', () => {
|
|
||||||
const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe'
|
|
||||||
const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe'
|
|
||||||
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(undefined)
|
|
||||||
process.env.ProgramFiles = 'C:\\Program Files'
|
|
||||||
mockExistingPaths(gitPath, bashPath)
|
|
||||||
|
|
||||||
autoDiscoverGitBash()
|
|
||||||
|
|
||||||
// Verify the exact call to configManager.set
|
|
||||||
expect(configManager.set).toHaveBeenCalledTimes(1)
|
|
||||||
expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should persist on each discovery when config remains undefined', () => {
|
|
||||||
const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe'
|
|
||||||
const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe'
|
|
||||||
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(undefined)
|
|
||||||
process.env.ProgramFiles = 'C:\\Program Files'
|
|
||||||
mockExistingPaths(gitPath, bashPath)
|
|
||||||
|
|
||||||
autoDiscoverGitBash()
|
|
||||||
autoDiscoverGitBash()
|
|
||||||
|
|
||||||
// Each call discovers and persists since config remains undefined (mocked)
|
|
||||||
expect(configManager.set).toHaveBeenCalledTimes(2)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('real-world scenarios', () => {
|
|
||||||
it('should discover and persist standard Git for Windows installation', () => {
|
|
||||||
const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe'
|
|
||||||
const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe'
|
|
||||||
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(undefined)
|
|
||||||
process.env.ProgramFiles = 'C:\\Program Files'
|
|
||||||
mockExistingPaths(gitPath, bashPath)
|
|
||||||
|
|
||||||
const result = autoDiscoverGitBash()
|
|
||||||
|
|
||||||
expect(result).toBe(bashPath)
|
|
||||||
expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should discover portable Git via where.exe and persist', () => {
|
|
||||||
const gitPath = 'D:\\PortableApps\\Git\\bin\\git.exe'
|
|
||||||
const bashPath = 'D:\\PortableApps\\Git\\bin\\bash.exe'
|
|
||||||
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(undefined)
|
|
||||||
|
|
||||||
vi.mocked(fs.existsSync).mockImplementation((p) => {
|
|
||||||
const pathStr = p?.toString() || ''
|
|
||||||
// Common git paths don't exist
|
|
||||||
if (pathStr.includes('Program Files\\Git\\cmd\\git.exe')) return false
|
|
||||||
if (pathStr.includes('Program Files (x86)\\Git\\cmd\\git.exe')) return false
|
|
||||||
// Portable bash path exists
|
|
||||||
if (pathStr === bashPath) return true
|
|
||||||
return false
|
|
||||||
})
|
|
||||||
|
|
||||||
vi.mocked(execFileSync).mockReturnValue(gitPath)
|
|
||||||
|
|
||||||
const result = autoDiscoverGitBash()
|
|
||||||
|
|
||||||
expect(result).toBe(bashPath)
|
|
||||||
expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should respect user-configured path over auto-discovery', () => {
|
|
||||||
const userConfiguredPath = 'D:\\MyGit\\bin\\bash.exe'
|
|
||||||
const systemPath = 'C:\\Program Files\\Git\\bin\\bash.exe'
|
|
||||||
|
|
||||||
vi.mocked(configManager.get).mockReturnValue(userConfiguredPath)
|
|
||||||
mockExistingPaths(userConfiguredPath, systemPath)
|
|
||||||
|
|
||||||
const result = autoDiscoverGitBash()
|
|
||||||
|
|
||||||
expect(result).toBe(userConfiguredPath)
|
|
||||||
expect(configManager.set).not.toHaveBeenCalled()
|
|
||||||
// Verify findGitBash was not called for discovery
|
|
||||||
expect(execFileSync).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper to create a mock child process for spawn
|
|
||||||
*/
|
|
||||||
function createMockChildProcess() {
|
|
||||||
const mockChild = new EventEmitter() as EventEmitter & {
|
|
||||||
stdout: EventEmitter
|
|
||||||
stderr: EventEmitter
|
|
||||||
kill: ReturnType<typeof vi.fn>
|
|
||||||
}
|
|
||||||
mockChild.stdout = new EventEmitter()
|
|
||||||
mockChild.stderr = new EventEmitter()
|
|
||||||
mockChild.kill = vi.fn()
|
|
||||||
return mockChild
|
|
||||||
}
|
|
||||||
|
|
||||||
describe('findCommandInShellEnv', () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
vi.clearAllMocks()
|
|
||||||
// Reset path.isAbsolute to real implementation for these tests
|
|
||||||
vi.mocked(path.isAbsolute).mockImplementation((p) => p.startsWith('/') || /^[A-Z]:/i.test(p))
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('command name validation', () => {
|
|
||||||
it('should reject empty command name', async () => {
|
|
||||||
const result = await findCommandInShellEnv('', {})
|
|
||||||
expect(result).toBeNull()
|
|
||||||
expect(spawn).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should reject command names with shell metacharacters', async () => {
|
|
||||||
const maliciousCommands = [
|
|
||||||
'npx; rm -rf /',
|
|
||||||
'npx && malicious',
|
|
||||||
'npx | cat /etc/passwd',
|
|
||||||
'npx`whoami`',
|
|
||||||
'$(whoami)',
|
|
||||||
'npx\nmalicious'
|
|
||||||
]
|
|
||||||
|
|
||||||
for (const cmd of maliciousCommands) {
|
|
||||||
const result = await findCommandInShellEnv(cmd, {})
|
|
||||||
expect(result).toBeNull()
|
|
||||||
expect(spawn).not.toHaveBeenCalled()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should reject command names starting with hyphen', async () => {
|
|
||||||
const result = await findCommandInShellEnv('-npx', {})
|
|
||||||
expect(result).toBeNull()
|
|
||||||
expect(spawn).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should reject path traversal attempts', async () => {
|
|
||||||
const pathTraversalCommands = ['../npx', '../../malicious', 'foo/bar', 'foo\\bar']
|
|
||||||
|
|
||||||
for (const cmd of pathTraversalCommands) {
|
|
||||||
const result = await findCommandInShellEnv(cmd, {})
|
|
||||||
expect(result).toBeNull()
|
|
||||||
expect(spawn).not.toHaveBeenCalled()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should reject command names exceeding max length', async () => {
|
|
||||||
const longCommand = 'a'.repeat(129)
|
|
||||||
const result = await findCommandInShellEnv(longCommand, {})
|
|
||||||
expect(result).toBeNull()
|
|
||||||
expect(spawn).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should accept valid command names', async () => {
|
|
||||||
const mockChild = createMockChildProcess()
|
|
||||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
|
||||||
|
|
||||||
// Don't await - just start the call
|
|
||||||
const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' })
|
|
||||||
|
|
||||||
// Simulate command not found
|
|
||||||
mockChild.emit('close', 1)
|
|
||||||
|
|
||||||
const result = await resultPromise
|
|
||||||
expect(result).toBeNull()
|
|
||||||
expect(spawn).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should accept command names with underscores and hyphens', async () => {
|
|
||||||
const mockChild = createMockChildProcess()
|
|
||||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
|
||||||
|
|
||||||
const resultPromise = findCommandInShellEnv('my_command-name', { PATH: '/usr/bin' })
|
|
||||||
mockChild.emit('close', 1)
|
|
||||||
|
|
||||||
await resultPromise
|
|
||||||
expect(spawn).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should accept command names at max length (128 chars)', async () => {
|
|
||||||
const mockChild = createMockChildProcess()
|
|
||||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
|
||||||
|
|
||||||
const maxLengthCommand = 'a'.repeat(128)
|
|
||||||
const resultPromise = findCommandInShellEnv(maxLengthCommand, { PATH: '/usr/bin' })
|
|
||||||
mockChild.emit('close', 1)
|
|
||||||
|
|
||||||
await resultPromise
|
|
||||||
expect(spawn).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe.skipIf(process.platform === 'win32')('Unix/macOS behavior', () => {
|
|
||||||
it('should find command and return absolute path', async () => {
|
|
||||||
const mockChild = createMockChildProcess()
|
|
||||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
|
||||||
|
|
||||||
const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' })
|
|
||||||
|
|
||||||
// Simulate successful command -v output
|
|
||||||
mockChild.stdout.emit('data', '/usr/local/bin/npx\n')
|
|
||||||
mockChild.emit('close', 0)
|
|
||||||
|
|
||||||
const result = await resultPromise
|
|
||||||
expect(result).toBe('/usr/local/bin/npx')
|
|
||||||
expect(spawn).toHaveBeenCalledWith('/bin/sh', ['-c', 'command -v "$1"', '--', 'npx'], expect.any(Object))
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should return null for non-absolute paths (aliases/builtins)', async () => {
|
|
||||||
const mockChild = createMockChildProcess()
|
|
||||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
|
||||||
|
|
||||||
const resultPromise = findCommandInShellEnv('cd', { PATH: '/usr/bin' })
|
|
||||||
|
|
||||||
// Simulate builtin output (just command name)
|
|
||||||
mockChild.stdout.emit('data', 'cd\n')
|
|
||||||
mockChild.emit('close', 0)
|
|
||||||
|
|
||||||
const result = await resultPromise
|
|
||||||
expect(result).toBeNull()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should return null when command not found', async () => {
|
|
||||||
const mockChild = createMockChildProcess()
|
|
||||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
|
||||||
|
|
||||||
const resultPromise = findCommandInShellEnv('nonexistent', { PATH: '/usr/bin' })
|
|
||||||
|
|
||||||
// Simulate command not found (exit code 1)
|
|
||||||
mockChild.emit('close', 1)
|
|
||||||
|
|
||||||
const result = await resultPromise
|
|
||||||
expect(result).toBeNull()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle spawn errors gracefully', async () => {
|
|
||||||
const mockChild = createMockChildProcess()
|
|
||||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
|
||||||
|
|
||||||
const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' })
|
|
||||||
|
|
||||||
// Simulate spawn error
|
|
||||||
mockChild.emit('error', new Error('spawn failed'))
|
|
||||||
|
|
||||||
const result = await resultPromise
|
|
||||||
expect(result).toBeNull()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle timeout gracefully', async () => {
|
|
||||||
vi.useFakeTimers()
|
|
||||||
const mockChild = createMockChildProcess()
|
|
||||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
|
||||||
|
|
||||||
const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' })
|
|
||||||
|
|
||||||
// Fast-forward past timeout (5000ms)
|
|
||||||
vi.advanceTimersByTime(6000)
|
|
||||||
|
|
||||||
const result = await resultPromise
|
|
||||||
expect(result).toBeNull()
|
|
||||||
expect(mockChild.kill).toHaveBeenCalledWith('SIGKILL')
|
|
||||||
|
|
||||||
vi.useRealTimers()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe.skipIf(process.platform !== 'win32')('Windows behavior', () => {
|
|
||||||
it('should find .exe files via where command', async () => {
|
|
||||||
const mockChild = createMockChildProcess()
|
|
||||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
|
||||||
|
|
||||||
const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' })
|
|
||||||
|
|
||||||
// Simulate where output
|
|
||||||
mockChild.stdout.emit('data', 'C:\\Program Files\\nodejs\\npx.exe\r\n')
|
|
||||||
mockChild.emit('close', 0)
|
|
||||||
|
|
||||||
const result = await resultPromise
|
|
||||||
expect(result).toBe('C:\\Program Files\\nodejs\\npx.exe')
|
|
||||||
expect(spawn).toHaveBeenCalledWith('where', ['npx'], expect.any(Object))
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should reject .cmd files on Windows', async () => {
|
|
||||||
const mockChild = createMockChildProcess()
|
|
||||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
|
||||||
|
|
||||||
const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' })
|
|
||||||
|
|
||||||
// Simulate where output with only .cmd file
|
|
||||||
mockChild.stdout.emit('data', 'C:\\Program Files\\nodejs\\npx.cmd\r\n')
|
|
||||||
mockChild.emit('close', 0)
|
|
||||||
|
|
||||||
const result = await resultPromise
|
|
||||||
expect(result).toBeNull()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should prefer .exe over .cmd when both exist', async () => {
|
|
||||||
const mockChild = createMockChildProcess()
|
|
||||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
|
||||||
|
|
||||||
const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' })
|
|
||||||
|
|
||||||
// Simulate where output with both .cmd and .exe
|
|
||||||
mockChild.stdout.emit('data', 'C:\\Program Files\\nodejs\\npx.cmd\r\nC:\\Program Files\\nodejs\\npx.exe\r\n')
|
|
||||||
mockChild.emit('close', 0)
|
|
||||||
|
|
||||||
const result = await resultPromise
|
|
||||||
expect(result).toBe('C:\\Program Files\\nodejs\\npx.exe')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle spawn errors gracefully', async () => {
|
|
||||||
const mockChild = createMockChildProcess()
|
|
||||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
|
||||||
|
|
||||||
const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' })
|
|
||||||
|
|
||||||
// Simulate spawn error
|
|
||||||
mockChild.emit('error', new Error('spawn failed'))
|
|
||||||
|
|
||||||
const result = await resultPromise
|
|
||||||
expect(result).toBeNull()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import type { GitBashPathInfo, GitBashPathSource } from '@shared/config/constant'
|
|
||||||
import { HOME_CHERRY_DIR } from '@shared/config/constant'
|
import { HOME_CHERRY_DIR } from '@shared/config/constant'
|
||||||
import { execFileSync, spawn } from 'child_process'
|
import { execFileSync, spawn } from 'child_process'
|
||||||
import fs from 'fs'
|
import fs from 'fs'
|
||||||
@ -7,7 +6,6 @@ import os from 'os'
|
|||||||
import path from 'path'
|
import path from 'path'
|
||||||
|
|
||||||
import { isWin } from '../constant'
|
import { isWin } from '../constant'
|
||||||
import { ConfigKeys, configManager } from '../services/ConfigManager'
|
|
||||||
import { getResourcePath } from '.'
|
import { getResourcePath } from '.'
|
||||||
|
|
||||||
const logger = loggerService.withContext('Utils:Process')
|
const logger = loggerService.withContext('Utils:Process')
|
||||||
@ -61,146 +59,7 @@ export async function getBinaryPath(name?: string): Promise<string> {
|
|||||||
|
|
||||||
export async function isBinaryExists(name: string): Promise<boolean> {
|
export async function isBinaryExists(name: string): Promise<boolean> {
|
||||||
const cmd = await getBinaryPath(name)
|
const cmd = await getBinaryPath(name)
|
||||||
return fs.existsSync(cmd)
|
return await fs.existsSync(cmd)
|
||||||
}
|
|
||||||
|
|
||||||
// Timeout for command lookup operations (in milliseconds)
|
|
||||||
const COMMAND_LOOKUP_TIMEOUT_MS = 5000
|
|
||||||
|
|
||||||
// Regex to validate command names - must start with alphanumeric or underscore, max 128 chars
|
|
||||||
const VALID_COMMAND_NAME_REGEX = /^[a-zA-Z0-9_][a-zA-Z0-9_-]{0,127}$/
|
|
||||||
|
|
||||||
// Maximum output size to prevent buffer overflow (10KB)
|
|
||||||
const MAX_OUTPUT_SIZE = 10240
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a command is available in the user's login shell environment
|
|
||||||
* @param command - Command name to check (e.g., 'npx', 'uvx')
|
|
||||||
* @param loginShellEnv - The login shell environment from getLoginShellEnvironment()
|
|
||||||
* @returns Full path to the command if found, null otherwise
|
|
||||||
*/
|
|
||||||
export async function findCommandInShellEnv(
|
|
||||||
command: string,
|
|
||||||
loginShellEnv: Record<string, string>
|
|
||||||
): Promise<string | null> {
|
|
||||||
// Validate command name to prevent command injection
|
|
||||||
if (!VALID_COMMAND_NAME_REGEX.test(command)) {
|
|
||||||
logger.warn(`Invalid command name '${command}' - must only contain alphanumeric characters, underscore, or hyphen`)
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Promise((resolve) => {
|
|
||||||
let resolved = false
|
|
||||||
|
|
||||||
const safeResolve = (value: string | null) => {
|
|
||||||
if (resolved) return
|
|
||||||
resolved = true
|
|
||||||
resolve(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isWin) {
|
|
||||||
// On Windows, use 'where' command
|
|
||||||
const child = spawn('where', [command], {
|
|
||||||
env: loginShellEnv,
|
|
||||||
stdio: ['ignore', 'pipe', 'pipe']
|
|
||||||
})
|
|
||||||
|
|
||||||
let output = ''
|
|
||||||
const timeoutId = setTimeout(() => {
|
|
||||||
if (resolved) return
|
|
||||||
child.kill('SIGKILL')
|
|
||||||
logger.debug(`Timeout checking command '${command}' on Windows`)
|
|
||||||
safeResolve(null)
|
|
||||||
}, COMMAND_LOOKUP_TIMEOUT_MS)
|
|
||||||
|
|
||||||
child.stdout.on('data', (data) => {
|
|
||||||
if (output.length < MAX_OUTPUT_SIZE) {
|
|
||||||
output += data.toString()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
child.on('close', (code) => {
|
|
||||||
clearTimeout(timeoutId)
|
|
||||||
if (resolved) return
|
|
||||||
|
|
||||||
if (code === 0 && output.trim()) {
|
|
||||||
const paths = output.trim().split(/\r?\n/)
|
|
||||||
// Only accept .exe files on Windows - .cmd/.bat files cannot be executed
|
|
||||||
// with spawn({ shell: false }) which is used by MCP SDK's StdioClientTransport
|
|
||||||
const exePath = paths.find((p) => p.toLowerCase().endsWith('.exe'))
|
|
||||||
if (exePath) {
|
|
||||||
logger.debug(`Found command '${command}' at: ${exePath}`)
|
|
||||||
safeResolve(exePath)
|
|
||||||
} else {
|
|
||||||
logger.debug(`Command '${command}' found but not as .exe (${paths[0]}), treating as not found`)
|
|
||||||
safeResolve(null)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.debug(`Command '${command}' not found in shell environment`)
|
|
||||||
safeResolve(null)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
child.on('error', (error) => {
|
|
||||||
clearTimeout(timeoutId)
|
|
||||||
if (resolved) return
|
|
||||||
logger.warn(`Error checking command '${command}':`, { error, platform: 'windows' })
|
|
||||||
safeResolve(null)
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
// Unix/Linux/macOS: use 'command -v' which is POSIX standard
|
|
||||||
// Use /bin/sh for reliability - it's POSIX compliant and always available
|
|
||||||
// This avoids issues with user's custom shell (csh, fish, etc.)
|
|
||||||
// SECURITY: Use positional parameter $1 to prevent command injection
|
|
||||||
const child = spawn('/bin/sh', ['-c', 'command -v "$1"', '--', command], {
|
|
||||||
env: loginShellEnv,
|
|
||||||
stdio: ['ignore', 'pipe', 'pipe']
|
|
||||||
})
|
|
||||||
|
|
||||||
let output = ''
|
|
||||||
const timeoutId = setTimeout(() => {
|
|
||||||
if (resolved) return
|
|
||||||
child.kill('SIGKILL')
|
|
||||||
logger.debug(`Timeout checking command '${command}'`)
|
|
||||||
safeResolve(null)
|
|
||||||
}, COMMAND_LOOKUP_TIMEOUT_MS)
|
|
||||||
|
|
||||||
child.stdout.on('data', (data) => {
|
|
||||||
if (output.length < MAX_OUTPUT_SIZE) {
|
|
||||||
output += data.toString()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
child.on('close', (code) => {
|
|
||||||
clearTimeout(timeoutId)
|
|
||||||
if (resolved) return
|
|
||||||
|
|
||||||
if (code === 0 && output.trim()) {
|
|
||||||
const commandPath = output.trim().split('\n')[0]
|
|
||||||
|
|
||||||
// Validate the output is an absolute path (not an alias, function, or builtin)
|
|
||||||
// command -v can return just the command name for aliases/builtins
|
|
||||||
if (path.isAbsolute(commandPath)) {
|
|
||||||
logger.debug(`Found command '${command}' at: ${commandPath}`)
|
|
||||||
safeResolve(commandPath)
|
|
||||||
} else {
|
|
||||||
logger.debug(`Command '${command}' resolved to non-path '${commandPath}', treating as not found`)
|
|
||||||
safeResolve(null)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.debug(`Command '${command}' not found in shell environment`)
|
|
||||||
safeResolve(null)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
child.on('error', (error) => {
|
|
||||||
clearTimeout(timeoutId)
|
|
||||||
if (resolved) return
|
|
||||||
logger.warn(`Error checking command '${command}':`, { error, platform: 'unix' })
|
|
||||||
safeResolve(null)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -366,77 +225,3 @@ export function validateGitBashPath(customPath?: string | null): string | null {
|
|||||||
logger.debug('Validated custom Git Bash path', { path: resolved })
|
logger.debug('Validated custom Git Bash path', { path: resolved })
|
||||||
return resolved
|
return resolved
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Auto-discover and persist Git Bash path if not already configured
|
|
||||||
* Only called when Git Bash is actually needed
|
|
||||||
*
|
|
||||||
* Precedence order:
|
|
||||||
* 1. CLAUDE_CODE_GIT_BASH_PATH environment variable (highest - runtime override)
|
|
||||||
* 2. Configured path from settings (manual or auto)
|
|
||||||
* 3. Auto-discovery via findGitBash (only if no valid config exists)
|
|
||||||
*/
|
|
||||||
export function autoDiscoverGitBash(): string | null {
|
|
||||||
if (!isWin) {
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1. Check environment variable override first (highest priority)
|
|
||||||
const envOverride = process.env.CLAUDE_CODE_GIT_BASH_PATH
|
|
||||||
if (envOverride) {
|
|
||||||
const validated = validateGitBashPath(envOverride)
|
|
||||||
if (validated) {
|
|
||||||
logger.debug('Using CLAUDE_CODE_GIT_BASH_PATH override', { path: validated })
|
|
||||||
return validated
|
|
||||||
}
|
|
||||||
logger.warn('CLAUDE_CODE_GIT_BASH_PATH provided but path is invalid', { path: envOverride })
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Check if a path is already configured
|
|
||||||
const existingPath = configManager.get<string | undefined>(ConfigKeys.GitBashPath)
|
|
||||||
const existingSource = configManager.get<GitBashPathSource | undefined>(ConfigKeys.GitBashPathSource)
|
|
||||||
|
|
||||||
if (existingPath) {
|
|
||||||
const validated = validateGitBashPath(existingPath)
|
|
||||||
if (validated) {
|
|
||||||
return validated
|
|
||||||
}
|
|
||||||
// Existing path is invalid, try to auto-discover
|
|
||||||
logger.warn('Existing Git Bash path is invalid, attempting auto-discovery', {
|
|
||||||
path: existingPath,
|
|
||||||
source: existingSource
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. Try to find Git Bash via auto-discovery
|
|
||||||
const discoveredPath = findGitBash()
|
|
||||||
if (discoveredPath) {
|
|
||||||
// Persist the discovered path with 'auto' source
|
|
||||||
configManager.set(ConfigKeys.GitBashPath, discoveredPath)
|
|
||||||
configManager.set(ConfigKeys.GitBashPathSource, 'auto')
|
|
||||||
logger.info('Auto-discovered Git Bash path', { path: discoveredPath })
|
|
||||||
}
|
|
||||||
|
|
||||||
return discoveredPath
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get Git Bash path info including source
|
|
||||||
* If no path is configured, triggers auto-discovery first
|
|
||||||
*/
|
|
||||||
export function getGitBashPathInfo(): GitBashPathInfo {
|
|
||||||
if (!isWin) {
|
|
||||||
return { path: null, source: null }
|
|
||||||
}
|
|
||||||
|
|
||||||
let path = configManager.get<string | null>(ConfigKeys.GitBashPath) ?? null
|
|
||||||
let source = configManager.get<GitBashPathSource | null>(ConfigKeys.GitBashPathSource) ?? null
|
|
||||||
|
|
||||||
// If no path configured, trigger auto-discovery (handles upgrade from old versions)
|
|
||||||
if (!path) {
|
|
||||||
path = autoDiscoverGitBash()
|
|
||||||
source = path ? 'auto' : null
|
|
||||||
}
|
|
||||||
|
|
||||||
return { path, source }
|
|
||||||
}
|
|
||||||
|
|||||||
@ -2,17 +2,9 @@ import type { PermissionUpdate } from '@anthropic-ai/claude-agent-sdk'
|
|||||||
import { electronAPI } from '@electron-toolkit/preload'
|
import { electronAPI } from '@electron-toolkit/preload'
|
||||||
import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
||||||
import type { SpanContext } from '@opentelemetry/api'
|
import type { SpanContext } from '@opentelemetry/api'
|
||||||
import type { GitBashPathInfo, TerminalConfig, UpgradeChannel } from '@shared/config/constant'
|
import type { TerminalConfig, UpgradeChannel } from '@shared/config/constant'
|
||||||
import type { LogLevel, LogSourceWithContext } from '@shared/config/logger'
|
import type { LogLevel, LogSourceWithContext } from '@shared/config/logger'
|
||||||
import type {
|
import type { FileChangeEvent, WebviewKeyEvent } from '@shared/config/types'
|
||||||
FileChangeEvent,
|
|
||||||
LanClientEvent,
|
|
||||||
LanFileCompleteMessage,
|
|
||||||
LanHandshakeAckMessage,
|
|
||||||
LocalTransferConnectPayload,
|
|
||||||
LocalTransferState,
|
|
||||||
WebviewKeyEvent
|
|
||||||
} from '@shared/config/types'
|
|
||||||
import type { MCPServerLogEntry } from '@shared/config/types'
|
import type { MCPServerLogEntry } from '@shared/config/types'
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
import type { Notification } from '@types'
|
import type { Notification } from '@types'
|
||||||
@ -134,7 +126,6 @@ const api = {
|
|||||||
getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName),
|
getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName),
|
||||||
checkGitBash: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.System_CheckGitBash),
|
checkGitBash: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.System_CheckGitBash),
|
||||||
getGitBashPath: (): Promise<string | null> => ipcRenderer.invoke(IpcChannel.System_GetGitBashPath),
|
getGitBashPath: (): Promise<string | null> => ipcRenderer.invoke(IpcChannel.System_GetGitBashPath),
|
||||||
getGitBashPathInfo: (): Promise<GitBashPathInfo> => ipcRenderer.invoke(IpcChannel.System_GetGitBashPathInfo),
|
|
||||||
setGitBashPath: (newPath: string | null): Promise<boolean> =>
|
setGitBashPath: (newPath: string | null): Promise<boolean> =>
|
||||||
ipcRenderer.invoke(IpcChannel.System_SetGitBashPath, newPath)
|
ipcRenderer.invoke(IpcChannel.System_SetGitBashPath, newPath)
|
||||||
},
|
},
|
||||||
@ -180,11 +171,7 @@ const api = {
|
|||||||
listS3Files: (s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_ListS3Files, s3Config),
|
listS3Files: (s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_ListS3Files, s3Config),
|
||||||
deleteS3File: (fileName: string, s3Config: S3Config) =>
|
deleteS3File: (fileName: string, s3Config: S3Config) =>
|
||||||
ipcRenderer.invoke(IpcChannel.Backup_DeleteS3File, fileName, s3Config),
|
ipcRenderer.invoke(IpcChannel.Backup_DeleteS3File, fileName, s3Config),
|
||||||
checkS3Connection: (s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_CheckS3Connection, s3Config),
|
checkS3Connection: (s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_CheckS3Connection, s3Config)
|
||||||
createLanTransferBackup: (data: string): Promise<string> =>
|
|
||||||
ipcRenderer.invoke(IpcChannel.Backup_CreateLanTransferBackup, data),
|
|
||||||
deleteTempBackup: (filePath: string): Promise<boolean> =>
|
|
||||||
ipcRenderer.invoke(IpcChannel.Backup_DeleteTempBackup, filePath)
|
|
||||||
},
|
},
|
||||||
file: {
|
file: {
|
||||||
select: (options?: OpenDialogOptions): Promise<FileMetadata[] | null> =>
|
select: (options?: OpenDialogOptions): Promise<FileMetadata[] | null> =>
|
||||||
@ -310,8 +297,7 @@ const api = {
|
|||||||
deleteUser: (userId: string) => ipcRenderer.invoke(IpcChannel.Memory_DeleteUser, userId),
|
deleteUser: (userId: string) => ipcRenderer.invoke(IpcChannel.Memory_DeleteUser, userId),
|
||||||
deleteAllMemoriesForUser: (userId: string) =>
|
deleteAllMemoriesForUser: (userId: string) =>
|
||||||
ipcRenderer.invoke(IpcChannel.Memory_DeleteAllMemoriesForUser, userId),
|
ipcRenderer.invoke(IpcChannel.Memory_DeleteAllMemoriesForUser, userId),
|
||||||
getUsersList: () => ipcRenderer.invoke(IpcChannel.Memory_GetUsersList),
|
getUsersList: () => ipcRenderer.invoke(IpcChannel.Memory_GetUsersList)
|
||||||
migrateMemoryDb: () => ipcRenderer.invoke(IpcChannel.Memory_MigrateMemoryDb)
|
|
||||||
},
|
},
|
||||||
window: {
|
window: {
|
||||||
setMinimumSize: (width: number, height: number) =>
|
setMinimumSize: (width: number, height: number) =>
|
||||||
@ -442,7 +428,7 @@ const api = {
|
|||||||
ipcRenderer.invoke(IpcChannel.Nutstore_GetDirectoryContents, token, path)
|
ipcRenderer.invoke(IpcChannel.Nutstore_GetDirectoryContents, token, path)
|
||||||
},
|
},
|
||||||
searchService: {
|
searchService: {
|
||||||
openSearchWindow: (uid: string, show?: boolean) => ipcRenderer.invoke(IpcChannel.SearchWindow_Open, uid, show),
|
openSearchWindow: (uid: string) => ipcRenderer.invoke(IpcChannel.SearchWindow_Open, uid),
|
||||||
closeSearchWindow: (uid: string) => ipcRenderer.invoke(IpcChannel.SearchWindow_Close, uid),
|
closeSearchWindow: (uid: string) => ipcRenderer.invoke(IpcChannel.SearchWindow_Close, uid),
|
||||||
openUrlInSearchWindow: (uid: string, url: string) => ipcRenderer.invoke(IpcChannel.SearchWindow_OpenUrl, uid, url)
|
openUrlInSearchWindow: (uid: string, url: string) => ipcRenderer.invoke(IpcChannel.SearchWindow_OpenUrl, uid, url)
|
||||||
},
|
},
|
||||||
@ -602,32 +588,12 @@ const api = {
|
|||||||
writeContent: (options: WritePluginContentOptions): Promise<PluginResult<void>> =>
|
writeContent: (options: WritePluginContentOptions): Promise<PluginResult<void>> =>
|
||||||
ipcRenderer.invoke(IpcChannel.ClaudeCodePlugin_WriteContent, options)
|
ipcRenderer.invoke(IpcChannel.ClaudeCodePlugin_WriteContent, options)
|
||||||
},
|
},
|
||||||
localTransfer: {
|
webSocket: {
|
||||||
getState: (): Promise<LocalTransferState> => ipcRenderer.invoke(IpcChannel.LocalTransfer_ListServices),
|
start: () => ipcRenderer.invoke(IpcChannel.WebSocket_Start),
|
||||||
startScan: (): Promise<LocalTransferState> => ipcRenderer.invoke(IpcChannel.LocalTransfer_StartScan),
|
stop: () => ipcRenderer.invoke(IpcChannel.WebSocket_Stop),
|
||||||
stopScan: (): Promise<LocalTransferState> => ipcRenderer.invoke(IpcChannel.LocalTransfer_StopScan),
|
status: () => ipcRenderer.invoke(IpcChannel.WebSocket_Status),
|
||||||
connect: (payload: LocalTransferConnectPayload): Promise<LanHandshakeAckMessage> =>
|
sendFile: (filePath: string) => ipcRenderer.invoke(IpcChannel.WebSocket_SendFile, filePath),
|
||||||
ipcRenderer.invoke(IpcChannel.LocalTransfer_Connect, payload),
|
getAllCandidates: () => ipcRenderer.invoke(IpcChannel.WebSocket_GetAllCandidates)
|
||||||
disconnect: (): Promise<void> => ipcRenderer.invoke(IpcChannel.LocalTransfer_Disconnect),
|
|
||||||
onServicesUpdated: (callback: (state: LocalTransferState) => void): (() => void) => {
|
|
||||||
const channel = IpcChannel.LocalTransfer_ServicesUpdated
|
|
||||||
const listener = (_: Electron.IpcRendererEvent, state: LocalTransferState) => callback(state)
|
|
||||||
ipcRenderer.on(channel, listener)
|
|
||||||
return () => {
|
|
||||||
ipcRenderer.removeListener(channel, listener)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
onClientEvent: (callback: (event: LanClientEvent) => void): (() => void) => {
|
|
||||||
const channel = IpcChannel.LocalTransfer_ClientEvent
|
|
||||||
const listener = (_: Electron.IpcRendererEvent, event: LanClientEvent) => callback(event)
|
|
||||||
ipcRenderer.on(channel, listener)
|
|
||||||
return () => {
|
|
||||||
ipcRenderer.removeListener(channel, listener)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
sendFile: (filePath: string): Promise<LanFileCompleteMessage> =>
|
|
||||||
ipcRenderer.invoke(IpcChannel.LocalTransfer_SendFile, { filePath }),
|
|
||||||
cancelTransfer: (): Promise<void> => ipcRenderer.invoke(IpcChannel.LocalTransfer_CancelTransfer)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,38 +0,0 @@
|
|||||||
import { describe, expect, it } from 'vitest'
|
|
||||||
|
|
||||||
import { normalizeAzureOpenAIEndpoint } from '../openai/azureOpenAIEndpoint'
|
|
||||||
|
|
||||||
describe('normalizeAzureOpenAIEndpoint', () => {
|
|
||||||
it.each([
|
|
||||||
{
|
|
||||||
apiHost: 'https://example.openai.azure.com/openai',
|
|
||||||
expectedEndpoint: 'https://example.openai.azure.com'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
apiHost: 'https://example.openai.azure.com/openai/',
|
|
||||||
expectedEndpoint: 'https://example.openai.azure.com'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
apiHost: 'https://example.openai.azure.com/openai/v1',
|
|
||||||
expectedEndpoint: 'https://example.openai.azure.com'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
apiHost: 'https://example.openai.azure.com/openai/v1/',
|
|
||||||
expectedEndpoint: 'https://example.openai.azure.com'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
apiHost: 'https://example.openai.azure.com',
|
|
||||||
expectedEndpoint: 'https://example.openai.azure.com'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
apiHost: 'https://example.openai.azure.com/',
|
|
||||||
expectedEndpoint: 'https://example.openai.azure.com'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
apiHost: 'https://example.openai.azure.com/OPENAI/V1',
|
|
||||||
expectedEndpoint: 'https://example.openai.azure.com'
|
|
||||||
}
|
|
||||||
])('strips trailing /openai from $apiHost', ({ apiHost, expectedEndpoint }) => {
|
|
||||||
expect(normalizeAzureOpenAIEndpoint(apiHost)).toBe(expectedEndpoint)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@ -10,7 +10,7 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
|||||||
import {
|
import {
|
||||||
findTokenLimit,
|
findTokenLimit,
|
||||||
GEMINI_FLASH_MODEL_REGEX,
|
GEMINI_FLASH_MODEL_REGEX,
|
||||||
getModelSupportedReasoningEffortOptions,
|
getThinkModelType,
|
||||||
isDeepSeekHybridInferenceModel,
|
isDeepSeekHybridInferenceModel,
|
||||||
isDoubaoThinkingAutoModel,
|
isDoubaoThinkingAutoModel,
|
||||||
isGPT5SeriesModel,
|
isGPT5SeriesModel,
|
||||||
@ -33,6 +33,7 @@ import {
|
|||||||
isSupportedThinkingTokenQwenModel,
|
isSupportedThinkingTokenQwenModel,
|
||||||
isSupportedThinkingTokenZhipuModel,
|
isSupportedThinkingTokenZhipuModel,
|
||||||
isVisionModel,
|
isVisionModel,
|
||||||
|
MODEL_SUPPORTED_REASONING_EFFORT,
|
||||||
ZHIPU_RESULT_TOKENS
|
ZHIPU_RESULT_TOKENS
|
||||||
} from '@renderer/config/models'
|
} from '@renderer/config/models'
|
||||||
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
|
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
|
||||||
@ -142,10 +143,6 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
|||||||
return { thinking: { type: reasoningEffort ? 'enabled' : 'disabled' } }
|
return { thinking: { type: reasoningEffort ? 'enabled' : 'disabled' } }
|
||||||
}
|
}
|
||||||
|
|
||||||
if (reasoningEffort === 'default') {
|
|
||||||
return {}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!reasoningEffort) {
|
if (!reasoningEffort) {
|
||||||
// DeepSeek hybrid inference models, v3.1 and maybe more in the future
|
// DeepSeek hybrid inference models, v3.1 and maybe more in the future
|
||||||
// 不同的 provider 有不同的思考控制方式,在这里统一解决
|
// 不同的 provider 有不同的思考控制方式,在这里统一解决
|
||||||
@ -307,15 +304,16 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
|||||||
// Grok models/Perplexity models/OpenAI models
|
// Grok models/Perplexity models/OpenAI models
|
||||||
if (isSupportedReasoningEffortModel(model)) {
|
if (isSupportedReasoningEffortModel(model)) {
|
||||||
// 检查模型是否支持所选选项
|
// 检查模型是否支持所选选项
|
||||||
const supportedOptions = getModelSupportedReasoningEffortOptions(model)?.filter((option) => option !== 'default')
|
const modelType = getThinkModelType(model)
|
||||||
if (supportedOptions?.includes(reasoningEffort)) {
|
const supportedOptions = MODEL_SUPPORTED_REASONING_EFFORT[modelType]
|
||||||
|
if (supportedOptions.includes(reasoningEffort)) {
|
||||||
return {
|
return {
|
||||||
reasoning_effort: reasoningEffort
|
reasoning_effort: reasoningEffort
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 如果不支持,fallback到第一个支持的值
|
// 如果不支持,fallback到第一个支持的值
|
||||||
return {
|
return {
|
||||||
reasoning_effort: supportedOptions?.[0]
|
reasoning_effort: supportedOptions[0]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,7 +29,6 @@ import { withoutTrailingSlash } from '@renderer/utils/api'
|
|||||||
import { isOllamaProvider } from '@renderer/utils/provider'
|
import { isOllamaProvider } from '@renderer/utils/provider'
|
||||||
|
|
||||||
import { BaseApiClient } from '../BaseApiClient'
|
import { BaseApiClient } from '../BaseApiClient'
|
||||||
import { normalizeAzureOpenAIEndpoint } from './azureOpenAIEndpoint'
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('OpenAIBaseClient')
|
const logger = loggerService.withContext('OpenAIBaseClient')
|
||||||
|
|
||||||
@ -70,7 +69,7 @@ export abstract class OpenAIBaseClient<
|
|||||||
const sdk = await this.getSdkInstance()
|
const sdk = await this.getSdkInstance()
|
||||||
const response = (await sdk.request({
|
const response = (await sdk.request({
|
||||||
method: 'post',
|
method: 'post',
|
||||||
path: '/v1/images/generations',
|
path: '/images/generations',
|
||||||
signal,
|
signal,
|
||||||
body: {
|
body: {
|
||||||
model,
|
model,
|
||||||
@ -89,11 +88,7 @@ export abstract class OpenAIBaseClient<
|
|||||||
}
|
}
|
||||||
|
|
||||||
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
let sdk: OpenAI = await this.getSdkInstance()
|
const sdk = await this.getSdkInstance()
|
||||||
if (isOllamaProvider(this.provider)) {
|
|
||||||
const embedBaseUrl = `${this.provider.apiHost.replace(/(\/(api|v1))\/?$/, '')}/v1`
|
|
||||||
sdk = sdk.withOptions({ baseURL: embedBaseUrl })
|
|
||||||
}
|
|
||||||
|
|
||||||
const data = await sdk.embeddings.create({
|
const data = await sdk.embeddings.create({
|
||||||
model: model.id,
|
model: model.id,
|
||||||
@ -214,7 +209,7 @@ export abstract class OpenAIBaseClient<
|
|||||||
dangerouslyAllowBrowser: true,
|
dangerouslyAllowBrowser: true,
|
||||||
apiKey: apiKeyForSdkInstance,
|
apiKey: apiKeyForSdkInstance,
|
||||||
apiVersion: this.provider.apiVersion,
|
apiVersion: this.provider.apiVersion,
|
||||||
endpoint: normalizeAzureOpenAIEndpoint(this.provider.apiHost)
|
endpoint: this.provider.apiHost
|
||||||
}) as TSdkInstance
|
}) as TSdkInstance
|
||||||
} else {
|
} else {
|
||||||
this.sdkInstance = new OpenAI({
|
this.sdkInstance = new OpenAI({
|
||||||
|
|||||||
@ -1,4 +0,0 @@
|
|||||||
export function normalizeAzureOpenAIEndpoint(apiHost: string): string {
|
|
||||||
const normalizedHost = apiHost.replace(/\/+$/, '')
|
|
||||||
return normalizedHost.replace(/\/openai(?:\/v1)?$/i, '')
|
|
||||||
}
|
|
||||||
@ -2,6 +2,7 @@ import { loggerService } from '@logger'
|
|||||||
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
||||||
import type { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
|
import type { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
|
||||||
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
||||||
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import { withSpanResult } from '@renderer/services/SpanManagerService'
|
import { withSpanResult } from '@renderer/services/SpanManagerService'
|
||||||
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||||
@ -159,6 +160,9 @@ export default class AiProvider {
|
|||||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
try {
|
try {
|
||||||
// Use the SDK instance to test embedding capabilities
|
// Use the SDK instance to test embedding capabilities
|
||||||
|
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
|
||||||
|
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
|
||||||
|
}
|
||||||
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
||||||
return dimensions
|
return dimensions
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import type { Chunk } from '@renderer/types/chunk'
|
|||||||
import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider'
|
import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider'
|
||||||
import type { LanguageModelMiddleware } from 'ai'
|
import type { LanguageModelMiddleware } from 'ai'
|
||||||
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
|
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
|
||||||
|
import { isEmpty } from 'lodash'
|
||||||
|
|
||||||
import { getAiSdkProviderId } from '../provider/factory'
|
import { getAiSdkProviderId } from '../provider/factory'
|
||||||
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
|
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
|
||||||
@ -15,6 +16,7 @@ import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMidd
|
|||||||
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
|
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
|
||||||
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
|
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
|
||||||
import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware'
|
import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware'
|
||||||
|
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
|
||||||
|
|
||||||
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
|
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
|
||||||
|
|
||||||
@ -134,6 +136,15 @@ export class AiSdkMiddlewareBuilder {
|
|||||||
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] {
|
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] {
|
||||||
const builder = new AiSdkMiddlewareBuilder()
|
const builder = new AiSdkMiddlewareBuilder()
|
||||||
|
|
||||||
|
// 0. 知识库强制调用中间件(必须在最前面,确保第一轮强制调用知识库)
|
||||||
|
if (!isEmpty(config.assistant?.knowledge_bases?.map((base) => base.id)) && config.knowledgeRecognition !== 'on') {
|
||||||
|
builder.add({
|
||||||
|
name: 'force-knowledge-first',
|
||||||
|
middleware: toolChoiceMiddleware('builtin_knowledge_search')
|
||||||
|
})
|
||||||
|
logger.debug('Added toolChoice middleware to force knowledge base search on first round')
|
||||||
|
}
|
||||||
|
|
||||||
// 1. 根据provider添加特定中间件
|
// 1. 根据provider添加特定中间件
|
||||||
if (config.provider) {
|
if (config.provider) {
|
||||||
addProviderSpecificMiddlewares(builder, config)
|
addProviderSpecificMiddlewares(builder, config)
|
||||||
|
|||||||
@ -31,7 +31,7 @@ import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool'
|
|||||||
|
|
||||||
const logger = loggerService.withContext('SearchOrchestrationPlugin')
|
const logger = loggerService.withContext('SearchOrchestrationPlugin')
|
||||||
|
|
||||||
export const getMessageContent = (message: ModelMessage) => {
|
const getMessageContent = (message: ModelMessage) => {
|
||||||
if (typeof message.content === 'string') return message.content
|
if (typeof message.content === 'string') return message.content
|
||||||
return message.content.reduce((acc, part) => {
|
return message.content.reduce((acc, part) => {
|
||||||
if (part.type === 'text') {
|
if (part.type === 'text') {
|
||||||
@ -266,14 +266,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
|||||||
// 判断是否需要各种搜索
|
// 判断是否需要各种搜索
|
||||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'off'
|
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||||
const shouldWebSearch = !!assistant.webSearchProviderId
|
const shouldWebSearch = !!assistant.webSearchProviderId
|
||||||
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
|
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
|
||||||
const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory
|
const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory
|
||||||
|
|
||||||
// 执行意图分析
|
// 执行意图分析
|
||||||
if (shouldWebSearch || shouldKnowledgeSearch) {
|
if (shouldWebSearch || hasKnowledgeBase) {
|
||||||
const analysisResult = await analyzeSearchIntent(lastUserMessage, assistant, {
|
const analysisResult = await analyzeSearchIntent(lastUserMessage, assistant, {
|
||||||
shouldWebSearch,
|
shouldWebSearch,
|
||||||
shouldKnowledgeSearch,
|
shouldKnowledgeSearch,
|
||||||
@ -330,25 +330,41 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
|||||||
// 📚 知识库搜索工具配置
|
// 📚 知识库搜索工具配置
|
||||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'off'
|
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||||
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
|
|
||||||
|
|
||||||
if (shouldKnowledgeSearch) {
|
if (hasKnowledgeBase) {
|
||||||
// on 模式:根据意图识别结果决定是否添加工具
|
if (knowledgeRecognition === 'off') {
|
||||||
const needsKnowledgeSearch =
|
// off 模式:直接添加知识库搜索工具,使用用户消息作为搜索关键词
|
||||||
analysisResult?.knowledge &&
|
|
||||||
analysisResult.knowledge.question &&
|
|
||||||
analysisResult.knowledge.question[0] !== 'not_needed'
|
|
||||||
|
|
||||||
if (needsKnowledgeSearch && analysisResult.knowledge) {
|
|
||||||
// logger.info('📚 Adding knowledge search tool (intent-based)')
|
|
||||||
const userMessage = userMessages[context.requestId]
|
const userMessage = userMessages[context.requestId]
|
||||||
|
const fallbackKeywords = {
|
||||||
|
question: [getMessageContent(userMessage) || 'search'],
|
||||||
|
rewrite: getMessageContent(userMessage) || 'search'
|
||||||
|
}
|
||||||
|
// logger.info('📚 Adding knowledge search tool (force mode)')
|
||||||
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(
|
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(
|
||||||
assistant,
|
assistant,
|
||||||
analysisResult.knowledge,
|
fallbackKeywords,
|
||||||
getMessageContent(userMessage),
|
getMessageContent(userMessage),
|
||||||
topicId
|
topicId
|
||||||
)
|
)
|
||||||
|
// params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' }
|
||||||
|
} else {
|
||||||
|
// on 模式:根据意图识别结果决定是否添加工具
|
||||||
|
const needsKnowledgeSearch =
|
||||||
|
analysisResult?.knowledge &&
|
||||||
|
analysisResult.knowledge.question &&
|
||||||
|
analysisResult.knowledge.question[0] !== 'not_needed'
|
||||||
|
|
||||||
|
if (needsKnowledgeSearch && analysisResult.knowledge) {
|
||||||
|
// logger.info('📚 Adding knowledge search tool (intent-based)')
|
||||||
|
const userMessage = userMessages[context.requestId]
|
||||||
|
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(
|
||||||
|
assistant,
|
||||||
|
analysisResult.knowledge,
|
||||||
|
getMessageContent(userMessage),
|
||||||
|
topicId
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -109,20 +109,6 @@ const createImageBlock = (
|
|||||||
...overrides
|
...overrides
|
||||||
})
|
})
|
||||||
|
|
||||||
const createThinkingBlock = (
|
|
||||||
messageId: string,
|
|
||||||
overrides: Partial<Omit<ThinkingMessageBlock, 'type' | 'messageId'>> = {}
|
|
||||||
): ThinkingMessageBlock => ({
|
|
||||||
id: overrides.id ?? `thinking-block-${++blockCounter}`,
|
|
||||||
messageId,
|
|
||||||
type: MessageBlockType.THINKING,
|
|
||||||
createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(),
|
|
||||||
status: overrides.status ?? MessageBlockStatus.SUCCESS,
|
|
||||||
content: overrides.content ?? 'Let me think...',
|
|
||||||
thinking_millsec: overrides.thinking_millsec ?? 1000,
|
|
||||||
...overrides
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('messageConverter', () => {
|
describe('messageConverter', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
convertFileBlockToFilePartMock.mockReset()
|
convertFileBlockToFilePartMock.mockReset()
|
||||||
@ -243,23 +229,6 @@ describe('messageConverter', () => {
|
|||||||
}
|
}
|
||||||
])
|
])
|
||||||
})
|
})
|
||||||
|
|
||||||
it('includes reasoning parts for assistant messages with thinking blocks', async () => {
|
|
||||||
const model = createModel()
|
|
||||||
const message = createMessage('assistant')
|
|
||||||
message.__mockContent = 'Here is my answer'
|
|
||||||
message.__mockThinkingBlocks = [createThinkingBlock(message.id, { content: 'Let me think...' })]
|
|
||||||
|
|
||||||
const result = await convertMessageToSdkParam(message, false, model)
|
|
||||||
|
|
||||||
expect(result).toEqual({
|
|
||||||
role: 'assistant',
|
|
||||||
content: [
|
|
||||||
{ type: 'text', text: 'Here is my answer' },
|
|
||||||
{ type: 'reasoning', text: 'Let me think...' }
|
|
||||||
]
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('convertMessagesToSdkMessages', () => {
|
describe('convertMessagesToSdkMessages', () => {
|
||||||
|
|||||||
@ -18,7 +18,7 @@ vi.mock('@renderer/services/AssistantService', () => ({
|
|||||||
toolUseMode: assistant.settings?.toolUseMode ?? 'prompt',
|
toolUseMode: assistant.settings?.toolUseMode ?? 'prompt',
|
||||||
defaultModel: assistant.defaultModel,
|
defaultModel: assistant.defaultModel,
|
||||||
customParameters: assistant.settings?.customParameters ?? [],
|
customParameters: assistant.settings?.customParameters ?? [],
|
||||||
reasoning_effort: assistant.settings?.reasoning_effort ?? 'default',
|
reasoning_effort: assistant.settings?.reasoning_effort,
|
||||||
reasoning_effort_cache: assistant.settings?.reasoning_effort_cache,
|
reasoning_effort_cache: assistant.settings?.reasoning_effort_cache,
|
||||||
qwenThinkMode: assistant.settings?.qwenThinkMode
|
qwenThinkMode: assistant.settings?.qwenThinkMode
|
||||||
})
|
})
|
||||||
|
|||||||
@ -3,7 +3,6 @@
|
|||||||
* 将 Cherry Studio 消息格式转换为 AI SDK 消息格式
|
* 将 Cherry Studio 消息格式转换为 AI SDK 消息格式
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import type { ReasoningPart } from '@ai-sdk/provider-utils'
|
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { isImageEnhancementModel, isVisionModel } from '@renderer/config/models'
|
import { isImageEnhancementModel, isVisionModel } from '@renderer/config/models'
|
||||||
import type { Message, Model } from '@renderer/types'
|
import type { Message, Model } from '@renderer/types'
|
||||||
@ -164,13 +163,13 @@ async function convertMessageToAssistantModelMessage(
|
|||||||
thinkingBlocks: ThinkingMessageBlock[],
|
thinkingBlocks: ThinkingMessageBlock[],
|
||||||
model?: Model
|
model?: Model
|
||||||
): Promise<AssistantModelMessage> {
|
): Promise<AssistantModelMessage> {
|
||||||
const parts: Array<TextPart | ReasoningPart | FilePart> = []
|
const parts: Array<TextPart | FilePart> = []
|
||||||
if (content) {
|
if (content) {
|
||||||
parts.push({ type: 'text', text: content })
|
parts.push({ type: 'text', text: content })
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const thinkingBlock of thinkingBlocks) {
|
for (const thinkingBlock of thinkingBlocks) {
|
||||||
parts.push({ type: 'reasoning', text: thinkingBlock.content })
|
parts.push({ type: 'text', text: thinkingBlock.content })
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const fileBlock of fileBlocks) {
|
for (const fileBlock of fileBlocks) {
|
||||||
|
|||||||
@ -28,14 +28,13 @@ import { getAnthropicThinkingBudget } from '../utils/reasoning'
|
|||||||
* - Disabled for models that do not support temperature.
|
* - Disabled for models that do not support temperature.
|
||||||
* - Disabled for Claude 4.5 reasoning models when TopP is enabled and temperature is disabled.
|
* - Disabled for Claude 4.5 reasoning models when TopP is enabled and temperature is disabled.
|
||||||
* Otherwise, returns the temperature value if the assistant has temperature enabled.
|
* Otherwise, returns the temperature value if the assistant has temperature enabled.
|
||||||
|
|
||||||
*/
|
*/
|
||||||
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
|
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isSupportTemperatureModel(model, assistant)) {
|
if (!isSupportTemperatureModel(model)) {
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -47,10 +46,6 @@ export function getTemperature(assistant: Assistant, model: Model): number | und
|
|||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
return getTemperatureValue(assistant, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
function getTemperatureValue(assistant: Assistant, model: Model): number | undefined {
|
|
||||||
const assistantSettings = getAssistantSettings(assistant)
|
const assistantSettings = getAssistantSettings(assistant)
|
||||||
let temperature = assistantSettings?.temperature
|
let temperature = assistantSettings?.temperature
|
||||||
if (temperature && isMaxTemperatureOneModel(model)) {
|
if (temperature && isMaxTemperatureOneModel(model)) {
|
||||||
@ -73,17 +68,13 @@ export function getTopP(assistant: Assistant, model: Model): number | undefined
|
|||||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
if (!isSupportTopPModel(model, assistant)) {
|
if (!isSupportTopPModel(model)) {
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
if (isTemperatureTopPMutuallyExclusiveModel(model) && assistant.settings?.enableTemperature) {
|
if (isTemperatureTopPMutuallyExclusiveModel(model) && assistant.settings?.enableTemperature) {
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
return getTopPValue(assistant)
|
|
||||||
}
|
|
||||||
|
|
||||||
function getTopPValue(assistant: Assistant): number | undefined {
|
|
||||||
const assistantSettings = getAssistantSettings(assistant)
|
const assistantSettings = getAssistantSettings(assistant)
|
||||||
// FIXME: assistant.settings.enableTopP should be always a boolean value.
|
// FIXME: assistant.settings.enableTopP should be always a boolean value.
|
||||||
const enableTopP = assistantSettings.enableTopP ?? DEFAULT_ASSISTANT_SETTINGS.enableTopP
|
const enableTopP = assistantSettings.enableTopP ?? DEFAULT_ASSISTANT_SETTINGS.enableTopP
|
||||||
|
|||||||
@ -21,7 +21,6 @@ import {
|
|||||||
isGrokModel,
|
isGrokModel,
|
||||||
isOpenAIModel,
|
isOpenAIModel,
|
||||||
isOpenRouterBuiltInWebSearchModel,
|
isOpenRouterBuiltInWebSearchModel,
|
||||||
isPureGenerateImageModel,
|
|
||||||
isSupportedReasoningEffortModel,
|
isSupportedReasoningEffortModel,
|
||||||
isSupportedThinkingTokenModel,
|
isSupportedThinkingTokenModel,
|
||||||
isWebSearchModel
|
isWebSearchModel
|
||||||
@ -34,7 +33,7 @@ import { type Assistant, type MCPTool, type Provider, SystemProviderIds } from '
|
|||||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||||
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
|
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
|
||||||
import { replacePromptVariables } from '@renderer/utils/prompt'
|
import { replacePromptVariables } from '@renderer/utils/prompt'
|
||||||
import { isAIGatewayProvider, isAwsBedrockProvider, isSupportUrlContextProvider } from '@renderer/utils/provider'
|
import { isAIGatewayProvider, isAwsBedrockProvider } from '@renderer/utils/provider'
|
||||||
import type { ModelMessage, Tool } from 'ai'
|
import type { ModelMessage, Tool } from 'ai'
|
||||||
import { stepCountIs } from 'ai'
|
import { stepCountIs } from 'ai'
|
||||||
|
|
||||||
@ -119,13 +118,7 @@ export async function buildStreamTextParams(
|
|||||||
isOpenRouterBuiltInWebSearchModel(model) ||
|
isOpenRouterBuiltInWebSearchModel(model) ||
|
||||||
model.id.includes('sonar'))
|
model.id.includes('sonar'))
|
||||||
|
|
||||||
// Validate provider and model support to prevent stale state from triggering urlContext
|
const enableUrlContext = assistant.enableUrlContext || false
|
||||||
const enableUrlContext = !!(
|
|
||||||
assistant.enableUrlContext &&
|
|
||||||
isSupportUrlContextProvider(provider) &&
|
|
||||||
!isPureGenerateImageModel(model) &&
|
|
||||||
(isGeminiModel(model) || isAnthropicModel(model))
|
|
||||||
)
|
|
||||||
|
|
||||||
const enableGenerateImage = !!(isGenerateImageModel(model) && assistant.enableGenerateImage)
|
const enableGenerateImage = !!(isGenerateImageModel(model) && assistant.enableGenerateImage)
|
||||||
|
|
||||||
|
|||||||
@ -79,7 +79,7 @@ vi.mock('@renderer/services/AssistantService', () => ({
|
|||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import type { Model, Provider } from '@renderer/types'
|
import type { Model, Provider } from '@renderer/types'
|
||||||
import { formatApiHost } from '@renderer/utils/api'
|
import { formatApiHost } from '@renderer/utils/api'
|
||||||
import { isAzureOpenAIProvider, isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
|
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
|
||||||
|
|
||||||
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
|
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
|
||||||
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'
|
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'
|
||||||
@ -133,17 +133,6 @@ const createPerplexityProvider = (): Provider => ({
|
|||||||
isSystem: false
|
isSystem: false
|
||||||
})
|
})
|
||||||
|
|
||||||
const createAzureProvider = (apiVersion: string): Provider => ({
|
|
||||||
id: 'azure-openai',
|
|
||||||
type: 'azure-openai',
|
|
||||||
name: 'Azure OpenAI',
|
|
||||||
apiKey: 'test-key',
|
|
||||||
apiHost: 'https://example.openai.azure.com/openai',
|
|
||||||
apiVersion,
|
|
||||||
models: [],
|
|
||||||
isSystem: true
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('Copilot responses routing', () => {
|
describe('Copilot responses routing', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
;(globalThis as any).window = {
|
;(globalThis as any).window = {
|
||||||
@ -515,46 +504,3 @@ describe('Stream options includeUsage configuration', () => {
|
|||||||
expect(config.providerId).toBe('github-copilot-openai-compatible')
|
expect(config.providerId).toBe('github-copilot-openai-compatible')
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('Azure OpenAI traditional API routing', () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
;(globalThis as any).window = {
|
|
||||||
...(globalThis as any).window,
|
|
||||||
keyv: createWindowKeyv()
|
|
||||||
}
|
|
||||||
mockGetState.mockReturnValue({
|
|
||||||
settings: {
|
|
||||||
openAI: {
|
|
||||||
streamOptions: {
|
|
||||||
includeUsage: undefined
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
vi.mocked(isAzureOpenAIProvider).mockImplementation((provider) => provider.type === 'azure-openai')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('uses deployment-based URLs when apiVersion is a date version', () => {
|
|
||||||
const provider = createAzureProvider('2024-02-15-preview')
|
|
||||||
const config = providerToAiSdkConfig(provider, createModel('gpt-4o', 'GPT-4o', provider.id))
|
|
||||||
|
|
||||||
expect(config.providerId).toBe('azure')
|
|
||||||
expect(config.options.apiVersion).toBe('2024-02-15-preview')
|
|
||||||
expect(config.options.useDeploymentBasedUrls).toBe(true)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('does not force deployment-based URLs for apiVersion v1/preview', () => {
|
|
||||||
const v1Provider = createAzureProvider('v1')
|
|
||||||
const v1Config = providerToAiSdkConfig(v1Provider, createModel('gpt-4o', 'GPT-4o', v1Provider.id))
|
|
||||||
expect(v1Config.providerId).toBe('azure-responses')
|
|
||||||
expect(v1Config.options.apiVersion).toBe('v1')
|
|
||||||
expect(v1Config.options.useDeploymentBasedUrls).toBeUndefined()
|
|
||||||
|
|
||||||
const previewProvider = createAzureProvider('preview')
|
|
||||||
const previewConfig = providerToAiSdkConfig(previewProvider, createModel('gpt-4o', 'GPT-4o', previewProvider.id))
|
|
||||||
expect(previewConfig.providerId).toBe('azure-responses')
|
|
||||||
expect(previewConfig.options.apiVersion).toBe('preview')
|
|
||||||
expect(previewConfig.options.useDeploymentBasedUrls).toBeUndefined()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|||||||
@ -32,7 +32,6 @@ import {
|
|||||||
isSupportStreamOptionsProvider,
|
isSupportStreamOptionsProvider,
|
||||||
isVertexProvider
|
isVertexProvider
|
||||||
} from '@renderer/utils/provider'
|
} from '@renderer/utils/provider'
|
||||||
import { defaultAppHeaders } from '@shared/utils'
|
|
||||||
import { cloneDeep, isEmpty } from 'lodash'
|
import { cloneDeep, isEmpty } from 'lodash'
|
||||||
|
|
||||||
import type { AiSdkConfig } from '../types'
|
import type { AiSdkConfig } from '../types'
|
||||||
@ -198,13 +197,18 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A
|
|||||||
extraOptions.mode = 'chat'
|
extraOptions.mode = 'chat'
|
||||||
}
|
}
|
||||||
|
|
||||||
extraOptions.headers = {
|
// 添加额外headers
|
||||||
...defaultAppHeaders(),
|
if (actualProvider.extra_headers) {
|
||||||
...actualProvider.extra_headers
|
extraOptions.headers = actualProvider.extra_headers
|
||||||
}
|
// copy from openaiBaseClient/openaiResponseApiClient
|
||||||
|
if (aiSdkProviderId === 'openai') {
|
||||||
if (aiSdkProviderId === 'openai') {
|
extraOptions.headers = {
|
||||||
extraOptions.headers['X-Api-Key'] = baseConfig.apiKey
|
...extraOptions.headers,
|
||||||
|
'HTTP-Referer': 'https://cherry-ai.com',
|
||||||
|
'X-Title': 'Cherry Studio',
|
||||||
|
'X-Api-Key': baseConfig.apiKey
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// azure
|
// azure
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
|
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
|
||||||
@ -214,15 +218,6 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A
|
|||||||
} else if (aiSdkProviderId === 'azure') {
|
} else if (aiSdkProviderId === 'azure') {
|
||||||
extraOptions.mode = 'chat'
|
extraOptions.mode = 'chat'
|
||||||
}
|
}
|
||||||
if (isAzureOpenAIProvider(actualProvider)) {
|
|
||||||
const apiVersion = actualProvider.apiVersion?.trim()
|
|
||||||
if (apiVersion) {
|
|
||||||
extraOptions.apiVersion = apiVersion
|
|
||||||
if (!['preview', 'v1'].includes(apiVersion)) {
|
|
||||||
extraOptions.useDeploymentBasedUrls = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// bedrock
|
// bedrock
|
||||||
if (aiSdkProviderId === 'bedrock') {
|
if (aiSdkProviderId === 'bedrock') {
|
||||||
@ -259,7 +254,7 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A
|
|||||||
// CherryIN API Host
|
// CherryIN API Host
|
||||||
const cherryinProvider = getProviderById(SystemProviderIds.cherryin)
|
const cherryinProvider = getProviderById(SystemProviderIds.cherryin)
|
||||||
if (cherryinProvider) {
|
if (cherryinProvider) {
|
||||||
extraOptions.anthropicBaseURL = cherryinProvider.anthropicApiHost + '/v1'
|
extraOptions.anthropicBaseURL = cherryinProvider.anthropicApiHost
|
||||||
extraOptions.geminiBaseURL = cherryinProvider.apiHost + '/v1beta/models'
|
extraOptions.geminiBaseURL = cherryinProvider.apiHost + '/v1beta/models'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -24,8 +24,7 @@ export const memorySearchTool = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const memoryConfig = selectMemoryConfig(store.getState())
|
const memoryConfig = selectMemoryConfig(store.getState())
|
||||||
|
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
||||||
if (!memoryConfig.llmModel || !memoryConfig.embeddingModel) {
|
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -464,8 +464,7 @@ describe('options utils', () => {
|
|||||||
custom_param: 'custom_value',
|
custom_param: 'custom_value',
|
||||||
another_param: 123,
|
another_param: 123,
|
||||||
serviceTier: undefined,
|
serviceTier: undefined,
|
||||||
textVerbosity: undefined,
|
textVerbosity: undefined
|
||||||
store: false
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@ -11,7 +11,6 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
|
|||||||
|
|
||||||
import {
|
import {
|
||||||
getAnthropicReasoningParams,
|
getAnthropicReasoningParams,
|
||||||
getAnthropicThinkingBudget,
|
|
||||||
getBedrockReasoningParams,
|
getBedrockReasoningParams,
|
||||||
getCustomParameters,
|
getCustomParameters,
|
||||||
getGeminiReasoningParams,
|
getGeminiReasoningParams,
|
||||||
@ -90,8 +89,7 @@ vi.mock('@renderer/config/models', async (importOriginal) => {
|
|||||||
isQwenAlwaysThinkModel: vi.fn(() => false),
|
isQwenAlwaysThinkModel: vi.fn(() => false),
|
||||||
isSupportedThinkingTokenHunyuanModel: vi.fn(() => false),
|
isSupportedThinkingTokenHunyuanModel: vi.fn(() => false),
|
||||||
isSupportedThinkingTokenModel: vi.fn(() => false),
|
isSupportedThinkingTokenModel: vi.fn(() => false),
|
||||||
isGPT51SeriesModel: vi.fn(() => false),
|
isGPT51SeriesModel: vi.fn(() => false)
|
||||||
findTokenLimit: vi.fn(actual.findTokenLimit)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -598,7 +596,7 @@ describe('reasoning utils', () => {
|
|||||||
expect(result).toEqual({})
|
expect(result).toEqual({})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should return disabled thinking when reasoning effort is none', async () => {
|
it('should return disabled thinking when no reasoning effort', async () => {
|
||||||
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
|
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
@ -613,9 +611,7 @@ describe('reasoning utils', () => {
|
|||||||
const assistant: Assistant = {
|
const assistant: Assistant = {
|
||||||
id: 'test',
|
id: 'test',
|
||||||
name: 'Test',
|
name: 'Test',
|
||||||
settings: {
|
settings: {}
|
||||||
reasoning_effort: 'none'
|
|
||||||
}
|
|
||||||
} as Assistant
|
} as Assistant
|
||||||
|
|
||||||
const result = getAnthropicReasoningParams(assistant, model)
|
const result = getAnthropicReasoningParams(assistant, model)
|
||||||
@ -651,7 +647,7 @@ describe('reasoning utils', () => {
|
|||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
thinking: {
|
thinking: {
|
||||||
type: 'enabled',
|
type: 'enabled',
|
||||||
budgetTokens: 4096
|
budgetTokens: 2048
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@ -679,7 +675,7 @@ describe('reasoning utils', () => {
|
|||||||
expect(result).toEqual({})
|
expect(result).toEqual({})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should disable thinking for Flash models when reasoning effort is none', async () => {
|
it('should disable thinking for Flash models without reasoning effort', async () => {
|
||||||
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
|
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
@ -694,9 +690,7 @@ describe('reasoning utils', () => {
|
|||||||
const assistant: Assistant = {
|
const assistant: Assistant = {
|
||||||
id: 'test',
|
id: 'test',
|
||||||
name: 'Test',
|
name: 'Test',
|
||||||
settings: {
|
settings: {}
|
||||||
reasoning_effort: 'none'
|
|
||||||
}
|
|
||||||
} as Assistant
|
} as Assistant
|
||||||
|
|
||||||
const result = getGeminiReasoningParams(assistant, model)
|
const result = getGeminiReasoningParams(assistant, model)
|
||||||
@ -731,7 +725,7 @@ describe('reasoning utils', () => {
|
|||||||
const result = getGeminiReasoningParams(assistant, model)
|
const result = getGeminiReasoningParams(assistant, model)
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
thinkingConfig: {
|
thinkingConfig: {
|
||||||
thinkingBudget: expect.any(Number),
|
thinkingBudget: 16448,
|
||||||
includeThoughts: true
|
includeThoughts: true
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -895,7 +889,7 @@ describe('reasoning utils', () => {
|
|||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
reasoningConfig: {
|
reasoningConfig: {
|
||||||
type: 'enabled',
|
type: 'enabled',
|
||||||
budgetTokens: 4096
|
budgetTokens: 2048
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@ -996,89 +990,4 @@ describe('reasoning utils', () => {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('getAnthropicThinkingBudget', () => {
|
|
||||||
it('should return undefined when reasoningEffort is undefined', async () => {
|
|
||||||
const result = getAnthropicThinkingBudget(4096, undefined, 'claude-3-7-sonnet')
|
|
||||||
expect(result).toBeUndefined()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should return undefined when reasoningEffort is none', async () => {
|
|
||||||
const result = getAnthropicThinkingBudget(4096, 'none', 'claude-3-7-sonnet')
|
|
||||||
expect(result).toBeUndefined()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should return undefined when tokenLimit is not found', async () => {
|
|
||||||
const { findTokenLimit } = await import('@renderer/config/models')
|
|
||||||
vi.mocked(findTokenLimit).mockReturnValue(undefined)
|
|
||||||
|
|
||||||
const result = getAnthropicThinkingBudget(4096, 'medium', 'unknown-model')
|
|
||||||
expect(result).toBeUndefined()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should calculate budget correctly when maxTokens is provided', async () => {
|
|
||||||
const { findTokenLimit } = await import('@renderer/config/models')
|
|
||||||
vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 })
|
|
||||||
|
|
||||||
const result = getAnthropicThinkingBudget(4096, 'medium', 'claude-3-7-sonnet')
|
|
||||||
// EFFORT_RATIO['medium'] = 0.5
|
|
||||||
// budget = Math.floor((32768 - 1024) * 0.5 + 1024)
|
|
||||||
// = Math.floor(31744 * 0.5 + 1024) = Math.floor(15872 + 1024) = 16896
|
|
||||||
// budgetTokens = Math.min(16896, 4096) = 4096
|
|
||||||
// result = Math.max(1024, 4096) = 4096
|
|
||||||
expect(result).toBe(4096)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should use tokenLimit.max when maxTokens is undefined', async () => {
|
|
||||||
const { findTokenLimit } = await import('@renderer/config/models')
|
|
||||||
vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 })
|
|
||||||
|
|
||||||
const result = getAnthropicThinkingBudget(undefined, 'medium', 'claude-3-7-sonnet')
|
|
||||||
// When maxTokens is undefined, budget is not constrained by maxTokens
|
|
||||||
// EFFORT_RATIO['medium'] = 0.5
|
|
||||||
// budget = Math.floor((32768 - 1024) * 0.5 + 1024)
|
|
||||||
// = Math.floor(31744 * 0.5 + 1024) = Math.floor(15872 + 1024) = 16896
|
|
||||||
// result = Math.max(1024, 16896) = 16896
|
|
||||||
expect(result).toBe(16896)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should enforce minimum budget of 1024', async () => {
|
|
||||||
const { findTokenLimit } = await import('@renderer/config/models')
|
|
||||||
vi.mocked(findTokenLimit).mockReturnValue({ min: 100, max: 1000 })
|
|
||||||
|
|
||||||
const result = getAnthropicThinkingBudget(500, 'low', 'claude-3-7-sonnet')
|
|
||||||
// EFFORT_RATIO['low'] = 0.05
|
|
||||||
// budget = Math.floor((1000 - 100) * 0.05 + 100)
|
|
||||||
// = Math.floor(900 * 0.05 + 100) = Math.floor(45 + 100) = 145
|
|
||||||
// budgetTokens = Math.min(145, 500) = 145
|
|
||||||
// result = Math.max(1024, 145) = 1024
|
|
||||||
expect(result).toBe(1024)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should respect effort ratio for high reasoning effort', async () => {
|
|
||||||
const { findTokenLimit } = await import('@renderer/config/models')
|
|
||||||
vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 })
|
|
||||||
|
|
||||||
const result = getAnthropicThinkingBudget(8192, 'high', 'claude-3-7-sonnet')
|
|
||||||
// EFFORT_RATIO['high'] = 0.8
|
|
||||||
// budget = Math.floor((32768 - 1024) * 0.8 + 1024)
|
|
||||||
// = Math.floor(31744 * 0.8 + 1024) = Math.floor(25395.2 + 1024) = 26419
|
|
||||||
// budgetTokens = Math.min(26419, 8192) = 8192
|
|
||||||
// result = Math.max(1024, 8192) = 8192
|
|
||||||
expect(result).toBe(8192)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should use full token limit when maxTokens is undefined and reasoning effort is high', async () => {
|
|
||||||
const { findTokenLimit } = await import('@renderer/config/models')
|
|
||||||
vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 })
|
|
||||||
|
|
||||||
const result = getAnthropicThinkingBudget(undefined, 'high', 'claude-3-7-sonnet')
|
|
||||||
// When maxTokens is undefined, budget is not constrained by maxTokens
|
|
||||||
// EFFORT_RATIO['high'] = 0.8
|
|
||||||
// budget = Math.floor((32768 - 1024) * 0.8 + 1024)
|
|
||||||
// = Math.floor(31744 * 0.8 + 1024) = Math.floor(25395.2 + 1024) = 26419
|
|
||||||
// result = Math.max(1024, 26419) = 26419
|
|
||||||
expect(result).toBe(26419)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|||||||
@ -259,7 +259,7 @@ describe('websearch utils', () => {
|
|||||||
|
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
xai: {
|
xai: {
|
||||||
maxSearchResults: 30,
|
maxSearchResults: 50,
|
||||||
returnCitations: true,
|
returnCitations: true,
|
||||||
sources: [{ type: 'web', excludedWebsites: [] }, { type: 'news' }, { type: 'x' }],
|
sources: [{ type: 'web', excludedWebsites: [] }, { type: 'news' }, { type: 'x' }],
|
||||||
mode: 'on'
|
mode: 'on'
|
||||||
|
|||||||
@ -10,9 +10,7 @@ import {
|
|||||||
isAnthropicModel,
|
isAnthropicModel,
|
||||||
isGeminiModel,
|
isGeminiModel,
|
||||||
isGrokModel,
|
isGrokModel,
|
||||||
isInterleavedThinkingModel,
|
|
||||||
isOpenAIModel,
|
isOpenAIModel,
|
||||||
isOpenAIOpenWeightModel,
|
|
||||||
isQwenMTModel,
|
isQwenMTModel,
|
||||||
isSupportFlexServiceTierModel,
|
isSupportFlexServiceTierModel,
|
||||||
isSupportVerbosityModel
|
isSupportVerbosityModel
|
||||||
@ -246,7 +244,7 @@ export function buildProviderOptions(
|
|||||||
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
|
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
|
||||||
break
|
break
|
||||||
case SystemProviderIds.ollama:
|
case SystemProviderIds.ollama:
|
||||||
providerSpecificOptions = buildOllamaProviderOptions(assistant, model, capabilities)
|
providerSpecificOptions = buildOllamaProviderOptions(assistant, capabilities)
|
||||||
break
|
break
|
||||||
case SystemProviderIds.gateway:
|
case SystemProviderIds.gateway:
|
||||||
providerSpecificOptions = buildAIGatewayOptions(assistant, model, capabilities, serviceTier, textVerbosity)
|
providerSpecificOptions = buildAIGatewayOptions(assistant, model, capabilities, serviceTier, textVerbosity)
|
||||||
@ -397,12 +395,10 @@ function buildOpenAIProviderOptions(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: 支持配置是否在服务端持久化
|
|
||||||
providerOptions = {
|
providerOptions = {
|
||||||
...providerOptions,
|
...providerOptions,
|
||||||
serviceTier,
|
serviceTier,
|
||||||
textVerbosity,
|
textVerbosity
|
||||||
store: false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -568,7 +564,6 @@ function buildBedrockProviderOptions(
|
|||||||
|
|
||||||
function buildOllamaProviderOptions(
|
function buildOllamaProviderOptions(
|
||||||
assistant: Assistant,
|
assistant: Assistant,
|
||||||
model: Model,
|
|
||||||
capabilities: {
|
capabilities: {
|
||||||
enableReasoning: boolean
|
enableReasoning: boolean
|
||||||
enableWebSearch: boolean
|
enableWebSearch: boolean
|
||||||
@ -579,12 +574,7 @@ function buildOllamaProviderOptions(
|
|||||||
const providerOptions: OllamaCompletionProviderOptions = {}
|
const providerOptions: OllamaCompletionProviderOptions = {}
|
||||||
const reasoningEffort = assistant.settings?.reasoning_effort
|
const reasoningEffort = assistant.settings?.reasoning_effort
|
||||||
if (enableReasoning) {
|
if (enableReasoning) {
|
||||||
if (isOpenAIOpenWeightModel(model)) {
|
providerOptions.think = !['none', undefined].includes(reasoningEffort)
|
||||||
// @ts-ignore upstream type error
|
|
||||||
providerOptions.think = reasoningEffort as any
|
|
||||||
} else {
|
|
||||||
providerOptions.think = !['none', undefined].includes(reasoningEffort)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
ollama: providerOptions
|
ollama: providerOptions
|
||||||
@ -604,7 +594,7 @@ function buildGenericProviderOptions(
|
|||||||
enableGenerateImage: boolean
|
enableGenerateImage: boolean
|
||||||
}
|
}
|
||||||
): Record<string, any> {
|
): Record<string, any> {
|
||||||
const { enableWebSearch, enableReasoning } = capabilities
|
const { enableWebSearch } = capabilities
|
||||||
let providerOptions: Record<string, any> = {}
|
let providerOptions: Record<string, any> = {}
|
||||||
|
|
||||||
const reasoningParams = getReasoningEffort(assistant, model)
|
const reasoningParams = getReasoningEffort(assistant, model)
|
||||||
@ -612,14 +602,6 @@ function buildGenericProviderOptions(
|
|||||||
...providerOptions,
|
...providerOptions,
|
||||||
...reasoningParams
|
...reasoningParams
|
||||||
}
|
}
|
||||||
if (enableReasoning) {
|
|
||||||
if (isInterleavedThinkingModel(model)) {
|
|
||||||
providerOptions = {
|
|
||||||
...providerOptions,
|
|
||||||
sendReasoning: true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (enableWebSearch) {
|
if (enableWebSearch) {
|
||||||
const webSearchParams = getWebSearchParams(model)
|
const webSearchParams = getWebSearchParams(model)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user