mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-12 00:49:14 +08:00
Merge remote-tracking branch 'origin/main' into feat/proxy-api-server
This commit is contained in:
commit
d74506ac3b
@ -1,140 +0,0 @@
|
||||
diff --git a/dist/index.js b/dist/index.js
|
||||
index 73045a7d38faafdc7f7d2cd79d7ff0e2b031056b..8d948c9ac4ea4b474db9ef3c5491961e7fcf9a07 100644
|
||||
--- a/dist/index.js
|
||||
+++ b/dist/index.js
|
||||
@@ -421,6 +421,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||
text: reasoning
|
||||
});
|
||||
}
|
||||
+ if (choice.message.images) {
|
||||
+ for (const image of choice.message.images) {
|
||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||
+ content.push({
|
||||
+ type: 'file',
|
||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||
+ data: match2 ? match2[1] : image.image_url.url,
|
||||
+ });
|
||||
+ }
|
||||
+ }
|
||||
if (choice.message.tool_calls != null) {
|
||||
for (const toolCall of choice.message.tool_calls) {
|
||||
content.push({
|
||||
@@ -598,6 +609,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||
delta: delta.content
|
||||
});
|
||||
}
|
||||
+ if (delta.images) {
|
||||
+ for (const image of delta.images) {
|
||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||
+ controller.enqueue({
|
||||
+ type: 'file',
|
||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||
+ data: match2 ? match2[1] : image.image_url.url,
|
||||
+ });
|
||||
+ }
|
||||
+ }
|
||||
if (delta.tool_calls != null) {
|
||||
for (const toolCallDelta of delta.tool_calls) {
|
||||
const index = toolCallDelta.index;
|
||||
@@ -765,6 +787,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({
|
||||
arguments: import_v43.z.string()
|
||||
})
|
||||
})
|
||||
+ ).nullish(),
|
||||
+ images: import_v43.z.array(
|
||||
+ import_v43.z.object({
|
||||
+ type: import_v43.z.literal('image_url'),
|
||||
+ image_url: import_v43.z.object({
|
||||
+ url: import_v43.z.string(),
|
||||
+ })
|
||||
+ })
|
||||
).nullish()
|
||||
}),
|
||||
finish_reason: import_v43.z.string().nullish()
|
||||
@@ -795,6 +825,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union(
|
||||
arguments: import_v43.z.string().nullish()
|
||||
})
|
||||
})
|
||||
+ ).nullish(),
|
||||
+ images: import_v43.z.array(
|
||||
+ import_v43.z.object({
|
||||
+ type: import_v43.z.literal('image_url'),
|
||||
+ image_url: import_v43.z.object({
|
||||
+ url: import_v43.z.string(),
|
||||
+ })
|
||||
+ })
|
||||
).nullish()
|
||||
}).nullish(),
|
||||
finish_reason: import_v43.z.string().nullish()
|
||||
diff --git a/dist/index.mjs b/dist/index.mjs
|
||||
index 1c2b9560bbfbfe10cb01af080aeeed4ff59db29c..2c8ddc4fc9bfc5e7e06cfca105d197a08864c427 100644
|
||||
--- a/dist/index.mjs
|
||||
+++ b/dist/index.mjs
|
||||
@@ -405,6 +405,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||
text: reasoning
|
||||
});
|
||||
}
|
||||
+ if (choice.message.images) {
|
||||
+ for (const image of choice.message.images) {
|
||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||
+ content.push({
|
||||
+ type: 'file',
|
||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||
+ data: match2 ? match2[1] : image.image_url.url,
|
||||
+ });
|
||||
+ }
|
||||
+ }
|
||||
if (choice.message.tool_calls != null) {
|
||||
for (const toolCall of choice.message.tool_calls) {
|
||||
content.push({
|
||||
@@ -582,6 +593,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||
delta: delta.content
|
||||
});
|
||||
}
|
||||
+ if (delta.images) {
|
||||
+ for (const image of delta.images) {
|
||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||
+ controller.enqueue({
|
||||
+ type: 'file',
|
||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||
+ data: match2 ? match2[1] : image.image_url.url,
|
||||
+ });
|
||||
+ }
|
||||
+ }
|
||||
if (delta.tool_calls != null) {
|
||||
for (const toolCallDelta of delta.tool_calls) {
|
||||
const index = toolCallDelta.index;
|
||||
@@ -749,6 +771,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({
|
||||
arguments: z3.string()
|
||||
})
|
||||
})
|
||||
+ ).nullish(),
|
||||
+ images: z3.array(
|
||||
+ z3.object({
|
||||
+ type: z3.literal('image_url'),
|
||||
+ image_url: z3.object({
|
||||
+ url: z3.string(),
|
||||
+ })
|
||||
+ })
|
||||
).nullish()
|
||||
}),
|
||||
finish_reason: z3.string().nullish()
|
||||
@@ -779,6 +809,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([
|
||||
arguments: z3.string().nullish()
|
||||
})
|
||||
})
|
||||
+ ).nullish(),
|
||||
+ images: z3.array(
|
||||
+ z3.object({
|
||||
+ type: z3.literal('image_url'),
|
||||
+ image_url: z3.object({
|
||||
+ url: z3.string(),
|
||||
+ })
|
||||
+ })
|
||||
).nullish()
|
||||
}).nullish(),
|
||||
finish_reason: z3.string().nullish()
|
||||
266
.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch
vendored
Normal file
266
.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch
vendored
Normal file
@ -0,0 +1,266 @@
|
||||
diff --git a/dist/index.d.ts b/dist/index.d.ts
|
||||
index 48e2f6263c6ee4c75d7e5c28733e64f6ebe92200..00d0729c4a3cbf9a48e8e1e962c7e2b256b75eba 100644
|
||||
--- a/dist/index.d.ts
|
||||
+++ b/dist/index.d.ts
|
||||
@@ -7,6 +7,7 @@ declare const openaiCompatibleProviderOptions: z.ZodObject<{
|
||||
user: z.ZodOptional<z.ZodString>;
|
||||
reasoningEffort: z.ZodOptional<z.ZodString>;
|
||||
textVerbosity: z.ZodOptional<z.ZodString>;
|
||||
+ sendReasoning: z.ZodOptional<z.ZodBoolean>;
|
||||
}, z.core.$strip>;
|
||||
type OpenAICompatibleProviderOptions = z.infer<typeof openaiCompatibleProviderOptions>;
|
||||
|
||||
diff --git a/dist/index.js b/dist/index.js
|
||||
index da237bb35b7fa8e24b37cd861ee73dfc51cdfc72..b3060fbaf010e30b64df55302807828e5bfe0f9a 100644
|
||||
--- a/dist/index.js
|
||||
+++ b/dist/index.js
|
||||
@@ -41,7 +41,7 @@ function getOpenAIMetadata(message) {
|
||||
var _a, _b;
|
||||
return (_b = (_a = message == null ? void 0 : message.providerOptions) == null ? void 0 : _a.openaiCompatible) != null ? _b : {};
|
||||
}
|
||||
-function convertToOpenAICompatibleChatMessages(prompt) {
|
||||
+function convertToOpenAICompatibleChatMessages({prompt, options}) {
|
||||
const messages = [];
|
||||
for (const { role, content, ...message } of prompt) {
|
||||
const metadata = getOpenAIMetadata({ ...message });
|
||||
@@ -91,6 +91,7 @@ function convertToOpenAICompatibleChatMessages(prompt) {
|
||||
}
|
||||
case "assistant": {
|
||||
let text = "";
|
||||
+ let reasoning_text = "";
|
||||
const toolCalls = [];
|
||||
for (const part of content) {
|
||||
const partMetadata = getOpenAIMetadata(part);
|
||||
@@ -99,6 +100,12 @@ function convertToOpenAICompatibleChatMessages(prompt) {
|
||||
text += part.text;
|
||||
break;
|
||||
}
|
||||
+ case "reasoning": {
|
||||
+ if (options.sendReasoning) {
|
||||
+ reasoning_text += part.text;
|
||||
+ }
|
||||
+ break;
|
||||
+ }
|
||||
case "tool-call": {
|
||||
toolCalls.push({
|
||||
id: part.toolCallId,
|
||||
@@ -116,6 +123,7 @@ function convertToOpenAICompatibleChatMessages(prompt) {
|
||||
messages.push({
|
||||
role: "assistant",
|
||||
content: text,
|
||||
+ reasoning_content: reasoning_text ?? undefined,
|
||||
tool_calls: toolCalls.length > 0 ? toolCalls : void 0,
|
||||
...metadata
|
||||
});
|
||||
@@ -200,7 +208,8 @@ var openaiCompatibleProviderOptions = import_v4.z.object({
|
||||
/**
|
||||
* Controls the verbosity of the generated text. Defaults to `medium`.
|
||||
*/
|
||||
- textVerbosity: import_v4.z.string().optional()
|
||||
+ textVerbosity: import_v4.z.string().optional(),
|
||||
+ sendReasoning: import_v4.z.boolean().optional()
|
||||
});
|
||||
|
||||
// src/openai-compatible-error.ts
|
||||
@@ -378,7 +387,7 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||
reasoning_effort: compatibleOptions.reasoningEffort,
|
||||
verbosity: compatibleOptions.textVerbosity,
|
||||
// messages:
|
||||
- messages: convertToOpenAICompatibleChatMessages(prompt),
|
||||
+ messages: convertToOpenAICompatibleChatMessages({prompt, options: compatibleOptions}),
|
||||
// tools:
|
||||
tools: openaiTools,
|
||||
tool_choice: openaiToolChoice
|
||||
@@ -421,6 +430,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||
text: reasoning
|
||||
});
|
||||
}
|
||||
+ if (choice.message.images) {
|
||||
+ for (const image of choice.message.images) {
|
||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||
+ content.push({
|
||||
+ type: 'file',
|
||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||
+ data: match2 ? match2[1] : image.image_url.url,
|
||||
+ });
|
||||
+ }
|
||||
+ }
|
||||
if (choice.message.tool_calls != null) {
|
||||
for (const toolCall of choice.message.tool_calls) {
|
||||
content.push({
|
||||
@@ -598,6 +618,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||
delta: delta.content
|
||||
});
|
||||
}
|
||||
+ if (delta.images) {
|
||||
+ for (const image of delta.images) {
|
||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||
+ controller.enqueue({
|
||||
+ type: 'file',
|
||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||
+ data: match2 ? match2[1] : image.image_url.url,
|
||||
+ });
|
||||
+ }
|
||||
+ }
|
||||
if (delta.tool_calls != null) {
|
||||
for (const toolCallDelta of delta.tool_calls) {
|
||||
const index = toolCallDelta.index;
|
||||
@@ -765,6 +796,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({
|
||||
arguments: import_v43.z.string()
|
||||
})
|
||||
})
|
||||
+ ).nullish(),
|
||||
+ images: import_v43.z.array(
|
||||
+ import_v43.z.object({
|
||||
+ type: import_v43.z.literal('image_url'),
|
||||
+ image_url: import_v43.z.object({
|
||||
+ url: import_v43.z.string(),
|
||||
+ })
|
||||
+ })
|
||||
).nullish()
|
||||
}),
|
||||
finish_reason: import_v43.z.string().nullish()
|
||||
@@ -795,6 +834,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union(
|
||||
arguments: import_v43.z.string().nullish()
|
||||
})
|
||||
})
|
||||
+ ).nullish(),
|
||||
+ images: import_v43.z.array(
|
||||
+ import_v43.z.object({
|
||||
+ type: import_v43.z.literal('image_url'),
|
||||
+ image_url: import_v43.z.object({
|
||||
+ url: import_v43.z.string(),
|
||||
+ })
|
||||
+ })
|
||||
).nullish()
|
||||
}).nullish(),
|
||||
finish_reason: import_v43.z.string().nullish()
|
||||
diff --git a/dist/index.mjs b/dist/index.mjs
|
||||
index a809a7aa0e148bfd43e01dd7b018568b151c8ad5..565b605eeacd9830b2b0e817e58ad0c5700264de 100644
|
||||
--- a/dist/index.mjs
|
||||
+++ b/dist/index.mjs
|
||||
@@ -23,7 +23,7 @@ function getOpenAIMetadata(message) {
|
||||
var _a, _b;
|
||||
return (_b = (_a = message == null ? void 0 : message.providerOptions) == null ? void 0 : _a.openaiCompatible) != null ? _b : {};
|
||||
}
|
||||
-function convertToOpenAICompatibleChatMessages(prompt) {
|
||||
+function convertToOpenAICompatibleChatMessages({prompt, options}) {
|
||||
const messages = [];
|
||||
for (const { role, content, ...message } of prompt) {
|
||||
const metadata = getOpenAIMetadata({ ...message });
|
||||
@@ -73,6 +73,7 @@ function convertToOpenAICompatibleChatMessages(prompt) {
|
||||
}
|
||||
case "assistant": {
|
||||
let text = "";
|
||||
+ let reasoning_text = "";
|
||||
const toolCalls = [];
|
||||
for (const part of content) {
|
||||
const partMetadata = getOpenAIMetadata(part);
|
||||
@@ -81,6 +82,12 @@ function convertToOpenAICompatibleChatMessages(prompt) {
|
||||
text += part.text;
|
||||
break;
|
||||
}
|
||||
+ case "reasoning": {
|
||||
+ if (options.sendReasoning) {
|
||||
+ reasoning_text += part.text;
|
||||
+ }
|
||||
+ break;
|
||||
+ }
|
||||
case "tool-call": {
|
||||
toolCalls.push({
|
||||
id: part.toolCallId,
|
||||
@@ -98,6 +105,7 @@ function convertToOpenAICompatibleChatMessages(prompt) {
|
||||
messages.push({
|
||||
role: "assistant",
|
||||
content: text,
|
||||
+ reasoning_content: reasoning_text ?? undefined,
|
||||
tool_calls: toolCalls.length > 0 ? toolCalls : void 0,
|
||||
...metadata
|
||||
});
|
||||
@@ -182,7 +190,8 @@ var openaiCompatibleProviderOptions = z.object({
|
||||
/**
|
||||
* Controls the verbosity of the generated text. Defaults to `medium`.
|
||||
*/
|
||||
- textVerbosity: z.string().optional()
|
||||
+ textVerbosity: z.string().optional(),
|
||||
+ sendReasoning: z.boolean().optional()
|
||||
});
|
||||
|
||||
// src/openai-compatible-error.ts
|
||||
@@ -362,7 +371,7 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||
reasoning_effort: compatibleOptions.reasoningEffort,
|
||||
verbosity: compatibleOptions.textVerbosity,
|
||||
// messages:
|
||||
- messages: convertToOpenAICompatibleChatMessages(prompt),
|
||||
+ messages: convertToOpenAICompatibleChatMessages({prompt, options: compatibleOptions}),
|
||||
// tools:
|
||||
tools: openaiTools,
|
||||
tool_choice: openaiToolChoice
|
||||
@@ -405,6 +414,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||
text: reasoning
|
||||
});
|
||||
}
|
||||
+ if (choice.message.images) {
|
||||
+ for (const image of choice.message.images) {
|
||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||
+ content.push({
|
||||
+ type: 'file',
|
||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||
+ data: match2 ? match2[1] : image.image_url.url,
|
||||
+ });
|
||||
+ }
|
||||
+ }
|
||||
if (choice.message.tool_calls != null) {
|
||||
for (const toolCall of choice.message.tool_calls) {
|
||||
content.push({
|
||||
@@ -582,6 +602,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||
delta: delta.content
|
||||
});
|
||||
}
|
||||
+ if (delta.images) {
|
||||
+ for (const image of delta.images) {
|
||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||
+ controller.enqueue({
|
||||
+ type: 'file',
|
||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||
+ data: match2 ? match2[1] : image.image_url.url,
|
||||
+ });
|
||||
+ }
|
||||
+ }
|
||||
if (delta.tool_calls != null) {
|
||||
for (const toolCallDelta of delta.tool_calls) {
|
||||
const index = toolCallDelta.index;
|
||||
@@ -749,6 +780,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({
|
||||
arguments: z3.string()
|
||||
})
|
||||
})
|
||||
+ ).nullish(),
|
||||
+ images: z3.array(
|
||||
+ z3.object({
|
||||
+ type: z3.literal('image_url'),
|
||||
+ image_url: z3.object({
|
||||
+ url: z3.string(),
|
||||
+ })
|
||||
+ })
|
||||
).nullish()
|
||||
}),
|
||||
finish_reason: z3.string().nullish()
|
||||
@@ -779,6 +818,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([
|
||||
arguments: z3.string().nullish()
|
||||
})
|
||||
})
|
||||
+ ).nullish(),
|
||||
+ images: z3.array(
|
||||
+ z3.object({
|
||||
+ type: z3.literal('image_url'),
|
||||
+ image_url: z3.object({
|
||||
+ url: z3.string(),
|
||||
+ })
|
||||
+ })
|
||||
).nullish()
|
||||
}).nullish(),
|
||||
finish_reason: z3.string().nullish()
|
||||
@ -7,7 +7,7 @@ index 8dd9b498050dbecd8dd6b901acf1aa8ca38a49af..ed644349c9d38fe2a66b2fb44214f7c1
|
||||
type OllamaChatModelId = "athene-v2" | "athene-v2:72b" | "aya-expanse" | "aya-expanse:8b" | "aya-expanse:32b" | "codegemma" | "codegemma:2b" | "codegemma:7b" | "codellama" | "codellama:7b" | "codellama:13b" | "codellama:34b" | "codellama:70b" | "codellama:code" | "codellama:python" | "command-r" | "command-r:35b" | "command-r-plus" | "command-r-plus:104b" | "command-r7b" | "command-r7b:7b" | "deepseek-r1" | "deepseek-r1:1.5b" | "deepseek-r1:7b" | "deepseek-r1:8b" | "deepseek-r1:14b" | "deepseek-r1:32b" | "deepseek-r1:70b" | "deepseek-r1:671b" | "deepseek-coder-v2" | "deepseek-coder-v2:16b" | "deepseek-coder-v2:236b" | "deepseek-v3" | "deepseek-v3:671b" | "devstral" | "devstral:24b" | "dolphin3" | "dolphin3:8b" | "exaone3.5" | "exaone3.5:2.4b" | "exaone3.5:7.8b" | "exaone3.5:32b" | "falcon2" | "falcon2:11b" | "falcon3" | "falcon3:1b" | "falcon3:3b" | "falcon3:7b" | "falcon3:10b" | "firefunction-v2" | "firefunction-v2:70b" | "gemma" | "gemma:2b" | "gemma:7b" | "gemma2" | "gemma2:2b" | "gemma2:9b" | "gemma2:27b" | "gemma3" | "gemma3:1b" | "gemma3:4b" | "gemma3:12b" | "gemma3:27b" | "granite3-dense" | "granite3-dense:2b" | "granite3-dense:8b" | "granite3-guardian" | "granite3-guardian:2b" | "granite3-guardian:8b" | "granite3-moe" | "granite3-moe:1b" | "granite3-moe:3b" | "granite3.1-dense" | "granite3.1-dense:2b" | "granite3.1-dense:8b" | "granite3.1-moe" | "granite3.1-moe:1b" | "granite3.1-moe:3b" | "llama2" | "llama2:7b" | "llama2:13b" | "llama2:70b" | "llama3" | "llama3:8b" | "llama3:70b" | "llama3-chatqa" | "llama3-chatqa:8b" | "llama3-chatqa:70b" | "llama3-gradient" | "llama3-gradient:8b" | "llama3-gradient:70b" | "llama3.1" | "llama3.1:8b" | "llama3.1:70b" | "llama3.1:405b" | "llama3.2" | "llama3.2:1b" | "llama3.2:3b" | "llama3.2-vision" | "llama3.2-vision:11b" | "llama3.2-vision:90b" | "llama3.3" | "llama3.3:70b" | "llama4" | "llama4:16x17b" | "llama4:128x17b" | "llama-guard3" | "llama-guard3:1b" | "llama-guard3:8b" | "llava" | "llava:7b" | "llava:13b" | "llava:34b" | "llava-llama3" | "llava-llama3:8b" | "llava-phi3" | "llava-phi3:3.8b" | "marco-o1" | "marco-o1:7b" | "mistral" | "mistral:7b" | "mistral-large" | "mistral-large:123b" | "mistral-nemo" | "mistral-nemo:12b" | "mistral-small" | "mistral-small:22b" | "mixtral" | "mixtral:8x7b" | "mixtral:8x22b" | "moondream" | "moondream:1.8b" | "openhermes" | "openhermes:v2.5" | "nemotron" | "nemotron:70b" | "nemotron-mini" | "nemotron-mini:4b" | "olmo" | "olmo:7b" | "olmo:13b" | "opencoder" | "opencoder:1.5b" | "opencoder:8b" | "phi3" | "phi3:3.8b" | "phi3:14b" | "phi3.5" | "phi3.5:3.8b" | "phi4" | "phi4:14b" | "qwen" | "qwen:7b" | "qwen:14b" | "qwen:32b" | "qwen:72b" | "qwen:110b" | "qwen2" | "qwen2:0.5b" | "qwen2:1.5b" | "qwen2:7b" | "qwen2:72b" | "qwen2.5" | "qwen2.5:0.5b" | "qwen2.5:1.5b" | "qwen2.5:3b" | "qwen2.5:7b" | "qwen2.5:14b" | "qwen2.5:32b" | "qwen2.5:72b" | "qwen2.5-coder" | "qwen2.5-coder:0.5b" | "qwen2.5-coder:1.5b" | "qwen2.5-coder:3b" | "qwen2.5-coder:7b" | "qwen2.5-coder:14b" | "qwen2.5-coder:32b" | "qwen3" | "qwen3:0.6b" | "qwen3:1.7b" | "qwen3:4b" | "qwen3:8b" | "qwen3:14b" | "qwen3:30b" | "qwen3:32b" | "qwen3:235b" | "qwq" | "qwq:32b" | "sailor2" | "sailor2:1b" | "sailor2:8b" | "sailor2:20b" | "shieldgemma" | "shieldgemma:2b" | "shieldgemma:9b" | "shieldgemma:27b" | "smallthinker" | "smallthinker:3b" | "smollm" | "smollm:135m" | "smollm:360m" | "smollm:1.7b" | "tinyllama" | "tinyllama:1.1b" | "tulu3" | "tulu3:8b" | "tulu3:70b" | (string & {});
|
||||
declare const ollamaProviderOptions: z.ZodObject<{
|
||||
- think: z.ZodOptional<z.ZodBoolean>;
|
||||
+ think: z.ZodOptional<z.ZodUnion<[z.ZodBoolean, z.ZodEnum<['low', 'medium', 'high']>]>>;
|
||||
+ think: z.ZodOptional<z.ZodUnion<[z.ZodBoolean, z.ZodLiteral<"low">, z.ZodLiteral<"medium">, z.ZodLiteral<"high">]>>;
|
||||
options: z.ZodOptional<z.ZodObject<{
|
||||
num_ctx: z.ZodOptional<z.ZodNumber>;
|
||||
repeat_last_n: z.ZodOptional<z.ZodNumber>;
|
||||
@ -29,7 +29,7 @@ index 8dd9b498050dbecd8dd6b901acf1aa8ca38a49af..ed644349c9d38fe2a66b2fb44214f7c1
|
||||
|
||||
declare const ollamaCompletionProviderOptions: z.ZodObject<{
|
||||
- think: z.ZodOptional<z.ZodBoolean>;
|
||||
+ think: z.ZodOptional<z.ZodUnion<[z.ZodBoolean, z.ZodEnum<['low', 'medium', 'high']>]>>;
|
||||
+ think: z.ZodOptional<z.ZodUnion<[z.ZodBoolean, z.ZodLiteral<"low">, z.ZodLiteral<"medium">, z.ZodLiteral<"high">]>>;
|
||||
user: z.ZodOptional<z.ZodString>;
|
||||
suffix: z.ZodOptional<z.ZodString>;
|
||||
echo: z.ZodOptional<z.ZodBoolean>;
|
||||
@ -42,7 +42,7 @@ index 35b5142ce8476ce2549ed7c2ec48e7d8c46c90d9..2ef64dc9a4c2be043e6af608241a6a83
|
||||
// src/completion/ollama-completion-language-model.ts
|
||||
var ollamaCompletionProviderOptions = import_v42.z.object({
|
||||
- think: import_v42.z.boolean().optional(),
|
||||
+ think: import_v42.z.union([import_v42.z.boolean(), import_v42.z.enum(['low', 'medium', 'high'])]).optional(),
|
||||
+ think: import_v42.z.union([import_v42.z.boolean(), import_v42.z.literal('low'), import_v42.z.literal('medium'), import_v42.z.literal('high')]).optional(),
|
||||
user: import_v42.z.string().optional(),
|
||||
suffix: import_v42.z.string().optional(),
|
||||
echo: import_v42.z.boolean().optional()
|
||||
@ -64,7 +64,7 @@ index 35b5142ce8476ce2549ed7c2ec48e7d8c46c90d9..2ef64dc9a4c2be043e6af608241a6a83
|
||||
* Only supported by certain models like DeepSeek R1 and Qwen 3.
|
||||
*/
|
||||
- think: import_v44.z.boolean().optional(),
|
||||
+ think: import_v44.z.union([import_v44.z.boolean(), import_v44.z.enum(['low', 'medium', 'high'])]).optional(),
|
||||
+ think: import_v44.z.union([import_v44.z.boolean(), import_v44.z.literal('low'), import_v44.z.literal('medium'), import_v44.z.literal('high')]).optional(),
|
||||
options: import_v44.z.object({
|
||||
num_ctx: import_v44.z.number().optional(),
|
||||
repeat_last_n: import_v44.z.number().optional(),
|
||||
@ -97,7 +97,7 @@ index e2a634a78d80ac9542f2cc4f96cf2291094b10cf..67b23efce3c1cf4f026693d3ff924698
|
||||
// src/completion/ollama-completion-language-model.ts
|
||||
var ollamaCompletionProviderOptions = z2.object({
|
||||
- think: z2.boolean().optional(),
|
||||
+ think: z2.union([z2.boolean(), z2.enum(['low', 'medium', 'high'])]).optional(),
|
||||
+ think: z2.union([z2.boolean(), z2.literal('low'), z2.literal('medium'), z2.literal('high')]).optional(),
|
||||
user: z2.string().optional(),
|
||||
suffix: z2.string().optional(),
|
||||
echo: z2.boolean().optional()
|
||||
@ -119,7 +119,7 @@ index e2a634a78d80ac9542f2cc4f96cf2291094b10cf..67b23efce3c1cf4f026693d3ff924698
|
||||
* Only supported by certain models like DeepSeek R1 and Qwen 3.
|
||||
*/
|
||||
- think: z4.boolean().optional(),
|
||||
+ think: z4.union([z4.boolean(), z4.enum(['low', 'medium', 'high'])]).optional(),
|
||||
+ think: z4.union([z4.boolean(), z4.literal('low'), z4.literal('medium'), z4.literal('high')]).optional(),
|
||||
options: z4.object({
|
||||
num_ctx: z4.number().optional(),
|
||||
repeat_last_n: z4.number().optional(),
|
||||
|
||||
@ -36,7 +36,7 @@ yarn install
|
||||
### ENV
|
||||
|
||||
```bash
|
||||
copy .env.example .env
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
### Start
|
||||
|
||||
129
docs/en/references/fuzzy-search.md
Normal file
129
docs/en/references/fuzzy-search.md
Normal file
@ -0,0 +1,129 @@
|
||||
# Fuzzy Search for File List
|
||||
|
||||
This document describes the fuzzy search implementation for file listing in Cherry Studio.
|
||||
|
||||
## Overview
|
||||
|
||||
The fuzzy search feature allows users to find files by typing partial or approximate file names/paths. It uses a two-tier file filtering strategy (ripgrep glob pre-filtering with greedy substring fallback) combined with subsequence-based scoring for optimal performance and flexibility.
|
||||
|
||||
## Features
|
||||
|
||||
- **Ripgrep Glob Pre-filtering**: Primary filtering using glob patterns for fast native-level filtering
|
||||
- **Greedy Substring Matching**: Fallback file filtering strategy when ripgrep glob pre-filtering returns no results
|
||||
- **Subsequence-based Segment Scoring**: During scoring, path segments gain additional weight when query characters appear in order
|
||||
- **Relevance Scoring**: Results are sorted by a relevance score derived from multiple factors
|
||||
|
||||
## Matching Strategies
|
||||
|
||||
### 1. Ripgrep Glob Pre-filtering (Primary)
|
||||
|
||||
The query is converted to a glob pattern for ripgrep to do initial filtering:
|
||||
|
||||
```
|
||||
Query: "updater"
|
||||
Glob: "*u*p*d*a*t*e*r*"
|
||||
```
|
||||
|
||||
This leverages ripgrep's native performance for the initial file filtering.
|
||||
|
||||
### 2. Greedy Substring Matching (Fallback)
|
||||
|
||||
When the glob pre-filter returns no results, the system falls back to greedy substring matching. This allows more flexible matching:
|
||||
|
||||
```
|
||||
Query: "updatercontroller"
|
||||
File: "packages/update/src/node/updateController.ts"
|
||||
|
||||
Matching process:
|
||||
1. Find "update" (longest match from start)
|
||||
2. Remaining "rcontroller" → find "r" then "controller"
|
||||
3. All parts matched → Success
|
||||
```
|
||||
|
||||
## Scoring Algorithm
|
||||
|
||||
Results are ranked by a relevance score based on named constants defined in `FileStorage.ts`:
|
||||
|
||||
| Constant | Value | Description |
|
||||
|----------|-------|-------------|
|
||||
| `SCORE_FILENAME_STARTS` | 100 | Filename starts with query (highest priority) |
|
||||
| `SCORE_FILENAME_CONTAINS` | 80 | Filename contains exact query substring |
|
||||
| `SCORE_SEGMENT_MATCH` | 60 | Per path segment that matches query |
|
||||
| `SCORE_WORD_BOUNDARY` | 20 | Query matches start of a word |
|
||||
| `SCORE_CONSECUTIVE_CHAR` | 15 | Per consecutive character match |
|
||||
| `PATH_LENGTH_PENALTY_FACTOR` | 4 | Logarithmic penalty for longer paths |
|
||||
|
||||
### Scoring Strategy
|
||||
|
||||
The scoring prioritizes:
|
||||
1. **Filename matches** (highest): Files where the query appears in the filename are most relevant
|
||||
2. **Path segment matches**: Multiple matching segments indicate stronger relevance
|
||||
3. **Word boundaries**: Matching at word starts (e.g., "upd" matching "update") is preferred
|
||||
4. **Consecutive matches**: Longer consecutive character sequences score higher
|
||||
5. **Path length**: Shorter paths are preferred (logarithmic penalty prevents long paths from dominating)
|
||||
|
||||
### Example Scoring
|
||||
|
||||
For query `updater`:
|
||||
|
||||
| File | Score Factors |
|
||||
|------|---------------|
|
||||
| `RCUpdater.js` | Short path + filename contains "updater" |
|
||||
| `updateController.ts` | Multiple segment matches |
|
||||
| `UpdaterHelper.plist` | Long path penalty |
|
||||
|
||||
## Configuration
|
||||
|
||||
### DirectoryListOptions
|
||||
|
||||
```typescript
|
||||
interface DirectoryListOptions {
|
||||
recursive?: boolean // Default: true
|
||||
maxDepth?: number // Default: 10
|
||||
includeHidden?: boolean // Default: false
|
||||
includeFiles?: boolean // Default: true
|
||||
includeDirectories?: boolean // Default: true
|
||||
maxEntries?: number // Default: 20
|
||||
searchPattern?: string // Default: '.'
|
||||
fuzzy?: boolean // Default: true
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```typescript
|
||||
// Basic fuzzy search
|
||||
const files = await window.api.file.listDirectory(dirPath, {
|
||||
searchPattern: 'updater',
|
||||
fuzzy: true,
|
||||
maxEntries: 20
|
||||
})
|
||||
|
||||
// Disable fuzzy search (exact glob matching)
|
||||
const files = await window.api.file.listDirectory(dirPath, {
|
||||
searchPattern: 'update',
|
||||
fuzzy: false
|
||||
})
|
||||
```
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
1. **Ripgrep Pre-filtering**: Most queries are handled by ripgrep's native glob matching, which is extremely fast
|
||||
2. **Fallback Only When Needed**: Greedy substring matching (which loads all files) only runs when glob matching returns empty results
|
||||
3. **Result Limiting**: Only top 20 results are returned by default
|
||||
4. **Excluded Directories**: Common large directories are automatically excluded:
|
||||
- `node_modules`
|
||||
- `.git`
|
||||
- `dist`, `build`
|
||||
- `.next`, `.nuxt`
|
||||
- `coverage`, `.cache`
|
||||
|
||||
## Implementation Details
|
||||
|
||||
The implementation is located in `src/main/services/FileStorage.ts`:
|
||||
|
||||
- `queryToGlobPattern()`: Converts query to ripgrep glob pattern
|
||||
- `isFuzzyMatch()`: Subsequence matching algorithm
|
||||
- `isGreedySubstringMatch()`: Greedy substring matching fallback
|
||||
- `getFuzzyMatchScore()`: Calculates relevance score
|
||||
- `listDirectoryWithRipgrep()`: Main search orchestration
|
||||
@ -36,7 +36,7 @@ yarn install
|
||||
### ENV
|
||||
|
||||
```bash
|
||||
copy .env.example .env
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
### Start
|
||||
|
||||
129
docs/zh/references/fuzzy-search.md
Normal file
129
docs/zh/references/fuzzy-search.md
Normal file
@ -0,0 +1,129 @@
|
||||
# 文件列表模糊搜索
|
||||
|
||||
本文档描述了 Cherry Studio 中文件列表的模糊搜索实现。
|
||||
|
||||
## 概述
|
||||
|
||||
模糊搜索功能允许用户通过输入部分或近似的文件名/路径来查找文件。它使用两层文件过滤策略(ripgrep glob 预过滤 + 贪婪子串匹配回退),结合基于子序列的评分,以获得最佳性能和灵活性。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **Ripgrep Glob 预过滤**:使用 glob 模式进行快速原生级过滤的主要过滤策略
|
||||
- **贪婪子串匹配**:当 ripgrep glob 预过滤无结果时的回退文件过滤策略
|
||||
- **基于子序列的段评分**:评分时,当查询字符按顺序出现时,路径段获得额外权重
|
||||
- **相关性评分**:结果按多因素相关性分数排序
|
||||
|
||||
## 匹配策略
|
||||
|
||||
### 1. Ripgrep Glob 预过滤(主要)
|
||||
|
||||
查询被转换为 glob 模式供 ripgrep 进行初始过滤:
|
||||
|
||||
```
|
||||
查询: "updater"
|
||||
Glob: "*u*p*d*a*t*e*r*"
|
||||
```
|
||||
|
||||
这利用了 ripgrep 的原生性能进行初始文件过滤。
|
||||
|
||||
### 2. 贪婪子串匹配(回退)
|
||||
|
||||
当 glob 预过滤无结果时,系统回退到贪婪子串匹配。这允许更灵活的匹配:
|
||||
|
||||
```
|
||||
查询: "updatercontroller"
|
||||
文件: "packages/update/src/node/updateController.ts"
|
||||
|
||||
匹配过程:
|
||||
1. 找到 "update"(从开头的最长匹配)
|
||||
2. 剩余 "rcontroller" → 找到 "r" 然后 "controller"
|
||||
3. 所有部分都匹配 → 成功
|
||||
```
|
||||
|
||||
## 评分算法
|
||||
|
||||
结果根据 `FileStorage.ts` 中定义的命名常量进行相关性分数排名:
|
||||
|
||||
| 常量 | 值 | 描述 |
|
||||
|------|-----|------|
|
||||
| `SCORE_FILENAME_STARTS` | 100 | 文件名以查询开头(最高优先级)|
|
||||
| `SCORE_FILENAME_CONTAINS` | 80 | 文件名包含精确查询子串 |
|
||||
| `SCORE_SEGMENT_MATCH` | 60 | 每个匹配查询的路径段 |
|
||||
| `SCORE_WORD_BOUNDARY` | 20 | 查询匹配单词开头 |
|
||||
| `SCORE_CONSECUTIVE_CHAR` | 15 | 每个连续字符匹配 |
|
||||
| `PATH_LENGTH_PENALTY_FACTOR` | 4 | 较长路径的对数惩罚 |
|
||||
|
||||
### 评分策略
|
||||
|
||||
评分优先级:
|
||||
1. **文件名匹配**(最高):查询出现在文件名中的文件最相关
|
||||
2. **路径段匹配**:多个匹配段表示更强的相关性
|
||||
3. **词边界**:在单词开头匹配(如 "upd" 匹配 "update")更优先
|
||||
4. **连续匹配**:更长的连续字符序列得分更高
|
||||
5. **路径长度**:较短路径更优先(对数惩罚防止长路径主导评分)
|
||||
|
||||
### 评分示例
|
||||
|
||||
对于查询 `updater`:
|
||||
|
||||
| 文件 | 评分因素 |
|
||||
|------|----------|
|
||||
| `RCUpdater.js` | 短路径 + 文件名包含 "updater" |
|
||||
| `updateController.ts` | 多个路径段匹配 |
|
||||
| `UpdaterHelper.plist` | 长路径惩罚 |
|
||||
|
||||
## 配置
|
||||
|
||||
### DirectoryListOptions
|
||||
|
||||
```typescript
|
||||
interface DirectoryListOptions {
|
||||
recursive?: boolean // 默认: true
|
||||
maxDepth?: number // 默认: 10
|
||||
includeHidden?: boolean // 默认: false
|
||||
includeFiles?: boolean // 默认: true
|
||||
includeDirectories?: boolean // 默认: true
|
||||
maxEntries?: number // 默认: 20
|
||||
searchPattern?: string // 默认: '.'
|
||||
fuzzy?: boolean // 默认: true
|
||||
}
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
```typescript
|
||||
// 基本模糊搜索
|
||||
const files = await window.api.file.listDirectory(dirPath, {
|
||||
searchPattern: 'updater',
|
||||
fuzzy: true,
|
||||
maxEntries: 20
|
||||
})
|
||||
|
||||
// 禁用模糊搜索(精确 glob 匹配)
|
||||
const files = await window.api.file.listDirectory(dirPath, {
|
||||
searchPattern: 'update',
|
||||
fuzzy: false
|
||||
})
|
||||
```
|
||||
|
||||
## 性能考虑
|
||||
|
||||
1. **Ripgrep 预过滤**:大多数查询由 ripgrep 的原生 glob 匹配处理,速度极快
|
||||
2. **仅在需要时回退**:贪婪子串匹配(加载所有文件)仅在 glob 匹配返回空结果时运行
|
||||
3. **结果限制**:默认只返回前 20 个结果
|
||||
4. **排除目录**:自动排除常见的大型目录:
|
||||
- `node_modules`
|
||||
- `.git`
|
||||
- `dist`、`build`
|
||||
- `.next`、`.nuxt`
|
||||
- `coverage`、`.cache`
|
||||
|
||||
## 实现细节
|
||||
|
||||
实现位于 `src/main/services/FileStorage.ts`:
|
||||
|
||||
- `queryToGlobPattern()`:将查询转换为 ripgrep glob 模式
|
||||
- `isFuzzyMatch()`:子序列匹配算法
|
||||
- `isGreedySubstringMatch()`:贪婪子串匹配回退
|
||||
- `getFuzzyMatchScore()`:计算相关性分数
|
||||
- `listDirectoryWithRipgrep()`:主搜索协调
|
||||
850
docs/zh/references/lan-transfer-protocol.md
Normal file
850
docs/zh/references/lan-transfer-protocol.md
Normal file
@ -0,0 +1,850 @@
|
||||
# Cherry Studio 局域网传输协议规范
|
||||
|
||||
> 版本: 1.0
|
||||
> 最后更新: 2025-12
|
||||
|
||||
本文档定义了 Cherry Studio 桌面客户端(Electron)与移动端(Expo)之间的局域网文件传输协议。
|
||||
|
||||
---
|
||||
|
||||
## 目录
|
||||
|
||||
1. [协议概述](#1-协议概述)
|
||||
2. [服务发现(Bonjour/mDNS)](#2-服务发现bonjourmdns)
|
||||
3. [TCP 连接与握手](#3-tcp-连接与握手)
|
||||
4. [消息格式规范](#4-消息格式规范)
|
||||
5. [文件传输协议](#5-文件传输协议)
|
||||
6. [心跳与连接保活](#6-心跳与连接保活)
|
||||
7. [错误处理](#7-错误处理)
|
||||
8. [常量与配置](#8-常量与配置)
|
||||
9. [完整时序图](#9-完整时序图)
|
||||
10. [移动端实现指南](#10-移动端实现指南)
|
||||
|
||||
---
|
||||
|
||||
## 1. 协议概述
|
||||
|
||||
### 1.1 架构角色
|
||||
|
||||
| 角色 | 平台 | 职责 |
|
||||
| -------------------- | --------------- | ---------------------------- |
|
||||
| **Client(客户端)** | Electron 桌面端 | 扫描服务、发起连接、发送文件 |
|
||||
| **Server(服务端)** | Expo 移动端 | 发布服务、接受连接、接收文件 |
|
||||
|
||||
### 1.2 协议栈(v1)
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ 应用层(文件传输) │
|
||||
├─────────────────────────────────────┤
|
||||
│ 消息层(控制: JSON \n) │
|
||||
│ (数据: 二进制帧) │
|
||||
├─────────────────────────────────────┤
|
||||
│ 传输层(TCP) │
|
||||
├─────────────────────────────────────┤
|
||||
│ 发现层(Bonjour/mDNS) │
|
||||
└─────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 1.3 通信流程概览
|
||||
|
||||
```
|
||||
1. 服务发现 → 移动端发布 mDNS 服务,桌面端扫描发现
|
||||
2. TCP 握手 → 建立连接,交换设备信息(`version=1`)
|
||||
3. 文件传输 → 控制消息使用 JSON,`file_chunk` 使用二进制帧分块传输
|
||||
4. 连接保活 → ping/pong 心跳
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. 服务发现(Bonjour/mDNS)
|
||||
|
||||
### 2.1 服务类型
|
||||
|
||||
| 属性 | 值 |
|
||||
| ------------ | -------------------- |
|
||||
| 服务类型 | `cherrystudio` |
|
||||
| 协议 | `tcp` |
|
||||
| 完整服务标识 | `_cherrystudio._tcp` |
|
||||
|
||||
### 2.2 服务发布(移动端)
|
||||
|
||||
移动端需要通过 mDNS/Bonjour 发布服务:
|
||||
|
||||
```typescript
|
||||
// 服务发布参数
|
||||
{
|
||||
name: "Cherry Studio Mobile", // 设备名称
|
||||
type: "cherrystudio", // 服务类型
|
||||
protocol: "tcp", // 协议
|
||||
port: 53317, // TCP 监听端口
|
||||
txt: { // TXT 记录(可选)
|
||||
version: "1",
|
||||
platform: "ios" // 或 "android"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2.3 服务发现(桌面端)
|
||||
|
||||
桌面端扫描并解析服务信息:
|
||||
|
||||
```typescript
|
||||
// 发现的服务信息结构
|
||||
type LocalTransferPeer = {
|
||||
id: string; // 唯一标识符
|
||||
name: string; // 设备名称
|
||||
host?: string; // 主机名
|
||||
fqdn?: string; // 完全限定域名
|
||||
port?: number; // TCP 端口
|
||||
type?: string; // 服务类型
|
||||
protocol?: "tcp" | "udp"; // 协议
|
||||
addresses: string[]; // IP 地址列表
|
||||
txt?: Record<string, string>; // TXT 记录
|
||||
updatedAt: number; // 发现时间戳
|
||||
};
|
||||
```
|
||||
|
||||
### 2.4 IP 地址选择策略
|
||||
|
||||
当服务有多个 IP 地址时,优先选择 IPv4:
|
||||
|
||||
```typescript
|
||||
// 优先选择 IPv4 地址
|
||||
const preferredAddress = addresses.find((addr) => isIPv4(addr)) || addresses[0];
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. TCP 连接与握手
|
||||
|
||||
### 3.1 连接建立
|
||||
|
||||
1. 客户端使用发现的 `host:port` 建立 TCP 连接
|
||||
2. 连接成功后立即发送握手消息
|
||||
3. 等待服务端响应握手确认
|
||||
|
||||
### 3.2 握手消息(协议版本 v1)
|
||||
|
||||
#### Client → Server: `handshake`
|
||||
|
||||
```typescript
|
||||
type LanTransferHandshakeMessage = {
|
||||
type: "handshake";
|
||||
deviceName: string; // 设备名称
|
||||
version: string; // 协议版本,当前为 "1"
|
||||
platform?: string; // 平台:'darwin' | 'win32' | 'linux'
|
||||
appVersion?: string; // 应用版本
|
||||
};
|
||||
```
|
||||
|
||||
**示例:**
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "handshake",
|
||||
"deviceName": "Cherry Studio 1.7.2",
|
||||
"version": "1",
|
||||
"platform": "darwin",
|
||||
"appVersion": "1.7.2"
|
||||
}
|
||||
```
|
||||
|
||||
### 4. 消息格式规范(混合协议)
|
||||
|
||||
v1 使用"控制 JSON + 二进制数据帧"的混合协议(流式传输模式,无 per-chunk ACK):
|
||||
|
||||
- **控制消息**(握手、心跳、file_start/ack、file_end、file_complete):UTF-8 JSON,`\n` 分隔
|
||||
- **数据消息**(`file_chunk`):二进制帧,使用 Magic + 总长度做分帧,不经 Base64
|
||||
|
||||
### 4.1 控制消息编码(JSON + `\n`)
|
||||
|
||||
| 属性 | 规范 |
|
||||
| ---------- | ------------ |
|
||||
| 编码格式 | UTF-8 |
|
||||
| 序列化格式 | JSON |
|
||||
| 消息分隔符 | `\n`(0x0A) |
|
||||
|
||||
```typescript
|
||||
function sendControlMessage(socket: Socket, message: object): void {
|
||||
socket.write(`${JSON.stringify(message)}\n`);
|
||||
}
|
||||
```
|
||||
|
||||
### 4.2 `file_chunk` 二进制帧格式
|
||||
|
||||
为解决 TCP 分包/粘包并消除 Base64 开销,`file_chunk` 采用带总长度的二进制帧:
|
||||
|
||||
```
|
||||
┌──────────┬──────────┬────────┬───────────────┬──────────────┬────────────┬───────────┐
|
||||
│ Magic │ TotalLen │ Type │ TransferId Len│ TransferId │ ChunkIdx │ Data │
|
||||
│ 0x43 0x53│ (4B BE) │ 0x01 │ (2B BE) │ (UTF-8) │ (4B BE) │ (raw) │
|
||||
└──────────┴──────────┴────────┴───────────────┴──────────────┴────────────┴───────────┘
|
||||
```
|
||||
|
||||
| 字段 | 大小 | 说明 |
|
||||
| -------------- | ---- | ------------------------------------------- |
|
||||
| Magic | 2B | 常量 `0x43 0x53` ("CS"), 用于区分 JSON 消息 |
|
||||
| TotalLen | 4B | Big-endian,帧总长度(不含 Magic/TotalLen) |
|
||||
| Type | 1B | `0x01` 代表 `file_chunk` |
|
||||
| TransferId Len | 2B | Big-endian,transferId 字符串长度 |
|
||||
| TransferId | nB | UTF-8 transferId(长度由上一字段给出) |
|
||||
| ChunkIdx | 4B | Big-endian,块索引,从 0 开始 |
|
||||
| Data | mB | 原始文件二进制数据(未编码) |
|
||||
|
||||
> 计算帧总长度:`TotalLen = 1 + 2 + transferIdLen + 4 + dataLen`(即 Type~Data 的长度和)。
|
||||
|
||||
### 4.3 消息解析策略
|
||||
|
||||
1. 读取 socket 数据到缓冲区;
|
||||
2. 若前两字节为 `0x43 0x53` → 按二进制帧解析:
|
||||
- 至少需要 6 字节头(Magic + TotalLen),不足则等待更多数据
|
||||
- 读取 `TotalLen` 判断帧整体长度,缓冲区不足则继续等待
|
||||
- 解析 Type/TransferId/ChunkIdx/Data,并传入文件接收逻辑
|
||||
3. 否则若首字节为 `{` → 按 JSON + `\n` 解析控制消息
|
||||
4. 其它数据丢弃 1 字节并继续循环,避免阻塞。
|
||||
|
||||
### 4.4 消息类型汇总(v1)
|
||||
|
||||
| 类型 | 方向 | 编码 | 用途 |
|
||||
| ---------------- | --------------- | -------- | ----------------------- |
|
||||
| `handshake` | Client → Server | JSON+\n | 握手请求(version=1) |
|
||||
| `handshake_ack` | Server → Client | JSON+\n | 握手响应 |
|
||||
| `ping` | Client → Server | JSON+\n | 心跳请求 |
|
||||
| `pong` | Server → Client | JSON+\n | 心跳响应 |
|
||||
| `file_start` | Client → Server | JSON+\n | 开始文件传输 |
|
||||
| `file_start_ack` | Server → Client | JSON+\n | 文件传输确认 |
|
||||
| `file_chunk` | Client → Server | 二进制帧 | 文件数据块(无 Base64,流式无 per-chunk ACK) |
|
||||
| `file_end` | Client → Server | JSON+\n | 文件传输结束 |
|
||||
| `file_complete` | Server → Client | JSON+\n | 传输完成结果 |
|
||||
|
||||
```
|
||||
{"type":"message_type",...其他字段...}\n
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. 文件传输协议
|
||||
|
||||
### 5.1 传输流程
|
||||
|
||||
```
|
||||
Client (Sender) Server (Receiver)
|
||||
| |
|
||||
|──── 1. file_start ────────────────>|
|
||||
| (文件元数据) |
|
||||
| |
|
||||
|<─── 2. file_start_ack ─────────────|
|
||||
| (接受/拒绝) |
|
||||
| |
|
||||
|══════ 循环发送数据块(流式,无 ACK) ═════|
|
||||
| |
|
||||
|──── 3. file_chunk [0] ────────────>|
|
||||
| |
|
||||
|──── 3. file_chunk [1] ────────────>|
|
||||
| |
|
||||
| ... 重复直到所有块发送完成 ... |
|
||||
| |
|
||||
|══════════════════════════════════════
|
||||
| |
|
||||
|──── 5. file_end ──────────────────>|
|
||||
| (所有块已发送) |
|
||||
| |
|
||||
|<─── 6. file_complete ──────────────|
|
||||
| (最终结果) |
|
||||
```
|
||||
|
||||
### 5.2 消息定义
|
||||
|
||||
#### 5.2.1 `file_start` - 开始传输
|
||||
|
||||
**方向:** Client → Server
|
||||
|
||||
```typescript
|
||||
type LanTransferFileStartMessage = {
|
||||
type: "file_start";
|
||||
transferId: string; // UUID,唯一传输标识
|
||||
fileName: string; // 文件名(含扩展名)
|
||||
fileSize: number; // 文件总字节数
|
||||
mimeType: string; // MIME 类型
|
||||
checksum: string; // 整个文件的 SHA-256 哈希(hex)
|
||||
totalChunks: number; // 总数据块数
|
||||
chunkSize: number; // 每块大小(字节)
|
||||
};
|
||||
```
|
||||
|
||||
**示例:**
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "file_start",
|
||||
"transferId": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"fileName": "backup.zip",
|
||||
"fileSize": 524288000,
|
||||
"mimeType": "application/zip",
|
||||
"checksum": "a1b2c3d4e5f6789012345678901234567890abcdef1234567890abcdef123456",
|
||||
"totalChunks": 8192,
|
||||
"chunkSize": 65536
|
||||
}
|
||||
```
|
||||
|
||||
#### 5.2.2 `file_start_ack` - 传输确认
|
||||
|
||||
**方向:** Server → Client
|
||||
|
||||
```typescript
|
||||
type LanTransferFileStartAckMessage = {
|
||||
type: "file_start_ack";
|
||||
transferId: string; // 对应的传输 ID
|
||||
accepted: boolean; // 是否接受传输
|
||||
message?: string; // 拒绝原因
|
||||
};
|
||||
```
|
||||
|
||||
**接受示例:**
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "file_start_ack",
|
||||
"transferId": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"accepted": true
|
||||
}
|
||||
```
|
||||
|
||||
**拒绝示例:**
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "file_start_ack",
|
||||
"transferId": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"accepted": false,
|
||||
"message": "Insufficient storage space"
|
||||
}
|
||||
```
|
||||
|
||||
#### 5.2.3 `file_chunk` - 数据块
|
||||
|
||||
**方向:** Client → Server(**二进制帧**,见 4.2)
|
||||
|
||||
- 不再使用 JSON/`\n`,也不再使用 Base64
|
||||
- 帧结构:`Magic` + `TotalLen` + `Type` + `TransferId` + `ChunkIdx` + `Data`
|
||||
- `Type` 固定 `0x01`,`Data` 为原始文件二进制数据
|
||||
- 传输完整性依赖 `file_start.checksum`(全文件 SHA-256);分块校验和可选,不在帧中发送
|
||||
|
||||
#### 5.2.4 `file_chunk_ack` - 数据块确认(v1 流式不使用)
|
||||
|
||||
v1 采用流式传输,不发送 per-chunk ACK。本节类型仅保留作为向后兼容参考,实际不会发送。
|
||||
|
||||
#### 5.2.5 `file_end` - 传输结束
|
||||
|
||||
**方向:** Client → Server
|
||||
|
||||
```typescript
|
||||
type LanTransferFileEndMessage = {
|
||||
type: "file_end";
|
||||
transferId: string; // 传输 ID
|
||||
};
|
||||
```
|
||||
|
||||
**示例:**
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "file_end",
|
||||
"transferId": "550e8400-e29b-41d4-a716-446655440000"
|
||||
}
|
||||
```
|
||||
|
||||
#### 5.2.6 `file_complete` - 传输完成
|
||||
|
||||
**方向:** Server → Client
|
||||
|
||||
```typescript
|
||||
type LanTransferFileCompleteMessage = {
|
||||
type: "file_complete";
|
||||
transferId: string; // 传输 ID
|
||||
success: boolean; // 是否成功
|
||||
filePath?: string; // 保存路径(成功时)
|
||||
error?: string; // 错误信息(失败时)
|
||||
};
|
||||
```
|
||||
|
||||
**成功示例:**
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "file_complete",
|
||||
"transferId": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"success": true,
|
||||
"filePath": "/storage/emulated/0/Documents/backup.zip"
|
||||
}
|
||||
```
|
||||
|
||||
**失败示例:**
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "file_complete",
|
||||
"transferId": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"success": false,
|
||||
"error": "File checksum verification failed"
|
||||
}
|
||||
```
|
||||
|
||||
### 5.3 校验和算法
|
||||
|
||||
#### 整个文件校验和(保持不变)
|
||||
|
||||
```typescript
|
||||
async function calculateFileChecksum(filePath: string): Promise<string> {
|
||||
const hash = crypto.createHash("sha256");
|
||||
const stream = fs.createReadStream(filePath);
|
||||
|
||||
for await (const chunk of stream) {
|
||||
hash.update(chunk);
|
||||
}
|
||||
|
||||
return hash.digest("hex");
|
||||
}
|
||||
```
|
||||
|
||||
#### 数据块校验和
|
||||
|
||||
v1 默认 **不传输分块校验和**,依赖最终文件 checksum。若需要,可在应用层自定义(非协议字段)。
|
||||
|
||||
### 5.4 校验流程
|
||||
|
||||
**发送端(Client):**
|
||||
|
||||
1. 发送前计算整个文件的 SHA-256 → `file_start.checksum`
|
||||
2. 分块直接发送原始二进制(无 Base64)
|
||||
|
||||
**接收端(Server):**
|
||||
|
||||
1. 收到 `file_chunk` 后直接使用二进制数据
|
||||
2. 边收边落盘并增量计算 SHA-256(推荐)
|
||||
3. 所有块接收完成后,计算/完成增量哈希,得到最终 SHA-256
|
||||
4. 与 `file_start.checksum` 比对,结果写入 `file_complete`
|
||||
|
||||
### 5.5 数据块大小计算
|
||||
|
||||
```typescript
|
||||
const CHUNK_SIZE = 512 * 1024; // 512KB
|
||||
|
||||
const totalChunks = Math.ceil(fileSize / CHUNK_SIZE);
|
||||
|
||||
// 最后一个块可能小于 CHUNK_SIZE
|
||||
const lastChunkSize = fileSize % CHUNK_SIZE || CHUNK_SIZE;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 心跳与连接保活
|
||||
|
||||
### 6.1 心跳消息
|
||||
|
||||
#### `ping`
|
||||
|
||||
**方向:** Client → Server
|
||||
|
||||
```typescript
|
||||
type LanTransferPingMessage = {
|
||||
type: "ping";
|
||||
payload?: string; // 可选载荷
|
||||
};
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "ping",
|
||||
"payload": "heartbeat"
|
||||
}
|
||||
```
|
||||
|
||||
#### `pong`
|
||||
|
||||
**方向:** Server → Client
|
||||
|
||||
```typescript
|
||||
type LanTransferPongMessage = {
|
||||
type: "pong";
|
||||
received: boolean; // 确认收到
|
||||
payload?: string; // 回传 ping 的载荷
|
||||
};
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "pong",
|
||||
"received": true,
|
||||
"payload": "heartbeat"
|
||||
}
|
||||
```
|
||||
|
||||
### 6.2 心跳策略
|
||||
|
||||
- 握手成功后立即发送一次 `ping` 验证连接
|
||||
- 可选:定期发送心跳保持连接活跃
|
||||
- `pong` 应返回 `ping` 中的 `payload`(可选)
|
||||
|
||||
---
|
||||
|
||||
## 7. 错误处理
|
||||
|
||||
### 7.1 超时配置
|
||||
|
||||
| 操作 | 超时时间 | 说明 |
|
||||
| ---------- | -------- | --------------------- |
|
||||
| TCP 连接 | 10 秒 | 连接建立超时 |
|
||||
| 握手等待 | 10 秒 | 等待 `handshake_ack` |
|
||||
| 传输完成 | 60 秒 | 等待 `file_complete` |
|
||||
|
||||
### 7.2 错误场景处理
|
||||
|
||||
| 场景 | Client 处理 | Server 处理 |
|
||||
| --------------- | ------------------ | ---------------------- |
|
||||
| TCP 连接失败 | 通知 UI,允许重试 | - |
|
||||
| 握手超时 | 断开连接,通知 UI | 关闭 socket |
|
||||
| 握手被拒绝 | 显示拒绝原因 | - |
|
||||
| 数据块处理失败 | 中止传输,清理状态 | 清理临时文件 |
|
||||
| 连接意外断开 | 清理状态,通知 UI | 清理临时文件 |
|
||||
| 存储空间不足 | - | 发送 `accepted: false` |
|
||||
|
||||
### 7.3 资源清理
|
||||
|
||||
**Client 端:**
|
||||
|
||||
```typescript
|
||||
function cleanup(): void {
|
||||
// 1. 销毁文件读取流
|
||||
if (readStream) {
|
||||
readStream.destroy();
|
||||
}
|
||||
// 2. 清理传输状态
|
||||
activeTransfer = undefined;
|
||||
// 3. 关闭 socket(如需要)
|
||||
socket?.destroy();
|
||||
}
|
||||
```
|
||||
|
||||
**Server 端:**
|
||||
|
||||
```typescript
|
||||
function cleanup(): void {
|
||||
// 1. 关闭文件写入流
|
||||
if (writeStream) {
|
||||
writeStream.end();
|
||||
}
|
||||
// 2. 删除未完成的临时文件
|
||||
if (tempFilePath) {
|
||||
fs.unlinkSync(tempFilePath);
|
||||
}
|
||||
// 3. 清理传输状态
|
||||
activeTransfer = undefined;
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. 常量与配置
|
||||
|
||||
### 8.1 协议常量
|
||||
|
||||
```typescript
|
||||
// 协议版本(v1 = 控制 JSON + 二进制 chunk + 流式传输)
|
||||
export const LAN_TRANSFER_PROTOCOL_VERSION = "1";
|
||||
|
||||
// 服务发现
|
||||
export const LAN_TRANSFER_SERVICE_TYPE = "cherrystudio";
|
||||
export const LAN_TRANSFER_SERVICE_FULL_NAME = "_cherrystudio._tcp";
|
||||
|
||||
// TCP 端口
|
||||
export const LAN_TRANSFER_TCP_PORT = 53317;
|
||||
|
||||
// 文件传输(与二进制帧一致)
|
||||
export const LAN_TRANSFER_CHUNK_SIZE = 512 * 1024; // 512KB
|
||||
export const LAN_TRANSFER_GLOBAL_TIMEOUT_MS = 10 * 60 * 1000; // 10 分钟
|
||||
|
||||
// 超时设置
|
||||
export const LAN_TRANSFER_HANDSHAKE_TIMEOUT_MS = 10_000; // 10秒
|
||||
export const LAN_TRANSFER_CHUNK_TIMEOUT_MS = 30_000; // 30秒
|
||||
export const LAN_TRANSFER_COMPLETE_TIMEOUT_MS = 60_000; // 60秒
|
||||
```
|
||||
|
||||
### 8.2 支持的文件类型
|
||||
|
||||
当前仅支持 ZIP 文件:
|
||||
|
||||
```typescript
|
||||
export const LAN_TRANSFER_ALLOWED_EXTENSIONS = [".zip"];
|
||||
export const LAN_TRANSFER_ALLOWED_MIME_TYPES = [
|
||||
"application/zip",
|
||||
"application/x-zip-compressed",
|
||||
];
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9. 完整时序图
|
||||
|
||||
### 9.1 完整传输流程(v1,流式传输)
|
||||
|
||||
```
|
||||
┌─────────┐ ┌─────────┐ ┌─────────┐
|
||||
│ Renderer│ │ Main │ │ Mobile │
|
||||
│ (UI) │ │ Process │ │ Server │
|
||||
└────┬────┘ └────┬────┘ └────┬────┘
|
||||
│ │ │
|
||||
│ ════════════ 服务发现阶段 ════════════ │
|
||||
│ │ │
|
||||
│ startScan() │ │
|
||||
│────────────────────────────────────>│ │
|
||||
│ │ mDNS browse │
|
||||
│ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─>│
|
||||
│ │ │
|
||||
│ │<─ ─ ─ service discovered ─ ─ ─ ─ ─ ─│
|
||||
│ │ │
|
||||
│<────── onServicesUpdated ───────────│ │
|
||||
│ │ │
|
||||
│ ════════════ 握手连接阶段 ════════════ │
|
||||
│ │ │
|
||||
│ connect(peer) │ │
|
||||
│────────────────────────────────────>│ │
|
||||
│ │──────── TCP Connect ───────────────>│
|
||||
│ │ │
|
||||
│ │──────── handshake ─────────────────>│
|
||||
│ │ │
|
||||
│ │<─────── handshake_ack ──────────────│
|
||||
│ │ │
|
||||
│ │──────── ping ──────────────────────>│
|
||||
│ │<─────── pong ───────────────────────│
|
||||
│ │ │
|
||||
│<────── connect result ──────────────│ │
|
||||
│ │ │
|
||||
│ ════════════ 文件传输阶段 ════════════ │
|
||||
│ │ │
|
||||
│ sendFile(path) │ │
|
||||
│────────────────────────────────────>│ │
|
||||
│ │──────── file_start ────────────────>│
|
||||
│ │ │
|
||||
│ │<─────── file_start_ack ─────────────│
|
||||
│ │ │
|
||||
│ │ │
|
||||
│ │══════ 循环发送数据块 ═══════════════│
|
||||
│ │ │
|
||||
│ │──────── file_chunk[0] (binary) ────>│
|
||||
│<────── progress event ──────────────│ │
|
||||
│ │ │
|
||||
│ │──────── file_chunk[1] (binary) ────>│
|
||||
│<────── progress event ──────────────│ │
|
||||
│ │ │
|
||||
│ │ ... 重复 ... │
|
||||
│ │ │
|
||||
│ │══════════════════════════════════════│
|
||||
│ │ │
|
||||
│ │──────── file_end ──────────────────>│
|
||||
│ │ │
|
||||
│ │<─────── file_complete ──────────────│
|
||||
│ │ │
|
||||
│<────── complete event ──────────────│ │
|
||||
│<────── sendFile result ─────────────│ │
|
||||
│ │ │
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. 移动端实现指南(v1 要点)
|
||||
|
||||
### 10.1 必须实现的功能
|
||||
|
||||
1. **mDNS 服务发布**
|
||||
|
||||
- 发布 `_cherrystudio._tcp` 服务
|
||||
- 提供 TCP 端口号 `53317`
|
||||
- 可选:TXT 记录(版本、平台信息)
|
||||
|
||||
2. **TCP 服务端**
|
||||
|
||||
- 监听指定端口
|
||||
- 支持单连接或多连接
|
||||
|
||||
3. **消息解析**
|
||||
|
||||
- 控制消息:UTF-8 + `\n` JSON
|
||||
- 数据消息:二进制帧(Magic+TotalLen 分帧)
|
||||
|
||||
4. **握手处理**
|
||||
|
||||
- 验证 `handshake` 消息
|
||||
- 发送 `handshake_ack` 响应
|
||||
- 响应 `ping` 消息
|
||||
|
||||
5. **文件接收(流式模式)**
|
||||
- 解析 `file_start`,准备接收
|
||||
- 接收 `file_chunk` 二进制帧,直接写入文件/缓冲并增量哈希
|
||||
- v1 不发送 per-chunk ACK(流式传输)
|
||||
- 处理 `file_end`,完成增量哈希并校验 checksum
|
||||
- 发送 `file_complete` 结果
|
||||
|
||||
### 10.2 推荐的库
|
||||
|
||||
**React Native / Expo:**
|
||||
|
||||
- mDNS: `react-native-zeroconf` 或 `@homielab/react-native-bonjour`
|
||||
- TCP: `react-native-tcp-socket`
|
||||
- Crypto: `expo-crypto` 或 `react-native-quick-crypto`
|
||||
|
||||
### 10.3 接收端伪代码
|
||||
|
||||
```typescript
|
||||
class FileReceiver {
|
||||
private transfer?: {
|
||||
id: string;
|
||||
fileName: string;
|
||||
fileSize: number;
|
||||
checksum: string;
|
||||
totalChunks: number;
|
||||
receivedChunks: number;
|
||||
tempPath: string;
|
||||
// v1: 边收边写文件,避免大文件 OOM
|
||||
// stream: FileSystem writable stream (平台相关封装)
|
||||
};
|
||||
|
||||
handleMessage(message: any) {
|
||||
switch (message.type) {
|
||||
case "handshake":
|
||||
this.handleHandshake(message);
|
||||
break;
|
||||
case "ping":
|
||||
this.sendPong(message);
|
||||
break;
|
||||
case "file_start":
|
||||
this.handleFileStart(message);
|
||||
break;
|
||||
// v1: file_chunk 为二进制帧,不再走 JSON 分支
|
||||
case "file_end":
|
||||
this.handleFileEnd(message);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
handleFileStart(msg: LanTransferFileStartMessage) {
|
||||
// 1. 检查存储空间
|
||||
// 2. 创建临时文件
|
||||
// 3. 初始化传输状态
|
||||
// 4. 发送 file_start_ack
|
||||
}
|
||||
|
||||
// v1: 二进制帧处理在 socket data 流中解析,随后调用 handleBinaryFileChunk
|
||||
handleBinaryFileChunk(transferId: string, chunkIndex: number, data: Buffer) {
|
||||
// 直接使用二进制数据,按 chunkSize/lastChunk 计算长度
|
||||
// 写入文件流并更新增量 SHA-256
|
||||
this.transfer.receivedChunks++;
|
||||
// v1: 流式传输,不发送 per-chunk ACK
|
||||
}
|
||||
|
||||
handleFileEnd(msg: LanTransferFileEndMessage) {
|
||||
// 1. 合并所有数据块
|
||||
// 2. 验证完整文件 checksum
|
||||
// 3. 写入最终位置
|
||||
// 4. 发送 file_complete
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 附录 A:TypeScript 类型定义
|
||||
|
||||
完整的类型定义位于 `packages/shared/config/types.ts`:
|
||||
|
||||
```typescript
|
||||
// 握手消息
|
||||
export interface LanTransferHandshakeMessage {
|
||||
type: "handshake";
|
||||
deviceName: string;
|
||||
version: string;
|
||||
platform?: string;
|
||||
appVersion?: string;
|
||||
}
|
||||
|
||||
export interface LanTransferHandshakeAckMessage {
|
||||
type: "handshake_ack";
|
||||
accepted: boolean;
|
||||
message?: string;
|
||||
}
|
||||
|
||||
// 心跳消息
|
||||
export interface LanTransferPingMessage {
|
||||
type: "ping";
|
||||
payload?: string;
|
||||
}
|
||||
|
||||
export interface LanTransferPongMessage {
|
||||
type: "pong";
|
||||
received: boolean;
|
||||
payload?: string;
|
||||
}
|
||||
|
||||
// 文件传输消息 (Client -> Server)
|
||||
export interface LanTransferFileStartMessage {
|
||||
type: "file_start";
|
||||
transferId: string;
|
||||
fileName: string;
|
||||
fileSize: number;
|
||||
mimeType: string;
|
||||
checksum: string;
|
||||
totalChunks: number;
|
||||
chunkSize: number;
|
||||
}
|
||||
|
||||
export interface LanTransferFileChunkMessage {
|
||||
type: "file_chunk";
|
||||
transferId: string;
|
||||
chunkIndex: number;
|
||||
data: string; // Base64 encoded (v1: 二进制帧模式下不使用)
|
||||
}
|
||||
|
||||
export interface LanTransferFileEndMessage {
|
||||
type: "file_end";
|
||||
transferId: string;
|
||||
}
|
||||
|
||||
// 文件传输响应消息 (Server -> Client)
|
||||
export interface LanTransferFileStartAckMessage {
|
||||
type: "file_start_ack";
|
||||
transferId: string;
|
||||
accepted: boolean;
|
||||
message?: string;
|
||||
}
|
||||
|
||||
// v1 流式不发送 per-chunk ACK,以下类型仅用于向后兼容参考
|
||||
export interface LanTransferFileChunkAckMessage {
|
||||
type: "file_chunk_ack";
|
||||
transferId: string;
|
||||
chunkIndex: number;
|
||||
received: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface LanTransferFileCompleteMessage {
|
||||
type: "file_complete";
|
||||
transferId: string;
|
||||
success: boolean;
|
||||
filePath?: string;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
// 常量
|
||||
export const LAN_TRANSFER_TCP_PORT = 53317;
|
||||
export const LAN_TRANSFER_CHUNK_SIZE = 512 * 1024;
|
||||
export const LAN_TRANSFER_CHUNK_TIMEOUT_MS = 30_000;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 附录 B:版本历史
|
||||
|
||||
| 版本 | 日期 | 变更 |
|
||||
| ---- | ------- | ---------------------------------------- |
|
||||
| 1.0 | 2025-12 | 初始发布版本,支持二进制帧格式与流式传输 |
|
||||
@ -134,38 +134,38 @@ artifactBuildCompleted: scripts/artifact-build-completed.js
|
||||
releaseInfo:
|
||||
releaseNotes: |
|
||||
<!--LANG:en-->
|
||||
Cherry Studio 1.7.5 - Filesystem MCP Overhaul & Topic Management
|
||||
Cherry Studio 1.7.8 - Bug Fixes & Performance Improvements
|
||||
|
||||
This release features a completely rewritten filesystem MCP server, new batch topic management, and improved assistant management.
|
||||
This release focuses on bug fixes and performance optimizations.
|
||||
|
||||
✨ New Features
|
||||
- [MCP] Rewrite filesystem MCP server with improved tool set (glob, ls, grep, read, write, edit, delete)
|
||||
- [Topics] Add topic manage mode for batch delete and move operations with search functionality
|
||||
- [Assistants] Merge import/subscribe popups and add export to assistant management
|
||||
- [Knowledge] Use prompt injection for forced knowledge base search (faster response times)
|
||||
- [Settings] Add tool use mode setting (prompt/function) to default assistant settings
|
||||
⚡ Performance
|
||||
- [ModelList] Improve model list loading performance
|
||||
|
||||
🐛 Bug Fixes
|
||||
- [Model] Correct typo in Gemini 3 Pro Image Preview model name
|
||||
- [Installer] Auto-install VC++ Redistributable without user prompt
|
||||
- [Notes] Fix notes directory validation and default path reset for cross-platform restore
|
||||
- [OAuth] Bind OAuth callback server to localhost (127.0.0.1) for security
|
||||
- [Ollama] Fix new users unable to use Ollama models
|
||||
- [Ollama] Improve reasoningEffort handling
|
||||
- [Assistants] Prevent deleting last assistant and add error message
|
||||
- [Shortcut] Fix shortcut icons sorting disorder
|
||||
- [Memory] Fix global memory settings submit failure
|
||||
- [Windows] Fix remember size not working for SelectionAction window
|
||||
- [Anthropic] Fix API base URL handling
|
||||
- [Files] Allow more file extensions
|
||||
|
||||
<!--LANG:zh-CN-->
|
||||
Cherry Studio 1.7.5 - 文件系统 MCP 重构与话题管理
|
||||
Cherry Studio 1.7.8 - 问题修复与性能优化
|
||||
|
||||
本次更新完全重写了文件系统 MCP 服务器,新增批量话题管理功能,并改进了助手管理。
|
||||
本次更新专注于问题修复和性能优化。
|
||||
|
||||
✨ 新功能
|
||||
- [MCP] 重写文件系统 MCP 服务器,提供改进的工具集(glob、ls、grep、read、write、edit、delete)
|
||||
- [话题] 新增话题管理模式,支持批量删除和移动操作,带搜索功能
|
||||
- [助手] 合并导入/订阅弹窗,并在助手管理中添加导出功能
|
||||
- [知识库] 使用提示词注入进行强制知识库搜索(响应更快)
|
||||
- [设置] 在默认助手设置中添加工具使用模式设置(prompt/function)
|
||||
⚡ 性能优化
|
||||
- [模型列表] 提升模型列表加载性能
|
||||
|
||||
🐛 问题修复
|
||||
- [模型] 修正 Gemini 3 Pro Image Preview 模型名称的拼写错误
|
||||
- [安装程序] 自动安装 VC++ 运行库,无需用户确认
|
||||
- [笔记] 修复跨平台恢复场景下的笔记目录验证和默认路径重置逻辑
|
||||
- [OAuth] 将 OAuth 回调服务器绑定到 localhost (127.0.0.1) 以提高安全性
|
||||
- [Ollama] 修复新用户无法使用 Ollama 模型的问题
|
||||
- [Ollama] 改进推理参数处理
|
||||
- [助手] 防止删除最后一个助手并添加错误提示
|
||||
- [快捷方式] 修复快捷方式图标排序混乱
|
||||
- [记忆] 修复全局记忆设置提交失败
|
||||
- [窗口] 修复 SelectionAction 窗口记住尺寸不生效
|
||||
- [Anthropic] 修复 API 地址处理
|
||||
- [文件] 允许更多文件扩展名
|
||||
<!--LANG:END-->
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import react from '@vitejs/plugin-react-swc'
|
||||
import { CodeInspectorPlugin } from 'code-inspector-plugin'
|
||||
import { defineConfig, externalizeDepsPlugin } from 'electron-vite'
|
||||
import { defineConfig } from 'electron-vite'
|
||||
import { resolve } from 'path'
|
||||
import { visualizer } from 'rollup-plugin-visualizer'
|
||||
|
||||
@ -17,7 +17,7 @@ const isProd = process.env.NODE_ENV === 'production'
|
||||
|
||||
export default defineConfig({
|
||||
main: {
|
||||
plugins: [externalizeDepsPlugin(), ...visualizerPlugin('main')],
|
||||
plugins: [...visualizerPlugin('main')],
|
||||
resolve: {
|
||||
alias: {
|
||||
'@main': resolve('src/main'),
|
||||
@ -54,8 +54,7 @@ export default defineConfig({
|
||||
plugins: [
|
||||
react({
|
||||
tsDecorators: true
|
||||
}),
|
||||
externalizeDepsPlugin()
|
||||
})
|
||||
],
|
||||
resolve: {
|
||||
alias: {
|
||||
|
||||
16
package.json
16
package.json
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "CherryStudio",
|
||||
"version": "1.7.5",
|
||||
"version": "1.7.8",
|
||||
"private": true,
|
||||
"description": "A powerful AI assistant for producer.",
|
||||
"main": "./out/main/index.js",
|
||||
@ -27,6 +27,7 @@
|
||||
"scripts": {
|
||||
"start": "electron-vite preview",
|
||||
"dev": "dotenv electron-vite dev",
|
||||
"dev:watch": "dotenv electron-vite dev -- -w",
|
||||
"debug": "electron-vite -- --inspect --sourcemap --remote-debugging-port=9222",
|
||||
"build": "npm run typecheck && electron-vite build",
|
||||
"build:check": "yarn lint && yarn test",
|
||||
@ -87,6 +88,7 @@
|
||||
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
|
||||
"@paymoapp/electron-shutdown-handler": "^1.1.2",
|
||||
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
||||
"bonjour-service": "^1.3.0",
|
||||
"emoji-picker-element-data": "^1",
|
||||
"express": "^5.1.0",
|
||||
"font-list": "^2.0.0",
|
||||
@ -97,10 +99,8 @@
|
||||
"node-stream-zip": "^1.15.0",
|
||||
"officeparser": "^4.2.0",
|
||||
"os-proxy-config": "^1.1.2",
|
||||
"qrcode.react": "^4.2.0",
|
||||
"selection-hook": "^1.0.12",
|
||||
"sharp": "^0.34.3",
|
||||
"socket.io": "^4.8.1",
|
||||
"swagger-jsdoc": "^6.2.8",
|
||||
"swagger-ui-express": "^5.0.1",
|
||||
"tesseract.js": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch",
|
||||
@ -274,7 +274,7 @@
|
||||
"electron-reload": "^2.0.0-alpha.1",
|
||||
"electron-store": "^8.2.0",
|
||||
"electron-updater": "patch:electron-updater@npm%3A6.7.0#~/.yarn/patches/electron-updater-npm-6.7.0-47b11bb0d4.patch",
|
||||
"electron-vite": "4.0.1",
|
||||
"electron-vite": "5.0.0",
|
||||
"electron-window-state": "^5.0.3",
|
||||
"emittery": "^1.0.3",
|
||||
"emoji-picker-element": "^1.22.1",
|
||||
@ -371,7 +371,7 @@
|
||||
"undici": "6.21.2",
|
||||
"unified": "^11.0.5",
|
||||
"uuid": "^13.0.0",
|
||||
"vite": "npm:rolldown-vite@7.1.5",
|
||||
"vite": "npm:rolldown-vite@7.3.0",
|
||||
"vitest": "^3.2.4",
|
||||
"webdav": "^5.8.0",
|
||||
"winston": "^3.17.0",
|
||||
@ -401,7 +401,7 @@
|
||||
"pkce-challenge@npm:^4.1.0": "patch:pkce-challenge@npm%3A4.1.0#~/.yarn/patches/pkce-challenge-npm-4.1.0-fbc51695a3.patch",
|
||||
"tar-fs": "^2.1.4",
|
||||
"undici": "6.21.2",
|
||||
"vite": "npm:rolldown-vite@7.1.5",
|
||||
"vite": "npm:rolldown-vite@7.3.0",
|
||||
"tesseract.js@npm:*": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch",
|
||||
"@ai-sdk/openai@npm:^2.0.52": "patch:@ai-sdk/openai@npm%3A2.0.52#~/.yarn/patches/@ai-sdk-openai-npm-2.0.52-b36d949c76.patch",
|
||||
"@img/sharp-darwin-arm64": "0.34.3",
|
||||
@ -417,7 +417,9 @@
|
||||
"@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.85#~/.yarn/patches/@ai-sdk-openai-npm-2.0.85-27483d1d6a.patch",
|
||||
"@ai-sdk/google@npm:^2.0.40": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
|
||||
"@ai-sdk/openai-compatible@npm:^1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
|
||||
"@ai-sdk/google@npm:2.0.49": "patch:@ai-sdk/google@npm%3A2.0.49#~/.yarn/patches/@ai-sdk-google-npm-2.0.49-84720f41bd.patch"
|
||||
"@ai-sdk/google@npm:2.0.49": "patch:@ai-sdk/google@npm%3A2.0.49#~/.yarn/patches/@ai-sdk-google-npm-2.0.49-84720f41bd.patch",
|
||||
"@ai-sdk/openai-compatible@npm:1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch",
|
||||
"@ai-sdk/openai-compatible@npm:^1.0.19": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch"
|
||||
},
|
||||
"packageManager": "yarn@4.9.1",
|
||||
"lint-staged": {
|
||||
|
||||
@ -41,7 +41,7 @@
|
||||
"ai": "^5.0.26"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ai-sdk/openai-compatible": "^1.0.28",
|
||||
"@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch",
|
||||
"@ai-sdk/provider": "^2.0.0",
|
||||
"@ai-sdk/provider-utils": "^3.0.17"
|
||||
},
|
||||
|
||||
@ -42,7 +42,7 @@
|
||||
"@ai-sdk/anthropic": "^2.0.49",
|
||||
"@ai-sdk/azure": "^2.0.87",
|
||||
"@ai-sdk/deepseek": "^1.0.31",
|
||||
"@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
|
||||
"@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch",
|
||||
"@ai-sdk/provider": "^2.0.0",
|
||||
"@ai-sdk/provider-utils": "^3.0.17",
|
||||
"@ai-sdk/xai": "^2.0.36",
|
||||
|
||||
@ -22,10 +22,10 @@ const TOOL_USE_TAG_CONFIG: TagConfig = {
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认系统提示符模板(提取自 Cherry Studio)
|
||||
* 默认系统提示符模板
|
||||
*/
|
||||
const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \\
|
||||
You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
|
||||
export const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \
|
||||
You can use one or more tools per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
|
||||
|
||||
## Tool Use Formatting
|
||||
|
||||
@ -74,10 +74,13 @@ Here are the rules you should always follow to solve your task:
|
||||
4. Never re-do a tool call that you previously did with the exact same parameters.
|
||||
5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format.
|
||||
|
||||
## Response rules
|
||||
|
||||
Respond in the language of the user's query, unless the user instructions specify additional requirements for the language to be used.
|
||||
|
||||
# User Instructions
|
||||
{{ USER_SYSTEM_PROMPT }}
|
||||
|
||||
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.`
|
||||
`
|
||||
|
||||
/**
|
||||
* 默认工具使用示例(提取自 Cherry Studio)
|
||||
|
||||
@ -233,6 +233,8 @@ export enum IpcChannel {
|
||||
Backup_ListS3Files = 'backup:listS3Files',
|
||||
Backup_DeleteS3File = 'backup:deleteS3File',
|
||||
Backup_CheckS3Connection = 'backup:checkS3Connection',
|
||||
Backup_CreateLanTransferBackup = 'backup:createLanTransferBackup',
|
||||
Backup_DeleteTempBackup = 'backup:deleteTempBackup',
|
||||
|
||||
// zip
|
||||
Zip_Compress = 'zip:compress',
|
||||
@ -316,6 +318,7 @@ export enum IpcChannel {
|
||||
Memory_DeleteUser = 'memory:delete-user',
|
||||
Memory_DeleteAllMemoriesForUser = 'memory:delete-all-memories-for-user',
|
||||
Memory_GetUsersList = 'memory:get-users-list',
|
||||
Memory_MigrateMemoryDb = 'memory:migrate-memory-db',
|
||||
|
||||
// TRACE
|
||||
TRACE_SAVE_DATA = 'trace:saveData',
|
||||
@ -361,6 +364,7 @@ export enum IpcChannel {
|
||||
OCR_ListProviders = 'ocr:list-providers',
|
||||
|
||||
// OVMS
|
||||
Ovms_IsSupported = 'ovms:is-supported',
|
||||
Ovms_AddModel = 'ovms:add-model',
|
||||
Ovms_StopAddModel = 'ovms:stop-addmodel',
|
||||
Ovms_GetModels = 'ovms:get-models',
|
||||
@ -381,10 +385,14 @@ export enum IpcChannel {
|
||||
ClaudeCodePlugin_ReadContent = 'claudeCodePlugin:read-content',
|
||||
ClaudeCodePlugin_WriteContent = 'claudeCodePlugin:write-content',
|
||||
|
||||
// WebSocket
|
||||
WebSocket_Start = 'webSocket:start',
|
||||
WebSocket_Stop = 'webSocket:stop',
|
||||
WebSocket_Status = 'webSocket:status',
|
||||
WebSocket_SendFile = 'webSocket:send-file',
|
||||
WebSocket_GetAllCandidates = 'webSocket:get-all-candidates'
|
||||
// Local Transfer
|
||||
LocalTransfer_ListServices = 'local-transfer:list',
|
||||
LocalTransfer_StartScan = 'local-transfer:start-scan',
|
||||
LocalTransfer_StopScan = 'local-transfer:stop-scan',
|
||||
LocalTransfer_ServicesUpdated = 'local-transfer:services-updated',
|
||||
LocalTransfer_Connect = 'local-transfer:connect',
|
||||
LocalTransfer_Disconnect = 'local-transfer:disconnect',
|
||||
LocalTransfer_ClientEvent = 'local-transfer:client-event',
|
||||
LocalTransfer_SendFile = 'local-transfer:send-file',
|
||||
LocalTransfer_CancelTransfer = 'local-transfer:cancel-transfer'
|
||||
}
|
||||
|
||||
@ -52,3 +52,196 @@ export interface WebSocketCandidatesResponse {
|
||||
interface: string
|
||||
priority: number
|
||||
}
|
||||
|
||||
export type LocalTransferPeer = {
|
||||
id: string
|
||||
name: string
|
||||
host?: string
|
||||
fqdn?: string
|
||||
port?: number
|
||||
type?: string
|
||||
protocol?: 'tcp' | 'udp'
|
||||
addresses: string[]
|
||||
txt?: Record<string, string>
|
||||
updatedAt: number
|
||||
}
|
||||
|
||||
export type LocalTransferState = {
|
||||
services: LocalTransferPeer[]
|
||||
isScanning: boolean
|
||||
lastScanStartedAt?: number
|
||||
lastUpdatedAt: number
|
||||
lastError?: string
|
||||
}
|
||||
|
||||
export type LanHandshakeRequestMessage = {
|
||||
type: 'handshake'
|
||||
deviceName: string
|
||||
version: string
|
||||
platform?: string
|
||||
appVersion?: string
|
||||
}
|
||||
|
||||
export type LanHandshakeAckMessage = {
|
||||
type: 'handshake_ack'
|
||||
accepted: boolean
|
||||
message?: string
|
||||
}
|
||||
|
||||
export type LocalTransferConnectPayload = {
|
||||
peerId: string
|
||||
metadata?: Record<string, string>
|
||||
timeoutMs?: number
|
||||
}
|
||||
|
||||
export type LanClientEvent =
|
||||
| {
|
||||
type: 'ping_sent'
|
||||
payload: string
|
||||
timestamp: number
|
||||
peerId?: string
|
||||
peerName?: string
|
||||
}
|
||||
| {
|
||||
type: 'pong'
|
||||
payload?: string
|
||||
received?: boolean
|
||||
timestamp: number
|
||||
peerId?: string
|
||||
peerName?: string
|
||||
}
|
||||
| {
|
||||
type: 'socket_closed'
|
||||
reason?: string
|
||||
timestamp: number
|
||||
peerId?: string
|
||||
peerName?: string
|
||||
}
|
||||
| {
|
||||
type: 'error'
|
||||
message: string
|
||||
timestamp: number
|
||||
peerId?: string
|
||||
peerName?: string
|
||||
}
|
||||
| {
|
||||
type: 'file_transfer_progress'
|
||||
transferId: string
|
||||
fileName: string
|
||||
bytesSent: number
|
||||
totalBytes: number
|
||||
chunkIndex: number
|
||||
totalChunks: number
|
||||
progress: number // 0-100
|
||||
speed: number // bytes/sec
|
||||
timestamp: number
|
||||
peerId?: string
|
||||
peerName?: string
|
||||
}
|
||||
| {
|
||||
type: 'file_transfer_complete'
|
||||
transferId: string
|
||||
fileName: string
|
||||
success: boolean
|
||||
filePath?: string
|
||||
error?: string
|
||||
timestamp: number
|
||||
peerId?: string
|
||||
peerName?: string
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// LAN File Transfer Protocol Types
|
||||
// =============================================================================
|
||||
|
||||
// Constants for file transfer
|
||||
export const LAN_TRANSFER_TCP_PORT = 53317
|
||||
export const LAN_TRANSFER_CHUNK_SIZE = 512 * 1024 // 512KB
|
||||
export const LAN_TRANSFER_MAX_FILE_SIZE = 500 * 1024 * 1024 // 500MB
|
||||
export const LAN_TRANSFER_COMPLETE_TIMEOUT_MS = 60_000 // 60s - wait for file_complete after file_end
|
||||
export const LAN_TRANSFER_GLOBAL_TIMEOUT_MS = 10 * 60 * 1000 // 10 minutes - global transfer timeout
|
||||
|
||||
// Binary protocol constants (v1)
|
||||
export const LAN_TRANSFER_PROTOCOL_VERSION = '1'
|
||||
export const LAN_BINARY_FRAME_MAGIC = 0x4353 // "CS" as uint16
|
||||
export const LAN_BINARY_TYPE_FILE_CHUNK = 0x01
|
||||
|
||||
// Messages from Electron (Client/Sender) to Mobile (Server/Receiver)
|
||||
|
||||
/** Request to start file transfer */
|
||||
export type LanFileStartMessage = {
|
||||
type: 'file_start'
|
||||
transferId: string
|
||||
fileName: string
|
||||
fileSize: number
|
||||
mimeType: string // 'application/zip'
|
||||
checksum: string // SHA-256 of entire file
|
||||
totalChunks: number
|
||||
chunkSize: number
|
||||
}
|
||||
|
||||
/**
|
||||
* File chunk data (JSON format)
|
||||
* @deprecated Use binary frame format in protocol v1. This type is kept for reference only.
|
||||
*/
|
||||
export type LanFileChunkMessage = {
|
||||
type: 'file_chunk'
|
||||
transferId: string
|
||||
chunkIndex: number
|
||||
data: string // Base64 encoded
|
||||
chunkChecksum: string // SHA-256 of this chunk
|
||||
}
|
||||
|
||||
/** Notification that all chunks have been sent */
|
||||
export type LanFileEndMessage = {
|
||||
type: 'file_end'
|
||||
transferId: string
|
||||
}
|
||||
|
||||
/** Request to cancel file transfer */
|
||||
export type LanFileCancelMessage = {
|
||||
type: 'file_cancel'
|
||||
transferId: string
|
||||
reason?: string
|
||||
}
|
||||
|
||||
// Messages from Mobile (Server/Receiver) to Electron (Client/Sender)
|
||||
|
||||
/** Acknowledgment of file transfer request */
|
||||
export type LanFileStartAckMessage = {
|
||||
type: 'file_start_ack'
|
||||
transferId: string
|
||||
accepted: boolean
|
||||
message?: string // Rejection reason
|
||||
}
|
||||
|
||||
/**
|
||||
* Acknowledgment of file chunk received
|
||||
* @deprecated Protocol v1 uses streaming mode without per-chunk acknowledgment.
|
||||
* This type is kept for backward compatibility reference only.
|
||||
*/
|
||||
export type LanFileChunkAckMessage = {
|
||||
type: 'file_chunk_ack'
|
||||
transferId: string
|
||||
chunkIndex: number
|
||||
received: boolean
|
||||
message?: string
|
||||
}
|
||||
|
||||
/** Final result of file transfer */
|
||||
export type LanFileCompleteMessage = {
|
||||
type: 'file_complete'
|
||||
transferId: string
|
||||
success: boolean
|
||||
filePath?: string // Path where file was saved on mobile
|
||||
error?: string
|
||||
// Enhanced error diagnostics
|
||||
errorCode?: 'CHECKSUM_MISMATCH' | 'INCOMPLETE_TRANSFER' | 'DISK_ERROR' | 'CANCELLED'
|
||||
receivedChunks?: number
|
||||
receivedBytes?: number
|
||||
}
|
||||
|
||||
/** Payload for sending a file via IPC */
|
||||
export type LanFileSendPayload = {
|
||||
filePath: string
|
||||
}
|
||||
|
||||
@ -19,7 +19,8 @@ export const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
|
||||
'azure-openai': 'azure', // Azure OpenAI -> azure
|
||||
'openai-response': 'openai', // OpenAI Responses -> openai
|
||||
grok: 'xai', // Grok -> xai
|
||||
copilot: 'github-copilot-openai-compatible'
|
||||
copilot: 'github-copilot-openai-compatible',
|
||||
tokenflux: 'openrouter'
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -35,3 +35,56 @@ export const defaultAppHeaders = () => {
|
||||
// return value
|
||||
// }
|
||||
// }
|
||||
|
||||
/**
|
||||
* Extracts the trailing API version segment from a URL path.
|
||||
*
|
||||
* This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL.
|
||||
* Only versions at the end of the path are extracted, not versions in the middle.
|
||||
* The returned version string does not include leading or trailing slashes.
|
||||
*
|
||||
* @param {string} url - The URL string to parse.
|
||||
* @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found.
|
||||
*
|
||||
* @example
|
||||
* getTrailingApiVersion('https://api.example.com/v1') // 'v1'
|
||||
* getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta'
|
||||
* getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end)
|
||||
* getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta'
|
||||
* getTrailingApiVersion('https://api.example.com') // undefined
|
||||
*/
|
||||
export function getTrailingApiVersion(url: string): string | undefined {
|
||||
const match = url.match(TRAILING_VERSION_REGEX)
|
||||
|
||||
if (match) {
|
||||
// Extract version without leading slash and trailing slash
|
||||
return match[0].replace(/^\//, '').replace(/\/$/, '')
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Matches an API version at the end of a URL (with optional trailing slash).
|
||||
* Used to detect and extract versions only from the trailing position.
|
||||
*/
|
||||
const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i
|
||||
|
||||
/**
|
||||
* Removes the trailing API version segment from a URL path.
|
||||
*
|
||||
* This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL.
|
||||
* Only versions at the end of the path are removed, not versions in the middle.
|
||||
*
|
||||
* @param {string} url - The URL string to process.
|
||||
* @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found.
|
||||
*
|
||||
* @example
|
||||
* withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com'
|
||||
* withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com'
|
||||
* withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change)
|
||||
* withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com'
|
||||
*/
|
||||
export function withoutTrailingApiVersion(url: string): string {
|
||||
return url.replace(TRAILING_VERSION_REGEX, '')
|
||||
}
|
||||
|
||||
@ -6,8 +6,8 @@ const { downloadWithPowerShell } = require('./download')
|
||||
|
||||
// Base URL for downloading OVMS binaries
|
||||
const OVMS_RELEASE_BASE_URL =
|
||||
'https://storage.openvinotoolkit.org/repositories/openvino_model_server/packages/2025.3.0/ovms_windows_python_on.zip'
|
||||
const OVMS_EX_URL = 'https://gitcode.com/gcw_ggDjjkY3/kjfile/releases/download/download/ovms_25.3_ex.zip'
|
||||
'https://storage.openvinotoolkit.org/repositories/openvino_model_server/packages/2025.4.1/ovms_windows_python_on.zip'
|
||||
const OVMS_EX_URL = 'https://gitcode.com/gcw_ggDjjkY3/kjfile/releases/download/download/ovms_25.4_ex.zip'
|
||||
|
||||
/**
|
||||
* error code:
|
||||
|
||||
@ -19,8 +19,10 @@ import { agentService } from './services/agents'
|
||||
import { apiServerService } from './services/ApiServerService'
|
||||
import { appMenuService } from './services/AppMenuService'
|
||||
import { configManager } from './services/ConfigManager'
|
||||
import { nodeTraceService } from './services/NodeTraceService'
|
||||
import { lanTransferClientService } from './services/lanTransfer'
|
||||
import mcpService from './services/MCPService'
|
||||
import { localTransferService } from './services/LocalTransferService'
|
||||
import { nodeTraceService } from './services/NodeTraceService'
|
||||
import powerMonitorService from './services/PowerMonitorService'
|
||||
import {
|
||||
CHERRY_STUDIO_PROTOCOL,
|
||||
@ -35,6 +37,7 @@ import { versionService } from './services/VersionService'
|
||||
import { windowService } from './services/WindowService'
|
||||
import { initWebviewHotkeys } from './services/WebviewService'
|
||||
import { runAsyncFunction } from './utils'
|
||||
import { isOvmsSupported } from './services/OvmsManager'
|
||||
|
||||
const logger = loggerService.withContext('MainEntry')
|
||||
|
||||
@ -155,7 +158,8 @@ if (!app.requestSingleInstanceLock()) {
|
||||
|
||||
registerShortcuts(mainWindow)
|
||||
|
||||
registerIpc(mainWindow, app)
|
||||
await registerIpc(mainWindow, app)
|
||||
localTransferService.startDiscovery({ resetList: true })
|
||||
|
||||
replaceDevtoolsFont(mainWindow)
|
||||
|
||||
@ -237,16 +241,29 @@ if (!app.requestSingleInstanceLock()) {
|
||||
if (selectionService) {
|
||||
selectionService.quit()
|
||||
}
|
||||
|
||||
lanTransferClientService.dispose()
|
||||
localTransferService.dispose()
|
||||
})
|
||||
|
||||
app.on('will-quit', async () => {
|
||||
// 简单的资源清理,不阻塞退出流程
|
||||
if (isOvmsSupported) {
|
||||
const { ovmsManager } = await import('./services/OvmsManager')
|
||||
if (ovmsManager) {
|
||||
await ovmsManager.stopOvms()
|
||||
} else {
|
||||
logger.warn('Unexpected behavior: undefined ovmsManager, but OVMS should be supported.')
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
await mcpService.cleanup()
|
||||
await apiServerService.stop()
|
||||
} catch (error) {
|
||||
logger.warn('Error cleaning up MCP service:', error as Error)
|
||||
}
|
||||
|
||||
// finish the logger
|
||||
logger.finish()
|
||||
})
|
||||
|
||||
121
src/main/ipc.ts
121
src/main/ipc.ts
@ -18,6 +18,7 @@ import { handleZoomFactor } from '@main/utils/zoom'
|
||||
import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
||||
import type { UpgradeChannel } from '@shared/config/constant'
|
||||
import { MIN_WINDOW_HEIGHT, MIN_WINDOW_WIDTH } from '@shared/config/constant'
|
||||
import type { LocalTransferConnectPayload } from '@shared/config/types'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import type { PluginError } from '@types'
|
||||
import type {
|
||||
@ -49,6 +50,8 @@ import { ExportService } from './services/ExportService'
|
||||
import { fileStorage as fileManager } from './services/FileStorage'
|
||||
import FileService from './services/FileSystemService'
|
||||
import KnowledgeService from './services/KnowledgeService'
|
||||
import { lanTransferClientService } from './services/lanTransfer'
|
||||
import { localTransferService } from './services/LocalTransferService'
|
||||
import mcpService from './services/MCPService'
|
||||
import MemoryService from './services/memory/MemoryService'
|
||||
import { openTraceWindow, setTraceWindowTitle } from './services/NodeTraceService'
|
||||
@ -56,7 +59,7 @@ import NotificationService from './services/NotificationService'
|
||||
import * as NutstoreService from './services/NutstoreService'
|
||||
import ObsidianVaultService from './services/ObsidianVaultService'
|
||||
import { ocrService } from './services/ocr/OcrService'
|
||||
import OvmsManager from './services/OvmsManager'
|
||||
import { isOvmsSupported } from './services/OvmsManager'
|
||||
import powerMonitorService from './services/PowerMonitorService'
|
||||
import { proxyManager } from './services/ProxyManager'
|
||||
import { pythonService } from './services/PythonService'
|
||||
@ -80,7 +83,6 @@ import {
|
||||
import storeSyncService from './services/StoreSyncService'
|
||||
import { themeService } from './services/ThemeService'
|
||||
import VertexAIService from './services/VertexAIService'
|
||||
import WebSocketService from './services/WebSocketService'
|
||||
import { setOpenLinkExternal } from './services/WebviewService'
|
||||
import { windowService } from './services/WindowService'
|
||||
import { calculateDirectorySize, getResourcePath } from './utils'
|
||||
@ -95,6 +97,7 @@ import {
|
||||
untildify
|
||||
} from './utils/file'
|
||||
import { updateAppDataConfig } from './utils/init'
|
||||
import { getCpuName, getDeviceType, getHostname } from './utils/system'
|
||||
import { compress, decompress } from './utils/zip'
|
||||
|
||||
const logger = loggerService.withContext('IPC')
|
||||
@ -105,7 +108,6 @@ const obsidianVaultService = new ObsidianVaultService()
|
||||
const vertexAIService = VertexAIService.getInstance()
|
||||
const memoryService = MemoryService.getInstance()
|
||||
const dxtService = new DxtService()
|
||||
const ovmsManager = new OvmsManager()
|
||||
const pluginService = PluginService.getInstance()
|
||||
|
||||
function normalizeError(error: unknown): Error {
|
||||
@ -119,7 +121,7 @@ function extractPluginError(error: unknown): PluginError | null {
|
||||
return null
|
||||
}
|
||||
|
||||
export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
export async function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
const appUpdater = new AppUpdater()
|
||||
const notificationService = new NotificationService()
|
||||
|
||||
@ -497,9 +499,9 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text))
|
||||
|
||||
// system
|
||||
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
|
||||
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
|
||||
ipcMain.handle(IpcChannel.System_GetCpuName, () => require('os').cpus()[0].model)
|
||||
ipcMain.handle(IpcChannel.System_GetDeviceType, getDeviceType)
|
||||
ipcMain.handle(IpcChannel.System_GetHostname, getHostname)
|
||||
ipcMain.handle(IpcChannel.System_GetCpuName, getCpuName)
|
||||
ipcMain.handle(IpcChannel.System_CheckGitBash, () => {
|
||||
if (!isWin) {
|
||||
return true // Non-Windows systems don't need Git Bash
|
||||
@ -583,6 +585,8 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
ipcMain.handle(IpcChannel.Backup_ListS3Files, backupManager.listS3Files.bind(backupManager))
|
||||
ipcMain.handle(IpcChannel.Backup_DeleteS3File, backupManager.deleteS3File.bind(backupManager))
|
||||
ipcMain.handle(IpcChannel.Backup_CheckS3Connection, backupManager.checkS3Connection.bind(backupManager))
|
||||
ipcMain.handle(IpcChannel.Backup_CreateLanTransferBackup, backupManager.createLanTransferBackup.bind(backupManager))
|
||||
ipcMain.handle(IpcChannel.Backup_DeleteTempBackup, backupManager.deleteTempBackup.bind(backupManager))
|
||||
|
||||
// file
|
||||
ipcMain.handle(IpcChannel.File_Open, fileManager.open.bind(fileManager))
|
||||
@ -682,36 +686,19 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
ipcMain.handle(IpcChannel.KnowledgeBase_Check_Quota, KnowledgeService.checkQuota.bind(KnowledgeService))
|
||||
|
||||
// memory
|
||||
ipcMain.handle(IpcChannel.Memory_Add, async (_, messages, config) => {
|
||||
return await memoryService.add(messages, config)
|
||||
})
|
||||
ipcMain.handle(IpcChannel.Memory_Search, async (_, query, config) => {
|
||||
return await memoryService.search(query, config)
|
||||
})
|
||||
ipcMain.handle(IpcChannel.Memory_List, async (_, config) => {
|
||||
return await memoryService.list(config)
|
||||
})
|
||||
ipcMain.handle(IpcChannel.Memory_Delete, async (_, id) => {
|
||||
return await memoryService.delete(id)
|
||||
})
|
||||
ipcMain.handle(IpcChannel.Memory_Update, async (_, id, memory, metadata) => {
|
||||
return await memoryService.update(id, memory, metadata)
|
||||
})
|
||||
ipcMain.handle(IpcChannel.Memory_Get, async (_, memoryId) => {
|
||||
return await memoryService.get(memoryId)
|
||||
})
|
||||
ipcMain.handle(IpcChannel.Memory_SetConfig, async (_, config) => {
|
||||
memoryService.setConfig(config)
|
||||
})
|
||||
ipcMain.handle(IpcChannel.Memory_DeleteUser, async (_, userId) => {
|
||||
return await memoryService.deleteUser(userId)
|
||||
})
|
||||
ipcMain.handle(IpcChannel.Memory_DeleteAllMemoriesForUser, async (_, userId) => {
|
||||
return await memoryService.deleteAllMemoriesForUser(userId)
|
||||
})
|
||||
ipcMain.handle(IpcChannel.Memory_GetUsersList, async () => {
|
||||
return await memoryService.getUsersList()
|
||||
})
|
||||
ipcMain.handle(IpcChannel.Memory_Add, (_, messages, config) => memoryService.add(messages, config))
|
||||
ipcMain.handle(IpcChannel.Memory_Search, (_, query, config) => memoryService.search(query, config))
|
||||
ipcMain.handle(IpcChannel.Memory_List, (_, config) => memoryService.list(config))
|
||||
ipcMain.handle(IpcChannel.Memory_Delete, (_, id) => memoryService.delete(id))
|
||||
ipcMain.handle(IpcChannel.Memory_Update, (_, id, memory, metadata) => memoryService.update(id, memory, metadata))
|
||||
ipcMain.handle(IpcChannel.Memory_Get, (_, memoryId) => memoryService.get(memoryId))
|
||||
ipcMain.handle(IpcChannel.Memory_SetConfig, (_, config) => memoryService.setConfig(config))
|
||||
ipcMain.handle(IpcChannel.Memory_DeleteUser, (_, userId) => memoryService.deleteUser(userId))
|
||||
ipcMain.handle(IpcChannel.Memory_DeleteAllMemoriesForUser, (_, userId) =>
|
||||
memoryService.deleteAllMemoriesForUser(userId)
|
||||
)
|
||||
ipcMain.handle(IpcChannel.Memory_GetUsersList, () => memoryService.getUsersList())
|
||||
ipcMain.handle(IpcChannel.Memory_MigrateMemoryDb, () => memoryService.migrateMemoryDb())
|
||||
|
||||
// window
|
||||
ipcMain.handle(IpcChannel.Windows_SetMinimumSize, (_, width: number, height: number) => {
|
||||
@ -871,8 +858,8 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
)
|
||||
|
||||
// search window
|
||||
ipcMain.handle(IpcChannel.SearchWindow_Open, async (_, uid: string) => {
|
||||
await searchService.openSearchWindow(uid)
|
||||
ipcMain.handle(IpcChannel.SearchWindow_Open, async (_, uid: string, show?: boolean) => {
|
||||
await searchService.openSearchWindow(uid, show)
|
||||
})
|
||||
ipcMain.handle(IpcChannel.SearchWindow_Close, async (_, uid: string) => {
|
||||
await searchService.closeSearchWindow(uid)
|
||||
@ -988,15 +975,36 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
ipcMain.handle(IpcChannel.OCR_ListProviders, () => ocrService.listProviderIds())
|
||||
|
||||
// OVMS
|
||||
ipcMain.handle(IpcChannel.Ovms_AddModel, (_, modelName: string, modelId: string, modelSource: string, task: string) =>
|
||||
ovmsManager.addModel(modelName, modelId, modelSource, task)
|
||||
)
|
||||
ipcMain.handle(IpcChannel.Ovms_StopAddModel, () => ovmsManager.stopAddModel())
|
||||
ipcMain.handle(IpcChannel.Ovms_GetModels, () => ovmsManager.getModels())
|
||||
ipcMain.handle(IpcChannel.Ovms_IsRunning, () => ovmsManager.initializeOvms())
|
||||
ipcMain.handle(IpcChannel.Ovms_GetStatus, () => ovmsManager.getOvmsStatus())
|
||||
ipcMain.handle(IpcChannel.Ovms_RunOVMS, () => ovmsManager.runOvms())
|
||||
ipcMain.handle(IpcChannel.Ovms_StopOVMS, () => ovmsManager.stopOvms())
|
||||
ipcMain.handle(IpcChannel.Ovms_IsSupported, () => isOvmsSupported)
|
||||
if (isOvmsSupported) {
|
||||
const { ovmsManager } = await import('./services/OvmsManager')
|
||||
if (ovmsManager) {
|
||||
ipcMain.handle(
|
||||
IpcChannel.Ovms_AddModel,
|
||||
(_, modelName: string, modelId: string, modelSource: string, task: string) =>
|
||||
ovmsManager.addModel(modelName, modelId, modelSource, task)
|
||||
)
|
||||
ipcMain.handle(IpcChannel.Ovms_StopAddModel, () => ovmsManager.stopAddModel())
|
||||
ipcMain.handle(IpcChannel.Ovms_GetModels, () => ovmsManager.getModels())
|
||||
ipcMain.handle(IpcChannel.Ovms_IsRunning, () => ovmsManager.initializeOvms())
|
||||
ipcMain.handle(IpcChannel.Ovms_GetStatus, () => ovmsManager.getOvmsStatus())
|
||||
ipcMain.handle(IpcChannel.Ovms_RunOVMS, () => ovmsManager.runOvms())
|
||||
ipcMain.handle(IpcChannel.Ovms_StopOVMS, () => ovmsManager.stopOvms())
|
||||
} else {
|
||||
logger.error('Unexpected behavior: undefined ovmsManager, but OVMS should be supported.')
|
||||
}
|
||||
} else {
|
||||
const fallback = () => {
|
||||
throw new Error('OVMS is only supported on Windows with intel CPU.')
|
||||
}
|
||||
ipcMain.handle(IpcChannel.Ovms_AddModel, fallback)
|
||||
ipcMain.handle(IpcChannel.Ovms_StopAddModel, fallback)
|
||||
ipcMain.handle(IpcChannel.Ovms_GetModels, fallback)
|
||||
ipcMain.handle(IpcChannel.Ovms_IsRunning, fallback)
|
||||
ipcMain.handle(IpcChannel.Ovms_GetStatus, fallback)
|
||||
ipcMain.handle(IpcChannel.Ovms_RunOVMS, fallback)
|
||||
ipcMain.handle(IpcChannel.Ovms_StopOVMS, fallback)
|
||||
}
|
||||
|
||||
// CherryAI
|
||||
ipcMain.handle(IpcChannel.Cherryai_GetSignature, (_, params) => generateSignature(params))
|
||||
@ -1114,12 +1122,17 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
}
|
||||
})
|
||||
|
||||
// WebSocket
|
||||
ipcMain.handle(IpcChannel.WebSocket_Start, WebSocketService.start)
|
||||
ipcMain.handle(IpcChannel.WebSocket_Stop, WebSocketService.stop)
|
||||
ipcMain.handle(IpcChannel.WebSocket_Status, WebSocketService.getStatus)
|
||||
ipcMain.handle(IpcChannel.WebSocket_SendFile, WebSocketService.sendFile)
|
||||
ipcMain.handle(IpcChannel.WebSocket_GetAllCandidates, WebSocketService.getAllCandidates)
|
||||
ipcMain.handle(IpcChannel.LocalTransfer_ListServices, () => localTransferService.getState())
|
||||
ipcMain.handle(IpcChannel.LocalTransfer_StartScan, () => localTransferService.startDiscovery({ resetList: true }))
|
||||
ipcMain.handle(IpcChannel.LocalTransfer_StopScan, () => localTransferService.stopDiscovery())
|
||||
ipcMain.handle(IpcChannel.LocalTransfer_Connect, (_, payload: LocalTransferConnectPayload) =>
|
||||
lanTransferClientService.connectAndHandshake(payload)
|
||||
)
|
||||
ipcMain.handle(IpcChannel.LocalTransfer_Disconnect, () => lanTransferClientService.disconnect())
|
||||
ipcMain.handle(IpcChannel.LocalTransfer_SendFile, (_, payload: { filePath: string }) =>
|
||||
lanTransferClientService.sendFile(payload.filePath)
|
||||
)
|
||||
ipcMain.handle(IpcChannel.LocalTransfer_CancelTransfer, () => lanTransferClientService.cancelTransfer())
|
||||
|
||||
ipcMain.handle(IpcChannel.APP_CrashRenderProcess, () => {
|
||||
mainWindow.webContents.forcefullyCrashRenderer()
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/**
|
||||
* @deprecated Scheduled for removal in v2.0.0
|
||||
* --------------------------------------------------------------------------
|
||||
* ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex)
|
||||
* --------------------------------------------------------------------------
|
||||
* STOP: Feature PRs affecting this file are currently BLOCKED.
|
||||
* Only critical bug fixes are accepted during this migration phase.
|
||||
*
|
||||
* This file is being refactored to v2 standards.
|
||||
* Any non-critical changes will conflict with the ongoing work.
|
||||
*
|
||||
* 🔗 Context & Status:
|
||||
* - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954
|
||||
* - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162
|
||||
* --------------------------------------------------------------------------
|
||||
*/
|
||||
import { loggerService } from '@logger'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import type { WebDavConfig } from '@types'
|
||||
@ -767,6 +783,56 @@ class BackupManager {
|
||||
const s3Client = this.getS3Storage(s3Config)
|
||||
return await s3Client.checkConnection()
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a temporary backup for LAN transfer
|
||||
* Creates a lightweight backup (skipBackupFile=true) in the temp directory
|
||||
* Returns the path to the created ZIP file
|
||||
*/
|
||||
async createLanTransferBackup(_: Electron.IpcMainInvokeEvent, data: string): Promise<string> {
|
||||
const timestamp = new Date()
|
||||
.toISOString()
|
||||
.replace(/[-:T.Z]/g, '')
|
||||
.slice(0, 12)
|
||||
const fileName = `cherry-studio.${timestamp}.zip`
|
||||
const tempPath = path.join(app.getPath('temp'), 'cherry-studio', 'lan-transfer')
|
||||
|
||||
// Ensure temp directory exists
|
||||
await fs.ensureDir(tempPath)
|
||||
|
||||
// Create backup with skipBackupFile=true (no Data folder)
|
||||
const backupedFilePath = await this.backup(_, fileName, data, tempPath, true)
|
||||
|
||||
logger.info(`[BackupManager] Created LAN transfer backup at: ${backupedFilePath}`)
|
||||
return backupedFilePath
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a temporary backup file after LAN transfer completes
|
||||
*/
|
||||
async deleteTempBackup(_: Electron.IpcMainInvokeEvent, filePath: string): Promise<boolean> {
|
||||
try {
|
||||
// Security check: only allow deletion within temp directory
|
||||
const tempBase = path.normalize(path.join(app.getPath('temp'), 'cherry-studio', 'lan-transfer'))
|
||||
const resolvedPath = path.normalize(path.resolve(filePath))
|
||||
|
||||
// Use normalized paths with trailing separator to prevent prefix attacks (e.g., /temp-evil)
|
||||
if (!resolvedPath.startsWith(tempBase + path.sep) && resolvedPath !== tempBase) {
|
||||
logger.warn(`[BackupManager] Attempted to delete file outside temp directory: ${filePath}`)
|
||||
return false
|
||||
}
|
||||
|
||||
if (await fs.pathExists(resolvedPath)) {
|
||||
await fs.remove(resolvedPath)
|
||||
logger.info(`[BackupManager] Deleted temp backup: ${resolvedPath}`)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
} catch (error) {
|
||||
logger.error('[BackupManager] Failed to delete temp backup:', error as Error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default BackupManager
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/**
|
||||
* @deprecated Scheduled for removal in v2.0.0
|
||||
* --------------------------------------------------------------------------
|
||||
* ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex)
|
||||
* --------------------------------------------------------------------------
|
||||
* STOP: Feature PRs affecting this file are currently BLOCKED.
|
||||
* Only critical bug fixes are accepted during this migration phase.
|
||||
*
|
||||
* This file is being refactored to v2 standards.
|
||||
* Any non-critical changes will conflict with the ongoing work.
|
||||
*
|
||||
* 🔗 Context & Status:
|
||||
* - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954
|
||||
* - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162
|
||||
* --------------------------------------------------------------------------
|
||||
*/
|
||||
interface CacheItem<T> {
|
||||
data: T
|
||||
timestamp: number
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/**
|
||||
* @deprecated Scheduled for removal in v2.0.0
|
||||
* --------------------------------------------------------------------------
|
||||
* ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex)
|
||||
* --------------------------------------------------------------------------
|
||||
* STOP: Feature PRs affecting this file are currently BLOCKED.
|
||||
* Only critical bug fixes are accepted during this migration phase.
|
||||
*
|
||||
* This file is being refactored to v2 standards.
|
||||
* Any non-critical changes will conflict with the ongoing work.
|
||||
*
|
||||
* 🔗 Context & Status:
|
||||
* - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954
|
||||
* - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162
|
||||
* --------------------------------------------------------------------------
|
||||
*/
|
||||
import type { UpgradeChannel } from '@shared/config/constant'
|
||||
import { defaultLanguage, ZOOM_SHORTCUTS } from '@shared/config/constant'
|
||||
import type { LanguageVarious, Shortcut } from '@types'
|
||||
|
||||
@ -2,7 +2,7 @@ import { loggerService } from '@logger'
|
||||
import {
|
||||
checkName,
|
||||
getFilesDir,
|
||||
getFileType,
|
||||
getFileType as getFileTypeByExt,
|
||||
getName,
|
||||
getNotesDir,
|
||||
getTempDir,
|
||||
@ -11,13 +11,13 @@ import {
|
||||
} from '@main/utils/file'
|
||||
import { documentExts, imageExts, KB, MB } from '@shared/config/constant'
|
||||
import type { FileMetadata, NotesTreeNode } from '@types'
|
||||
import { FileTypes } from '@types'
|
||||
import chardet from 'chardet'
|
||||
import type { FSWatcher } from 'chokidar'
|
||||
import chokidar from 'chokidar'
|
||||
import * as crypto from 'crypto'
|
||||
import type { OpenDialogOptions, OpenDialogReturnValue, SaveDialogOptions, SaveDialogReturnValue } from 'electron'
|
||||
import { app } from 'electron'
|
||||
import { dialog, net, shell } from 'electron'
|
||||
import { app, dialog, net, shell } from 'electron'
|
||||
import * as fs from 'fs'
|
||||
import { writeFileSync } from 'fs'
|
||||
import { readFile } from 'fs/promises'
|
||||
@ -130,16 +130,18 @@ interface DirectoryListOptions {
|
||||
includeDirectories?: boolean
|
||||
maxEntries?: number
|
||||
searchPattern?: string
|
||||
fuzzy?: boolean
|
||||
}
|
||||
|
||||
const DEFAULT_DIRECTORY_LIST_OPTIONS: Required<DirectoryListOptions> = {
|
||||
recursive: true,
|
||||
maxDepth: 3,
|
||||
maxDepth: 10,
|
||||
includeHidden: false,
|
||||
includeFiles: true,
|
||||
includeDirectories: true,
|
||||
maxEntries: 10,
|
||||
searchPattern: '.'
|
||||
maxEntries: 20,
|
||||
searchPattern: '.',
|
||||
fuzzy: true
|
||||
}
|
||||
|
||||
class FileStorage {
|
||||
@ -185,7 +187,7 @@ class FileStorage {
|
||||
})
|
||||
}
|
||||
|
||||
findDuplicateFile = async (filePath: string): Promise<FileMetadata | null> => {
|
||||
private findDuplicateFile = async (filePath: string): Promise<FileMetadata | null> => {
|
||||
const stats = fs.statSync(filePath)
|
||||
logger.debug(`stats: ${stats}, filePath: ${filePath}`)
|
||||
const fileSize = stats.size
|
||||
@ -204,6 +206,8 @@ class FileStorage {
|
||||
if (originalHash === storedHash) {
|
||||
const ext = path.extname(file)
|
||||
const id = path.basename(file, ext)
|
||||
const type = await this.getFileType(filePath)
|
||||
|
||||
return {
|
||||
id,
|
||||
origin_name: file,
|
||||
@ -212,7 +216,7 @@ class FileStorage {
|
||||
created_at: storedStats.birthtime.toISOString(),
|
||||
size: storedStats.size,
|
||||
ext,
|
||||
type: getFileType(ext),
|
||||
type,
|
||||
count: 2
|
||||
}
|
||||
}
|
||||
@ -222,6 +226,13 @@ class FileStorage {
|
||||
return null
|
||||
}
|
||||
|
||||
public getFileType = async (filePath: string): Promise<FileTypes> => {
|
||||
const ext = path.extname(filePath)
|
||||
const fileType = getFileTypeByExt(ext)
|
||||
|
||||
return fileType === FileTypes.OTHER && (await this._isTextFile(filePath)) ? FileTypes.TEXT : fileType
|
||||
}
|
||||
|
||||
public selectFile = async (
|
||||
_: Electron.IpcMainInvokeEvent,
|
||||
options?: OpenDialogOptions
|
||||
@ -241,7 +252,7 @@ class FileStorage {
|
||||
const fileMetadataPromises = result.filePaths.map(async (filePath) => {
|
||||
const stats = fs.statSync(filePath)
|
||||
const ext = path.extname(filePath)
|
||||
const fileType = getFileType(ext)
|
||||
const fileType = await this.getFileType(filePath)
|
||||
|
||||
return {
|
||||
id: uuidv4(),
|
||||
@ -307,7 +318,7 @@ class FileStorage {
|
||||
}
|
||||
|
||||
const stats = await fs.promises.stat(destPath)
|
||||
const fileType = getFileType(ext)
|
||||
const fileType = await this.getFileType(destPath)
|
||||
|
||||
const fileMetadata: FileMetadata = {
|
||||
id: uuid,
|
||||
@ -332,8 +343,7 @@ class FileStorage {
|
||||
}
|
||||
|
||||
const stats = fs.statSync(filePath)
|
||||
const ext = path.extname(filePath)
|
||||
const fileType = getFileType(ext)
|
||||
const fileType = await this.getFileType(filePath)
|
||||
|
||||
return {
|
||||
id: uuidv4(),
|
||||
@ -342,7 +352,7 @@ class FileStorage {
|
||||
path: filePath,
|
||||
created_at: stats.birthtime.toISOString(),
|
||||
size: stats.size,
|
||||
ext: ext,
|
||||
ext: path.extname(filePath),
|
||||
type: fileType,
|
||||
count: 1
|
||||
}
|
||||
@ -690,7 +700,7 @@ class FileStorage {
|
||||
created_at: new Date().toISOString(),
|
||||
size: buffer.length,
|
||||
ext: ext.slice(1),
|
||||
type: getFileType(ext),
|
||||
type: getFileTypeByExt(ext),
|
||||
count: 1
|
||||
}
|
||||
} catch (error) {
|
||||
@ -740,7 +750,7 @@ class FileStorage {
|
||||
created_at: new Date().toISOString(),
|
||||
size: stats.size,
|
||||
ext: ext.slice(1),
|
||||
type: getFileType(ext),
|
||||
type: getFileTypeByExt(ext),
|
||||
count: 1
|
||||
}
|
||||
} catch (error) {
|
||||
@ -1038,10 +1048,226 @@ class FileStorage {
|
||||
}
|
||||
|
||||
/**
|
||||
* Search files by content pattern
|
||||
* Fuzzy match: checks if all characters in query appear in text in order (case-insensitive)
|
||||
* Example: "updater" matches "packages/update/src/node/updateController.ts"
|
||||
*/
|
||||
private async searchByContent(resolvedPath: string, options: Required<DirectoryListOptions>): Promise<string[]> {
|
||||
const args: string[] = ['-l']
|
||||
private isFuzzyMatch(text: string, query: string): boolean {
|
||||
let i = 0 // text index
|
||||
let j = 0 // query index
|
||||
const textLower = text.toLowerCase()
|
||||
const queryLower = query.toLowerCase()
|
||||
|
||||
while (i < textLower.length && j < queryLower.length) {
|
||||
if (textLower[i] === queryLower[j]) {
|
||||
j++
|
||||
}
|
||||
i++
|
||||
}
|
||||
return j === queryLower.length
|
||||
}
|
||||
|
||||
/**
|
||||
* Scoring constants for fuzzy match relevance ranking
|
||||
* Higher values = higher priority in search results
|
||||
*/
|
||||
private static readonly SCORE_SEGMENT_MATCH = 60 // Per path segment that matches query
|
||||
private static readonly SCORE_FILENAME_CONTAINS = 80 // Filename contains exact query substring
|
||||
private static readonly SCORE_FILENAME_STARTS = 100 // Filename starts with query (highest priority)
|
||||
private static readonly SCORE_CONSECUTIVE_CHAR = 15 // Per consecutive character match
|
||||
private static readonly SCORE_WORD_BOUNDARY = 20 // Query matches start of a word
|
||||
private static readonly PATH_LENGTH_PENALTY_FACTOR = 4 // Logarithmic penalty multiplier for longer paths
|
||||
|
||||
/**
|
||||
* Calculate fuzzy match score (higher is better)
|
||||
* Scoring factors:
|
||||
* - Consecutive character matches (bonus)
|
||||
* - Match at word boundaries (bonus)
|
||||
* - Shorter path length (bonus)
|
||||
* - Match in filename vs directory (bonus)
|
||||
*/
|
||||
private getFuzzyMatchScore(filePath: string, query: string): number {
|
||||
const pathLower = filePath.toLowerCase()
|
||||
const queryLower = query.toLowerCase()
|
||||
const fileName = filePath.split('/').pop() || ''
|
||||
const fileNameLower = fileName.toLowerCase()
|
||||
|
||||
let score = 0
|
||||
|
||||
// Count how many times query-related words appear in path segments
|
||||
const pathSegments = pathLower.split(/[/\\]/)
|
||||
let segmentMatchCount = 0
|
||||
for (const segment of pathSegments) {
|
||||
if (this.isFuzzyMatch(segment, queryLower)) {
|
||||
segmentMatchCount++
|
||||
}
|
||||
}
|
||||
score += segmentMatchCount * FileStorage.SCORE_SEGMENT_MATCH
|
||||
|
||||
// Bonus for filename starting with query (stronger than generic "contains")
|
||||
if (fileNameLower.startsWith(queryLower)) {
|
||||
score += FileStorage.SCORE_FILENAME_STARTS
|
||||
} else if (fileNameLower.includes(queryLower)) {
|
||||
// Bonus for exact substring match in filename (e.g., "updater" in "RCUpdater.js")
|
||||
score += FileStorage.SCORE_FILENAME_CONTAINS
|
||||
}
|
||||
|
||||
// Calculate consecutive match bonus
|
||||
let i = 0
|
||||
let j = 0
|
||||
let consecutiveCount = 0
|
||||
let maxConsecutive = 0
|
||||
|
||||
while (i < pathLower.length && j < queryLower.length) {
|
||||
if (pathLower[i] === queryLower[j]) {
|
||||
consecutiveCount++
|
||||
maxConsecutive = Math.max(maxConsecutive, consecutiveCount)
|
||||
j++
|
||||
} else {
|
||||
consecutiveCount = 0
|
||||
}
|
||||
i++
|
||||
}
|
||||
score += maxConsecutive * FileStorage.SCORE_CONSECUTIVE_CHAR
|
||||
|
||||
// Bonus for word boundary matches (e.g., "upd" matches start of "update")
|
||||
// Only count once to avoid inflating scores for paths with repeated patterns
|
||||
const boundaryPrefix = queryLower.slice(0, Math.min(3, queryLower.length))
|
||||
const words = pathLower.split(/[/\\._-]/)
|
||||
for (const word of words) {
|
||||
if (word.startsWith(boundaryPrefix)) {
|
||||
score += FileStorage.SCORE_WORD_BOUNDARY
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Penalty for longer paths (prefer shorter, more specific matches)
|
||||
// Use logarithmic scaling to prevent long paths from dominating the score
|
||||
// A 50-char path gets ~-16 penalty, 100-char gets ~-18, 200-char gets ~-21
|
||||
score -= Math.log(filePath.length + 1) * FileStorage.PATH_LENGTH_PENALTY_FACTOR
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert query to glob pattern for ripgrep pre-filtering
|
||||
* e.g., "updater" -> "*u*p*d*a*t*e*r*"
|
||||
*/
|
||||
private queryToGlobPattern(query: string): string {
|
||||
// Escape special glob characters (including ! for negation)
|
||||
const escaped = query.replace(/[[\]{}()*+?.,\\^$|#!]/g, '\\$&')
|
||||
// Convert to fuzzy glob: each char separated by *
|
||||
return '*' + escaped.split('').join('*') + '*'
|
||||
}
|
||||
|
||||
/**
|
||||
* Greedy substring match: check if all characters in query can be matched
|
||||
* by finding consecutive substrings in text (not necessarily single chars)
|
||||
* e.g., "updatercontroller" matches "updateController" by:
|
||||
* "update" + "r" (from Controller) + "controller"
|
||||
*/
|
||||
private isGreedySubstringMatch(text: string, query: string): boolean {
|
||||
const textLower = text.toLowerCase()
|
||||
const queryLower = query.toLowerCase()
|
||||
|
||||
let queryIndex = 0
|
||||
let searchStart = 0
|
||||
|
||||
while (queryIndex < queryLower.length) {
|
||||
// Try to find the longest matching substring starting at queryIndex
|
||||
let bestMatchLen = 0
|
||||
let bestMatchPos = -1
|
||||
|
||||
for (let len = queryLower.length - queryIndex; len >= 1; len--) {
|
||||
const substr = queryLower.slice(queryIndex, queryIndex + len)
|
||||
const foundAt = textLower.indexOf(substr, searchStart)
|
||||
if (foundAt !== -1) {
|
||||
bestMatchLen = len
|
||||
bestMatchPos = foundAt
|
||||
break // Found longest possible match
|
||||
}
|
||||
}
|
||||
|
||||
if (bestMatchLen === 0) {
|
||||
// No substring match found, query cannot be matched
|
||||
return false
|
||||
}
|
||||
|
||||
queryIndex += bestMatchLen
|
||||
searchStart = bestMatchPos + bestMatchLen
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate greedy substring match score (higher is better)
|
||||
* Rewards: fewer match fragments, shorter match span, matches in filename
|
||||
*/
|
||||
private getGreedyMatchScore(filePath: string, query: string): number {
|
||||
const textLower = filePath.toLowerCase()
|
||||
const queryLower = query.toLowerCase()
|
||||
const fileName = filePath.split('/').pop() || ''
|
||||
const fileNameLower = fileName.toLowerCase()
|
||||
|
||||
let queryIndex = 0
|
||||
let searchStart = 0
|
||||
let fragmentCount = 0
|
||||
let firstMatchPos = -1
|
||||
let lastMatchEnd = 0
|
||||
|
||||
while (queryIndex < queryLower.length) {
|
||||
let bestMatchLen = 0
|
||||
let bestMatchPos = -1
|
||||
|
||||
for (let len = queryLower.length - queryIndex; len >= 1; len--) {
|
||||
const substr = queryLower.slice(queryIndex, queryIndex + len)
|
||||
const foundAt = textLower.indexOf(substr, searchStart)
|
||||
if (foundAt !== -1) {
|
||||
bestMatchLen = len
|
||||
bestMatchPos = foundAt
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if (bestMatchLen === 0) {
|
||||
return -Infinity // No match
|
||||
}
|
||||
|
||||
fragmentCount++
|
||||
if (firstMatchPos === -1) firstMatchPos = bestMatchPos
|
||||
lastMatchEnd = bestMatchPos + bestMatchLen
|
||||
queryIndex += bestMatchLen
|
||||
searchStart = lastMatchEnd
|
||||
}
|
||||
|
||||
const matchSpan = lastMatchEnd - firstMatchPos
|
||||
let score = 0
|
||||
|
||||
// Fewer fragments = better (single continuous match is best)
|
||||
// Max bonus when fragmentCount=1, decreases as fragments increase
|
||||
score += Math.max(0, 100 - (fragmentCount - 1) * 30)
|
||||
|
||||
// Shorter span relative to query length = better (tighter match)
|
||||
// Perfect match: span equals query length
|
||||
const spanRatio = queryLower.length / matchSpan
|
||||
score += spanRatio * 50
|
||||
|
||||
// Bonus for match in filename
|
||||
if (this.isGreedySubstringMatch(fileNameLower, queryLower)) {
|
||||
score += 80
|
||||
}
|
||||
|
||||
// Penalty for longer paths
|
||||
score -= Math.log(filePath.length + 1) * 4
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
/**
|
||||
* Build common ripgrep arguments for file listing
|
||||
*/
|
||||
private buildRipgrepBaseArgs(options: Required<DirectoryListOptions>, resolvedPath: string): string[] {
|
||||
const args: string[] = ['--files']
|
||||
|
||||
// Handle hidden files
|
||||
if (!options.includeHidden) {
|
||||
@ -1068,82 +1294,74 @@ class FileStorage {
|
||||
args.push('--max-depth', options.maxDepth.toString())
|
||||
}
|
||||
|
||||
// Handle max count
|
||||
if (options.maxEntries > 0) {
|
||||
args.push('--max-count', options.maxEntries.toString())
|
||||
}
|
||||
|
||||
// Add search pattern (search in content)
|
||||
args.push(options.searchPattern)
|
||||
|
||||
// Add the directory path
|
||||
args.push(resolvedPath)
|
||||
|
||||
const { exitCode, output } = await executeRipgrep(args)
|
||||
|
||||
// Exit code 0 means files found, 1 means no files found (still success), 2+ means error
|
||||
if (exitCode >= 2) {
|
||||
throw new Error(`Ripgrep failed with exit code ${exitCode}: ${output}`)
|
||||
}
|
||||
|
||||
// Parse ripgrep output (already sorted by relevance)
|
||||
const results = output
|
||||
.split('\n')
|
||||
.filter((line) => line.trim())
|
||||
.map((line) => line.replace(/\\/g, '/'))
|
||||
.slice(0, options.maxEntries)
|
||||
|
||||
return results
|
||||
return args
|
||||
}
|
||||
|
||||
private async listDirectoryWithRipgrep(
|
||||
resolvedPath: string,
|
||||
options: Required<DirectoryListOptions>
|
||||
): Promise<string[]> {
|
||||
const maxEntries = options.maxEntries
|
||||
// Fuzzy search mode: use ripgrep glob for pre-filtering, then score in JS
|
||||
if (options.fuzzy && options.searchPattern && options.searchPattern !== '.') {
|
||||
const args = this.buildRipgrepBaseArgs(options, resolvedPath)
|
||||
|
||||
// Step 1: Search by filename first
|
||||
// Insert glob pattern before the path (last element)
|
||||
const globPattern = this.queryToGlobPattern(options.searchPattern)
|
||||
args.splice(args.length - 1, 0, '--iglob', globPattern)
|
||||
|
||||
const { exitCode, output } = await executeRipgrep(args)
|
||||
|
||||
if (exitCode >= 2) {
|
||||
throw new Error(`Ripgrep failed with exit code ${exitCode}: ${output}`)
|
||||
}
|
||||
|
||||
const filteredFiles = output
|
||||
.split('\n')
|
||||
.filter((line) => line.trim())
|
||||
.map((line) => line.replace(/\\/g, '/'))
|
||||
|
||||
// If fuzzy glob found results, validate fuzzy match, sort and return
|
||||
if (filteredFiles.length > 0) {
|
||||
return filteredFiles
|
||||
.filter((file) => this.isFuzzyMatch(file, options.searchPattern))
|
||||
.map((file) => ({ file, score: this.getFuzzyMatchScore(file, options.searchPattern) }))
|
||||
.sort((a, b) => b.score - a.score)
|
||||
.slice(0, options.maxEntries)
|
||||
.map((item) => item.file)
|
||||
}
|
||||
|
||||
// Fallback: if no results, try greedy substring match on all files
|
||||
logger.debug('Fuzzy glob returned no results, falling back to greedy substring match')
|
||||
const fallbackArgs = this.buildRipgrepBaseArgs(options, resolvedPath)
|
||||
|
||||
const fallbackResult = await executeRipgrep(fallbackArgs)
|
||||
|
||||
if (fallbackResult.exitCode >= 2) {
|
||||
return []
|
||||
}
|
||||
|
||||
const allFiles = fallbackResult.output
|
||||
.split('\n')
|
||||
.filter((line) => line.trim())
|
||||
.map((line) => line.replace(/\\/g, '/'))
|
||||
|
||||
const greedyMatched = allFiles.filter((file) => this.isGreedySubstringMatch(file, options.searchPattern))
|
||||
|
||||
return greedyMatched
|
||||
.map((file) => ({ file, score: this.getGreedyMatchScore(file, options.searchPattern) }))
|
||||
.sort((a, b) => b.score - a.score)
|
||||
.slice(0, options.maxEntries)
|
||||
.map((item) => item.file)
|
||||
}
|
||||
|
||||
// Fallback: search by filename only (non-fuzzy mode)
|
||||
logger.debug('Searching by filename pattern', { pattern: options.searchPattern, path: resolvedPath })
|
||||
const filenameResults = await this.searchByFilename(resolvedPath, options)
|
||||
|
||||
logger.debug('Found matches by filename', { count: filenameResults.length })
|
||||
|
||||
// If we have enough filename matches, return them
|
||||
if (filenameResults.length >= maxEntries) {
|
||||
return filenameResults.slice(0, maxEntries)
|
||||
}
|
||||
|
||||
// Step 2: If filename matches are less than maxEntries, search by content to fill up
|
||||
logger.debug('Filename matches insufficient, searching by content to fill up', {
|
||||
filenameCount: filenameResults.length,
|
||||
needed: maxEntries - filenameResults.length
|
||||
})
|
||||
|
||||
// Adjust maxEntries for content search to get enough results
|
||||
const contentOptions = {
|
||||
...options,
|
||||
maxEntries: maxEntries - filenameResults.length + 20 // Request extra to account for duplicates
|
||||
}
|
||||
|
||||
const contentResults = await this.searchByContent(resolvedPath, contentOptions)
|
||||
|
||||
logger.debug('Found matches by content', { count: contentResults.length })
|
||||
|
||||
// Combine results: filename matches first, then content matches (deduplicated)
|
||||
const combined = [...filenameResults]
|
||||
const filenameSet = new Set(filenameResults)
|
||||
|
||||
for (const filePath of contentResults) {
|
||||
if (!filenameSet.has(filePath)) {
|
||||
combined.push(filePath)
|
||||
if (combined.length >= maxEntries) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug('Combined results', { total: combined.length, filenameCount: filenameResults.length })
|
||||
return combined.slice(0, maxEntries)
|
||||
return filenameResults.slice(0, options.maxEntries)
|
||||
}
|
||||
|
||||
public validateNotesDirectory = async (_: Electron.IpcMainInvokeEvent, dirPath: string): Promise<boolean> => {
|
||||
@ -1317,7 +1535,7 @@ class FileStorage {
|
||||
await fs.promises.writeFile(destPath, buffer)
|
||||
|
||||
const stats = await fs.promises.stat(destPath)
|
||||
const fileType = getFileType(ext)
|
||||
const fileType = await this.getFileType(destPath)
|
||||
|
||||
return {
|
||||
id: uuid,
|
||||
@ -1604,6 +1822,10 @@ class FileStorage {
|
||||
}
|
||||
|
||||
public isTextFile = async (_: Electron.IpcMainInvokeEvent, filePath: string): Promise<boolean> => {
|
||||
return this._isTextFile(filePath)
|
||||
}
|
||||
|
||||
private _isTextFile = async (filePath: string): Promise<boolean> => {
|
||||
try {
|
||||
const isBinary = await isBinaryFile(filePath)
|
||||
if (isBinary) {
|
||||
|
||||
207
src/main/services/LocalTransferService.ts
Normal file
207
src/main/services/LocalTransferService.ts
Normal file
@ -0,0 +1,207 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { LocalTransferPeer, LocalTransferState } from '@shared/config/types'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import type { Browser, Service } from 'bonjour-service'
|
||||
import Bonjour from 'bonjour-service'
|
||||
|
||||
import { windowService } from './WindowService'
|
||||
|
||||
const SERVICE_TYPE = 'cherrystudio'
|
||||
const SERVICE_PROTOCOL = 'tcp' as const
|
||||
|
||||
const logger = loggerService.withContext('LocalTransferService')
|
||||
|
||||
type StartDiscoveryOptions = {
|
||||
resetList?: boolean
|
||||
}
|
||||
|
||||
class LocalTransferService {
|
||||
private static instance: LocalTransferService
|
||||
private bonjour: Bonjour | null = null
|
||||
private browser: Browser | null = null
|
||||
private services = new Map<string, LocalTransferPeer>()
|
||||
private isScanning = false
|
||||
private lastScanStartedAt?: number
|
||||
private lastUpdatedAt = Date.now()
|
||||
private lastError?: string
|
||||
|
||||
private constructor() {}
|
||||
|
||||
public static getInstance(): LocalTransferService {
|
||||
if (!LocalTransferService.instance) {
|
||||
LocalTransferService.instance = new LocalTransferService()
|
||||
}
|
||||
return LocalTransferService.instance
|
||||
}
|
||||
|
||||
public startDiscovery(options?: StartDiscoveryOptions): LocalTransferState {
|
||||
if (options?.resetList) {
|
||||
this.services.clear()
|
||||
}
|
||||
|
||||
this.isScanning = true
|
||||
this.lastScanStartedAt = Date.now()
|
||||
this.lastUpdatedAt = Date.now()
|
||||
this.lastError = undefined
|
||||
this.restartBrowser()
|
||||
this.broadcastState()
|
||||
return this.getState()
|
||||
}
|
||||
|
||||
public stopDiscovery(): LocalTransferState {
|
||||
if (this.browser) {
|
||||
try {
|
||||
this.browser.stop()
|
||||
} catch (error) {
|
||||
logger.warn('Failed to stop local transfer browser', error as Error)
|
||||
}
|
||||
}
|
||||
this.isScanning = false
|
||||
this.lastUpdatedAt = Date.now()
|
||||
this.broadcastState()
|
||||
return this.getState()
|
||||
}
|
||||
|
||||
public getState(): LocalTransferState {
|
||||
const services = Array.from(this.services.values()).sort((a, b) => a.name.localeCompare(b.name))
|
||||
return {
|
||||
services,
|
||||
isScanning: this.isScanning,
|
||||
lastScanStartedAt: this.lastScanStartedAt,
|
||||
lastUpdatedAt: this.lastUpdatedAt,
|
||||
lastError: this.lastError
|
||||
}
|
||||
}
|
||||
|
||||
public getPeerById(id: string): LocalTransferPeer | undefined {
|
||||
return this.services.get(id)
|
||||
}
|
||||
|
||||
public dispose(): void {
|
||||
this.stopDiscovery()
|
||||
this.services.clear()
|
||||
this.browser?.removeAllListeners()
|
||||
this.browser = null
|
||||
if (this.bonjour) {
|
||||
try {
|
||||
this.bonjour.destroy()
|
||||
} catch (error) {
|
||||
logger.warn('Failed to destroy Bonjour instance', error as Error)
|
||||
}
|
||||
this.bonjour = null
|
||||
}
|
||||
}
|
||||
|
||||
private getBonjour(): Bonjour {
|
||||
if (!this.bonjour) {
|
||||
this.bonjour = new Bonjour()
|
||||
}
|
||||
return this.bonjour
|
||||
}
|
||||
|
||||
private restartBrowser(): void {
|
||||
// Clean up existing browser
|
||||
if (this.browser) {
|
||||
this.browser.removeAllListeners()
|
||||
try {
|
||||
this.browser.stop()
|
||||
} catch (error) {
|
||||
logger.warn('Error while stopping Bonjour browser', error as Error)
|
||||
}
|
||||
this.browser = null
|
||||
}
|
||||
|
||||
// Destroy and recreate Bonjour instance to prevent socket leaks
|
||||
if (this.bonjour) {
|
||||
try {
|
||||
this.bonjour.destroy()
|
||||
} catch (error) {
|
||||
logger.warn('Error while destroying Bonjour instance', error as Error)
|
||||
}
|
||||
this.bonjour = null
|
||||
}
|
||||
|
||||
const browser = this.getBonjour().find({ type: SERVICE_TYPE, protocol: SERVICE_PROTOCOL })
|
||||
this.browser = browser
|
||||
this.bindBrowserEvents(browser)
|
||||
|
||||
try {
|
||||
browser.start()
|
||||
logger.info('Local transfer discovery started')
|
||||
} catch (error) {
|
||||
const err = error instanceof Error ? error : new Error(String(error))
|
||||
this.lastError = err.message
|
||||
logger.error('Failed to start local transfer discovery', err)
|
||||
}
|
||||
}
|
||||
|
||||
private bindBrowserEvents(browser: Browser) {
|
||||
browser.on('up', (service) => {
|
||||
const peer = this.normalizeService(service)
|
||||
logger.info(`LAN peer detected: ${peer.name} (${peer.addresses.join(', ')})`)
|
||||
this.services.set(peer.id, peer)
|
||||
this.lastUpdatedAt = Date.now()
|
||||
this.broadcastState()
|
||||
})
|
||||
|
||||
browser.on('down', (service) => {
|
||||
const key = this.buildServiceKey(service.fqdn || service.name, service.host, service.port)
|
||||
if (this.services.delete(key)) {
|
||||
logger.info(`LAN peer removed: ${service.name}`)
|
||||
this.lastUpdatedAt = Date.now()
|
||||
this.broadcastState()
|
||||
}
|
||||
})
|
||||
|
||||
browser.on('error', (error) => {
|
||||
const err = error instanceof Error ? error : new Error(String(error))
|
||||
logger.error('Local transfer discovery error', err)
|
||||
this.lastError = err.message
|
||||
this.broadcastState()
|
||||
})
|
||||
}
|
||||
|
||||
private normalizeService(service: Service): LocalTransferPeer {
|
||||
const addressCandidates = [...(service.addresses || []), service.referer?.address].filter(
|
||||
(value): value is string => typeof value === 'string' && value.length > 0
|
||||
)
|
||||
const addresses = Array.from(new Set(addressCandidates))
|
||||
const txtEntries = Object.entries(service.txt || {})
|
||||
const txt =
|
||||
txtEntries.length > 0
|
||||
? Object.fromEntries(
|
||||
txtEntries.map(([key, value]) => [key, value === undefined || value === null ? '' : String(value)])
|
||||
)
|
||||
: undefined
|
||||
|
||||
const peer: LocalTransferPeer = {
|
||||
id: this.buildServiceKey(service.fqdn || service.name, service.host, service.port),
|
||||
name: service.name,
|
||||
host: service.host,
|
||||
fqdn: service.fqdn,
|
||||
port: service.port,
|
||||
type: service.type,
|
||||
protocol: service.protocol,
|
||||
addresses,
|
||||
txt,
|
||||
updatedAt: Date.now()
|
||||
}
|
||||
|
||||
return peer
|
||||
}
|
||||
|
||||
private buildServiceKey(name?: string, host?: string, port?: number): string {
|
||||
const raw = [name, host, port?.toString()].filter(Boolean).join('-')
|
||||
return raw || `service-${Date.now()}`
|
||||
}
|
||||
|
||||
private broadcastState() {
|
||||
const mainWindow = windowService.getMainWindow()
|
||||
if (!mainWindow || mainWindow.isDestroyed()) {
|
||||
return
|
||||
}
|
||||
mainWindow.webContents.send(IpcChannel.LocalTransfer_ServicesUpdated, this.getState())
|
||||
}
|
||||
}
|
||||
|
||||
export const localTransferService = LocalTransferService.getInstance()
|
||||
@ -6,7 +6,7 @@ import { loggerService } from '@logger'
|
||||
import { createInMemoryMCPServer } from '@main/mcpServers/factory'
|
||||
import { makeSureDirExists, removeEnvProxy } from '@main/utils'
|
||||
import { buildFunctionCallToolName } from '@main/utils/mcp'
|
||||
import { getBinaryName, getBinaryPath } from '@main/utils/process'
|
||||
import { findCommandInShellEnv, getBinaryName, getBinaryPath, isBinaryExists } from '@main/utils/process'
|
||||
import getLoginShellEnvironment from '@main/utils/shell-env'
|
||||
import { TraceMethod, withSpanFunc } from '@mcp-trace/trace-core'
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||
@ -318,6 +318,10 @@ class McpService {
|
||||
} else if (server.command) {
|
||||
let cmd = server.command
|
||||
|
||||
// Get login shell environment first - needed for command detection and server execution
|
||||
// Note: getLoginShellEnvironment() is memoized, so subsequent calls are fast
|
||||
const loginShellEnv = await getLoginShellEnvironment()
|
||||
|
||||
// For DXT servers, use resolved configuration with platform overrides and variable substitution
|
||||
if (server.dxtPath) {
|
||||
const resolvedConfig = this.dxtService.getResolvedMcpConfig(server.dxtPath)
|
||||
@ -339,18 +343,45 @@ class McpService {
|
||||
}
|
||||
|
||||
if (server.command === 'npx') {
|
||||
cmd = await getBinaryPath('bun')
|
||||
getServerLogger(server).debug(`Using command`, { command: cmd })
|
||||
// First, check if npx is available in user's shell environment
|
||||
const npxPath = await findCommandInShellEnv('npx', loginShellEnv)
|
||||
|
||||
// add -x to args if args exist
|
||||
if (args && args.length > 0) {
|
||||
if (!args.includes('-y')) {
|
||||
args.unshift('-y')
|
||||
}
|
||||
if (!args.includes('x')) {
|
||||
args.unshift('x')
|
||||
if (npxPath) {
|
||||
// Use system npx
|
||||
cmd = npxPath
|
||||
getServerLogger(server).debug(`Using system npx`, { command: cmd })
|
||||
} else {
|
||||
// System npx not found, try bundled bun as fallback
|
||||
getServerLogger(server).debug(`System npx not found, checking for bundled bun`)
|
||||
|
||||
if (await isBinaryExists('bun')) {
|
||||
// Fall back to bundled bun
|
||||
cmd = await getBinaryPath('bun')
|
||||
getServerLogger(server).info(`Using bundled bun as fallback (npx not found in PATH)`, {
|
||||
command: cmd
|
||||
})
|
||||
|
||||
// Transform args for bun x format
|
||||
if (args && args.length > 0) {
|
||||
if (!args.includes('-y')) {
|
||||
args.unshift('-y')
|
||||
}
|
||||
if (!args.includes('x')) {
|
||||
args.unshift('x')
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Neither npx nor bun available
|
||||
throw new Error(
|
||||
'npx not found in PATH and bundled bun is not available. This may indicate an installation issue.\n' +
|
||||
'Please either:\n' +
|
||||
'1. Install Node.js (which includes npx) from https://nodejs.org\n' +
|
||||
'2. Run the MCP dependencies installer from Settings\n' +
|
||||
'3. Restart the application if you recently installed Node.js'
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if (server.registryUrl) {
|
||||
server.env = {
|
||||
...server.env,
|
||||
@ -365,7 +396,35 @@ class McpService {
|
||||
}
|
||||
}
|
||||
} else if (server.command === 'uvx' || server.command === 'uv') {
|
||||
cmd = await getBinaryPath(server.command)
|
||||
// First, check if uvx/uv is available in user's shell environment
|
||||
const uvPath = await findCommandInShellEnv(server.command, loginShellEnv)
|
||||
|
||||
if (uvPath) {
|
||||
// Use system uvx/uv
|
||||
cmd = uvPath
|
||||
getServerLogger(server).debug(`Using system ${server.command}`, { command: cmd })
|
||||
} else {
|
||||
// System command not found, try bundled version as fallback
|
||||
getServerLogger(server).debug(`System ${server.command} not found, checking for bundled version`)
|
||||
|
||||
if (await isBinaryExists(server.command)) {
|
||||
// Fall back to bundled version
|
||||
cmd = await getBinaryPath(server.command)
|
||||
getServerLogger(server).info(`Using bundled ${server.command} as fallback (not found in PATH)`, {
|
||||
command: cmd
|
||||
})
|
||||
} else {
|
||||
// Neither system nor bundled available
|
||||
throw new Error(
|
||||
`${server.command} not found in PATH and bundled version is not available. This may indicate an installation issue.\n` +
|
||||
'Please either:\n' +
|
||||
'1. Install uv from https://github.com/astral-sh/uv\n' +
|
||||
'2. Run the MCP dependencies installer from Settings\n' +
|
||||
`3. Restart the application if you recently installed ${server.command}`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if (server.registryUrl) {
|
||||
server.env = {
|
||||
...server.env,
|
||||
@ -376,8 +435,6 @@ class McpService {
|
||||
}
|
||||
|
||||
getServerLogger(server).debug(`Starting server`, { command: cmd, args })
|
||||
// Logger.info(`[MCP] Environment variables for server:`, server.env)
|
||||
const loginShellEnv = await getLoginShellEnvironment()
|
||||
|
||||
// Bun not support proxy https://github.com/oven-sh/bun/issues/16812
|
||||
if (cmd.includes('bun')) {
|
||||
@ -728,7 +785,7 @@ class McpService {
|
||||
...tool,
|
||||
inputSchema: z.parse(MCPToolInputSchema, tool.inputSchema),
|
||||
outputSchema: tool.outputSchema ? z.parse(MCPToolOutputSchema, tool.outputSchema) : undefined,
|
||||
id: buildFunctionCallToolName(server.name, tool.name, server.id),
|
||||
id: buildFunctionCallToolName(server.name, tool.name),
|
||||
serverId: server.id,
|
||||
serverName: server.name,
|
||||
type: 'mcp'
|
||||
|
||||
@ -3,6 +3,8 @@ import { homedir } from 'node:os'
|
||||
import { promisify } from 'node:util'
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import { isWin } from '@main/constant'
|
||||
import { getCpuName } from '@main/utils/system'
|
||||
import { HOME_CHERRY_DIR } from '@shared/config/constant'
|
||||
import * as fs from 'fs-extra'
|
||||
import * as path from 'path'
|
||||
@ -11,6 +13,8 @@ const logger = loggerService.withContext('OvmsManager')
|
||||
|
||||
const execAsync = promisify(exec)
|
||||
|
||||
export const isOvmsSupported = isWin && getCpuName().toLowerCase().includes('intel')
|
||||
|
||||
interface OvmsProcess {
|
||||
pid: number
|
||||
path: string
|
||||
@ -29,6 +33,12 @@ interface OvmsConfig {
|
||||
class OvmsManager {
|
||||
private ovms: OvmsProcess | null = null
|
||||
|
||||
constructor() {
|
||||
if (!isOvmsSupported) {
|
||||
throw new Error('OVMS Manager is only supported on Windows platform with Intel CPU.')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively terminate a process and all its child processes
|
||||
* @param pid Process ID to terminate
|
||||
@ -102,32 +112,10 @@ class OvmsManager {
|
||||
*/
|
||||
public async stopOvms(): Promise<{ success: boolean; message?: string }> {
|
||||
try {
|
||||
// Check if OVMS process is running
|
||||
const psCommand = `Get-Process -Name "ovms" -ErrorAction SilentlyContinue | Select-Object Id, Path | ConvertTo-Json`
|
||||
const { stdout } = await execAsync(`powershell -Command "${psCommand}"`)
|
||||
|
||||
if (!stdout.trim()) {
|
||||
logger.info('OVMS process is not running')
|
||||
return { success: true, message: 'OVMS process is not running' }
|
||||
}
|
||||
|
||||
const processes = JSON.parse(stdout)
|
||||
const processList = Array.isArray(processes) ? processes : [processes]
|
||||
|
||||
if (processList.length === 0) {
|
||||
logger.info('OVMS process is not running')
|
||||
return { success: true, message: 'OVMS process is not running' }
|
||||
}
|
||||
|
||||
// Terminate all OVMS processes using terminalProcess
|
||||
for (const process of processList) {
|
||||
const result = await this.terminalProcess(process.Id)
|
||||
if (!result.success) {
|
||||
logger.error(`Failed to terminate OVMS process with PID: ${process.Id}, ${result.message}`)
|
||||
return { success: false, message: `Failed to terminate OVMS process: ${result.message}` }
|
||||
}
|
||||
logger.info(`Terminated OVMS process with PID: ${process.Id}`)
|
||||
}
|
||||
// close the OVMS process
|
||||
await execAsync(
|
||||
`powershell -Command "Get-WmiObject Win32_Process | Where-Object { $_.CommandLine -like 'ovms.exe*' } | ForEach-Object { Stop-Process -Id $_.ProcessId -Force }"`
|
||||
)
|
||||
|
||||
// Reset the ovms instance
|
||||
this.ovms = null
|
||||
@ -584,4 +572,5 @@ class OvmsManager {
|
||||
}
|
||||
}
|
||||
|
||||
export default OvmsManager
|
||||
// Export singleton instance
|
||||
export const ovmsManager = isOvmsSupported ? new OvmsManager() : undefined
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/**
|
||||
* @deprecated Scheduled for removal in v2.0.0
|
||||
* --------------------------------------------------------------------------
|
||||
* ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex)
|
||||
* --------------------------------------------------------------------------
|
||||
* STOP: Feature PRs affecting this file are currently BLOCKED.
|
||||
* Only critical bug fixes are accepted during this migration phase.
|
||||
*
|
||||
* This file is being refactored to v2 standards.
|
||||
* Any non-critical changes will conflict with the ongoing work.
|
||||
*
|
||||
* 🔗 Context & Status:
|
||||
* - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954
|
||||
* - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162
|
||||
* --------------------------------------------------------------------------
|
||||
*/
|
||||
import { loggerService } from '@logger'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { ipcMain } from 'electron'
|
||||
|
||||
@ -14,38 +14,36 @@ export class SearchService {
|
||||
return SearchService.instance
|
||||
}
|
||||
|
||||
constructor() {
|
||||
// Initialize the service
|
||||
}
|
||||
|
||||
private async createNewSearchWindow(uid: string): Promise<BrowserWindow> {
|
||||
private async createNewSearchWindow(uid: string, show: boolean = false): Promise<BrowserWindow> {
|
||||
const newWindow = new BrowserWindow({
|
||||
width: 800,
|
||||
height: 600,
|
||||
show: false,
|
||||
width: 1280,
|
||||
height: 768,
|
||||
show,
|
||||
webPreferences: {
|
||||
nodeIntegration: true,
|
||||
contextIsolation: false,
|
||||
devTools: is.dev
|
||||
}
|
||||
})
|
||||
newWindow.webContents.session.webRequest.onBeforeSendHeaders({ urls: ['*://*/*'] }, (details, callback) => {
|
||||
const headers = {
|
||||
...details.requestHeaders,
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||||
}
|
||||
callback({ requestHeaders: headers })
|
||||
})
|
||||
|
||||
this.searchWindows[uid] = newWindow
|
||||
newWindow.on('closed', () => {
|
||||
delete this.searchWindows[uid]
|
||||
})
|
||||
newWindow.on('closed', () => delete this.searchWindows[uid])
|
||||
|
||||
newWindow.webContents.userAgent =
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36'
|
||||
|
||||
return newWindow
|
||||
}
|
||||
|
||||
public async openSearchWindow(uid: string): Promise<void> {
|
||||
await this.createNewSearchWindow(uid)
|
||||
public async openSearchWindow(uid: string, show: boolean = false): Promise<void> {
|
||||
const existingWindow = this.searchWindows[uid]
|
||||
|
||||
if (existingWindow) {
|
||||
show && existingWindow.show()
|
||||
return
|
||||
}
|
||||
|
||||
await this.createNewSearchWindow(uid, show)
|
||||
}
|
||||
|
||||
public async closeSearchWindow(uid: string): Promise<void> {
|
||||
|
||||
@ -1435,6 +1435,12 @@ export class SelectionService {
|
||||
}
|
||||
|
||||
actionWindow.setBounds({ x, y, width, height })
|
||||
|
||||
// [Windows only] Update remembered window size for custom resize
|
||||
// setBounds() may not trigger the 'resized' event, so we need to update manually
|
||||
if (this.isRemeberWinSize) {
|
||||
this.lastActionWindowSize = { width, height }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/**
|
||||
* @deprecated Scheduled for removal in v2.0.0
|
||||
* --------------------------------------------------------------------------
|
||||
* ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex)
|
||||
* --------------------------------------------------------------------------
|
||||
* STOP: Feature PRs affecting this file are currently BLOCKED.
|
||||
* Only critical bug fixes are accepted during this migration phase.
|
||||
*
|
||||
* This file is being refactored to v2 standards.
|
||||
* Any non-critical changes will conflict with the ongoing work.
|
||||
*
|
||||
* 🔗 Context & Status:
|
||||
* - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954
|
||||
* - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162
|
||||
* --------------------------------------------------------------------------
|
||||
*/
|
||||
import { loggerService } from '@logger'
|
||||
import { handleZoomFactor } from '@main/utils/zoom'
|
||||
import type { Shortcut } from '@types'
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/**
|
||||
* @deprecated Scheduled for removal in v2.0.0
|
||||
* --------------------------------------------------------------------------
|
||||
* ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex)
|
||||
* --------------------------------------------------------------------------
|
||||
* STOP: Feature PRs affecting this file are currently BLOCKED.
|
||||
* Only critical bug fixes are accepted during this migration phase.
|
||||
*
|
||||
* This file is being refactored to v2 standards.
|
||||
* Any non-critical changes will conflict with the ongoing work.
|
||||
*
|
||||
* 🔗 Context & Status:
|
||||
* - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954
|
||||
* - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162
|
||||
* --------------------------------------------------------------------------
|
||||
*/
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import type { StoreSyncAction } from '@types'
|
||||
import { BrowserWindow, ipcMain } from 'electron'
|
||||
|
||||
@ -1,359 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { WebSocketCandidatesResponse, WebSocketStatusResponse } from '@shared/config/types'
|
||||
import * as fs from 'fs'
|
||||
import { networkInterfaces } from 'os'
|
||||
import * as path from 'path'
|
||||
import type { Socket } from 'socket.io'
|
||||
import { Server } from 'socket.io'
|
||||
|
||||
import { windowService } from './WindowService'
|
||||
|
||||
const logger = loggerService.withContext('WebSocketService')
|
||||
|
||||
class WebSocketService {
|
||||
private io: Server | null = null
|
||||
private isStarted = false
|
||||
private port = 7017
|
||||
private connectedClients = new Set<string>()
|
||||
|
||||
private getLocalIpAddress(): string | undefined {
|
||||
const interfaces = networkInterfaces()
|
||||
|
||||
// 按优先级排序的网络接口名称模式
|
||||
const interfacePriority = [
|
||||
// macOS: 以太网/Wi-Fi 优先
|
||||
/^en[0-9]+$/, // en0, en1 (以太网/Wi-Fi)
|
||||
/^(en|eth)[0-9]+$/, // 以太网接口
|
||||
/^wlan[0-9]+$/, // 无线接口
|
||||
// Windows: 以太网/Wi-Fi 优先
|
||||
/^(Ethernet|Wi-Fi|Local Area Connection)/,
|
||||
/^(Wi-Fi|无线网络连接)/,
|
||||
// Linux: 以太网/Wi-Fi 优先
|
||||
/^(eth|enp|wlp|wlan)[0-9]+/,
|
||||
// 虚拟化接口(低优先级)
|
||||
/^bridge[0-9]+$/, // Docker bridge
|
||||
/^veth[0-9]+$/, // Docker veth
|
||||
/^docker[0-9]+/, // Docker interfaces
|
||||
/^br-[0-9a-f]+/, // Docker bridge
|
||||
/^vmnet[0-9]+$/, // VMware
|
||||
/^vboxnet[0-9]+$/, // VirtualBox
|
||||
// VPN 隧道接口(低优先级)
|
||||
/^utun[0-9]+$/, // macOS VPN
|
||||
/^tun[0-9]+$/, // Linux/Unix VPN
|
||||
/^tap[0-9]+$/, // TAP interfaces
|
||||
/^tailscale[0-9]*$/, // Tailscale VPN
|
||||
/^wg[0-9]+$/ // WireGuard VPN
|
||||
]
|
||||
|
||||
const candidates: Array<{ interface: string; address: string; priority: number }> = []
|
||||
|
||||
for (const [name, ifaces] of Object.entries(interfaces)) {
|
||||
for (const iface of ifaces || []) {
|
||||
if (iface.family === 'IPv4' && !iface.internal) {
|
||||
// 计算接口优先级
|
||||
let priority = 999 // 默认最低优先级
|
||||
for (let i = 0; i < interfacePriority.length; i++) {
|
||||
if (interfacePriority[i].test(name)) {
|
||||
priority = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
candidates.push({
|
||||
interface: name,
|
||||
address: iface.address,
|
||||
priority
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidates.length === 0) {
|
||||
logger.warn('无法获取局域网 IP,使用默认 IP: 127.0.0.1')
|
||||
return '127.0.0.1'
|
||||
}
|
||||
|
||||
// 按优先级排序,选择优先级最高的
|
||||
candidates.sort((a, b) => a.priority - b.priority)
|
||||
const best = candidates[0]
|
||||
|
||||
logger.info(`获取局域网 IP: ${best.address} (interface: ${best.interface})`)
|
||||
return best.address
|
||||
}
|
||||
|
||||
public start = async (): Promise<{ success: boolean; port?: number; error?: string }> => {
|
||||
if (this.isStarted && this.io) {
|
||||
return { success: true, port: this.port }
|
||||
}
|
||||
|
||||
try {
|
||||
this.io = new Server(this.port, {
|
||||
cors: {
|
||||
origin: '*',
|
||||
methods: ['GET', 'POST']
|
||||
},
|
||||
transports: ['websocket', 'polling'],
|
||||
allowEIO3: true,
|
||||
pingTimeout: 60000,
|
||||
pingInterval: 25000
|
||||
})
|
||||
|
||||
this.io.on('connection', (socket: Socket) => {
|
||||
this.connectedClients.add(socket.id)
|
||||
|
||||
const mainWindow = windowService.getMainWindow()
|
||||
if (!mainWindow) {
|
||||
logger.error('Main window is null, cannot send connection event')
|
||||
} else {
|
||||
mainWindow.webContents.send('websocket-client-connected', {
|
||||
connected: true,
|
||||
clientId: socket.id
|
||||
})
|
||||
logger.info(`Connection event sent to renderer, total clients: ${this.connectedClients.size}`)
|
||||
}
|
||||
|
||||
socket.on('message', (data) => {
|
||||
logger.info('Received message from mobile:', data)
|
||||
mainWindow?.webContents.send('websocket-message-received', data)
|
||||
socket.emit('message_received', { success: true })
|
||||
})
|
||||
|
||||
socket.on('disconnect', () => {
|
||||
logger.info(`Client disconnected: ${socket.id}`)
|
||||
this.connectedClients.delete(socket.id)
|
||||
|
||||
if (this.connectedClients.size === 0) {
|
||||
mainWindow?.webContents.send('websocket-client-connected', {
|
||||
connected: false,
|
||||
clientId: socket.id
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// Engine 层面的事件监听
|
||||
this.io.engine.on('connection_error', (err) => {
|
||||
logger.error('Engine connection error:', err)
|
||||
})
|
||||
|
||||
this.io.engine.on('connection', (rawSocket) => {
|
||||
const remoteAddr = rawSocket.request.connection.remoteAddress
|
||||
logger.info(`[Engine] Raw connection from: ${remoteAddr}`)
|
||||
logger.info(`[Engine] Transport: ${rawSocket.transport.name}`)
|
||||
|
||||
rawSocket.on('packet', (packet: { type: string; data?: any }) => {
|
||||
logger.info(
|
||||
`[Engine] ← Packet from ${remoteAddr}: type="${packet.type}"`,
|
||||
packet.data ? { data: packet.data } : {}
|
||||
)
|
||||
})
|
||||
|
||||
rawSocket.on('packetCreate', (packet: { type: string; data?: any }) => {
|
||||
logger.info(`[Engine] → Packet to ${remoteAddr}: type="${packet.type}"`)
|
||||
})
|
||||
|
||||
rawSocket.on('close', (reason: string) => {
|
||||
logger.warn(`[Engine] Connection closed from ${remoteAddr}, reason: ${reason}`)
|
||||
})
|
||||
|
||||
rawSocket.on('error', (error: Error) => {
|
||||
logger.error(`[Engine] Connection error from ${remoteAddr}:`, error)
|
||||
})
|
||||
})
|
||||
|
||||
// Socket.IO 握手失败监听
|
||||
this.io.on('connection_error', (err) => {
|
||||
logger.error('[Socket.IO] Connection error during handshake:', err)
|
||||
})
|
||||
|
||||
this.isStarted = true
|
||||
logger.info(`WebSocket server started on port ${this.port}`)
|
||||
|
||||
return { success: true, port: this.port }
|
||||
} catch (error) {
|
||||
logger.error('Failed to start WebSocket server:', error as Error)
|
||||
return {
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : 'Unknown error'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public stop = async (): Promise<{ success: boolean }> => {
|
||||
if (!this.isStarted || !this.io) {
|
||||
return { success: true }
|
||||
}
|
||||
|
||||
try {
|
||||
await new Promise<void>((resolve) => {
|
||||
this.io!.close(() => {
|
||||
resolve()
|
||||
})
|
||||
})
|
||||
|
||||
this.io = null
|
||||
this.isStarted = false
|
||||
this.connectedClients.clear()
|
||||
logger.info('WebSocket server stopped')
|
||||
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
logger.error('Failed to stop WebSocket server:', error as Error)
|
||||
return { success: false }
|
||||
}
|
||||
}
|
||||
|
||||
public getStatus = async (): Promise<WebSocketStatusResponse> => {
|
||||
return {
|
||||
isRunning: this.isStarted,
|
||||
port: this.isStarted ? this.port : undefined,
|
||||
ip: this.isStarted ? this.getLocalIpAddress() : undefined,
|
||||
clientConnected: this.connectedClients.size > 0
|
||||
}
|
||||
}
|
||||
|
||||
public getAllCandidates = async (): Promise<WebSocketCandidatesResponse[]> => {
|
||||
const interfaces = networkInterfaces()
|
||||
|
||||
// 按优先级排序的网络接口名称模式
|
||||
const interfacePriority = [
|
||||
// macOS: 以太网/Wi-Fi 优先
|
||||
/^en[0-9]+$/, // en0, en1 (以太网/Wi-Fi)
|
||||
/^(en|eth)[0-9]+$/, // 以太网接口
|
||||
/^wlan[0-9]+$/, // 无线接口
|
||||
// Windows: 以太网/Wi-Fi 优先
|
||||
/^(Ethernet|Wi-Fi|Local Area Connection)/,
|
||||
/^(Wi-Fi|无线网络连接)/,
|
||||
// Linux: 以太网/Wi-Fi 优先
|
||||
/^(eth|enp|wlp|wlan)[0-9]+/,
|
||||
// 虚拟化接口(低优先级)
|
||||
/^bridge[0-9]+$/, // Docker bridge
|
||||
/^veth[0-9]+$/, // Docker veth
|
||||
/^docker[0-9]+/, // Docker interfaces
|
||||
/^br-[0-9a-f]+/, // Docker bridge
|
||||
/^vmnet[0-9]+$/, // VMware
|
||||
/^vboxnet[0-9]+$/, // VirtualBox
|
||||
// VPN 隧道接口(低优先级)
|
||||
/^utun[0-9]+$/, // macOS VPN
|
||||
/^tun[0-9]+$/, // Linux/Unix VPN
|
||||
/^tap[0-9]+$/, // TAP interfaces
|
||||
/^tailscale[0-9]*$/, // Tailscale VPN
|
||||
/^wg[0-9]+$/ // WireGuard VPN
|
||||
]
|
||||
|
||||
const candidates: Array<{ host: string; interface: string; priority: number }> = []
|
||||
|
||||
for (const [name, ifaces] of Object.entries(interfaces)) {
|
||||
for (const iface of ifaces || []) {
|
||||
if (iface.family === 'IPv4' && !iface.internal) {
|
||||
// 计算接口优先级
|
||||
let priority = 999 // 默认最低优先级
|
||||
for (let i = 0; i < interfacePriority.length; i++) {
|
||||
if (interfacePriority[i].test(name)) {
|
||||
priority = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
candidates.push({
|
||||
host: iface.address,
|
||||
interface: name,
|
||||
priority
|
||||
})
|
||||
|
||||
logger.debug(`Found interface: ${name} -> ${iface.address} (priority: ${priority})`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 按优先级排序返回
|
||||
candidates.sort((a, b) => a.priority - b.priority)
|
||||
logger.info(
|
||||
`Found ${candidates.length} IP candidates: ${candidates.map((c) => `${c.host}(${c.interface})`).join(', ')}`
|
||||
)
|
||||
return candidates
|
||||
}
|
||||
|
||||
public sendFile = async (
|
||||
_: Electron.IpcMainInvokeEvent,
|
||||
filePath: string
|
||||
): Promise<{ success: boolean; error?: string }> => {
|
||||
if (!this.isStarted || !this.io) {
|
||||
const errorMsg = 'WebSocket server is not running.'
|
||||
logger.error(errorMsg)
|
||||
return { success: false, error: errorMsg }
|
||||
}
|
||||
|
||||
if (this.connectedClients.size === 0) {
|
||||
const errorMsg = 'No client connected.'
|
||||
logger.error(errorMsg)
|
||||
return { success: false, error: errorMsg }
|
||||
}
|
||||
|
||||
const mainWindow = windowService.getMainWindow()
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const stats = fs.statSync(filePath)
|
||||
const totalSize = stats.size
|
||||
const filename = path.basename(filePath)
|
||||
const stream = fs.createReadStream(filePath)
|
||||
let bytesSent = 0
|
||||
const startTime = Date.now()
|
||||
|
||||
logger.info(`Starting file transfer: ${filename} (${this.formatFileSize(totalSize)})`)
|
||||
|
||||
// 向客户端发送文件开始的信号,包含文件名和总大小
|
||||
this.io!.emit('zip-file-start', { filename, totalSize })
|
||||
|
||||
stream.on('data', (chunk) => {
|
||||
bytesSent += chunk.length
|
||||
const progress = (bytesSent / totalSize) * 100
|
||||
|
||||
// 向客户端发送文件块
|
||||
this.io!.emit('zip-file-chunk', chunk)
|
||||
|
||||
// 向渲染进程发送进度更新
|
||||
mainWindow?.webContents.send('file-send-progress', { progress })
|
||||
|
||||
// 每10%记录一次进度
|
||||
if (Math.floor(progress) % 10 === 0) {
|
||||
const elapsed = (Date.now() - startTime) / 1000
|
||||
const speed = elapsed > 0 ? bytesSent / elapsed : 0
|
||||
logger.info(`Transfer progress: ${Math.floor(progress)}% (${this.formatFileSize(speed)}/s)`)
|
||||
}
|
||||
})
|
||||
|
||||
stream.on('end', () => {
|
||||
const totalTime = (Date.now() - startTime) / 1000
|
||||
const avgSpeed = totalTime > 0 ? totalSize / totalTime : 0
|
||||
logger.info(
|
||||
`File transfer completed: ${filename} in ${totalTime.toFixed(1)}s (${this.formatFileSize(avgSpeed)}/s)`
|
||||
)
|
||||
|
||||
// 确保发送100%的进度
|
||||
mainWindow?.webContents.send('file-send-progress', { progress: 100 })
|
||||
// 向客户端发送文件结束的信号
|
||||
this.io!.emit('zip-file-end')
|
||||
resolve({ success: true })
|
||||
})
|
||||
|
||||
stream.on('error', (error) => {
|
||||
logger.error(`File transfer failed: ${filename}`, error)
|
||||
reject({
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : 'Unknown error'
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
private formatFileSize(bytes: number): string {
|
||||
if (bytes === 0) return '0 B'
|
||||
const k = 1024
|
||||
const sizes = ['B', 'KB', 'MB', 'GB']
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k))
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]
|
||||
}
|
||||
}
|
||||
|
||||
export default new WebSocketService()
|
||||
@ -0,0 +1,274 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
// Use vi.hoisted to define mocks that are available during hoisting
|
||||
const { mockLogger } = vi.hoisted(() => ({
|
||||
mockLogger: {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => mockLogger
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('electron', () => ({
|
||||
app: {
|
||||
getPath: vi.fn((key: string) => {
|
||||
if (key === 'temp') return '/tmp'
|
||||
if (key === 'userData') return '/mock/userData'
|
||||
return '/mock/unknown'
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('fs-extra', () => ({
|
||||
default: {
|
||||
pathExists: vi.fn(),
|
||||
remove: vi.fn(),
|
||||
ensureDir: vi.fn(),
|
||||
copy: vi.fn(),
|
||||
readdir: vi.fn(),
|
||||
stat: vi.fn(),
|
||||
readFile: vi.fn(),
|
||||
writeFile: vi.fn(),
|
||||
createWriteStream: vi.fn(),
|
||||
createReadStream: vi.fn()
|
||||
},
|
||||
pathExists: vi.fn(),
|
||||
remove: vi.fn(),
|
||||
ensureDir: vi.fn(),
|
||||
copy: vi.fn(),
|
||||
readdir: vi.fn(),
|
||||
stat: vi.fn(),
|
||||
readFile: vi.fn(),
|
||||
writeFile: vi.fn(),
|
||||
createWriteStream: vi.fn(),
|
||||
createReadStream: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('../WindowService', () => ({
|
||||
windowService: {
|
||||
getMainWindow: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../WebDav', () => ({
|
||||
default: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('../S3Storage', () => ({
|
||||
default: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('../../utils', () => ({
|
||||
getDataPath: vi.fn(() => '/mock/data')
|
||||
}))
|
||||
|
||||
vi.mock('archiver', () => ({
|
||||
default: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('node-stream-zip', () => ({
|
||||
default: vi.fn()
|
||||
}))
|
||||
|
||||
// Import after mocks
|
||||
import * as fs from 'fs-extra'
|
||||
|
||||
import BackupManager from '../BackupManager'
|
||||
|
||||
describe('BackupManager.deleteTempBackup - Security Tests', () => {
|
||||
let backupManager: BackupManager
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
backupManager = new BackupManager()
|
||||
})
|
||||
|
||||
describe('Normal Operations', () => {
|
||||
it('should delete valid file in allowed directory', async () => {
|
||||
vi.mocked(fs.pathExists).mockResolvedValue(true as never)
|
||||
vi.mocked(fs.remove).mockResolvedValue(undefined as never)
|
||||
|
||||
const validPath = '/tmp/cherry-studio/lan-transfer/backup.zip'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, validPath)
|
||||
|
||||
expect(result).toBe(true)
|
||||
expect(fs.remove).toHaveBeenCalledWith(validPath)
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('Deleted temp backup'))
|
||||
})
|
||||
|
||||
it('should delete file in nested subdirectory', async () => {
|
||||
vi.mocked(fs.pathExists).mockResolvedValue(true as never)
|
||||
vi.mocked(fs.remove).mockResolvedValue(undefined as never)
|
||||
|
||||
const nestedPath = '/tmp/cherry-studio/lan-transfer/sub/dir/file.zip'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, nestedPath)
|
||||
|
||||
expect(result).toBe(true)
|
||||
expect(fs.remove).toHaveBeenCalledWith(nestedPath)
|
||||
})
|
||||
|
||||
it('should return false when file does not exist', async () => {
|
||||
vi.mocked(fs.pathExists).mockResolvedValue(false as never)
|
||||
|
||||
const missingPath = '/tmp/cherry-studio/lan-transfer/missing.zip'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, missingPath)
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(fs.remove).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Path Traversal Attacks', () => {
|
||||
it('should block basic directory traversal attack (../../../../etc/passwd)', async () => {
|
||||
const attackPath = '/tmp/cherry-studio/lan-transfer/../../../../etc/passwd'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(fs.pathExists).not.toHaveBeenCalled()
|
||||
expect(fs.remove).not.toHaveBeenCalled()
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(expect.stringContaining('outside temp directory'))
|
||||
})
|
||||
|
||||
it('should block absolute path escape (/etc/passwd)', async () => {
|
||||
const attackPath = '/etc/passwd'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(fs.remove).not.toHaveBeenCalled()
|
||||
expect(mockLogger.warn).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should block traversal with multiple slashes', async () => {
|
||||
const attackPath = '/tmp/cherry-studio/lan-transfer/../../../etc/passwd'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(fs.remove).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should block relative path traversal from current directory', async () => {
|
||||
const attackPath = '../../../etc/passwd'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(fs.remove).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should block traversal to parent directory', async () => {
|
||||
const attackPath = '/tmp/cherry-studio/lan-transfer/../backup/secret.zip'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(fs.remove).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Prefix Attacks', () => {
|
||||
it('should block similar prefix attack (lan-transfer-evil)', async () => {
|
||||
const attackPath = '/tmp/cherry-studio/lan-transfer-evil/file.zip'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(fs.remove).not.toHaveBeenCalled()
|
||||
expect(mockLogger.warn).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should block path without separator (lan-transferx)', async () => {
|
||||
const attackPath = '/tmp/cherry-studio/lan-transferx'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(fs.remove).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should block different temp directory prefix', async () => {
|
||||
const attackPath = '/tmp-evil/cherry-studio/lan-transfer/file.zip'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, attackPath)
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(fs.remove).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should return false and log error on permission denied', async () => {
|
||||
vi.mocked(fs.pathExists).mockResolvedValue(true as never)
|
||||
vi.mocked(fs.remove).mockRejectedValue(new Error('EACCES: permission denied') as never)
|
||||
|
||||
const validPath = '/tmp/cherry-studio/lan-transfer/file.zip'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, validPath)
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(mockLogger.error).toHaveBeenCalledWith(expect.stringContaining('Failed to delete'), expect.any(Error))
|
||||
})
|
||||
|
||||
it('should return false on fs.pathExists error', async () => {
|
||||
vi.mocked(fs.pathExists).mockRejectedValue(new Error('ENOENT') as never)
|
||||
|
||||
const validPath = '/tmp/cherry-studio/lan-transfer/file.zip'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, validPath)
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(mockLogger.error).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle empty path string', async () => {
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, '')
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(fs.remove).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should allow deletion of the temp directory itself', async () => {
|
||||
vi.mocked(fs.pathExists).mockResolvedValue(true as never)
|
||||
vi.mocked(fs.remove).mockResolvedValue(undefined as never)
|
||||
|
||||
const tempDir = '/tmp/cherry-studio/lan-transfer'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, tempDir)
|
||||
|
||||
expect(result).toBe(true)
|
||||
expect(fs.remove).toHaveBeenCalledWith(tempDir)
|
||||
})
|
||||
|
||||
it('should handle path with trailing slash', async () => {
|
||||
vi.mocked(fs.pathExists).mockResolvedValue(true as never)
|
||||
vi.mocked(fs.remove).mockResolvedValue(undefined as never)
|
||||
|
||||
const pathWithSlash = '/tmp/cherry-studio/lan-transfer/sub/'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, pathWithSlash)
|
||||
|
||||
// path.normalize removes trailing slash
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
it('should handle file with special characters in name', async () => {
|
||||
vi.mocked(fs.pathExists).mockResolvedValue(true as never)
|
||||
vi.mocked(fs.remove).mockResolvedValue(undefined as never)
|
||||
|
||||
const specialPath = '/tmp/cherry-studio/lan-transfer/file with spaces & (special).zip'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, specialPath)
|
||||
|
||||
expect(result).toBe(true)
|
||||
expect(fs.remove).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle path with double slashes', async () => {
|
||||
vi.mocked(fs.pathExists).mockResolvedValue(true as never)
|
||||
vi.mocked(fs.remove).mockResolvedValue(undefined as never)
|
||||
|
||||
const doubleSlashPath = '/tmp/cherry-studio//lan-transfer//file.zip'
|
||||
const result = await backupManager.deleteTempBackup({} as Electron.IpcMainInvokeEvent, doubleSlashPath)
|
||||
|
||||
// path.normalize handles double slashes
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
481
src/main/services/__tests__/LocalTransferService.test.ts
Normal file
481
src/main/services/__tests__/LocalTransferService.test.ts
Normal file
@ -0,0 +1,481 @@
|
||||
import { EventEmitter } from 'events'
|
||||
import { afterEach, beforeEach, describe, expect, it, type Mock, vi } from 'vitest'
|
||||
|
||||
// Create mock objects before vi.mock calls
|
||||
const mockLogger = {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn()
|
||||
}
|
||||
|
||||
let mockMainWindow: {
|
||||
isDestroyed: Mock
|
||||
webContents: { send: Mock }
|
||||
} | null = null
|
||||
|
||||
let mockBrowser: EventEmitter & {
|
||||
start: Mock
|
||||
stop: Mock
|
||||
removeAllListeners: Mock
|
||||
}
|
||||
|
||||
let mockBonjour: {
|
||||
find: Mock
|
||||
destroy: Mock
|
||||
}
|
||||
|
||||
// Mock dependencies before importing the service
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => mockLogger
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../WindowService', () => ({
|
||||
windowService: {
|
||||
getMainWindow: vi.fn(() => mockMainWindow)
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('bonjour-service', () => ({
|
||||
default: vi.fn(() => mockBonjour)
|
||||
}))
|
||||
|
||||
describe('LocalTransferService', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.resetModules()
|
||||
|
||||
// Reset mock objects
|
||||
mockMainWindow = {
|
||||
isDestroyed: vi.fn(() => false),
|
||||
webContents: { send: vi.fn() }
|
||||
}
|
||||
|
||||
mockBrowser = Object.assign(new EventEmitter(), {
|
||||
start: vi.fn(),
|
||||
stop: vi.fn(),
|
||||
removeAllListeners: vi.fn()
|
||||
})
|
||||
|
||||
mockBonjour = {
|
||||
find: vi.fn(() => mockBrowser),
|
||||
destroy: vi.fn()
|
||||
}
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks()
|
||||
})
|
||||
|
||||
describe('startDiscovery', () => {
|
||||
it('should set isScanning to true and start browser', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
const state = localTransferService.startDiscovery()
|
||||
|
||||
expect(state.isScanning).toBe(true)
|
||||
expect(state.lastScanStartedAt).toBeDefined()
|
||||
expect(mockBonjour.find).toHaveBeenCalledWith({ type: 'cherrystudio', protocol: 'tcp' })
|
||||
expect(mockBrowser.start).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should clear services when resetList is true', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
// First, start discovery and add a service
|
||||
localTransferService.startDiscovery()
|
||||
mockBrowser.emit('up', {
|
||||
name: 'Test Service',
|
||||
host: 'localhost',
|
||||
port: 12345,
|
||||
addresses: ['192.168.1.100'],
|
||||
fqdn: 'test.local'
|
||||
})
|
||||
|
||||
expect(localTransferService.getState().services).toHaveLength(1)
|
||||
|
||||
// Now restart with resetList
|
||||
const state = localTransferService.startDiscovery({ resetList: true })
|
||||
|
||||
expect(state.services).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('should broadcast state after starting discovery', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
expect(mockMainWindow?.webContents.send).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle browser.start() error', async () => {
|
||||
mockBrowser.start.mockImplementation(() => {
|
||||
throw new Error('Failed to start mDNS')
|
||||
})
|
||||
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
const state = localTransferService.startDiscovery()
|
||||
|
||||
expect(state.lastError).toBe('Failed to start mDNS')
|
||||
expect(mockLogger.error).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('stopDiscovery', () => {
|
||||
it('should set isScanning to false and stop browser', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
const state = localTransferService.stopDiscovery()
|
||||
|
||||
expect(state.isScanning).toBe(false)
|
||||
expect(mockBrowser.stop).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle browser.stop() error gracefully', async () => {
|
||||
mockBrowser.stop.mockImplementation(() => {
|
||||
throw new Error('Stop failed')
|
||||
})
|
||||
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
// Should not throw
|
||||
expect(() => localTransferService.stopDiscovery()).not.toThrow()
|
||||
expect(mockLogger.warn).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should broadcast state after stopping', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
vi.clearAllMocks()
|
||||
|
||||
localTransferService.stopDiscovery()
|
||||
|
||||
expect(mockMainWindow?.webContents.send).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('browser events', () => {
|
||||
it('should add service on "up" event', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
mockBrowser.emit('up', {
|
||||
name: 'Test Service',
|
||||
host: 'localhost',
|
||||
port: 12345,
|
||||
addresses: ['192.168.1.100'],
|
||||
fqdn: 'test.local',
|
||||
type: 'cherrystudio',
|
||||
protocol: 'tcp'
|
||||
})
|
||||
|
||||
const state = localTransferService.getState()
|
||||
expect(state.services).toHaveLength(1)
|
||||
expect(state.services[0].name).toBe('Test Service')
|
||||
expect(state.services[0].port).toBe(12345)
|
||||
expect(state.services[0].addresses).toContain('192.168.1.100')
|
||||
})
|
||||
|
||||
it('should remove service on "down" event', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
// Add service
|
||||
mockBrowser.emit('up', {
|
||||
name: 'Test Service',
|
||||
host: 'localhost',
|
||||
port: 12345,
|
||||
addresses: ['192.168.1.100'],
|
||||
fqdn: 'test.local'
|
||||
})
|
||||
|
||||
expect(localTransferService.getState().services).toHaveLength(1)
|
||||
|
||||
// Remove service
|
||||
mockBrowser.emit('down', {
|
||||
name: 'Test Service',
|
||||
host: 'localhost',
|
||||
port: 12345,
|
||||
fqdn: 'test.local'
|
||||
})
|
||||
|
||||
expect(localTransferService.getState().services).toHaveLength(0)
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('removed'))
|
||||
})
|
||||
|
||||
it('should set lastError on "error" event', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
mockBrowser.emit('error', new Error('Discovery failed'))
|
||||
|
||||
const state = localTransferService.getState()
|
||||
expect(state.lastError).toBe('Discovery failed')
|
||||
expect(mockLogger.error).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle non-Error objects in error event', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
mockBrowser.emit('error', 'String error message')
|
||||
|
||||
const state = localTransferService.getState()
|
||||
expect(state.lastError).toBe('String error message')
|
||||
})
|
||||
})
|
||||
|
||||
describe('getState', () => {
|
||||
it('should return sorted services by name', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
mockBrowser.emit('up', {
|
||||
name: 'Zebra Service',
|
||||
host: 'host1',
|
||||
port: 1001,
|
||||
addresses: ['192.168.1.1']
|
||||
})
|
||||
|
||||
mockBrowser.emit('up', {
|
||||
name: 'Alpha Service',
|
||||
host: 'host2',
|
||||
port: 1002,
|
||||
addresses: ['192.168.1.2']
|
||||
})
|
||||
|
||||
const state = localTransferService.getState()
|
||||
expect(state.services[0].name).toBe('Alpha Service')
|
||||
expect(state.services[1].name).toBe('Zebra Service')
|
||||
})
|
||||
|
||||
it('should include all state properties', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
const state = localTransferService.getState()
|
||||
|
||||
expect(state).toHaveProperty('services')
|
||||
expect(state).toHaveProperty('isScanning')
|
||||
expect(state).toHaveProperty('lastScanStartedAt')
|
||||
expect(state).toHaveProperty('lastUpdatedAt')
|
||||
})
|
||||
})
|
||||
|
||||
describe('getPeerById', () => {
|
||||
it('should return peer when exists', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
mockBrowser.emit('up', {
|
||||
name: 'Test Service',
|
||||
host: 'localhost',
|
||||
port: 12345,
|
||||
addresses: ['192.168.1.100'],
|
||||
fqdn: 'test.local'
|
||||
})
|
||||
|
||||
const services = localTransferService.getState().services
|
||||
const peer = localTransferService.getPeerById(services[0].id)
|
||||
|
||||
expect(peer).toBeDefined()
|
||||
expect(peer?.name).toBe('Test Service')
|
||||
})
|
||||
|
||||
it('should return undefined when peer does not exist', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
const peer = localTransferService.getPeerById('non-existent-id')
|
||||
|
||||
expect(peer).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('normalizeService', () => {
|
||||
it('should deduplicate addresses', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
mockBrowser.emit('up', {
|
||||
name: 'Test Service',
|
||||
host: 'localhost',
|
||||
port: 12345,
|
||||
addresses: ['192.168.1.100', '192.168.1.100', '10.0.0.1'],
|
||||
referer: { address: '192.168.1.100' }
|
||||
})
|
||||
|
||||
const services = localTransferService.getState().services
|
||||
expect(services[0].addresses).toHaveLength(2)
|
||||
expect(services[0].addresses).toContain('192.168.1.100')
|
||||
expect(services[0].addresses).toContain('10.0.0.1')
|
||||
})
|
||||
|
||||
it('should filter empty addresses', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
mockBrowser.emit('up', {
|
||||
name: 'Test Service',
|
||||
host: 'localhost',
|
||||
port: 12345,
|
||||
addresses: ['192.168.1.100', '', null as any]
|
||||
})
|
||||
|
||||
const services = localTransferService.getState().services
|
||||
expect(services[0].addresses).toEqual(['192.168.1.100'])
|
||||
})
|
||||
|
||||
it('should convert txt null/undefined values to empty strings', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
mockBrowser.emit('up', {
|
||||
name: 'Test Service',
|
||||
host: 'localhost',
|
||||
port: 12345,
|
||||
addresses: ['192.168.1.100'],
|
||||
txt: {
|
||||
version: '1.0',
|
||||
nullValue: null,
|
||||
undefinedValue: undefined,
|
||||
numberValue: 42
|
||||
}
|
||||
})
|
||||
|
||||
const services = localTransferService.getState().services
|
||||
expect(services[0].txt).toEqual({
|
||||
version: '1.0',
|
||||
nullValue: '',
|
||||
undefinedValue: '',
|
||||
numberValue: '42'
|
||||
})
|
||||
})
|
||||
|
||||
it('should not include txt when empty', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
mockBrowser.emit('up', {
|
||||
name: 'Test Service',
|
||||
host: 'localhost',
|
||||
port: 12345,
|
||||
addresses: ['192.168.1.100'],
|
||||
txt: {}
|
||||
})
|
||||
|
||||
const services = localTransferService.getState().services
|
||||
expect(services[0].txt).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('dispose', () => {
|
||||
it('should clean up all resources', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
mockBrowser.emit('up', {
|
||||
name: 'Test Service',
|
||||
host: 'localhost',
|
||||
port: 12345,
|
||||
addresses: ['192.168.1.100']
|
||||
})
|
||||
|
||||
localTransferService.dispose()
|
||||
|
||||
expect(localTransferService.getState().services).toHaveLength(0)
|
||||
expect(localTransferService.getState().isScanning).toBe(false)
|
||||
expect(mockBrowser.removeAllListeners).toHaveBeenCalled()
|
||||
expect(mockBonjour.destroy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle bonjour.destroy() error gracefully', async () => {
|
||||
mockBonjour.destroy.mockImplementation(() => {
|
||||
throw new Error('Destroy failed')
|
||||
})
|
||||
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
// Should not throw
|
||||
expect(() => localTransferService.dispose()).not.toThrow()
|
||||
expect(mockLogger.warn).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should be safe to call multiple times', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
expect(() => {
|
||||
localTransferService.dispose()
|
||||
localTransferService.dispose()
|
||||
}).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('broadcastState', () => {
|
||||
it('should not throw when main window is null', async () => {
|
||||
mockMainWindow = null
|
||||
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
// Should not throw
|
||||
expect(() => localTransferService.startDiscovery()).not.toThrow()
|
||||
})
|
||||
|
||||
it('should not throw when main window is destroyed', async () => {
|
||||
mockMainWindow = {
|
||||
isDestroyed: vi.fn(() => true),
|
||||
webContents: { send: vi.fn() }
|
||||
}
|
||||
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
// Should not throw
|
||||
expect(() => localTransferService.startDiscovery()).not.toThrow()
|
||||
expect(mockMainWindow.webContents.send).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('restartBrowser', () => {
|
||||
it('should destroy old bonjour instance to prevent socket leaks', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
// First start
|
||||
localTransferService.startDiscovery()
|
||||
expect(mockBonjour.destroy).not.toHaveBeenCalled()
|
||||
|
||||
// Restart - should destroy old instance
|
||||
localTransferService.startDiscovery()
|
||||
expect(mockBonjour.destroy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should remove all listeners from old browser', async () => {
|
||||
const { localTransferService } = await import('../LocalTransferService')
|
||||
|
||||
localTransferService.startDiscovery()
|
||||
localTransferService.startDiscovery()
|
||||
|
||||
expect(mockBrowser.removeAllListeners).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -2,6 +2,7 @@ import { loggerService } from '@logger'
|
||||
import { mcpApiService } from '@main/apiServer/services/mcp'
|
||||
import type { ModelValidationError } from '@main/apiServer/utils'
|
||||
import { validateModelId } from '@main/apiServer/utils'
|
||||
import { buildFunctionCallToolName } from '@main/utils/mcp'
|
||||
import type { AgentType, MCPTool, SlashCommand, Tool } from '@types'
|
||||
import { objectKeys } from '@types'
|
||||
import fs from 'fs'
|
||||
@ -14,6 +15,17 @@ import { builtinSlashCommands } from './services/claudecode/commands'
|
||||
import { builtinTools } from './services/claudecode/tools'
|
||||
|
||||
const logger = loggerService.withContext('BaseService')
|
||||
const MCP_TOOL_ID_PREFIX = 'mcp__'
|
||||
const MCP_TOOL_LEGACY_PREFIX = 'mcp_'
|
||||
|
||||
const buildMcpToolId = (serverId: string, toolName: string) => `${MCP_TOOL_ID_PREFIX}${serverId}__${toolName}`
|
||||
const toLegacyMcpToolId = (toolId: string) => {
|
||||
if (!toolId.startsWith(MCP_TOOL_ID_PREFIX)) {
|
||||
return null
|
||||
}
|
||||
const rawId = toolId.slice(MCP_TOOL_ID_PREFIX.length)
|
||||
return `${MCP_TOOL_LEGACY_PREFIX}${rawId.replace(/__/g, '_')}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Base service class providing shared utilities for all agent-related services.
|
||||
@ -35,8 +47,12 @@ export abstract class BaseService {
|
||||
'slash_commands'
|
||||
]
|
||||
|
||||
public async listMcpTools(agentType: AgentType, ids?: string[]): Promise<Tool[]> {
|
||||
public async listMcpTools(
|
||||
agentType: AgentType,
|
||||
ids?: string[]
|
||||
): Promise<{ tools: Tool[]; legacyIdMap: Map<string, string> }> {
|
||||
const tools: Tool[] = []
|
||||
const legacyIdMap = new Map<string, string>()
|
||||
if (agentType === 'claude-code') {
|
||||
tools.push(...builtinTools)
|
||||
}
|
||||
@ -46,13 +62,21 @@ export abstract class BaseService {
|
||||
const server = await mcpApiService.getServerInfo(id)
|
||||
if (server) {
|
||||
server.tools.forEach((tool: MCPTool) => {
|
||||
const canonicalId = buildFunctionCallToolName(server.name, tool.name)
|
||||
const serverIdBasedId = buildMcpToolId(id, tool.name)
|
||||
const legacyId = toLegacyMcpToolId(serverIdBasedId)
|
||||
|
||||
tools.push({
|
||||
id: `mcp_${id}_${tool.name}`,
|
||||
id: canonicalId,
|
||||
name: tool.name,
|
||||
type: 'mcp',
|
||||
description: tool.description || '',
|
||||
requirePermissions: true
|
||||
})
|
||||
legacyIdMap.set(serverIdBasedId, canonicalId)
|
||||
if (legacyId) {
|
||||
legacyIdMap.set(legacyId, canonicalId)
|
||||
}
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
@ -64,7 +88,53 @@ export abstract class BaseService {
|
||||
}
|
||||
}
|
||||
|
||||
return tools
|
||||
return { tools, legacyIdMap }
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize MCP tool IDs in allowed_tools to the current format.
|
||||
*
|
||||
* Legacy formats:
|
||||
* - "mcp__<serverId>__<toolName>" (double underscore separators, server ID based)
|
||||
* - "mcp_<serverId>_<toolName>" (single underscore separators)
|
||||
* Current format: "mcp__<serverName>__<toolName>" (double underscore separators).
|
||||
*
|
||||
* This keeps persisted data compatible without requiring a database migration.
|
||||
*/
|
||||
protected normalizeAllowedTools(
|
||||
allowedTools: string[] | undefined,
|
||||
tools: Tool[],
|
||||
legacyIdMap?: Map<string, string>
|
||||
): string[] | undefined {
|
||||
if (!allowedTools || allowedTools.length === 0) {
|
||||
return allowedTools
|
||||
}
|
||||
|
||||
const resolvedLegacyIdMap = new Map<string, string>()
|
||||
|
||||
if (legacyIdMap) {
|
||||
for (const [legacyId, canonicalId] of legacyIdMap) {
|
||||
resolvedLegacyIdMap.set(legacyId, canonicalId)
|
||||
}
|
||||
}
|
||||
|
||||
for (const tool of tools) {
|
||||
if (tool.type !== 'mcp') {
|
||||
continue
|
||||
}
|
||||
const legacyId = toLegacyMcpToolId(tool.id)
|
||||
if (!legacyId) {
|
||||
continue
|
||||
}
|
||||
resolvedLegacyIdMap.set(legacyId, tool.id)
|
||||
}
|
||||
|
||||
if (resolvedLegacyIdMap.size === 0) {
|
||||
return allowedTools
|
||||
}
|
||||
|
||||
const normalized = allowedTools.map((toolId) => resolvedLegacyIdMap.get(toolId) ?? toolId)
|
||||
return Array.from(new Set(normalized))
|
||||
}
|
||||
|
||||
public async listSlashCommands(agentType: AgentType): Promise<SlashCommand[]> {
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/**
|
||||
* @deprecated Scheduled for removal in v2.0.0
|
||||
* --------------------------------------------------------------------------
|
||||
* ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex)
|
||||
* --------------------------------------------------------------------------
|
||||
* STOP: Feature PRs affecting this file are currently BLOCKED.
|
||||
* Only critical bug fixes are accepted during this migration phase.
|
||||
*
|
||||
* This file is being refactored to v2 standards.
|
||||
* Any non-critical changes will conflict with the ongoing work.
|
||||
*
|
||||
* 🔗 Context & Status:
|
||||
* - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954
|
||||
* - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162
|
||||
* --------------------------------------------------------------------------
|
||||
*/
|
||||
import { type Client, createClient } from '@libsql/client'
|
||||
import { loggerService } from '@logger'
|
||||
import type { LibSQLDatabase } from 'drizzle-orm/libsql'
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/**
|
||||
* @deprecated Scheduled for removal in v2.0.0
|
||||
* --------------------------------------------------------------------------
|
||||
* ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex)
|
||||
* --------------------------------------------------------------------------
|
||||
* STOP: Feature PRs affecting this file are currently BLOCKED.
|
||||
* Only critical bug fixes are accepted during this migration phase.
|
||||
*
|
||||
* This file is being refactored to v2 standards.
|
||||
* Any non-critical changes will conflict with the ongoing work.
|
||||
*
|
||||
* 🔗 Context & Status:
|
||||
* - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954
|
||||
* - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162
|
||||
* --------------------------------------------------------------------------
|
||||
*/
|
||||
/**
|
||||
* Drizzle Kit configuration for agents database
|
||||
*/
|
||||
|
||||
@ -89,7 +89,9 @@ export class AgentService extends BaseService {
|
||||
}
|
||||
|
||||
const agent = this.deserializeJsonFields(result[0]) as GetAgentResponse
|
||||
agent.tools = await this.listMcpTools(agent.type, agent.mcps)
|
||||
const { tools, legacyIdMap } = await this.listMcpTools(agent.type, agent.mcps)
|
||||
agent.tools = tools
|
||||
agent.allowed_tools = this.normalizeAllowedTools(agent.allowed_tools, agent.tools, legacyIdMap)
|
||||
|
||||
// Load installed_plugins from cache file instead of database
|
||||
const workdir = agent.accessible_paths?.[0]
|
||||
@ -134,7 +136,9 @@ export class AgentService extends BaseService {
|
||||
const agents = result.map((row) => this.deserializeJsonFields(row)) as GetAgentResponse[]
|
||||
|
||||
for (const agent of agents) {
|
||||
agent.tools = await this.listMcpTools(agent.type, agent.mcps)
|
||||
const { tools, legacyIdMap } = await this.listMcpTools(agent.type, agent.mcps)
|
||||
agent.tools = tools
|
||||
agent.allowed_tools = this.normalizeAllowedTools(agent.allowed_tools, agent.tools, legacyIdMap)
|
||||
}
|
||||
|
||||
return { agents, total: totalResult[0].count }
|
||||
|
||||
@ -156,7 +156,9 @@ export class SessionService extends BaseService {
|
||||
}
|
||||
|
||||
const session = this.deserializeJsonFields(result[0]) as GetAgentSessionResponse
|
||||
session.tools = await this.listMcpTools(session.agent_type, session.mcps)
|
||||
const { tools, legacyIdMap } = await this.listMcpTools(session.agent_type, session.mcps)
|
||||
session.tools = tools
|
||||
session.allowed_tools = this.normalizeAllowedTools(session.allowed_tools, session.tools, legacyIdMap)
|
||||
|
||||
// If slash_commands is not in database yet (e.g., first invoke before init message),
|
||||
// fall back to builtin + local commands. Otherwise, use the merged commands from database.
|
||||
@ -202,6 +204,12 @@ export class SessionService extends BaseService {
|
||||
|
||||
const sessions = result.map((row) => this.deserializeJsonFields(row)) as GetAgentSessionResponse[]
|
||||
|
||||
for (const session of sessions) {
|
||||
const { tools, legacyIdMap } = await this.listMcpTools(session.agent_type, session.mcps)
|
||||
session.tools = tools
|
||||
session.allowed_tools = this.normalizeAllowedTools(session.allowed_tools, session.tools, legacyIdMap)
|
||||
}
|
||||
|
||||
return { sessions, total }
|
||||
}
|
||||
|
||||
|
||||
91
src/main/services/agents/tests/BaseService.test.ts
Normal file
91
src/main/services/agents/tests/BaseService.test.ts
Normal file
@ -0,0 +1,91 @@
|
||||
import type { Tool } from '@types'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('@main/apiServer/services/mcp', () => ({
|
||||
mcpApiService: {
|
||||
getServerInfo: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@main/apiServer/utils', () => ({
|
||||
validateModelId: vi.fn()
|
||||
}))
|
||||
|
||||
import { BaseService } from '../BaseService'
|
||||
|
||||
class TestBaseService extends BaseService {
|
||||
public normalize(
|
||||
allowedTools: string[] | undefined,
|
||||
tools: Tool[],
|
||||
legacyIdMap?: Map<string, string>
|
||||
): string[] | undefined {
|
||||
return this.normalizeAllowedTools(allowedTools, tools, legacyIdMap)
|
||||
}
|
||||
}
|
||||
|
||||
const buildMcpTool = (id: string): Tool => ({
|
||||
id,
|
||||
name: id,
|
||||
type: 'mcp',
|
||||
description: 'test tool',
|
||||
requirePermissions: true
|
||||
})
|
||||
|
||||
describe('BaseService.normalizeAllowedTools', () => {
|
||||
const service = new TestBaseService()
|
||||
|
||||
it('returns undefined or empty inputs unchanged', () => {
|
||||
expect(service.normalize(undefined, [])).toBeUndefined()
|
||||
expect(service.normalize([], [])).toEqual([])
|
||||
})
|
||||
|
||||
it('normalizes legacy MCP tool IDs and deduplicates entries', () => {
|
||||
const tools: Tool[] = [
|
||||
buildMcpTool('mcp__server_one__tool_one'),
|
||||
buildMcpTool('mcp__server_two__tool_two'),
|
||||
{ id: 'custom_tool', name: 'custom_tool', type: 'custom' }
|
||||
]
|
||||
|
||||
const legacyIdMap = new Map<string, string>([
|
||||
['mcp__server-1__tool-one', 'mcp__server_one__tool_one'],
|
||||
['mcp_server-1_tool-one', 'mcp__server_one__tool_one'],
|
||||
['mcp__server-2__tool-two', 'mcp__server_two__tool_two']
|
||||
])
|
||||
|
||||
const allowedTools = [
|
||||
'mcp__server-1__tool-one',
|
||||
'mcp_server-1_tool-one',
|
||||
'mcp_server_one_tool_one',
|
||||
'mcp__server_one__tool_one',
|
||||
'custom_tool',
|
||||
'mcp__server_two__tool_two',
|
||||
'mcp_server_two_tool_two',
|
||||
'mcp__server-2__tool-two'
|
||||
]
|
||||
|
||||
expect(service.normalize(allowedTools, tools, legacyIdMap)).toEqual([
|
||||
'mcp__server_one__tool_one',
|
||||
'custom_tool',
|
||||
'mcp__server_two__tool_two'
|
||||
])
|
||||
})
|
||||
|
||||
it('keeps legacy IDs when no matching MCP tool exists', () => {
|
||||
const tools: Tool[] = [buildMcpTool('mcp__server_one__tool_one')]
|
||||
const legacyIdMap = new Map<string, string>([['mcp__server-1__tool-one', 'mcp__server_one__tool_one']])
|
||||
|
||||
const allowedTools = ['mcp__unknown__tool', 'mcp__server_one__tool_one']
|
||||
|
||||
expect(service.normalize(allowedTools, tools, legacyIdMap)).toEqual([
|
||||
'mcp__unknown__tool',
|
||||
'mcp__server_one__tool_one'
|
||||
])
|
||||
})
|
||||
|
||||
it('returns allowed tools unchanged when no MCP tools are available', () => {
|
||||
const allowedTools = ['custom_tool', 'builtin_tool']
|
||||
const tools: Tool[] = [{ id: 'custom_tool', name: 'custom_tool', type: 'custom' }]
|
||||
|
||||
expect(service.normalize(allowedTools, tools)).toEqual(allowedTools)
|
||||
})
|
||||
})
|
||||
525
src/main/services/lanTransfer/LanTransferClientService.ts
Normal file
525
src/main/services/lanTransfer/LanTransferClientService.ts
Normal file
@ -0,0 +1,525 @@
|
||||
import * as crypto from 'node:crypto'
|
||||
import { createConnection, type Socket } from 'node:net'
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import type {
|
||||
LanClientEvent,
|
||||
LanFileCompleteMessage,
|
||||
LanHandshakeAckMessage,
|
||||
LocalTransferConnectPayload,
|
||||
LocalTransferPeer
|
||||
} from '@shared/config/types'
|
||||
import { LAN_TRANSFER_GLOBAL_TIMEOUT_MS } from '@shared/config/types'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
|
||||
import { localTransferService } from '../LocalTransferService'
|
||||
import { windowService } from '../WindowService'
|
||||
import {
|
||||
abortTransfer,
|
||||
buildHandshakeMessage,
|
||||
calculateFileChecksum,
|
||||
cleanupTransfer,
|
||||
createDataHandler,
|
||||
createTransferState,
|
||||
formatFileSize,
|
||||
HANDSHAKE_PROTOCOL_VERSION,
|
||||
pickHost,
|
||||
sendFileEnd,
|
||||
sendFileStart,
|
||||
sendTestPing,
|
||||
streamFileChunks,
|
||||
validateFile,
|
||||
waitForFileComplete,
|
||||
waitForFileStartAck
|
||||
} from './handlers'
|
||||
import { ResponseManager } from './responseManager'
|
||||
import type { ActiveFileTransfer, ConnectionContext, FileTransferContext } from './types'
|
||||
|
||||
const DEFAULT_HANDSHAKE_TIMEOUT_MS = 10_000
|
||||
|
||||
const logger = loggerService.withContext('LanTransferClientService')
|
||||
|
||||
/**
|
||||
* LAN Transfer Client Service
|
||||
*
|
||||
* Handles outgoing file transfers to LAN peers via TCP.
|
||||
* Protocol v1 with streaming mode (no per-chunk acknowledgment).
|
||||
*/
|
||||
class LanTransferClientService {
|
||||
private socket: Socket | null = null
|
||||
private currentPeer?: LocalTransferPeer
|
||||
private dataHandler?: ReturnType<typeof createDataHandler>
|
||||
private responseManager = new ResponseManager()
|
||||
private isConnecting = false
|
||||
private activeTransfer?: ActiveFileTransfer
|
||||
private lastConnectOptions?: LocalTransferConnectPayload
|
||||
private consecutiveJsonErrors = 0
|
||||
private static readonly MAX_CONSECUTIVE_JSON_ERRORS = 3
|
||||
private reconnectPromise: Promise<void> | null = null
|
||||
|
||||
constructor() {
|
||||
this.responseManager.setTimeoutCallback(() => void this.disconnect())
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect to a LAN peer and perform handshake.
|
||||
*/
|
||||
public async connectAndHandshake(options: LocalTransferConnectPayload): Promise<LanHandshakeAckMessage> {
|
||||
if (this.isConnecting) {
|
||||
throw new Error('LAN transfer client is busy')
|
||||
}
|
||||
|
||||
const peer = localTransferService.getPeerById(options.peerId)
|
||||
if (!peer) {
|
||||
throw new Error('Selected LAN peer is no longer available')
|
||||
}
|
||||
if (!peer.port) {
|
||||
throw new Error('Selected peer does not expose a TCP port')
|
||||
}
|
||||
|
||||
const host = pickHost(peer)
|
||||
if (!host) {
|
||||
throw new Error('Unable to resolve a reachable host for the peer')
|
||||
}
|
||||
|
||||
await this.disconnect()
|
||||
this.isConnecting = true
|
||||
|
||||
return new Promise<LanHandshakeAckMessage>((resolve, reject) => {
|
||||
const socket = createConnection({ host, port: peer.port as number }, () => {
|
||||
logger.info(`Connected to LAN peer ${peer.name} (${host}:${peer.port})`)
|
||||
socket.setKeepAlive(true, 30_000)
|
||||
this.socket = socket
|
||||
this.currentPeer = peer
|
||||
this.attachSocketListeners(socket)
|
||||
|
||||
this.responseManager.waitForResponse(
|
||||
'handshake_ack',
|
||||
options.timeoutMs ?? DEFAULT_HANDSHAKE_TIMEOUT_MS,
|
||||
(payload) => {
|
||||
const ack = payload as LanHandshakeAckMessage
|
||||
if (!ack.accepted) {
|
||||
const message = ack.message || 'Handshake rejected by remote device'
|
||||
logger.warn(`Handshake rejected by ${peer.name}: ${message}`)
|
||||
this.broadcastClientEvent({
|
||||
type: 'error',
|
||||
message,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
reject(new Error(message))
|
||||
void this.disconnect()
|
||||
return
|
||||
}
|
||||
logger.info(`Handshake accepted by ${peer.name}`)
|
||||
socket.setTimeout(0)
|
||||
this.isConnecting = false
|
||||
this.lastConnectOptions = options
|
||||
sendTestPing(this.createConnectionContext())
|
||||
resolve(ack)
|
||||
},
|
||||
(error) => {
|
||||
this.isConnecting = false
|
||||
reject(error)
|
||||
}
|
||||
)
|
||||
|
||||
const handshakeMessage = buildHandshakeMessage()
|
||||
this.sendControlMessage(handshakeMessage)
|
||||
})
|
||||
|
||||
socket.setTimeout(options.timeoutMs ?? DEFAULT_HANDSHAKE_TIMEOUT_MS, () => {
|
||||
const error = new Error('Handshake timed out')
|
||||
logger.error('LAN transfer socket timeout', error)
|
||||
this.broadcastClientEvent({
|
||||
type: 'error',
|
||||
message: error.message,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
reject(error)
|
||||
socket.destroy(error)
|
||||
void this.disconnect()
|
||||
})
|
||||
|
||||
socket.once('error', (error) => {
|
||||
logger.error('LAN transfer socket error', error as Error)
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
this.broadcastClientEvent({
|
||||
type: 'error',
|
||||
message,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
this.isConnecting = false
|
||||
reject(error instanceof Error ? error : new Error(message))
|
||||
void this.disconnect()
|
||||
})
|
||||
|
||||
socket.once('close', () => {
|
||||
logger.info('LAN transfer socket closed')
|
||||
if (this.socket === socket) {
|
||||
this.socket = null
|
||||
this.dataHandler?.resetBuffer()
|
||||
this.responseManager.rejectAll(new Error('LAN transfer socket closed'))
|
||||
this.currentPeer = undefined
|
||||
abortTransfer(this.activeTransfer, new Error('LAN transfer socket closed'))
|
||||
}
|
||||
this.isConnecting = false
|
||||
this.broadcastClientEvent({
|
||||
type: 'socket_closed',
|
||||
reason: 'connection_closed',
|
||||
timestamp: Date.now()
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Disconnect from the current peer.
|
||||
*/
|
||||
public async disconnect(): Promise<void> {
|
||||
const socket = this.socket
|
||||
if (!socket) {
|
||||
return
|
||||
}
|
||||
|
||||
this.socket = null
|
||||
this.dataHandler?.resetBuffer()
|
||||
this.currentPeer = undefined
|
||||
this.responseManager.rejectAll(new Error('LAN transfer socket disconnected'))
|
||||
abortTransfer(this.activeTransfer, new Error('LAN transfer socket disconnected'))
|
||||
|
||||
const DISCONNECT_TIMEOUT_MS = 3000
|
||||
await new Promise<void>((resolve) => {
|
||||
const timeout = setTimeout(() => {
|
||||
logger.warn('Disconnect timeout, forcing cleanup')
|
||||
socket.removeAllListeners()
|
||||
resolve()
|
||||
}, DISCONNECT_TIMEOUT_MS)
|
||||
|
||||
socket.once('close', () => {
|
||||
clearTimeout(timeout)
|
||||
resolve()
|
||||
})
|
||||
|
||||
socket.destroy()
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose the service and clean up all resources.
|
||||
*/
|
||||
public dispose(): void {
|
||||
this.responseManager.rejectAll(new Error('LAN transfer client disposed'))
|
||||
cleanupTransfer(this.activeTransfer)
|
||||
this.activeTransfer = undefined
|
||||
if (this.socket) {
|
||||
this.socket.destroy()
|
||||
this.socket = null
|
||||
}
|
||||
this.dataHandler?.resetBuffer()
|
||||
this.isConnecting = false
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a ZIP file to the connected peer.
|
||||
*/
|
||||
public async sendFile(filePath: string): Promise<LanFileCompleteMessage> {
|
||||
await this.ensureConnection()
|
||||
|
||||
if (this.activeTransfer) {
|
||||
throw new Error('A file transfer is already in progress')
|
||||
}
|
||||
|
||||
// Validate file
|
||||
const { stats, fileName } = await validateFile(filePath)
|
||||
|
||||
// Calculate checksum
|
||||
logger.info('Calculating file checksum...')
|
||||
const checksum = await calculateFileChecksum(filePath)
|
||||
logger.info(`File checksum: ${checksum.substring(0, 16)}...`)
|
||||
|
||||
// Connection can drop while validating/checking file; ensure it is still ready before starting transfer.
|
||||
await this.ensureConnection()
|
||||
|
||||
// Initialize transfer state
|
||||
const transferId = crypto.randomUUID()
|
||||
this.activeTransfer = createTransferState(transferId, fileName, stats.size, checksum)
|
||||
|
||||
logger.info(
|
||||
`Starting file transfer: ${fileName} (${formatFileSize(stats.size)}, ${this.activeTransfer.totalChunks} chunks)`
|
||||
)
|
||||
|
||||
// Global timeout
|
||||
const globalTimeoutError = new Error('Transfer timed out (global timeout exceeded)')
|
||||
const globalTimeoutHandle = setTimeout(() => {
|
||||
logger.warn('Global transfer timeout exceeded, aborting transfer', { transferId, fileName })
|
||||
abortTransfer(this.activeTransfer, globalTimeoutError)
|
||||
}, LAN_TRANSFER_GLOBAL_TIMEOUT_MS)
|
||||
|
||||
try {
|
||||
const result = await this.performFileTransfer(filePath, transferId, fileName)
|
||||
return result
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
logger.error(`File transfer failed: ${message}`)
|
||||
|
||||
this.broadcastClientEvent({
|
||||
type: 'file_transfer_complete',
|
||||
transferId,
|
||||
fileName,
|
||||
success: false,
|
||||
error: message,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
|
||||
throw error
|
||||
} finally {
|
||||
clearTimeout(globalTimeoutHandle)
|
||||
cleanupTransfer(this.activeTransfer)
|
||||
this.activeTransfer = undefined
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancel the current file transfer.
|
||||
*/
|
||||
public cancelTransfer(): void {
|
||||
if (!this.activeTransfer) {
|
||||
logger.warn('No active transfer to cancel')
|
||||
return
|
||||
}
|
||||
|
||||
const { transferId, fileName } = this.activeTransfer
|
||||
logger.info(`Cancelling file transfer: ${fileName}`)
|
||||
|
||||
this.activeTransfer.isCancelled = true
|
||||
|
||||
try {
|
||||
this.sendControlMessage({
|
||||
type: 'file_cancel',
|
||||
transferId,
|
||||
reason: 'Cancelled by user'
|
||||
})
|
||||
} catch (error) {
|
||||
// Expected when connection is already broken
|
||||
logger.warn('Failed to send cancel message', error as Error)
|
||||
}
|
||||
|
||||
abortTransfer(this.activeTransfer, new Error('Transfer cancelled by user'))
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Private Methods
|
||||
// =============================================================================
|
||||
|
||||
private async ensureConnection(): Promise<void> {
|
||||
// Check socket is valid and writable (not just undestroyed)
|
||||
if (this.socket && !this.socket.destroyed && this.socket.writable && this.currentPeer) {
|
||||
return
|
||||
}
|
||||
|
||||
if (!this.lastConnectOptions) {
|
||||
throw new Error('No active connection. Please connect to a peer first.')
|
||||
}
|
||||
|
||||
// Prevent concurrent reconnection attempts
|
||||
if (this.reconnectPromise) {
|
||||
logger.debug('Waiting for existing reconnection attempt...')
|
||||
await this.reconnectPromise
|
||||
return
|
||||
}
|
||||
|
||||
logger.info('Connection lost, attempting to reconnect...')
|
||||
this.reconnectPromise = this.connectAndHandshake(this.lastConnectOptions)
|
||||
.then(() => {
|
||||
// Handshake succeeded, connection restored
|
||||
})
|
||||
.finally(() => {
|
||||
this.reconnectPromise = null
|
||||
})
|
||||
|
||||
await this.reconnectPromise
|
||||
}
|
||||
|
||||
private async performFileTransfer(
|
||||
filePath: string,
|
||||
transferId: string,
|
||||
fileName: string
|
||||
): Promise<LanFileCompleteMessage> {
|
||||
const transfer = this.activeTransfer!
|
||||
const ctx = this.createFileTransferContext()
|
||||
|
||||
// Step 1: Send file_start
|
||||
sendFileStart(ctx, transfer)
|
||||
|
||||
// Step 2: Wait for file_start_ack
|
||||
const startAck = await waitForFileStartAck(ctx, transferId, transfer.abortController.signal)
|
||||
if (!startAck.accepted) {
|
||||
throw new Error(startAck.message || 'Transfer rejected by receiver')
|
||||
}
|
||||
logger.info('Received file_start_ack: accepted')
|
||||
|
||||
// Step 3: Stream file chunks
|
||||
await streamFileChunks(this.socket!, filePath, transfer, transfer.abortController.signal, (bytesSent, chunkIndex) =>
|
||||
this.onTransferProgress(transfer, bytesSent, chunkIndex)
|
||||
)
|
||||
|
||||
// Step 4: Send file_end
|
||||
sendFileEnd(ctx, transferId)
|
||||
|
||||
// Step 5: Wait for file_complete
|
||||
const result = await waitForFileComplete(ctx, transferId, transfer.abortController.signal)
|
||||
logger.info(`File transfer ${result.success ? 'completed' : 'failed'}`)
|
||||
|
||||
// Broadcast completion
|
||||
this.broadcastClientEvent({
|
||||
type: 'file_transfer_complete',
|
||||
transferId,
|
||||
fileName,
|
||||
success: result.success,
|
||||
filePath: result.filePath,
|
||||
error: result.error,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
private onTransferProgress(transfer: ActiveFileTransfer, bytesSent: number, chunkIndex: number): void {
|
||||
const progress = (bytesSent / transfer.fileSize) * 100
|
||||
const elapsed = (Date.now() - transfer.startedAt) / 1000
|
||||
const speed = elapsed > 0 ? bytesSent / elapsed : 0
|
||||
|
||||
this.broadcastClientEvent({
|
||||
type: 'file_transfer_progress',
|
||||
transferId: transfer.transferId,
|
||||
fileName: transfer.fileName,
|
||||
bytesSent,
|
||||
totalBytes: transfer.fileSize,
|
||||
chunkIndex,
|
||||
totalChunks: transfer.totalChunks,
|
||||
progress: Math.round(progress * 100) / 100,
|
||||
speed,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
}
|
||||
|
||||
private attachSocketListeners(socket: Socket): void {
|
||||
this.dataHandler = createDataHandler((line) => this.handleControlLine(line))
|
||||
socket.on('data', (chunk: Buffer) => {
|
||||
try {
|
||||
this.dataHandler?.handleData(chunk)
|
||||
} catch (error) {
|
||||
logger.error('Data handler error', error as Error)
|
||||
void this.disconnect()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private handleControlLine(line: string): void {
|
||||
let payload: Record<string, unknown>
|
||||
try {
|
||||
payload = JSON.parse(line)
|
||||
this.consecutiveJsonErrors = 0 // Reset on successful parse
|
||||
} catch {
|
||||
this.consecutiveJsonErrors++
|
||||
logger.warn('Received invalid JSON control message', { line, consecutiveErrors: this.consecutiveJsonErrors })
|
||||
|
||||
if (this.consecutiveJsonErrors >= LanTransferClientService.MAX_CONSECUTIVE_JSON_ERRORS) {
|
||||
const message = `Protocol error: ${this.consecutiveJsonErrors} consecutive invalid messages, disconnecting`
|
||||
logger.error(message)
|
||||
this.broadcastClientEvent({
|
||||
type: 'error',
|
||||
message,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
void this.disconnect()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const type = payload?.type as string | undefined
|
||||
if (!type) {
|
||||
logger.warn('Received control message without type', payload)
|
||||
return
|
||||
}
|
||||
|
||||
// Try to resolve a pending response
|
||||
const transferId = payload?.transferId as string | undefined
|
||||
const chunkIndex = payload?.chunkIndex as number | undefined
|
||||
if (this.responseManager.tryResolve(type, payload, transferId, chunkIndex)) {
|
||||
return
|
||||
}
|
||||
|
||||
logger.info('Received control message', payload)
|
||||
|
||||
if (type === 'pong') {
|
||||
this.broadcastClientEvent({
|
||||
type: 'pong',
|
||||
payload: payload?.payload as string | undefined,
|
||||
received: payload?.received as boolean | undefined,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Ignore late-arriving file transfer messages
|
||||
const fileTransferMessageTypes = ['file_start_ack', 'file_complete']
|
||||
if (fileTransferMessageTypes.includes(type)) {
|
||||
logger.debug('Ignoring late file transfer message', { type, payload })
|
||||
return
|
||||
}
|
||||
|
||||
this.broadcastClientEvent({
|
||||
type: 'error',
|
||||
message: `Unexpected control message type: ${type}`,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
}
|
||||
|
||||
private sendControlMessage(message: Record<string, unknown>): void {
|
||||
if (!this.socket || this.socket.destroyed || !this.socket.writable) {
|
||||
throw new Error('Socket is not connected')
|
||||
}
|
||||
const payload = JSON.stringify(message)
|
||||
this.socket.write(`${payload}\n`)
|
||||
}
|
||||
|
||||
private createConnectionContext(): ConnectionContext {
|
||||
return {
|
||||
socket: this.socket,
|
||||
currentPeer: this.currentPeer,
|
||||
sendControlMessage: (msg) => this.sendControlMessage(msg),
|
||||
broadcastClientEvent: (event) => this.broadcastClientEvent(event)
|
||||
}
|
||||
}
|
||||
|
||||
private createFileTransferContext(): FileTransferContext {
|
||||
return {
|
||||
...this.createConnectionContext(),
|
||||
activeTransfer: this.activeTransfer,
|
||||
setActiveTransfer: (transfer) => {
|
||||
this.activeTransfer = transfer
|
||||
},
|
||||
waitForResponse: (type, timeoutMs, resolve, reject, transferId, chunkIndex, abortSignal) => {
|
||||
this.responseManager.waitForResponse(type, timeoutMs, resolve, reject, transferId, chunkIndex, abortSignal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private broadcastClientEvent(event: LanClientEvent): void {
|
||||
const mainWindow = windowService.getMainWindow()
|
||||
if (!mainWindow || mainWindow.isDestroyed()) {
|
||||
return
|
||||
}
|
||||
mainWindow.webContents.send(IpcChannel.LocalTransfer_ClientEvent, {
|
||||
...event,
|
||||
peerId: event.peerId ?? this.currentPeer?.id,
|
||||
peerName: event.peerName ?? this.currentPeer?.name
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export const lanTransferClientService = new LanTransferClientService()
|
||||
|
||||
// Re-export for backward compatibility
|
||||
export { HANDSHAKE_PROTOCOL_VERSION }
|
||||
@ -0,0 +1,133 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
// Mock dependencies before importing the service
|
||||
vi.mock('node:net', async (importOriginal) => {
|
||||
const actual = (await importOriginal()) as Record<string, unknown>
|
||||
return {
|
||||
...actual,
|
||||
createConnection: vi.fn()
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('electron', () => ({
|
||||
app: {
|
||||
getName: vi.fn(() => 'Cherry Studio'),
|
||||
getVersion: vi.fn(() => '1.0.0')
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../LocalTransferService', () => ({
|
||||
localTransferService: {
|
||||
getPeerById: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../WindowService', () => ({
|
||||
windowService: {
|
||||
getMainWindow: vi.fn(() => ({
|
||||
isDestroyed: () => false,
|
||||
webContents: {
|
||||
send: vi.fn()
|
||||
}
|
||||
}))
|
||||
}
|
||||
}))
|
||||
|
||||
// Import after mocks
|
||||
import { localTransferService } from '../../LocalTransferService'
|
||||
|
||||
describe('LanTransferClientService', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.resetModules()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks()
|
||||
})
|
||||
|
||||
describe('connectAndHandshake - validation', () => {
|
||||
it('should throw error when peer is not found', async () => {
|
||||
vi.mocked(localTransferService.getPeerById).mockReturnValue(undefined)
|
||||
|
||||
const { lanTransferClientService } = await import('../LanTransferClientService')
|
||||
|
||||
await expect(
|
||||
lanTransferClientService.connectAndHandshake({
|
||||
peerId: 'non-existent'
|
||||
})
|
||||
).rejects.toThrow('Selected LAN peer is no longer available')
|
||||
})
|
||||
|
||||
it('should throw error when peer has no port', async () => {
|
||||
vi.mocked(localTransferService.getPeerById).mockReturnValue({
|
||||
id: 'test-peer',
|
||||
name: 'Test Peer',
|
||||
addresses: ['192.168.1.100'],
|
||||
updatedAt: Date.now()
|
||||
})
|
||||
|
||||
const { lanTransferClientService } = await import('../LanTransferClientService')
|
||||
|
||||
await expect(
|
||||
lanTransferClientService.connectAndHandshake({
|
||||
peerId: 'test-peer'
|
||||
})
|
||||
).rejects.toThrow('Selected peer does not expose a TCP port')
|
||||
})
|
||||
|
||||
it('should throw error when no reachable host', async () => {
|
||||
vi.mocked(localTransferService.getPeerById).mockReturnValue({
|
||||
id: 'test-peer',
|
||||
name: 'Test Peer',
|
||||
port: 12345,
|
||||
addresses: [],
|
||||
updatedAt: Date.now()
|
||||
})
|
||||
|
||||
const { lanTransferClientService } = await import('../LanTransferClientService')
|
||||
|
||||
await expect(
|
||||
lanTransferClientService.connectAndHandshake({
|
||||
peerId: 'test-peer'
|
||||
})
|
||||
).rejects.toThrow('Unable to resolve a reachable host for the peer')
|
||||
})
|
||||
})
|
||||
|
||||
describe('cancelTransfer', () => {
|
||||
it('should not throw when no active transfer', async () => {
|
||||
const { lanTransferClientService } = await import('../LanTransferClientService')
|
||||
|
||||
// Should not throw, just log warning
|
||||
expect(() => lanTransferClientService.cancelTransfer()).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('dispose', () => {
|
||||
it('should clean up resources without throwing', async () => {
|
||||
const { lanTransferClientService } = await import('../LanTransferClientService')
|
||||
|
||||
// Should not throw
|
||||
expect(() => lanTransferClientService.dispose()).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('sendFile', () => {
|
||||
it('should throw error when not connected', async () => {
|
||||
const { lanTransferClientService } = await import('../LanTransferClientService')
|
||||
|
||||
await expect(lanTransferClientService.sendFile('/path/to/file.zip')).rejects.toThrow(
|
||||
'No active connection. Please connect to a peer first.'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('HANDSHAKE_PROTOCOL_VERSION', () => {
|
||||
it('should export protocol version', async () => {
|
||||
const { HANDSHAKE_PROTOCOL_VERSION } = await import('../LanTransferClientService')
|
||||
|
||||
expect(HANDSHAKE_PROTOCOL_VERSION).toBe('1')
|
||||
})
|
||||
})
|
||||
})
|
||||
103
src/main/services/lanTransfer/__tests__/binaryProtocol.test.ts
Normal file
103
src/main/services/lanTransfer/__tests__/binaryProtocol.test.ts
Normal file
@ -0,0 +1,103 @@
|
||||
import { EventEmitter } from 'node:events'
|
||||
import type { Socket } from 'node:net'
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { BINARY_TYPE_FILE_CHUNK, sendBinaryChunk } from '../binaryProtocol'
|
||||
|
||||
describe('binaryProtocol', () => {
|
||||
describe('sendBinaryChunk', () => {
|
||||
let mockSocket: Socket
|
||||
let writtenBuffers: Buffer[]
|
||||
|
||||
beforeEach(() => {
|
||||
writtenBuffers = []
|
||||
mockSocket = Object.assign(new EventEmitter(), {
|
||||
destroyed: false,
|
||||
writable: true,
|
||||
write: vi.fn((buffer: Buffer) => {
|
||||
writtenBuffers.push(Buffer.from(buffer))
|
||||
return true
|
||||
}),
|
||||
cork: vi.fn(),
|
||||
uncork: vi.fn()
|
||||
}) as unknown as Socket
|
||||
})
|
||||
|
||||
it('should send binary chunk with correct frame format', () => {
|
||||
const transferId = 'test-uuid-1234'
|
||||
const chunkIndex = 5
|
||||
const data = Buffer.from('test data chunk')
|
||||
|
||||
const result = sendBinaryChunk(mockSocket, transferId, chunkIndex, data)
|
||||
|
||||
expect(result).toBe(true)
|
||||
expect(mockSocket.cork).toHaveBeenCalled()
|
||||
expect(mockSocket.uncork).toHaveBeenCalled()
|
||||
expect(mockSocket.write).toHaveBeenCalledTimes(2)
|
||||
|
||||
// Verify header structure
|
||||
const header = writtenBuffers[0]
|
||||
|
||||
// Magic bytes "CS"
|
||||
expect(header[0]).toBe(0x43)
|
||||
expect(header[1]).toBe(0x53)
|
||||
|
||||
// Type byte
|
||||
const typeOffset = 2 + 4 // magic + totalLen
|
||||
expect(header[typeOffset]).toBe(BINARY_TYPE_FILE_CHUNK)
|
||||
|
||||
// TransferId length
|
||||
const tidLenOffset = typeOffset + 1
|
||||
const tidLen = header.readUInt16BE(tidLenOffset)
|
||||
expect(tidLen).toBe(Buffer.from(transferId).length)
|
||||
|
||||
// ChunkIndex
|
||||
const chunkIdxOffset = tidLenOffset + 2 + tidLen
|
||||
expect(header.readUInt32BE(chunkIdxOffset)).toBe(chunkIndex)
|
||||
|
||||
// Data buffer
|
||||
expect(writtenBuffers[1].toString()).toBe('test data chunk')
|
||||
})
|
||||
|
||||
it('should return false when socket write returns false (backpressure)', () => {
|
||||
;(mockSocket.write as ReturnType<typeof vi.fn>).mockReturnValueOnce(false)
|
||||
|
||||
const result = sendBinaryChunk(mockSocket, 'test-id', 0, Buffer.from('data'))
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should correctly calculate totalLen in frame header', () => {
|
||||
const transferId = 'uuid-1234'
|
||||
const data = Buffer.from('chunk data here')
|
||||
|
||||
sendBinaryChunk(mockSocket, transferId, 0, data)
|
||||
|
||||
const header = writtenBuffers[0]
|
||||
const totalLen = header.readUInt32BE(2) // After magic bytes
|
||||
|
||||
// totalLen = type(1) + tidLen(2) + tid(n) + idx(4) + data(m)
|
||||
const expectedTotalLen = 1 + 2 + Buffer.from(transferId).length + 4 + data.length
|
||||
expect(totalLen).toBe(expectedTotalLen)
|
||||
})
|
||||
|
||||
it('should throw error when socket is not writable', () => {
|
||||
;(mockSocket as any).writable = false
|
||||
|
||||
expect(() => sendBinaryChunk(mockSocket, 'test-id', 0, Buffer.from('data'))).toThrow('Socket is not writable')
|
||||
})
|
||||
|
||||
it('should throw error when socket is destroyed', () => {
|
||||
;(mockSocket as any).destroyed = true
|
||||
|
||||
expect(() => sendBinaryChunk(mockSocket, 'test-id', 0, Buffer.from('data'))).toThrow('Socket is not writable')
|
||||
})
|
||||
})
|
||||
|
||||
describe('BINARY_TYPE_FILE_CHUNK', () => {
|
||||
it('should be 0x01', () => {
|
||||
expect(BINARY_TYPE_FILE_CHUNK).toBe(0x01)
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,265 @@
|
||||
import { EventEmitter } from 'node:events'
|
||||
import type { Socket } from 'node:net'
|
||||
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
buildHandshakeMessage,
|
||||
createDataHandler,
|
||||
getAbortError,
|
||||
HANDSHAKE_PROTOCOL_VERSION,
|
||||
pickHost,
|
||||
waitForSocketDrain
|
||||
} from '../../handlers/connection'
|
||||
|
||||
// Mock electron app
|
||||
vi.mock('electron', () => ({
|
||||
app: {
|
||||
getName: vi.fn(() => 'Cherry Studio'),
|
||||
getVersion: vi.fn(() => '1.0.0')
|
||||
}
|
||||
}))
|
||||
|
||||
describe('connection handlers', () => {
|
||||
describe('buildHandshakeMessage', () => {
|
||||
it('should build handshake message with correct structure', () => {
|
||||
const message = buildHandshakeMessage()
|
||||
|
||||
expect(message.type).toBe('handshake')
|
||||
expect(message.deviceName).toBe('Cherry Studio')
|
||||
expect(message.version).toBe(HANDSHAKE_PROTOCOL_VERSION)
|
||||
expect(message.appVersion).toBe('1.0.0')
|
||||
expect(typeof message.platform).toBe('string')
|
||||
})
|
||||
|
||||
it('should use protocol version 1', () => {
|
||||
expect(HANDSHAKE_PROTOCOL_VERSION).toBe('1')
|
||||
})
|
||||
})
|
||||
|
||||
describe('pickHost', () => {
|
||||
it('should prefer IPv4 addresses', () => {
|
||||
const peer = {
|
||||
id: '1',
|
||||
name: 'Test',
|
||||
addresses: ['fe80::1', '192.168.1.100', '::1'],
|
||||
updatedAt: Date.now()
|
||||
}
|
||||
|
||||
expect(pickHost(peer)).toBe('192.168.1.100')
|
||||
})
|
||||
|
||||
it('should fall back to first address if no IPv4', () => {
|
||||
const peer = {
|
||||
id: '1',
|
||||
name: 'Test',
|
||||
addresses: ['fe80::1', '::1'],
|
||||
updatedAt: Date.now()
|
||||
}
|
||||
|
||||
expect(pickHost(peer)).toBe('fe80::1')
|
||||
})
|
||||
|
||||
it('should fall back to host property if no addresses', () => {
|
||||
const peer = {
|
||||
id: '1',
|
||||
name: 'Test',
|
||||
host: 'example.local',
|
||||
addresses: [],
|
||||
updatedAt: Date.now()
|
||||
}
|
||||
|
||||
expect(pickHost(peer)).toBe('example.local')
|
||||
})
|
||||
|
||||
it('should return undefined if no addresses or host', () => {
|
||||
const peer = {
|
||||
id: '1',
|
||||
name: 'Test',
|
||||
addresses: [],
|
||||
updatedAt: Date.now()
|
||||
}
|
||||
|
||||
expect(pickHost(peer)).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('createDataHandler', () => {
|
||||
it('should parse complete lines from buffer', () => {
|
||||
const lines: string[] = []
|
||||
const handler = createDataHandler((line) => lines.push(line))
|
||||
|
||||
handler.handleData(Buffer.from('{"type":"test"}\n'))
|
||||
|
||||
expect(lines).toEqual(['{"type":"test"}'])
|
||||
})
|
||||
|
||||
it('should handle partial lines across multiple chunks', () => {
|
||||
const lines: string[] = []
|
||||
const handler = createDataHandler((line) => lines.push(line))
|
||||
|
||||
handler.handleData(Buffer.from('{"type":'))
|
||||
handler.handleData(Buffer.from('"test"}\n'))
|
||||
|
||||
expect(lines).toEqual(['{"type":"test"}'])
|
||||
})
|
||||
|
||||
it('should handle multiple lines in single chunk', () => {
|
||||
const lines: string[] = []
|
||||
const handler = createDataHandler((line) => lines.push(line))
|
||||
|
||||
handler.handleData(Buffer.from('{"a":1}\n{"b":2}\n'))
|
||||
|
||||
expect(lines).toEqual(['{"a":1}', '{"b":2}'])
|
||||
})
|
||||
|
||||
it('should reset buffer', () => {
|
||||
const lines: string[] = []
|
||||
const handler = createDataHandler((line) => lines.push(line))
|
||||
|
||||
handler.handleData(Buffer.from('partial'))
|
||||
handler.resetBuffer()
|
||||
handler.handleData(Buffer.from('{"complete":true}\n'))
|
||||
|
||||
expect(lines).toEqual(['{"complete":true}'])
|
||||
})
|
||||
|
||||
it('should trim whitespace from lines', () => {
|
||||
const lines: string[] = []
|
||||
const handler = createDataHandler((line) => lines.push(line))
|
||||
|
||||
handler.handleData(Buffer.from(' {"type":"test"} \n'))
|
||||
|
||||
expect(lines).toEqual(['{"type":"test"}'])
|
||||
})
|
||||
|
||||
it('should skip empty lines', () => {
|
||||
const lines: string[] = []
|
||||
const handler = createDataHandler((line) => lines.push(line))
|
||||
|
||||
handler.handleData(Buffer.from('\n\n{"type":"test"}\n\n'))
|
||||
|
||||
expect(lines).toEqual(['{"type":"test"}'])
|
||||
})
|
||||
|
||||
it('should throw error when buffer exceeds MAX_LINE_BUFFER_SIZE', () => {
|
||||
const handler = createDataHandler(vi.fn())
|
||||
|
||||
// Create a buffer larger than 1MB (MAX_LINE_BUFFER_SIZE)
|
||||
const largeData = 'x'.repeat(1024 * 1024 + 1)
|
||||
|
||||
expect(() => handler.handleData(Buffer.from(largeData))).toThrow('Control message too large')
|
||||
})
|
||||
|
||||
it('should reset buffer after exceeding MAX_LINE_BUFFER_SIZE', () => {
|
||||
const lines: string[] = []
|
||||
const handler = createDataHandler((line) => lines.push(line))
|
||||
|
||||
// Create a buffer larger than 1MB
|
||||
const largeData = 'x'.repeat(1024 * 1024 + 1)
|
||||
|
||||
try {
|
||||
handler.handleData(Buffer.from(largeData))
|
||||
} catch {
|
||||
// Expected error
|
||||
}
|
||||
|
||||
// Buffer should be reset, so lineBuffer should be empty
|
||||
expect(handler.lineBuffer).toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
describe('waitForSocketDrain', () => {
|
||||
let mockSocket: Socket & EventEmitter
|
||||
|
||||
beforeEach(() => {
|
||||
mockSocket = Object.assign(new EventEmitter(), {
|
||||
destroyed: false,
|
||||
writable: true,
|
||||
write: vi.fn(),
|
||||
off: vi.fn(),
|
||||
removeAllListeners: vi.fn()
|
||||
}) as unknown as Socket & EventEmitter
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks()
|
||||
})
|
||||
|
||||
it('should throw error when abort signal is already aborted', async () => {
|
||||
const abortController = new AbortController()
|
||||
abortController.abort(new Error('Already aborted'))
|
||||
|
||||
await expect(waitForSocketDrain(mockSocket, abortController.signal)).rejects.toThrow('Already aborted')
|
||||
})
|
||||
|
||||
it('should throw error when socket is destroyed', async () => {
|
||||
;(mockSocket as any).destroyed = true
|
||||
const abortController = new AbortController()
|
||||
|
||||
await expect(waitForSocketDrain(mockSocket, abortController.signal)).rejects.toThrow('Socket is closed')
|
||||
})
|
||||
|
||||
it('should resolve when drain event is emitted', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const drainPromise = waitForSocketDrain(mockSocket, abortController.signal)
|
||||
|
||||
// Emit drain event after a short delay
|
||||
setImmediate(() => mockSocket.emit('drain'))
|
||||
|
||||
await expect(drainPromise).resolves.toBeUndefined()
|
||||
})
|
||||
|
||||
it('should reject when close event is emitted', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const drainPromise = waitForSocketDrain(mockSocket, abortController.signal)
|
||||
|
||||
setImmediate(() => mockSocket.emit('close'))
|
||||
|
||||
await expect(drainPromise).rejects.toThrow('Socket closed while waiting for drain')
|
||||
})
|
||||
|
||||
it('should reject when error event is emitted', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const drainPromise = waitForSocketDrain(mockSocket, abortController.signal)
|
||||
|
||||
setImmediate(() => mockSocket.emit('error', new Error('Network error')))
|
||||
|
||||
await expect(drainPromise).rejects.toThrow('Network error')
|
||||
})
|
||||
|
||||
it('should reject when abort signal is triggered', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const drainPromise = waitForSocketDrain(mockSocket, abortController.signal)
|
||||
|
||||
setImmediate(() => abortController.abort(new Error('User cancelled')))
|
||||
|
||||
await expect(drainPromise).rejects.toThrow('User cancelled')
|
||||
})
|
||||
})
|
||||
|
||||
describe('getAbortError', () => {
|
||||
it('should return Error reason directly', () => {
|
||||
const originalError = new Error('Original')
|
||||
const signal = { aborted: true, reason: originalError } as AbortSignal
|
||||
|
||||
expect(getAbortError(signal, 'Fallback')).toBe(originalError)
|
||||
})
|
||||
|
||||
it('should create Error from string reason', () => {
|
||||
const signal = { aborted: true, reason: 'String reason' } as AbortSignal
|
||||
|
||||
expect(getAbortError(signal, 'Fallback').message).toBe('String reason')
|
||||
})
|
||||
|
||||
it('should use fallback for empty reason', () => {
|
||||
const signal = { aborted: true, reason: '' } as AbortSignal
|
||||
|
||||
expect(getAbortError(signal, 'Fallback').message).toBe('Fallback')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,216 @@
|
||||
import { EventEmitter } from 'node:events'
|
||||
import type * as fs from 'node:fs'
|
||||
import type { Socket } from 'node:net'
|
||||
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
abortTransfer,
|
||||
cleanupTransfer,
|
||||
createTransferState,
|
||||
formatFileSize,
|
||||
streamFileChunks
|
||||
} from '../../handlers/fileTransfer'
|
||||
import type { ActiveFileTransfer } from '../../types'
|
||||
|
||||
// Mock binaryProtocol
|
||||
vi.mock('../../binaryProtocol', () => ({
|
||||
sendBinaryChunk: vi.fn().mockReturnValue(true)
|
||||
}))
|
||||
|
||||
// Mock connection handlers
|
||||
vi.mock('./connection', () => ({
|
||||
waitForSocketDrain: vi.fn().mockResolvedValue(undefined),
|
||||
getAbortError: vi.fn((signal, fallback) => {
|
||||
const reason = (signal as AbortSignal & { reason?: unknown }).reason
|
||||
if (reason instanceof Error) return reason
|
||||
if (typeof reason === 'string' && reason.length > 0) return new Error(reason)
|
||||
return new Error(fallback)
|
||||
})
|
||||
}))
|
||||
|
||||
// Note: validateFile and calculateFileChecksum tests are skipped because
|
||||
// the test environment has globally mocked node:fs and node:os modules.
|
||||
// These functions are tested through integration tests instead.
|
||||
|
||||
describe('fileTransfer handlers', () => {
|
||||
describe('createTransferState', () => {
|
||||
it('should create transfer state with correct defaults', () => {
|
||||
const state = createTransferState('uuid-123', 'test.zip', 1024000, 'abc123')
|
||||
|
||||
expect(state.transferId).toBe('uuid-123')
|
||||
expect(state.fileName).toBe('test.zip')
|
||||
expect(state.fileSize).toBe(1024000)
|
||||
expect(state.checksum).toBe('abc123')
|
||||
expect(state.bytesSent).toBe(0)
|
||||
expect(state.currentChunk).toBe(0)
|
||||
expect(state.isCancelled).toBe(false)
|
||||
expect(state.abortController).toBeInstanceOf(AbortController)
|
||||
})
|
||||
|
||||
it('should calculate totalChunks based on chunk size', () => {
|
||||
// 512KB chunk size
|
||||
const state = createTransferState('id', 'test.zip', 1024 * 1024, 'checksum') // 1MB
|
||||
|
||||
expect(state.totalChunks).toBe(2) // 1MB / 512KB = 2
|
||||
})
|
||||
})
|
||||
|
||||
describe('abortTransfer', () => {
|
||||
it('should abort transfer and destroy stream', () => {
|
||||
const mockStream = {
|
||||
destroyed: false,
|
||||
destroy: vi.fn()
|
||||
} as unknown as fs.ReadStream
|
||||
|
||||
const transfer: ActiveFileTransfer = {
|
||||
transferId: 'test',
|
||||
fileName: 'test.zip',
|
||||
fileSize: 1000,
|
||||
checksum: 'abc',
|
||||
totalChunks: 1,
|
||||
chunkSize: 512000,
|
||||
bytesSent: 0,
|
||||
currentChunk: 0,
|
||||
startedAt: Date.now(),
|
||||
stream: mockStream,
|
||||
isCancelled: false,
|
||||
abortController: new AbortController()
|
||||
}
|
||||
|
||||
const error = new Error('Test abort')
|
||||
abortTransfer(transfer, error)
|
||||
|
||||
expect(transfer.isCancelled).toBe(true)
|
||||
expect(transfer.abortController.signal.aborted).toBe(true)
|
||||
expect(mockStream.destroy).toHaveBeenCalledWith(error)
|
||||
})
|
||||
|
||||
it('should handle undefined transfer', () => {
|
||||
expect(() => abortTransfer(undefined, new Error('test'))).not.toThrow()
|
||||
})
|
||||
|
||||
it('should not abort already aborted controller', () => {
|
||||
const transfer: ActiveFileTransfer = {
|
||||
transferId: 'test',
|
||||
fileName: 'test.zip',
|
||||
fileSize: 1000,
|
||||
checksum: 'abc',
|
||||
totalChunks: 1,
|
||||
chunkSize: 512000,
|
||||
bytesSent: 0,
|
||||
currentChunk: 0,
|
||||
startedAt: Date.now(),
|
||||
isCancelled: false,
|
||||
abortController: new AbortController()
|
||||
}
|
||||
|
||||
transfer.abortController.abort()
|
||||
|
||||
// Should not throw when aborting again
|
||||
expect(() => abortTransfer(transfer, new Error('test'))).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('cleanupTransfer', () => {
|
||||
it('should cleanup transfer resources', () => {
|
||||
const mockStream = {
|
||||
destroyed: false,
|
||||
destroy: vi.fn()
|
||||
} as unknown as fs.ReadStream
|
||||
|
||||
const transfer: ActiveFileTransfer = {
|
||||
transferId: 'test',
|
||||
fileName: 'test.zip',
|
||||
fileSize: 1000,
|
||||
checksum: 'abc',
|
||||
totalChunks: 1,
|
||||
chunkSize: 512000,
|
||||
bytesSent: 0,
|
||||
currentChunk: 0,
|
||||
startedAt: Date.now(),
|
||||
stream: mockStream,
|
||||
isCancelled: false,
|
||||
abortController: new AbortController()
|
||||
}
|
||||
|
||||
cleanupTransfer(transfer)
|
||||
|
||||
expect(transfer.abortController.signal.aborted).toBe(true)
|
||||
expect(mockStream.destroy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle undefined transfer', () => {
|
||||
expect(() => cleanupTransfer(undefined)).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('formatFileSize', () => {
|
||||
it('should format 0 bytes', () => {
|
||||
expect(formatFileSize(0)).toBe('0 B')
|
||||
})
|
||||
|
||||
it('should format bytes', () => {
|
||||
expect(formatFileSize(500)).toBe('500 B')
|
||||
})
|
||||
|
||||
it('should format kilobytes', () => {
|
||||
expect(formatFileSize(1024)).toBe('1 KB')
|
||||
expect(formatFileSize(2048)).toBe('2 KB')
|
||||
})
|
||||
|
||||
it('should format megabytes', () => {
|
||||
expect(formatFileSize(1024 * 1024)).toBe('1 MB')
|
||||
expect(formatFileSize(5 * 1024 * 1024)).toBe('5 MB')
|
||||
})
|
||||
|
||||
it('should format gigabytes', () => {
|
||||
expect(formatFileSize(1024 * 1024 * 1024)).toBe('1 GB')
|
||||
})
|
||||
|
||||
it('should format with decimal precision', () => {
|
||||
expect(formatFileSize(1536)).toBe('1.5 KB')
|
||||
expect(formatFileSize(1.5 * 1024 * 1024)).toBe('1.5 MB')
|
||||
})
|
||||
})
|
||||
|
||||
// Note: streamFileChunks tests require careful mocking of fs.createReadStream
|
||||
// which is globally mocked in the test environment. These tests verify the
|
||||
// streaming logic works correctly with mock streams.
|
||||
describe('streamFileChunks', () => {
|
||||
let mockSocket: Socket & EventEmitter
|
||||
let mockProgress: ReturnType<typeof vi.fn>
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
mockSocket = Object.assign(new EventEmitter(), {
|
||||
destroyed: false,
|
||||
writable: true,
|
||||
write: vi.fn().mockReturnValue(true),
|
||||
cork: vi.fn(),
|
||||
uncork: vi.fn()
|
||||
}) as unknown as Socket & EventEmitter
|
||||
|
||||
mockProgress = vi.fn()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks()
|
||||
})
|
||||
|
||||
it('should throw when abort signal is already aborted', async () => {
|
||||
const transfer = createTransferState('test-id', 'test.zip', 1024, 'checksum')
|
||||
transfer.abortController.abort(new Error('Already cancelled'))
|
||||
|
||||
await expect(
|
||||
streamFileChunks(mockSocket, '/fake/path.zip', transfer, transfer.abortController.signal, mockProgress)
|
||||
).rejects.toThrow()
|
||||
})
|
||||
|
||||
// Note: Full integration testing of streamFileChunks with actual file streaming
|
||||
// requires a real file system, which cannot be easily mocked in ESM.
|
||||
// The abort signal test above verifies the early abort path.
|
||||
// Additional streaming tests are covered through integration tests.
|
||||
})
|
||||
})
|
||||
177
src/main/services/lanTransfer/__tests__/responseManager.test.ts
Normal file
177
src/main/services/lanTransfer/__tests__/responseManager.test.ts
Normal file
@ -0,0 +1,177 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { ResponseManager } from '../responseManager'
|
||||
|
||||
describe('ResponseManager', () => {
|
||||
let manager: ResponseManager
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
manager = new ResponseManager()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
describe('buildResponseKey', () => {
|
||||
it('should build key with type only', () => {
|
||||
expect(manager.buildResponseKey('handshake_ack')).toBe('handshake_ack')
|
||||
})
|
||||
|
||||
it('should build key with type and transferId', () => {
|
||||
expect(manager.buildResponseKey('file_start_ack', 'uuid-123')).toBe('file_start_ack:uuid-123')
|
||||
})
|
||||
|
||||
it('should build key with type, transferId, and chunkIndex', () => {
|
||||
expect(manager.buildResponseKey('file_chunk_ack', 'uuid-123', 5)).toBe('file_chunk_ack:uuid-123:5')
|
||||
})
|
||||
})
|
||||
|
||||
describe('waitForResponse', () => {
|
||||
it('should resolve when tryResolve is called with matching key', async () => {
|
||||
const resolvePromise = new Promise<unknown>((resolve, reject) => {
|
||||
manager.waitForResponse('handshake_ack', 5000, resolve, reject)
|
||||
})
|
||||
|
||||
const payload = { type: 'handshake_ack', accepted: true }
|
||||
const resolved = manager.tryResolve('handshake_ack', payload)
|
||||
|
||||
expect(resolved).toBe(true)
|
||||
await expect(resolvePromise).resolves.toEqual(payload)
|
||||
})
|
||||
|
||||
it('should reject on timeout', async () => {
|
||||
const resolvePromise = new Promise<unknown>((resolve, reject) => {
|
||||
manager.waitForResponse('handshake_ack', 1000, resolve, reject)
|
||||
})
|
||||
|
||||
vi.advanceTimersByTime(1001)
|
||||
|
||||
await expect(resolvePromise).rejects.toThrow('Timeout waiting for handshake_ack')
|
||||
})
|
||||
|
||||
it('should call onTimeout callback when timeout occurs', async () => {
|
||||
const onTimeout = vi.fn()
|
||||
manager.setTimeoutCallback(onTimeout)
|
||||
|
||||
const resolvePromise = new Promise<unknown>((resolve, reject) => {
|
||||
manager.waitForResponse('test', 1000, resolve, reject)
|
||||
})
|
||||
|
||||
vi.advanceTimersByTime(1001)
|
||||
|
||||
await expect(resolvePromise).rejects.toThrow()
|
||||
expect(onTimeout).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should reject when abort signal is triggered', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const resolvePromise = new Promise<unknown>((resolve, reject) => {
|
||||
manager.waitForResponse('test', 10000, resolve, reject, undefined, undefined, abortController.signal)
|
||||
})
|
||||
|
||||
abortController.abort(new Error('User cancelled'))
|
||||
|
||||
await expect(resolvePromise).rejects.toThrow('User cancelled')
|
||||
})
|
||||
|
||||
it('should replace existing response with same key', async () => {
|
||||
const firstReject = vi.fn()
|
||||
const secondResolve = vi.fn()
|
||||
const secondReject = vi.fn()
|
||||
|
||||
manager.waitForResponse('test', 5000, vi.fn(), firstReject)
|
||||
manager.waitForResponse('test', 5000, secondResolve, secondReject)
|
||||
|
||||
// First should be cleared (no rejection since it's replaced)
|
||||
const payload = { type: 'test' }
|
||||
manager.tryResolve('test', payload)
|
||||
|
||||
expect(secondResolve).toHaveBeenCalledWith(payload)
|
||||
})
|
||||
})
|
||||
|
||||
describe('tryResolve', () => {
|
||||
it('should return false when no matching response', () => {
|
||||
expect(manager.tryResolve('nonexistent', {})).toBe(false)
|
||||
})
|
||||
|
||||
it('should match with transferId', async () => {
|
||||
const resolvePromise = new Promise<unknown>((resolve, reject) => {
|
||||
manager.waitForResponse('file_start_ack', 5000, resolve, reject, 'uuid-123')
|
||||
})
|
||||
|
||||
const payload = { type: 'file_start_ack', transferId: 'uuid-123' }
|
||||
manager.tryResolve('file_start_ack', payload, 'uuid-123')
|
||||
|
||||
await expect(resolvePromise).resolves.toEqual(payload)
|
||||
})
|
||||
})
|
||||
|
||||
describe('rejectAll', () => {
|
||||
it('should reject all pending responses', async () => {
|
||||
const promises = [
|
||||
new Promise<unknown>((resolve, reject) => {
|
||||
manager.waitForResponse('test1', 5000, resolve, reject)
|
||||
}),
|
||||
new Promise<unknown>((resolve, reject) => {
|
||||
manager.waitForResponse('test2', 5000, resolve, reject, 'uuid')
|
||||
})
|
||||
]
|
||||
|
||||
manager.rejectAll(new Error('Connection closed'))
|
||||
|
||||
await expect(promises[0]).rejects.toThrow('Connection closed')
|
||||
await expect(promises[1]).rejects.toThrow('Connection closed')
|
||||
})
|
||||
})
|
||||
|
||||
describe('clearPendingResponse', () => {
|
||||
it('should clear specific response by key', () => {
|
||||
manager.waitForResponse('test', 5000, vi.fn(), vi.fn())
|
||||
|
||||
manager.clearPendingResponse('test')
|
||||
|
||||
expect(manager.tryResolve('test', {})).toBe(false)
|
||||
})
|
||||
|
||||
it('should clear all responses when no key provided', () => {
|
||||
manager.waitForResponse('test1', 5000, vi.fn(), vi.fn())
|
||||
manager.waitForResponse('test2', 5000, vi.fn(), vi.fn())
|
||||
|
||||
manager.clearPendingResponse()
|
||||
|
||||
expect(manager.tryResolve('test1', {})).toBe(false)
|
||||
expect(manager.tryResolve('test2', {})).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getAbortError', () => {
|
||||
it('should return Error reason directly', () => {
|
||||
const originalError = new Error('Original error')
|
||||
const signal = { aborted: true, reason: originalError } as AbortSignal
|
||||
|
||||
const error = manager.getAbortError(signal, 'Fallback')
|
||||
|
||||
expect(error).toBe(originalError)
|
||||
})
|
||||
|
||||
it('should create Error from string reason', () => {
|
||||
const signal = { aborted: true, reason: 'String reason' } as AbortSignal
|
||||
|
||||
const error = manager.getAbortError(signal, 'Fallback')
|
||||
|
||||
expect(error.message).toBe('String reason')
|
||||
})
|
||||
|
||||
it('should use fallback message when no reason', () => {
|
||||
const signal = { aborted: true } as AbortSignal
|
||||
|
||||
const error = manager.getAbortError(signal, 'Fallback message')
|
||||
|
||||
expect(error.message).toBe('Fallback message')
|
||||
})
|
||||
})
|
||||
})
|
||||
67
src/main/services/lanTransfer/binaryProtocol.ts
Normal file
67
src/main/services/lanTransfer/binaryProtocol.ts
Normal file
@ -0,0 +1,67 @@
|
||||
import type { Socket } from 'node:net'
|
||||
|
||||
/**
|
||||
* Binary protocol constants (v1)
|
||||
*/
|
||||
export const BINARY_TYPE_FILE_CHUNK = 0x01
|
||||
|
||||
/**
|
||||
* Send file chunk as binary frame (protocol v1 - streaming mode)
|
||||
*
|
||||
* Frame format:
|
||||
* ```
|
||||
* ┌──────────┬──────────┬──────────┬───────────────┬──────────────┬────────────┬───────────┐
|
||||
* │ Magic │ TotalLen │ Type │ TransferId Len│ TransferId │ ChunkIdx │ Data │
|
||||
* │ 0x43 0x53│ (4B BE) │ 0x01 │ (2B BE) │ (variable) │ (4B BE) │ (raw) │
|
||||
* └──────────┴──────────┴──────────┴───────────────┴──────────────┴────────────┴───────────┘
|
||||
* ```
|
||||
*
|
||||
* @param socket - TCP socket to write to
|
||||
* @param transferId - UUID of the transfer
|
||||
* @param chunkIndex - Index of the chunk (0-based)
|
||||
* @param data - Raw chunk data buffer
|
||||
* @returns true if data was buffered, false if backpressure should be applied
|
||||
*/
|
||||
export function sendBinaryChunk(socket: Socket, transferId: string, chunkIndex: number, data: Buffer): boolean {
|
||||
if (!socket || socket.destroyed || !socket.writable) {
|
||||
throw new Error('Socket is not writable')
|
||||
}
|
||||
|
||||
const tidBuffer = Buffer.from(transferId, 'utf8')
|
||||
const tidLen = tidBuffer.length
|
||||
|
||||
// totalLen = type(1) + tidLen(2) + tid(n) + idx(4) + data(m)
|
||||
const totalLen = 1 + 2 + tidLen + 4 + data.length
|
||||
|
||||
const header = Buffer.allocUnsafe(2 + 4 + 1 + 2 + tidLen + 4)
|
||||
let offset = 0
|
||||
|
||||
// Magic (2 bytes): "CS"
|
||||
header[offset++] = 0x43
|
||||
header[offset++] = 0x53
|
||||
|
||||
// TotalLen (4 bytes, Big-Endian)
|
||||
header.writeUInt32BE(totalLen, offset)
|
||||
offset += 4
|
||||
|
||||
// Type (1 byte)
|
||||
header[offset++] = BINARY_TYPE_FILE_CHUNK
|
||||
|
||||
// TransferId length (2 bytes, Big-Endian)
|
||||
header.writeUInt16BE(tidLen, offset)
|
||||
offset += 2
|
||||
|
||||
// TransferId (variable)
|
||||
tidBuffer.copy(header, offset)
|
||||
offset += tidLen
|
||||
|
||||
// ChunkIndex (4 bytes, Big-Endian)
|
||||
header.writeUInt32BE(chunkIndex, offset)
|
||||
|
||||
socket.cork()
|
||||
const wroteHeader = socket.write(header)
|
||||
const wroteData = socket.write(data)
|
||||
socket.uncork()
|
||||
|
||||
return wroteHeader && wroteData
|
||||
}
|
||||
162
src/main/services/lanTransfer/handlers/connection.ts
Normal file
162
src/main/services/lanTransfer/handlers/connection.ts
Normal file
@ -0,0 +1,162 @@
|
||||
import { isIP, type Socket } from 'node:net'
|
||||
import { platform } from 'node:os'
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import type { LanHandshakeRequestMessage, LocalTransferPeer } from '@shared/config/types'
|
||||
import { app } from 'electron'
|
||||
|
||||
import type { ConnectionContext } from '../types'
|
||||
|
||||
export const HANDSHAKE_PROTOCOL_VERSION = '1'
|
||||
|
||||
/** Maximum size for line buffer to prevent memory exhaustion from malicious peers */
|
||||
const MAX_LINE_BUFFER_SIZE = 1024 * 1024 // 1MB limit for control messages
|
||||
|
||||
const logger = loggerService.withContext('LanTransferConnection')
|
||||
|
||||
/**
|
||||
* Build a handshake request message with device info.
|
||||
*/
|
||||
export function buildHandshakeMessage(): LanHandshakeRequestMessage {
|
||||
return {
|
||||
type: 'handshake',
|
||||
deviceName: app.getName(),
|
||||
version: HANDSHAKE_PROTOCOL_VERSION,
|
||||
platform: platform(),
|
||||
appVersion: app.getVersion()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Pick the best host address from a peer's available addresses.
|
||||
* Prefers IPv4 addresses over IPv6.
|
||||
*/
|
||||
export function pickHost(peer: LocalTransferPeer): string | undefined {
|
||||
const preferred = peer.addresses?.find((addr) => isIP(addr) === 4) || peer.addresses?.[0]
|
||||
return preferred || peer.host
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a test ping message after successful handshake.
|
||||
*/
|
||||
export function sendTestPing(ctx: ConnectionContext): void {
|
||||
const payload = 'hello world'
|
||||
try {
|
||||
ctx.sendControlMessage({ type: 'ping', payload })
|
||||
logger.info('Sent LAN ping test payload')
|
||||
ctx.broadcastClientEvent({
|
||||
type: 'ping_sent',
|
||||
payload,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
logger.error('Failed to send LAN test ping', error as Error)
|
||||
ctx.broadcastClientEvent({
|
||||
type: 'error',
|
||||
message,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Attach data listener to socket for receiving control messages.
|
||||
* Returns a function to parse the line buffer.
|
||||
*/
|
||||
export function createDataHandler(onControlLine: (line: string) => void): {
|
||||
lineBuffer: string
|
||||
handleData: (chunk: Buffer) => void
|
||||
resetBuffer: () => void
|
||||
} {
|
||||
let lineBuffer = ''
|
||||
|
||||
return {
|
||||
get lineBuffer() {
|
||||
return lineBuffer
|
||||
},
|
||||
handleData(chunk: Buffer) {
|
||||
lineBuffer += chunk.toString('utf8')
|
||||
|
||||
// Prevent memory exhaustion from malicious peers sending data without newlines
|
||||
if (lineBuffer.length > MAX_LINE_BUFFER_SIZE) {
|
||||
logger.error('Line buffer exceeded maximum size, resetting')
|
||||
lineBuffer = ''
|
||||
throw new Error('Control message too large')
|
||||
}
|
||||
|
||||
let newlineIndex = lineBuffer.indexOf('\n')
|
||||
while (newlineIndex !== -1) {
|
||||
const line = lineBuffer.slice(0, newlineIndex).trim()
|
||||
lineBuffer = lineBuffer.slice(newlineIndex + 1)
|
||||
if (line.length > 0) {
|
||||
onControlLine(line)
|
||||
}
|
||||
newlineIndex = lineBuffer.indexOf('\n')
|
||||
}
|
||||
},
|
||||
resetBuffer() {
|
||||
lineBuffer = ''
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for socket to drain (backpressure handling).
|
||||
*/
|
||||
export async function waitForSocketDrain(socket: Socket, abortSignal: AbortSignal): Promise<void> {
|
||||
if (abortSignal.aborted) {
|
||||
throw getAbortError(abortSignal, 'Transfer aborted while waiting for socket drain')
|
||||
}
|
||||
if (socket.destroyed) {
|
||||
throw new Error('Socket is closed')
|
||||
}
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const cleanup = () => {
|
||||
socket.off('drain', onDrain)
|
||||
socket.off('close', onClose)
|
||||
socket.off('error', onError)
|
||||
abortSignal.removeEventListener('abort', onAbort)
|
||||
}
|
||||
|
||||
const onDrain = () => {
|
||||
cleanup()
|
||||
resolve()
|
||||
}
|
||||
|
||||
const onClose = () => {
|
||||
cleanup()
|
||||
reject(new Error('Socket closed while waiting for drain'))
|
||||
}
|
||||
|
||||
const onError = (error: Error) => {
|
||||
cleanup()
|
||||
reject(error)
|
||||
}
|
||||
|
||||
const onAbort = () => {
|
||||
cleanup()
|
||||
reject(getAbortError(abortSignal, 'Transfer aborted while waiting for socket drain'))
|
||||
}
|
||||
|
||||
socket.once('drain', onDrain)
|
||||
socket.once('close', onClose)
|
||||
socket.once('error', onError)
|
||||
abortSignal.addEventListener('abort', onAbort, { once: true })
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the error from an abort signal, or create a fallback error.
|
||||
*/
|
||||
export function getAbortError(signal: AbortSignal, fallbackMessage: string): Error {
|
||||
const reason = (signal as AbortSignal & { reason?: unknown }).reason
|
||||
if (reason instanceof Error) {
|
||||
return reason
|
||||
}
|
||||
if (typeof reason === 'string' && reason.length > 0) {
|
||||
return new Error(reason)
|
||||
}
|
||||
return new Error(fallbackMessage)
|
||||
}
|
||||
267
src/main/services/lanTransfer/handlers/fileTransfer.ts
Normal file
267
src/main/services/lanTransfer/handlers/fileTransfer.ts
Normal file
@ -0,0 +1,267 @@
|
||||
import * as crypto from 'node:crypto'
|
||||
import * as fs from 'node:fs'
|
||||
import type { Socket } from 'node:net'
|
||||
import * as path from 'node:path'
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import type {
|
||||
LanFileCompleteMessage,
|
||||
LanFileEndMessage,
|
||||
LanFileStartAckMessage,
|
||||
LanFileStartMessage
|
||||
} from '@shared/config/types'
|
||||
import {
|
||||
LAN_TRANSFER_CHUNK_SIZE,
|
||||
LAN_TRANSFER_COMPLETE_TIMEOUT_MS,
|
||||
LAN_TRANSFER_MAX_FILE_SIZE
|
||||
} from '@shared/config/types'
|
||||
|
||||
import { sendBinaryChunk } from '../binaryProtocol'
|
||||
import type { ActiveFileTransfer, FileTransferContext } from '../types'
|
||||
import { getAbortError, waitForSocketDrain } from './connection'
|
||||
|
||||
const DEFAULT_FILE_START_ACK_TIMEOUT_MS = 30_000 // 30s for file_start_ack
|
||||
|
||||
const logger = loggerService.withContext('LanTransferFileHandler')
|
||||
|
||||
/**
|
||||
* Validate a file for transfer.
|
||||
* Checks existence, type, extension, and size limits.
|
||||
*/
|
||||
export async function validateFile(filePath: string): Promise<{ stats: fs.Stats; fileName: string }> {
|
||||
let stats: fs.Stats
|
||||
try {
|
||||
stats = await fs.promises.stat(filePath)
|
||||
} catch (error) {
|
||||
const nodeError = error as NodeJS.ErrnoException
|
||||
if (nodeError.code === 'ENOENT') {
|
||||
throw new Error(`File not found: ${filePath}`)
|
||||
} else if (nodeError.code === 'EACCES') {
|
||||
throw new Error(`Permission denied: ${filePath}`)
|
||||
} else if (nodeError.code === 'ENOTDIR') {
|
||||
throw new Error(`Invalid path: ${filePath}`)
|
||||
} else {
|
||||
throw new Error(`Cannot access file: ${filePath} (${nodeError.code || 'unknown error'})`)
|
||||
}
|
||||
}
|
||||
|
||||
if (!stats.isFile()) {
|
||||
throw new Error('Path is not a file')
|
||||
}
|
||||
|
||||
const fileName = path.basename(filePath)
|
||||
const ext = path.extname(fileName).toLowerCase()
|
||||
if (ext !== '.zip') {
|
||||
throw new Error('Only ZIP files are supported')
|
||||
}
|
||||
|
||||
if (stats.size > LAN_TRANSFER_MAX_FILE_SIZE) {
|
||||
throw new Error(`File too large. Maximum size is ${formatFileSize(LAN_TRANSFER_MAX_FILE_SIZE)}`)
|
||||
}
|
||||
|
||||
return { stats, fileName }
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate SHA-256 checksum of a file.
|
||||
*/
|
||||
export async function calculateFileChecksum(filePath: string): Promise<string> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const hash = crypto.createHash('sha256')
|
||||
const stream = fs.createReadStream(filePath)
|
||||
stream.on('data', (data) => hash.update(data))
|
||||
stream.on('end', () => resolve(hash.digest('hex')))
|
||||
stream.on('error', reject)
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Create initial transfer state for a new file transfer.
|
||||
*/
|
||||
export function createTransferState(
|
||||
transferId: string,
|
||||
fileName: string,
|
||||
fileSize: number,
|
||||
checksum: string
|
||||
): ActiveFileTransfer {
|
||||
const chunkSize = LAN_TRANSFER_CHUNK_SIZE
|
||||
const totalChunks = Math.ceil(fileSize / chunkSize)
|
||||
|
||||
return {
|
||||
transferId,
|
||||
fileName,
|
||||
fileSize,
|
||||
checksum,
|
||||
totalChunks,
|
||||
chunkSize,
|
||||
bytesSent: 0,
|
||||
currentChunk: 0,
|
||||
startedAt: Date.now(),
|
||||
isCancelled: false,
|
||||
abortController: new AbortController()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send file_start message to receiver.
|
||||
*/
|
||||
export function sendFileStart(ctx: FileTransferContext, transfer: ActiveFileTransfer): void {
|
||||
const startMessage: LanFileStartMessage = {
|
||||
type: 'file_start',
|
||||
transferId: transfer.transferId,
|
||||
fileName: transfer.fileName,
|
||||
fileSize: transfer.fileSize,
|
||||
mimeType: 'application/zip',
|
||||
checksum: transfer.checksum,
|
||||
totalChunks: transfer.totalChunks,
|
||||
chunkSize: transfer.chunkSize
|
||||
}
|
||||
ctx.sendControlMessage(startMessage)
|
||||
logger.info('Sent file_start message')
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for file_start_ack from receiver.
|
||||
*/
|
||||
export function waitForFileStartAck(
|
||||
ctx: FileTransferContext,
|
||||
transferId: string,
|
||||
abortSignal?: AbortSignal
|
||||
): Promise<LanFileStartAckMessage> {
|
||||
return new Promise((resolve, reject) => {
|
||||
ctx.waitForResponse(
|
||||
'file_start_ack',
|
||||
DEFAULT_FILE_START_ACK_TIMEOUT_MS,
|
||||
(payload) => resolve(payload as LanFileStartAckMessage),
|
||||
reject,
|
||||
transferId,
|
||||
undefined,
|
||||
abortSignal
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for file_complete from receiver after all chunks sent.
|
||||
*/
|
||||
export function waitForFileComplete(
|
||||
ctx: FileTransferContext,
|
||||
transferId: string,
|
||||
abortSignal?: AbortSignal
|
||||
): Promise<LanFileCompleteMessage> {
|
||||
return new Promise((resolve, reject) => {
|
||||
ctx.waitForResponse(
|
||||
'file_complete',
|
||||
LAN_TRANSFER_COMPLETE_TIMEOUT_MS,
|
||||
(payload) => resolve(payload as LanFileCompleteMessage),
|
||||
reject,
|
||||
transferId,
|
||||
undefined,
|
||||
abortSignal
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Send file_end message to receiver.
|
||||
*/
|
||||
export function sendFileEnd(ctx: FileTransferContext, transferId: string): void {
|
||||
const endMessage: LanFileEndMessage = {
|
||||
type: 'file_end',
|
||||
transferId
|
||||
}
|
||||
ctx.sendControlMessage(endMessage)
|
||||
logger.info('Sent file_end message')
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream file chunks to the receiver (v1 streaming mode - no per-chunk acknowledgment).
|
||||
*/
|
||||
export async function streamFileChunks(
|
||||
socket: Socket,
|
||||
filePath: string,
|
||||
transfer: ActiveFileTransfer,
|
||||
abortSignal: AbortSignal,
|
||||
onProgress: (bytesSent: number, chunkIndex: number) => void
|
||||
): Promise<void> {
|
||||
const { chunkSize, transferId } = transfer
|
||||
|
||||
const stream = fs.createReadStream(filePath, { highWaterMark: chunkSize })
|
||||
transfer.stream = stream
|
||||
|
||||
let chunkIndex = 0
|
||||
let bytesSent = 0
|
||||
|
||||
try {
|
||||
for await (const chunk of stream) {
|
||||
if (abortSignal.aborted) {
|
||||
throw getAbortError(abortSignal, 'Transfer aborted')
|
||||
}
|
||||
|
||||
const buffer = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk)
|
||||
bytesSent += buffer.length
|
||||
|
||||
// Send chunk as binary frame (v1 streaming) with backpressure handling
|
||||
const canContinue = sendBinaryChunk(socket, transferId, chunkIndex, buffer)
|
||||
if (!canContinue) {
|
||||
await waitForSocketDrain(socket, abortSignal)
|
||||
}
|
||||
|
||||
// Update progress
|
||||
transfer.bytesSent = bytesSent
|
||||
transfer.currentChunk = chunkIndex
|
||||
|
||||
onProgress(bytesSent, chunkIndex)
|
||||
chunkIndex++
|
||||
}
|
||||
|
||||
logger.info(`File streaming completed: ${chunkIndex} chunks sent`)
|
||||
} catch (error) {
|
||||
logger.error('File streaming failed', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Abort an active transfer and clean up resources.
|
||||
*/
|
||||
export function abortTransfer(transfer: ActiveFileTransfer | undefined, error: Error): void {
|
||||
if (!transfer) {
|
||||
return
|
||||
}
|
||||
|
||||
transfer.isCancelled = true
|
||||
if (!transfer.abortController.signal.aborted) {
|
||||
transfer.abortController.abort(error)
|
||||
}
|
||||
if (transfer.stream && !transfer.stream.destroyed) {
|
||||
transfer.stream.destroy(error)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up transfer resources without error.
|
||||
*/
|
||||
export function cleanupTransfer(transfer: ActiveFileTransfer | undefined): void {
|
||||
if (!transfer) {
|
||||
return
|
||||
}
|
||||
|
||||
if (!transfer.abortController.signal.aborted) {
|
||||
transfer.abortController.abort()
|
||||
}
|
||||
if (transfer.stream && !transfer.stream.destroyed) {
|
||||
transfer.stream.destroy()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Format bytes into human-readable size string.
|
||||
*/
|
||||
export function formatFileSize(bytes: number): string {
|
||||
if (bytes === 0) return '0 B'
|
||||
const k = 1024
|
||||
const sizes = ['B', 'KB', 'MB', 'GB']
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k))
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]
|
||||
}
|
||||
22
src/main/services/lanTransfer/handlers/index.ts
Normal file
22
src/main/services/lanTransfer/handlers/index.ts
Normal file
@ -0,0 +1,22 @@
|
||||
export {
|
||||
buildHandshakeMessage,
|
||||
createDataHandler,
|
||||
getAbortError,
|
||||
HANDSHAKE_PROTOCOL_VERSION,
|
||||
pickHost,
|
||||
sendTestPing,
|
||||
waitForSocketDrain
|
||||
} from './connection'
|
||||
export {
|
||||
abortTransfer,
|
||||
calculateFileChecksum,
|
||||
cleanupTransfer,
|
||||
createTransferState,
|
||||
formatFileSize,
|
||||
sendFileEnd,
|
||||
sendFileStart,
|
||||
streamFileChunks,
|
||||
validateFile,
|
||||
waitForFileComplete,
|
||||
waitForFileStartAck
|
||||
} from './fileTransfer'
|
||||
21
src/main/services/lanTransfer/index.ts
Normal file
21
src/main/services/lanTransfer/index.ts
Normal file
@ -0,0 +1,21 @@
|
||||
/**
|
||||
* LAN Transfer Client Module
|
||||
*
|
||||
* Protocol: v1.0 (streaming mode)
|
||||
*
|
||||
* Features:
|
||||
* - Binary frame format for file chunks (no base64 overhead)
|
||||
* - Streaming mode (no per-chunk acknowledgment)
|
||||
* - JSON messages for control flow (handshake, file_start, file_end, etc.)
|
||||
* - Global timeout protection
|
||||
* - Backpressure handling
|
||||
*
|
||||
* Binary Frame Format:
|
||||
* ┌──────────┬──────────┬──────────┬───────────────┬──────────────┬────────────┬───────────┐
|
||||
* │ Magic │ TotalLen │ Type │ TransferId Len│ TransferId │ ChunkIdx │ Data │
|
||||
* │ 0x43 0x53│ (4B BE) │ 0x01 │ (2B BE) │ (variable) │ (4B BE) │ (raw) │
|
||||
* └──────────┴──────────┴──────────┴───────────────┴──────────────┴────────────┴───────────┘
|
||||
*/
|
||||
|
||||
export { HANDSHAKE_PROTOCOL_VERSION, lanTransferClientService } from './LanTransferClientService'
|
||||
export type { ActiveFileTransfer, ConnectionContext, FileTransferContext, PendingResponse } from './types'
|
||||
144
src/main/services/lanTransfer/responseManager.ts
Normal file
144
src/main/services/lanTransfer/responseManager.ts
Normal file
@ -0,0 +1,144 @@
|
||||
import type { PendingResponse } from './types'
|
||||
|
||||
/**
|
||||
* Manages pending response handlers for awaiting control messages.
|
||||
* Handles timeouts, abort signals, and cleanup.
|
||||
*/
|
||||
export class ResponseManager {
|
||||
private pendingResponses = new Map<string, PendingResponse>()
|
||||
private onTimeout?: () => void
|
||||
|
||||
/**
|
||||
* Set a callback to be called when a response times out.
|
||||
* Typically used to trigger disconnect on timeout.
|
||||
*/
|
||||
setTimeoutCallback(callback: () => void): void {
|
||||
this.onTimeout = callback
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a composite key for identifying pending responses.
|
||||
*/
|
||||
buildResponseKey(type: string, transferId?: string, chunkIndex?: number): string {
|
||||
const parts = [type]
|
||||
if (transferId !== undefined) parts.push(transferId)
|
||||
if (chunkIndex !== undefined) parts.push(String(chunkIndex))
|
||||
return parts.join(':')
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a response listener with timeout and optional abort signal.
|
||||
*/
|
||||
waitForResponse(
|
||||
type: string,
|
||||
timeoutMs: number,
|
||||
resolve: (payload: unknown) => void,
|
||||
reject: (error: Error) => void,
|
||||
transferId?: string,
|
||||
chunkIndex?: number,
|
||||
abortSignal?: AbortSignal
|
||||
): void {
|
||||
const responseKey = this.buildResponseKey(type, transferId, chunkIndex)
|
||||
|
||||
// Clear any existing response with the same key
|
||||
this.clearPendingResponse(responseKey)
|
||||
|
||||
const timeoutHandle = setTimeout(() => {
|
||||
this.clearPendingResponse(responseKey)
|
||||
const error = new Error(`Timeout waiting for ${type}`)
|
||||
reject(error)
|
||||
this.onTimeout?.()
|
||||
}, timeoutMs)
|
||||
|
||||
const pending: PendingResponse = {
|
||||
type,
|
||||
transferId,
|
||||
chunkIndex,
|
||||
resolve,
|
||||
reject,
|
||||
timeoutHandle,
|
||||
abortSignal
|
||||
}
|
||||
|
||||
if (abortSignal) {
|
||||
const abortListener = () => {
|
||||
this.clearPendingResponse(responseKey)
|
||||
reject(this.getAbortError(abortSignal, `Aborted while waiting for ${type}`))
|
||||
}
|
||||
pending.abortListener = abortListener
|
||||
abortSignal.addEventListener('abort', abortListener, { once: true })
|
||||
}
|
||||
|
||||
this.pendingResponses.set(responseKey, pending)
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to resolve a pending response by type and optional identifiers.
|
||||
* Returns true if a matching response was found and resolved.
|
||||
*/
|
||||
tryResolve(type: string, payload: unknown, transferId?: string, chunkIndex?: number): boolean {
|
||||
const responseKey = this.buildResponseKey(type, transferId, chunkIndex)
|
||||
const pendingResponse = this.pendingResponses.get(responseKey)
|
||||
|
||||
if (pendingResponse) {
|
||||
const resolver = pendingResponse.resolve
|
||||
this.clearPendingResponse(responseKey)
|
||||
resolver(payload)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear a single pending response by key, or all responses if no key provided.
|
||||
*/
|
||||
clearPendingResponse(key?: string): void {
|
||||
if (key) {
|
||||
const pending = this.pendingResponses.get(key)
|
||||
if (pending?.timeoutHandle) {
|
||||
clearTimeout(pending.timeoutHandle)
|
||||
}
|
||||
if (pending?.abortSignal && pending.abortListener) {
|
||||
pending.abortSignal.removeEventListener('abort', pending.abortListener)
|
||||
}
|
||||
this.pendingResponses.delete(key)
|
||||
} else {
|
||||
// Clear all pending responses
|
||||
for (const pending of this.pendingResponses.values()) {
|
||||
if (pending.timeoutHandle) {
|
||||
clearTimeout(pending.timeoutHandle)
|
||||
}
|
||||
if (pending.abortSignal && pending.abortListener) {
|
||||
pending.abortSignal.removeEventListener('abort', pending.abortListener)
|
||||
}
|
||||
}
|
||||
this.pendingResponses.clear()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reject all pending responses with the given error.
|
||||
*/
|
||||
rejectAll(error: Error): void {
|
||||
for (const key of Array.from(this.pendingResponses.keys())) {
|
||||
const pending = this.pendingResponses.get(key)
|
||||
this.clearPendingResponse(key)
|
||||
pending?.reject(error)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the abort error from an abort signal, or create a fallback error.
|
||||
*/
|
||||
getAbortError(signal: AbortSignal, fallbackMessage: string): Error {
|
||||
const reason = (signal as AbortSignal & { reason?: unknown }).reason
|
||||
if (reason instanceof Error) {
|
||||
return reason
|
||||
}
|
||||
if (typeof reason === 'string' && reason.length > 0) {
|
||||
return new Error(reason)
|
||||
}
|
||||
return new Error(fallbackMessage)
|
||||
}
|
||||
}
|
||||
65
src/main/services/lanTransfer/types.ts
Normal file
65
src/main/services/lanTransfer/types.ts
Normal file
@ -0,0 +1,65 @@
|
||||
import type * as fs from 'node:fs'
|
||||
import type { Socket } from 'node:net'
|
||||
|
||||
import type { LanClientEvent, LocalTransferPeer } from '@shared/config/types'
|
||||
|
||||
/**
|
||||
* Pending response handler for awaiting control messages
|
||||
*/
|
||||
export type PendingResponse = {
|
||||
type: string
|
||||
transferId?: string
|
||||
chunkIndex?: number
|
||||
resolve: (payload: unknown) => void
|
||||
reject: (error: Error) => void
|
||||
timeoutHandle?: NodeJS.Timeout
|
||||
abortSignal?: AbortSignal
|
||||
abortListener?: () => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Active file transfer state tracking
|
||||
*/
|
||||
export type ActiveFileTransfer = {
|
||||
transferId: string
|
||||
fileName: string
|
||||
fileSize: number
|
||||
checksum: string
|
||||
totalChunks: number
|
||||
chunkSize: number
|
||||
bytesSent: number
|
||||
currentChunk: number
|
||||
startedAt: number
|
||||
stream?: fs.ReadStream
|
||||
isCancelled: boolean
|
||||
abortController: AbortController
|
||||
}
|
||||
|
||||
/**
|
||||
* Context interface for connection handlers
|
||||
* Provides access to service methods without circular dependencies
|
||||
*/
|
||||
export type ConnectionContext = {
|
||||
socket: Socket | null
|
||||
currentPeer?: LocalTransferPeer
|
||||
sendControlMessage: (message: Record<string, unknown>) => void
|
||||
broadcastClientEvent: (event: LanClientEvent) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Context interface for file transfer handlers
|
||||
* Extends connection context with transfer-specific methods
|
||||
*/
|
||||
export type FileTransferContext = ConnectionContext & {
|
||||
activeTransfer?: ActiveFileTransfer
|
||||
setActiveTransfer: (transfer: ActiveFileTransfer | undefined) => void
|
||||
waitForResponse: (
|
||||
type: string,
|
||||
timeoutMs: number,
|
||||
resolve: (payload: unknown) => void,
|
||||
reject: (error: Error) => void,
|
||||
transferId?: string,
|
||||
chunkIndex?: number,
|
||||
abortSignal?: AbortSignal
|
||||
) => void
|
||||
}
|
||||
@ -1,7 +1,9 @@
|
||||
import type { Client } from '@libsql/client'
|
||||
import { createClient } from '@libsql/client'
|
||||
import { loggerService } from '@logger'
|
||||
import { DATA_PATH } from '@main/config'
|
||||
import Embeddings from '@main/knowledge/embedjs/embeddings/Embeddings'
|
||||
import { makeSureDirExists } from '@main/utils'
|
||||
import type {
|
||||
AddMemoryOptions,
|
||||
AssistantMessage,
|
||||
@ -13,6 +15,7 @@ import type {
|
||||
} from '@types'
|
||||
import crypto from 'crypto'
|
||||
import { app } from 'electron'
|
||||
import fs from 'fs'
|
||||
import path from 'path'
|
||||
|
||||
import { MemoryQueries } from './queries'
|
||||
@ -71,6 +74,21 @@ export class MemoryService {
|
||||
return MemoryService.instance
|
||||
}
|
||||
|
||||
/**
|
||||
* Migrate the memory database from the old path to the new path
|
||||
* If the old memory database exists, rename it to the new path
|
||||
*/
|
||||
public migrateMemoryDb(): void {
|
||||
const oldMemoryDbPath = path.join(app.getPath('userData'), 'memories.db')
|
||||
const memoryDbPath = path.join(DATA_PATH, 'Memory', 'memories.db')
|
||||
|
||||
makeSureDirExists(path.dirname(memoryDbPath))
|
||||
|
||||
if (fs.existsSync(oldMemoryDbPath)) {
|
||||
fs.renameSync(oldMemoryDbPath, memoryDbPath)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the database connection and create tables
|
||||
*/
|
||||
@ -80,11 +98,12 @@ export class MemoryService {
|
||||
}
|
||||
|
||||
try {
|
||||
const userDataPath = app.getPath('userData')
|
||||
const dbPath = path.join(userDataPath, 'memories.db')
|
||||
const memoryDbPath = path.join(DATA_PATH, 'Memory', 'memories.db')
|
||||
|
||||
makeSureDirExists(path.dirname(memoryDbPath))
|
||||
|
||||
this.db = createClient({
|
||||
url: `file:${dbPath}`,
|
||||
url: `file:${memoryDbPath}`,
|
||||
intMode: 'number'
|
||||
})
|
||||
|
||||
@ -168,12 +187,13 @@ export class MemoryService {
|
||||
|
||||
// Generate embedding if model is configured
|
||||
let embedding: number[] | null = null
|
||||
const embedderApiClient = this.config?.embedderApiClient
|
||||
if (embedderApiClient) {
|
||||
const embeddingModel = this.config?.embeddingModel
|
||||
|
||||
if (embeddingModel) {
|
||||
try {
|
||||
embedding = await this.generateEmbedding(trimmedMemory)
|
||||
logger.debug(
|
||||
`Generated embedding for restored memory with dimension: ${embedding.length} (target: ${this.config?.embedderDimensions || MemoryService.UNIFIED_DIMENSION})`
|
||||
`Generated embedding for restored memory with dimension: ${embedding.length} (target: ${this.config?.embeddingDimensions || MemoryService.UNIFIED_DIMENSION})`
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error('Failed to generate embedding for restored memory:', error as Error)
|
||||
@ -211,11 +231,11 @@ export class MemoryService {
|
||||
|
||||
// Generate embedding if model is configured
|
||||
let embedding: number[] | null = null
|
||||
if (this.config?.embedderApiClient) {
|
||||
if (this.config?.embeddingModel) {
|
||||
try {
|
||||
embedding = await this.generateEmbedding(trimmedMemory)
|
||||
logger.debug(
|
||||
`Generated embedding with dimension: ${embedding.length} (target: ${this.config?.embedderDimensions || MemoryService.UNIFIED_DIMENSION})`
|
||||
`Generated embedding with dimension: ${embedding.length} (target: ${this.config?.embeddingDimensions || MemoryService.UNIFIED_DIMENSION})`
|
||||
)
|
||||
|
||||
// Check for similar memories using vector similarity
|
||||
@ -300,7 +320,7 @@ export class MemoryService {
|
||||
|
||||
try {
|
||||
// If we have an embedder model configured, use vector search
|
||||
if (this.config?.embedderApiClient) {
|
||||
if (this.config?.embeddingModel) {
|
||||
try {
|
||||
const queryEmbedding = await this.generateEmbedding(query)
|
||||
return await this.hybridSearch(query, queryEmbedding, { limit, userId, agentId, filters })
|
||||
@ -497,11 +517,11 @@ export class MemoryService {
|
||||
|
||||
// Generate new embedding if model is configured
|
||||
let embedding: number[] | null = null
|
||||
if (this.config?.embedderApiClient) {
|
||||
if (this.config?.embeddingModel) {
|
||||
try {
|
||||
embedding = await this.generateEmbedding(memory)
|
||||
logger.debug(
|
||||
`Updated embedding with dimension: ${embedding.length} (target: ${this.config?.embedderDimensions || MemoryService.UNIFIED_DIMENSION})`
|
||||
`Updated embedding with dimension: ${embedding.length} (target: ${this.config?.embeddingDimensions || MemoryService.UNIFIED_DIMENSION})`
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error('Failed to generate embedding for update:', error as Error)
|
||||
@ -710,21 +730,22 @@ export class MemoryService {
|
||||
* Generate embedding for text
|
||||
*/
|
||||
private async generateEmbedding(text: string): Promise<number[]> {
|
||||
if (!this.config?.embedderApiClient) {
|
||||
if (!this.config?.embeddingModel) {
|
||||
throw new Error('Embedder model not configured')
|
||||
}
|
||||
|
||||
try {
|
||||
// Initialize embeddings instance if needed
|
||||
if (!this.embeddings) {
|
||||
if (!this.config.embedderApiClient) {
|
||||
if (!this.config.embeddingApiClient) {
|
||||
throw new Error('Embedder provider not configured')
|
||||
}
|
||||
|
||||
this.embeddings = new Embeddings({
|
||||
embedApiClient: this.config.embedderApiClient,
|
||||
dimensions: this.config.embedderDimensions
|
||||
embedApiClient: this.config.embeddingApiClient,
|
||||
dimensions: this.config.embeddingDimensions
|
||||
})
|
||||
|
||||
await this.embeddings.init()
|
||||
}
|
||||
|
||||
|
||||
@ -3,194 +3,223 @@ import { describe, expect, it } from 'vitest'
|
||||
import { buildFunctionCallToolName } from '../mcp'
|
||||
|
||||
describe('buildFunctionCallToolName', () => {
|
||||
describe('basic functionality', () => {
|
||||
it('should combine server name and tool name', () => {
|
||||
describe('basic format', () => {
|
||||
it('should return format mcp__{server}__{tool}', () => {
|
||||
const result = buildFunctionCallToolName('github', 'search_issues')
|
||||
expect(result).toContain('github')
|
||||
expect(result).toContain('search')
|
||||
expect(result).toBe('mcp__github__search_issues')
|
||||
})
|
||||
|
||||
it('should sanitize names by replacing dashes with underscores', () => {
|
||||
const result = buildFunctionCallToolName('my-server', 'my-tool')
|
||||
// Input dashes are replaced, but the separator between server and tool is a dash
|
||||
expect(result).toBe('my_serv-my_tool')
|
||||
expect(result).toContain('_')
|
||||
})
|
||||
|
||||
it('should handle empty server names gracefully', () => {
|
||||
const result = buildFunctionCallToolName('', 'tool')
|
||||
expect(result).toBeTruthy()
|
||||
it('should handle simple server and tool names', () => {
|
||||
expect(buildFunctionCallToolName('fetch', 'get_page')).toBe('mcp__fetch__get_page')
|
||||
expect(buildFunctionCallToolName('database', 'query')).toBe('mcp__database__query')
|
||||
expect(buildFunctionCallToolName('cherry_studio', 'search')).toBe('mcp__cherry_studio__search')
|
||||
})
|
||||
})
|
||||
|
||||
describe('uniqueness with serverId', () => {
|
||||
it('should generate different IDs for same server name but different serverIds', () => {
|
||||
const serverId1 = 'server-id-123456'
|
||||
const serverId2 = 'server-id-789012'
|
||||
const serverName = 'github'
|
||||
const toolName = 'search_repos'
|
||||
|
||||
const result1 = buildFunctionCallToolName(serverName, toolName, serverId1)
|
||||
const result2 = buildFunctionCallToolName(serverName, toolName, serverId2)
|
||||
|
||||
expect(result1).not.toBe(result2)
|
||||
expect(result1).toContain('123456')
|
||||
expect(result2).toContain('789012')
|
||||
describe('valid JavaScript identifier', () => {
|
||||
it('should always start with mcp__ prefix (valid JS identifier start)', () => {
|
||||
const result = buildFunctionCallToolName('123server', '456tool')
|
||||
expect(result).toMatch(/^mcp__/)
|
||||
expect(result).toBe('mcp__123server__456tool')
|
||||
})
|
||||
|
||||
it('should generate same ID when serverId is not provided', () => {
|
||||
it('should only contain alphanumeric chars and underscores', () => {
|
||||
const result = buildFunctionCallToolName('my-server', 'my-tool')
|
||||
expect(result).toBe('mcp__my_server__my_tool')
|
||||
expect(result).toMatch(/^[a-zA-Z][a-zA-Z0-9_]*$/)
|
||||
})
|
||||
|
||||
it('should be a valid JavaScript identifier', () => {
|
||||
const testCases = [
|
||||
['github', 'create_issue'],
|
||||
['my-server', 'fetch-data'],
|
||||
['test@server', 'tool#name'],
|
||||
['server.name', 'tool.action'],
|
||||
['123abc', 'def456']
|
||||
]
|
||||
|
||||
for (const [server, tool] of testCases) {
|
||||
const result = buildFunctionCallToolName(server, tool)
|
||||
// Valid JS identifiers match this pattern
|
||||
expect(result).toMatch(/^[a-zA-Z_][a-zA-Z0-9_]*$/)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('character sanitization', () => {
|
||||
it('should replace dashes with underscores', () => {
|
||||
const result = buildFunctionCallToolName('my-server', 'my-tool-name')
|
||||
expect(result).toBe('mcp__my_server__my_tool_name')
|
||||
})
|
||||
|
||||
it('should replace special characters with underscores', () => {
|
||||
const result = buildFunctionCallToolName('test@server!', 'tool#name$')
|
||||
expect(result).toBe('mcp__test_server__tool_name')
|
||||
})
|
||||
|
||||
it('should replace dots with underscores', () => {
|
||||
const result = buildFunctionCallToolName('server.name', 'tool.action')
|
||||
expect(result).toBe('mcp__server_name__tool_action')
|
||||
})
|
||||
|
||||
it('should replace spaces with underscores', () => {
|
||||
const result = buildFunctionCallToolName('my server', 'my tool')
|
||||
expect(result).toBe('mcp__my_server__my_tool')
|
||||
})
|
||||
|
||||
it('should collapse consecutive underscores', () => {
|
||||
const result = buildFunctionCallToolName('my--server', 'my___tool')
|
||||
expect(result).toBe('mcp__my_server__my_tool')
|
||||
expect(result).not.toMatch(/_{3,}/)
|
||||
})
|
||||
|
||||
it('should trim leading and trailing underscores from parts', () => {
|
||||
const result = buildFunctionCallToolName('_server_', '_tool_')
|
||||
expect(result).toBe('mcp__server__tool')
|
||||
})
|
||||
|
||||
it('should handle names with only special characters', () => {
|
||||
const result = buildFunctionCallToolName('---', '###')
|
||||
expect(result).toBe('mcp____')
|
||||
})
|
||||
})
|
||||
|
||||
describe('length constraints', () => {
|
||||
it('should not exceed 63 characters', () => {
|
||||
const longServerName = 'a'.repeat(50)
|
||||
const longToolName = 'b'.repeat(50)
|
||||
const result = buildFunctionCallToolName(longServerName, longToolName)
|
||||
|
||||
expect(result.length).toBeLessThanOrEqual(63)
|
||||
})
|
||||
|
||||
it('should truncate server name to max 20 chars', () => {
|
||||
const longServerName = 'abcdefghijklmnopqrstuvwxyz' // 26 chars
|
||||
const result = buildFunctionCallToolName(longServerName, 'tool')
|
||||
|
||||
expect(result).toBe('mcp__abcdefghijklmnopqrst__tool')
|
||||
expect(result).toContain('abcdefghijklmnopqrst') // First 20 chars
|
||||
expect(result).not.toContain('uvwxyz') // Truncated
|
||||
})
|
||||
|
||||
it('should truncate tool name to max 35 chars', () => {
|
||||
const longToolName = 'a'.repeat(40)
|
||||
const result = buildFunctionCallToolName('server', longToolName)
|
||||
|
||||
const expectedTool = 'a'.repeat(35)
|
||||
expect(result).toBe(`mcp__server__${expectedTool}`)
|
||||
})
|
||||
|
||||
it('should not end with underscores after truncation', () => {
|
||||
// Create a name that would end with underscores after truncation
|
||||
const longServerName = 'a'.repeat(20)
|
||||
const longToolName = 'b'.repeat(35) + '___extra'
|
||||
const result = buildFunctionCallToolName(longServerName, longToolName)
|
||||
|
||||
expect(result).not.toMatch(/_+$/)
|
||||
expect(result.length).toBeLessThanOrEqual(63)
|
||||
})
|
||||
|
||||
it('should handle max length edge case exactly', () => {
|
||||
// mcp__ (5) + server (20) + __ (2) + tool (35) = 62 chars
|
||||
const server = 'a'.repeat(20)
|
||||
const tool = 'b'.repeat(35)
|
||||
const result = buildFunctionCallToolName(server, tool)
|
||||
|
||||
expect(result.length).toBe(62)
|
||||
expect(result).toBe(`mcp__${'a'.repeat(20)}__${'b'.repeat(35)}`)
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle empty server name', () => {
|
||||
const result = buildFunctionCallToolName('', 'tool')
|
||||
expect(result).toBe('mcp____tool')
|
||||
})
|
||||
|
||||
it('should handle empty tool name', () => {
|
||||
const result = buildFunctionCallToolName('server', '')
|
||||
expect(result).toBe('mcp__server__')
|
||||
})
|
||||
|
||||
it('should handle both empty names', () => {
|
||||
const result = buildFunctionCallToolName('', '')
|
||||
expect(result).toBe('mcp____')
|
||||
})
|
||||
|
||||
it('should handle whitespace-only names', () => {
|
||||
const result = buildFunctionCallToolName(' ', ' ')
|
||||
expect(result).toBe('mcp____')
|
||||
})
|
||||
|
||||
it('should trim whitespace from names', () => {
|
||||
const result = buildFunctionCallToolName(' server ', ' tool ')
|
||||
expect(result).toBe('mcp__server__tool')
|
||||
})
|
||||
|
||||
it('should handle unicode characters', () => {
|
||||
const result = buildFunctionCallToolName('服务器', '工具')
|
||||
// Unicode chars are replaced with underscores, then collapsed
|
||||
expect(result).toMatch(/^mcp__/)
|
||||
})
|
||||
|
||||
it('should handle mixed case', () => {
|
||||
const result = buildFunctionCallToolName('MyServer', 'MyTool')
|
||||
expect(result).toBe('mcp__MyServer__MyTool')
|
||||
})
|
||||
})
|
||||
|
||||
describe('deterministic output', () => {
|
||||
it('should produce consistent results for same input', () => {
|
||||
const serverName = 'github'
|
||||
const toolName = 'search_repos'
|
||||
|
||||
const result1 = buildFunctionCallToolName(serverName, toolName)
|
||||
const result2 = buildFunctionCallToolName(serverName, toolName)
|
||||
const result3 = buildFunctionCallToolName(serverName, toolName)
|
||||
|
||||
expect(result1).toBe(result2)
|
||||
expect(result2).toBe(result3)
|
||||
})
|
||||
|
||||
it('should include serverId suffix when provided', () => {
|
||||
const serverId = 'abc123def456'
|
||||
const result = buildFunctionCallToolName('server', 'tool', serverId)
|
||||
it('should produce different results for different inputs', () => {
|
||||
const result1 = buildFunctionCallToolName('server1', 'tool')
|
||||
const result2 = buildFunctionCallToolName('server2', 'tool')
|
||||
const result3 = buildFunctionCallToolName('server', 'tool1')
|
||||
const result4 = buildFunctionCallToolName('server', 'tool2')
|
||||
|
||||
// Should include last 6 chars of serverId
|
||||
expect(result).toContain('ef456')
|
||||
})
|
||||
})
|
||||
|
||||
describe('character sanitization', () => {
|
||||
it('should replace invalid characters with underscores', () => {
|
||||
const result = buildFunctionCallToolName('test@server', 'tool#name')
|
||||
expect(result).not.toMatch(/[@#]/)
|
||||
expect(result).toMatch(/^[a-zA-Z0-9_-]+$/)
|
||||
})
|
||||
|
||||
it('should ensure name starts with a letter', () => {
|
||||
const result = buildFunctionCallToolName('123server', '456tool')
|
||||
expect(result).toMatch(/^[a-zA-Z]/)
|
||||
})
|
||||
|
||||
it('should handle consecutive underscores/dashes', () => {
|
||||
const result = buildFunctionCallToolName('my--server', 'my__tool')
|
||||
expect(result).not.toMatch(/[_-]{2,}/)
|
||||
})
|
||||
})
|
||||
|
||||
describe('length constraints', () => {
|
||||
it('should truncate names longer than 63 characters', () => {
|
||||
const longServerName = 'a'.repeat(50)
|
||||
const longToolName = 'b'.repeat(50)
|
||||
const result = buildFunctionCallToolName(longServerName, longToolName, 'id123456')
|
||||
|
||||
expect(result.length).toBeLessThanOrEqual(63)
|
||||
})
|
||||
|
||||
it('should not end with underscore or dash after truncation', () => {
|
||||
const longServerName = 'a'.repeat(50)
|
||||
const longToolName = 'b'.repeat(50)
|
||||
const result = buildFunctionCallToolName(longServerName, longToolName, 'id123456')
|
||||
|
||||
expect(result).not.toMatch(/[_-]$/)
|
||||
})
|
||||
|
||||
it('should preserve serverId suffix even with long server/tool names', () => {
|
||||
const longServerName = 'a'.repeat(50)
|
||||
const longToolName = 'b'.repeat(50)
|
||||
const serverId = 'server-id-xyz789'
|
||||
|
||||
const result = buildFunctionCallToolName(longServerName, longToolName, serverId)
|
||||
|
||||
// The suffix should be preserved and not truncated
|
||||
expect(result).toContain('xyz789')
|
||||
expect(result.length).toBeLessThanOrEqual(63)
|
||||
})
|
||||
|
||||
it('should ensure two long-named servers with different IDs produce different results', () => {
|
||||
const longServerName = 'a'.repeat(50)
|
||||
const longToolName = 'b'.repeat(50)
|
||||
const serverId1 = 'server-id-abc123'
|
||||
const serverId2 = 'server-id-def456'
|
||||
|
||||
const result1 = buildFunctionCallToolName(longServerName, longToolName, serverId1)
|
||||
const result2 = buildFunctionCallToolName(longServerName, longToolName, serverId2)
|
||||
|
||||
// Both should be within limit
|
||||
expect(result1.length).toBeLessThanOrEqual(63)
|
||||
expect(result2.length).toBeLessThanOrEqual(63)
|
||||
|
||||
// They should be different due to preserved suffix
|
||||
expect(result1).not.toBe(result2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases with serverId', () => {
|
||||
it('should handle serverId with only non-alphanumeric characters', () => {
|
||||
const serverId = '------' // All dashes
|
||||
const result = buildFunctionCallToolName('server', 'tool', serverId)
|
||||
|
||||
// Should still produce a valid unique suffix via fallback hash
|
||||
expect(result).toBeTruthy()
|
||||
expect(result.length).toBeLessThanOrEqual(63)
|
||||
expect(result).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/)
|
||||
// Should have a suffix (underscore followed by something)
|
||||
expect(result).toMatch(/_[a-z0-9]+$/)
|
||||
})
|
||||
|
||||
it('should produce different results for different non-alphanumeric serverIds', () => {
|
||||
const serverId1 = '------'
|
||||
const serverId2 = '!!!!!!'
|
||||
|
||||
const result1 = buildFunctionCallToolName('server', 'tool', serverId1)
|
||||
const result2 = buildFunctionCallToolName('server', 'tool', serverId2)
|
||||
|
||||
// Should be different because the hash fallback produces different values
|
||||
expect(result1).not.toBe(result2)
|
||||
})
|
||||
|
||||
it('should handle empty string serverId differently from undefined', () => {
|
||||
const resultWithEmpty = buildFunctionCallToolName('server', 'tool', '')
|
||||
const resultWithUndefined = buildFunctionCallToolName('server', 'tool', undefined)
|
||||
|
||||
// Empty string is falsy, so both should behave the same (no suffix)
|
||||
expect(resultWithEmpty).toBe(resultWithUndefined)
|
||||
})
|
||||
|
||||
it('should handle serverId with mixed alphanumeric and special chars', () => {
|
||||
const serverId = 'ab@#cd' // Mixed chars, last 6 chars contain some alphanumeric
|
||||
const result = buildFunctionCallToolName('server', 'tool', serverId)
|
||||
|
||||
// Should extract alphanumeric chars: 'abcd' from 'ab@#cd'
|
||||
expect(result).toContain('abcd')
|
||||
expect(result3).not.toBe(result4)
|
||||
})
|
||||
})
|
||||
|
||||
describe('real-world scenarios', () => {
|
||||
it('should handle GitHub MCP server instances correctly', () => {
|
||||
const serverName = 'github'
|
||||
const toolName = 'search_repositories'
|
||||
|
||||
const githubComId = 'server-github-com-abc123'
|
||||
const gheId = 'server-ghe-internal-xyz789'
|
||||
|
||||
const tool1 = buildFunctionCallToolName(serverName, toolName, githubComId)
|
||||
const tool2 = buildFunctionCallToolName(serverName, toolName, gheId)
|
||||
|
||||
// Should be different
|
||||
expect(tool1).not.toBe(tool2)
|
||||
|
||||
// Both should be valid identifiers
|
||||
expect(tool1).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/)
|
||||
expect(tool2).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/)
|
||||
|
||||
// Both should be <= 63 chars
|
||||
expect(tool1.length).toBeLessThanOrEqual(63)
|
||||
expect(tool2.length).toBeLessThanOrEqual(63)
|
||||
it('should handle GitHub MCP server', () => {
|
||||
expect(buildFunctionCallToolName('github', 'create_issue')).toBe('mcp__github__create_issue')
|
||||
expect(buildFunctionCallToolName('github', 'search_repositories')).toBe('mcp__github__search_repositories')
|
||||
expect(buildFunctionCallToolName('github', 'get_pull_request')).toBe('mcp__github__get_pull_request')
|
||||
})
|
||||
|
||||
it('should handle tool names that already include server name prefix', () => {
|
||||
const result = buildFunctionCallToolName('github', 'github_search_repos')
|
||||
expect(result).toBeTruthy()
|
||||
// Should not double the server name
|
||||
expect(result.split('github').length - 1).toBeLessThanOrEqual(2)
|
||||
it('should handle filesystem MCP server', () => {
|
||||
expect(buildFunctionCallToolName('filesystem', 'read_file')).toBe('mcp__filesystem__read_file')
|
||||
expect(buildFunctionCallToolName('filesystem', 'write_file')).toBe('mcp__filesystem__write_file')
|
||||
expect(buildFunctionCallToolName('filesystem', 'list_directory')).toBe('mcp__filesystem__list_directory')
|
||||
})
|
||||
|
||||
it('should handle hyphenated server names (common in npm packages)', () => {
|
||||
expect(buildFunctionCallToolName('cherry-fetch', 'get_page')).toBe('mcp__cherry_fetch__get_page')
|
||||
expect(buildFunctionCallToolName('mcp-server-github', 'search')).toBe('mcp__mcp_server_github__search')
|
||||
})
|
||||
|
||||
it('should handle scoped npm package style names', () => {
|
||||
const result = buildFunctionCallToolName('@anthropic/mcp-server', 'chat')
|
||||
expect(result).toBe('mcp__anthropic_mcp_server__chat')
|
||||
})
|
||||
|
||||
it('should handle tools with long descriptive names', () => {
|
||||
const result = buildFunctionCallToolName('github', 'search_repositories_by_language_and_stars')
|
||||
expect(result.length).toBeLessThanOrEqual(63)
|
||||
expect(result).toMatch(/^mcp__github__search_repositories_by_lan/)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,10 +1,17 @@
|
||||
import { configManager } from '@main/services/ConfigManager'
|
||||
import { execFileSync } from 'child_process'
|
||||
import { execFileSync, spawn } from 'child_process'
|
||||
import { EventEmitter } from 'events'
|
||||
import fs from 'fs'
|
||||
import path from 'path'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { autoDiscoverGitBash, findExecutable, findGitBash, validateGitBashPath } from '../process'
|
||||
import {
|
||||
autoDiscoverGitBash,
|
||||
findCommandInShellEnv,
|
||||
findExecutable,
|
||||
findGitBash,
|
||||
validateGitBashPath
|
||||
} from '../process'
|
||||
|
||||
// Mock configManager
|
||||
vi.mock('@main/services/ConfigManager', () => ({
|
||||
@ -988,3 +995,244 @@ describe.skipIf(process.platform !== 'win32')('process utilities', () => {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
/**
|
||||
* Helper to create a mock child process for spawn
|
||||
*/
|
||||
function createMockChildProcess() {
|
||||
const mockChild = new EventEmitter() as EventEmitter & {
|
||||
stdout: EventEmitter
|
||||
stderr: EventEmitter
|
||||
kill: ReturnType<typeof vi.fn>
|
||||
}
|
||||
mockChild.stdout = new EventEmitter()
|
||||
mockChild.stderr = new EventEmitter()
|
||||
mockChild.kill = vi.fn()
|
||||
return mockChild
|
||||
}
|
||||
|
||||
describe('findCommandInShellEnv', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
// Reset path.isAbsolute to real implementation for these tests
|
||||
vi.mocked(path.isAbsolute).mockImplementation((p) => p.startsWith('/') || /^[A-Z]:/i.test(p))
|
||||
})
|
||||
|
||||
describe('command name validation', () => {
|
||||
it('should reject empty command name', async () => {
|
||||
const result = await findCommandInShellEnv('', {})
|
||||
expect(result).toBeNull()
|
||||
expect(spawn).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should reject command names with shell metacharacters', async () => {
|
||||
const maliciousCommands = [
|
||||
'npx; rm -rf /',
|
||||
'npx && malicious',
|
||||
'npx | cat /etc/passwd',
|
||||
'npx`whoami`',
|
||||
'$(whoami)',
|
||||
'npx\nmalicious'
|
||||
]
|
||||
|
||||
for (const cmd of maliciousCommands) {
|
||||
const result = await findCommandInShellEnv(cmd, {})
|
||||
expect(result).toBeNull()
|
||||
expect(spawn).not.toHaveBeenCalled()
|
||||
}
|
||||
})
|
||||
|
||||
it('should reject command names starting with hyphen', async () => {
|
||||
const result = await findCommandInShellEnv('-npx', {})
|
||||
expect(result).toBeNull()
|
||||
expect(spawn).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should reject path traversal attempts', async () => {
|
||||
const pathTraversalCommands = ['../npx', '../../malicious', 'foo/bar', 'foo\\bar']
|
||||
|
||||
for (const cmd of pathTraversalCommands) {
|
||||
const result = await findCommandInShellEnv(cmd, {})
|
||||
expect(result).toBeNull()
|
||||
expect(spawn).not.toHaveBeenCalled()
|
||||
}
|
||||
})
|
||||
|
||||
it('should reject command names exceeding max length', async () => {
|
||||
const longCommand = 'a'.repeat(129)
|
||||
const result = await findCommandInShellEnv(longCommand, {})
|
||||
expect(result).toBeNull()
|
||||
expect(spawn).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should accept valid command names', async () => {
|
||||
const mockChild = createMockChildProcess()
|
||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
||||
|
||||
// Don't await - just start the call
|
||||
const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' })
|
||||
|
||||
// Simulate command not found
|
||||
mockChild.emit('close', 1)
|
||||
|
||||
const result = await resultPromise
|
||||
expect(result).toBeNull()
|
||||
expect(spawn).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should accept command names with underscores and hyphens', async () => {
|
||||
const mockChild = createMockChildProcess()
|
||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
||||
|
||||
const resultPromise = findCommandInShellEnv('my_command-name', { PATH: '/usr/bin' })
|
||||
mockChild.emit('close', 1)
|
||||
|
||||
await resultPromise
|
||||
expect(spawn).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should accept command names at max length (128 chars)', async () => {
|
||||
const mockChild = createMockChildProcess()
|
||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
||||
|
||||
const maxLengthCommand = 'a'.repeat(128)
|
||||
const resultPromise = findCommandInShellEnv(maxLengthCommand, { PATH: '/usr/bin' })
|
||||
mockChild.emit('close', 1)
|
||||
|
||||
await resultPromise
|
||||
expect(spawn).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe.skipIf(process.platform === 'win32')('Unix/macOS behavior', () => {
|
||||
it('should find command and return absolute path', async () => {
|
||||
const mockChild = createMockChildProcess()
|
||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
||||
|
||||
const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' })
|
||||
|
||||
// Simulate successful command -v output
|
||||
mockChild.stdout.emit('data', '/usr/local/bin/npx\n')
|
||||
mockChild.emit('close', 0)
|
||||
|
||||
const result = await resultPromise
|
||||
expect(result).toBe('/usr/local/bin/npx')
|
||||
expect(spawn).toHaveBeenCalledWith('/bin/sh', ['-c', 'command -v "$1"', '--', 'npx'], expect.any(Object))
|
||||
})
|
||||
|
||||
it('should return null for non-absolute paths (aliases/builtins)', async () => {
|
||||
const mockChild = createMockChildProcess()
|
||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
||||
|
||||
const resultPromise = findCommandInShellEnv('cd', { PATH: '/usr/bin' })
|
||||
|
||||
// Simulate builtin output (just command name)
|
||||
mockChild.stdout.emit('data', 'cd\n')
|
||||
mockChild.emit('close', 0)
|
||||
|
||||
const result = await resultPromise
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null when command not found', async () => {
|
||||
const mockChild = createMockChildProcess()
|
||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
||||
|
||||
const resultPromise = findCommandInShellEnv('nonexistent', { PATH: '/usr/bin' })
|
||||
|
||||
// Simulate command not found (exit code 1)
|
||||
mockChild.emit('close', 1)
|
||||
|
||||
const result = await resultPromise
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('should handle spawn errors gracefully', async () => {
|
||||
const mockChild = createMockChildProcess()
|
||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
||||
|
||||
const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' })
|
||||
|
||||
// Simulate spawn error
|
||||
mockChild.emit('error', new Error('spawn failed'))
|
||||
|
||||
const result = await resultPromise
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('should handle timeout gracefully', async () => {
|
||||
vi.useFakeTimers()
|
||||
const mockChild = createMockChildProcess()
|
||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
||||
|
||||
const resultPromise = findCommandInShellEnv('npx', { PATH: '/usr/bin' })
|
||||
|
||||
// Fast-forward past timeout (5000ms)
|
||||
vi.advanceTimersByTime(6000)
|
||||
|
||||
const result = await resultPromise
|
||||
expect(result).toBeNull()
|
||||
expect(mockChild.kill).toHaveBeenCalledWith('SIGKILL')
|
||||
|
||||
vi.useRealTimers()
|
||||
})
|
||||
})
|
||||
|
||||
describe.skipIf(process.platform !== 'win32')('Windows behavior', () => {
|
||||
it('should find .exe files via where command', async () => {
|
||||
const mockChild = createMockChildProcess()
|
||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
||||
|
||||
const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' })
|
||||
|
||||
// Simulate where output
|
||||
mockChild.stdout.emit('data', 'C:\\Program Files\\nodejs\\npx.exe\r\n')
|
||||
mockChild.emit('close', 0)
|
||||
|
||||
const result = await resultPromise
|
||||
expect(result).toBe('C:\\Program Files\\nodejs\\npx.exe')
|
||||
expect(spawn).toHaveBeenCalledWith('where', ['npx'], expect.any(Object))
|
||||
})
|
||||
|
||||
it('should reject .cmd files on Windows', async () => {
|
||||
const mockChild = createMockChildProcess()
|
||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
||||
|
||||
const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' })
|
||||
|
||||
// Simulate where output with only .cmd file
|
||||
mockChild.stdout.emit('data', 'C:\\Program Files\\nodejs\\npx.cmd\r\n')
|
||||
mockChild.emit('close', 0)
|
||||
|
||||
const result = await resultPromise
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('should prefer .exe over .cmd when both exist', async () => {
|
||||
const mockChild = createMockChildProcess()
|
||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
||||
|
||||
const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' })
|
||||
|
||||
// Simulate where output with both .cmd and .exe
|
||||
mockChild.stdout.emit('data', 'C:\\Program Files\\nodejs\\npx.cmd\r\nC:\\Program Files\\nodejs\\npx.exe\r\n')
|
||||
mockChild.emit('close', 0)
|
||||
|
||||
const result = await resultPromise
|
||||
expect(result).toBe('C:\\Program Files\\nodejs\\npx.exe')
|
||||
})
|
||||
|
||||
it('should handle spawn errors gracefully', async () => {
|
||||
const mockChild = createMockChildProcess()
|
||||
vi.mocked(spawn).mockReturnValue(mockChild as never)
|
||||
|
||||
const resultPromise = findCommandInShellEnv('npx', { PATH: 'C:\\nodejs' })
|
||||
|
||||
// Simulate spawn error
|
||||
mockChild.emit('error', new Error('spawn failed'))
|
||||
|
||||
const result = await resultPromise
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,56 +1,28 @@
|
||||
export function buildFunctionCallToolName(serverName: string, toolName: string, serverId?: string) {
|
||||
const sanitizedServer = serverName.trim().replace(/-/g, '_')
|
||||
const sanitizedTool = toolName.trim().replace(/-/g, '_')
|
||||
/**
|
||||
* Builds a valid JavaScript function name for MCP tool calls.
|
||||
* Format: mcp__{server_name}__{tool_name}
|
||||
*
|
||||
* @param serverName - The MCP server name
|
||||
* @param toolName - The tool name from the server
|
||||
* @returns A valid JS identifier in format mcp__{server}__{tool}, max 63 chars
|
||||
*/
|
||||
export function buildFunctionCallToolName(serverName: string, toolName: string): string {
|
||||
// Sanitize to valid JS identifier chars (alphanumeric + underscore only)
|
||||
const sanitize = (str: string): string =>
|
||||
str
|
||||
.trim()
|
||||
.replace(/[^a-zA-Z0-9]/g, '_') // Replace all non-alphanumeric with underscore
|
||||
.replace(/_{2,}/g, '_') // Collapse multiple underscores
|
||||
.replace(/^_+|_+$/g, '') // Trim leading/trailing underscores
|
||||
|
||||
// Calculate suffix first to reserve space for it
|
||||
// Suffix format: "_" + 6 alphanumeric chars = 7 chars total
|
||||
let serverIdSuffix = ''
|
||||
if (serverId) {
|
||||
// Take the last 6 characters of the serverId for brevity
|
||||
serverIdSuffix = serverId.slice(-6).replace(/[^a-zA-Z0-9]/g, '')
|
||||
const server = sanitize(serverName).slice(0, 20) // Keep server name short
|
||||
const tool = sanitize(toolName).slice(0, 35) // More room for tool name
|
||||
|
||||
// Fallback: if suffix becomes empty (all non-alphanumeric chars), use a simple hash
|
||||
if (!serverIdSuffix) {
|
||||
const hash = serverId.split('').reduce((acc, char) => acc + char.charCodeAt(0), 0)
|
||||
serverIdSuffix = hash.toString(36).slice(-6) || 'x'
|
||||
}
|
||||
}
|
||||
let name = `mcp__${server}__${tool}`
|
||||
|
||||
// Reserve space for suffix when calculating max base name length
|
||||
const SUFFIX_LENGTH = serverIdSuffix ? serverIdSuffix.length + 1 : 0 // +1 for underscore
|
||||
const MAX_BASE_LENGTH = 63 - SUFFIX_LENGTH
|
||||
|
||||
// Combine server name and tool name
|
||||
let name = sanitizedTool
|
||||
if (!sanitizedTool.includes(sanitizedServer.slice(0, 7))) {
|
||||
name = `${sanitizedServer.slice(0, 7) || ''}-${sanitizedTool || ''}`
|
||||
}
|
||||
|
||||
// Replace invalid characters with underscores or dashes
|
||||
// Keep a-z, A-Z, 0-9, underscores and dashes
|
||||
name = name.replace(/[^a-zA-Z0-9_-]/g, '_')
|
||||
|
||||
// Ensure name starts with a letter or underscore (for valid JavaScript identifier)
|
||||
if (!/^[a-zA-Z]/.test(name)) {
|
||||
name = `tool-${name}`
|
||||
}
|
||||
|
||||
// Remove consecutive underscores/dashes (optional improvement)
|
||||
name = name.replace(/[_-]{2,}/g, '_')
|
||||
|
||||
// Truncate base name BEFORE adding suffix to ensure suffix is never cut off
|
||||
if (name.length > MAX_BASE_LENGTH) {
|
||||
name = name.slice(0, MAX_BASE_LENGTH)
|
||||
}
|
||||
|
||||
// Handle edge case: ensure we still have a valid name if truncation left invalid chars at edges
|
||||
if (name.endsWith('_') || name.endsWith('-')) {
|
||||
name = name.slice(0, -1)
|
||||
}
|
||||
|
||||
// Now append the suffix - it will always fit within 63 chars
|
||||
if (serverIdSuffix) {
|
||||
name = `${name}_${serverIdSuffix}`
|
||||
// Ensure max 63 chars and clean trailing underscores
|
||||
if (name.length > 63) {
|
||||
name = name.slice(0, 63).replace(/_+$/, '')
|
||||
}
|
||||
|
||||
return name
|
||||
|
||||
@ -64,6 +64,145 @@ export async function isBinaryExists(name: string): Promise<boolean> {
|
||||
return fs.existsSync(cmd)
|
||||
}
|
||||
|
||||
// Timeout for command lookup operations (in milliseconds)
|
||||
const COMMAND_LOOKUP_TIMEOUT_MS = 5000
|
||||
|
||||
// Regex to validate command names - must start with alphanumeric or underscore, max 128 chars
|
||||
const VALID_COMMAND_NAME_REGEX = /^[a-zA-Z0-9_][a-zA-Z0-9_-]{0,127}$/
|
||||
|
||||
// Maximum output size to prevent buffer overflow (10KB)
|
||||
const MAX_OUTPUT_SIZE = 10240
|
||||
|
||||
/**
|
||||
* Check if a command is available in the user's login shell environment
|
||||
* @param command - Command name to check (e.g., 'npx', 'uvx')
|
||||
* @param loginShellEnv - The login shell environment from getLoginShellEnvironment()
|
||||
* @returns Full path to the command if found, null otherwise
|
||||
*/
|
||||
export async function findCommandInShellEnv(
|
||||
command: string,
|
||||
loginShellEnv: Record<string, string>
|
||||
): Promise<string | null> {
|
||||
// Validate command name to prevent command injection
|
||||
if (!VALID_COMMAND_NAME_REGEX.test(command)) {
|
||||
logger.warn(`Invalid command name '${command}' - must only contain alphanumeric characters, underscore, or hyphen`)
|
||||
return null
|
||||
}
|
||||
|
||||
return new Promise((resolve) => {
|
||||
let resolved = false
|
||||
|
||||
const safeResolve = (value: string | null) => {
|
||||
if (resolved) return
|
||||
resolved = true
|
||||
resolve(value)
|
||||
}
|
||||
|
||||
if (isWin) {
|
||||
// On Windows, use 'where' command
|
||||
const child = spawn('where', [command], {
|
||||
env: loginShellEnv,
|
||||
stdio: ['ignore', 'pipe', 'pipe']
|
||||
})
|
||||
|
||||
let output = ''
|
||||
const timeoutId = setTimeout(() => {
|
||||
if (resolved) return
|
||||
child.kill('SIGKILL')
|
||||
logger.debug(`Timeout checking command '${command}' on Windows`)
|
||||
safeResolve(null)
|
||||
}, COMMAND_LOOKUP_TIMEOUT_MS)
|
||||
|
||||
child.stdout.on('data', (data) => {
|
||||
if (output.length < MAX_OUTPUT_SIZE) {
|
||||
output += data.toString()
|
||||
}
|
||||
})
|
||||
|
||||
child.on('close', (code) => {
|
||||
clearTimeout(timeoutId)
|
||||
if (resolved) return
|
||||
|
||||
if (code === 0 && output.trim()) {
|
||||
const paths = output.trim().split(/\r?\n/)
|
||||
// Only accept .exe files on Windows - .cmd/.bat files cannot be executed
|
||||
// with spawn({ shell: false }) which is used by MCP SDK's StdioClientTransport
|
||||
const exePath = paths.find((p) => p.toLowerCase().endsWith('.exe'))
|
||||
if (exePath) {
|
||||
logger.debug(`Found command '${command}' at: ${exePath}`)
|
||||
safeResolve(exePath)
|
||||
} else {
|
||||
logger.debug(`Command '${command}' found but not as .exe (${paths[0]}), treating as not found`)
|
||||
safeResolve(null)
|
||||
}
|
||||
} else {
|
||||
logger.debug(`Command '${command}' not found in shell environment`)
|
||||
safeResolve(null)
|
||||
}
|
||||
})
|
||||
|
||||
child.on('error', (error) => {
|
||||
clearTimeout(timeoutId)
|
||||
if (resolved) return
|
||||
logger.warn(`Error checking command '${command}':`, { error, platform: 'windows' })
|
||||
safeResolve(null)
|
||||
})
|
||||
} else {
|
||||
// Unix/Linux/macOS: use 'command -v' which is POSIX standard
|
||||
// Use /bin/sh for reliability - it's POSIX compliant and always available
|
||||
// This avoids issues with user's custom shell (csh, fish, etc.)
|
||||
// SECURITY: Use positional parameter $1 to prevent command injection
|
||||
const child = spawn('/bin/sh', ['-c', 'command -v "$1"', '--', command], {
|
||||
env: loginShellEnv,
|
||||
stdio: ['ignore', 'pipe', 'pipe']
|
||||
})
|
||||
|
||||
let output = ''
|
||||
const timeoutId = setTimeout(() => {
|
||||
if (resolved) return
|
||||
child.kill('SIGKILL')
|
||||
logger.debug(`Timeout checking command '${command}'`)
|
||||
safeResolve(null)
|
||||
}, COMMAND_LOOKUP_TIMEOUT_MS)
|
||||
|
||||
child.stdout.on('data', (data) => {
|
||||
if (output.length < MAX_OUTPUT_SIZE) {
|
||||
output += data.toString()
|
||||
}
|
||||
})
|
||||
|
||||
child.on('close', (code) => {
|
||||
clearTimeout(timeoutId)
|
||||
if (resolved) return
|
||||
|
||||
if (code === 0 && output.trim()) {
|
||||
const commandPath = output.trim().split('\n')[0]
|
||||
|
||||
// Validate the output is an absolute path (not an alias, function, or builtin)
|
||||
// command -v can return just the command name for aliases/builtins
|
||||
if (path.isAbsolute(commandPath)) {
|
||||
logger.debug(`Found command '${command}' at: ${commandPath}`)
|
||||
safeResolve(commandPath)
|
||||
} else {
|
||||
logger.debug(`Command '${command}' resolved to non-path '${commandPath}', treating as not found`)
|
||||
safeResolve(null)
|
||||
}
|
||||
} else {
|
||||
logger.debug(`Command '${command}' not found in shell environment`)
|
||||
safeResolve(null)
|
||||
}
|
||||
})
|
||||
|
||||
child.on('error', (error) => {
|
||||
clearTimeout(timeoutId)
|
||||
if (resolved) return
|
||||
logger.warn(`Error checking command '${command}':`, { error, platform: 'unix' })
|
||||
safeResolve(null)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Find executable in common paths or PATH environment variable
|
||||
* Based on Claude Code's implementation with security checks
|
||||
|
||||
19
src/main/utils/system.ts
Normal file
19
src/main/utils/system.ts
Normal file
@ -0,0 +1,19 @@
|
||||
import os from 'node:os'
|
||||
|
||||
import { isMac, isWin } from '@main/constant'
|
||||
|
||||
export const getDeviceType = () => (isMac ? 'mac' : isWin ? 'windows' : 'linux')
|
||||
|
||||
export const getHostname = () => os.hostname()
|
||||
|
||||
export const getCpuName = () => {
|
||||
try {
|
||||
const cpus = os.cpus()
|
||||
if (!cpus || cpus.length === 0 || !cpus[0].model) {
|
||||
return 'Unknown CPU'
|
||||
}
|
||||
return cpus[0].model
|
||||
} catch {
|
||||
return 'Unknown CPU'
|
||||
}
|
||||
}
|
||||
@ -4,7 +4,15 @@ import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
||||
import type { SpanContext } from '@opentelemetry/api'
|
||||
import type { GitBashPathInfo, TerminalConfig, UpgradeChannel } from '@shared/config/constant'
|
||||
import type { LogLevel, LogSourceWithContext } from '@shared/config/logger'
|
||||
import type { FileChangeEvent, WebviewKeyEvent } from '@shared/config/types'
|
||||
import type {
|
||||
FileChangeEvent,
|
||||
LanClientEvent,
|
||||
LanFileCompleteMessage,
|
||||
LanHandshakeAckMessage,
|
||||
LocalTransferConnectPayload,
|
||||
LocalTransferState,
|
||||
WebviewKeyEvent
|
||||
} from '@shared/config/types'
|
||||
import type { MCPServerLogEntry } from '@shared/config/types'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import type { Notification } from '@types'
|
||||
@ -172,7 +180,11 @@ const api = {
|
||||
listS3Files: (s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_ListS3Files, s3Config),
|
||||
deleteS3File: (fileName: string, s3Config: S3Config) =>
|
||||
ipcRenderer.invoke(IpcChannel.Backup_DeleteS3File, fileName, s3Config),
|
||||
checkS3Connection: (s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_CheckS3Connection, s3Config)
|
||||
checkS3Connection: (s3Config: S3Config) => ipcRenderer.invoke(IpcChannel.Backup_CheckS3Connection, s3Config),
|
||||
createLanTransferBackup: (data: string): Promise<string> =>
|
||||
ipcRenderer.invoke(IpcChannel.Backup_CreateLanTransferBackup, data),
|
||||
deleteTempBackup: (filePath: string): Promise<boolean> =>
|
||||
ipcRenderer.invoke(IpcChannel.Backup_DeleteTempBackup, filePath)
|
||||
},
|
||||
file: {
|
||||
select: (options?: OpenDialogOptions): Promise<FileMetadata[] | null> =>
|
||||
@ -298,7 +310,8 @@ const api = {
|
||||
deleteUser: (userId: string) => ipcRenderer.invoke(IpcChannel.Memory_DeleteUser, userId),
|
||||
deleteAllMemoriesForUser: (userId: string) =>
|
||||
ipcRenderer.invoke(IpcChannel.Memory_DeleteAllMemoriesForUser, userId),
|
||||
getUsersList: () => ipcRenderer.invoke(IpcChannel.Memory_GetUsersList)
|
||||
getUsersList: () => ipcRenderer.invoke(IpcChannel.Memory_GetUsersList),
|
||||
migrateMemoryDb: () => ipcRenderer.invoke(IpcChannel.Memory_MigrateMemoryDb)
|
||||
},
|
||||
window: {
|
||||
setMinimumSize: (width: number, height: number) =>
|
||||
@ -327,6 +340,7 @@ const api = {
|
||||
ipcRenderer.invoke(IpcChannel.VertexAI_ClearAuthCache, projectId, clientEmail)
|
||||
},
|
||||
ovms: {
|
||||
isSupported: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.Ovms_IsSupported),
|
||||
addModel: (modelName: string, modelId: string, modelSource: string, task: string) =>
|
||||
ipcRenderer.invoke(IpcChannel.Ovms_AddModel, modelName, modelId, modelSource, task),
|
||||
stopAddModel: () => ipcRenderer.invoke(IpcChannel.Ovms_StopAddModel),
|
||||
@ -429,7 +443,7 @@ const api = {
|
||||
ipcRenderer.invoke(IpcChannel.Nutstore_GetDirectoryContents, token, path)
|
||||
},
|
||||
searchService: {
|
||||
openSearchWindow: (uid: string) => ipcRenderer.invoke(IpcChannel.SearchWindow_Open, uid),
|
||||
openSearchWindow: (uid: string, show?: boolean) => ipcRenderer.invoke(IpcChannel.SearchWindow_Open, uid, show),
|
||||
closeSearchWindow: (uid: string) => ipcRenderer.invoke(IpcChannel.SearchWindow_Close, uid),
|
||||
openUrlInSearchWindow: (uid: string, url: string) => ipcRenderer.invoke(IpcChannel.SearchWindow_OpenUrl, uid, url)
|
||||
},
|
||||
@ -589,12 +603,32 @@ const api = {
|
||||
writeContent: (options: WritePluginContentOptions): Promise<PluginResult<void>> =>
|
||||
ipcRenderer.invoke(IpcChannel.ClaudeCodePlugin_WriteContent, options)
|
||||
},
|
||||
webSocket: {
|
||||
start: () => ipcRenderer.invoke(IpcChannel.WebSocket_Start),
|
||||
stop: () => ipcRenderer.invoke(IpcChannel.WebSocket_Stop),
|
||||
status: () => ipcRenderer.invoke(IpcChannel.WebSocket_Status),
|
||||
sendFile: (filePath: string) => ipcRenderer.invoke(IpcChannel.WebSocket_SendFile, filePath),
|
||||
getAllCandidates: () => ipcRenderer.invoke(IpcChannel.WebSocket_GetAllCandidates)
|
||||
localTransfer: {
|
||||
getState: (): Promise<LocalTransferState> => ipcRenderer.invoke(IpcChannel.LocalTransfer_ListServices),
|
||||
startScan: (): Promise<LocalTransferState> => ipcRenderer.invoke(IpcChannel.LocalTransfer_StartScan),
|
||||
stopScan: (): Promise<LocalTransferState> => ipcRenderer.invoke(IpcChannel.LocalTransfer_StopScan),
|
||||
connect: (payload: LocalTransferConnectPayload): Promise<LanHandshakeAckMessage> =>
|
||||
ipcRenderer.invoke(IpcChannel.LocalTransfer_Connect, payload),
|
||||
disconnect: (): Promise<void> => ipcRenderer.invoke(IpcChannel.LocalTransfer_Disconnect),
|
||||
onServicesUpdated: (callback: (state: LocalTransferState) => void): (() => void) => {
|
||||
const channel = IpcChannel.LocalTransfer_ServicesUpdated
|
||||
const listener = (_: Electron.IpcRendererEvent, state: LocalTransferState) => callback(state)
|
||||
ipcRenderer.on(channel, listener)
|
||||
return () => {
|
||||
ipcRenderer.removeListener(channel, listener)
|
||||
}
|
||||
},
|
||||
onClientEvent: (callback: (event: LanClientEvent) => void): (() => void) => {
|
||||
const channel = IpcChannel.LocalTransfer_ClientEvent
|
||||
const listener = (_: Electron.IpcRendererEvent, event: LanClientEvent) => callback(event)
|
||||
ipcRenderer.on(channel, listener)
|
||||
return () => {
|
||||
ipcRenderer.removeListener(channel, listener)
|
||||
}
|
||||
},
|
||||
sendFile: (filePath: string): Promise<LanFileCompleteMessage> =>
|
||||
ipcRenderer.invoke(IpcChannel.LocalTransfer_SendFile, { filePath }),
|
||||
cancelTransfer: (): Promise<void> => ipcRenderer.invoke(IpcChannel.LocalTransfer_CancelTransfer)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,38 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { normalizeAzureOpenAIEndpoint } from '../openai/azureOpenAIEndpoint'
|
||||
|
||||
describe('normalizeAzureOpenAIEndpoint', () => {
|
||||
it.each([
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/openai',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/openai/',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/openai/v1',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/openai/v1/',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/OPENAI/V1',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
}
|
||||
])('strips trailing /openai from $apiHost', ({ apiHost, expectedEndpoint }) => {
|
||||
expect(normalizeAzureOpenAIEndpoint(apiHost)).toBe(expectedEndpoint)
|
||||
})
|
||||
})
|
||||
@ -46,7 +46,6 @@ import type {
|
||||
GeminiSdkRawOutput,
|
||||
GeminiSdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import { getTrailingApiVersion, withoutTrailingApiVersion } from '@renderer/utils'
|
||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||
import {
|
||||
geminiFunctionCallToMcpTool,
|
||||
@ -56,6 +55,7 @@ import {
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { defaultTimeout, MB } from '@shared/config/constant'
|
||||
import { getTrailingApiVersion, withoutTrailingApiVersion } from '@shared/utils'
|
||||
import { t } from 'i18next'
|
||||
|
||||
import type { GenericChunk } from '../../middleware/schemas'
|
||||
|
||||
@ -29,6 +29,7 @@ import { withoutTrailingSlash } from '@renderer/utils/api'
|
||||
import { isOllamaProvider } from '@renderer/utils/provider'
|
||||
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { normalizeAzureOpenAIEndpoint } from './azureOpenAIEndpoint'
|
||||
|
||||
const logger = loggerService.withContext('OpenAIBaseClient')
|
||||
|
||||
@ -213,7 +214,7 @@ export abstract class OpenAIBaseClient<
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: apiKeyForSdkInstance,
|
||||
apiVersion: this.provider.apiVersion,
|
||||
endpoint: this.provider.apiHost
|
||||
endpoint: normalizeAzureOpenAIEndpoint(this.provider.apiHost)
|
||||
}) as TSdkInstance
|
||||
} else {
|
||||
this.sdkInstance = new OpenAI({
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
export function normalizeAzureOpenAIEndpoint(apiHost: string): string {
|
||||
const normalizedHost = apiHost.replace(/\/+$/, '')
|
||||
return normalizedHost.replace(/\/openai(?:\/v1)?$/i, '')
|
||||
}
|
||||
@ -3,7 +3,8 @@ import { loggerService } from '@logger'
|
||||
import { isSupportedModel } from '@renderer/config/models'
|
||||
import type { Provider } from '@renderer/types'
|
||||
import { objectKeys } from '@renderer/types'
|
||||
import { formatApiHost, withoutTrailingApiVersion } from '@renderer/utils'
|
||||
import { formatApiHost } from '@renderer/utils'
|
||||
import { withoutTrailingApiVersion } from '@shared/utils'
|
||||
|
||||
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
|
||||
|
||||
|
||||
@ -66,6 +66,11 @@ export class ZhipuAPIClient extends OpenAIAPIClient {
|
||||
|
||||
public async listModels(): Promise<OpenAI.Models.Model[]> {
|
||||
const models = [
|
||||
'glm-4.7',
|
||||
'glm-4.6',
|
||||
'glm-4.6v',
|
||||
'glm-4.6v-flash',
|
||||
'glm-4.6v-flashx',
|
||||
'glm-4.5',
|
||||
'glm-4.5-x',
|
||||
'glm-4.5-air',
|
||||
|
||||
@ -21,6 +21,7 @@ import {
|
||||
isGrokModel,
|
||||
isOpenAIModel,
|
||||
isOpenRouterBuiltInWebSearchModel,
|
||||
isPureGenerateImageModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isWebSearchModel
|
||||
@ -33,7 +34,7 @@ import { type Assistant, type MCPTool, type Provider, SystemProviderIds } from '
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
|
||||
import { replacePromptVariables } from '@renderer/utils/prompt'
|
||||
import { isAIGatewayProvider, isAwsBedrockProvider } from '@renderer/utils/provider'
|
||||
import { isAIGatewayProvider, isAwsBedrockProvider, isSupportUrlContextProvider } from '@renderer/utils/provider'
|
||||
import type { ModelMessage, Tool } from 'ai'
|
||||
import { stepCountIs } from 'ai'
|
||||
|
||||
@ -118,7 +119,13 @@ export async function buildStreamTextParams(
|
||||
isOpenRouterBuiltInWebSearchModel(model) ||
|
||||
model.id.includes('sonar'))
|
||||
|
||||
const enableUrlContext = assistant.enableUrlContext || false
|
||||
// Validate provider and model support to prevent stale state from triggering urlContext
|
||||
const enableUrlContext = !!(
|
||||
assistant.enableUrlContext &&
|
||||
isSupportUrlContextProvider(provider) &&
|
||||
!isPureGenerateImageModel(model) &&
|
||||
(isGeminiModel(model) || isAnthropicModel(model))
|
||||
)
|
||||
|
||||
const enableGenerateImage = !!(isGenerateImageModel(model) && assistant.enableGenerateImage)
|
||||
|
||||
|
||||
@ -24,7 +24,8 @@ export const memorySearchTool = () => {
|
||||
}
|
||||
|
||||
const memoryConfig = selectMemoryConfig(store.getState())
|
||||
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
||||
|
||||
if (!memoryConfig.llmModel || !memoryConfig.embeddingModel) {
|
||||
return []
|
||||
}
|
||||
|
||||
|
||||
@ -464,7 +464,8 @@ describe('options utils', () => {
|
||||
custom_param: 'custom_value',
|
||||
another_param: 123,
|
||||
serviceTier: undefined,
|
||||
textVerbosity: undefined
|
||||
textVerbosity: undefined,
|
||||
store: false
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@ -10,6 +10,7 @@ import {
|
||||
isAnthropicModel,
|
||||
isGeminiModel,
|
||||
isGrokModel,
|
||||
isInterleavedThinkingModel,
|
||||
isOpenAIModel,
|
||||
isOpenAIOpenWeightModel,
|
||||
isQwenMTModel,
|
||||
@ -396,10 +397,12 @@ function buildOpenAIProviderOptions(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 支持配置是否在服务端持久化
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
serviceTier,
|
||||
textVerbosity
|
||||
textVerbosity,
|
||||
store: false
|
||||
}
|
||||
|
||||
return {
|
||||
@ -577,8 +580,10 @@ function buildOllamaProviderOptions(
|
||||
const reasoningEffort = assistant.settings?.reasoning_effort
|
||||
if (enableReasoning) {
|
||||
if (isOpenAIOpenWeightModel(model)) {
|
||||
// @ts-ignore upstream type error
|
||||
providerOptions.think = reasoningEffort as any
|
||||
// For gpt-oss models, Ollama accepts: 'low' | 'medium' | 'high'
|
||||
if (reasoningEffort === 'low' || reasoningEffort === 'medium' || reasoningEffort === 'high') {
|
||||
providerOptions.think = reasoningEffort
|
||||
}
|
||||
} else {
|
||||
providerOptions.think = !['none', undefined].includes(reasoningEffort)
|
||||
}
|
||||
@ -601,7 +606,7 @@ function buildGenericProviderOptions(
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableWebSearch } = capabilities
|
||||
const { enableWebSearch, enableReasoning } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
const reasoningParams = getReasoningEffort(assistant, model)
|
||||
@ -609,6 +614,14 @@ function buildGenericProviderOptions(
|
||||
...providerOptions,
|
||||
...reasoningParams
|
||||
}
|
||||
if (enableReasoning) {
|
||||
if (isInterleavedThinkingModel(model)) {
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
sendReasoning: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
const webSearchParams = getWebSearchParams(model)
|
||||
|
||||
@ -14,7 +14,6 @@ import {
|
||||
isDoubaoSeedAfter251015,
|
||||
isDoubaoThinkingAutoModel,
|
||||
isGemini3ThinkingTokenModel,
|
||||
isGPT51SeriesModel,
|
||||
isGrok4FastReasoningModel,
|
||||
isOpenAIDeepResearchModel,
|
||||
isOpenAIModel,
|
||||
@ -32,7 +31,8 @@ import {
|
||||
isSupportedThinkingTokenMiMoModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isSupportedThinkingTokenQwenModel,
|
||||
isSupportedThinkingTokenZhipuModel
|
||||
isSupportedThinkingTokenZhipuModel,
|
||||
isSupportNoneReasoningEffortModel
|
||||
} from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
|
||||
@ -74,9 +74,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
||||
if (reasoningEffort === 'none') {
|
||||
// openrouter: use reasoning
|
||||
if (model.provider === SystemProviderIds.openrouter) {
|
||||
// 'none' is not an available value for effort for now.
|
||||
// I think they should resolve this issue soon, so I'll just go ahead and use this value.
|
||||
if (isGPT51SeriesModel(model) && reasoningEffort === 'none') {
|
||||
if (isSupportNoneReasoningEffortModel(model) && reasoningEffort === 'none') {
|
||||
return { reasoning: { effort: 'none' } }
|
||||
}
|
||||
return { reasoning: { enabled: false, exclude: true } }
|
||||
@ -120,8 +118,8 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
||||
return { thinking: { type: 'disabled' } }
|
||||
}
|
||||
|
||||
// Specially for GPT-5.1. Suppose this is a OpenAI Compatible provider
|
||||
if (isGPT51SeriesModel(model)) {
|
||||
// GPT 5.1, GPT 5.2, or newer
|
||||
if (isSupportNoneReasoningEffortModel(model)) {
|
||||
return {
|
||||
reasoningEffort: 'none'
|
||||
}
|
||||
|
||||
BIN
src/renderer/src/assets/images/apps/aistudio.png
Normal file
BIN
src/renderer/src/assets/images/apps/aistudio.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.2 KiB |
@ -1,27 +0,0 @@
|
||||
<svg width="256" height="256" viewBox="0 0 256 256" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="256" height="256" rx="32" fill="#0057CE"/>
|
||||
<mask id="path-2-inside-1_4113_89308" fill="white">
|
||||
<path d="M169.6 131.626C173.075 129.641 176.32 128.241 180.1 126.943C183.74 125.695 187.444 124.664 191.186 123.735C194.915 122.806 198.682 122.017 202.449 121.228C206.216 120.439 209.958 119.675 213.598 118.314C231.429 111.619 242.221 93.6357 239.612 74.9396C237.003 56.2435 221.692 41.8237 202.691 40.1564C194.062 39.4055 185.726 41.4164 178.013 44.9418C170.326 48.4545 163.288 53.4435 157.166 59.158C144.795 70.676 135.657 85.4649 130.083 101.208C124.47 117.054 122.37 134.095 123.694 150.806C124.356 159.129 125.883 167.504 128.326 175.509C130.719 183.362 134.181 191.469 138.839 198.342C136.828 185.475 138.559 172.175 143.917 160.262C149.262 148.375 158.121 138.193 169.6 131.626Z"/>
|
||||
</mask>
|
||||
<path d="M169.6 131.626C173.075 129.641 176.32 128.241 180.1 126.943C183.74 125.695 187.444 124.664 191.186 123.735C194.915 122.806 198.682 122.017 202.449 121.228C206.216 120.439 209.958 119.675 213.598 118.314C231.429 111.619 242.221 93.6357 239.612 74.9396C237.003 56.2435 221.692 41.8237 202.691 40.1564C194.062 39.4055 185.726 41.4164 178.013 44.9418C170.326 48.4545 163.288 53.4435 157.166 59.158C144.795 70.676 135.657 85.4649 130.083 101.208C124.47 117.054 122.37 134.095 123.694 150.806C124.356 159.129 125.883 167.504 128.326 175.509C130.719 183.362 134.181 191.469 138.839 198.342C136.828 185.475 138.559 172.175 143.917 160.262C149.262 148.375 158.121 138.193 169.6 131.626Z" fill="white" stroke="white" stroke-width="32" mask="url(#path-2-inside-1_4113_89308)"/>
|
||||
<path d="M162.246 150.4C161.915 153.913 163.073 157.464 165.542 160.06C168.011 162.657 171.499 164.031 174.668 165.253C178.13 166.577 181.12 167.658 184.353 169.529C187.433 171.311 190.157 173.526 192.435 176.262C201.802 187.449 200.937 203.867 190.462 214.049C179.988 224.23 163.379 224.778 152.243 215.321C149.404 212.903 146.884 209.798 144.81 206.756C141.654 186.52 147.775 165.317 162.246 150.4Z" fill="white"/>
|
||||
<mask id="path-4-outside-2_4113_89308" maskUnits="userSpaceOnUse" x="136" y="138.4" width="71" height="92" fill="black">
|
||||
<rect fill="white" x="136" y="138.4" width="71" height="92"/>
|
||||
<path d="M162.246 150.4C165.542 153.666 163.073 157.464 165.542 160.06C168.011 162.657 171.499 164.031 174.668 165.253C178.13 166.577 181.12 167.658 184.353 169.529C187.433 171.311 190.157 173.526 192.435 176.262C201.802 187.449 200.937 203.867 190.462 214.049C179.988 224.23 163.379 224.778 152.243 215.321C149.404 212.903 146.884 209.798 144.81 206.756C141.654 186.52 147.775 165.317 162.246 150.4Z"/>
|
||||
</mask>
|
||||
<path d="M162.246 150.4C165.542 153.666 163.073 157.464 165.542 160.06C168.011 162.657 171.499 164.031 174.668 165.253C178.13 166.577 181.12 167.658 184.353 169.529C187.433 171.311 190.157 173.526 192.435 176.262C201.802 187.449 200.937 203.867 190.462 214.049C179.988 224.23 163.379 224.778 152.243 215.321C149.404 212.903 146.884 209.798 144.81 206.756C141.654 186.52 147.775 165.317 162.246 150.4Z" stroke="#0057CE" stroke-width="16" mask="url(#path-4-outside-2_4113_89308)"/>
|
||||
<mask id="path-5-inside-3_4113_89308" fill="white">
|
||||
<path d="M50.4113 61.9063C63.3547 61.8935 75.9164 69.008 85.0163 76.9879C94.6761 85.4641 102.16 96.2567 107.085 107.991C112.036 119.789 114.416 132.542 114.327 145.282C114.238 157.665 111.769 171.079 106.296 182.394C105.774 167.821 100.123 153.885 90.3107 143.003C88.5926 141.107 86.7981 139.389 84.6599 137.938C82.5218 136.487 80.2691 135.418 77.8382 134.565C73.1164 132.911 67.7838 132.134 62.8711 131.6C57.8057 131.04 52.7149 130.709 47.6622 129.971C42.4695 129.207 37.8114 128.087 33.1787 125.427C19.688 117.715 13.1463 102.009 17.1808 87.1441C21.2153 72.2661 34.846 61.919 50.4113 61.9063Z"/>
|
||||
</mask>
|
||||
<path d="M50.4113 61.9063C63.3547 61.8935 75.9164 69.008 85.0163 76.9879C94.6761 85.4641 102.16 96.2567 107.085 107.991C112.036 119.789 114.416 132.542 114.327 145.282C114.238 157.665 111.769 171.079 106.296 182.394C105.774 167.821 100.123 153.885 90.3107 143.003C88.5926 141.107 86.7981 139.389 84.6599 137.938C82.5218 136.487 80.2691 135.418 77.8382 134.565C73.1164 132.911 67.7838 132.134 62.8711 131.6C57.8057 131.04 52.7149 130.709 47.6622 129.971C42.4695 129.207 37.8114 128.087 33.1787 125.427C19.688 117.715 13.1463 102.009 17.1808 87.1441C21.2153 72.2661 34.846 61.919 50.4113 61.9063Z" fill="white" stroke="white" stroke-width="32" mask="url(#path-5-inside-3_4113_89308)"/>
|
||||
<mask id="path-6-inside-4_4113_89308" fill="white">
|
||||
<path d="M82.5802 149.38C81.3584 148.03 80.0857 146.745 78.673 145.6C80.4294 148.578 80.6075 151.95 79.8694 155.196C79.1312 158.429 77.5021 161.419 75.4403 163.99C73.3149 166.625 70.8204 168.725 68.1095 170.71C65.7423 172.441 62.2932 174.656 60.1551 176.73C53.8679 182.839 52.5824 192.384 57.0369 199.893C61.4914 207.415 70.5277 210.979 78.9912 208.535C83.662 207.186 87.6202 204.144 90.7638 200.67C93.9455 197.157 96.5291 192.983 98.5655 188.757C98.0437 174.185 92.3928 160.261 82.5802 149.38Z"/>
|
||||
</mask>
|
||||
<path d="M82.5802 149.38C81.3584 148.03 80.0857 146.745 78.673 145.6C80.4294 148.578 80.6075 151.95 79.8694 155.196C79.1312 158.429 77.5021 161.419 75.4403 163.99C73.3149 166.625 70.8204 168.725 68.1095 170.71C65.7423 172.441 62.2932 174.656 60.1551 176.73C53.8679 182.839 52.5824 192.384 57.0369 199.893C61.4914 207.415 70.5277 210.979 78.9912 208.535C83.662 207.186 87.6202 204.144 90.7638 200.67C93.9455 197.157 96.5291 192.983 98.5655 188.757C98.0437 174.185 92.3928 160.261 82.5802 149.38Z" stroke="white" stroke-width="24" mask="url(#path-6-inside-4_4113_89308)"/>
|
||||
<mask id="path-7-outside-5_4113_89308" maskUnits="userSpaceOnUse" x="45.3994" y="138.6" width="62" height="79" fill="black">
|
||||
<rect fill="white" x="45.3994" y="138.6" width="62" height="79"/>
|
||||
<path d="M82.5802 149.38C81.3584 148.03 80.0857 146.745 78.673 145.6C80.4294 148.578 80.6075 151.95 79.8694 155.196C79.1312 158.429 77.5021 161.419 75.4403 163.99C73.3149 166.625 70.8204 168.725 68.1095 170.71C65.7423 172.441 62.2932 174.656 60.1551 176.73C53.8679 182.839 52.5824 192.384 57.0369 199.893C61.4914 207.415 70.5277 210.979 78.9912 208.535C83.662 207.186 87.6202 204.144 90.7638 200.67C93.9455 197.157 96.5291 192.983 98.5655 188.757C98.0437 174.185 92.3928 160.261 82.5802 149.38Z"/>
|
||||
</mask>
|
||||
<path d="M82.5802 149.38C81.3584 148.03 80.0857 146.745 78.673 145.6C80.4294 148.578 80.6075 151.95 79.8694 155.196C79.1312 158.429 77.5021 161.419 75.4403 163.99C73.3149 166.625 70.8204 168.725 68.1095 170.71C65.7423 172.441 62.2932 174.656 60.1551 176.73C53.8679 182.839 52.5824 192.384 57.0369 199.893C61.4914 207.415 70.5277 210.979 78.9912 208.535C83.662 207.186 87.6202 204.144 90.7638 200.67C93.9455 197.157 96.5291 192.983 98.5655 188.757C98.0437 174.185 92.3928 160.261 82.5802 149.38Z" fill="white"/>
|
||||
<path d="M82.5802 149.38C81.3584 148.03 80.0857 146.745 78.673 145.6C80.4294 148.578 80.6075 151.95 79.8694 155.196C79.1312 158.429 77.5021 161.419 75.4403 163.99C73.3149 166.625 70.8204 168.725 68.1095 170.71C65.7423 172.441 62.2932 174.656 60.1551 176.73C53.8679 182.839 52.5824 192.384 57.0369 199.893C61.4914 207.415 70.5277 210.979 78.9912 208.535C83.662 207.186 87.6202 204.144 90.7638 200.67C93.9455 197.157 96.5291 192.983 98.5655 188.757C98.0437 174.185 92.3928 160.261 82.5802 149.38Z" stroke="#0057CE" stroke-width="16" mask="url(#path-7-outside-5_4113_89308)"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 7.3 KiB |
1
src/renderer/src/assets/images/search/baidu.svg
Normal file
1
src/renderer/src/assets/images/search/baidu.svg
Normal file
@ -0,0 +1 @@
|
||||
<svg height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Baidu</title><path d="M8.859 11.735c1.017-1.71 4.059-3.083 6.202.286 1.579 2.284 4.284 4.397 4.284 4.397s2.027 1.601.73 4.684c-1.24 2.956-5.64 1.607-6.005 1.49l-.024-.009s-1.746-.568-3.776-.112c-2.026.458-3.773.286-3.773.286l-.045-.001c-.328-.01-2.38-.187-3.001-2.968-.675-3.028 2.365-4.687 2.592-4.968.226-.288 1.802-1.37 2.816-3.085zm.986 1.738v2.032h-1.64s-1.64.138-2.213 2.014c-.2 1.252.177 1.99.242 2.148.067.157.596 1.073 1.927 1.342h3.078v-7.514l-1.394-.022zm3.588 2.191l-1.44.024v3.956s.064.985 1.44 1.344h3.541v-5.3h-1.528v3.979h-1.46s-.466-.068-.553-.447v-3.556zM9.82 16.715v3.06H8.58s-.863-.045-1.126-1.049c-.136-.445.02-.959.088-1.16.063-.203.353-.671.951-.85H9.82zm9.525-9.036c2.086 0 2.646 2.06 2.646 2.742 0 .688.284 3.597-2.309 3.655-2.595.057-2.704-1.77-2.704-3.08 0-1.374.277-3.317 2.367-3.317zM4.24 6.08c1.523-.135 2.645 1.55 2.762 2.513.07.625.393 3.486-1.975 4-2.364.515-3.244-2.249-2.984-3.544 0 0 .28-2.797 2.197-2.969zm8.847-1.483c.14-1.31 1.69-3.316 2.931-3.028 1.236.285 2.367 1.944 2.137 3.37-.224 1.428-1.345 3.313-3.095 3.082-1.748-.226-2.143-1.823-1.973-3.424zM9.425 1c1.307 0 2.364 1.519 2.364 3.398 0 1.879-1.057 3.4-2.364 3.4s-2.367-1.521-2.367-3.4C7.058 2.518 8.118 1 9.425 1z" fill="#2932E1" fill-rule="nonzero"></path></svg>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
1
src/renderer/src/assets/images/search/bing.svg
Normal file
1
src/renderer/src/assets/images/search/bing.svg
Normal file
@ -0,0 +1 @@
|
||||
<svg height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Bing</title><path d="M11.97 7.569a.92.92 0 00-.805.863c-.013.195-.01.209.43 1.347 1 2.59 1.242 3.214 1.283 3.302.099.213.237.413.41.592.134.138.222.212.37.311.26.176.39.224 1.405.527.989.295 1.529.49 1.994.723.603.302 1.024.644 1.29 1.051.191.292.36.815.434 1.342.029.206.029.661 0 .847a2.491 2.491 0 01-.376 1.026c-.1.151-.065.126.081-.058.415-.52.838-1.408 1.054-2.213a6.728 6.728 0 00.102-3.012 6.626 6.626 0 00-3.291-4.53 104.157 104.157 0 00-1.322-.698l-.254-.133a737.941 737.941 0 01-1.575-.827c-.548-.29-.78-.406-.846-.426a1.376 1.376 0 00-.29-.045l-.093.01z" fill="url(#lobe-icons-bing-fill-0)"></path><path d="M13.164 17.24a4.385 4.385 0 00-.202.125 511.45 511.45 0 00-1.795 1.115 163.087 163.087 0 01-.989.614l-.463.288a99.198 99.198 0 01-1.502.941c-.326.2-.704.334-1.09.387-.18.024-.52.024-.7 0a2.807 2.807 0 01-1.318-.538 3.665 3.665 0 01-.543-.545 2.837 2.837 0 01-.506-1.141 2.161 2.161 0 00-.041-.182c-.008-.008.006.138.032.33.027.199.085.487.147.733.482 1.907 1.85 3.457 3.705 4.195a6.31 6.31 0 001.658.412c.22.025.844.035 1.074.017 1.054-.08 1.972-.393 2.913-.992a325.28 325.28 0 01.937-.596l.384-.244.684-.435.234-.149.009-.005.025-.017.013-.007.172-.11.597-.38c.76-.481.987-.65 1.34-.998.148-.146.37-.394.381-.425.002-.007.042-.068.088-.136a2.49 2.49 0 00.373-1.023 4.181 4.181 0 000-.847 4.336 4.336 0 00-.318-1.137c-.224-.472-.7-.9-1.383-1.245a2.972 2.972 0 00-.406-.181c-.01 0-.646.392-1.413.87a7089.171 7089.171 0 00-1.658 1.031l-.439.274z" fill="url(#lobe-icons-bing-fill-1)" fill-rule="nonzero"></path><path d="M4.003 14.946l.004 3.33.042.193c.134.604.366 1.04.77 1.445a2.701 2.701 0 001.955.814c.536 0 1-.135 1.479-.43l.703-.435.556-.346V8.003c0-2.306-.004-3.675-.012-3.782a2.734 2.734 0 00-.797-1.765c-.145-.144-.268-.24-.637-.496A1780.102 1780.102 0 015.762.362C5.406.115 5.38.098 5.271.059a.943.943 0 00-1.254.696C4.003.818 4 1.659 4 6.223v5.394H4l.003 3.329z" fill="url(#lobe-icons-bing-fill-2)" fill-rule="nonzero"></path><defs><radialGradient cx="93.717%" cy="77.818%" fx="93.717%" fy="77.818%" gradientTransform="scale(-1 -.7146) rotate(49.288 2.035 -2.198)" id="lobe-icons-bing-fill-0" r="143.691%"><stop offset="0%" stop-color="#00CACC"></stop><stop offset="100%" stop-color="#048FCE"></stop></radialGradient><radialGradient cx="13.893%" cy="71.448%" fx="13.893%" fy="71.448%" gradientTransform="scale(.6042 1) rotate(-23.34 .184 .494)" id="lobe-icons-bing-fill-1" r="149.21%"><stop offset="0%" stop-color="#00BBEC"></stop><stop offset="100%" stop-color="#2756A9"></stop></radialGradient><linearGradient id="lobe-icons-bing-fill-2" x1="50%" x2="50%" y1="0%" y2="100%"><stop offset="0%" stop-color="#00BBEC"></stop><stop offset="100%" stop-color="#2756A9"></stop></linearGradient></defs></svg>
|
||||
|
After Width: | Height: | Size: 2.8 KiB |
1
src/renderer/src/assets/images/search/google.svg
Normal file
1
src/renderer/src/assets/images/search/google.svg
Normal file
@ -0,0 +1 @@
|
||||
<svg height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Google</title><path d="M23 12.245c0-.905-.075-1.565-.236-2.25h-10.54v4.083h6.186c-.124 1.014-.797 2.542-2.294 3.569l-.021.136 3.332 2.53.23.022C21.779 18.417 23 15.593 23 12.245z" fill="#4285F4"></path><path d="M12.225 23c3.03 0 5.574-.978 7.433-2.665l-3.542-2.688c-.948.648-2.22 1.1-3.891 1.1a6.745 6.745 0 01-6.386-4.572l-.132.011-3.465 2.628-.045.124C4.043 20.531 7.835 23 12.225 23z" fill="#34A853"></path><path d="M5.84 14.175A6.65 6.65 0 015.463 12c0-.758.138-1.491.361-2.175l-.006-.147-3.508-2.67-.115.054A10.831 10.831 0 001 12c0 1.772.436 3.447 1.197 4.938l3.642-2.763z" fill="#FBBC05"></path><path d="M12.225 5.253c2.108 0 3.529.892 4.34 1.638l3.167-3.031C17.787 2.088 15.255 1 12.225 1 7.834 1 4.043 3.469 2.197 7.062l3.63 2.763a6.77 6.77 0 016.398-4.572z" fill="#EB4335"></path></svg>
|
||||
|
After Width: | Height: | Size: 920 B |
@ -263,6 +263,23 @@ export function ZhipuLogo(props: SVGProps<SVGSVGElement>) {
|
||||
</svg>
|
||||
)
|
||||
}
|
||||
export function McpLogo(props: SVGProps<SVGSVGElement>) {
|
||||
return (
|
||||
<svg
|
||||
fill="currentColor"
|
||||
fillRule="evenodd"
|
||||
height="1em"
|
||||
width="1em"
|
||||
viewBox="0 0 24 24"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
{...props}>
|
||||
<title>ModelContextProtocol</title>
|
||||
<path d="M15.688 2.343a2.588 2.588 0 00-3.61 0l-9.626 9.44a.863.863 0 01-1.203 0 .823.823 0 010-1.18l9.626-9.44a4.313 4.313 0 016.016 0 4.116 4.116 0 011.204 3.54 4.3 4.3 0 013.609 1.18l.05.05a4.115 4.115 0 010 5.9l-8.706 8.537a.274.274 0 000 .393l1.788 1.754a.823.823 0 010 1.18.863.863 0 01-1.203 0l-1.788-1.753a1.92 1.92 0 010-2.754l8.706-8.538a2.47 2.47 0 000-3.54l-.05-.049a2.588 2.588 0 00-3.607-.003l-7.172 7.034-.002.002-.098.097a.863.863 0 01-1.204 0 .823.823 0 010-1.18l7.273-7.133a2.47 2.47 0 00-.003-3.537z"></path>
|
||||
<path d="M14.485 4.703a.823.823 0 000-1.18.863.863 0 00-1.204 0l-7.119 6.982a4.115 4.115 0 000 5.9 4.314 4.314 0 006.016 0l7.12-6.982a.823.823 0 000-1.18.863.863 0 00-1.204 0l-7.119 6.982a2.588 2.588 0 01-3.61 0 2.47 2.47 0 010-3.54l7.12-6.982z"></path>
|
||||
</svg>
|
||||
)
|
||||
}
|
||||
|
||||
export function PoeLogo(props: SVGProps<SVGSVGElement>) {
|
||||
return (
|
||||
<svg
|
||||
|
||||
@ -1,553 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { AppLogo } from '@renderer/config/env'
|
||||
import { SettingHelpText, SettingRow } from '@renderer/pages/settings'
|
||||
import type { WebSocketCandidatesResponse } from '@shared/config/types'
|
||||
import { Alert, Button, Modal, Progress, Spin } from 'antd'
|
||||
import { QRCodeSVG } from 'qrcode.react'
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import { TopView } from '../TopView'
|
||||
|
||||
const logger = loggerService.withContext('ExportToPhoneLanPopup')
|
||||
|
||||
interface Props {
|
||||
resolve: (data: any) => void
|
||||
}
|
||||
|
||||
type ConnectionPhase = 'initializing' | 'waiting_qr_scan' | 'connecting' | 'connected' | 'disconnected' | 'error'
|
||||
type TransferPhase = 'idle' | 'preparing' | 'sending' | 'completed' | 'error'
|
||||
|
||||
const LoadingQRCode: React.FC = () => {
|
||||
const { t } = useTranslation()
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', alignItems: 'center', gap: '12px' }}>
|
||||
<Spin />
|
||||
<span style={{ fontSize: '14px', color: 'var(--color-text-2)' }}>
|
||||
{t('settings.data.export_to_phone.lan.generating_qr')}
|
||||
</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const ScanQRCode: React.FC<{ qrCodeValue: string }> = ({ qrCodeValue }) => {
|
||||
const { t } = useTranslation()
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', alignItems: 'center', gap: '12px' }}>
|
||||
<QRCodeSVG
|
||||
marginSize={2}
|
||||
value={qrCodeValue}
|
||||
level="H"
|
||||
size={200}
|
||||
imageSettings={{
|
||||
src: AppLogo,
|
||||
width: 40,
|
||||
height: 40,
|
||||
excavate: true
|
||||
}}
|
||||
/>
|
||||
<span style={{ fontSize: '12px', color: 'var(--color-text-2)' }}>
|
||||
{t('settings.data.export_to_phone.lan.scan_qr')}
|
||||
</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const ConnectingAnimation: React.FC = () => {
|
||||
const { t } = useTranslation()
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', alignItems: 'center', gap: '12px' }}>
|
||||
<div
|
||||
style={{
|
||||
width: '160px',
|
||||
height: '160px',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
border: '2px dashed var(--color-status-warning)',
|
||||
borderRadius: '12px',
|
||||
backgroundColor: 'var(--color-status-warning)'
|
||||
}}>
|
||||
<Spin size="large" />
|
||||
<span style={{ fontSize: '14px', color: 'var(--color-text)', marginTop: '12px' }}>
|
||||
{t('settings.data.export_to_phone.lan.status.connecting')}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const ConnectedDisplay: React.FC = () => {
|
||||
const { t } = useTranslation()
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', alignItems: 'center', gap: '12px' }}>
|
||||
<div
|
||||
style={{
|
||||
width: '160px',
|
||||
height: '160px',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
border: '2px dashed var(--color-status-success)',
|
||||
borderRadius: '12px',
|
||||
backgroundColor: 'var(--color-status-success)'
|
||||
}}>
|
||||
<span style={{ fontSize: '48px' }}>📱</span>
|
||||
<span style={{ fontSize: '14px', color: 'var(--color-text)', marginTop: '8px' }}>
|
||||
{t('settings.data.export_to_phone.lan.connected')}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const ErrorQRCode: React.FC<{ error: string | null }> = ({ error }) => {
|
||||
const { t } = useTranslation()
|
||||
return (
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
gap: '12px',
|
||||
padding: '20px',
|
||||
border: `1px solid var(--color-error)`,
|
||||
borderRadius: '8px',
|
||||
backgroundColor: 'var(--color-error)'
|
||||
}}>
|
||||
<span style={{ fontSize: '48px' }}>⚠️</span>
|
||||
<span style={{ fontSize: '14px', color: 'var(--color-text)' }}>
|
||||
{t('settings.data.export_to_phone.lan.connection_failed')}
|
||||
</span>
|
||||
{error && <span style={{ fontSize: '12px', color: 'var(--color-text-2)' }}>{error}</span>}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const PopupContainer: React.FC<Props> = ({ resolve }) => {
|
||||
const [isOpen, setIsOpen] = useState(true)
|
||||
const [connectionPhase, setConnectionPhase] = useState<ConnectionPhase>('initializing')
|
||||
const [transferPhase, setTransferPhase] = useState<TransferPhase>('idle')
|
||||
const [qrCodeValue, setQrCodeValue] = useState('')
|
||||
const [selectedFolderPath, setSelectedFolderPath] = useState<string | null>(null)
|
||||
const [sendProgress, setSendProgress] = useState(0)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const [autoCloseCountdown, setAutoCloseCountdown] = useState<number | null>(null)
|
||||
|
||||
const { t } = useTranslation()
|
||||
|
||||
// 派生状态
|
||||
const isConnected = connectionPhase === 'connected'
|
||||
const canSend = isConnected && selectedFolderPath && transferPhase === 'idle'
|
||||
const isSending = transferPhase === 'preparing' || transferPhase === 'sending'
|
||||
|
||||
// 状态文本映射
|
||||
const connectionStatusText = useMemo(() => {
|
||||
const statusMap = {
|
||||
initializing: t('settings.data.export_to_phone.lan.status.initializing'),
|
||||
waiting_qr_scan: t('settings.data.export_to_phone.lan.status.waiting_qr_scan'),
|
||||
connecting: t('settings.data.export_to_phone.lan.status.connecting'),
|
||||
connected: t('settings.data.export_to_phone.lan.status.connected'),
|
||||
disconnected: t('settings.data.export_to_phone.lan.status.disconnected'),
|
||||
error: t('settings.data.export_to_phone.lan.status.error')
|
||||
}
|
||||
return statusMap[connectionPhase]
|
||||
}, [connectionPhase, t])
|
||||
|
||||
const transferStatusText = useMemo(() => {
|
||||
const statusMap = {
|
||||
idle: '',
|
||||
preparing: t('settings.data.export_to_phone.lan.status.preparing'),
|
||||
sending: t('settings.data.export_to_phone.lan.status.sending'),
|
||||
completed: t('settings.data.export_to_phone.lan.status.completed'),
|
||||
error: t('settings.data.export_to_phone.lan.status.error')
|
||||
}
|
||||
return statusMap[transferPhase]
|
||||
}, [transferPhase, t])
|
||||
|
||||
// 状态样式映射
|
||||
const connectionStatusStyles = useMemo(() => {
|
||||
const styleMap = {
|
||||
initializing: {
|
||||
bg: 'var(--color-background-mute)',
|
||||
border: 'var(--color-border-mute)'
|
||||
},
|
||||
waiting_qr_scan: {
|
||||
bg: 'var(--color-primary-mute)',
|
||||
border: 'var(--color-primary-soft)'
|
||||
},
|
||||
connecting: { bg: 'var(--color-status-warning)', border: 'var(--color-status-warning)' },
|
||||
connected: {
|
||||
bg: 'var(--color-status-success)',
|
||||
border: 'var(--color-status-success)'
|
||||
},
|
||||
disconnected: { bg: 'var(--color-error)', border: 'var(--color-error)' },
|
||||
error: { bg: 'var(--color-error)', border: 'var(--color-error)' }
|
||||
}
|
||||
return styleMap[connectionPhase]
|
||||
}, [connectionPhase])
|
||||
|
||||
const initWebSocket = useCallback(async () => {
|
||||
try {
|
||||
setConnectionPhase('initializing')
|
||||
await window.api.webSocket.start()
|
||||
const { port, ip } = await window.api.webSocket.status()
|
||||
|
||||
if (ip && port) {
|
||||
const candidatesData = await window.api.webSocket.getAllCandidates()
|
||||
|
||||
const optimizeConnectionInfo = () => {
|
||||
const ipToNumber = (ip: string) => {
|
||||
return ip.split('.').reduce((acc, octet) => (acc << 8) + parseInt(octet), 0)
|
||||
}
|
||||
|
||||
const compressedData = [
|
||||
'CSA',
|
||||
ipToNumber(ip),
|
||||
candidatesData.map((candidate: WebSocketCandidatesResponse) => ipToNumber(candidate.host)),
|
||||
port, // 端口号
|
||||
Date.now() % 86400000
|
||||
]
|
||||
|
||||
return compressedData
|
||||
}
|
||||
|
||||
const compressedData = optimizeConnectionInfo()
|
||||
const qrCodeValue = JSON.stringify(compressedData)
|
||||
setQrCodeValue(qrCodeValue)
|
||||
setConnectionPhase('waiting_qr_scan')
|
||||
} else {
|
||||
setError(t('settings.data.export_to_phone.lan.error.no_ip'))
|
||||
setConnectionPhase('error')
|
||||
}
|
||||
} catch (error) {
|
||||
setError(
|
||||
`${t('settings.data.export_to_phone.lan.error.init_failed')}: ${error instanceof Error ? error.message : ''}`
|
||||
)
|
||||
setConnectionPhase('error')
|
||||
logger.error('Failed to initialize WebSocket:', error as Error)
|
||||
}
|
||||
}, [t])
|
||||
|
||||
const handleClientConnected = useCallback((_event: any, data: { connected: boolean }) => {
|
||||
logger.info(`Client connection status: ${data.connected ? 'connected' : 'disconnected'}`)
|
||||
if (data.connected) {
|
||||
setConnectionPhase('connected')
|
||||
setError(null)
|
||||
} else {
|
||||
setConnectionPhase('disconnected')
|
||||
}
|
||||
}, [])
|
||||
|
||||
const handleMessageReceived = useCallback((_event: any, data: any) => {
|
||||
logger.info(`Received message from mobile: ${JSON.stringify(data)}`)
|
||||
}, [])
|
||||
|
||||
const handleSendProgress = useCallback(
|
||||
(_event: any, data: { progress: number }) => {
|
||||
const progress = data.progress
|
||||
setSendProgress(progress)
|
||||
|
||||
if (transferPhase === 'preparing' && progress > 0) {
|
||||
setTransferPhase('sending')
|
||||
}
|
||||
|
||||
if (progress >= 100) {
|
||||
setTransferPhase('completed')
|
||||
// 启动 3 秒倒计时自动关闭
|
||||
setAutoCloseCountdown(3)
|
||||
}
|
||||
},
|
||||
[transferPhase]
|
||||
)
|
||||
|
||||
const handleSelectZip = useCallback(async () => {
|
||||
const result = await window.api.file.select()
|
||||
if (result) {
|
||||
setSelectedFolderPath(result[0].path)
|
||||
}
|
||||
}, [])
|
||||
|
||||
const handleSendZip = useCallback(async () => {
|
||||
if (!selectedFolderPath) {
|
||||
setError(t('settings.data.export_to_phone.lan.error.no_file'))
|
||||
return
|
||||
}
|
||||
|
||||
setTransferPhase('preparing')
|
||||
setError(null)
|
||||
setSendProgress(0)
|
||||
|
||||
try {
|
||||
logger.info(`Starting file transfer: ${selectedFolderPath}`)
|
||||
await window.api.webSocket.sendFile(selectedFolderPath)
|
||||
} catch (error) {
|
||||
setError(
|
||||
`${t('settings.data.export_to_phone.lan.error.send_failed')}: ${error instanceof Error ? error.message : ''}`
|
||||
)
|
||||
setTransferPhase('error')
|
||||
logger.error('Failed to send file:', error as Error)
|
||||
}
|
||||
}, [selectedFolderPath, t])
|
||||
|
||||
// 尝试关闭弹窗 - 如果正在传输则显示确认
|
||||
const handleCancel = useCallback(() => {
|
||||
if (isSending) {
|
||||
window.modal.confirm({
|
||||
title: t('settings.data.export_to_phone.lan.confirm_close_title'),
|
||||
content: t('settings.data.export_to_phone.lan.confirm_close_message'),
|
||||
centered: true,
|
||||
okButtonProps: {
|
||||
danger: true
|
||||
},
|
||||
okText: t('settings.data.export_to_phone.lan.force_close'),
|
||||
onOk: () => setIsOpen(false)
|
||||
})
|
||||
} else {
|
||||
setIsOpen(false)
|
||||
}
|
||||
}, [isSending, t])
|
||||
|
||||
// 清理并关闭
|
||||
const handleClose = useCallback(async () => {
|
||||
try {
|
||||
// 主动断开 WebSocket 连接
|
||||
if (isConnected || connectionPhase !== 'disconnected') {
|
||||
logger.info('Closing popup, stopping WebSocket')
|
||||
await window.api.webSocket.stop()
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to stop WebSocket on close:', error as Error)
|
||||
}
|
||||
resolve({})
|
||||
}, [resolve, isConnected, connectionPhase])
|
||||
|
||||
useEffect(() => {
|
||||
initWebSocket()
|
||||
|
||||
const removeClientConnectedListener = window.electron.ipcRenderer.on(
|
||||
'websocket-client-connected',
|
||||
handleClientConnected
|
||||
)
|
||||
const removeMessageReceivedListener = window.electron.ipcRenderer.on(
|
||||
'websocket-message-received',
|
||||
handleMessageReceived
|
||||
)
|
||||
const removeSendProgressListener = window.electron.ipcRenderer.on('file-send-progress', handleSendProgress)
|
||||
|
||||
return () => {
|
||||
removeClientConnectedListener()
|
||||
removeMessageReceivedListener()
|
||||
removeSendProgressListener()
|
||||
window.api.webSocket.stop()
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
// 自动关闭倒计时
|
||||
useEffect(() => {
|
||||
if (autoCloseCountdown === null) return
|
||||
|
||||
if (autoCloseCountdown <= 0) {
|
||||
logger.debug('Auto-closing popup after transfer completion')
|
||||
setIsOpen(false)
|
||||
return
|
||||
}
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
setAutoCloseCountdown(autoCloseCountdown - 1)
|
||||
}, 1000)
|
||||
|
||||
return () => clearTimeout(timer)
|
||||
}, [autoCloseCountdown])
|
||||
|
||||
// 状态指示器组件
|
||||
const StatusIndicator = useCallback(
|
||||
() => (
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
gap: '8px',
|
||||
padding: '5px 12px',
|
||||
width: '100%',
|
||||
backgroundColor: connectionStatusStyles.bg,
|
||||
border: `1px solid ${connectionStatusStyles.border}`,
|
||||
marginBottom: 10
|
||||
}}>
|
||||
<span style={{ fontSize: '14px', fontWeight: '500', color: 'var(--color-text)' }}>{connectionStatusText}</span>
|
||||
</div>
|
||||
),
|
||||
[connectionStatusStyles, connectionStatusText]
|
||||
)
|
||||
|
||||
// 二维码显示组件 - 使用显式条件渲染以避免类型不匹配
|
||||
const QRCodeDisplay = useCallback(() => {
|
||||
switch (connectionPhase) {
|
||||
case 'waiting_qr_scan':
|
||||
case 'disconnected':
|
||||
return <ScanQRCode qrCodeValue={qrCodeValue} />
|
||||
case 'initializing':
|
||||
return <LoadingQRCode />
|
||||
case 'connecting':
|
||||
return <ConnectingAnimation />
|
||||
case 'connected':
|
||||
return <ConnectedDisplay />
|
||||
case 'error':
|
||||
return <ErrorQRCode error={error} />
|
||||
default:
|
||||
return null
|
||||
}
|
||||
}, [connectionPhase, qrCodeValue, error])
|
||||
|
||||
// 传输进度组件
|
||||
const TransferProgress = useCallback(() => {
|
||||
if (!isSending && transferPhase !== 'completed') return null
|
||||
|
||||
return (
|
||||
<div style={{ paddingTop: '20px' }}>
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
gap: '8px',
|
||||
padding: '12px',
|
||||
border: `1px solid var(--color-border)`,
|
||||
borderRadius: '8px',
|
||||
backgroundColor: 'var(--color-background-mute)'
|
||||
}}>
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center',
|
||||
fontSize: '14px',
|
||||
fontWeight: '500'
|
||||
}}>
|
||||
<span style={{ color: 'var(--color-text)' }}>
|
||||
{t('settings.data.export_to_phone.lan.transfer_progress')}
|
||||
</span>
|
||||
<span
|
||||
style={{ color: transferPhase === 'completed' ? 'var(--color-status-success)' : 'var(--color-primary)' }}>
|
||||
{transferPhase === 'completed' ? '✅ ' + t('common.completed') : `${Math.round(sendProgress)}%`}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<Progress
|
||||
percent={Math.round(sendProgress)}
|
||||
status={transferPhase === 'completed' ? 'success' : 'active'}
|
||||
showInfo={false}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}, [isSending, transferPhase, sendProgress, t])
|
||||
|
||||
const AutoCloseCountdown = useCallback(() => {
|
||||
if (transferPhase !== 'completed' || autoCloseCountdown === null || autoCloseCountdown <= 0) return null
|
||||
|
||||
return (
|
||||
<div
|
||||
style={{
|
||||
fontSize: '12px',
|
||||
color: 'var(--color-text-2)',
|
||||
textAlign: 'center',
|
||||
paddingTop: '4px'
|
||||
}}>
|
||||
{t('settings.data.export_to_phone.lan.auto_close_tip', { seconds: autoCloseCountdown })}
|
||||
</div>
|
||||
)
|
||||
}, [transferPhase, autoCloseCountdown, t])
|
||||
|
||||
// 错误显示组件
|
||||
const ErrorDisplay = useCallback(() => {
|
||||
if (!error || transferPhase !== 'error') return null
|
||||
|
||||
return (
|
||||
<div
|
||||
style={{
|
||||
padding: '12px',
|
||||
border: `1px solid var(--color-error)`,
|
||||
borderRadius: '8px',
|
||||
backgroundColor: 'var(--color-error)',
|
||||
textAlign: 'center'
|
||||
}}>
|
||||
<span style={{ fontSize: '14px', color: 'var(--color-text)' }}>❌ {error}</span>
|
||||
</div>
|
||||
)
|
||||
}, [error, transferPhase])
|
||||
|
||||
return (
|
||||
<Modal
|
||||
open={isOpen}
|
||||
onCancel={handleCancel}
|
||||
afterClose={handleClose}
|
||||
title={t('settings.data.export_to_phone.lan.title')}
|
||||
centered
|
||||
closable={!isSending}
|
||||
maskClosable={false}
|
||||
keyboard={true}
|
||||
footer={null}
|
||||
styles={{ body: { paddingBottom: 10 } }}>
|
||||
<SettingRow>
|
||||
<StatusIndicator />
|
||||
</SettingRow>
|
||||
|
||||
<Alert message={t('settings.data.export_to_phone.lan.content')} type="info" style={{ borderRadius: 0 }} />
|
||||
|
||||
<SettingRow style={{ display: 'flex', justifyContent: 'center', minHeight: '180px', marginBlock: 25 }}>
|
||||
<QRCodeDisplay />
|
||||
</SettingRow>
|
||||
|
||||
<SettingRow style={{ display: 'flex', alignItems: 'center', marginBlock: 10 }}>
|
||||
<div style={{ display: 'flex', gap: 10, justifyContent: 'center', width: '100%' }}>
|
||||
<Button onClick={handleSelectZip} disabled={isSending}>
|
||||
{t('settings.data.export_to_phone.lan.selectZip')}
|
||||
</Button>
|
||||
<Button type="primary" onClick={handleSendZip} disabled={!canSend} loading={isSending}>
|
||||
{transferStatusText || t('settings.data.export_to_phone.lan.sendZip')}
|
||||
</Button>
|
||||
</div>
|
||||
</SettingRow>
|
||||
|
||||
<SettingHelpText
|
||||
style={{
|
||||
overflow: 'hidden',
|
||||
textOverflow: 'ellipsis',
|
||||
whiteSpace: 'nowrap',
|
||||
textAlign: 'center'
|
||||
}}>
|
||||
{selectedFolderPath || t('settings.data.export_to_phone.lan.noZipSelected')}
|
||||
</SettingHelpText>
|
||||
|
||||
<TransferProgress />
|
||||
<AutoCloseCountdown />
|
||||
<ErrorDisplay />
|
||||
</Modal>
|
||||
)
|
||||
}
|
||||
|
||||
const TopViewKey = 'ExportToPhoneLanPopup'
|
||||
|
||||
export default class ExportToPhoneLanPopup {
|
||||
static topviewId = 0
|
||||
static hide() {
|
||||
TopView.hide(TopViewKey)
|
||||
}
|
||||
static show() {
|
||||
return new Promise<any>((resolve) => {
|
||||
TopView.show(
|
||||
<PopupContainer
|
||||
resolve={(v) => {
|
||||
resolve(v)
|
||||
TopView.hide(TopViewKey)
|
||||
}}
|
||||
/>,
|
||||
TopViewKey
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,97 @@
|
||||
import { cn } from '@renderer/utils'
|
||||
import type { FC, KeyboardEventHandler } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import { ProgressIndicator } from './ProgressIndicator'
|
||||
import type { LanDeviceCardProps } from './types'
|
||||
|
||||
export const LanDeviceCard: FC<LanDeviceCardProps> = ({
|
||||
service,
|
||||
transferState,
|
||||
isConnected,
|
||||
handshakeInProgress,
|
||||
isDisabled,
|
||||
onSendFile
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
// Device info
|
||||
const deviceName = service.txt?.modelName || t('common.unknown')
|
||||
const platform = service.txt?.platform
|
||||
const appVersion = service.txt?.appVersion
|
||||
const platformInfo = [platform, appVersion].filter(Boolean).join(' ')
|
||||
const displayTitle = platformInfo ? `${deviceName} (${platformInfo})` : deviceName
|
||||
|
||||
// Address info
|
||||
const primaryAddress = service.addresses?.[0]
|
||||
const addressesWithPort = primaryAddress ? (service.port ? `${primaryAddress}:${service.port}` : primaryAddress) : ''
|
||||
|
||||
// Progress visibility
|
||||
const shouldShowProgress =
|
||||
transferState && ['selecting', 'transferring', 'completed', 'failed'].includes(transferState.status)
|
||||
|
||||
// Status text
|
||||
const statusText = handshakeInProgress
|
||||
? t('settings.data.export_to_phone.lan.handshake.in_progress')
|
||||
: isConnected
|
||||
? t('settings.data.export_to_phone.lan.connected')
|
||||
: t('settings.data.export_to_phone.lan.send_file')
|
||||
|
||||
// Event handlers
|
||||
const handleClick = () => {
|
||||
if (isDisabled) return
|
||||
onSendFile(service.id)
|
||||
}
|
||||
|
||||
const handleKeyDown: KeyboardEventHandler<HTMLDivElement> = (event) => {
|
||||
if (event.key === 'Enter' || event.key === ' ') {
|
||||
event.preventDefault()
|
||||
handleClick()
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
onClick={handleClick}
|
||||
onKeyDown={handleKeyDown}
|
||||
className={cn(
|
||||
// Base styles
|
||||
'flex cursor-pointer flex-col gap-2 rounded-xl border p-3 outline-none transition-all duration-[120ms]',
|
||||
// Hover state
|
||||
'hover:-translate-y-px hover:border-[var(--color-primary-hover)] hover:shadow-md',
|
||||
// Focus state
|
||||
'focus-visible:border-[var(--color-primary)] focus-visible:shadow-[0_0_0_2px_rgba(24,144,255,0.2)]',
|
||||
// Connected state
|
||||
isConnected
|
||||
? 'border-[var(--color-primary)] bg-[rgba(24,144,255,0.04)]'
|
||||
: 'border-[var(--color-border)] bg-[var(--color-background)]',
|
||||
// Disabled state
|
||||
isDisabled && 'pointer-events-none translate-y-0 opacity-70 shadow-none'
|
||||
)}>
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="break-words font-semibold text-[var(--color-text-1)] text-sm">{displayTitle}</div>
|
||||
<span className="text-[var(--color-text-2)] text-xs">{statusText}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Meta Row - IP Address */}
|
||||
<div className="flex flex-col gap-1">
|
||||
<span className="text-[11px] text-[var(--color-text-3)] uppercase tracking-[0.03em]">
|
||||
{t('settings.data.export_to_phone.lan.ip_addresses')}
|
||||
</span>
|
||||
<span className="break-words text-[var(--color-text)] text-xs">{addressesWithPort || t('common.unknown')}</span>
|
||||
</div>
|
||||
|
||||
{/* Footer with Progress */}
|
||||
<div className="flex flex-wrap items-center justify-between gap-2 text-[11px] text-[var(--color-text-3)]">
|
||||
{shouldShowProgress && transferState && (
|
||||
<ProgressIndicator transferState={transferState} handshakeInProgress={handshakeInProgress} />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@ -0,0 +1,55 @@
|
||||
import { cn } from '@renderer/utils'
|
||||
import type { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import type { ProgressIndicatorProps } from './types'
|
||||
|
||||
export const ProgressIndicator: FC<ProgressIndicatorProps> = ({ transferState, handshakeInProgress }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const progressPercent = Math.min(100, Math.max(0, transferState.progress ?? 0))
|
||||
|
||||
const progressLabel = (() => {
|
||||
if (transferState.status === 'failed') {
|
||||
return transferState.error || t('common.unknown_error')
|
||||
}
|
||||
if (transferState.status === 'selecting') {
|
||||
return handshakeInProgress
|
||||
? t('settings.data.export_to_phone.lan.handshake.in_progress')
|
||||
: t('settings.data.export_to_phone.lan.status.preparing')
|
||||
}
|
||||
return `${Math.round(progressPercent)}%`
|
||||
})()
|
||||
|
||||
const isFailed = transferState.status === 'failed'
|
||||
const isCompleted = transferState.status === 'completed'
|
||||
|
||||
return (
|
||||
<div className="flex min-w-[180px] flex-1 flex-col gap-1">
|
||||
{/* Label Row */}
|
||||
<div
|
||||
className={cn(
|
||||
'flex items-center justify-between gap-1.5 text-[11px]',
|
||||
isFailed ? 'text-[var(--color-error)]' : 'text-[var(--color-text-2)]'
|
||||
)}>
|
||||
<span className="flex-1 overflow-hidden text-ellipsis whitespace-nowrap">{transferState.fileName}</span>
|
||||
<span className="shrink-0 whitespace-nowrap">{progressLabel}</span>
|
||||
</div>
|
||||
|
||||
{/* Progress Track */}
|
||||
<div className="relative h-1.5 w-full overflow-hidden rounded-full bg-[var(--color-border)]">
|
||||
<div
|
||||
className={cn(
|
||||
'h-full rounded-full transition-[width] duration-[120ms]',
|
||||
isFailed
|
||||
? 'bg-[var(--color-error)]'
|
||||
: isCompleted
|
||||
? 'bg-[var(--color-status-success)]'
|
||||
: 'bg-[var(--color-primary)]'
|
||||
)}
|
||||
style={{ width: `${progressPercent}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
397
src/renderer/src/components/Popups/LanTransferPopup/hook.ts
Normal file
397
src/renderer/src/components/Popups/LanTransferPopup/hook.ts
Normal file
@ -0,0 +1,397 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { getBackupData } from '@renderer/services/BackupService'
|
||||
import type { LocalTransferPeer } from '@shared/config/types'
|
||||
import { useCallback, useEffect, useMemo, useReducer, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import type { LanPeerTransferState, LanTransferAction, LanTransferReducerState } from './types'
|
||||
|
||||
const logger = loggerService.withContext('useLanTransfer')
|
||||
|
||||
// ==========================================
|
||||
// Initial State
|
||||
// ==========================================
|
||||
|
||||
export const initialState: LanTransferReducerState = {
|
||||
open: true,
|
||||
lanState: null,
|
||||
lanHandshakePeerId: null,
|
||||
lastHandshakeResult: null,
|
||||
fileTransferState: {},
|
||||
tempBackupPath: null
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// Reducer
|
||||
// ==========================================
|
||||
|
||||
export function lanTransferReducer(state: LanTransferReducerState, action: LanTransferAction): LanTransferReducerState {
|
||||
switch (action.type) {
|
||||
case 'SET_OPEN':
|
||||
return { ...state, open: action.payload }
|
||||
|
||||
case 'SET_LAN_STATE':
|
||||
return { ...state, lanState: action.payload }
|
||||
|
||||
case 'SET_HANDSHAKE_PEER_ID':
|
||||
return { ...state, lanHandshakePeerId: action.payload }
|
||||
|
||||
case 'SET_HANDSHAKE_RESULT':
|
||||
return { ...state, lastHandshakeResult: action.payload }
|
||||
|
||||
case 'SET_TEMP_BACKUP_PATH':
|
||||
return { ...state, tempBackupPath: action.payload }
|
||||
|
||||
case 'UPDATE_TRANSFER_STATE': {
|
||||
const { peerId, state: transferState } = action.payload
|
||||
return {
|
||||
...state,
|
||||
fileTransferState: {
|
||||
...state.fileTransferState,
|
||||
[peerId]: {
|
||||
...(state.fileTransferState[peerId] ?? { progress: 0, status: 'idle' as const }),
|
||||
...transferState
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case 'SET_TRANSFER_STATE': {
|
||||
const { peerId, state: transferState } = action.payload
|
||||
return {
|
||||
...state,
|
||||
fileTransferState: {
|
||||
...state.fileTransferState,
|
||||
[peerId]: transferState
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case 'CLEANUP_STALE_PEERS': {
|
||||
const activeIds = action.payload
|
||||
const newFileTransferState: Record<string, LanPeerTransferState> = {}
|
||||
for (const id of Object.keys(state.fileTransferState)) {
|
||||
if (activeIds.has(id)) {
|
||||
newFileTransferState[id] = state.fileTransferState[id]
|
||||
}
|
||||
}
|
||||
return {
|
||||
...state,
|
||||
fileTransferState: newFileTransferState,
|
||||
lastHandshakeResult:
|
||||
state.lastHandshakeResult && activeIds.has(state.lastHandshakeResult.peerId)
|
||||
? state.lastHandshakeResult
|
||||
: null,
|
||||
lanHandshakePeerId:
|
||||
state.lanHandshakePeerId && activeIds.has(state.lanHandshakePeerId) ? state.lanHandshakePeerId : null
|
||||
}
|
||||
}
|
||||
|
||||
case 'RESET_CONNECTION_STATE':
|
||||
return {
|
||||
...state,
|
||||
fileTransferState: {},
|
||||
lastHandshakeResult: null,
|
||||
lanHandshakePeerId: null,
|
||||
tempBackupPath: null
|
||||
}
|
||||
|
||||
default:
|
||||
return state
|
||||
}
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// Hook Return Type
|
||||
// ==========================================
|
||||
|
||||
export interface UseLanTransferReturn {
|
||||
// State
|
||||
state: LanTransferReducerState
|
||||
|
||||
// Derived values
|
||||
lanDevices: LocalTransferPeer[]
|
||||
isAnyTransferring: boolean
|
||||
lastError: string | undefined
|
||||
|
||||
// Actions
|
||||
handleSendFile: (peerId: string) => Promise<void>
|
||||
handleModalCancel: () => void
|
||||
getTransferState: (peerId: string) => LanPeerTransferState | undefined
|
||||
isConnected: (peerId: string) => boolean
|
||||
isHandshakeInProgress: (peerId: string) => boolean
|
||||
|
||||
// Dispatch (for advanced use)
|
||||
dispatch: React.Dispatch<LanTransferAction>
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// Hook
|
||||
// ==========================================
|
||||
|
||||
export function useLanTransfer(): UseLanTransferReturn {
|
||||
const { t } = useTranslation()
|
||||
const [state, dispatch] = useReducer(lanTransferReducer, initialState)
|
||||
const isSendingRef = useRef(false)
|
||||
|
||||
// ==========================================
|
||||
// Derived Values
|
||||
// ==========================================
|
||||
|
||||
const lanDevices = useMemo(() => state.lanState?.services ?? [], [state.lanState])
|
||||
|
||||
const isAnyTransferring = useMemo(
|
||||
() => Object.values(state.fileTransferState).some((s) => s.status === 'transferring' || s.status === 'selecting'),
|
||||
[state.fileTransferState]
|
||||
)
|
||||
|
||||
const lastError = state.lanState?.lastError
|
||||
|
||||
// ==========================================
|
||||
// LAN State Sync
|
||||
// ==========================================
|
||||
|
||||
const syncLanState = useCallback(async () => {
|
||||
if (!window.api?.localTransfer) {
|
||||
logger.warn('Local transfer bridge is unavailable')
|
||||
return
|
||||
}
|
||||
try {
|
||||
const nextState = await window.api.localTransfer.getState()
|
||||
dispatch({ type: 'SET_LAN_STATE', payload: nextState })
|
||||
} catch (error) {
|
||||
logger.error('Failed to sync LAN state', error as Error)
|
||||
}
|
||||
}, [])
|
||||
|
||||
// ==========================================
|
||||
// Send File Handler
|
||||
// ==========================================
|
||||
|
||||
const handleSendFile = useCallback(
|
||||
async (peerId: string) => {
|
||||
if (!window.api?.localTransfer || isSendingRef.current) {
|
||||
return
|
||||
}
|
||||
isSendingRef.current = true
|
||||
|
||||
dispatch({
|
||||
type: 'SET_TRANSFER_STATE',
|
||||
payload: { peerId, state: { progress: 0, status: 'selecting' } }
|
||||
})
|
||||
|
||||
let backupPath: string | null = null
|
||||
|
||||
try {
|
||||
// Step 0: Ensure handshake (connect if needed)
|
||||
if (!state.lastHandshakeResult?.ack.accepted || state.lastHandshakeResult.peerId !== peerId) {
|
||||
dispatch({ type: 'SET_HANDSHAKE_PEER_ID', payload: peerId })
|
||||
try {
|
||||
const ack = await window.api.localTransfer.connect({ peerId })
|
||||
dispatch({
|
||||
type: 'SET_HANDSHAKE_RESULT',
|
||||
payload: { peerId, ack, timestamp: Date.now() }
|
||||
})
|
||||
if (!ack.accepted) {
|
||||
throw new Error(ack.message || t('settings.data.export_to_phone.lan.connection_failed'))
|
||||
}
|
||||
} finally {
|
||||
dispatch({ type: 'SET_HANDSHAKE_PEER_ID', payload: null })
|
||||
}
|
||||
}
|
||||
|
||||
// Step 1: Create temporary backup
|
||||
logger.info('Creating temporary backup for LAN transfer...')
|
||||
const backupData = await getBackupData()
|
||||
backupPath = await window.api.backup.createLanTransferBackup(backupData)
|
||||
dispatch({ type: 'SET_TEMP_BACKUP_PATH', payload: backupPath })
|
||||
|
||||
// Extract filename from path
|
||||
const fileName = backupPath.split(/[/\\]/).pop() || 'backup.zip'
|
||||
|
||||
// Step 2: Set transferring state
|
||||
dispatch({
|
||||
type: 'UPDATE_TRANSFER_STATE',
|
||||
payload: { peerId, state: { fileName, progress: 0, status: 'transferring' } }
|
||||
})
|
||||
|
||||
// Step 3: Send file
|
||||
logger.info(`Sending backup file: ${backupPath}`)
|
||||
const result = await window.api.localTransfer.sendFile(backupPath)
|
||||
|
||||
if (result.success) {
|
||||
dispatch({
|
||||
type: 'UPDATE_TRANSFER_STATE',
|
||||
payload: { peerId, state: { progress: 100, status: 'completed' } }
|
||||
})
|
||||
} else {
|
||||
dispatch({
|
||||
type: 'UPDATE_TRANSFER_STATE',
|
||||
payload: { peerId, state: { status: 'failed', error: result.error } }
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
dispatch({
|
||||
type: 'UPDATE_TRANSFER_STATE',
|
||||
payload: { peerId, state: { status: 'failed', error: message } }
|
||||
})
|
||||
logger.error('Failed to send file', error as Error)
|
||||
} finally {
|
||||
// Step 4: Clean up temp file
|
||||
if (backupPath) {
|
||||
try {
|
||||
await window.api.backup.deleteTempBackup(backupPath)
|
||||
logger.info('Cleaned up temporary backup file')
|
||||
} catch (cleanupError) {
|
||||
logger.warn('Failed to clean up temp backup', cleanupError as Error)
|
||||
}
|
||||
dispatch({ type: 'SET_TEMP_BACKUP_PATH', payload: null })
|
||||
}
|
||||
isSendingRef.current = false
|
||||
}
|
||||
},
|
||||
[state.lastHandshakeResult, t]
|
||||
)
|
||||
|
||||
// ==========================================
|
||||
// Teardown
|
||||
// ==========================================
|
||||
|
||||
// Use ref to track temp backup path for cleanup without causing effect re-runs
|
||||
const tempBackupPathRef = useRef<string | null>(null)
|
||||
tempBackupPathRef.current = state.tempBackupPath
|
||||
|
||||
const teardownLan = useCallback(async () => {
|
||||
if (!window.api?.localTransfer) {
|
||||
return
|
||||
}
|
||||
try {
|
||||
await window.api.localTransfer.cancelTransfer?.()
|
||||
} catch (error) {
|
||||
logger.warn('Failed to cancel LAN transfer on close', error as Error)
|
||||
}
|
||||
try {
|
||||
await window.api.localTransfer.disconnect?.()
|
||||
} catch (error) {
|
||||
logger.warn('Failed to disconnect LAN on close', error as Error)
|
||||
}
|
||||
// Clean up temp backup if exists (use ref to get current value)
|
||||
if (tempBackupPathRef.current) {
|
||||
try {
|
||||
await window.api.backup.deleteTempBackup(tempBackupPathRef.current)
|
||||
} catch (error) {
|
||||
logger.warn('Failed to cleanup temp backup on close', error as Error)
|
||||
}
|
||||
}
|
||||
dispatch({ type: 'RESET_CONNECTION_STATE' })
|
||||
}, []) // No dependencies - uses ref for current value
|
||||
|
||||
const handleModalCancel = useCallback(() => {
|
||||
void teardownLan()
|
||||
dispatch({ type: 'SET_OPEN', payload: false })
|
||||
}, [teardownLan])
|
||||
|
||||
// ==========================================
|
||||
// Effects
|
||||
// ==========================================
|
||||
|
||||
// Initial sync and service listener
|
||||
useEffect(() => {
|
||||
if (!window.api?.localTransfer) {
|
||||
return
|
||||
}
|
||||
syncLanState()
|
||||
const removeListener = window.api.localTransfer.onServicesUpdated((lanState) => {
|
||||
dispatch({ type: 'SET_LAN_STATE', payload: lanState })
|
||||
})
|
||||
return () => {
|
||||
removeListener?.()
|
||||
}
|
||||
}, [syncLanState])
|
||||
|
||||
// Client events listener (progress, completion)
|
||||
useEffect(() => {
|
||||
if (!window.api?.localTransfer) {
|
||||
return
|
||||
}
|
||||
const removeListener = window.api.localTransfer.onClientEvent((event) => {
|
||||
const key = event.peerId ?? 'global'
|
||||
|
||||
if (event.type === 'file_transfer_progress') {
|
||||
dispatch({
|
||||
type: 'UPDATE_TRANSFER_STATE',
|
||||
payload: {
|
||||
peerId: key,
|
||||
state: {
|
||||
transferId: event.transferId,
|
||||
fileName: event.fileName,
|
||||
progress: event.progress,
|
||||
speed: event.speed,
|
||||
status: 'transferring'
|
||||
}
|
||||
}
|
||||
})
|
||||
} else if (event.type === 'file_transfer_complete') {
|
||||
dispatch({
|
||||
type: 'UPDATE_TRANSFER_STATE',
|
||||
payload: {
|
||||
peerId: key,
|
||||
state: {
|
||||
progress: event.success ? 100 : undefined,
|
||||
status: event.success ? 'completed' : 'failed',
|
||||
error: event.error
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
return () => {
|
||||
removeListener?.()
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Cleanup stale peers when services change
|
||||
useEffect(() => {
|
||||
const activeIds = new Set(lanDevices.map((s) => s.id))
|
||||
dispatch({ type: 'CLEANUP_STALE_PEERS', payload: activeIds })
|
||||
}, [lanDevices])
|
||||
|
||||
// Cleanup on unmount only (teardownLan is stable with no deps)
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
void teardownLan()
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
// ==========================================
|
||||
// Helper Functions
|
||||
// ==========================================
|
||||
|
||||
const getTransferState = useCallback((peerId: string) => state.fileTransferState[peerId], [state.fileTransferState])
|
||||
|
||||
const isConnected = useCallback(
|
||||
(peerId: string) =>
|
||||
state.lastHandshakeResult?.peerId === peerId && state.lastHandshakeResult?.ack.accepted === true,
|
||||
[state.lastHandshakeResult]
|
||||
)
|
||||
|
||||
const isHandshakeInProgress = useCallback(
|
||||
(peerId: string) => state.lanHandshakePeerId === peerId,
|
||||
[state.lanHandshakePeerId]
|
||||
)
|
||||
|
||||
return {
|
||||
state,
|
||||
lanDevices,
|
||||
isAnyTransferring,
|
||||
lastError,
|
||||
handleSendFile,
|
||||
handleModalCancel,
|
||||
getTransferState,
|
||||
isConnected,
|
||||
isHandshakeInProgress,
|
||||
dispatch
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,37 @@
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
|
||||
import { getHideCallback, PopupContainer } from './popup'
|
||||
import type { PopupResolveData } from './types'
|
||||
|
||||
// Re-export types for external use
|
||||
export type { LanPeerTransferState } from './types'
|
||||
|
||||
const TopViewKey = 'LanTransferPopup'
|
||||
|
||||
export default class LanTransferPopup {
|
||||
static topviewId = 0
|
||||
|
||||
static hide() {
|
||||
// Try to use the registered callback for proper cleanup, fallback to TopView.hide
|
||||
const callback = getHideCallback()
|
||||
if (callback) {
|
||||
callback()
|
||||
} else {
|
||||
TopView.hide(TopViewKey)
|
||||
}
|
||||
}
|
||||
|
||||
static show() {
|
||||
return new Promise<PopupResolveData>((resolve) => {
|
||||
TopView.show(
|
||||
<PopupContainer
|
||||
resolve={(v) => {
|
||||
resolve(v)
|
||||
TopView.hide(TopViewKey)
|
||||
}}
|
||||
/>,
|
||||
TopViewKey
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,88 @@
|
||||
import { Modal } from 'antd'
|
||||
import { TriangleAlert } from 'lucide-react'
|
||||
import type { FC } from 'react'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import { useLanTransfer } from './hook'
|
||||
import { LanDeviceCard } from './LanDeviceCard'
|
||||
import type { PopupContainerProps } from './types'
|
||||
|
||||
// Module-level callback for external hide access
|
||||
let hideCallback: (() => void) | null = null
|
||||
export const setHideCallback = (cb: () => void) => {
|
||||
hideCallback = cb
|
||||
}
|
||||
export const getHideCallback = () => hideCallback
|
||||
|
||||
export const PopupContainer: FC<PopupContainerProps> = ({ resolve }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const {
|
||||
state,
|
||||
lanDevices,
|
||||
isAnyTransferring,
|
||||
lastError,
|
||||
handleSendFile,
|
||||
handleModalCancel,
|
||||
getTransferState,
|
||||
isConnected,
|
||||
isHandshakeInProgress
|
||||
} = useLanTransfer()
|
||||
|
||||
const contentTitle = useMemo(() => t('settings.data.export_to_phone.lan.title'), [t])
|
||||
|
||||
const onClose = () => resolve({})
|
||||
|
||||
// Register hide callback for external access
|
||||
setHideCallback(handleModalCancel)
|
||||
|
||||
return (
|
||||
<Modal
|
||||
open={state.open}
|
||||
onCancel={handleModalCancel}
|
||||
afterClose={onClose}
|
||||
footer={null}
|
||||
centered
|
||||
title={contentTitle}
|
||||
transitionName="animation-move-down">
|
||||
<div className="flex flex-col gap-3">
|
||||
{/* Error Display */}
|
||||
{lastError && <div className="text-[var(--color-error)] text-xs">{lastError}</div>}
|
||||
|
||||
{/* Device List */}
|
||||
<div className="mt-2 flex flex-col gap-3">
|
||||
{lanDevices.length === 0 ? (
|
||||
// Warning when no devices
|
||||
<div className="flex w-full items-center gap-2.5 rounded-[10px] border border-[rgba(255,159,41,0.4)] border-dashed bg-[rgba(255,159,41,0.1)] px-3.5 py-3">
|
||||
<TriangleAlert size={20} className="text-orange-400" />
|
||||
<span className="flex-1 text-[#ff9f29] text-[13px] leading-[1.4]">
|
||||
{t('settings.data.export_to_phone.lan.no_connection_warning')}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
// Device cards
|
||||
lanDevices.map((service) => {
|
||||
const transferState = getTransferState(service.id)
|
||||
const connected = isConnected(service.id)
|
||||
const handshakeInProgress = isHandshakeInProgress(service.id)
|
||||
const isCardDisabled = isAnyTransferring || handshakeInProgress
|
||||
|
||||
return (
|
||||
<LanDeviceCard
|
||||
key={service.id}
|
||||
service={service}
|
||||
transferState={transferState}
|
||||
isConnected={connected}
|
||||
handshakeInProgress={handshakeInProgress}
|
||||
isDisabled={isCardDisabled}
|
||||
onSendFile={handleSendFile}
|
||||
/>
|
||||
)
|
||||
})
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
)
|
||||
}
|
||||
84
src/renderer/src/components/Popups/LanTransferPopup/types.ts
Normal file
84
src/renderer/src/components/Popups/LanTransferPopup/types.ts
Normal file
@ -0,0 +1,84 @@
|
||||
import type { LanHandshakeAckMessage, LocalTransferPeer, LocalTransferState } from '@shared/config/types'
|
||||
|
||||
// ==========================================
|
||||
// Transfer Status
|
||||
// ==========================================
|
||||
|
||||
export type TransferStatus = 'idle' | 'selecting' | 'transferring' | 'completed' | 'failed'
|
||||
|
||||
// ==========================================
|
||||
// Per-Peer Transfer State
|
||||
// ==========================================
|
||||
|
||||
export interface LanPeerTransferState {
|
||||
transferId?: string
|
||||
fileName?: string
|
||||
progress: number
|
||||
speed?: number
|
||||
status: TransferStatus
|
||||
error?: string
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// Handshake Result
|
||||
// ==========================================
|
||||
|
||||
export type HandshakeResult = {
|
||||
peerId: string
|
||||
ack: LanHandshakeAckMessage
|
||||
timestamp: number
|
||||
} | null
|
||||
|
||||
// ==========================================
|
||||
// Reducer State
|
||||
// ==========================================
|
||||
|
||||
export interface LanTransferReducerState {
|
||||
open: boolean
|
||||
lanState: LocalTransferState | null
|
||||
lanHandshakePeerId: string | null
|
||||
lastHandshakeResult: HandshakeResult
|
||||
fileTransferState: Record<string, LanPeerTransferState>
|
||||
tempBackupPath: string | null
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// Reducer Actions
|
||||
// ==========================================
|
||||
|
||||
export type LanTransferAction =
|
||||
| { type: 'SET_OPEN'; payload: boolean }
|
||||
| { type: 'SET_LAN_STATE'; payload: LocalTransferState | null }
|
||||
| { type: 'SET_HANDSHAKE_PEER_ID'; payload: string | null }
|
||||
| { type: 'SET_HANDSHAKE_RESULT'; payload: HandshakeResult }
|
||||
| { type: 'SET_TEMP_BACKUP_PATH'; payload: string | null }
|
||||
| { type: 'UPDATE_TRANSFER_STATE'; payload: { peerId: string; state: Partial<LanPeerTransferState> } }
|
||||
| { type: 'SET_TRANSFER_STATE'; payload: { peerId: string; state: LanPeerTransferState } }
|
||||
| { type: 'CLEANUP_STALE_PEERS'; payload: Set<string> }
|
||||
| { type: 'RESET_CONNECTION_STATE' }
|
||||
|
||||
// ==========================================
|
||||
// Component Props
|
||||
// ==========================================
|
||||
|
||||
export interface LanDeviceCardProps {
|
||||
service: LocalTransferPeer
|
||||
transferState?: LanPeerTransferState
|
||||
isConnected: boolean
|
||||
handshakeInProgress: boolean
|
||||
isDisabled: boolean
|
||||
onSendFile: (peerId: string) => void
|
||||
}
|
||||
|
||||
export interface ProgressIndicatorProps {
|
||||
transferState: LanPeerTransferState
|
||||
handshakeInProgress: boolean
|
||||
}
|
||||
|
||||
export interface PopupResolveData {
|
||||
// Empty for now, can be extended
|
||||
}
|
||||
|
||||
export interface PopupContainerProps {
|
||||
resolve: (data: PopupResolveData) => void
|
||||
}
|
||||
@ -21,7 +21,6 @@ import type { LRUCache } from 'lru-cache'
|
||||
import {
|
||||
FileSearch,
|
||||
Folder,
|
||||
Hammer,
|
||||
Home,
|
||||
Languages,
|
||||
LayoutGrid,
|
||||
@ -99,8 +98,6 @@ const getTabIcon = (
|
||||
return <NotepadText size={14} />
|
||||
case 'knowledge':
|
||||
return <FileSearch size={14} />
|
||||
case 'mcp':
|
||||
return <Hammer size={14} />
|
||||
case 'files':
|
||||
return <Folder size={14} />
|
||||
case 'settings':
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import { loggerService } from '@logger'
|
||||
import ThreeMinTopAppLogo from '@renderer/assets/images/apps/3mintop.png?url'
|
||||
import AbacusLogo from '@renderer/assets/images/apps/abacus.webp?url'
|
||||
import AIStudioLogo from '@renderer/assets/images/apps/aistudio.svg?url'
|
||||
import AIStudioLogo from '@renderer/assets/images/apps/aistudio.png?url'
|
||||
import ApplicationLogo from '@renderer/assets/images/apps/application.png?url'
|
||||
import BaiduAiAppLogo from '@renderer/assets/images/apps/baidu-ai.png?url'
|
||||
import BaiduAiSearchLogo from '@renderer/assets/images/apps/baidu-ai-search.webp?url'
|
||||
|
||||
139
src/renderer/src/config/models/__tests__/openai.test.ts
Normal file
139
src/renderer/src/config/models/__tests__/openai.test.ts
Normal file
@ -0,0 +1,139 @@
|
||||
import type { Model } from '@renderer/types'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { isSupportNoneReasoningEffortModel } from '../openai'
|
||||
|
||||
// Mock store and settings to avoid initialization issues
|
||||
vi.mock('@renderer/store', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
getState: () => ({
|
||||
llm: { providers: [] },
|
||||
settings: {}
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/hooks/useStore', () => ({
|
||||
getStoreProviders: vi.fn(() => [])
|
||||
}))
|
||||
|
||||
const createModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
id: 'gpt-4o',
|
||||
name: 'gpt-4o',
|
||||
provider: 'openai',
|
||||
group: 'OpenAI',
|
||||
...overrides
|
||||
})
|
||||
|
||||
describe('OpenAI Model Detection', () => {
|
||||
describe('isSupportNoneReasoningEffortModel', () => {
|
||||
describe('should return true for GPT-5.1 and GPT-5.2 reasoning models', () => {
|
||||
it('returns true for GPT-5.1 base model', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.1' }))).toBe(true)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'GPT-5.1' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('returns true for GPT-5.1 mini model', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.1-mini' }))).toBe(true)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.1-mini-preview' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('returns true for GPT-5.1 preview model', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('returns true for GPT-5.2 base model', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.2' }))).toBe(true)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'GPT-5.2' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('returns true for GPT-5.2 mini model', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.2-mini' }))).toBe(true)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.2-mini-preview' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('returns true for GPT-5.2 preview model', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.2-preview' }))).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('should return false for pro variants', () => {
|
||||
it('returns false for GPT-5.1-pro models', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.1-pro' }))).toBe(false)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'GPT-5.1-Pro' }))).toBe(false)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.1-pro-preview' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('returns false for GPT-5.2-pro models', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.2-pro' }))).toBe(false)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'GPT-5.2-Pro' }))).toBe(false)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.2-pro-preview' }))).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('should return false for chat variants', () => {
|
||||
it('returns false for GPT-5.1-chat models', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.1-chat' }))).toBe(false)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'GPT-5.1-Chat' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('returns false for GPT-5.2-chat models', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.2-chat' }))).toBe(false)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'GPT-5.2-Chat' }))).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('should return false for GPT-5 series (non-5.1/5.2)', () => {
|
||||
it('returns false for GPT-5 base model', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('returns false for GPT-5 pro model', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5-pro' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('returns false for GPT-5 preview model', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5-preview' }))).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('should return false for other OpenAI models', () => {
|
||||
it('returns false for GPT-4 models', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-4-turbo' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('returns false for o1 models', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'o1' }))).toBe(false)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'o1-mini' }))).toBe(false)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'o1-preview' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('returns false for o3 models', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'o3' }))).toBe(false)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'o3-mini' }))).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('handles models with version suffixes', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.1-2025-01-01' }))).toBe(true)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.2-latest' }))).toBe(true)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'gpt-5.1-pro-2025-01-01' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('handles models with OpenRouter prefixes', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'openai/gpt-5.1' }))).toBe(true)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'openai/gpt-5.2-mini' }))).toBe(true)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'openai/gpt-5.1-pro' }))).toBe(false)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'openai/gpt-5.1-chat' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('handles mixed case with chat and pro', () => {
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'GPT-5.1-CHAT' }))).toBe(false)
|
||||
expect(isSupportNoneReasoningEffortModel(createModel({ id: 'GPT-5.2-PRO' }))).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -17,6 +17,7 @@ import {
|
||||
isGeminiReasoningModel,
|
||||
isGrok4FastReasoningModel,
|
||||
isHunyuanReasoningModel,
|
||||
isInterleavedThinkingModel,
|
||||
isLingReasoningModel,
|
||||
isMiniMaxReasoningModel,
|
||||
isPerplexityReasoningModel,
|
||||
@ -679,7 +680,12 @@ describe('getThinkModelType - Comprehensive Coverage', () => {
|
||||
expect(getThinkModelType(createModel({ id: 'o3' }))).toBe('o')
|
||||
expect(getThinkModelType(createModel({ id: 'o3-mini' }))).toBe('o')
|
||||
expect(getThinkModelType(createModel({ id: 'o4' }))).toBe('o')
|
||||
expect(getThinkModelType(createModel({ id: 'gpt-oss-reasoning' }))).toBe('o')
|
||||
})
|
||||
|
||||
it('should return gpt_oss for gpt-oss models', () => {
|
||||
expect(getThinkModelType(createModel({ id: 'gpt-oss' }))).toBe('gpt_oss')
|
||||
expect(getThinkModelType(createModel({ id: 'gpt-oss:20b' }))).toBe('gpt_oss')
|
||||
expect(getThinkModelType(createModel({ id: 'gpt-oss-reasoning' }))).toBe('gpt_oss')
|
||||
})
|
||||
})
|
||||
|
||||
@ -1762,6 +1768,21 @@ describe('getModelSupportedReasoningEffortOptions', () => {
|
||||
'medium',
|
||||
'high'
|
||||
])
|
||||
})
|
||||
|
||||
it('should return correct options for gpt-oss models', () => {
|
||||
expect(getModelSupportedReasoningEffortOptions(createModel({ id: 'gpt-oss' }))).toEqual([
|
||||
'default',
|
||||
'low',
|
||||
'medium',
|
||||
'high'
|
||||
])
|
||||
expect(getModelSupportedReasoningEffortOptions(createModel({ id: 'gpt-oss:20b' }))).toEqual([
|
||||
'default',
|
||||
'low',
|
||||
'medium',
|
||||
'high'
|
||||
])
|
||||
expect(getModelSupportedReasoningEffortOptions(createModel({ id: 'gpt-oss-reasoning' }))).toEqual([
|
||||
'default',
|
||||
'low',
|
||||
@ -2157,3 +2178,105 @@ describe('getModelSupportedReasoningEffortOptions', () => {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('isInterleavedThinkingModel', () => {
|
||||
describe('MiniMax models', () => {
|
||||
it('should return true for minimax-m2', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'minimax-m2' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for minimax-m2.1', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'minimax-m2.1' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for minimax-m2 with suffixes', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'minimax-m2-pro' }))).toBe(true)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'minimax-m2-preview' }))).toBe(true)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'minimax-m2-lite' }))).toBe(true)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'minimax-m2-ultra-lite' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for minimax-m2.x with suffixes', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'minimax-m2.1-pro' }))).toBe(true)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'minimax-m2.2-preview' }))).toBe(true)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'minimax-m2.5-lite' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for non-m2 minimax models', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'minimax-m1' }))).toBe(false)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'minimax-m3' }))).toBe(false)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'minimax-pro' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle case insensitivity', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'MiniMax-M2' }))).toBe(true)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'MINIMAX-M2.1' }))).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('MiMo models', () => {
|
||||
it('should return true for mimo-v2-flash', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'mimo-v2-flash' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for other mimo models', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'mimo-v1-flash' }))).toBe(false)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'mimo-v2' }))).toBe(false)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'mimo-v2-pro' }))).toBe(false)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'mimo-flash' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle case insensitivity', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'MiMo-V2-Flash' }))).toBe(true)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'MIMO-V2-FLASH' }))).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Zhipu GLM models', () => {
|
||||
it('should return true for glm-4.5', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'glm-4.5' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for glm-4.6', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'glm-4.6' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for glm-4.7 and higher versions', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'glm-4.7' }))).toBe(true)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'glm-4.8' }))).toBe(true)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'glm-4.9' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for glm-4.x with suffixes', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'glm-4.5-pro' }))).toBe(true)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'glm-4.6-preview' }))).toBe(true)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'glm-4.7-lite' }))).toBe(true)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'glm-4.8-ultra' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for glm-4 without decimal version', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'glm-4' }))).toBe(false)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'glm-4-pro' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for other glm models', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'glm-3.5' }))).toBe(false)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'glm-5.0' }))).toBe(false)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'glm-zero-preview' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle case insensitivity', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'GLM-4.5' }))).toBe(true)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'Glm-4.6-Pro' }))).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Non-matching models', () => {
|
||||
it('should return false for unrelated models', () => {
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'gpt-4' }))).toBe(false)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'claude-3-opus' }))).toBe(false)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'gemini-pro' }))).toBe(false)
|
||||
expect(isInterleavedThinkingModel(createModel({ id: 'deepseek-v3' }))).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -617,6 +617,30 @@ export const SYSTEM_MODELS: Record<SystemProviderId | 'defaultModel', Model[]> =
|
||||
name: 'GLM-4.6',
|
||||
group: 'GLM-4.6'
|
||||
},
|
||||
{
|
||||
id: 'glm-4.6v',
|
||||
provider: 'zhipu',
|
||||
name: 'GLM-4.6V',
|
||||
group: 'GLM-4.6V'
|
||||
},
|
||||
{
|
||||
id: 'glm-4.6v-flash',
|
||||
provider: 'zhipu',
|
||||
name: 'GLM-4.6V-Flash',
|
||||
group: 'GLM-4.6V'
|
||||
},
|
||||
{
|
||||
id: 'glm-4.6v-flashx',
|
||||
provider: 'zhipu',
|
||||
name: 'GLM-4.6V-FlashX',
|
||||
group: 'GLM-4.6V'
|
||||
},
|
||||
{
|
||||
id: 'glm-4.7',
|
||||
provider: 'zhipu',
|
||||
name: 'GLM-4.7',
|
||||
group: 'GLM-4.7'
|
||||
},
|
||||
{
|
||||
id: 'glm-4.5',
|
||||
provider: 'zhipu',
|
||||
@ -921,6 +945,12 @@ export const SYSTEM_MODELS: Record<SystemProviderId | 'defaultModel', Model[]> =
|
||||
provider: 'minimax',
|
||||
name: 'MiniMax M2 Stable',
|
||||
group: 'minimax-m2'
|
||||
},
|
||||
{
|
||||
id: 'MiniMax-M2.1',
|
||||
provider: 'minimax',
|
||||
name: 'MiniMax M2.1',
|
||||
group: 'minimax-m2'
|
||||
}
|
||||
],
|
||||
hyperbolic: [
|
||||
|
||||
@ -72,7 +72,37 @@ export const isGPT52SeriesModel = (model: Model) => {
|
||||
|
||||
export function isSupportVerbosityModel(model: Model): boolean {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return (isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat')
|
||||
return (
|
||||
(isGPT5SeriesModel(model) || isGPT51SeriesModel(model) || isGPT52SeriesModel(model)) && !modelId.includes('chat')
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines if a model supports the "none" reasoning effort parameter.
|
||||
*
|
||||
* This applies to GPT-5.1 and GPT-5.2 series reasoning models (non-chat, non-pro variants).
|
||||
* These models allow setting reasoning_effort to "none" to skip reasoning steps.
|
||||
*
|
||||
* @param model - The model to check
|
||||
* @returns true if the model supports "none" reasoning effort, false otherwise
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* // Returns true
|
||||
* isSupportNoneReasoningEffortModel({ id: 'gpt-5.1', provider: 'openai' })
|
||||
* isSupportNoneReasoningEffortModel({ id: 'gpt-5.2-mini', provider: 'openai' })
|
||||
*
|
||||
* // Returns false
|
||||
* isSupportNoneReasoningEffortModel({ id: 'gpt-5.1-pro', provider: 'openai' })
|
||||
* isSupportNoneReasoningEffortModel({ id: 'gpt-5.1-chat', provider: 'openai' })
|
||||
* isSupportNoneReasoningEffortModel({ id: 'gpt-5-pro', provider: 'openai' })
|
||||
* ```
|
||||
*/
|
||||
export function isSupportNoneReasoningEffortModel(model: Model): boolean {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return (
|
||||
(isGPT51SeriesModel(model) || isGPT52SeriesModel(model)) && !modelId.includes('chat') && !modelId.includes('pro')
|
||||
)
|
||||
}
|
||||
|
||||
export function isOpenAIChatCompletionOnlyModel(model: Model): boolean {
|
||||
|
||||
@ -17,6 +17,7 @@ import {
|
||||
isGPT52ProModel,
|
||||
isGPT52SeriesModel,
|
||||
isOpenAIDeepResearchModel,
|
||||
isOpenAIOpenWeightModel,
|
||||
isOpenAIReasoningModel,
|
||||
isSupportedReasoningEffortOpenAIModel
|
||||
} from './openai'
|
||||
@ -41,6 +42,7 @@ export const MODEL_SUPPORTED_REASONING_EFFORT = {
|
||||
gpt5_2: ['none', 'low', 'medium', 'high', 'xhigh'] as const,
|
||||
gpt5pro: ['high'] as const,
|
||||
gpt52pro: ['medium', 'high', 'xhigh'] as const,
|
||||
gpt_oss: ['low', 'medium', 'high'] as const,
|
||||
grok: ['low', 'high'] as const,
|
||||
grok4_fast: ['auto'] as const,
|
||||
gemini2_flash: ['low', 'medium', 'high', 'auto'] as const,
|
||||
@ -72,6 +74,7 @@ export const MODEL_SUPPORTED_OPTIONS: ThinkingOptionConfig = {
|
||||
gpt5_2: ['default', ...MODEL_SUPPORTED_REASONING_EFFORT.gpt5_2] as const,
|
||||
gpt5_1_codex_max: ['default', ...MODEL_SUPPORTED_REASONING_EFFORT.gpt5_1_codex_max] as const,
|
||||
gpt52pro: ['default', ...MODEL_SUPPORTED_REASONING_EFFORT.gpt52pro] as const,
|
||||
gpt_oss: ['default', ...MODEL_SUPPORTED_REASONING_EFFORT.gpt_oss] as const,
|
||||
grok: ['default', ...MODEL_SUPPORTED_REASONING_EFFORT.grok] as const,
|
||||
grok4_fast: ['default', 'none', ...MODEL_SUPPORTED_REASONING_EFFORT.grok4_fast] as const,
|
||||
gemini2_flash: ['default', 'none', ...MODEL_SUPPORTED_REASONING_EFFORT.gemini2_flash] as const,
|
||||
@ -127,6 +130,8 @@ const _getThinkModelType = (model: Model): ThinkingModelType => {
|
||||
thinkingModelType = 'gpt5pro'
|
||||
}
|
||||
}
|
||||
} else if (isOpenAIOpenWeightModel(model)) {
|
||||
thinkingModelType = 'gpt_oss'
|
||||
} else if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
thinkingModelType = 'o'
|
||||
} else if (isGrok4FastReasoningModel(model)) {
|
||||
@ -571,7 +576,7 @@ export const isSupportedReasoningEffortPerplexityModel = (model: Model): boolean
|
||||
|
||||
export const isSupportedThinkingTokenZhipuModel = (model: Model): boolean => {
|
||||
const modelId = getLowerBaseModelName(model.id, '/')
|
||||
return ['glm-4.5', 'glm-4.6'].some((id) => modelId.includes(id))
|
||||
return ['glm-4.5', 'glm-4.6', 'glm-4.7'].some((id) => modelId.includes(id))
|
||||
}
|
||||
|
||||
export const isSupportedThinkingTokenMiMoModel = (model: Model): boolean => {
|
||||
@ -632,7 +637,7 @@ export const isMiniMaxReasoningModel = (model?: Model): boolean => {
|
||||
return false
|
||||
}
|
||||
const modelId = getLowerBaseModelName(model.id, '/')
|
||||
return (['minimax-m1', 'minimax-m2'] as const).some((id) => modelId.includes(id))
|
||||
return (['minimax-m1', 'minimax-m2', 'minimax-m2.1'] as const).some((id) => modelId.includes(id))
|
||||
}
|
||||
|
||||
export function isReasoningModel(model?: Model): boolean {
|
||||
@ -738,3 +743,20 @@ export const findTokenLimit = (modelId: string): { min: number; max: number } |
|
||||
*/
|
||||
export const isFixedReasoningModel = (model: Model) =>
|
||||
isReasoningModel(model) && !isSupportedThinkingTokenModel(model) && !isSupportedReasoningEffortModel(model)
|
||||
|
||||
// https://platform.minimaxi.com/docs/guides/text-m2-function-call#openai-sdk
|
||||
// https://docs.z.ai/guides/capabilities/thinking-mode
|
||||
// https://platform.moonshot.cn/docs/guide/use-kimi-k2-thinking-model#%E5%A4%9A%E6%AD%A5%E5%B7%A5%E5%85%B7%E8%B0%83%E7%94%A8
|
||||
const INTERLEAVED_THINKING_MODEL_REGEX =
|
||||
/minimax-m2(.(\d+))?(?:-[\w-]+)?|mimo-v2-flash|glm-4.(\d+)(?:-[\w-]+)?|kimi-k2-thinking?$/i
|
||||
|
||||
/**
|
||||
* Determines whether the given model supports interleaved thinking.
|
||||
*
|
||||
* @param model - The model object to check.
|
||||
* @returns `true` if the model's ID matches the interleaved thinking model pattern; otherwise, `false`.
|
||||
*/
|
||||
export const isInterleavedThinkingModel = (model: Model) => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return INTERLEAVED_THINKING_MODEL_REGEX.test(modelId)
|
||||
}
|
||||
|
||||
@ -24,6 +24,7 @@ export const FUNCTION_CALLING_MODELS = [
|
||||
'deepseek',
|
||||
'glm-4(?:-[\\w-]+)?',
|
||||
'glm-4.5(?:-[\\w-]+)?',
|
||||
'glm-4.7(?:-[\\w-]+)?',
|
||||
'learnlm(?:-[\\w-]+)?',
|
||||
'gemini(?:-[\\w-]+)?', // 提前排除了gemini的嵌入模型
|
||||
'grok-3(?:-[\\w-]+)?',
|
||||
@ -32,7 +33,7 @@ export const FUNCTION_CALLING_MODELS = [
|
||||
'kimi-k2(?:-[\\w-]+)?',
|
||||
'ling-\\w+(?:-[\\w-]+)?',
|
||||
'ring-\\w+(?:-[\\w-]+)?',
|
||||
'minimax-m2',
|
||||
'minimax-m2(?:.1)?',
|
||||
'mimo-v2-flash'
|
||||
] as const
|
||||
|
||||
|
||||
@ -75,12 +75,37 @@ const VISION_REGEX = new RegExp(
|
||||
'i'
|
||||
)
|
||||
|
||||
// For middleware to identify models that must use the dedicated Image API
|
||||
// All dedicated image generation models (only generate images, no text chat capability)
|
||||
// These models need:
|
||||
// 1. Route to dedicated image generation API
|
||||
// 2. Exclude from reasoning/websearch/tooluse selection
|
||||
const DEDICATED_IMAGE_MODELS = [
|
||||
'grok-2-image(?:-[\\w-]+)?',
|
||||
// OpenAI series
|
||||
'dall-e(?:-[\\w-]+)?',
|
||||
'gpt-image-1(?:-[\\w-]+)?',
|
||||
'imagen(?:-[\\w-]+)?'
|
||||
'gpt-image(?:-[\\w-]+)?',
|
||||
// xAI
|
||||
'grok-2-image(?:-[\\w-]+)?',
|
||||
// Google
|
||||
'imagen(?:-[\\w-]+)?',
|
||||
// Stable Diffusion series
|
||||
'flux(?:-[\\w-]+)?',
|
||||
'stable-?diffusion(?:-[\\w-]+)?',
|
||||
'stabilityai(?:-[\\w-]+)?',
|
||||
'sd-[\\w-]+',
|
||||
'sdxl(?:-[\\w-]+)?',
|
||||
// zhipu
|
||||
'cogview(?:-[\\w-]+)?',
|
||||
// Alibaba
|
||||
'qwen-image(?:-[\\w-]+)?',
|
||||
// Others
|
||||
'janus(?:-[\\w-]+)?',
|
||||
'midjourney(?:-[\\w-]+)?',
|
||||
'mj-[\\w-]+',
|
||||
'z-image(?:-[\\w-]+)?',
|
||||
'longcat-image(?:-[\\w-]+)?',
|
||||
'hunyuanimage(?:-[\\w-]+)?',
|
||||
'seedream(?:-[\\w-]+)?',
|
||||
'kandinsky(?:-[\\w-]+)?'
|
||||
]
|
||||
|
||||
const IMAGE_ENHANCEMENT_MODELS = [
|
||||
@ -133,13 +158,23 @@ const GENERATE_IMAGE_MODELS_REGEX = new RegExp(GENERATE_IMAGE_MODELS.join('|'),
|
||||
|
||||
const MODERN_GENERATE_IMAGE_MODELS_REGEX = new RegExp(MODERN_IMAGE_MODELS.join('|'), 'i')
|
||||
|
||||
export const isDedicatedImageGenerationModel = (model: Model): boolean => {
|
||||
/**
|
||||
* Check if the model is a dedicated image generation model
|
||||
* Dedicated image generation models can only generate images, no text chat capability
|
||||
*
|
||||
* These models need:
|
||||
* 1. Route to dedicated image generation API
|
||||
* 2. Exclude from reasoning/websearch/tooluse selection
|
||||
*/
|
||||
export function isDedicatedImageModel(model: Model): boolean {
|
||||
if (!model) return false
|
||||
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return DEDICATED_IMAGE_MODELS_REGEX.test(modelId)
|
||||
}
|
||||
|
||||
// Backward compatible aliases
|
||||
export const isDedicatedImageGenerationModel = isDedicatedImageModel
|
||||
|
||||
export const isAutoEnableImageGenerationModel = (model: Model): boolean => {
|
||||
if (!model) return false
|
||||
|
||||
@ -195,14 +230,8 @@ export function isPureGenerateImageModel(model: Model): boolean {
|
||||
return !OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS.some((m) => modelId.includes(m))
|
||||
}
|
||||
|
||||
// TODO: refine the regex
|
||||
// Text to image models
|
||||
const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus|midjourney|mj-|imagen|gpt-image/i
|
||||
|
||||
export function isTextToImageModel(model: Model): boolean {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return TEXT_TO_IMAGE_REGEX.test(modelId)
|
||||
}
|
||||
// Backward compatible alias - now uses unified dedicated image model detection
|
||||
export const isTextToImageModel = isDedicatedImageModel
|
||||
|
||||
/**
|
||||
* 判断模型是否支持图片增强(包括编辑、增强、修复等)
|
||||
|
||||
@ -107,7 +107,7 @@ export const SYSTEM_PROVIDERS_CONFIG: Record<SystemProviderId, SystemProvider> =
|
||||
type: 'openai',
|
||||
apiKey: '',
|
||||
apiHost: 'https://aihubmix.com',
|
||||
anthropicApiHost: 'https://aihubmix.com/anthropic',
|
||||
anthropicApiHost: 'https://aihubmix.com',
|
||||
models: SYSTEM_MODELS.aihubmix,
|
||||
isSystem: true,
|
||||
enabled: false
|
||||
@ -200,7 +200,8 @@ export const SYSTEM_PROVIDERS_CONFIG: Record<SystemProviderId, SystemProvider> =
|
||||
name: 'TokenFlux',
|
||||
type: 'openai',
|
||||
apiKey: '',
|
||||
apiHost: 'https://tokenflux.ai',
|
||||
apiHost: 'https://api.tokenflux.ai/openai/v1',
|
||||
anthropicApiHost: 'https://api.tokenflux.ai/anthropic',
|
||||
models: SYSTEM_MODELS.tokenflux,
|
||||
isSystem: true,
|
||||
enabled: false
|
||||
@ -289,7 +290,7 @@ export const SYSTEM_PROVIDERS_CONFIG: Record<SystemProviderId, SystemProvider> =
|
||||
ollama: {
|
||||
id: 'ollama',
|
||||
name: 'Ollama',
|
||||
type: 'openai',
|
||||
type: 'ollama',
|
||||
apiKey: '',
|
||||
apiHost: 'http://localhost:11434',
|
||||
models: SYSTEM_MODELS.ollama,
|
||||
@ -1088,7 +1089,7 @@ export const PROVIDER_URLS: Record<SystemProviderId, ProviderUrls> = {
|
||||
websites: {
|
||||
official: 'https://platform.minimaxi.com/',
|
||||
apiKey: 'https://platform.minimaxi.com/user-center/basic-information/interface-key',
|
||||
docs: 'https://platform.minimaxi.com/document/Announcement',
|
||||
docs: 'https://platform.minimaxi.com/docs/api-reference/text-openai-api',
|
||||
models: 'https://platform.minimaxi.com/document/Models'
|
||||
}
|
||||
},
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/**
|
||||
* @deprecated Scheduled for removal in v2.0.0
|
||||
* --------------------------------------------------------------------------
|
||||
* ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex)
|
||||
* --------------------------------------------------------------------------
|
||||
* STOP: Feature PRs affecting this file are currently BLOCKED.
|
||||
* Only critical bug fixes are accepted during this migration phase.
|
||||
*
|
||||
* This file is being refactored to v2 standards.
|
||||
* Any non-critical changes will conflict with the ongoing work.
|
||||
*
|
||||
* 🔗 Context & Status:
|
||||
* - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954
|
||||
* - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162
|
||||
* --------------------------------------------------------------------------
|
||||
*/
|
||||
import type {
|
||||
CustomTranslateLanguage,
|
||||
FileMetadata,
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/**
|
||||
* @deprecated Scheduled for removal in v2.0.0
|
||||
* --------------------------------------------------------------------------
|
||||
* ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex)
|
||||
* --------------------------------------------------------------------------
|
||||
* STOP: Feature PRs affecting this file are currently BLOCKED.
|
||||
* Only critical bug fixes are accepted during this migration phase.
|
||||
*
|
||||
* This file is being refactored to v2 standards.
|
||||
* Any non-critical changes will conflict with the ongoing work.
|
||||
*
|
||||
* 🔗 Context & Status:
|
||||
* - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954
|
||||
* - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162
|
||||
* --------------------------------------------------------------------------
|
||||
*/
|
||||
import { loggerService } from '@logger'
|
||||
import { LanguagesEnum } from '@renderer/config/translate'
|
||||
import type { LegacyMessage as OldMessage, Topic, TranslateLanguageCode } from '@renderer/types'
|
||||
|
||||
@ -268,9 +268,7 @@ export function useAppInit() {
|
||||
// Update memory service configuration when it changes
|
||||
useEffect(() => {
|
||||
const memoryService = MemoryService.getInstance()
|
||||
memoryService.updateConfig().catch((error) => {
|
||||
logger.error('Failed to update memory config:', error)
|
||||
})
|
||||
memoryService.updateConfig().catch((error) => logger.error('Failed to update memory config:', error))
|
||||
}, [memoryConfig])
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/**
|
||||
* @deprecated Scheduled for removal in v2.0.0
|
||||
* --------------------------------------------------------------------------
|
||||
* ⚠️ NOTICE: V2 DATA&UI REFACTORING (by 0xfullex)
|
||||
* --------------------------------------------------------------------------
|
||||
* STOP: Feature PRs affecting this file are currently BLOCKED.
|
||||
* Only critical bug fixes are accepted during this migration phase.
|
||||
*
|
||||
* This file is being refactored to v2 standards.
|
||||
* Any non-critical changes will conflict with the ongoing work.
|
||||
*
|
||||
* 🔗 Context & Status:
|
||||
* - Contribution Hold: https://github.com/CherryHQ/cherry-studio/issues/10954
|
||||
* - v2 Refactor PR : https://github.com/CherryHQ/cherry-studio/pull/10162
|
||||
* --------------------------------------------------------------------------
|
||||
*/
|
||||
import store, { useAppDispatch, useAppSelector } from '@renderer/store'
|
||||
import type { AssistantIconType, SendMessageShortcut, SettingsState } from '@renderer/store/settings'
|
||||
import {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user