From c2561726e08e4f502382a056dd8134dfd06b0e13 Mon Sep 17 00:00:00 2001 From: Phantom <59059173+EurFelux@users.noreply.github.com> Date: Fri, 15 Aug 2025 14:31:49 +0800 Subject: [PATCH 01/14] style(Inputbar): use primary color for buttons (#9174) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit style(Inputbar): 统一按钮激活状态颜色为主题色 将输入栏中多个按钮的激活状态颜色从链接色(--color-link)统一为主题色(--color-primary),保持UI一致性 --- src/renderer/src/pages/home/Inputbar/GenerateImageButton.tsx | 2 +- src/renderer/src/pages/home/Inputbar/KnowledgeBaseButton.tsx | 5 ++++- src/renderer/src/pages/home/Inputbar/MentionModelsButton.tsx | 2 +- src/renderer/src/pages/home/Inputbar/UrlContextbutton.tsx | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/renderer/src/pages/home/Inputbar/GenerateImageButton.tsx b/src/renderer/src/pages/home/Inputbar/GenerateImageButton.tsx index c7d930bf5d..eadcea8b82 100644 --- a/src/renderer/src/pages/home/Inputbar/GenerateImageButton.tsx +++ b/src/renderer/src/pages/home/Inputbar/GenerateImageButton.tsx @@ -24,7 +24,7 @@ const GenerateImageButton: FC = ({ model, ToolbarButton, assistant, onEna mouseLeaveDelay={0} arrow> - + ) diff --git a/src/renderer/src/pages/home/Inputbar/KnowledgeBaseButton.tsx b/src/renderer/src/pages/home/Inputbar/KnowledgeBaseButton.tsx index 8e4782c8a7..735796dfad 100644 --- a/src/renderer/src/pages/home/Inputbar/KnowledgeBaseButton.tsx +++ b/src/renderer/src/pages/home/Inputbar/KnowledgeBaseButton.tsx @@ -87,7 +87,10 @@ const KnowledgeBaseButton: FC = ({ ref, selectedBases, onSelect, disabled return ( - + 0 ? 'var(--color-primary)' : 'var(--color-icon)'} + /> ) diff --git a/src/renderer/src/pages/home/Inputbar/MentionModelsButton.tsx b/src/renderer/src/pages/home/Inputbar/MentionModelsButton.tsx index 2aff40c0de..822d52fef6 100644 --- a/src/renderer/src/pages/home/Inputbar/MentionModelsButton.tsx +++ b/src/renderer/src/pages/home/Inputbar/MentionModelsButton.tsx @@ -195,7 +195,7 @@ const MentionModelsButton: FC = ({ return ( - + 0 ? 'var(--color-primary)' : 'var(--color-icon)'} /> ) diff --git a/src/renderer/src/pages/home/Inputbar/UrlContextbutton.tsx b/src/renderer/src/pages/home/Inputbar/UrlContextbutton.tsx index 2f5a228bb8..11cfde3f72 100644 --- a/src/renderer/src/pages/home/Inputbar/UrlContextbutton.tsx +++ b/src/renderer/src/pages/home/Inputbar/UrlContextbutton.tsx @@ -33,7 +33,7 @@ const UrlContextButton: FC = ({ assistant, ToolbarButton }) => { From 748ac600fa7149639173492fbef1afff4dea8918 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=A4=A9=E5=AF=92?= Date: Fri, 15 Aug 2025 15:13:48 +0800 Subject: [PATCH 02/14] fix(aws-bedrock): support thinking mode (#9172) * fix(aws-bedrock): support thinking mode * fix(aws-bedrock): fix code review suggestions * fix(aws-bedrock): Add thinking processing for other models --- .../__tests__/ApiClientFactory.test.ts | 23 + .../aiCore/clients/aws/AwsBedrockAPIClient.ts | 432 ++++++++++++++++-- src/renderer/src/types/sdk.ts | 19 + 3 files changed, 426 insertions(+), 48 deletions(-) diff --git a/src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts b/src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts index 172798dc38..5ec3bf6404 100644 --- a/src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts +++ b/src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts @@ -5,6 +5,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { AihubmixAPIClient } from '../AihubmixAPIClient' import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient' import { ApiClientFactory } from '../ApiClientFactory' +import { AwsBedrockAPIClient } from '../aws/AwsBedrockAPIClient' import { GeminiAPIClient } from '../gemini/GeminiAPIClient' import { VertexAPIClient } from '../gemini/VertexAPIClient' import { NewAPIClient } from '../NewAPIClient' @@ -54,6 +55,19 @@ vi.mock('../openai/OpenAIResponseAPIClient', () => ({ vi.mock('../ppio/PPIOAPIClient', () => ({ PPIOAPIClient: vi.fn().mockImplementation(() => ({})) })) +vi.mock('../aws/AwsBedrockAPIClient', () => ({ + AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({})) +})) + +// Mock the models config to prevent circular dependency issues +vi.mock('@renderer/config/models', () => ({ + findTokenLimit: vi.fn(), + isReasoningModel: vi.fn(), + SYSTEM_MODELS: { + silicon: [], + defaultModel: [] + } +})) describe('ApiClientFactory', () => { beforeEach(() => { @@ -144,6 +158,15 @@ describe('ApiClientFactory', () => { expect(client).toBeDefined() }) + it('should create AwsBedrockAPIClient for aws-bedrock type', () => { + const provider = createTestProvider('aws-bedrock', 'aws-bedrock') + + const client = ApiClientFactory.create(provider) + + expect(AwsBedrockAPIClient).toHaveBeenCalledWith(provider) + expect(client).toBeDefined() + }) + // 测试默认情况 it('should create OpenAIAPIClient as default for unknown type', () => { const provider = createTestProvider('unknown', 'unknown-type') diff --git a/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts b/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts index 4e29a0cb5c..d9bd9af9c8 100644 --- a/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts +++ b/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts @@ -2,19 +2,23 @@ import { BedrockClient, ListFoundationModelsCommand, ListInferenceProfilesComman import { BedrockRuntimeClient, ConverseCommand, - ConverseStreamCommand, - InvokeModelCommand + InvokeModelCommand, + InvokeModelWithResponseStreamCommand } from '@aws-sdk/client-bedrock-runtime' import { loggerService } from '@logger' import { GenericChunk } from '@renderer/aiCore/middleware/schemas' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' +import { findTokenLimit, isReasoningModel } from '@renderer/config/models' import { getAwsBedrockAccessKeyId, getAwsBedrockRegion, getAwsBedrockSecretAccessKey } from '@renderer/hooks/useAwsBedrock' +import { getAssistantSettings } from '@renderer/services/AssistantService' import { estimateTextTokens } from '@renderer/services/TokenService' import { + Assistant, + EFFORT_RATIO, GenerateImageParams, MCPCallToolResponse, MCPTool, @@ -23,7 +27,13 @@ import { Provider, ToolCallResponse } from '@renderer/types' -import { ChunkType, MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk' +import { + ChunkType, + MCPToolCreatedChunk, + TextDeltaChunk, + ThinkingDeltaChunk, + ThinkingStartChunk +} from '@renderer/types/chunk' import { Message } from '@renderer/types/newMessage' import { AwsBedrockSdkInstance, @@ -33,6 +43,7 @@ import { AwsBedrockSdkRawOutput, AwsBedrockSdkTool, AwsBedrockSdkToolCall, + AwsBedrockStreamChunk, SdkModel } from '@renderer/types/sdk' import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils' @@ -103,46 +114,65 @@ export class AwsBedrockAPIClient extends BaseApiClient< override async createCompletions(payload: AwsBedrockSdkParams): Promise { const sdk = await this.getSdkInstance() - // 转换消息格式到AWS SDK原生格式 + // 转换消息格式(用于 InvokeModelWithResponseStreamCommand) const awsMessages = payload.messages.map((msg) => ({ role: msg.role, content: msg.content.map((content) => { if (content.text) { - return { text: content.text } + return { type: 'text', text: content.text } } if (content.image) { + // 处理图片数据,将 Uint8Array 或数字数组转换为 base64 字符串 + let base64Data = '' + if (content.image.source.bytes) { + if (typeof content.image.source.bytes === 'string') { + // 如果已经是字符串,直接使用 + base64Data = content.image.source.bytes + } else { + // 如果是数组或 Uint8Array,转换为 base64 + const uint8Array = new Uint8Array(Object.values(content.image.source.bytes)) + const binaryString = Array.from(uint8Array) + .map((byte) => String.fromCharCode(byte)) + .join('') + base64Data = btoa(binaryString) + } + } + return { - image: { - format: content.image.format, - source: content.image.source + type: 'image', + source: { + type: 'base64', + media_type: `image/${content.image.format}`, + data: base64Data } } } if (content.toolResult) { return { - toolResult: { - toolUseId: content.toolResult.toolUseId, - content: content.toolResult.content, - status: content.toolResult.status - } + type: 'tool_result', + tool_use_id: content.toolResult.toolUseId, + content: content.toolResult.content } } if (content.toolUse) { return { - toolUse: { - toolUseId: content.toolUse.toolUseId, - name: content.toolUse.name, - input: content.toolUse.input - } + type: 'tool_use', + id: content.toolUse.toolUseId, + name: content.toolUse.name, + input: content.toolUse.input } } - // 返回符合AWS SDK ContentBlock类型的对象 - return { text: 'Unknown content type' } + return { type: 'text', text: 'Unknown content type' } }) })) logger.info('Creating completions with model ID:', { modelId: payload.modelId }) + const excludeKeys = ['modelId', 'messages', 'system', 'maxTokens', 'temperature', 'topP', 'stream', 'tools'] + const additionalParams = Object.keys(payload) + .filter((key) => !excludeKeys.includes(key)) + .reduce((acc, key) => ({ ...acc, [key]: payload[key] }), {}) + const commonParams = { modelId: payload.modelId, messages: awsMessages as any, @@ -162,10 +192,18 @@ export class AwsBedrockAPIClient extends BaseApiClient< try { if (payload.stream) { - const command = new ConverseStreamCommand(commonParams) + // 根据模型类型选择正确的 API 格式 + const requestBody = this.createRequestBodyForModel(commonParams, additionalParams) + + const command = new InvokeModelWithResponseStreamCommand({ + modelId: commonParams.modelId, + body: JSON.stringify(requestBody), + contentType: 'application/json', + accept: 'application/json' + }) + const response = await sdk.client.send(command) - // 直接返回AWS Bedrock流式响应的异步迭代器 - return this.createStreamIterator(response) + return this.createInvokeModelStreamIterator(response) } else { const command = new ConverseCommand(commonParams) const response = await sdk.client.send(command) @@ -177,32 +215,236 @@ export class AwsBedrockAPIClient extends BaseApiClient< } } - private async *createStreamIterator(response: any): AsyncIterable { - try { - if (response.stream) { - for await (const chunk of response.stream) { - logger.debug('AWS Bedrock chunk received:', chunk) + /** + * 根据模型类型创建请求体 + */ + private createRequestBodyForModel(commonParams: any, additionalParams: any): any { + const modelId = commonParams.modelId.toLowerCase() - // AWS Bedrock的流式响应格式转换为标准格式 - if (chunk.contentBlockDelta?.delta?.text) { - yield { - contentBlockDelta: { - delta: { text: chunk.contentBlockDelta.delta.text } + // Claude 系列模型使用 Anthropic API 格式 + if (modelId.includes('claude')) { + return { + anthropic_version: 'bedrock-2023-05-31', + max_tokens: commonParams.inferenceConfig.maxTokens, + temperature: commonParams.inferenceConfig.temperature, + top_p: commonParams.inferenceConfig.topP, + messages: commonParams.messages, + ...(commonParams.system && commonParams.system[0]?.text ? { system: commonParams.system[0].text } : {}), + ...(commonParams.toolConfig?.tools ? { tools: commonParams.toolConfig.tools } : {}), + ...additionalParams + } + } + + // OpenAI 系列模型 + if (modelId.includes('gpt') || modelId.includes('openai')) { + const messages: any[] = [] + + // 添加系统消息 + if (commonParams.system && commonParams.system[0]?.text) { + messages.push({ + role: 'system', + content: commonParams.system[0].text + }) + } + + // 转换消息格式 + for (const message of commonParams.messages) { + const content: any[] = [] + for (const part of message.content) { + if (part.text) { + content.push({ type: 'text', text: part.text }) + } else if (part.image) { + content.push({ + type: 'image_url', + image_url: { + url: `data:image/${part.image.format};base64,${part.image.source.bytes}` + } + }) + } + } + messages.push({ + role: message.role, + content: content.length === 1 && content[0].type === 'text' ? content[0].text : content + }) + } + + const baseBody: any = { + model: commonParams.modelId, + messages: messages, + max_tokens: commonParams.inferenceConfig.maxTokens, + temperature: commonParams.inferenceConfig.temperature, + top_p: commonParams.inferenceConfig.topP, + stream: true, + ...(commonParams.toolConfig?.tools ? { tools: commonParams.toolConfig.tools } : {}) + } + + // OpenAI 模型的 thinking 参数格式 + if (additionalParams.reasoning_effort) { + baseBody.reasoning_effort = additionalParams.reasoning_effort + delete additionalParams.reasoning_effort + } + + return { + ...baseBody, + ...additionalParams + } + } + + // Llama 系列模型 + if (modelId.includes('llama')) { + const baseBody: any = { + prompt: this.convertMessagesToPrompt(commonParams.messages, commonParams.system), + max_gen_len: commonParams.inferenceConfig.maxTokens, + temperature: commonParams.inferenceConfig.temperature, + top_p: commonParams.inferenceConfig.topP + } + + // Llama 模型的 thinking 参数格式 + if (additionalParams.thinking_mode) { + baseBody.thinking_mode = additionalParams.thinking_mode + delete additionalParams.thinking_mode + } + + return { + ...baseBody, + ...additionalParams + } + } + + // Amazon Titan 系列模型 + if (modelId.includes('titan')) { + const textGenerationConfig: any = { + maxTokenCount: commonParams.inferenceConfig.maxTokens, + temperature: commonParams.inferenceConfig.temperature, + topP: commonParams.inferenceConfig.topP + } + + // 将 thinking 相关参数添加到 textGenerationConfig 中 + if (additionalParams.thinking) { + textGenerationConfig.thinking = additionalParams.thinking + delete additionalParams.thinking + } + + return { + inputText: this.convertMessagesToPrompt(commonParams.messages, commonParams.system), + textGenerationConfig: { + ...textGenerationConfig, + ...Object.keys(additionalParams).reduce((acc, key) => { + if (['thinking_tokens', 'reasoning_mode'].includes(key)) { + acc[key] = additionalParams[key] + delete additionalParams[key] + } + return acc + }, {} as any) + }, + ...additionalParams + } + } + + // Cohere Command 系列模型 + if (modelId.includes('cohere') || modelId.includes('command')) { + const baseBody: any = { + message: this.convertMessagesToPrompt(commonParams.messages, commonParams.system), + max_tokens: commonParams.inferenceConfig.maxTokens, + temperature: commonParams.inferenceConfig.temperature, + p: commonParams.inferenceConfig.topP + } + + // Cohere 模型的 thinking 参数格式 + if (additionalParams.thinking) { + baseBody.thinking = additionalParams.thinking + delete additionalParams.thinking + } + if (additionalParams.reasoning_tokens) { + baseBody.reasoning_tokens = additionalParams.reasoning_tokens + delete additionalParams.reasoning_tokens + } + + return { + ...baseBody, + ...additionalParams + } + } + + // 默认使用通用格式 + const baseBody: any = { + prompt: this.convertMessagesToPrompt(commonParams.messages, commonParams.system), + max_tokens: commonParams.inferenceConfig.maxTokens, + temperature: commonParams.inferenceConfig.temperature, + top_p: commonParams.inferenceConfig.topP + } + + return { + ...baseBody, + ...additionalParams + } + } + + /** + * 将消息转换为简单的 prompt 格式 + */ + private convertMessagesToPrompt(messages: any[], system?: any[]): string { + let prompt = '' + + // 添加系统消息 + if (system && system[0]?.text) { + prompt += `System: ${system[0].text}\n\n` + } + + // 添加对话消息 + for (const message of messages) { + const role = message.role === 'assistant' ? 'Assistant' : 'Human' + let content = '' + + for (const part of message.content) { + if (part.text) { + content += part.text + } else if (part.image) { + content += '[Image]' + } + } + + prompt += `${role}: ${content}\n\n` + } + + prompt += 'Assistant:' + return prompt + } + + private async *createInvokeModelStreamIterator(response: any): AsyncIterable { + try { + if (response.body) { + for await (const event of response.body) { + if (event.chunk) { + const chunk: AwsBedrockStreamChunk = JSON.parse(new TextDecoder().decode(event.chunk.bytes)) + + // 转换为标准格式 + if (chunk.type === 'content_block_delta') { + yield { + contentBlockDelta: { + delta: chunk.delta, + contentBlockIndex: chunk.index + } + } + } else if (chunk.type === 'message_start') { + yield { messageStart: chunk } + } else if (chunk.type === 'message_stop') { + yield { messageStop: chunk } + } else if (chunk.type === 'content_block_start') { + yield { + contentBlockStart: { + start: chunk.content_block, + contentBlockIndex: chunk.index + } + } + } else if (chunk.type === 'content_block_stop') { + yield { + contentBlockStop: { + contentBlockIndex: chunk.index + } } } } - - if (chunk.messageStart) { - yield { messageStart: chunk.messageStart } - } - - if (chunk.messageStop) { - yield { messageStop: chunk.messageStop } - } - - if (chunk.metadata) { - yield { metadata: chunk.metadata } - } } } } catch (error) { @@ -485,6 +727,38 @@ export class AwsBedrockAPIClient extends BaseApiClient< } } + // 获取推理预算token(对所有支持推理的模型) + const budgetTokens = this.getBudgetToken(assistant, model) + + // 构建基础自定义参数 + const customParams: Record = + coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {} + + // 根据模型类型添加 thinking 参数 + if (budgetTokens) { + const modelId = model.id.toLowerCase() + + if (modelId.includes('claude')) { + // Claude 模型使用 Anthropic 格式 + customParams.thinking = { type: 'enabled', budget_tokens: budgetTokens } + } else if (modelId.includes('gpt') || modelId.includes('openai')) { + // OpenAI 模型格式 + customParams.reasoning_effort = assistant?.settings?.reasoning_effort + } else if (modelId.includes('llama')) { + // Llama 模型格式 + customParams.thinking_mode = true + customParams.thinking_tokens = budgetTokens + } else if (modelId.includes('titan')) { + // Titan 模型格式 + customParams.thinking = { enabled: true } + customParams.thinking_tokens = budgetTokens + } else if (modelId.includes('cohere') || modelId.includes('command')) { + // Cohere 模型格式 + customParams.thinking = { enabled: true } + customParams.reasoning_tokens = budgetTokens + } + } + const payload: AwsBedrockSdkParams = { modelId: model.id, messages: @@ -497,9 +771,7 @@ export class AwsBedrockAPIClient extends BaseApiClient< topP: this.getTopP(assistant, model), stream: streamOutput !== false, tools: tools.length > 0 ? tools : undefined, - // 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑 - // 注意:用户自定义参数总是应该覆盖其他参数 - ...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}) + ...customParams } const timeout = this.getTimeout(model) @@ -511,6 +783,7 @@ export class AwsBedrockAPIClient extends BaseApiClient< getResponseChunkTransformer(): ResponseChunkTransformer { return () => { let hasStartedText = false + let hasStartedThinking = false let accumulatedJson = '' const toolCalls: Record = {} @@ -570,6 +843,24 @@ export class AwsBedrockAPIClient extends BaseApiClient< } as TextDeltaChunk) } + // 处理thinking增量 + if ( + rawChunk.contentBlockDelta?.delta?.type === 'thinking_delta' && + rawChunk.contentBlockDelta?.delta?.thinking + ) { + if (!hasStartedThinking) { + controller.enqueue({ + type: ChunkType.THINKING_START + } as ThinkingStartChunk) + hasStartedThinking = true + } + + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: rawChunk.contentBlockDelta.delta.thinking + } as ThinkingDeltaChunk) + } + // 处理内容块停止事件 - 参考 Anthropic 的 content_block_stop 处理 if (rawChunk.contentBlockStop) { const blockIndex = rawChunk.contentBlockStop.contentBlockIndex || 0 @@ -708,4 +999,49 @@ export class AwsBedrockAPIClient extends BaseApiClient< extractMessagesFromSdkPayload(sdkPayload: AwsBedrockSdkParams): AwsBedrockSdkMessageParam[] { return sdkPayload.messages || [] } + + /** + * 获取 AWS Bedrock 的推理工作量预算token + * @param assistant - The assistant + * @param model - The model + * @returns The budget tokens for reasoning effort + */ + private getBudgetToken(assistant: Assistant, model: Model): number | undefined { + try { + if (!isReasoningModel(model)) { + return undefined + } + + const { maxTokens } = getAssistantSettings(assistant) + const reasoningEffort = assistant?.settings?.reasoning_effort + + if (reasoningEffort === undefined) { + return undefined + } + + const effortRatio = EFFORT_RATIO[reasoningEffort] + const tokenLimits = findTokenLimit(model.id) + + if (tokenLimits) { + // 使用模型特定的 token 限制 + const budgetTokens = Math.max( + 1024, + Math.floor( + Math.min( + (tokenLimits.max - tokenLimits.min) * effortRatio + tokenLimits.min, + (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio + ) + ) + ) + return budgetTokens + } else { + // 对于没有特定限制的模型,使用简化计算 + const budgetTokens = Math.max(1024, Math.floor((maxTokens || DEFAULT_MAX_TOKENS) * effortRatio)) + return budgetTokens + } + } catch (error) { + logger.warn('Failed to calculate budget tokens for reasoning effort:', error as Error) + return undefined + } + } } diff --git a/src/renderer/src/types/sdk.ts b/src/renderer/src/types/sdk.ts index a5413e54fb..36608ab9fe 100644 --- a/src/renderer/src/types/sdk.ts +++ b/src/renderer/src/types/sdk.ts @@ -162,6 +162,7 @@ export interface AwsBedrockSdkParams { topP?: number stream?: boolean tools?: AwsBedrockSdkTool[] + [key: string]: any // Allow any additional custom parameters } export interface AwsBedrockSdkMessageParam { @@ -206,6 +207,22 @@ export interface AwsBedrockSdkMessageParam { }> } +export interface AwsBedrockStreamChunk { + type: string + delta?: { + text?: string + toolUse?: { input?: string } + type?: string + thinking?: string + } + index?: number + content_block?: any + usage?: { + inputTokens?: number + outputTokens?: number + } +} + export interface AwsBedrockSdkRawChunk { contentBlockStart?: { start?: { @@ -222,6 +239,8 @@ export interface AwsBedrockSdkRawChunk { toolUse?: { input?: string } + type?: string // 支持 'thinking_delta' 等类型 + thinking?: string // 支持 thinking 内容 } contentBlockIndex?: number } From 4a62bb6ad71614d3553511992cf3bed3984cda26 Mon Sep 17 00:00:00 2001 From: beyondkmp Date: Fri, 15 Aug 2025 22:48:22 +0800 Subject: [PATCH 03/14] refactor: replace axios and node fetch with electron's net module (#9212) * refactor: replace axios and node fetch with electron's net module for network requests in preprocess providers - Updated Doc2xPreprocessProvider and MineruPreprocessProvider to use net.fetch instead of axios for making HTTP requests. - Improved error handling for network responses across various methods. - Removed unnecessary AxiosRequestConfig and related code to streamline the implementation. * lint * refactor(Doc2xPreprocessProvider): enhance file validation and upload process - Added file size validation to prevent loading files larger than 300MB into memory. - Implemented file size check before reading the PDF to ensure efficient memory usage. - Updated the file upload method to use a stream, setting the 'Content-Length' header for better handling of large files. * refactor(brave-search): update net.fetch calls to use url.toString() - Modified all instances of net.fetch to use url.toString() for better URL handling. - Ensured consistency in how URLs are passed to the fetch method across various functions. * refactor(MCPService): improve URL handling in net.fetch calls - Updated net.fetch to use url.toString() for better type handling of URLs. - Ensured consistent URL processing across the MCPService class. * feat(ProxyManager): integrate axios with fetch proxy support - Added axios as a dependency to enable fetch proxy usage. - Implemented logic to set axios's adapter to 'fetch' for proxy handling. - Preserved original axios adapter for restoration when disabling the proxy. --- .../preprocess/Doc2xPreprocessProvider.ts | 136 ++++++++++++------ .../preprocess/MineruPreprocessProvider.ts | 18 ++- .../knowledge/reranker/GeneralReranker.ts | 14 +- src/main/mcpServers/brave-search.ts | 9 +- src/main/mcpServers/dify-knowledge.ts | 5 +- src/main/mcpServers/fetch.ts | 3 +- src/main/services/AppUpdater.ts | 6 +- src/main/services/CopilotService.ts | 70 +++++---- src/main/services/FileStorage.ts | 3 +- src/main/services/MCPService.ts | 4 +- src/main/services/NutstoreService.ts | 3 +- src/main/utils/ipService.ts | 3 +- 12 files changed, 182 insertions(+), 92 deletions(-) diff --git a/src/main/knowledge/preprocess/Doc2xPreprocessProvider.ts b/src/main/knowledge/preprocess/Doc2xPreprocessProvider.ts index afc8d1ba9b..834ff2f27e 100644 --- a/src/main/knowledge/preprocess/Doc2xPreprocessProvider.ts +++ b/src/main/knowledge/preprocess/Doc2xPreprocessProvider.ts @@ -5,7 +5,7 @@ import { loggerService } from '@logger' import { fileStorage } from '@main/services/FileStorage' import { FileMetadata, PreprocessProvider } from '@types' import AdmZip from 'adm-zip' -import axios, { AxiosRequestConfig } from 'axios' +import { net } from 'electron' import BasePreprocessProvider from './BasePreprocessProvider' @@ -38,19 +38,24 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { } private async validateFile(filePath: string): Promise { - const pdfBuffer = await fs.promises.readFile(filePath) + // 首先检查文件大小,避免读取大文件到内存 + const stats = await fs.promises.stat(filePath) + const fileSizeBytes = stats.size + // 文件大小小于300MB + if (fileSizeBytes >= 300 * 1024 * 1024) { + const fileSizeMB = Math.round(fileSizeBytes / (1024 * 1024)) + throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 300MB`) + } + + // 只有在文件大小合理的情况下才读取文件内容检查页数 + const pdfBuffer = await fs.promises.readFile(filePath) const doc = await this.readPdf(pdfBuffer) // 文件页数小于1000页 if (doc.numPages >= 1000) { throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 1000 pages`) } - // 文件大小小于300MB - if (pdfBuffer.length >= 300 * 1024 * 1024) { - const fileSizeMB = Math.round(pdfBuffer.length / (1024 * 1024)) - throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 300MB`) - } } public async parseFile(sourceId: string, file: FileMetadata): Promise<{ processedFile: FileMetadata }> { @@ -160,11 +165,23 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { * @returns 预上传响应的url和uid */ private async preupload(): Promise { - const config = this.createAuthConfig() const endpoint = `${this.provider.apiHost}/api/v2/parse/preupload` try { - const { data } = await axios.post>(endpoint, null, config) + const response = await net.fetch(endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${this.provider.apiKey}` + }, + body: null + }) + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + const data = (await response.json()) as ApiResponse if (data.code === 'success' && data.data) { return data.data @@ -178,17 +195,29 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { } /** - * 上传文件 + * 上传文件(使用流式上传) * @param filePath 文件路径 * @param url 预上传响应的url */ private async putFile(filePath: string, url: string): Promise { try { - const fileStream = fs.createReadStream(filePath) - const response = await axios.put(url, fileStream) + // 获取文件大小用于设置 Content-Length + const stats = await fs.promises.stat(filePath) + const fileSize = stats.size - if (response.status !== 200) { - throw new Error(`HTTP status ${response.status}: ${response.statusText}`) + // 创建可读流 + const fileStream = fs.createReadStream(filePath) + + const response = await net.fetch(url, { + method: 'PUT', + body: fileStream as any, // TypeScript 类型转换,net.fetch 支持 ReadableStream + headers: { + 'Content-Length': fileSize.toString() + } + }) + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) } } catch (error) { logger.error(`Failed to upload file ${filePath}: ${error instanceof Error ? error.message : String(error)}`) @@ -197,16 +226,25 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { } private async getStatus(uid: string): Promise { - const config = this.createAuthConfig() const endpoint = `${this.provider.apiHost}/api/v2/parse/status?uid=${uid}` try { - const response = await axios.get>(endpoint, config) + const response = await net.fetch(endpoint, { + method: 'GET', + headers: { + Authorization: `Bearer ${this.provider.apiKey}` + } + }) - if (response.data.code === 'success' && response.data.data) { - return response.data.data + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + const data = (await response.json()) as ApiResponse + if (data.code === 'success' && data.data) { + return data.data } else { - throw new Error(`API returned error: ${response.data.message || JSON.stringify(response.data)}`) + throw new Error(`API returned error: ${data.message || JSON.stringify(data)}`) } } catch (error) { logger.error(`Failed to get status for uid ${uid}: ${error instanceof Error ? error.message : String(error)}`) @@ -221,13 +259,6 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { */ private async convertFile(uid: string, filePath: string): Promise { const fileName = path.parse(filePath).name - const config = { - ...this.createAuthConfig(), - headers: { - ...this.createAuthConfig().headers, - 'Content-Type': 'application/json' - } - } const payload = { uid, @@ -239,10 +270,22 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { const endpoint = `${this.provider.apiHost}/api/v2/convert/parse` try { - const response = await axios.post>(endpoint, payload, config) + const response = await net.fetch(endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${this.provider.apiKey}` + }, + body: JSON.stringify(payload) + }) - if (response.data.code !== 'success') { - throw new Error(`API returned error: ${response.data.message || JSON.stringify(response.data)}`) + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + const data = (await response.json()) as ApiResponse + if (data.code !== 'success') { + throw new Error(`API returned error: ${data.message || JSON.stringify(data)}`) } } catch (error) { logger.error(`Failed to convert file ${filePath}: ${error instanceof Error ? error.message : String(error)}`) @@ -256,16 +299,25 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { * @returns 解析后的文件信息 */ private async getParsedFile(uid: string): Promise { - const config = this.createAuthConfig() const endpoint = `${this.provider.apiHost}/api/v2/convert/parse/result?uid=${uid}` try { - const response = await axios.get>(endpoint, config) + const response = await net.fetch(endpoint, { + method: 'GET', + headers: { + Authorization: `Bearer ${this.provider.apiKey}` + } + }) - if (response.status === 200 && response.data.data) { - return response.data.data + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + const data = (await response.json()) as ApiResponse + if (data.data) { + return data.data } else { - throw new Error(`HTTP status ${response.status}: ${response.statusText}`) + throw new Error(`No data in response`) } } catch (error) { logger.error( @@ -295,8 +347,12 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { try { // 下载文件 - const response = await axios.get(url, { responseType: 'arraybuffer' }) - fs.writeFileSync(zipPath, response.data) + const response = await net.fetch(url, { method: 'GET' }) + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + const arrayBuffer = await response.arrayBuffer() + fs.writeFileSync(zipPath, Buffer.from(arrayBuffer)) // 确保提取目录存在 if (!fs.existsSync(extractPath)) { @@ -318,14 +374,6 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { } } - private createAuthConfig(): AxiosRequestConfig { - return { - headers: { - Authorization: `Bearer ${this.provider.apiKey}` - } - } - } - public checkQuota(): Promise { throw new Error('Method not implemented.') } diff --git a/src/main/knowledge/preprocess/MineruPreprocessProvider.ts b/src/main/knowledge/preprocess/MineruPreprocessProvider.ts index 0e29a6443f..1976f64c05 100644 --- a/src/main/knowledge/preprocess/MineruPreprocessProvider.ts +++ b/src/main/knowledge/preprocess/MineruPreprocessProvider.ts @@ -5,7 +5,7 @@ import { loggerService } from '@logger' import { fileStorage } from '@main/services/FileStorage' import { FileMetadata, PreprocessProvider } from '@types' import AdmZip from 'adm-zip' -import axios from 'axios' +import { net } from 'electron' import BasePreprocessProvider from './BasePreprocessProvider' @@ -95,7 +95,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider { public async checkQuota() { try { - const quota = await fetch(`${this.provider.apiHost}/api/v4/quota`, { + const quota = await net.fetch(`${this.provider.apiHost}/api/v4/quota`, { method: 'GET', headers: { 'Content-Type': 'application/json', @@ -179,8 +179,12 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider { try { // 下载ZIP文件 - const response = await axios.get(zipUrl, { responseType: 'arraybuffer' }) - fs.writeFileSync(zipPath, Buffer.from(response.data)) + const response = await net.fetch(zipUrl, { method: 'GET' }) + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + const arrayBuffer = await response.arrayBuffer() + fs.writeFileSync(zipPath, Buffer.from(arrayBuffer)) logger.info(`Downloaded ZIP file: ${zipPath}`) // 确保提取目录存在 @@ -236,7 +240,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider { } try { - const response = await fetch(endpoint, { + const response = await net.fetch(endpoint, { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -271,7 +275,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider { try { const fileBuffer = await fs.promises.readFile(filePath) - const response = await fetch(uploadUrl, { + const response = await net.fetch(uploadUrl, { method: 'PUT', body: fileBuffer, headers: { @@ -316,7 +320,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider { const endpoint = `${this.provider.apiHost}/api/v4/extract-results/batch/${batchId}` try { - const response = await fetch(endpoint, { + const response = await net.fetch(endpoint, { method: 'GET', headers: { 'Content-Type': 'application/json', diff --git a/src/main/knowledge/reranker/GeneralReranker.ts b/src/main/knowledge/reranker/GeneralReranker.ts index 1252ecad57..5a0e240a9d 100644 --- a/src/main/knowledge/reranker/GeneralReranker.ts +++ b/src/main/knowledge/reranker/GeneralReranker.ts @@ -1,6 +1,6 @@ import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' import { KnowledgeBaseParams } from '@types' -import axios from 'axios' +import { net } from 'electron' import BaseReranker from './BaseReranker' @@ -15,7 +15,17 @@ export default class GeneralReranker extends BaseReranker { const requestBody = this.getRerankRequestBody(query, searchResults) try { - const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() }) + const response = await net.fetch(url, { + method: 'POST', + headers: this.defaultHeaders(), + body: JSON.stringify(requestBody) + }) + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + const data = await response.json() const rerankResults = this.extractRerankResult(data) return this.getRerankResult(searchResults, rerankResults) diff --git a/src/main/mcpServers/brave-search.ts b/src/main/mcpServers/brave-search.ts index 6f219e1eb8..d11a4f2580 100644 --- a/src/main/mcpServers/brave-search.ts +++ b/src/main/mcpServers/brave-search.ts @@ -3,6 +3,7 @@ import { Server } from '@modelcontextprotocol/sdk/server/index.js' import { CallToolRequestSchema, ListToolsRequestSchema, Tool } from '@modelcontextprotocol/sdk/types.js' +import { net } from 'electron' const WEB_SEARCH_TOOL: Tool = { name: 'brave_web_search', @@ -159,7 +160,7 @@ async function performWebSearch(apiKey: string, query: string, count: number = 1 url.searchParams.set('count', Math.min(count, 20).toString()) // API limit url.searchParams.set('offset', offset.toString()) - const response = await fetch(url, { + const response = await net.fetch(url.toString(), { headers: { Accept: 'application/json', 'Accept-Encoding': 'gzip', @@ -192,7 +193,7 @@ async function performLocalSearch(apiKey: string, query: string, count: number = webUrl.searchParams.set('result_filter', 'locations') webUrl.searchParams.set('count', Math.min(count, 20).toString()) - const webResponse = await fetch(webUrl, { + const webResponse = await net.fetch(webUrl.toString(), { headers: { Accept: 'application/json', 'Accept-Encoding': 'gzip', @@ -225,7 +226,7 @@ async function getPoisData(apiKey: string, ids: string[]): Promise url.searchParams.append('ids', id)) - const response = await fetch(url, { + const response = await net.fetch(url.toString(), { headers: { Accept: 'application/json', 'Accept-Encoding': 'gzip', @@ -244,7 +245,7 @@ async function getDescriptionsData(apiKey: string, ids: string[]): Promise url.searchParams.append('ids', id)) - const response = await fetch(url, { + const response = await net.fetch(url.toString(), { headers: { Accept: 'application/json', 'Accept-Encoding': 'gzip', diff --git a/src/main/mcpServers/dify-knowledge.ts b/src/main/mcpServers/dify-knowledge.ts index 2bd2c4adda..83a352fd4f 100644 --- a/src/main/mcpServers/dify-knowledge.ts +++ b/src/main/mcpServers/dify-knowledge.ts @@ -2,6 +2,7 @@ import { loggerService } from '@logger' import { Server } from '@modelcontextprotocol/sdk/server/index.js' import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js' +import { net } from 'electron' import * as z from 'zod/v4' const logger = loggerService.withContext('DifyKnowledgeServer') @@ -134,7 +135,7 @@ class DifyKnowledgeServer { private async performListKnowledges(difyKey: string, apiHost: string): Promise { try { const url = `${apiHost.replace(/\/$/, '')}/datasets` - const response = await fetch(url, { + const response = await net.fetch(url, { method: 'GET', headers: { Authorization: `Bearer ${difyKey}` @@ -186,7 +187,7 @@ class DifyKnowledgeServer { try { const url = `${apiHost.replace(/\/$/, '')}/datasets/${id}/retrieve` - const response = await fetch(url, { + const response = await net.fetch(url, { method: 'POST', headers: { Authorization: `Bearer ${difyKey}`, diff --git a/src/main/mcpServers/fetch.ts b/src/main/mcpServers/fetch.ts index 04839d8a92..e55b114776 100644 --- a/src/main/mcpServers/fetch.ts +++ b/src/main/mcpServers/fetch.ts @@ -2,6 +2,7 @@ import { Server } from '@modelcontextprotocol/sdk/server/index.js' import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js' +import { net } from 'electron' import { JSDOM } from 'jsdom' import TurndownService from 'turndown' import { z } from 'zod' @@ -16,7 +17,7 @@ export type RequestPayload = z.infer export class Fetcher { private static async _fetch({ url, headers }: RequestPayload): Promise { try { - const response = await fetch(url, { + const response = await net.fetch(url, { headers: { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36', diff --git a/src/main/services/AppUpdater.ts b/src/main/services/AppUpdater.ts index ea3b1f3f1e..bdfb8e3cc8 100644 --- a/src/main/services/AppUpdater.ts +++ b/src/main/services/AppUpdater.ts @@ -6,7 +6,7 @@ import { generateUserAgent } from '@main/utils/systemInfo' import { FeedUrl, UpgradeChannel } from '@shared/config/constant' import { IpcChannel } from '@shared/IpcChannel' import { CancellationToken, UpdateInfo } from 'builder-util-runtime' -import { app, BrowserWindow, dialog } from 'electron' +import { app, BrowserWindow, dialog, net } from 'electron' import { AppUpdater as _AppUpdater, autoUpdater, Logger, NsisUpdater, UpdateCheckResult } from 'electron-updater' import path from 'path' import semver from 'semver' @@ -75,7 +75,7 @@ export default class AppUpdater { } try { logger.info(`get release version from github: ${channel}`) - const responses = await fetch('https://api.github.com/repos/CherryHQ/cherry-studio/releases?per_page=8', { + const responses = await net.fetch('https://api.github.com/repos/CherryHQ/cherry-studio/releases?per_page=8', { headers }) const data = (await responses.json()) as GithubReleaseInfo[] @@ -99,7 +99,7 @@ export default class AppUpdater { if (mightHaveLatest) { logger.info(`might have latest release, get latest release`) - const latestReleaseResponse = await fetch( + const latestReleaseResponse = await net.fetch( 'https://api.github.com/repos/CherryHQ/cherry-studio/releases/latest', { headers diff --git a/src/main/services/CopilotService.ts b/src/main/services/CopilotService.ts index bb54e74932..f5c773a7cc 100644 --- a/src/main/services/CopilotService.ts +++ b/src/main/services/CopilotService.ts @@ -1,6 +1,5 @@ import { loggerService } from '@logger' -import { AxiosRequestConfig } from 'axios' -import axios from 'axios' +import { net } from 'electron' import { app, safeStorage } from 'electron' import fs from 'fs/promises' import path from 'path' @@ -86,7 +85,8 @@ class CopilotService { */ public getUser = async (_: Electron.IpcMainInvokeEvent, token: string): Promise => { try { - const config: AxiosRequestConfig = { + const response = await net.fetch(CONFIG.API_URLS.GITHUB_USER, { + method: 'GET', headers: { Connection: 'keep-alive', 'user-agent': 'Visual Studio Code (desktop)', @@ -95,12 +95,16 @@ class CopilotService { 'Sec-Fetch-Dest': 'empty', authorization: `token ${token}` } + }) + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) } - const response = await axios.get(CONFIG.API_URLS.GITHUB_USER, config) + const data = await response.json() return { - login: response.data.login, - avatar: response.data.avatar_url + login: data.login, + avatar: data.avatar_url } } catch (error) { logger.error('Failed to get user information:', error as Error) @@ -118,16 +122,23 @@ class CopilotService { try { this.updateHeaders(headers) - const response = await axios.post( - CONFIG.API_URLS.GITHUB_DEVICE_CODE, - { + const response = await net.fetch(CONFIG.API_URLS.GITHUB_DEVICE_CODE, { + method: 'POST', + headers: { + ...this.headers, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ client_id: CONFIG.GITHUB_CLIENT_ID, scope: 'read:user' - }, - { headers: this.headers } - ) + }) + }) - return response.data + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + return (await response.json()) as AuthResponse } catch (error) { logger.error('Failed to get auth message:', error as Error) throw new CopilotServiceError('无法获取GitHub授权信息', error) @@ -150,17 +161,25 @@ class CopilotService { await this.delay(currentDelay) try { - const response = await axios.post( - CONFIG.API_URLS.GITHUB_ACCESS_TOKEN, - { + const response = await net.fetch(CONFIG.API_URLS.GITHUB_ACCESS_TOKEN, { + method: 'POST', + headers: { + ...this.headers, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ client_id: CONFIG.GITHUB_CLIENT_ID, device_code, grant_type: 'urn:ietf:params:oauth:grant-type:device_code' - }, - { headers: this.headers } - ) + }) + }) - const { access_token } = response.data + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + const data = (await response.json()) as TokenResponse + const { access_token } = data if (access_token) { return { access_token } } @@ -205,16 +224,19 @@ class CopilotService { const encryptedToken = await fs.readFile(this.tokenFilePath) const access_token = safeStorage.decryptString(Buffer.from(encryptedToken)) - const config: AxiosRequestConfig = { + const response = await net.fetch(CONFIG.API_URLS.COPILOT_TOKEN, { + method: 'GET', headers: { ...this.headers, authorization: `token ${access_token}` } + }) + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) } - const response = await axios.get(CONFIG.API_URLS.COPILOT_TOKEN, config) - - return response.data + return (await response.json()) as CopilotTokenResponse } catch (error) { logger.error('Failed to get Copilot token:', error as Error) throw new CopilotServiceError('无法获取Copilot令牌,请重新授权', error) diff --git a/src/main/services/FileStorage.ts b/src/main/services/FileStorage.ts index e34c51f299..f5df9ed3f7 100644 --- a/src/main/services/FileStorage.ts +++ b/src/main/services/FileStorage.ts @@ -5,6 +5,7 @@ import { FileMetadata } from '@types' import * as crypto from 'crypto' import { dialog, + net, OpenDialogOptions, OpenDialogReturnValue, SaveDialogOptions, @@ -509,7 +510,7 @@ class FileStorage { isUseContentType?: boolean ): Promise => { try { - const response = await fetch(url) + const response = await net.fetch(url) if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`) } diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index d3909cc86f..a7f907f65f 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -29,7 +29,7 @@ import { } from '@modelcontextprotocol/sdk/types.js' import { nanoid } from '@reduxjs/toolkit' import type { GetResourceResponse, MCPCallToolResponse, MCPPrompt, MCPResource, MCPServer, MCPTool } from '@types' -import { app } from 'electron' +import { app, net } from 'electron' import { EventEmitter } from 'events' import { memoize } from 'lodash' import { v4 as uuidv4 } from 'uuid' @@ -205,7 +205,7 @@ class McpService { } } - return fetch(url, { ...init, headers }) + return net.fetch(typeof url === 'string' ? url : url.toString(), { ...init, headers }) } }, requestInit: { diff --git a/src/main/services/NutstoreService.ts b/src/main/services/NutstoreService.ts index 4422ea8a07..f4ad5a2c33 100644 --- a/src/main/services/NutstoreService.ts +++ b/src/main/services/NutstoreService.ts @@ -2,6 +2,7 @@ import path from 'node:path' import { loggerService } from '@logger' import { NUTSTORE_HOST } from '@shared/config/nutstore' +import { net } from 'electron' import { XMLParser } from 'fast-xml-parser' import { isNil, partial } from 'lodash' import { type FileStat } from 'webdav' @@ -62,7 +63,7 @@ export async function getDirectoryContents(token: string, target: string): Promi let currentUrl = `${NUTSTORE_HOST}${target}` while (true) { - const response = await fetch(currentUrl, { + const response = await net.fetch(currentUrl, { method: 'PROPFIND', headers: { Authorization: `Basic ${token}`, diff --git a/src/main/utils/ipService.ts b/src/main/utils/ipService.ts index ec5ab78215..3180f9457c 100644 --- a/src/main/utils/ipService.ts +++ b/src/main/utils/ipService.ts @@ -1,4 +1,5 @@ import { loggerService } from '@logger' +import { net } from 'electron' const logger = loggerService.withContext('IpService') @@ -12,7 +13,7 @@ export async function getIpCountry(): Promise { const controller = new AbortController() const timeoutId = setTimeout(() => controller.abort(), 5000) - const ipinfo = await fetch('https://ipinfo.io/json', { + const ipinfo = await net.fetch('https://ipinfo.io/json', { signal: controller.signal, headers: { 'User-Agent': From e0dbd2d2dbea06ca1abf335066964c6bdc34c669 Mon Sep 17 00:00:00 2001 From: SuYao Date: Fri, 15 Aug 2025 22:56:40 +0800 Subject: [PATCH 04/14] fix/9165 (#9194) * fix/9165 * fix: early return --- .../aiCore/clients/openai/OpenAIApiClient.ts | 12 ++++--- src/renderer/src/config/models.ts | 36 ++++++++++--------- src/renderer/src/types/index.ts | 1 + 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts b/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts index 60173551b4..7568ee69be 100644 --- a/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts +++ b/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts @@ -9,6 +9,7 @@ import { isGPT5SeriesModel, isGrokReasoningModel, isNotSupportSystemMessageModel, + isOpenAIReasoningModel, isQwenAlwaysThinkModel, isQwenMTModel, isQwenReasoningModel, @@ -146,7 +147,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient< return {} } // Don't disable reasoning for models that require it - if (isGrokReasoningModel(model)) { + if (isGrokReasoningModel(model) || isOpenAIReasoningModel(model)) { return {} } return { reasoning: { enabled: false, exclude: true } } @@ -524,12 +525,13 @@ export class OpenAIAPIClient extends OpenAIBaseClient< } // 1. 处理系统消息 - let systemMessage = { role: 'system', content: assistant.prompt || '' } + const systemMessage = { role: 'system', content: assistant.prompt || '' } if (isSupportedReasoningEffortOpenAIModel(model)) { - systemMessage = { - role: isSupportDeveloperRoleProvider(this.provider) ? 'developer' : 'system', - content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}` + if (isSupportDeveloperRoleProvider(this.provider)) { + systemMessage.role = 'developer' + } else { + systemMessage.role = 'system' } } diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index b5164dfe5b..e6276cbb98 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -292,6 +292,7 @@ export const CLAUDE_SUPPORTED_WEBSEARCH_REGEX = new RegExp( // 模型类型到支持的reasoning_effort的映射表 export const MODEL_SUPPORTED_REASONING_EFFORT: ReasoningEffortConfig = { default: ['low', 'medium', 'high'] as const, + o: ['low', 'medium', 'high'] as const, gpt5: ['minimal', 'low', 'medium', 'high'] as const, grok: ['low', 'high'] as const, gemini: ['low', 'medium', 'high', 'auto'] as const, @@ -307,7 +308,8 @@ export const MODEL_SUPPORTED_REASONING_EFFORT: ReasoningEffortConfig = { // 模型类型到支持选项的映射表 export const MODEL_SUPPORTED_OPTIONS: ThinkingOptionConfig = { default: ['off', ...MODEL_SUPPORTED_REASONING_EFFORT.default] as const, - gpt5: ['off', ...MODEL_SUPPORTED_REASONING_EFFORT.gpt5] as const, + o: MODEL_SUPPORTED_REASONING_EFFORT.o, + gpt5: [...MODEL_SUPPORTED_REASONING_EFFORT.gpt5] as const, grok: MODEL_SUPPORTED_REASONING_EFFORT.grok, gemini: ['off', ...MODEL_SUPPORTED_REASONING_EFFORT.gemini] as const, gemini_pro: MODEL_SUPPORTED_REASONING_EFFORT.gemini_pro, @@ -320,28 +322,28 @@ export const MODEL_SUPPORTED_OPTIONS: ThinkingOptionConfig = { } as const export const getThinkModelType = (model: Model): ThinkingModelType => { + let thinkingModelType: ThinkingModelType = 'default' if (isGPT5SeriesModel(model)) { - return 'gpt5' - } - if (isSupportedThinkingTokenGeminiModel(model)) { + thinkingModelType = 'gpt5' + } else if (isSupportedReasoningEffortOpenAIModel(model)) { + thinkingModelType = 'o' + } else if (isSupportedThinkingTokenGeminiModel(model)) { if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) { - return 'gemini' + thinkingModelType = 'gemini' } else { - return 'gemini_pro' + thinkingModelType = 'gemini_pro' } - } - if (isSupportedReasoningEffortGrokModel(model)) return 'grok' - if (isSupportedThinkingTokenQwenModel(model)) { + } else if (isSupportedReasoningEffortGrokModel(model)) thinkingModelType = 'grok' + else if (isSupportedThinkingTokenQwenModel(model)) { if (isQwenAlwaysThinkModel(model)) { - return 'qwen_thinking' + thinkingModelType = 'qwen_thinking' } - return 'qwen' - } - if (isSupportedThinkingTokenDoubaoModel(model)) return 'doubao' - if (isSupportedThinkingTokenHunyuanModel(model)) return 'hunyuan' - if (isSupportedReasoningEffortPerplexityModel(model)) return 'perplexity' - if (isSupportedThinkingTokenZhipuModel(model)) return 'zhipu' - return 'default' + thinkingModelType = 'qwen' + } else if (isSupportedThinkingTokenDoubaoModel(model)) thinkingModelType = 'doubao' + else if (isSupportedThinkingTokenHunyuanModel(model)) thinkingModelType = 'hunyuan' + else if (isSupportedReasoningEffortPerplexityModel(model)) thinkingModelType = 'perplexity' + else if (isSupportedThinkingTokenZhipuModel(model)) thinkingModelType = 'zhipu' + return thinkingModelType } export function isFunctionCallingModel(model?: Model): boolean { diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index ca4a1fe4fd..5f649c5e8f 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -56,6 +56,7 @@ export type ReasoningEffortOption = NonNullable | 'auto' export type ThinkingOption = ReasoningEffortOption | 'off' export type ThinkingModelType = | 'default' + | 'o' | 'gpt5' | 'grok' | 'gemini' From a02b4b3955d742fb5b10455bacf357c9fd9dec82 Mon Sep 17 00:00:00 2001 From: Pleasure1234 <3196812536@qq.com> Date: Sat, 16 Aug 2025 04:00:32 +0800 Subject: [PATCH 05/14] fix: websearch (#9222) Update LocalSearchProvider.ts --- .../providers/WebSearchProvider/LocalSearchProvider.ts | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts b/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts index abdf9fc826..7911646630 100644 --- a/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts @@ -89,15 +89,9 @@ export default class LocalSearchProvider extends BaseWebSearchProvider { * @returns 带有语言过滤的查询 */ protected applyLanguageFilter(query: string, language: string): string { - if (this.provider.id.includes('local-google')) { + if (this.provider.id.includes('local-google') || this.provider.id.includes('local-bing')) { return `${query} lang:${language.split('-')[0]}` } - if (this.provider.id.includes('local-bing')) { - return `${query} language:${language}` - } - if (this.provider.id.includes('local-baidu')) { - return `${query} language:${language.split('-')[0]}` - } return query } From 04326eba21ddf14fe62199ff11d8b1c50fa51858 Mon Sep 17 00:00:00 2001 From: miro <16189212+miroklarin@users.noreply.github.com> Date: Sat, 16 Aug 2025 04:21:29 +0100 Subject: [PATCH 06/14] feat: Use different window name for Quick Assistant (#9217) Co-authored-by: Miro Klarin --- src/renderer/miniWindow.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/renderer/miniWindow.html b/src/renderer/miniWindow.html index 83b108b8a4..7f3b936444 100644 --- a/src/renderer/miniWindow.html +++ b/src/renderer/miniWindow.html @@ -6,7 +6,7 @@ - Cherry Studio + Cherry Studio Quick Assistant